Beispiel #1
0
// SymRankK performs a symmetric rank-k update to the matrix a and stores the
// result into the receiver. If a is zero, see SymOuterK.
//  s = a + alpha * x * x'
func (s *SymDense) SymRankK(a Symmetric, alpha float64, x Matrix) {
	n := a.Symmetric()
	r, _ := x.Dims()
	if r != n {
		panic(matrix.ErrShape)
	}
	xMat, aTrans := untranspose(x)
	var g blas64.General
	if rm, ok := xMat.(RawMatrixer); ok {
		g = rm.RawMatrix()
	} else {
		g = DenseCopyOf(x).mat
		aTrans = false
	}
	if a != s {
		if rs, ok := a.(RawSymmetricer); ok {
			s.checkOverlap(rs.RawSymmetric())
		}
		s.reuseAs(n)
		s.CopySym(a)
	}
	t := blas.NoTrans
	if aTrans {
		t = blas.Trans
	}
	blas64.Syrk(t, alpha, g, 1, s.mat)
}
Beispiel #2
0
// MulTrans takes the matrix product of a and b, optionally transposing each,
// and placing the result in the receiver.
//
// See the MulTranser interface for more information.
func (m *Dense) MulTrans(a Matrix, aTrans bool, b Matrix, bTrans bool) {
	ar, ac := a.Dims()
	if aTrans {
		ar, ac = ac, ar
	}

	br, bc := b.Dims()
	if bTrans {
		br, bc = bc, br
	}

	if ac != br {
		panic(ErrShape)
	}

	m.reuseAs(ar, bc)
	var w *Dense
	if m != a && m != b {
		w = m
	} else {
		w = getWorkspace(ar, bc, false)
		defer func() {
			m.Copy(w)
			putWorkspace(w)
		}()
	}

	if a, ok := a.(RawMatrixer); ok {
		if b, ok := b.(RawMatrixer); ok {
			amat := a.RawMatrix()
			if a == b && aTrans != bTrans {
				var op blas.Transpose
				if aTrans {
					op = blas.Trans
				} else {
					op = blas.NoTrans
				}
				blas64.Syrk(op, 1, amat, 0, blas64.Symmetric{N: w.Mat.Rows, Stride: w.Mat.Stride, Data: w.Mat.Data, Uplo: blas.Upper})

				// Fill lower matrix with result.
				// TODO(kortschak): Investigate whether using blas64.Copy improves the performance of this significantly.
				for i := 0; i < w.Mat.Rows; i++ {
					for j := i + 1; j < w.Mat.Cols; j++ {
						w.set(j, i, w.at(i, j))
					}
				}
			} else {
				var aOp, bOp blas.Transpose
				if aTrans {
					aOp = blas.Trans
				} else {
					aOp = blas.NoTrans
				}
				if bTrans {
					bOp = blas.Trans
				} else {
					bOp = blas.NoTrans
				}
				bmat := b.RawMatrix()
				blas64.Gemm(aOp, bOp, 1, amat, bmat, 0, w.Mat)
			}
			return
		}
	}

	if a, ok := a.(Vectorer); ok {
		if b, ok := b.(Vectorer); ok {
			row := make([]float64, ac)
			col := make([]float64, br)
			if aTrans {
				if bTrans {
					for r := 0; r < ar; r++ {
						dataTmp := w.Mat.Data[r*w.Mat.Stride : r*w.Mat.Stride+bc]
						for c := 0; c < bc; c++ {
							dataTmp[c] = blas64.Dot(ac,
								blas64.Vector{Inc: 1, Data: a.Col(row, r)},
								blas64.Vector{Inc: 1, Data: b.Row(col, c)},
							)
						}
					}
					return
				}
				// TODO(jonlawlor): determine if (b*a)' is more efficient
				for r := 0; r < ar; r++ {
					dataTmp := w.Mat.Data[r*w.Mat.Stride : r*w.Mat.Stride+bc]
					for c := 0; c < bc; c++ {
						dataTmp[c] = blas64.Dot(ac,
							blas64.Vector{Inc: 1, Data: a.Col(row, r)},
							blas64.Vector{Inc: 1, Data: b.Col(col, c)},
						)
					}
				}
				return
			}
			if bTrans {
				for r := 0; r < ar; r++ {
					dataTmp := w.Mat.Data[r*w.Mat.Stride : r*w.Mat.Stride+bc]
					for c := 0; c < bc; c++ {
						dataTmp[c] = blas64.Dot(ac,
							blas64.Vector{Inc: 1, Data: a.Row(row, r)},
							blas64.Vector{Inc: 1, Data: b.Row(col, c)},
						)
					}
				}
				return
			}
			for r := 0; r < ar; r++ {
				dataTmp := w.Mat.Data[r*w.Mat.Stride : r*w.Mat.Stride+bc]
				for c := 0; c < bc; c++ {
					dataTmp[c] = blas64.Dot(ac,
						blas64.Vector{Inc: 1, Data: a.Row(row, r)},
						blas64.Vector{Inc: 1, Data: b.Col(col, c)},
					)
				}
			}
			return
		}
	}

	row := make([]float64, ac)
	if aTrans {
		if bTrans {
			for r := 0; r < ar; r++ {
				dataTmp := w.Mat.Data[r*w.Mat.Stride : r*w.Mat.Stride+bc]
				for i := range row {
					row[i] = a.At(i, r)
				}
				for c := 0; c < bc; c++ {
					var v float64
					for i, e := range row {
						v += e * b.At(c, i)
					}
					dataTmp[c] = v
				}
			}
			return
		}

		for r := 0; r < ar; r++ {
			dataTmp := w.Mat.Data[r*w.Mat.Stride : r*w.Mat.Stride+bc]
			for i := range row {
				row[i] = a.At(i, r)
			}
			for c := 0; c < bc; c++ {
				var v float64
				for i, e := range row {
					v += e * b.At(i, c)
				}
				dataTmp[c] = v
			}
		}
		return
	}
	if bTrans {
		for r := 0; r < ar; r++ {
			dataTmp := w.Mat.Data[r*w.Mat.Stride : r*w.Mat.Stride+bc]
			for i := range row {
				row[i] = a.At(r, i)
			}
			for c := 0; c < bc; c++ {
				var v float64
				for i, e := range row {
					v += e * b.At(c, i)
				}
				dataTmp[c] = v
			}
		}
		return
	}
	for r := 0; r < ar; r++ {
		dataTmp := w.Mat.Data[r*w.Mat.Stride : r*w.Mat.Stride+bc]
		for i := range row {
			row[i] = a.At(r, i)
		}
		for c := 0; c < bc; c++ {
			var v float64
			for i, e := range row {
				v += e * b.At(i, c)
			}
			dataTmp[c] = v
		}
	}
}