コード例 #1
0
ファイル: lq.go プロジェクト: gidden/cloudlus
// replaces x with Q.x
func (f LQFactor) applyQTo(x *Dense, trans bool) {
	nh, nc := f.LQ.Dims()
	m, n := x.Dims()
	if m != nc {
		panic(ErrShape)
	}
	proj := make([]float64, n)

	if trans {
		for k := nh - 1; k >= 0; k-- {
			hh := f.LQ.RawRowView(k)[k:]

			sub := x.View(k, 0, m-k, n).(*Dense)

			blas64.Gemv(blas.Trans,
				1, sub.mat, blas64.Vector{Inc: 1, Data: hh},
				0, blas64.Vector{Inc: 1, Data: proj},
			)
			for i := k; i < m; i++ {
				row := x.RawRowView(i)
				blas64.Axpy(n, -hh[i-k],
					blas64.Vector{Inc: 1, Data: proj},
					blas64.Vector{Inc: 1, Data: row},
				)
			}
		}
	} else {
		for k := 0; k < nh; k++ {
			hh := f.LQ.RawRowView(k)[k:]

			sub := x.View(k, 0, m-k, n).(*Dense)

			blas64.Gemv(blas.Trans,
				1, sub.mat, blas64.Vector{Inc: 1, Data: hh},
				0, blas64.Vector{Inc: 1, Data: proj},
			)
			for i := k; i < m; i++ {
				row := x.RawRowView(i)
				blas64.Axpy(n, -hh[i-k],
					blas64.Vector{Inc: 1, Data: proj},
					blas64.Vector{Inc: 1, Data: row},
				)
			}
		}
	}
}
コード例 #2
0
ファイル: vector.go プロジェクト: gidden/cloudlus
// MulVec computes a * b if trans == false and a^T * b if trans == true. The
// result is stored into the receiver. MulVec panics if the number of columns in
// a does not equal the number of rows in b.
func (v *Vector) MulVec(a Matrix, trans bool, b *Vector) {
	ar, ac := a.Dims()
	br := b.Len()
	if trans {
		if ar != br {
			panic(ErrShape)
		}
	} else {
		if ac != br {
			panic(ErrShape)
		}
	}

	var w Vector
	if v != a && v != b {
		w = *v
	}
	if w.n == 0 {
		if trans {
			w.mat.Data = use(w.mat.Data, ac)
		} else {
			w.mat.Data = use(w.mat.Data, ar)
		}

		w.mat.Inc = 1
		w.n = ar
		if trans {
			w.n = ac
		}
	} else {
		if trans {
			if ac != w.n {
				panic(ErrShape)
			}
		} else {
			if ar != w.n {
				panic(ErrShape)
			}
		}
	}

	switch a := a.(type) {
	case RawSymmetricer:
		amat := a.RawSymmetric()
		blas64.Symv(1, amat, b.mat, 0, w.mat)
	case RawTriangular:
		w.CopyVec(b)
		amat := a.RawTriangular()
		ta := blas.NoTrans
		if trans {
			ta = blas.Trans
		}
		blas64.Trmv(ta, amat, w.mat)
	case RawMatrixer:
		amat := a.RawMatrix()
		t := blas.NoTrans
		if trans {
			t = blas.Trans
		}
		blas64.Gemv(t, 1, amat, b.mat, 0, w.mat)
	case Vectorer:
		if trans {
			col := make([]float64, ar)
			for c := 0; c < ac; c++ {
				w.mat.Data[c*w.mat.Inc] = blas64.Dot(ar,
					blas64.Vector{Inc: 1, Data: a.Col(col, c)},
					b.mat,
				)
			}
		} else {
			row := make([]float64, ac)
			for r := 0; r < ar; r++ {
				w.mat.Data[r*w.mat.Inc] = blas64.Dot(ac,
					blas64.Vector{Inc: 1, Data: a.Row(row, r)},
					b.mat,
				)
			}
		}
	default:
		if trans {
			col := make([]float64, ar)
			for c := 0; c < ac; c++ {
				for i := range col {
					col[i] = a.At(i, c)
				}
				var f float64
				for i, e := range col {
					f += e * b.mat.Data[i*b.mat.Inc]
				}
				w.mat.Data[c*w.mat.Inc] = f
			}
		} else {
			row := make([]float64, ac)
			for r := 0; r < ar; r++ {
				for i := range row {
					row[i] = a.At(r, i)
				}
				var f float64
				for i, e := range row {
					f += e * b.mat.Data[i*b.mat.Inc]
				}
				w.mat.Data[r*w.mat.Inc] = f
			}
		}
	}
	*v = w
}