Ejemplo n.º 1
0
// Calculate C = alpha*A*B.T + beta*C, C is M*N, A is M*P and B is N*P
func MMMultTransB(C, A, B *matrix.FloatMatrix, alpha, beta float64) error {
	psize := int64(C.NumElements() * A.Cols())
	Ar := A.FloatArray()
	ldA := A.LeadingIndex()
	Br := B.FloatArray()
	ldB := B.LeadingIndex()
	Cr := C.FloatArray()
	ldC := C.LeadingIndex()
	if nWorker <= 1 || psize <= limitOne {
		calgo.DMult(Cr, Ar, Br, alpha, beta, calgo.TRANSB, ldC, ldA, ldB,
			B.Rows(), 0, C.Cols(), 0, C.Rows(), vpLen, nB, mB)
		return nil
	}

	// here we have more than one worker available
	worker := func(cstart, cend, rstart, rend int, ready chan int) {
		calgo.DMult(Cr, Ar, Br, alpha, beta, calgo.TRANSB, ldC, ldA, ldB, B.Rows(),
			cstart, cend, rstart, rend, vpLen, nB, mB)
		ready <- 1
	}
	colworks, rowworks := divideWork(C.Rows(), C.Cols(), nWorker)
	scheduleWork(colworks, rowworks, C.Cols(), C.Rows(), worker)
	//scheduleWork(colworks, rowworks, worker)
	return nil
}
Ejemplo n.º 2
0
// Generic matrix-matrix multpily. (blas.GEMM). Calculates
//   C = beta*C + alpha*A*B     (default)
//   C = beta*C + alpha*A.T*B   flags&TRANSA
//   C = beta*C + alpha*A*B.T   flags&TRANSB
//   C = beta*C + alpha*A.T*B.T flags&(TRANSA|TRANSB)
//
// C is M*N, A is M*P or P*M if flags&TRANSA. B is P*N or N*P if flags&TRANSB.
//
func Mult(C, A, B *matrix.FloatMatrix, alpha, beta float64, flags Flags) error {
	var ok, empty bool
	// error checking must take in account flag values!

	ar, ac := A.Size()
	br, bc := B.Size()
	cr, cc := C.Size()
	switch flags & (TRANSA | TRANSB) {
	case TRANSA | TRANSB:
		empty = ac == 0 || br == 0
		ok = cr == ac && cc == br && ar == bc
	case TRANSA:
		empty = ac == 0 || bc == 0
		ok = cr == ac && cc == bc && ar == br
	case TRANSB:
		empty = ar == 0 || br == 0
		ok = cr == ar && cc == br && ac == bc
	default:
		empty = ar == 0 || bc == 0
		ok = cr == ar && cc == bc && ac == br

	}
	if empty {
		return nil
	}
	if !ok {
		return errors.New("Mult: size mismatch")
	}

	psize := int64(C.NumElements()) * int64(A.Cols())
	Ar := A.FloatArray()
	ldA := A.LeadingIndex()
	Br := B.FloatArray()
	ldB := B.LeadingIndex()
	Cr := C.FloatArray()
	ldC := C.LeadingIndex()

	// matrix A, B common dimension
	P := A.Cols()
	if flags&TRANSA != 0 {
		P = A.Rows()
	}

	if nWorker <= 1 || psize <= limitOne {
		calgo.DMult(Cr, Ar, Br, alpha, beta, calgo.Flags(flags), ldC, ldA, ldB, P,
			0, C.Cols(), 0, C.Rows(),
			vpLen, nB, mB)
		return nil
	}
	// here we have more than one worker available
	worker := func(cstart, cend, rstart, rend int, ready chan int) {
		calgo.DMult(Cr, Ar, Br, alpha, beta, calgo.Flags(flags), ldC, ldA, ldB, P,
			cstart, cend, rstart, rend, vpLen, nB, mB)
		ready <- 1
	}
	colworks, rowworks := divideWork(C.Rows(), C.Cols(), nWorker)
	scheduleWork(colworks, rowworks, C.Cols(), C.Rows(), worker)
	return nil
}