// SymRankK performs a symmetric rank-k update to the matrix a and stores the // result into the receiver. If a is zero, see SymOuterK. // s = a + alpha * x * x' func (s *SymDense) SymRankK(a Symmetric, alpha float64, x Matrix) { n := a.Symmetric() r, _ := x.Dims() if r != n { panic(matrix.ErrShape) } xMat, aTrans := untranspose(x) var g blas64.General if rm, ok := xMat.(RawMatrixer); ok { g = rm.RawMatrix() } else { g = DenseCopyOf(x).mat aTrans = false } if a != s { if rs, ok := a.(RawSymmetricer); ok { s.checkOverlap(rs.RawSymmetric()) } s.reuseAs(n) s.CopySym(a) } t := blas.NoTrans if aTrans { t = blas.Trans } blas64.Syrk(t, alpha, g, 1, s.mat) }
// MulTrans takes the matrix product of a and b, optionally transposing each, // and placing the result in the receiver. // // See the MulTranser interface for more information. func (m *Dense) MulTrans(a Matrix, aTrans bool, b Matrix, bTrans bool) { ar, ac := a.Dims() if aTrans { ar, ac = ac, ar } br, bc := b.Dims() if bTrans { br, bc = bc, br } if ac != br { panic(ErrShape) } m.reuseAs(ar, bc) var w *Dense if m != a && m != b { w = m } else { w = getWorkspace(ar, bc, false) defer func() { m.Copy(w) putWorkspace(w) }() } if a, ok := a.(RawMatrixer); ok { if b, ok := b.(RawMatrixer); ok { amat := a.RawMatrix() if a == b && aTrans != bTrans { var op blas.Transpose if aTrans { op = blas.Trans } else { op = blas.NoTrans } blas64.Syrk(op, 1, amat, 0, blas64.Symmetric{N: w.Mat.Rows, Stride: w.Mat.Stride, Data: w.Mat.Data, Uplo: blas.Upper}) // Fill lower matrix with result. // TODO(kortschak): Investigate whether using blas64.Copy improves the performance of this significantly. for i := 0; i < w.Mat.Rows; i++ { for j := i + 1; j < w.Mat.Cols; j++ { w.set(j, i, w.at(i, j)) } } } else { var aOp, bOp blas.Transpose if aTrans { aOp = blas.Trans } else { aOp = blas.NoTrans } if bTrans { bOp = blas.Trans } else { bOp = blas.NoTrans } bmat := b.RawMatrix() blas64.Gemm(aOp, bOp, 1, amat, bmat, 0, w.Mat) } return } } if a, ok := a.(Vectorer); ok { if b, ok := b.(Vectorer); ok { row := make([]float64, ac) col := make([]float64, br) if aTrans { if bTrans { for r := 0; r < ar; r++ { dataTmp := w.Mat.Data[r*w.Mat.Stride : r*w.Mat.Stride+bc] for c := 0; c < bc; c++ { dataTmp[c] = blas64.Dot(ac, blas64.Vector{Inc: 1, Data: a.Col(row, r)}, blas64.Vector{Inc: 1, Data: b.Row(col, c)}, ) } } return } // TODO(jonlawlor): determine if (b*a)' is more efficient for r := 0; r < ar; r++ { dataTmp := w.Mat.Data[r*w.Mat.Stride : r*w.Mat.Stride+bc] for c := 0; c < bc; c++ { dataTmp[c] = blas64.Dot(ac, blas64.Vector{Inc: 1, Data: a.Col(row, r)}, blas64.Vector{Inc: 1, Data: b.Col(col, c)}, ) } } return } if bTrans { for r := 0; r < ar; r++ { dataTmp := w.Mat.Data[r*w.Mat.Stride : r*w.Mat.Stride+bc] for c := 0; c < bc; c++ { dataTmp[c] = blas64.Dot(ac, blas64.Vector{Inc: 1, Data: a.Row(row, r)}, blas64.Vector{Inc: 1, Data: b.Row(col, c)}, ) } } return } for r := 0; r < ar; r++ { dataTmp := w.Mat.Data[r*w.Mat.Stride : r*w.Mat.Stride+bc] for c := 0; c < bc; c++ { dataTmp[c] = blas64.Dot(ac, blas64.Vector{Inc: 1, Data: a.Row(row, r)}, blas64.Vector{Inc: 1, Data: b.Col(col, c)}, ) } } return } } row := make([]float64, ac) if aTrans { if bTrans { for r := 0; r < ar; r++ { dataTmp := w.Mat.Data[r*w.Mat.Stride : r*w.Mat.Stride+bc] for i := range row { row[i] = a.At(i, r) } for c := 0; c < bc; c++ { var v float64 for i, e := range row { v += e * b.At(c, i) } dataTmp[c] = v } } return } for r := 0; r < ar; r++ { dataTmp := w.Mat.Data[r*w.Mat.Stride : r*w.Mat.Stride+bc] for i := range row { row[i] = a.At(i, r) } for c := 0; c < bc; c++ { var v float64 for i, e := range row { v += e * b.At(i, c) } dataTmp[c] = v } } return } if bTrans { for r := 0; r < ar; r++ { dataTmp := w.Mat.Data[r*w.Mat.Stride : r*w.Mat.Stride+bc] for i := range row { row[i] = a.At(r, i) } for c := 0; c < bc; c++ { var v float64 for i, e := range row { v += e * b.At(c, i) } dataTmp[c] = v } } return } for r := 0; r < ar; r++ { dataTmp := w.Mat.Data[r*w.Mat.Stride : r*w.Mat.Stride+bc] for i := range row { row[i] = a.At(r, i) } for c := 0; c < bc; c++ { var v float64 for i, e := range row { v += e * b.At(i, c) } dataTmp[c] = v } } }