Exemplo n.º 1
0
// Mul takes the matrix product of a and b, placing the result in the receiver.
//
// See the Muler interface for more information.
func (m *Dense) Mul(a, b Matrix) {
	ar, ac := a.Dims()
	br, bc := b.Dims()

	if ac != br {
		panic(ErrShape)
	}

	m.reuseAs(ar, bc)
	var w *Dense
	if m != a && m != b {
		w = m
	} else {
		w = getWorkspace(ar, bc, false)
		defer func() {
			m.Copy(w)
			putWorkspace(w)
		}()
	}

	if a, ok := a.(RawMatrixer); ok {
		if b, ok := b.(RawMatrixer); ok {
			amat, bmat := a.RawMatrix(), b.RawMatrix()
			blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, amat, bmat, 0, w.mat)
			return
		}
	}

	if a, ok := a.(Vectorer); ok {
		if b, ok := b.(Vectorer); ok {
			row := make([]float64, ac)
			col := make([]float64, br)
			for r := 0; r < ar; r++ {
				dataTmp := w.mat.Data[r*w.mat.Stride : r*w.mat.Stride+bc]
				for c := 0; c < bc; c++ {
					dataTmp[c] = blas64.Dot(ac,
						blas64.Vector{Inc: 1, Data: a.Row(row, r)},
						blas64.Vector{Inc: 1, Data: b.Col(col, c)},
					)
				}
			}
			return
		}
	}

	row := make([]float64, ac)
	for r := 0; r < ar; r++ {
		for i := range row {
			row[i] = a.At(r, i)
		}
		for c := 0; c < bc; c++ {
			var v float64
			for i, e := range row {
				v += e * b.At(i, c)
			}
			w.mat.Data[r*w.mat.Stride+c] = v
		}
	}
}
Exemplo n.º 2
0
// TODO: Need to add tests where one is overwritten.
func TestMul(t *testing.T) {
	for _, test := range []struct {
		ar     int
		ac     int
		br     int
		bc     int
		Panics bool
	}{
		{
			ar:     5,
			ac:     5,
			br:     5,
			bc:     5,
			Panics: false,
		},
		{
			ar:     10,
			ac:     5,
			br:     5,
			bc:     3,
			Panics: false,
		},
		{
			ar:     10,
			ac:     5,
			br:     5,
			bc:     8,
			Panics: false,
		},
		{
			ar:     8,
			ac:     10,
			br:     10,
			bc:     3,
			Panics: false,
		},
		{
			ar:     8,
			ac:     3,
			br:     3,
			bc:     10,
			Panics: false,
		},
		{
			ar:     5,
			ac:     8,
			br:     8,
			bc:     10,
			Panics: false,
		},
		{
			ar:     5,
			ac:     12,
			br:     12,
			bc:     8,
			Panics: false,
		},
		{
			ar:     5,
			ac:     7,
			br:     8,
			bc:     10,
			Panics: true,
		},
	} {
		ar := test.ar
		ac := test.ac
		br := test.br
		bc := test.bc

		// Generate random matrices
		avec := make([]float64, ar*ac)
		randomSlice(avec)
		a := NewDense(ar, ac, avec)

		bvec := make([]float64, br*bc)
		randomSlice(bvec)

		b := NewDense(br, bc, bvec)

		// Check that it panics if it is supposed to
		if test.Panics {
			c := NewDense(0, 0, nil)
			fn := func() {
				c.Mul(a, b)
			}
			pan, _ := panics(fn)
			if !pan {
				t.Errorf("Mul did not panic with dimension mismatch")
			}
			continue
		}

		cvec := make([]float64, ar*bc)

		// Get correct matrix multiply answer from blas64.Gemm
		blas64.Gemm(blas.NoTrans, blas.NoTrans,
			1, a.mat, b.mat,
			0, blas64.General{Rows: ar, Cols: bc, Stride: bc, Data: cvec},
		)

		avecCopy := append([]float64{}, avec...)
		bvecCopy := append([]float64{}, bvec...)
		cvecCopy := append([]float64{}, cvec...)

		acomp := matComp{r: ar, c: ac, data: avecCopy}
		bcomp := matComp{r: br, c: bc, data: bvecCopy}
		ccomp := matComp{r: ar, c: bc, data: cvecCopy}

		// Do normal multiply with empty dense
		d := NewDense(0, 0, nil)

		testMul(t, a, b, d, acomp, bcomp, ccomp, false, "zero receiver")

		// Normal multiply with existing receiver
		c := NewDense(ar, bc, cvec)
		randomSlice(cvec)
		testMul(t, a, b, c, acomp, bcomp, ccomp, false, "existing receiver")

		// Test with vectorers
		avm := (*basicVectorer)(a)
		bvm := (*basicVectorer)(b)
		d.Reset()
		testMul(t, avm, b, d, acomp, bcomp, ccomp, true, "a vectoror with zero receiver")
		d.Reset()
		testMul(t, a, bvm, d, acomp, bcomp, ccomp, true, "b vectoror with zero receiver")
		d.Reset()
		testMul(t, avm, bvm, d, acomp, bcomp, ccomp, true, "both vectoror with zero receiver")
		randomSlice(cvec)
		testMul(t, avm, b, c, acomp, bcomp, ccomp, true, "a vectoror with existing receiver")
		randomSlice(cvec)
		testMul(t, a, bvm, c, acomp, bcomp, ccomp, true, "b vectoror with existing receiver")
		randomSlice(cvec)
		testMul(t, avm, bvm, c, acomp, bcomp, ccomp, true, "both vectoror with existing receiver")

		// Cast a as a basic matrix
		am := (*basicMatrix)(a)
		bm := (*basicMatrix)(b)
		d.Reset()
		testMul(t, am, b, d, acomp, bcomp, ccomp, true, "a is basic, receiver is zero")
		d.Reset()
		testMul(t, a, bm, d, acomp, bcomp, ccomp, true, "b is basic, receiver is zero")
		d.Reset()
		testMul(t, am, bm, d, acomp, bcomp, ccomp, true, "both basic, receiver is zero")
		randomSlice(cvec)
		testMul(t, am, b, d, acomp, bcomp, ccomp, true, "a is basic, receiver is full")
		randomSlice(cvec)
		testMul(t, a, bm, d, acomp, bcomp, ccomp, true, "b is basic, receiver is full")
		randomSlice(cvec)
		testMul(t, am, bm, d, acomp, bcomp, ccomp, true, "both basic, receiver is full")
	}
}
Exemplo n.º 3
0
// MulTrans takes the matrix product of a and b, optionally transposing each,
// and placing the result in the receiver.
//
// See the MulTranser interface for more information.
func (m *Dense) MulTrans(a Matrix, aTrans bool, b Matrix, bTrans bool) {
	ar, ac := a.Dims()
	if aTrans {
		ar, ac = ac, ar
	}

	br, bc := b.Dims()
	if bTrans {
		br, bc = bc, br
	}

	if ac != br {
		panic(ErrShape)
	}

	m.reuseAs(ar, bc)
	var w *Dense
	if m != a && m != b {
		w = m
	} else {
		w = getWorkspace(ar, bc, false)
		defer func() {
			m.Copy(w)
			putWorkspace(w)
		}()
	}

	if a, ok := a.(RawMatrixer); ok {
		if b, ok := b.(RawMatrixer); ok {
			amat := a.RawMatrix()
			if a == b && aTrans != bTrans {
				var op blas.Transpose
				if aTrans {
					op = blas.Trans
				} else {
					op = blas.NoTrans
				}
				blas64.Syrk(op, 1, amat, 0, blas64.Symmetric{N: w.mat.Rows, Stride: w.mat.Stride, Data: w.mat.Data, Uplo: blas.Upper})

				// Fill lower matrix with result.
				// TODO(kortschak): Investigate whether using blas64.Copy improves the performance of this significantly.
				for i := 0; i < w.mat.Rows; i++ {
					for j := i + 1; j < w.mat.Cols; j++ {
						w.set(j, i, w.at(i, j))
					}
				}
			} else {
				var aOp, bOp blas.Transpose
				if aTrans {
					aOp = blas.Trans
				} else {
					aOp = blas.NoTrans
				}
				if bTrans {
					bOp = blas.Trans
				} else {
					bOp = blas.NoTrans
				}
				bmat := b.RawMatrix()
				blas64.Gemm(aOp, bOp, 1, amat, bmat, 0, w.mat)
			}
			return
		}
	}

	if a, ok := a.(Vectorer); ok {
		if b, ok := b.(Vectorer); ok {
			row := make([]float64, ac)
			col := make([]float64, br)
			if aTrans {
				if bTrans {
					for r := 0; r < ar; r++ {
						dataTmp := w.mat.Data[r*w.mat.Stride : r*w.mat.Stride+bc]
						for c := 0; c < bc; c++ {
							dataTmp[c] = blas64.Dot(ac,
								blas64.Vector{Inc: 1, Data: a.Col(row, r)},
								blas64.Vector{Inc: 1, Data: b.Row(col, c)},
							)
						}
					}
					return
				}
				// TODO(jonlawlor): determine if (b*a)' is more efficient
				for r := 0; r < ar; r++ {
					dataTmp := w.mat.Data[r*w.mat.Stride : r*w.mat.Stride+bc]
					for c := 0; c < bc; c++ {
						dataTmp[c] = blas64.Dot(ac,
							blas64.Vector{Inc: 1, Data: a.Col(row, r)},
							blas64.Vector{Inc: 1, Data: b.Col(col, c)},
						)
					}
				}
				return
			}
			if bTrans {
				for r := 0; r < ar; r++ {
					dataTmp := w.mat.Data[r*w.mat.Stride : r*w.mat.Stride+bc]
					for c := 0; c < bc; c++ {
						dataTmp[c] = blas64.Dot(ac,
							blas64.Vector{Inc: 1, Data: a.Row(row, r)},
							blas64.Vector{Inc: 1, Data: b.Row(col, c)},
						)
					}
				}
				return
			}
			for r := 0; r < ar; r++ {
				dataTmp := w.mat.Data[r*w.mat.Stride : r*w.mat.Stride+bc]
				for c := 0; c < bc; c++ {
					dataTmp[c] = blas64.Dot(ac,
						blas64.Vector{Inc: 1, Data: a.Row(row, r)},
						blas64.Vector{Inc: 1, Data: b.Col(col, c)},
					)
				}
			}
			return
		}
	}

	row := make([]float64, ac)
	if aTrans {
		if bTrans {
			for r := 0; r < ar; r++ {
				dataTmp := w.mat.Data[r*w.mat.Stride : r*w.mat.Stride+bc]
				for i := range row {
					row[i] = a.At(i, r)
				}
				for c := 0; c < bc; c++ {
					var v float64
					for i, e := range row {
						v += e * b.At(c, i)
					}
					dataTmp[c] = v
				}
			}
			return
		}

		for r := 0; r < ar; r++ {
			dataTmp := w.mat.Data[r*w.mat.Stride : r*w.mat.Stride+bc]
			for i := range row {
				row[i] = a.At(i, r)
			}
			for c := 0; c < bc; c++ {
				var v float64
				for i, e := range row {
					v += e * b.At(i, c)
				}
				dataTmp[c] = v
			}
		}
		return
	}
	if bTrans {
		for r := 0; r < ar; r++ {
			dataTmp := w.mat.Data[r*w.mat.Stride : r*w.mat.Stride+bc]
			for i := range row {
				row[i] = a.At(r, i)
			}
			for c := 0; c < bc; c++ {
				var v float64
				for i, e := range row {
					v += e * b.At(c, i)
				}
				dataTmp[c] = v
			}
		}
		return
	}
	for r := 0; r < ar; r++ {
		dataTmp := w.mat.Data[r*w.mat.Stride : r*w.mat.Stride+bc]
		for i := range row {
			row[i] = a.At(r, i)
		}
		for c := 0; c < bc; c++ {
			var v float64
			for i, e := range row {
				v += e * b.At(i, c)
			}
			dataTmp[c] = v
		}
	}
}