Esempio n. 1
0
// MulVec computes a * b. The result is stored into the receiver.
// MulVec panics if the number of columns in a does not equal the number of rows in b.
func (v *Vector) MulVec(a Matrix, b *Vector) {
	r, c := a.Dims()
	br := b.Len()
	if c != br {
		panic(ErrShape)
	}
	a, trans := untranspose(a)
	ar, ac := a.Dims()
	v.reuseAs(r)
	var restore func()
	if v == a {
		v, restore = v.isolatedWorkspace(a.(*Vector))
		defer restore()
	} else if v == b {
		v, restore = v.isolatedWorkspace(b)
		defer restore()
	}

	switch a := a.(type) {
	case *Vector:
		if a.Len() == 1 {
			// {1,1} x {1,n}
			av := a.At(0, 0)
			for i := 0; i < b.Len(); i++ {
				v.mat.Data[i*v.mat.Inc] = av * b.mat.Data[i*b.mat.Inc]
			}
			return
		}
		if b.Len() == 1 {
			// {1,n} x {1,1}
			bv := b.At(0, 0)
			for i := 0; i < a.Len(); i++ {
				v.mat.Data[i*v.mat.Inc] = bv * a.mat.Data[i*a.mat.Inc]
			}
			return
		}
		// {n,1} x {1,n}
		var sum float64
		for i := 0; i < c; i++ {
			sum += a.At(i, 0) * b.At(i, 0)
		}
		v.SetVec(0, sum)
		return
	case RawSymmetricer:
		amat := a.RawSymmetric()
		blas64.Symv(1, amat, b.mat, 0, v.mat)
	case RawTriangular:
		v.CopyVec(b)
		amat := a.RawTriangular()
		ta := blas.NoTrans
		if trans {
			ta = blas.Trans
		}
		blas64.Trmv(ta, amat, v.mat)
	case RawMatrixer:
		amat := a.RawMatrix()
		t := blas.NoTrans
		if trans {
			t = blas.Trans
		}
		blas64.Gemv(t, 1, amat, b.mat, 0, v.mat)
	case Vectorer:
		if trans {
			col := make([]float64, ar)
			for c := 0; c < ac; c++ {
				v.mat.Data[c*v.mat.Inc] = blas64.Dot(ar,
					blas64.Vector{Inc: 1, Data: a.Col(col, c)},
					b.mat,
				)
			}
		} else {
			row := make([]float64, ac)
			for r := 0; r < ar; r++ {
				v.mat.Data[r*v.mat.Inc] = blas64.Dot(ac,
					blas64.Vector{Inc: 1, Data: a.Row(row, r)},
					b.mat,
				)
			}
		}
	default:
		if trans {
			col := make([]float64, ar)
			for c := 0; c < ac; c++ {
				for i := range col {
					col[i] = a.At(i, c)
				}
				var f float64
				for i, e := range col {
					f += e * b.mat.Data[i*b.mat.Inc]
				}
				v.mat.Data[c*v.mat.Inc] = f
			}
		} else {
			row := make([]float64, ac)
			for r := 0; r < ar; r++ {
				for i := range row {
					row[i] = a.At(r, i)
				}
				var f float64
				for i, e := range row {
					f += e * b.mat.Data[i*b.mat.Inc]
				}
				v.mat.Data[r*v.mat.Inc] = f
			}
		}
	}
}
Esempio n. 2
0
// MulVec computes a * b if trans == false and a^T * b if trans == true. The
// result is stored into the receiver. MulVec panics if the number of columns in
// a does not equal the number of rows in b.
func (v *Vector) MulVec(a Matrix, trans bool, b *Vector) {
	ar, ac := a.Dims()
	br := b.Len()
	if trans {
		if ar != br {
			panic(ErrShape)
		}
	} else {
		if ac != br {
			panic(ErrShape)
		}
	}

	var w Vector
	if v != a && v != b {
		w = *v
	}
	if w.n == 0 {
		if trans {
			w.mat.Data = use(w.mat.Data, ac)
		} else {
			w.mat.Data = use(w.mat.Data, ar)
		}

		w.mat.Inc = 1
		w.n = ar
		if trans {
			w.n = ac
		}
	} else {
		if trans {
			if ac != w.n {
				panic(ErrShape)
			}
		} else {
			if ar != w.n {
				panic(ErrShape)
			}
		}
	}

	switch a := a.(type) {
	case RawSymmetricer:
		amat := a.RawSymmetric()
		blas64.Symv(1, amat, b.mat, 0, w.mat)
	case RawTriangular:
		w.CopyVec(b)
		amat := a.RawTriangular()
		ta := blas.NoTrans
		if trans {
			ta = blas.Trans
		}
		blas64.Trmv(ta, amat, w.mat)
	case RawMatrixer:
		amat := a.RawMatrix()
		t := blas.NoTrans
		if trans {
			t = blas.Trans
		}
		blas64.Gemv(t, 1, amat, b.mat, 0, w.mat)
	case Vectorer:
		if trans {
			col := make([]float64, ar)
			for c := 0; c < ac; c++ {
				w.mat.Data[c*w.mat.Inc] = blas64.Dot(ar,
					blas64.Vector{Inc: 1, Data: a.Col(col, c)},
					b.mat,
				)
			}
		} else {
			row := make([]float64, ac)
			for r := 0; r < ar; r++ {
				w.mat.Data[r*w.mat.Inc] = blas64.Dot(ac,
					blas64.Vector{Inc: 1, Data: a.Row(row, r)},
					b.mat,
				)
			}
		}
	default:
		if trans {
			col := make([]float64, ar)
			for c := 0; c < ac; c++ {
				for i := range col {
					col[i] = a.At(i, c)
				}
				var f float64
				for i, e := range col {
					f += e * b.mat.Data[i*b.mat.Inc]
				}
				w.mat.Data[c*w.mat.Inc] = f
			}
		} else {
			row := make([]float64, ac)
			for r := 0; r < ar; r++ {
				for i := range row {
					row[i] = a.At(r, i)
				}
				var f float64
				for i, e := range row {
					f += e * b.mat.Data[i*b.mat.Inc]
				}
				w.mat.Data[r*w.mat.Inc] = f
			}
		}
	}
	*v = w
}