示例#1
0
文件: chol.go 项目: hrautila/gomas
func blockedCHOL(A *cmat.FloatMatrix, flags int, conf *gomas.Config) *gomas.Error {
	var err, firstErr *gomas.Error
	var ATL, ATR, ABL, ABR cmat.FloatMatrix
	var A00, A01, A02, A10, A11, A12, A20, A21, A22 cmat.FloatMatrix

	nb := conf.LB
	err = nil
	firstErr = nil
	util.Partition2x2(
		&ATL, &ATR,
		&ABL, &ABR, A, 0, 0, util.PTOPLEFT)

	for m(A)-m(&ATL) > nb {
		util.Repartition2x2to3x3(&ATL,
			&A00, &A01, &A02,
			&A10, &A11, &A12,
			&A20, &A21, &A22, A, nb, util.PBOTTOMRIGHT)

		if flags&gomas.LOWER != 0 {
			// A11 = chol(A11)
			err = unblockedLowerCHOL(&A11, flags, m(&ATL))
			// A21 = A21 * tril(A11).-1
			blasd.SolveTrm(&A21, &A11, 1.0, gomas.RIGHT|gomas.LOWER|gomas.TRANSA, conf)
			// A22 = A22 - A21*A21.T
			blasd.UpdateSym(&A22, &A21, -1.0, 1.0, gomas.LOWER, conf)
		} else {
			// A11 = chol(A11)
			err = unblockedUpperCHOL(&A11, flags, m(&ATL))
			// A12 = triu(A11).-1 * A12
			blasd.SolveTrm(&A12, &A11, 1.0, gomas.UPPER|gomas.TRANSA, conf)
			// A22 = A22 - A12.T*A12
			blasd.UpdateSym(&A22, &A12, -1.0, 1.0, gomas.UPPER|gomas.TRANSA, conf)
		}
		if err != nil && firstErr == nil {
			firstErr = err
		}

		util.Continue3x3to2x2(
			&ATL, &ATR,
			&ABL, &ABR, &A00, &A11, &A22, A, util.PBOTTOMRIGHT)
	}

	if m(&ATL) < m(A) {
		// last block
		if flags&gomas.LOWER != 0 {
			unblockedLowerCHOL(&ABR, flags, 0)
		} else {
			unblockedUpperCHOL(&ABR, flags, 0)
		}
	}
	return firstErr
}
示例#2
0
func TestDSyrkUpper(t *testing.T) {
	var ok bool
	conf := gomas.NewConf()

	A := cmat.NewMatrix(N, N)
	A0 := cmat.NewMatrix(N, N)
	B := cmat.NewMatrix(N, K)
	Bt := cmat.NewMatrix(K, N)

	ones := cmat.NewFloatConstSource(1.0)
	zeromean := cmat.NewFloatUniformSource()
	_, _ = ones, zeromean

	A.SetFrom(ones, cmat.UPPER)
	A0.Copy(A)
	B.SetFrom(ones)
	Bt.Transpose(B)

	// B = A*B
	blasd.UpdateSym(A, B, 1.0, 1.0, gomas.UPPER, conf)
	blasd.Mult(A0, B, B, 1.0, 1.0, gomas.TRANSB)
	cmat.TriU(A0, cmat.NONE)
	ok = A0.AllClose(A)
	t.Logf("UpdateSym(A, B, U|N) == TriU(Mult(A, B, B.T)) : %v\n", ok)
	if N < 10 {
		t.Logf("UpdateSym(A, B)\n%v\n", A)
		t.Logf("Mult(A, B.T, B)\n%v\n", A0)
	}
	A.SetFrom(ones, cmat.UPPER)
	A0.Copy(A)

	blasd.UpdateSym(A, Bt, 1.0, 1.0, gomas.UPPER|gomas.TRANSA, conf)
	blasd.Mult(A0, Bt, Bt, 1.0, 1.0, gomas.TRANSA)
	cmat.TriU(A0, cmat.NONE)
	ok = A0.AllClose(A)
	t.Logf("UpdateSym(A, B, U|T) == TriU(Mult(A, B.T, B)) : %v\n", ok)
	if N < 10 {
		t.Logf("UpdateSym(A, B)\n%v\n", A)
		t.Logf("Mult(A, B.T, B)\n%v\n", A0)
	}
}