Beispiel #1
0
// TestPredict tests that predict returns the expected value, and that calling predict in parallel
// also works
func TestPredictAndBatch(t *testing.T, p Predictor, inputs, trueOutputs common.RowMatrix, name string) {
	nSamples, inputDim := inputs.Dims()
	if inputDim != p.InputDim() {
		panic("input Dim doesn't match predictor input dim")
	}
	nOutSamples, outputDim := trueOutputs.Dims()
	if outputDim != p.OutputDim() {
		panic("outpuDim doesn't match predictor outputDim")
	}
	if nOutSamples != nSamples {
		panic("inputs and outputs have different number of rows")
	}

	// First, test sequentially
	for i := 0; i < nSamples; i++ {
		trueOut := make([]float64, outputDim)
		for j := 0; j < outputDim; j++ {
			trueOut[j] = trueOutputs.At(i, j)
		}
		// Predict with nil
		input := make([]float64, inputDim)
		inputCpy := make([]float64, inputDim)
		for j := 0; j < inputDim; j++ {
			input[j] = inputs.At(i, j)
			inputCpy[j] = inputs.At(i, j)
		}

		out1, err := p.Predict(input, nil)
		if err != nil {
			t.Errorf(name + ": Error predicting with nil output")
			return
		}
		if !floats.Equal(input, inputCpy) {
			t.Errorf("%v: input changed with nil input for row %v", name, i)
			break
		}
		out2 := make([]float64, outputDim)
		for j := 0; j < outputDim; j++ {
			out2[j] = rand.NormFloat64()
		}

		_, err = p.Predict(input, out2)
		if err != nil {
			t.Errorf("%v: error predicting with non-nil input for row %v", name, i)
			break
		}
		if !floats.Equal(input, inputCpy) {
			t.Errorf("%v: input changed with non-nil input for row %v", name, i)
			break
		}

		if !floats.Equal(out1, out2) {
			t.Errorf(name + ": different answers with nil and non-nil predict ")
			break
		}
		if !floats.EqualApprox(out1, trueOut, 1e-14) {
			t.Errorf("%v: predicted output doesn't match for row %v. Expected %v, found %v", name, i, trueOut, out1)
			break
		}
	}

	// Check that predict errors with bad sized arguments
	badOuput := make([]float64, outputDim+1)
	input := make([]float64, inputDim)
	for i := 0; i < inputDim; i++ {
		input[i] = inputs.At(0, i)
	}
	output := make([]float64, outputDim)
	for i := 0; i < outputDim; i++ {
		output[i] = trueOutputs.At(0, i)
	}

	_, err := p.Predict(input, badOuput)
	if err == nil {
		t.Errorf("Predict did not throw an error with an output too large")
	}
	if outputDim > 1 {
		badOuput := make([]float64, outputDim-1)
		_, err := p.Predict(input, badOuput)
		if err == nil {
			t.Errorf("Predict did not throw an error with an output too small")
		}
	}

	badInput := make([]float64, inputDim+1)
	_, err = p.Predict(badInput, output)
	if err == nil {
		t.Errorf("Predict did not err when input is too large")
	}
	if inputDim > 1 {
		badInput := make([]float64, inputDim-1)
		_, err = p.Predict(badInput, output)
		if err == nil {
			t.Errorf("Predict did not err when input is too small")
		}
	}

	// Now, test batch
	// With non-nil
	inputCpy := &mat64.Dense{}
	inputCpy.Clone(inputs)
	predOutput, err := p.PredictBatch(inputs, nil)
	if err != nil {
		t.Errorf("Error batch predicting: %v", err)
	}
	if !inputCpy.Equals(inputs) {
		t.Errorf("Inputs changed during call to PredictBatch")
	}
	predOutputRows, predOutputCols := predOutput.Dims()
	if predOutputRows != nSamples || predOutputCols != outputDim {
		t.Errorf("Dimension mismatch after predictbatch with nil input")
	}

	outputs := mat64.NewDense(nSamples, outputDim, nil)
	_, err = p.PredictBatch(inputs, outputs)

	pd := predOutput.(*mat64.Dense)
	if !pd.Equals(outputs) {
		t.Errorf("Different outputs from predict batch with nil and non-nil")
	}

	badInputs := mat64.NewDense(nSamples, inputDim+1, nil)
	_, err = p.PredictBatch(badInputs, outputs)
	if err == nil {
		t.Error("PredictBatch did not err when input dim too large")
	}
	badInputs = mat64.NewDense(nSamples+1, inputDim, nil)
	_, err = p.PredictBatch(badInputs, outputs)
	if err == nil {
		t.Errorf("PredictBatch did not err with row mismatch")
	}
	badOuputs := mat64.NewDense(nSamples, outputDim+1, nil)
	_, err = p.PredictBatch(inputs, badOuputs)
	if err == nil {
		t.Errorf("PredictBatch did not err with output dim too large")
	}
}