示例#1
0
func TestDTrsm3(t *testing.T) {
	const N = 31
	const K = 4

	A := cmat.NewMatrix(N, N)
	B := cmat.NewMatrix(N, K)
	B0 := cmat.NewMatrix(N, K)

	ones := cmat.NewFloatConstSource(1.0)
	zeromean := cmat.NewFloatUniformSource(1.0, -0.5)

	A.SetFrom(zeromean, cmat.UPPER)
	B.SetFrom(ones)
	B0.Copy(B)
	// B = A*B
	blasd.MultTrm(B, A, 1.0, gomas.UPPER|gomas.LEFT)
	blasd.SolveTrm(B, A, 1.0, gomas.UPPER|gomas.LEFT)
	ok := B0.AllClose(B)
	t.Logf("B == trsm(trmm(B, A, L|U|N), A, L|U|N) : %v\n", ok)

	B.Copy(B0)
	// B = A.T*B
	blasd.MultTrm(B, A, 1.0, gomas.UPPER|gomas.LEFT|gomas.TRANSA)
	blasd.SolveTrm(B, A, 1.0, gomas.UPPER|gomas.LEFT|gomas.TRANSA)
	ok = B0.AllClose(B)
	t.Logf("B == trsm(trmm(B, A, L|U|T), A, L|U|T) : %v\n", ok)
}
示例#2
0
func TestDTrsm1(t *testing.T) {
	nofail := true

	const N = 31
	const K = 4

	A := cmat.NewMatrix(N, N)
	B := cmat.NewMatrix(N, K)
	B0 := cmat.NewMatrix(N, K)

	ones := cmat.NewFloatConstSource(1.0)
	zeromean := cmat.NewFloatUniformSource(1.0, -0.5)

	A.SetFrom(zeromean, cmat.LOWER)
	B.SetFrom(ones)
	B0.Copy(B)
	// B = A*B
	blasd.MultTrm(B, A, 1.0, gomas.LOWER|gomas.LEFT)
	blasd.SolveTrm(B, A, 1.0, gomas.LOWER|gomas.LEFT)
	ok := B0.AllClose(B)
	nofail = nofail && ok
	t.Logf("B == trsm(trmm(B, A, L|L|N), A, L|L|N) : %v\n", ok)
	if !ok {
		t.Logf("B|B0:\n%v\n", cmat.NewJoin(cmat.AUGMENT, B, B0))
	}

	B.Copy(B0)
	// B = A.T*B
	blasd.MultTrm(B, A, 1.0, gomas.LOWER|gomas.LEFT|gomas.TRANSA)
	blasd.SolveTrm(B, A, 1.0, gomas.LOWER|gomas.LEFT|gomas.TRANSA)
	ok = B0.AllClose(B)
	nofail = nofail && ok
	t.Logf("B == trsm(trmm(B, A, L|L|T), A, L|L|T) : %v\n", ok)
}
示例#3
0
文件: qrtmult.go 项目: hrautila/gomas
/*
 * Solve a system of linear equations A*X = B with general M-by-N
 * matrix A using the QR factorization computed by DecomposeQRT().
 *
 * If flags&gomas.TRANS != 0:
 *   find the minimum norm solution of an overdetermined system A.T * X = B.
 *   i.e min ||X|| s.t A.T*X = B
 *
 * Otherwise:
 *   find the least squares solution of an overdetermined system, i.e.,
 *   solve the least squares problem: min || B - A*X ||.
 *
 * Arguments:
 *  B     On entry, the right hand side N-by-P matrix B. On exit, the solution matrix X.
 *
 *  A     The elements on and above the diagonal contain the min(M,N)-by-N upper
 *        trapezoidal matrix R. The elements below the diagonal with the matrix 'T',
 *        represent the ortogonal matrix Q as product of elementary reflectors.
 *        Matrix A and T are as returned by DecomposeQRT()
 *
 *  T     The block reflector computed from elementary reflectors as returned by
 *        DecomposeQRT() or computed from elementary reflectors and scalar coefficients
 *        by BuildT()
 *
 *  W     Workspace, size as returned by WorkspaceMultQT()
 *
 *  flags Indicator flag
 *
 *  conf  Blocking configuration
 *
 * Compatible with lapack.GELS (the m >= n part)
 */
func QRTSolve(B, A, T, W *cmat.FloatMatrix, flags int, confs ...*gomas.Config) *gomas.Error {
	var err *gomas.Error = nil
	var R, BT cmat.FloatMatrix
	conf := gomas.CurrentConf(confs...)

	if flags&gomas.TRANS != 0 {
		// Solve overdetermined system A.T*X = B

		// B' = R.-1*B
		R.SubMatrix(A, 0, 0, n(A), n(A))
		BT.SubMatrix(B, 0, 0, n(A), n(B))
		err = blasd.SolveTrm(&BT, &R, 1.0, gomas.LEFT|gomas.UPPER|gomas.TRANSA, conf)

		// Clear bottom part of B
		BT.SubMatrix(B, n(A), 0)
		BT.SetFrom(cmat.NewFloatConstSource(0.0))

		// X = Q*B'
		err = QRTMult(B, A, T, W, gomas.LEFT, conf)
	} else {
		// solve least square problem min ||A*X - B||

		// B' = Q.T*B
		err = QRTMult(B, A, T, W, gomas.LEFT|gomas.TRANS, conf)
		if err != nil {
			return err
		}

		// X = R.-1*B'
		R.SubMatrix(A, 0, 0, n(A), n(A))
		BT.SubMatrix(B, 0, 0, n(A), n(B))
		err = blasd.SolveTrm(&BT, &R, 1.0, gomas.LEFT|gomas.UPPER, conf)
	}
	return err
}
示例#4
0
func TestDTrms2(t *testing.T) {
	const N = 31
	const K = 4

	nofail := true

	A := cmat.NewMatrix(N, N)
	B := cmat.NewMatrix(K, N)
	B0 := cmat.NewMatrix(K, N)

	ones := cmat.NewFloatConstSource(1.0)
	zeromean := cmat.NewFloatUniformSource(1.0, -0.5)

	A.SetFrom(zeromean, cmat.LOWER)
	B.SetFrom(ones)
	B0.Copy(B)

	// B = B*A
	blasd.MultTrm(B, A, 1.0, gomas.LOWER|gomas.RIGHT)
	blasd.SolveTrm(B, A, 1.0, gomas.LOWER|gomas.RIGHT)
	ok := B0.AllClose(B)
	nofail = nofail && ok
	t.Logf("B == trsm(trmm(B, A, R|L|N), A, R|L|N) : %v\n", ok)

	B.Copy(B0)
	// B = B*A.T
	blasd.MultTrm(B, A, 1.0, gomas.LOWER|gomas.RIGHT|gomas.TRANSA)
	blasd.SolveTrm(B, A, 1.0, gomas.LOWER|gomas.RIGHT|gomas.TRANSA)
	ok = B0.AllClose(B)
	nofail = nofail && ok
	t.Logf("B == trsm(trmm(B, A, R|L|T), A, R|L|T) : %v\n", ok)
}
示例#5
0
文件: lu.go 项目: hrautila/gomas
// blocked LU decomposition w/o pivots, FLAME LU nopivots variant 5
func blockedLUnoPiv(A *cmat.FloatMatrix, nb int, conf *gomas.Config) *gomas.Error {
	var err *gomas.Error = nil
	var ATL, ATR, ABL, ABR cmat.FloatMatrix
	var A00, A01, A02, A10, A11, A12, A20, A21, A22 cmat.FloatMatrix

	util.Partition2x2(
		&ATL, &ATR,
		&ABL, &ABR, A, 0, 0, util.PTOPLEFT)

	for m(&ATL) < m(A)-nb {
		util.Repartition2x2to3x3(&ATL,
			&A00, &A01, &A02,
			&A10, &A11, &A12,
			&A20, &A21, &A22, A, nb, util.PBOTTOMRIGHT)

		// A00 = LU(A00)
		unblockedLUnoPiv(&A11, conf)
		// A12 = trilu(A00)*A12.-1  (TRSM)
		blasd.SolveTrm(&A12, &A11, 1.0, gomas.LEFT|gomas.LOWER|gomas.UNIT)
		// A21 = A21.-1*triu(A00) (TRSM)
		blasd.SolveTrm(&A21, &A11, 1.0, gomas.RIGHT|gomas.UPPER)
		// A22 = A22 - A21*A12
		blasd.Mult(&A22, &A21, &A12, -1.0, 1.0, gomas.NONE)

		util.Continue3x3to2x2(
			&ATL, &ATR,
			&ABL, &ABR, &A00, &A11, &A22, A, util.PBOTTOMRIGHT)
	}
	// last block
	if m(&ATL) < m(A) {
		unblockedLUnoPiv(&ABR, conf)
	}
	return err
}
示例#6
0
文件: lu.go 项目: hrautila/gomas
/*
 * Solve a system of linear equations A*X = B or A.T*X = B with general N-by-N
 * matrix A using the LU factorization computed by LUFactor().
 *
 * Arguments:
 *  B      On entry, the right hand side matrix B. On exit, the solution matrix X.
 *
 *  A      The factor L and U from the factorization A = P*L*U as computed by
 *         LUFactor()
 *
 *  pivots The pivot indices from LUFactor().
 *
 *  flags  The indicator of the form of the system of equations.
 *         If flags&TRANSA then system is transposed. All other values
 *         indicate non transposed system.
 *
 * Compatible with lapack.DGETRS.
 */
func LUSolve(B, A *cmat.FloatMatrix, pivots Pivots, flags int, confs ...*gomas.Config) *gomas.Error {
	var err *gomas.Error = nil
	conf := gomas.DefaultConf()
	if len(confs) > 0 {
		conf = confs[0]
	}
	ar, ac := A.Size()
	br, _ := B.Size()
	if ar != ac {
		return gomas.NewError(gomas.ENOTSQUARE, "SolveLU")
	}
	if br != ac {
		return gomas.NewError(gomas.ESIZE, "SolveLU")
	}
	if pivots != nil {
		applyPivots(B, pivots)
	}
	if flags&gomas.TRANSA != 0 {
		// transposed X = A.-1*B == (L.T*U.T).-1*B == U.-T*(L.-T*B)
		blasd.SolveTrm(B, A, 1.0, gomas.LOWER|gomas.UNIT|gomas.TRANSA, conf)
		blasd.SolveTrm(B, A, 1.0, gomas.UPPER|gomas.TRANSA, conf)
	} else {
		// non-transposed X = A.-1*B == (L*U).-1*B == U.-1*(L.-1*B)
		blasd.SolveTrm(B, A, 1.0, gomas.LOWER|gomas.UNIT, conf)
		blasd.SolveTrm(B, A, 1.0, gomas.UPPER, conf)
	}

	return err
}
示例#7
0
文件: chol.go 项目: hrautila/gomas
func blockedCHOL(A *cmat.FloatMatrix, flags int, conf *gomas.Config) *gomas.Error {
	var err, firstErr *gomas.Error
	var ATL, ATR, ABL, ABR cmat.FloatMatrix
	var A00, A01, A02, A10, A11, A12, A20, A21, A22 cmat.FloatMatrix

	nb := conf.LB
	err = nil
	firstErr = nil
	util.Partition2x2(
		&ATL, &ATR,
		&ABL, &ABR, A, 0, 0, util.PTOPLEFT)

	for m(A)-m(&ATL) > nb {
		util.Repartition2x2to3x3(&ATL,
			&A00, &A01, &A02,
			&A10, &A11, &A12,
			&A20, &A21, &A22, A, nb, util.PBOTTOMRIGHT)

		if flags&gomas.LOWER != 0 {
			// A11 = chol(A11)
			err = unblockedLowerCHOL(&A11, flags, m(&ATL))
			// A21 = A21 * tril(A11).-1
			blasd.SolveTrm(&A21, &A11, 1.0, gomas.RIGHT|gomas.LOWER|gomas.TRANSA, conf)
			// A22 = A22 - A21*A21.T
			blasd.UpdateSym(&A22, &A21, -1.0, 1.0, gomas.LOWER, conf)
		} else {
			// A11 = chol(A11)
			err = unblockedUpperCHOL(&A11, flags, m(&ATL))
			// A12 = triu(A11).-1 * A12
			blasd.SolveTrm(&A12, &A11, 1.0, gomas.UPPER|gomas.TRANSA, conf)
			// A22 = A22 - A12.T*A12
			blasd.UpdateSym(&A22, &A12, -1.0, 1.0, gomas.UPPER|gomas.TRANSA, conf)
		}
		if err != nil && firstErr == nil {
			firstErr = err
		}

		util.Continue3x3to2x2(
			&ATL, &ATR,
			&ABL, &ABR, &A00, &A11, &A22, A, util.PBOTTOMRIGHT)
	}

	if m(&ATL) < m(A) {
		// last block
		if flags&gomas.LOWER != 0 {
			unblockedLowerCHOL(&ABR, flags, 0)
		} else {
			unblockedUpperCHOL(&ABR, flags, 0)
		}
	}
	return firstErr
}
示例#8
0
文件: lqmult.go 项目: hrautila/gomas
/*
 * Solve a system of linear equations A.T*X = B with general M-by-N
 * matrix A using the QR factorization computed by LQFactor().
 *
 * If flags&TRANS != 0:
 *   find the minimum norm solution of an overdetermined system A.T * X = B.
 *   i.e min ||X|| s.t A.T*X = B
 *
 * Otherwise:
 *   find the least squares solution of an overdetermined system, i.e.,
 *   solve the least squares problem: min || B - A*X ||.
 *
 * Arguments:
 *  B     On entry, the right hand side N-by-P matrix B. On exit, the solution matrix X.
 *
 *  A     The elements on and below the diagonal contain the M-by-min(M,N) lower
 *        trapezoidal matrix L. The elements right of the diagonal with the vector 'tau',
 *        represent the ortogonal matrix Q as product of elementary reflectors.
 *        Matrix A is as returned by LQFactor()
 *
 *  tau   The vector of N scalar coefficients that together with trilu(A) define
 *        the ortogonal matrix Q as Q = H(N)H(N-1)...H(1)
 *
 *  W     Workspace, size required returned WorksizeMultLQ().
 *
 *  flags Indicator flags
 *
 *  conf  Optinal blocking configuration. If not given default will be used. Unblocked
 *        invocation is indicated with conf.LB == 0.
 *
 * Compatible with lapack.GELS (the m < n part)
 */
func LQSolve(B, A, tau, W *cmat.FloatMatrix, flags int, confs ...*gomas.Config) *gomas.Error {
	var err *gomas.Error = nil
	var L, BL cmat.FloatMatrix

	conf := gomas.CurrentConf(confs...)

	wsmin := wsMultLQLeft(B, 0)
	if W.Len() < wsmin {
		return gomas.NewError(gomas.EWORK, "SolveLQ", wsmin)
	}

	if flags&gomas.TRANS != 0 {
		// solve: MIN ||A.T*X - B||

		// B' = Q.T*B
		err = LQMult(B, A, tau, W, gomas.LEFT, conf)
		if err != nil {
			return err
		}

		// X = L.-1*B'
		L.SubMatrix(A, 0, 0, m(A), m(A))
		BL.SubMatrix(B, 0, 0, m(A), n(B))
		err = blasd.SolveTrm(&BL, &L, 1.0, gomas.LEFT|gomas.LOWER|gomas.TRANSA, conf)

	} else {
		// Solve underdetermined system A*X = B

		// B' = L.-1*B
		L.SubMatrix(A, 0, 0, m(A), m(A))
		BL.SubMatrix(B, 0, 0, m(A), n(B))
		err = blasd.SolveTrm(&BL, &L, 1.0, gomas.LEFT|gomas.LOWER, conf)

		// Clear bottom part of B
		BL.SubMatrix(B, m(A), 0)
		BL.SetFrom(cmat.NewFloatConstSource(0.0))

		// X = Q.T*B'
		err = LQMult(B, A, tau, W, gomas.LEFT|gomas.TRANS, conf)

	}
	return err
}
示例#9
0
文件: chol.go 项目: hrautila/gomas
/*
 * Solves a system system of linear equations A*X = B with symmetric positive
 * definite matrix A using the Cholesky factorization A = U.T*U or A = L*L.T
 * computed by DecomposeCHOL().
 *
 * Arguments:
 *  B   On entry, the right hand side matrix B. On exit, the solution
 *      matrix X.
 *
 *  A   The triangular factor U or L from Cholesky factorization as computed by
 *      DecomposeCHOL().
 *
 *  flags Indicator of which factor is stored in A. If flags&UPPER then upper
 *        triangle of A is stored. If flags&LOWER then lower triangle of A is
 *        stored.
 *
 * Compatible with lapack.DPOTRS.
 */
func CHOLSolve(B, A *cmat.FloatMatrix, flags int, confs ...*gomas.Config) *gomas.Error {
	// A*X = B; X = A.-1*B == (LU).-1*B == U.-1*L.-1*B == U.-1*(L.-1*B)
	conf := gomas.DefaultConf()
	if len(confs) > 0 {
		conf = confs[0]
	}
	ar, ac := A.Size()
	br, _ := B.Size()
	if ac != br || ar != ac {
		return gomas.NewError(gomas.ESIZE, "SolveCHOL")
	}
	if flags&gomas.UPPER != 0 {
		// X = (U.T*U).-1*B => U.-1*(U.-T*B)
		blasd.SolveTrm(B, A, 1.0, gomas.UPPER|gomas.TRANSA, conf)
		blasd.SolveTrm(B, A, 1.0, gomas.UPPER, conf)
	} else if flags&gomas.LOWER != 0 {
		// X = (L*L.T).-1*B = L.-T*(L.1*B)
		blasd.SolveTrm(B, A, 1.0, gomas.LOWER, conf)
		blasd.SolveTrm(B, A, 1.0, gomas.LOWER|gomas.TRANSA, conf)
	}
	return nil
}
示例#10
0
文件: lu.go 项目: hrautila/gomas
// unblocked LU decomposition with pivots: FLAME LU variant 3; Left-looking
func unblockedLUpiv(A *cmat.FloatMatrix, p *Pivots, offset int, conf *gomas.Config) *gomas.Error {
	var err *gomas.Error = nil
	var ATL, ATR, ABL, ABR cmat.FloatMatrix
	var A00, a01, A02, a10, a11, a12, A20, a21, A22 cmat.FloatMatrix
	var AL, AR, A0, a1, A2, aB1, AB0 cmat.FloatMatrix
	var pT, pB, p0, p1, p2 Pivots

	err = nil
	util.Partition2x2(
		&ATL, &ATR,
		&ABL, &ABR, A, 0, 0, util.PTOPLEFT)
	util.Partition1x2(
		&AL, &AR, A, 0, util.PLEFT)
	partitionPivot2x1(
		&pT,
		&pB, *p, 0, util.PTOP)

	for m(&ATL) < m(A) && n(&ATL) < n(A) {
		util.Repartition2x2to3x3(&ATL,
			&A00, &a01, &A02,
			&a10, &a11, &a12,
			&A20, &a21, &A22 /**/, A, 1, util.PBOTTOMRIGHT)
		util.Repartition1x2to1x3(&AL,
			&A0, &a1, &A2 /**/, A, 1, util.PRIGHT)
		repartPivot2x1to3x1(&pT,
			&p0, &p1, &p2 /**/, *p, 1, util.PBOTTOM)

		// apply previously computed pivots on current column
		applyPivots(&a1, p0)

		// a01 = trilu(A00) \ a01 (TRSV)
		blasd.MVSolveTrm(&a01, &A00, 1.0, gomas.LOWER|gomas.UNIT)
		// a11 = a11 - a10 *a01
		aval := a11.Get(0, 0) - blasd.Dot(&a10, &a01)
		a11.Set(0, 0, aval)
		// a21 = a21 -A20*a01
		blasd.MVMult(&a21, &A20, &a01, -1.0, 1.0, gomas.NONE)

		// pivot index on current column [a11, a21].T
		aB1.Column(&ABR, 0)
		p1[0] = pivotIndex(&aB1)
		// pivots to current column
		applyPivots(&aB1, p1)

		// a21 = a21 / a11
		if aval == 0.0 {
			if err == nil {
				ij := m(&ATL) + p1[0] - 1
				err = gomas.NewError(gomas.ESINGULAR, "DecomposeLU", ij)
			}
		} else {
			blasd.InvScale(&a21, a11.Get(0, 0))
		}

		// apply pivots to previous columns
		AB0.SubMatrix(&ABL, 0, 0)
		applyPivots(&AB0, p1)
		// scale last pivots to origin matrix row numbers
		p1[0] += m(&ATL)

		util.Continue3x3to2x2(
			&ATL, &ATR,
			&ABL, &ABR, &A00, &a11, &A22, A, util.PBOTTOMRIGHT)
		util.Continue1x3to1x2(
			&AL, &AR, &A0, &a1, A, util.PRIGHT)
		contPivot3x1to2x1(
			&pT,
			&pB, p0, p1, *p, util.PBOTTOM)
	}
	if n(&ATL) < n(A) {
		applyPivots(&ATR, *p)
		blasd.SolveTrm(&ATR, &ATL, 1.0, gomas.LEFT|gomas.UNIT|gomas.LOWER, conf)
	}
	return err
}
示例#11
0
文件: lu.go 项目: hrautila/gomas
// blocked LU decomposition with pivots: FLAME LU variant 3; left-looking version
func blockedLUpiv(A *cmat.FloatMatrix, p *Pivots, nb int, conf *gomas.Config) *gomas.Error {
	var err *gomas.Error = nil
	var ATL, ATR, ABL, ABR cmat.FloatMatrix
	var A00, A01, A02, A10, A11, A12, A20, A21, A22 cmat.FloatMatrix
	var AL, AR, A0, A1, A2, AB1, AB0 cmat.FloatMatrix
	var pT, pB, p0, p1, p2 Pivots

	util.Partition2x2(
		&ATL, &ATR,
		&ABL, &ABR, A, 0, 0, util.PTOPLEFT)
	util.Partition1x2(
		&AL, &AR, A, 0, util.PLEFT)
	partitionPivot2x1(
		&pT,
		&pB, *p, 0, util.PTOP)

	for m(&ATL) < m(A) && n(&ATL) < n(A) {
		util.Repartition2x2to3x3(&ATL,
			&A00, &A01, &A02,
			&A10, &A11, &A12,
			&A20, &A21, &A22 /**/, A, nb, util.PBOTTOMRIGHT)
		util.Repartition1x2to1x3(&AL,
			&A0, &A1, &A2 /**/, A, nb, util.PRIGHT)
		repartPivot2x1to3x1(&pT,
			&p0, &p1, &p2 /**/, *p, nb, util.PBOTTOM)

		// apply previously computed pivots
		applyPivots(&A1, p0)

		// a01 = trilu(A00) \ a01 (TRSV)
		blasd.SolveTrm(&A01, &A00, 1.0, gomas.LOWER|gomas.UNIT)
		// A11 = A11 - A10*A01
		blasd.Mult(&A11, &A10, &A01, -1.0, 1.0, gomas.NONE)
		// A21 = A21 - A20*A01
		blasd.Mult(&A21, &A20, &A01, -1.0, 1.0, gomas.NONE)

		// LU_piv(AB1, p1)
		AB1.SubMatrix(&ABR, 0, 0, m(&ABR), n(&A11))
		unblockedLUpiv(&AB1, &p1, m(&ATL), conf)

		// apply pivots to previous columns
		AB0.SubMatrix(&ABL, 0, 0)
		applyPivots(&AB0, p1)
		// scale last pivots to origin matrix row numbers
		for k, _ := range p1 {
			p1[k] += m(&ATL)
		}

		util.Continue3x3to2x2(
			&ATL, &ATR,
			&ABL, &ABR /**/, &A00, &A11, &A22, A, util.PBOTTOMRIGHT)
		util.Continue1x3to1x2(
			&AL, &AR /**/, &A0, &A1, A, util.PRIGHT)
		contPivot3x1to2x1(
			&pT,
			&pB /**/, p0, p1, *p, util.PBOTTOM)
	}
	if n(&ATL) < n(A) {
		applyPivots(&ATR, *p)
		blasd.SolveTrm(&ATR, &ATL, 1.0, gomas.LEFT|gomas.UNIT|gomas.LOWER)
	}
	return err
}