コード例 #1
0
ファイル: tsolve_test.go プロジェクト: sguzwf/algorithm
func solveMVTest(t *testing.T, A, X0 *matrix.FloatMatrix, flags Flags, bN, bNB int) {
	X1 := X0.Copy()

	uplo := linalg.OptUpper
	diag := linalg.OptNonUnit
	if flags&LOWER != 0 {
		uplo = linalg.OptLower
	}
	if flags&UNIT != 0 {
		diag = linalg.OptUnit
	}

	blas.TrsvFloat(A, X0, uplo, diag)

	Ar := A.FloatArray()
	Xr := X1.FloatArray()
	if bN == bNB {
		DSolveUnblkMV(Xr, Ar, flags, 1, A.LeadingIndex(), bN)
	} else {
		DSolveBlkMV(Xr, Ar, flags, 1, A.LeadingIndex(), bN, bNB)
	}
	ok := X1.AllClose(X0)
	t.Logf("X1 == X0: %v\n", ok)
	if !ok && bN < 8 {
		t.Logf("A=\n%v\n", A)
		t.Logf("X0=\n%v\n", X0)
		t.Logf("blas: X0\n%v\n", X0)
		t.Logf("X1:\n%v\n", X1)
	}
}
コード例 #2
0
ファイル: kkt.go プロジェクト: sguzwf/algorithm
func kktChol2(G *matrix.FloatMatrix, dims *sets.DimensionSet, A *matrix.FloatMatrix, mnl int) (kktFactor, error) {

	if len(dims.At("q")) > 0 || len(dims.At("s")) > 0 {
		return nil, errors.New("'chol2' solver only for problems with no second-order or " +
			"semidefinite cone constraints")
	}

	p, n := A.Size()
	ml := dims.At("l")[0]
	F := &chol2Data{firstcall: true, singular: false, A: A, G: G, dims: dims}

	factor := func(W *sets.FloatMatrixSet, H, Df *matrix.FloatMatrix) (KKTFunc, error) {
		var err error = nil
		minor := 0
		if !checkpnt.MinorEmpty() {
			minor = checkpnt.MinorTop()
		}
		if F.firstcall {
			F.Gs = matrix.FloatZeros(F.G.Size())
			if mnl > 0 {
				F.Dfs = matrix.FloatZeros(Df.Size())
			}
			F.S = matrix.FloatZeros(n, n)
			F.K = matrix.FloatZeros(p, p)
			checkpnt.AddMatrixVar("Gs", F.Gs)
			checkpnt.AddMatrixVar("Dfs", F.Dfs)
			checkpnt.AddMatrixVar("S", F.S)
			checkpnt.AddMatrixVar("K", F.K)
		}

		if mnl > 0 {
			dnli := matrix.FloatZeros(mnl, mnl)
			dnli.SetIndexesFromArray(W.At("dnli")[0].FloatArray(), matrix.DiagonalIndexes(dnli)...)
			blas.GemmFloat(dnli, Df, F.Dfs, 1.0, 0.0)
		}
		checkpnt.Check("02factor_chol2", minor)
		di := matrix.FloatZeros(ml, ml)
		di.SetIndexesFromArray(W.At("di")[0].FloatArray(), matrix.DiagonalIndexes(di)...)
		err = blas.GemmFloat(di, G, F.Gs, 1.0, 0.0)
		checkpnt.Check("06factor_chol2", minor)

		if F.firstcall {
			blas.SyrkFloat(F.Gs, F.S, 1.0, 0.0, la.OptTrans)
			if mnl > 0 {
				blas.SyrkFloat(F.Dfs, F.S, 1.0, 1.0, la.OptTrans)
			}
			if H != nil {
				F.S.Plus(H)
			}
			checkpnt.Check("10factor_chol2", minor)
			err = lapack.Potrf(F.S)
			if err != nil {
				err = nil // reset error
				F.singular = true
				// original code recreates F.S as dense if it is sparse and
				// A is dense, we don't do it as currently no sparse matrices
				//F.S = matrix.FloatZeros(n, n)
				//checkpnt.AddMatrixVar("S", F.S)
				blas.SyrkFloat(F.Gs, F.S, 1.0, 0.0, la.OptTrans)
				if mnl > 0 {
					blas.SyrkFloat(F.Dfs, F.S, 1.0, 1.0, la.OptTrans)
				}
				checkpnt.Check("14factor_chol2", minor)
				blas.SyrkFloat(F.A, F.S, 1.0, 1.0, la.OptTrans)
				if H != nil {
					F.S.Plus(H)
				}
				lapack.Potrf(F.S)
			}
			F.firstcall = false
			checkpnt.Check("20factor_chol2", minor)
		} else {
			blas.SyrkFloat(F.Gs, F.S, 1.0, 0.0, la.OptTrans)
			if mnl > 0 {
				blas.SyrkFloat(F.Dfs, F.S, 1.0, 1.0, la.OptTrans)
			}
			if H != nil {
				F.S.Plus(H)
			}
			checkpnt.Check("40factor_chol2", minor)
			if F.singular {
				blas.SyrkFloat(F.A, F.S, 1.0, 1.0, la.OptTrans)
			}
			lapack.Potrf(F.S)
			checkpnt.Check("50factor_chol2", minor)
		}

		// Asct := L^{-1}*A'.  Factor K = Asct'*Asct.
		Asct := F.A.Transpose()
		blas.TrsmFloat(F.S, Asct, 1.0)
		blas.SyrkFloat(Asct, F.K, 1.0, 0.0, la.OptTrans)
		lapack.Potrf(F.K)
		checkpnt.Check("90factor_chol2", minor)

		solve := func(x, y, z *matrix.FloatMatrix) (err error) {
			// Solve
			//
			//     [ H          A'  GG'*W^{-1} ]   [ ux   ]   [ bx        ]
			//     [ A          0   0          ] * [ uy   ] = [ by        ]
			//     [ W^{-T}*GG  0   -I         ]   [ W*uz ]   [ W^{-T}*bz ]
			//
			// and return ux, uy, W*uz.
			//
			// If not F['singular']:
			//
			//     K*uy = A * S^{-1} * ( bx + GG'*W^{-1}*W^{-T}*bz ) - by
			//     S*ux = bx + GG'*W^{-1}*W^{-T}*bz - A'*uy
			//     W*uz = W^{-T} * ( GG*ux - bz ).
			//
			// If F['singular']:
			//
			//     K*uy = A * S^{-1} * ( bx + GG'*W^{-1}*W^{-T}*bz + A'*by )
			//            - by
			//     S*ux = bx + GG'*W^{-1}*W^{-T}*bz + A'*by - A'*y.
			//     W*uz = W^{-T} * ( GG*ux - bz ).

			minor := 0
			if !checkpnt.MinorEmpty() {
				minor = checkpnt.MinorTop()
			}

			// z := W^{-1} * z = W^{-1} * bz
			scale(z, W, true, true)
			checkpnt.Check("10solve_chol2", minor)

			// If not F['singular']:
			//     x := L^{-1} * P * (x + GGs'*z)
			//        = L^{-1} * P * (x + GG'*W^{-1}*W^{-T}*bz)
			//
			// If F['singular']:
			//     x := L^{-1} * P * (x + GGs'*z + A'*y))
			//        = L^{-1} * P * (x + GG'*W^{-1}*W^{-T}*bz + A'*y)
			if mnl > 0 {
				blas.GemvFloat(F.Dfs, z, x, 1.0, 1.0, la.OptTrans)
			}
			blas.GemvFloat(F.Gs, z, x, 1.0, 1.0, la.OptTrans, &la.IOpt{"offsetx", mnl})
			//checkpnt.Check("20solve_chol2", minor)
			if F.singular {
				blas.GemvFloat(F.A, y, x, 1.0, 1.0, la.OptTrans)
			}
			checkpnt.Check("30solve_chol2", minor)
			blas.TrsvFloat(F.S, x)
			//checkpnt.Check("50solve_chol2", minor)

			// y := K^{-1} * (Asc*x - y)
			//    = K^{-1} * (A * S^{-1} * (bx + GG'*W^{-1}*W^{-T}*bz) - by)
			//      (if not F['singular'])
			//    = K^{-1} * (A * S^{-1} * (bx + GG'*W^{-1}*W^{-T}*bz +
			//      A'*by) - by)
			//      (if F['singular']).
			blas.GemvFloat(Asct, x, y, 1.0, -1.0, la.OptTrans)
			//checkpnt.Check("55solve_chol2", minor)
			lapack.Potrs(F.K, y)
			//checkpnt.Check("60solve_chol2", minor)

			// x := P' * L^{-T} * (x - Asc'*y)
			//    = S^{-1} * (bx + GG'*W^{-1}*W^{-T}*bz - A'*y)
			//      (if not F['singular'])
			//    = S^{-1} * (bx + GG'*W^{-1}*W^{-T}*bz + A'*by - A'*y)
			//      (if F['singular'])
			blas.GemvFloat(Asct, y, x, -1.0, 1.0)
			blas.TrsvFloat(F.S, x, la.OptTrans)
			checkpnt.Check("70solve_chol2", minor)

			// W*z := GGs*x - z = W^{-T} * (GG*x - bz)
			if mnl > 0 {
				blas.GemvFloat(F.Dfs, x, z, 1.0, -1.0)
			}
			blas.GemvFloat(F.Gs, x, z, 1.0, -1.0, &la.IOpt{"offsety", mnl})

			checkpnt.Check("90solve_chol2", minor)
			return nil
		}
		return solve, err
	}
	return factor, nil
}