// Mul takes the matrix product of a and b, placing the result in the receiver. // // See the Muler interface for more information. func (m *Dense) Mul(a, b Matrix) { ar, ac := a.Dims() br, bc := b.Dims() if ac != br { panic(ErrShape) } m.reuseAs(ar, bc) var w *Dense if m != a && m != b { w = m } else { w = getWorkspace(ar, bc, false) defer func() { m.Copy(w) putWorkspace(w) }() } if a, ok := a.(RawMatrixer); ok { if b, ok := b.(RawMatrixer); ok { amat, bmat := a.RawMatrix(), b.RawMatrix() blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, amat, bmat, 0, w.mat) return } } if a, ok := a.(Vectorer); ok { if b, ok := b.(Vectorer); ok { row := make([]float64, ac) col := make([]float64, br) for r := 0; r < ar; r++ { dataTmp := w.mat.Data[r*w.mat.Stride : r*w.mat.Stride+bc] for c := 0; c < bc; c++ { dataTmp[c] = blas64.Dot(ac, blas64.Vector{Inc: 1, Data: a.Row(row, r)}, blas64.Vector{Inc: 1, Data: b.Col(col, c)}, ) } } return } } row := make([]float64, ac) for r := 0; r < ar; r++ { for i := range row { row[i] = a.At(r, i) } for c := 0; c < bc; c++ { var v float64 for i, e := range row { v += e * b.At(i, c) } w.mat.Data[r*w.mat.Stride+c] = v } } }
// TODO: Need to add tests where one is overwritten. func TestMul(t *testing.T) { for _, test := range []struct { ar int ac int br int bc int Panics bool }{ { ar: 5, ac: 5, br: 5, bc: 5, Panics: false, }, { ar: 10, ac: 5, br: 5, bc: 3, Panics: false, }, { ar: 10, ac: 5, br: 5, bc: 8, Panics: false, }, { ar: 8, ac: 10, br: 10, bc: 3, Panics: false, }, { ar: 8, ac: 3, br: 3, bc: 10, Panics: false, }, { ar: 5, ac: 8, br: 8, bc: 10, Panics: false, }, { ar: 5, ac: 12, br: 12, bc: 8, Panics: false, }, { ar: 5, ac: 7, br: 8, bc: 10, Panics: true, }, } { ar := test.ar ac := test.ac br := test.br bc := test.bc // Generate random matrices avec := make([]float64, ar*ac) randomSlice(avec) a := NewDense(ar, ac, avec) bvec := make([]float64, br*bc) randomSlice(bvec) b := NewDense(br, bc, bvec) // Check that it panics if it is supposed to if test.Panics { c := NewDense(0, 0, nil) fn := func() { c.Mul(a, b) } pan, _ := panics(fn) if !pan { t.Errorf("Mul did not panic with dimension mismatch") } continue } cvec := make([]float64, ar*bc) // Get correct matrix multiply answer from blas64.Gemm blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, a.mat, b.mat, 0, blas64.General{Rows: ar, Cols: bc, Stride: bc, Data: cvec}, ) avecCopy := append([]float64{}, avec...) bvecCopy := append([]float64{}, bvec...) cvecCopy := append([]float64{}, cvec...) acomp := matComp{r: ar, c: ac, data: avecCopy} bcomp := matComp{r: br, c: bc, data: bvecCopy} ccomp := matComp{r: ar, c: bc, data: cvecCopy} // Do normal multiply with empty dense d := NewDense(0, 0, nil) testMul(t, a, b, d, acomp, bcomp, ccomp, false, "zero receiver") // Normal multiply with existing receiver c := NewDense(ar, bc, cvec) randomSlice(cvec) testMul(t, a, b, c, acomp, bcomp, ccomp, false, "existing receiver") // Test with vectorers avm := (*basicVectorer)(a) bvm := (*basicVectorer)(b) d.Reset() testMul(t, avm, b, d, acomp, bcomp, ccomp, true, "a vectoror with zero receiver") d.Reset() testMul(t, a, bvm, d, acomp, bcomp, ccomp, true, "b vectoror with zero receiver") d.Reset() testMul(t, avm, bvm, d, acomp, bcomp, ccomp, true, "both vectoror with zero receiver") randomSlice(cvec) testMul(t, avm, b, c, acomp, bcomp, ccomp, true, "a vectoror with existing receiver") randomSlice(cvec) testMul(t, a, bvm, c, acomp, bcomp, ccomp, true, "b vectoror with existing receiver") randomSlice(cvec) testMul(t, avm, bvm, c, acomp, bcomp, ccomp, true, "both vectoror with existing receiver") // Cast a as a basic matrix am := (*basicMatrix)(a) bm := (*basicMatrix)(b) d.Reset() testMul(t, am, b, d, acomp, bcomp, ccomp, true, "a is basic, receiver is zero") d.Reset() testMul(t, a, bm, d, acomp, bcomp, ccomp, true, "b is basic, receiver is zero") d.Reset() testMul(t, am, bm, d, acomp, bcomp, ccomp, true, "both basic, receiver is zero") randomSlice(cvec) testMul(t, am, b, d, acomp, bcomp, ccomp, true, "a is basic, receiver is full") randomSlice(cvec) testMul(t, a, bm, d, acomp, bcomp, ccomp, true, "b is basic, receiver is full") randomSlice(cvec) testMul(t, am, bm, d, acomp, bcomp, ccomp, true, "both basic, receiver is full") } }
// MulTrans takes the matrix product of a and b, optionally transposing each, // and placing the result in the receiver. // // See the MulTranser interface for more information. func (m *Dense) MulTrans(a Matrix, aTrans bool, b Matrix, bTrans bool) { ar, ac := a.Dims() if aTrans { ar, ac = ac, ar } br, bc := b.Dims() if bTrans { br, bc = bc, br } if ac != br { panic(ErrShape) } m.reuseAs(ar, bc) var w *Dense if m != a && m != b { w = m } else { w = getWorkspace(ar, bc, false) defer func() { m.Copy(w) putWorkspace(w) }() } if a, ok := a.(RawMatrixer); ok { if b, ok := b.(RawMatrixer); ok { amat := a.RawMatrix() if a == b && aTrans != bTrans { var op blas.Transpose if aTrans { op = blas.Trans } else { op = blas.NoTrans } blas64.Syrk(op, 1, amat, 0, blas64.Symmetric{N: w.mat.Rows, Stride: w.mat.Stride, Data: w.mat.Data, Uplo: blas.Upper}) // Fill lower matrix with result. // TODO(kortschak): Investigate whether using blas64.Copy improves the performance of this significantly. for i := 0; i < w.mat.Rows; i++ { for j := i + 1; j < w.mat.Cols; j++ { w.set(j, i, w.at(i, j)) } } } else { var aOp, bOp blas.Transpose if aTrans { aOp = blas.Trans } else { aOp = blas.NoTrans } if bTrans { bOp = blas.Trans } else { bOp = blas.NoTrans } bmat := b.RawMatrix() blas64.Gemm(aOp, bOp, 1, amat, bmat, 0, w.mat) } return } } if a, ok := a.(Vectorer); ok { if b, ok := b.(Vectorer); ok { row := make([]float64, ac) col := make([]float64, br) if aTrans { if bTrans { for r := 0; r < ar; r++ { dataTmp := w.mat.Data[r*w.mat.Stride : r*w.mat.Stride+bc] for c := 0; c < bc; c++ { dataTmp[c] = blas64.Dot(ac, blas64.Vector{Inc: 1, Data: a.Col(row, r)}, blas64.Vector{Inc: 1, Data: b.Row(col, c)}, ) } } return } // TODO(jonlawlor): determine if (b*a)' is more efficient for r := 0; r < ar; r++ { dataTmp := w.mat.Data[r*w.mat.Stride : r*w.mat.Stride+bc] for c := 0; c < bc; c++ { dataTmp[c] = blas64.Dot(ac, blas64.Vector{Inc: 1, Data: a.Col(row, r)}, blas64.Vector{Inc: 1, Data: b.Col(col, c)}, ) } } return } if bTrans { for r := 0; r < ar; r++ { dataTmp := w.mat.Data[r*w.mat.Stride : r*w.mat.Stride+bc] for c := 0; c < bc; c++ { dataTmp[c] = blas64.Dot(ac, blas64.Vector{Inc: 1, Data: a.Row(row, r)}, blas64.Vector{Inc: 1, Data: b.Row(col, c)}, ) } } return } for r := 0; r < ar; r++ { dataTmp := w.mat.Data[r*w.mat.Stride : r*w.mat.Stride+bc] for c := 0; c < bc; c++ { dataTmp[c] = blas64.Dot(ac, blas64.Vector{Inc: 1, Data: a.Row(row, r)}, blas64.Vector{Inc: 1, Data: b.Col(col, c)}, ) } } return } } row := make([]float64, ac) if aTrans { if bTrans { for r := 0; r < ar; r++ { dataTmp := w.mat.Data[r*w.mat.Stride : r*w.mat.Stride+bc] for i := range row { row[i] = a.At(i, r) } for c := 0; c < bc; c++ { var v float64 for i, e := range row { v += e * b.At(c, i) } dataTmp[c] = v } } return } for r := 0; r < ar; r++ { dataTmp := w.mat.Data[r*w.mat.Stride : r*w.mat.Stride+bc] for i := range row { row[i] = a.At(i, r) } for c := 0; c < bc; c++ { var v float64 for i, e := range row { v += e * b.At(i, c) } dataTmp[c] = v } } return } if bTrans { for r := 0; r < ar; r++ { dataTmp := w.mat.Data[r*w.mat.Stride : r*w.mat.Stride+bc] for i := range row { row[i] = a.At(r, i) } for c := 0; c < bc; c++ { var v float64 for i, e := range row { v += e * b.At(c, i) } dataTmp[c] = v } } return } for r := 0; r < ar; r++ { dataTmp := w.mat.Data[r*w.mat.Stride : r*w.mat.Stride+bc] for i := range row { row[i] = a.At(r, i) } for c := 0; c < bc; c++ { var v float64 for i, e := range row { v += e * b.At(i, c) } dataTmp[c] = v } } }