예제 #1
0
파일: multi_test.go 프로젝트: kortschak/opt
func TestExampleModel(t *testing.T) {
	//badly conditioned Hessian leads to zig-zagging of the steepest descent
	//algorithm
	condNo := 100.0
	optSol := mat.Vec{1, 2}

	A := mat.NewFromArray([]float64{condNo, 0, 0, 1}, true, 2, 2)
	b := mat.Vec{-2 * optSol[0] * condNo, -2 * optSol[1]}
	c := -0.5 * mat.Dot(b, optSol)

	//define objective function
	fun := opt.NewQuadratic(A, b, c)

	//set inital solution estimate
	sol := NewSolution(mat.NewVec(2))

	//set termination parameters
	p := NewParams()
	p.IterMax = 5

	//Use steepest descent solver to solve the model
	result := NewSteepestDescent().Solve(fun, sol, p, NewDisplay(1))

	fmt.Println("x =", result.X)
	//should be [1,2], but because of the bad conditioning we made little
	//progress in the second dimension

	//Use a BFGS solver to refine the result:
	result = NewLBFGS().Solve(fun, result.Solution, p, NewDisplay(1))

	fmt.Println("x =", result.X)
}
예제 #2
0
파일: multi_test.go 프로젝트: kortschak/opt
func TestQuadratic(t *testing.T) {
	mat.Register(cops)
	n := 10
	xStar := mat.NewVec(n)
	xStar.AddSc(1)
	A := mat.RandN(n)
	At := A.TrView()
	AtA := mat.New(n)
	AtA.Mul(At, A)

	bTmp := mat.NewVec(n)
	bTmp.Apply(A, xStar)
	b := mat.NewVec(n)
	b.Apply(At, bTmp)
	b.Scal(-2)

	c := bTmp.Nrm2Sq()

	//Define input arguments
	obj := opt.NewQuadratic(AtA, b, c)
	p := NewParams()
	sol := NewSolution(mat.NewVec(n))

	//Steepest descent with armijo
	stDesc := NewSteepestDescent()
	res1 := stDesc.Solve(obj, sol, p, NewDisplay(100))

	t.Log(res1.ObjX, res1.FunEvals, res1.GradEvals, res1.Status)

	//Steepest descent with Quadratic
	stDesc.LineSearch = uni.DerivWrapper{uni.NewQuadratic()}
	res2 := stDesc.Solve(obj, sol, p, NewDisplay(100))

	t.Log(res2.ObjX, res2.FunEvals, res2.GradEvals, res2.Status)

	//LBFGS with armijo
	lbfgs := NewLBFGS()
	res3 := lbfgs.Solve(obj, sol, p, NewDisplay(10))

	t.Log(res3.ObjX, res3.FunEvals, res3.GradEvals, res3.Status)

	//constrained problems (constraints described as projection)
	projGrad := NewProjGrad()

	res4 := projGrad.Solve(obj, opt.RealPlus{}, sol, p, NewDisplay(100))

	t.Log(res4.ObjX, res4.FunEvals, res4.GradEvals, res4.Status)

	if math.Abs(res1.ObjX) > 0.01 {
		t.Fail()
	}
	if math.Abs(res2.ObjX) > 0.01 {
		t.Fail()
	}
	if math.Abs(res3.ObjX) > 0.01 {
		t.Fail()
	}
	if math.Abs(res4.ObjX) > 0.01 {
		t.Fail()
	}
}