예제 #1
0
파일: deprec.go 프로젝트: sguzwf/algorithm
// Calculate C = alpha*A*B + beta*C, C is M*N, A is M*M and B is M*N
func MMSymmUpper(C, A, B *matrix.FloatMatrix, alpha, beta float64) error {

	if A.Rows() != A.Cols() {
		return errors.New("A matrix not square matrix.")
	}
	psize := int64(C.NumElements() * A.Cols())
	Ar := A.FloatArray()
	ldA := A.LeadingIndex()
	Br := B.FloatArray()
	ldB := B.LeadingIndex()
	Cr := C.FloatArray()
	ldC := C.LeadingIndex()

	if nWorker <= 1 || psize <= limitOne {
		calgo.DMultSymm(Cr, Ar, Br, alpha, beta, calgo.LEFT|calgo.UPPER, ldC, ldA, ldB,
			A.Cols(), 0, C.Cols(), 0, C.Rows(), vpLen, nB, mB)
		return nil
	}
	// here we have more than one worker available
	worker := func(cstart, cend, rstart, rend int, ready chan int) {
		calgo.DMultSymm(Cr, Ar, Br, alpha, beta, calgo.LEFT|calgo.UPPER, ldC, ldA, ldB,
			A.Cols(), cstart, cend, rstart, rend, vpLen, nB, mB)
		ready <- 1
	}
	colworks, rowworks := divideWork(C.Rows(), C.Cols(), nWorker)
	scheduleWork(colworks, rowworks, C.Cols(), C.Rows(), worker)
	return nil
}
예제 #2
0
파일: mmult.go 프로젝트: sguzwf/algorithm
// Symmetric matrix multiply. (blas.SYMM)
//   C = beta*C + alpha*A*B     (default)
//   C = beta*C + alpha*A.T*B   flags&TRANSA
//   C = beta*C + alpha*A*B.T   flags&TRANSB
//   C = beta*C + alpha*A.T*B.T flags&(TRANSA|TRANSB)
//
// C is N*P, A is N*N symmetric matrix. B is N*P or P*N if flags&TRANSB.
//
func MultSym(C, A, B *matrix.FloatMatrix, alpha, beta float64, flags Flags) error {
	var ok, empty bool

	ar, ac := A.Size()
	br, bc := B.Size()
	cr, cc := C.Size()
	switch flags & (TRANSA | TRANSB) {
	case TRANSA | TRANSB:
		empty = ac == 0 || br == 0
		ok = ar == ac && cr == ac && cc == br && ar == bc
	case TRANSA:
		empty = ac == 0 || bc == 0
		ok = ar == ac && cr == ac && cc == bc && ar == br
	case TRANSB:
		empty = ar == 0 || br == 0
		ok = ar == ac && cr == ar && cc == br && ac == bc
	default:
		empty = ar == 0 || bc == 0
		ok = ar == ac && cr == ar && cc == bc && ac == br
	}
	if empty {
		return nil
	}
	if !ok {
		return errors.New("MultSym: size mismatch")
	}
	/*
	   if A.Rows() != A.Cols() {
	       return errors.New("A matrix not square matrix.");
	   }
	   if A.Cols() != B.Rows() {
	       return errors.New("A.cols != B.rows: size mismatch")
	   }
	*/
	psize := int64(C.NumElements()) * int64(A.Cols())
	Ar := A.FloatArray()
	ldA := A.LeadingIndex()
	Br := B.FloatArray()
	ldB := B.LeadingIndex()
	Cr := C.FloatArray()
	ldC := C.LeadingIndex()

	if nWorker <= 1 || psize <= limitOne {
		calgo.DMultSymm(Cr, Ar, Br, alpha, beta, calgo.Flags(flags), ldC, ldA, ldB,
			A.Cols(), 0, C.Cols(), 0, C.Rows(), vpLen, nB, mB)
		return nil
	}
	// here we have more than one worker available
	worker := func(cstart, cend, rstart, rend int, ready chan int) {
		calgo.DMultSymm(Cr, Ar, Br, alpha, beta, calgo.Flags(flags), ldC, ldA, ldB,
			A.Cols(), cstart, cend, rstart, rend, vpLen, nB, mB)
		ready <- 1
	}
	colworks, rowworks := divideWork(C.Rows(), C.Cols(), nWorker)
	scheduleWork(colworks, rowworks, C.Cols(), C.Rows(), worker)
	return nil
}