// Calculate C = alpha*A*B.T + beta*C, C is M*N, A is M*P and B is N*P func MMMultTransB(C, A, B *matrix.FloatMatrix, alpha, beta float64) error { psize := int64(C.NumElements() * A.Cols()) Ar := A.FloatArray() ldA := A.LeadingIndex() Br := B.FloatArray() ldB := B.LeadingIndex() Cr := C.FloatArray() ldC := C.LeadingIndex() if nWorker <= 1 || psize <= limitOne { calgo.DMult(Cr, Ar, Br, alpha, beta, calgo.TRANSB, ldC, ldA, ldB, B.Rows(), 0, C.Cols(), 0, C.Rows(), vpLen, nB, mB) return nil } // here we have more than one worker available worker := func(cstart, cend, rstart, rend int, ready chan int) { calgo.DMult(Cr, Ar, Br, alpha, beta, calgo.TRANSB, ldC, ldA, ldB, B.Rows(), cstart, cend, rstart, rend, vpLen, nB, mB) ready <- 1 } colworks, rowworks := divideWork(C.Rows(), C.Cols(), nWorker) scheduleWork(colworks, rowworks, C.Cols(), C.Rows(), worker) //scheduleWork(colworks, rowworks, worker) return nil }
// Generic matrix-matrix multpily. (blas.GEMM). Calculates // C = beta*C + alpha*A*B (default) // C = beta*C + alpha*A.T*B flags&TRANSA // C = beta*C + alpha*A*B.T flags&TRANSB // C = beta*C + alpha*A.T*B.T flags&(TRANSA|TRANSB) // // C is M*N, A is M*P or P*M if flags&TRANSA. B is P*N or N*P if flags&TRANSB. // func Mult(C, A, B *matrix.FloatMatrix, alpha, beta float64, flags Flags) error { var ok, empty bool // error checking must take in account flag values! ar, ac := A.Size() br, bc := B.Size() cr, cc := C.Size() switch flags & (TRANSA | TRANSB) { case TRANSA | TRANSB: empty = ac == 0 || br == 0 ok = cr == ac && cc == br && ar == bc case TRANSA: empty = ac == 0 || bc == 0 ok = cr == ac && cc == bc && ar == br case TRANSB: empty = ar == 0 || br == 0 ok = cr == ar && cc == br && ac == bc default: empty = ar == 0 || bc == 0 ok = cr == ar && cc == bc && ac == br } if empty { return nil } if !ok { return errors.New("Mult: size mismatch") } psize := int64(C.NumElements()) * int64(A.Cols()) Ar := A.FloatArray() ldA := A.LeadingIndex() Br := B.FloatArray() ldB := B.LeadingIndex() Cr := C.FloatArray() ldC := C.LeadingIndex() // matrix A, B common dimension P := A.Cols() if flags&TRANSA != 0 { P = A.Rows() } if nWorker <= 1 || psize <= limitOne { calgo.DMult(Cr, Ar, Br, alpha, beta, calgo.Flags(flags), ldC, ldA, ldB, P, 0, C.Cols(), 0, C.Rows(), vpLen, nB, mB) return nil } // here we have more than one worker available worker := func(cstart, cend, rstart, rend int, ready chan int) { calgo.DMult(Cr, Ar, Br, alpha, beta, calgo.Flags(flags), ldC, ldA, ldB, P, cstart, cend, rstart, rend, vpLen, nB, mB) ready <- 1 } colworks, rowworks := divideWork(C.Rows(), C.Cols(), nWorker) scheduleWork(colworks, rowworks, C.Cols(), C.Rows(), worker) return nil }