func TestQuality(t *testing.T) { // On a certain function we know that the prediction is good // confirm that it is nDim := 2 nTrain := 160 xTrain, yTrain := generateRandomSamples(nTrain, nDim) nTest := 1000 xTest, yTest := generateRandomSamples(nTest, nDim) nFeatures := 300 sigmaSq := 0.01 kernel := &IsoSqExp{LogScale: math.Log(sigmaSq)} sink := NewTrainer(nDim, nDim, nFeatures, kernel) parameters := train.LinearSolve(sink, nil, xTrain, yTrain, nil, regularize.None{}) sink.SetParameters(parameters) /* batchGrad := train.NewBatchGradBased(sink, true, xTrain, yTrain, nil, loss.SquaredDistance{}, regularize.None{}) derivative := make([]float64, sink.NumParameters()) batchGrad.ObjGrad(parameters, derivative) fmt.Println("Quality derivative") //fmt.Println(derivative) fmt.Println("sum deriv = ", floats.Sum(derivative)) sink.RandomizeParameters() sink.Parameters(parameters) batchGrad.ObjGrad(parameters, derivative) fmt.Println("sum deriv 2 = ", floats.Sum(derivative)) */ // Predict on new values pred, err := sink.PredictBatch(xTest, nil) if err != nil { t.Errorf(err.Error()) } for i := 0; i < nTest; i++ { for j := 0; j < nDim; j++ { diff := pred.At(i, j) - yTest.At(i, j) if math.Abs(diff) > 1e-9 { t.Errorf("Mismatch sample %v, output %v. Expected %v, Found %v", i, j, yTest.At(i, j), pred.At(i, j)) } } } }
// TestLinearsolveAndDeriv compares the optimal weights found from gradient-based optimization with those found // from computing a linear solve func TestLinearsolveAndDeriv(t *testing.T, linear train.LinearTrainable, inputs, trueOutputs common.RowMatrix, name string) { // Compare with no weights rows, cols := trueOutputs.Dims() predOutLinear := mat64.NewDense(rows, cols, nil) parametersLinearSolve := train.LinearSolve(linear, nil, inputs, trueOutputs, nil, nil) linear.SetParameters(parametersLinearSolve) linear.Predictor().PredictBatch(inputs, predOutLinear) //fmt.Println("Pred out linear", predOutLinear) linear.RandomizeParameters() parameters := linear.Parameters(nil) batch := train.NewBatchGradBased(linear, true, inputs, trueOutputs, nil, loss.SquaredDistance{}, regularize.None{}) problem := batch settings := multivariate.DefaultSettings() settings.GradAbsTol = 1e-11 //settings. = 0 result, err := multivariate.OptimizeGrad(problem, parameters, settings, nil) if err != nil { t.Errorf("Error training: %v", err) } parametersDeriv := result.Loc deriv := make([]float64, linear.NumParameters()) loss1 := batch.ObjGrad(parametersDeriv, deriv) linear.SetParameters(parametersDeriv) predOutDeriv := mat64.NewDense(rows, cols, nil) linear.Predictor().PredictBatch(inputs, predOutDeriv) linear.RandomizeParameters() init2 := linear.Parameters(nil) batch2 := train.NewBatchGradBased(linear, true, inputs, trueOutputs, nil, loss.SquaredDistance{}, regularize.None{}) problem2 := batch2 result2, err := multivariate.OptimizeGrad(problem2, init2, settings, nil) parametersDeriv2 := result2.Loc //fmt.Println("starting deriv2 loss") deriv2 := make([]float64, linear.NumParameters()) loss2 := batch2.ObjGrad(parametersDeriv2, deriv2) //fmt.Println("starting derivlin loss") derivlinear := make([]float64, linear.NumParameters()) lossLin := batch2.ObjGrad(parametersLinearSolve, derivlinear) _ = loss1 _ = loss2 _ = lossLin /* fmt.Println("param deriv 1 =", parametersDeriv) fmt.Println("param deriv2 =", parametersDeriv2) fmt.Println("linear params =", parametersLinearSolve) fmt.Println("deriv1 loss =", loss1) fmt.Println("deriv2 loss =", loss2) fmt.Println("lin loss =", lossLin) fmt.Println("deriv =", deriv) fmt.Println("deriv2 =", deriv2) fmt.Println("linderiv =", derivlinear) //fmt.Println("Pred out deriv", predOutDeriv) */ /* for i := 0; i < rows; i++ { fmt.Println(predOutLinear.RowView(i), predOutBatch.RowView(i)) } */ if !floats.EqualApprox(parametersLinearSolve, parametersDeriv, 1e-8) { t.Errorf("Parameters don't match for gradient based and linear solve.") //for i := range parametersDeriv { // fmt.Printf("index %v: Deriv = %v, linsolve = %v, diff = %v\n", i, parametersDeriv[i], parametersLinearSolve[i], parametersDeriv[i]-parametersLinearSolve[i]) //} } }