示例#1
0
文件: deprec.go 项目: hrautila/matops
func Mult0(C, A, B *matrix.FloatMatrix, alpha, beta float64, flags Flags) error {
	if A.Cols() != B.Rows() {
		return errors.New("A.cols != B.rows: size mismatch")
	}
	psize := int64(C.NumElements()) * int64(A.Cols())
	Ar := A.FloatArray()
	ldA := A.LeadingIndex()
	Br := B.FloatArray()
	ldB := B.LeadingIndex()
	Cr := C.FloatArray()
	ldC := C.LeadingIndex()

	if nWorker <= 1 || psize <= limitOne {
		calgo.DMult0(Cr, Ar, Br, alpha, beta, calgo.Flags(flags), ldC, ldA, ldB, B.Rows(),
			0, C.Cols(), 0, C.Rows(),
			vpLen, nB, mB)
		return nil
	}
	// here we have more than one worker available
	worker := func(cstart, cend, rstart, rend int, ready chan int) {
		calgo.DMult0(Cr, Ar, Br, alpha, beta, calgo.Flags(flags), ldC, ldA, ldB, B.Rows(),
			cstart, cend, rstart, rend, vpLen, nB, mB)
		ready <- 1
	}
	colworks, rowworks := divideWork(C.Rows(), C.Cols(), nWorker)
	scheduleWork(colworks, rowworks, C.Cols(), C.Rows(), worker)
	return nil
}
示例#2
0
文件: mmult.go 项目: hrautila/matops
// Generic matrix-matrix multpily. (blas.GEMM). Calculates
//   C = beta*C + alpha*A*B     (default)
//   C = beta*C + alpha*A.T*B   flags&TRANSA
//   C = beta*C + alpha*A*B.T   flags&TRANSB
//   C = beta*C + alpha*A.T*B.T flags&(TRANSA|TRANSB)
//
// C is M*N, A is M*P or P*M if flags&TRANSA. B is P*N or N*P if flags&TRANSB.
//
func Mult(C, A, B *matrix.FloatMatrix, alpha, beta float64, flags Flags) error {
	var ok, empty bool
	// error checking must take in account flag values!

	ar, ac := A.Size()
	br, bc := B.Size()
	cr, cc := C.Size()
	switch flags & (TRANSA | TRANSB) {
	case TRANSA | TRANSB:
		empty = ac == 0 || br == 0
		ok = cr == ac && cc == br && ar == bc
	case TRANSA:
		empty = ac == 0 || bc == 0
		ok = cr == ac && cc == bc && ar == br
	case TRANSB:
		empty = ar == 0 || br == 0
		ok = cr == ar && cc == br && ac == bc
	default:
		empty = ar == 0 || bc == 0
		ok = cr == ar && cc == bc && ac == br

	}
	if empty {
		return nil
	}
	if !ok {
		return errors.New("Mult: size mismatch")
	}

	psize := int64(C.NumElements()) * int64(A.Cols())
	Ar := A.FloatArray()
	ldA := A.LeadingIndex()
	Br := B.FloatArray()
	ldB := B.LeadingIndex()
	Cr := C.FloatArray()
	ldC := C.LeadingIndex()

	// matrix A, B common dimension
	P := A.Cols()
	if flags&TRANSA != 0 {
		P = A.Rows()
	}

	if nWorker <= 1 || psize <= limitOne {
		calgo.DMult(Cr, Ar, Br, alpha, beta, calgo.Flags(flags), ldC, ldA, ldB, P,
			0, C.Cols(), 0, C.Rows(),
			vpLen, nB, mB)
		return nil
	}
	// here we have more than one worker available
	worker := func(cstart, cend, rstart, rend int, ready chan int) {
		calgo.DMult(Cr, Ar, Br, alpha, beta, calgo.Flags(flags), ldC, ldA, ldB, P,
			cstart, cend, rstart, rend, vpLen, nB, mB)
		ready <- 1
	}
	colworks, rowworks := divideWork(C.Rows(), C.Cols(), nWorker)
	scheduleWork(colworks, rowworks, C.Cols(), C.Rows(), worker)
	return nil
}
示例#3
0
文件: mmult.go 项目: hrautila/matops
// Symmetric matrix multiply. (blas.SYMM)
//   C = beta*C + alpha*A*B     (default)
//   C = beta*C + alpha*A.T*B   flags&TRANSA
//   C = beta*C + alpha*A*B.T   flags&TRANSB
//   C = beta*C + alpha*A.T*B.T flags&(TRANSA|TRANSB)
//
// C is N*P, A is N*N symmetric matrix. B is N*P or P*N if flags&TRANSB.
//
func MultSym(C, A, B *matrix.FloatMatrix, alpha, beta float64, flags Flags) error {
	var ok, empty bool

	ar, ac := A.Size()
	br, bc := B.Size()
	cr, cc := C.Size()
	switch flags & (TRANSA | TRANSB) {
	case TRANSA | TRANSB:
		empty = ac == 0 || br == 0
		ok = ar == ac && cr == ac && cc == br && ar == bc
	case TRANSA:
		empty = ac == 0 || bc == 0
		ok = ar == ac && cr == ac && cc == bc && ar == br
	case TRANSB:
		empty = ar == 0 || br == 0
		ok = ar == ac && cr == ar && cc == br && ac == bc
	default:
		empty = ar == 0 || bc == 0
		ok = ar == ac && cr == ar && cc == bc && ac == br
	}
	if empty {
		return nil
	}
	if !ok {
		return errors.New("MultSym: size mismatch")
	}
	/*
	   if A.Rows() != A.Cols() {
	       return errors.New("A matrix not square matrix.");
	   }
	   if A.Cols() != B.Rows() {
	       return errors.New("A.cols != B.rows: size mismatch")
	   }
	*/
	psize := int64(C.NumElements()) * int64(A.Cols())
	Ar := A.FloatArray()
	ldA := A.LeadingIndex()
	Br := B.FloatArray()
	ldB := B.LeadingIndex()
	Cr := C.FloatArray()
	ldC := C.LeadingIndex()

	if nWorker <= 1 || psize <= limitOne {
		calgo.DMultSymm(Cr, Ar, Br, alpha, beta, calgo.Flags(flags), ldC, ldA, ldB,
			A.Cols(), 0, C.Cols(), 0, C.Rows(), vpLen, nB, mB)
		return nil
	}
	// here we have more than one worker available
	worker := func(cstart, cend, rstart, rend int, ready chan int) {
		calgo.DMultSymm(Cr, Ar, Br, alpha, beta, calgo.Flags(flags), ldC, ldA, ldB,
			A.Cols(), cstart, cend, rstart, rend, vpLen, nB, mB)
		ready <- 1
	}
	colworks, rowworks := divideWork(C.Rows(), C.Cols(), nWorker)
	scheduleWork(colworks, rowworks, C.Cols(), C.Rows(), worker)
	return nil
}
示例#4
0
文件: mmult.go 项目: hrautila/matops
// Solve multiple right sides. If flags&UNIT then A diagonal is assumed to
// to unit and is not referenced. (blas.TRSM)
//      alpha*B = A.-1*B if flags&LEFT
//      alpha*B = A.-T*B if flags&(LEFT|TRANS)
//      alpha*B = B*A.-1 if flags&RIGHT
//      alpha*B = B*A.-T if flags&(RIGHT|TRANS)
//
// Matrix A is N*N triangular matrix defined with flags bits as follow
//  LOWER       non-unit lower triangular
//  LOWER|UNIT  unit lower triangular
//  UPPER       non-unit upper triangular
//  UPPER|UNIT  unit upper triangular
//
// Matrix B is N*P if flags&LEFT or P*N if flags&RIGHT.
//
func SolveTrm(B, A *matrix.FloatMatrix, alpha float64, flags Flags) error {

	ok := true
	empty := false
	br, bc := B.Size()
	ar, ac := A.Size()
	switch flags & (LEFT | RIGHT) {
	case LEFT:
		empty = br == 0
		ok = br == ac && ac == ar
	case RIGHT:
		empty = bc == 0
		ok = bc == ar && ac == ar
	}
	if empty {
		return nil
	}
	if !ok {
		return onError("A, B size mismatch")
	}

	Ar := A.FloatArray()
	ldA := A.LeadingIndex()
	Br := B.FloatArray()
	ldB := B.LeadingIndex()

	E := bc
	if flags&RIGHT != 0 {
		E = br
	}
	// if more workers available can divide to tasks by B columns if flags&LEFT or by
	// B rows if flags&RIGHT.
	calgo.DSolveBlk(Br, Ar, alpha, calgo.Flags(flags), ldB, ldA, ac, 0, E, nB)
	return nil
}
示例#5
0
文件: mmult.go 项目: hrautila/matops
// A = alpha*A + beta*B
// A = alpha*A + beta*B.T  if flags&TRANSB
func ScalePlus(A, B *matrix.FloatMatrix, alpha, beta float64, flags Flags) error {

	Ar := A.FloatArray()
	ldA := A.LeadingIndex()
	Br := B.FloatArray()
	ldB := B.LeadingIndex()
	S := 0
	L := A.Cols()
	R := 0
	E := A.Rows()
	calgo.DScalePlus(Ar, Br, alpha, beta, calgo.Flags(flags), ldA, ldB, S, L, R, E)
	return nil
}
示例#6
0
文件: mmult.go 项目: hrautila/matops
// Rank update for symmetric lower or upper matrix (blas.SYRK)
//      C = beta*C + alpha*A*A.T + alpha*A.T*A
func RankUpdateSym(C, A *matrix.FloatMatrix, alpha, beta float64, flags Flags) error {
	if C.Rows() != C.Cols() {
		return onError("C not a square matrix")
	}
	Ar := A.FloatArray()
	ldA := A.LeadingIndex()
	Cr := C.FloatArray()
	ldC := C.LeadingIndex()
	S := 0
	E := C.Rows()
	P := A.Cols()
	if flags&TRANSA != 0 {
		P = A.Rows()
	}
	// if more workers available C can be divided to blocks [S:E, S:E] along diagonal
	// and updated in separate tasks.
	calgo.DSymmRankBlk(Cr, Ar, alpha, beta, calgo.Flags(flags), ldC, ldA, P, S, E,
		vpLen, nB)
	return nil
}
示例#7
0
func update(t *testing.T, Y1, Y2, C1, C2, T, W *matrix.FloatMatrix) {
	if W.Rows() != C1.Cols() {
		panic("W.Rows != C1.Cols")
	}
	// W = C1.T
	ScalePlus(W, C1, 0.0, 1.0, TRANSB)
	//fmt.Printf("W = C1.T:\n%v\n", W)
	// W = C1.T*Y1
	//MultTrm(W, Y1, 1.0, LOWER|UNIT|RIGHT)
	Wr := W.FloatArray()
	Y1r := Y1.FloatArray()
	ldW := W.LeadingIndex()
	ldY := Y1.LeadingIndex()
	calgo.DTrmmUnblk(Wr, Y1r, 1.0, calgo.Flags(LOWER|UNIT|RIGHT),
		ldW, ldY, Y1.Cols(), 0, W.Rows(), 0)
	t.Logf("W = C1.T*Y1:\n%v\n", W)
	// W = W + C2.T*Y2
	Mult(W, C2, Y2, 1.0, 1.0, TRANSA)
	t.Logf("W = W + C2.T*Y2:\n%v\n", W)

	// --- here: W == C.T*Y ---
	// W = W*T
	MultTrm(W, T, 1.0, UPPER|RIGHT)
	t.Logf("W = C.T*Y*T:\n%v\n", W)

	// --- here: W == C.T*Y*T ---
	// C2 = C2 - Y2*W.T
	Mult(C2, Y2, W, -1.0, 1.0, TRANSB)
	t.Logf("C2 = C2 - Y2*W.T:\n%v\n", C2)
	//  W = Y1*W.T ==> W.T = W*Y1.T
	MultTrm(W, Y1, 1.0, LOWER|UNIT|TRANSA|RIGHT)
	t.Logf("W.T = W*Y1.T:\n%v\n", W)

	// C1 = C1 - W.T
	ScalePlus(C1, W, 1.0, -1.0, TRANSB)
	//fmt.Printf("C1 = C1 - W.T:\n%v\n", C1)

	// --- here: C = (I - Y*T*Y.T).T * C ---
}