func blockedCHOL(A *cmat.FloatMatrix, flags int, conf *gomas.Config) *gomas.Error { var err, firstErr *gomas.Error var ATL, ATR, ABL, ABR cmat.FloatMatrix var A00, A01, A02, A10, A11, A12, A20, A21, A22 cmat.FloatMatrix nb := conf.LB err = nil firstErr = nil util.Partition2x2( &ATL, &ATR, &ABL, &ABR, A, 0, 0, util.PTOPLEFT) for m(A)-m(&ATL) > nb { util.Repartition2x2to3x3(&ATL, &A00, &A01, &A02, &A10, &A11, &A12, &A20, &A21, &A22, A, nb, util.PBOTTOMRIGHT) if flags&gomas.LOWER != 0 { // A11 = chol(A11) err = unblockedLowerCHOL(&A11, flags, m(&ATL)) // A21 = A21 * tril(A11).-1 blasd.SolveTrm(&A21, &A11, 1.0, gomas.RIGHT|gomas.LOWER|gomas.TRANSA, conf) // A22 = A22 - A21*A21.T blasd.UpdateSym(&A22, &A21, -1.0, 1.0, gomas.LOWER, conf) } else { // A11 = chol(A11) err = unblockedUpperCHOL(&A11, flags, m(&ATL)) // A12 = triu(A11).-1 * A12 blasd.SolveTrm(&A12, &A11, 1.0, gomas.UPPER|gomas.TRANSA, conf) // A22 = A22 - A12.T*A12 blasd.UpdateSym(&A22, &A12, -1.0, 1.0, gomas.UPPER|gomas.TRANSA, conf) } if err != nil && firstErr == nil { firstErr = err } util.Continue3x3to2x2( &ATL, &ATR, &ABL, &ABR, &A00, &A11, &A22, A, util.PBOTTOMRIGHT) } if m(&ATL) < m(A) { // last block if flags&gomas.LOWER != 0 { unblockedLowerCHOL(&ABR, flags, 0) } else { unblockedUpperCHOL(&ABR, flags, 0) } } return firstErr }
func TestDSyrkUpper(t *testing.T) { var ok bool conf := gomas.NewConf() A := cmat.NewMatrix(N, N) A0 := cmat.NewMatrix(N, N) B := cmat.NewMatrix(N, K) Bt := cmat.NewMatrix(K, N) ones := cmat.NewFloatConstSource(1.0) zeromean := cmat.NewFloatUniformSource() _, _ = ones, zeromean A.SetFrom(ones, cmat.UPPER) A0.Copy(A) B.SetFrom(ones) Bt.Transpose(B) // B = A*B blasd.UpdateSym(A, B, 1.0, 1.0, gomas.UPPER, conf) blasd.Mult(A0, B, B, 1.0, 1.0, gomas.TRANSB) cmat.TriU(A0, cmat.NONE) ok = A0.AllClose(A) t.Logf("UpdateSym(A, B, U|N) == TriU(Mult(A, B, B.T)) : %v\n", ok) if N < 10 { t.Logf("UpdateSym(A, B)\n%v\n", A) t.Logf("Mult(A, B.T, B)\n%v\n", A0) } A.SetFrom(ones, cmat.UPPER) A0.Copy(A) blasd.UpdateSym(A, Bt, 1.0, 1.0, gomas.UPPER|gomas.TRANSA, conf) blasd.Mult(A0, Bt, Bt, 1.0, 1.0, gomas.TRANSA) cmat.TriU(A0, cmat.NONE) ok = A0.AllClose(A) t.Logf("UpdateSym(A, B, U|T) == TriU(Mult(A, B.T, B)) : %v\n", ok) if N < 10 { t.Logf("UpdateSym(A, B)\n%v\n", A) t.Logf("Mult(A, B.T, B)\n%v\n", A0) } }