Ejemplo n.º 1
0
func TestQuality(t *testing.T) {
	// On a certain function we know that the prediction is good
	// confirm that it is
	nDim := 2
	nTrain := 160
	xTrain, yTrain := generateRandomSamples(nTrain, nDim)

	nTest := 1000
	xTest, yTest := generateRandomSamples(nTest, nDim)

	nFeatures := 300

	sigmaSq := 0.01

	kernel := &IsoSqExp{LogScale: math.Log(sigmaSq)}

	sink := NewTrainer(nDim, nDim, nFeatures, kernel)

	parameters := train.LinearSolve(sink, nil, xTrain, yTrain, nil, regularize.None{})
	sink.SetParameters(parameters)

	/*
		batchGrad := train.NewBatchGradBased(sink, true, xTrain, yTrain, nil, loss.SquaredDistance{}, regularize.None{})

		derivative := make([]float64, sink.NumParameters())
		batchGrad.ObjGrad(parameters, derivative)
		fmt.Println("Quality derivative")
		//fmt.Println(derivative)
		fmt.Println("sum deriv = ", floats.Sum(derivative))
		sink.RandomizeParameters()
		sink.Parameters(parameters)
		batchGrad.ObjGrad(parameters, derivative)
		fmt.Println("sum deriv 2 = ", floats.Sum(derivative))

	*/

	// Predict on new values
	pred, err := sink.PredictBatch(xTest, nil)
	if err != nil {
		t.Errorf(err.Error())
	}
	for i := 0; i < nTest; i++ {
		for j := 0; j < nDim; j++ {
			diff := pred.At(i, j) - yTest.At(i, j)
			if math.Abs(diff) > 1e-9 {
				t.Errorf("Mismatch sample %v, output %v. Expected %v, Found %v", i, j, yTest.At(i, j), pred.At(i, j))
			}
		}
	}
}
Ejemplo n.º 2
0
// TestLinearsolveAndDeriv compares the optimal weights found from gradient-based optimization with those found
// from computing a linear solve
func TestLinearsolveAndDeriv(t *testing.T, linear train.LinearTrainable, inputs, trueOutputs common.RowMatrix, name string) {
	// Compare with no weights

	rows, cols := trueOutputs.Dims()
	predOutLinear := mat64.NewDense(rows, cols, nil)
	parametersLinearSolve := train.LinearSolve(linear, nil, inputs, trueOutputs, nil, nil)

	linear.SetParameters(parametersLinearSolve)
	linear.Predictor().PredictBatch(inputs, predOutLinear)

	//fmt.Println("Pred out linear", predOutLinear)

	linear.RandomizeParameters()
	parameters := linear.Parameters(nil)

	batch := train.NewBatchGradBased(linear, true, inputs, trueOutputs, nil, loss.SquaredDistance{}, regularize.None{})
	problem := batch
	settings := multivariate.DefaultSettings()
	settings.GradAbsTol = 1e-11
	//settings. = 0

	result, err := multivariate.OptimizeGrad(problem, parameters, settings, nil)
	if err != nil {
		t.Errorf("Error training: %v", err)
	}

	parametersDeriv := result.Loc

	deriv := make([]float64, linear.NumParameters())

	loss1 := batch.ObjGrad(parametersDeriv, deriv)

	linear.SetParameters(parametersDeriv)
	predOutDeriv := mat64.NewDense(rows, cols, nil)
	linear.Predictor().PredictBatch(inputs, predOutDeriv)

	linear.RandomizeParameters()
	init2 := linear.Parameters(nil)
	batch2 := train.NewBatchGradBased(linear, true, inputs, trueOutputs, nil, loss.SquaredDistance{}, regularize.None{})
	problem2 := batch2
	result2, err := multivariate.OptimizeGrad(problem2, init2, settings, nil)
	parametersDeriv2 := result2.Loc

	//fmt.Println("starting deriv2 loss")
	deriv2 := make([]float64, linear.NumParameters())
	loss2 := batch2.ObjGrad(parametersDeriv2, deriv2)

	//fmt.Println("starting derivlin loss")
	derivlinear := make([]float64, linear.NumParameters())
	lossLin := batch2.ObjGrad(parametersLinearSolve, derivlinear)

	_ = loss1
	_ = loss2
	_ = lossLin

	/*

		fmt.Println("param deriv 1 =", parametersDeriv)
		fmt.Println("param deriv2  =", parametersDeriv2)
		fmt.Println("linear params =", parametersLinearSolve)

		fmt.Println("deriv1 loss =", loss1)
		fmt.Println("deriv2 loss =", loss2)
		fmt.Println("lin loss    =", lossLin)

		fmt.Println("deriv    =", deriv)
		fmt.Println("deriv2   =", deriv2)
		fmt.Println("linderiv =", derivlinear)

		//fmt.Println("Pred out deriv", predOutDeriv)

	*/

	/*
		for i := 0; i < rows; i++ {
			fmt.Println(predOutLinear.RowView(i), predOutBatch.RowView(i))
		}
	*/

	if !floats.EqualApprox(parametersLinearSolve, parametersDeriv, 1e-8) {
		t.Errorf("Parameters don't match for gradient based and linear solve.")
		//for i := range parametersDeriv {
		//	fmt.Printf("index %v: Deriv = %v, linsolve = %v, diff = %v\n", i, parametersDeriv[i], parametersLinearSolve[i], parametersDeriv[i]-parametersLinearSolve[i])
		//}
	}

}