// Mul takes the matrix product of a and b, placing the result in the receiver. // // See the Muler interface for more information. func (m *Dense) Mul(a, b Matrix) { ar, ac := a.Dims() br, bc := b.Dims() if ac != br { panic(ErrShape) } m.reuseAs(ar, bc) var w *Dense if m != a && m != b { w = m } else { w = getWorkspace(ar, bc, false) defer func() { m.Copy(w) putWorkspace(w) }() } if a, ok := a.(RawMatrixer); ok { if b, ok := b.(RawMatrixer); ok { amat, bmat := a.RawMatrix(), b.RawMatrix() blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, amat, bmat, 0, w.mat) return } } if a, ok := a.(Vectorer); ok { if b, ok := b.(Vectorer); ok { row := make([]float64, ac) col := make([]float64, br) for r := 0; r < ar; r++ { dataTmp := w.mat.Data[r*w.mat.Stride : r*w.mat.Stride+bc] for c := 0; c < bc; c++ { dataTmp[c] = blas64.Dot(ac, blas64.Vector{Inc: 1, Data: a.Row(row, r)}, blas64.Vector{Inc: 1, Data: b.Col(col, c)}, ) } } return } } row := make([]float64, ac) for r := 0; r < ar; r++ { for i := range row { row[i] = a.At(r, i) } for c := 0; c < bc; c++ { var v float64 for i, e := range row { v += e * b.At(i, c) } w.mat.Data[r*w.mat.Stride+c] = v } } }
// MulVec computes a * b if trans == false and a^T * b if trans == true. The // result is stored into the receiver. MulVec panics if the number of columns in // a does not equal the number of rows in b. func (v *Vector) MulVec(a Matrix, trans bool, b *Vector) { ar, ac := a.Dims() br := b.Len() if trans { if ar != br { panic(ErrShape) } } else { if ac != br { panic(ErrShape) } } var w Vector if v != a && v != b { w = *v } if w.n == 0 { if trans { w.mat.Data = use(w.mat.Data, ac) } else { w.mat.Data = use(w.mat.Data, ar) } w.mat.Inc = 1 w.n = ar if trans { w.n = ac } } else { if trans { if ac != w.n { panic(ErrShape) } } else { if ar != w.n { panic(ErrShape) } } } switch a := a.(type) { case RawSymmetricer: amat := a.RawSymmetric() blas64.Symv(1, amat, b.mat, 0, w.mat) case RawTriangular: w.CopyVec(b) amat := a.RawTriangular() ta := blas.NoTrans if trans { ta = blas.Trans } blas64.Trmv(ta, amat, w.mat) case RawMatrixer: amat := a.RawMatrix() t := blas.NoTrans if trans { t = blas.Trans } blas64.Gemv(t, 1, amat, b.mat, 0, w.mat) case Vectorer: if trans { col := make([]float64, ar) for c := 0; c < ac; c++ { w.mat.Data[c*w.mat.Inc] = blas64.Dot(ar, blas64.Vector{Inc: 1, Data: a.Col(col, c)}, b.mat, ) } } else { row := make([]float64, ac) for r := 0; r < ar; r++ { w.mat.Data[r*w.mat.Inc] = blas64.Dot(ac, blas64.Vector{Inc: 1, Data: a.Row(row, r)}, b.mat, ) } } default: if trans { col := make([]float64, ar) for c := 0; c < ac; c++ { for i := range col { col[i] = a.At(i, c) } var f float64 for i, e := range col { f += e * b.mat.Data[i*b.mat.Inc] } w.mat.Data[c*w.mat.Inc] = f } } else { row := make([]float64, ac) for r := 0; r < ar; r++ { for i := range row { row[i] = a.At(r, i) } var f float64 for i, e := range row { f += e * b.mat.Data[i*b.mat.Inc] } w.mat.Data[r*w.mat.Inc] = f } } } *v = w }
// MulTrans takes the matrix product of a and b, optionally transposing each, // and placing the result in the receiver. // // See the MulTranser interface for more information. func (m *Dense) MulTrans(a Matrix, aTrans bool, b Matrix, bTrans bool) { ar, ac := a.Dims() if aTrans { ar, ac = ac, ar } br, bc := b.Dims() if bTrans { br, bc = bc, br } if ac != br { panic(ErrShape) } m.reuseAs(ar, bc) var w *Dense if m != a && m != b { w = m } else { w = getWorkspace(ar, bc, false) defer func() { m.Copy(w) putWorkspace(w) }() } if a, ok := a.(RawMatrixer); ok { if b, ok := b.(RawMatrixer); ok { amat := a.RawMatrix() if a == b && aTrans != bTrans { var op blas.Transpose if aTrans { op = blas.Trans } else { op = blas.NoTrans } blas64.Syrk(op, 1, amat, 0, blas64.Symmetric{N: w.mat.Rows, Stride: w.mat.Stride, Data: w.mat.Data, Uplo: blas.Upper}) // Fill lower matrix with result. // TODO(kortschak): Investigate whether using blas64.Copy improves the performance of this significantly. for i := 0; i < w.mat.Rows; i++ { for j := i + 1; j < w.mat.Cols; j++ { w.set(j, i, w.at(i, j)) } } } else { var aOp, bOp blas.Transpose if aTrans { aOp = blas.Trans } else { aOp = blas.NoTrans } if bTrans { bOp = blas.Trans } else { bOp = blas.NoTrans } bmat := b.RawMatrix() blas64.Gemm(aOp, bOp, 1, amat, bmat, 0, w.mat) } return } } if a, ok := a.(Vectorer); ok { if b, ok := b.(Vectorer); ok { row := make([]float64, ac) col := make([]float64, br) if aTrans { if bTrans { for r := 0; r < ar; r++ { dataTmp := w.mat.Data[r*w.mat.Stride : r*w.mat.Stride+bc] for c := 0; c < bc; c++ { dataTmp[c] = blas64.Dot(ac, blas64.Vector{Inc: 1, Data: a.Col(row, r)}, blas64.Vector{Inc: 1, Data: b.Row(col, c)}, ) } } return } // TODO(jonlawlor): determine if (b*a)' is more efficient for r := 0; r < ar; r++ { dataTmp := w.mat.Data[r*w.mat.Stride : r*w.mat.Stride+bc] for c := 0; c < bc; c++ { dataTmp[c] = blas64.Dot(ac, blas64.Vector{Inc: 1, Data: a.Col(row, r)}, blas64.Vector{Inc: 1, Data: b.Col(col, c)}, ) } } return } if bTrans { for r := 0; r < ar; r++ { dataTmp := w.mat.Data[r*w.mat.Stride : r*w.mat.Stride+bc] for c := 0; c < bc; c++ { dataTmp[c] = blas64.Dot(ac, blas64.Vector{Inc: 1, Data: a.Row(row, r)}, blas64.Vector{Inc: 1, Data: b.Row(col, c)}, ) } } return } for r := 0; r < ar; r++ { dataTmp := w.mat.Data[r*w.mat.Stride : r*w.mat.Stride+bc] for c := 0; c < bc; c++ { dataTmp[c] = blas64.Dot(ac, blas64.Vector{Inc: 1, Data: a.Row(row, r)}, blas64.Vector{Inc: 1, Data: b.Col(col, c)}, ) } } return } } row := make([]float64, ac) if aTrans { if bTrans { for r := 0; r < ar; r++ { dataTmp := w.mat.Data[r*w.mat.Stride : r*w.mat.Stride+bc] for i := range row { row[i] = a.At(i, r) } for c := 0; c < bc; c++ { var v float64 for i, e := range row { v += e * b.At(c, i) } dataTmp[c] = v } } return } for r := 0; r < ar; r++ { dataTmp := w.mat.Data[r*w.mat.Stride : r*w.mat.Stride+bc] for i := range row { row[i] = a.At(i, r) } for c := 0; c < bc; c++ { var v float64 for i, e := range row { v += e * b.At(i, c) } dataTmp[c] = v } } return } if bTrans { for r := 0; r < ar; r++ { dataTmp := w.mat.Data[r*w.mat.Stride : r*w.mat.Stride+bc] for i := range row { row[i] = a.At(r, i) } for c := 0; c < bc; c++ { var v float64 for i, e := range row { v += e * b.At(c, i) } dataTmp[c] = v } } return } for r := 0; r < ar; r++ { dataTmp := w.mat.Data[r*w.mat.Stride : r*w.mat.Stride+bc] for i := range row { row[i] = a.At(r, i) } for c := 0; c < bc; c++ { var v float64 for i, e := range row { v += e * b.At(i, c) } dataTmp[c] = v } } }