예제 #1
0
파일: matrix.go 프로젝트: yonglehou/matrix
// Solve returns a matrix x that satisfies ax = b.
// It returns a nil matrix and ErrSingular if a is singular.
func Solve(a, b Matrix) (x *Dense, err error) {
	m, n := a.Dims()
	br, _ := b.Dims()
	if m != br {
		panic("rowMismatch")
	}
	switch {
	case m == n:
		var lu LU
		lu.Factorize(a)
		if lu.Det() == 0 {
			return nil, ErrSingular
		}
		x := DenseCopyOf(b)
		lapack64.Getrs(blas.NoTrans, lu.lu.mat, x.mat, lu.pivot)
		return x, nil
	default:
		_, bc := b.Dims()
		mn := max(m, n)
		// TODO(btracey): Employ special cases to avoid the copy where possible.
		aCopy := DenseCopyOf(a)
		x := NewDense(mn, bc, nil)

		x.Copy(b)
		work := make([]float64, 1)
		lapack64.Gels(blas.NoTrans, aCopy.mat, x.mat, work, -1)
		work = make([]float64, int(work[0]))
		ok := lapack64.Gels(blas.NoTrans, aCopy.mat, x.mat, work, len(work))
		if !ok {
			return nil, ErrSingular
		}
		return x.View(0, 0, n, bc).(*Dense), nil
	}
}
예제 #2
0
파일: solve.go 프로젝트: jmptrader/matrix
// Solve solves a minimum-norm solution to a system of linear equations defined
// by the matrices a and b. If a is singular or near-singular a Condition error
// is returned. Please see the documentation for Condition for more information.
//
// The minimization problem solved depends on the input parameters.
//  1. If m >= n and trans == false, find X such that ||a*X - b||_2 is minimized.
//  2. If m < n and trans == false, find the minimum norm solution of a * X = b.
//  3. If m >= n and trans == true, find the minimum norm solution of a^T * X = b.
//  4. If m < n and trans == true, find X such that ||a*X - b||_2 is minimized.
// The solution matrix, X, is stored in place into the receiver.
func (m *Dense) Solve(a, b Matrix) error {
	ar, ac := a.Dims()
	br, bc := b.Dims()
	if ar != br {
		panic(ErrShape)
	}
	m.reuseAs(ac, bc)
	// TODO(btracey): Add a test for the condition number of A.
	// TODO(btracey): Add special cases for TriDense, SymDense, etc.
	switch {
	case ar == ac:
		if a == b {
			// x = I.
			if ar == 1 {
				m.mat.Data[0] = 1
				return nil
			}
			for i := 0; i < ar; i++ {
				v := m.mat.Data[i*m.mat.Stride : i*m.mat.Stride+ac]
				zero(v)
				v[i] = 1
			}
			return nil
		}
		// Solve using an LU decomposition.
		var lu LU
		lu.Factorize(a)
		if lu.Det() == 0 {
			return Condition(math.Inf(1))
		}
		bMat, bTrans := untranspose(b)
		if m == bMat && bTrans {
			var restore func()
			m, restore = m.isolatedWorkspace(bMat)
			defer restore()
		}
		if m != bMat {
			m.Copy(b)
		}
		lapack64.Getrs(blas.NoTrans, lu.lu.mat, m.mat, lu.pivot)
		return nil
	default:
		// Solve using QR/LQ.

		// Copy a since the corresponding LAPACK argument is modified during
		// the call.
		var aCopy Dense
		aCopy.Clone(a)

		x := getWorkspace(max(ar, ac), bc, false)
		defer putWorkspace(x)
		x.Copy(b)

		work := make([]float64, 1)
		lapack64.Gels(blas.NoTrans, aCopy.mat, x.mat, work, -1)
		work = make([]float64, int(work[0]))
		ok := lapack64.Gels(blas.NoTrans, aCopy.mat, x.mat, work, len(work))
		if !ok {
			return Condition(math.Inf(1))
		}
		m.Copy(x)
		return nil
	}
}