예제 #1
0
파일: syevd.go 프로젝트: jvlmdr/linalg
func checkSyevd(ind *linalg.IndexOpts, A, W matrix.Matrix) error {
	arows := ind.LDa
	if ind.N < 0 {
		ind.N = A.Rows()
		if ind.N != A.Cols() {
			return onError("Syevd: A not square")
		}
	}
	if ind.N == 0 {
		return nil
	}
	if ind.LDa == 0 {
		ind.LDa = max(1, A.LeadingIndex())
		arows = max(1, A.Rows())
	}
	if ind.LDa < max(1, ind.N) {
		return onError("Syevd: lda")
	}
	if ind.OffsetA < 0 {
		return onError("Syevd: offsetA")
	}
	sizeA := A.NumElements()
	if sizeA < ind.OffsetA+(ind.N-1)*arows+ind.N {
		return onError("Syevd: sizeA")
	}
	if ind.OffsetW < 0 {
		return onError("Syevd: offsetW")
	}
	sizeW := W.NumElements()
	if sizeW < ind.OffsetW+ind.N {
		return onError("Syevd: sizeW")
	}
	return nil
}
예제 #2
0
파일: sytrf.go 프로젝트: jvlmdr/linalg
func checkSytrf(ind *linalg.IndexOpts, A matrix.Matrix, ipiv []int32) error {
	arows := ind.LDa
	if ind.N < 0 {
		ind.N = A.Rows()
		if ind.N != A.Cols() {
			return onError("A not square")
		}
	}
	if ind.N == 0 {
		return nil
	}
	if ind.LDa == 0 {
		ind.LDa = max(1, A.LeadingIndex())
		arows = max(1, A.Rows())
	}
	if ind.LDa < max(1, ind.N) {
		return onError("Sytrf: lda")
	}
	if ind.OffsetA < 0 {
		return onError("Sytrf: offsetA")
	}
	sizeA := A.NumElements()
	if sizeA < ind.OffsetA+(ind.N-1)*arows+ind.N {
		return onError("Sytrf: sizeA")
	}
	if ipiv != nil && len(ipiv) < ind.N {
		return onError("Sytrf: size ipiv")
	}
	return nil
}
예제 #3
0
파일: potrf.go 프로젝트: jvlmdr/linalg
func checkPotrf(ind *linalg.IndexOpts, A matrix.Matrix) error {
	arows := ind.LDa
	if ind.N < 0 {
		ind.N = A.Rows()
		if ind.N != A.Cols() {
			return onError("Potrf: not square")
		}
	}
	if ind.N == 0 {
		return nil
	}
	if ind.LDa == 0 {
		ind.LDa = max(1, A.LeadingIndex())
		arows = max(1, A.Rows())
	}
	if ind.LDa < max(1, ind.N) {
		return onError("Potrf: lda")
	}
	if ind.OffsetA < 0 {
		return onError("Potrf: offsetA")
	}
	if A.NumElements() < ind.OffsetA+(ind.N-1)*arows+ind.N {
		return onError("Potrf: sizeA")
	}
	return nil
}
예제 #4
0
파일: gels.go 프로젝트: jvlmdr/linalg
/*
 Solves a general real or complex set of linear equations.

 PURPOSE

 Solves A*X=B with A m by n real or complex.

 ARGUMENTS.
  A         float or complex matrix
  B         float or complex matrix.  Must have the same type as A.

 OPTIONS:
  trans
  m         nonnegative integer.  If negative, the default value is used.
  n         nonnegative integer.  If negative, the default value is used.
  nrhs      nonnegative integer.  If negative, the default value is used.
  ldA       positive integer.  ldA >= max(1,n).  If zero, the default value is used.
  ldB       positive integer.  ldB >= max(1,n).  If zero, the default value is used.
*/
func Gels(A, B matrix.Matrix, opts ...linalg.Option) error {
	pars, _ := linalg.GetParameters(opts...)
	ind := linalg.GetIndexOpts(opts...)
	arows := ind.LDa
	brows := ind.LDb
	if ind.M < 0 {
		ind.M = A.Rows()
	}
	if ind.N < 0 {
		ind.N = A.Cols()
	}
	if ind.Nrhs < 0 {
		ind.Nrhs = B.Cols()
	}
	if ind.M == 0 || ind.N == 0 || ind.Nrhs == 0 {
		return nil
	}
	if ind.LDa == 0 {
		ind.LDa = max(1, A.LeadingIndex())
		arows = max(1, A.Rows())
	}
	if ind.LDa < max(1, ind.M) {
		return onError("Gesv: ldA")
	}
	if ind.LDb == 0 {
		ind.LDb = max(1, B.LeadingIndex())
		brows = max(1, B.Rows())
	}
	if ind.LDb < max(ind.M, ind.N) {
		return onError("Gesv: ldB")
	}
	if !matrix.EqualTypes(A, B) {
		return onError("Gesv: arguments not of same type")
	}
	_, _ = arows, brows // todo!! something
	info := -1
	trans := linalg.ParamString(pars.Trans)
	switch A.(type) {
	case *matrix.FloatMatrix:
		Aa := A.(*matrix.FloatMatrix).FloatArray()
		Ba := B.(*matrix.FloatMatrix).FloatArray()
		info = dgels(trans, ind.M, ind.N, ind.Nrhs, Aa[ind.OffsetA:], ind.LDa,
			Ba[ind.OffsetB:], ind.LDb)
	case *matrix.ComplexMatrix:
		Aa := A.(*matrix.ComplexMatrix).ComplexArray()
		Ba := B.(*matrix.ComplexMatrix).ComplexArray()
		info = zgels(trans, ind.M, ind.N, ind.Nrhs, Aa[ind.OffsetA:], ind.LDa,
			Ba[ind.OffsetB:], ind.LDb)
	}
	if info != 0 {
		return onError(fmt.Sprintf("Gels: lapack error: %d", info))
	}
	return nil
}
예제 #5
0
파일: gbtrs.go 프로젝트: jvlmdr/linalg
func checkGbtrs(ind *linalg.IndexOpts, A, B matrix.Matrix, ipiv []int32) error {
	arows := ind.LDa
	brows := ind.LDb
	if ind.Kl < 0 {
		return onError("Gbtrs: invalid kl")
	}
	if ind.N < 0 {
		ind.N = A.Rows()
	}
	if ind.Nrhs < 0 {
		ind.Nrhs = A.Cols()
	}
	if ind.N == 0 || ind.Nrhs == 0 {
		return nil
	}
	if ind.Ku < 0 {
		ind.Ku = A.Rows() - 2*ind.Kl - 1
	}
	if ind.Ku < 0 {
		return onError("Gbtrs: invalid ku")
	}
	if ind.LDa == 0 {
		ind.LDa = max(1, A.LeadingIndex())
		arows = max(1, A.Rows())
	}
	if ind.LDa < 2*ind.Kl+ind.Ku+1 {
		return onError("Gbtrs: lda")
	}
	if ind.OffsetA < 0 {
		return onError("Gbtrs: offsetA")
	}
	sizeA := A.NumElements()
	if sizeA < ind.OffsetA+(ind.N-1)*arows+2*ind.Kl+ind.Ku+1 {
		return onError("Gbtrs: sizeA")
	}
	if ind.LDb == 0 {
		ind.LDb = max(1, B.LeadingIndex())
		brows = max(1, B.Rows())
	}
	if ind.OffsetB < 0 {
		return onError("Gbtrs: offsetB")
	}
	sizeB := B.NumElements()
	if sizeB < ind.OffsetB+(ind.Nrhs-1)*brows+ind.N {
		return onError("Gbtrs: sizeB")
	}
	if ipiv != nil && len(ipiv) < ind.N {
		return onError("Gbtrs: size ipiv")
	}
	return nil
}
예제 #6
0
파일: geqrf.go 프로젝트: jvlmdr/linalg
/*
 QR factorization.

 PURPOSE

 QR factorization of an m by n real or complex matrix A:

  A = Q*R = [Q1 Q2] * [R1; 0] if m >= n
  A = Q*R = Q * [R1 R2]       if m <= n,

 where Q is m by m and orthogonal/unitary and R is m by n with R1
 upper triangular.  On exit, R is stored in the upper triangular
 part of A.  Q is stored as a product of k=min(m,n) elementary
 reflectors.  The parameters of the reflectors are stored in the
 first k entries of tau and in the lower triangular part of the
 first k columns of A.

 ARGUMENTS
  A         float or complex matrix
  tau       float or complex  matrix of length at least min(m,n).  Must
            have the same type as A.
  m         integer.  If negative, the default value is used.
  n         integer.  If negative, the default value is used.
  ldA       nonnegative integer.  ldA >= max(1,m).  If zero, the
            default value is used.
  offsetA   nonnegative integer

*/
func Geqrf(A, tau matrix.Matrix, opts ...linalg.Option) error {
	ind := linalg.GetIndexOpts(opts...)
	arows := ind.LDa
	if ind.N < 0 {
		ind.N = A.Cols()
	}
	if ind.M < 0 {
		ind.M = A.Rows()
	}
	if ind.N == 0 || ind.M == 0 {
		return nil
	}
	if ind.LDa == 0 {
		ind.LDa = max(1, A.LeadingIndex())
		arows = max(1, A.Rows())
	}
	if ind.LDa < max(1, ind.M) {
		return onError("Geqrf: ldA")
	}
	if ind.OffsetA < 0 {
		return onError("Geqrf: offsetA")
	}
	if A.NumElements() < ind.OffsetA+ind.K*arows {
		return onError("Geqrf: sizeA")
	}
	if tau.NumElements() < min(ind.M, ind.N) {
		return onError("Geqrf: sizeTau")
	}
	if !matrix.EqualTypes(A, tau) {
		return onError("Geqrf: arguments not of same type")
	}
	info := -1
	switch A.(type) {
	case *matrix.FloatMatrix:
		Aa := A.(*matrix.FloatMatrix).FloatArray()
		taua := tau.(*matrix.FloatMatrix).FloatArray()
		info = dgeqrf(ind.M, ind.N, Aa[ind.OffsetA:], ind.LDa, taua)
	case *matrix.ComplexMatrix:
		return onError("Geqrf: complex not yet implemented")
	}
	if info != 0 {
		return onError(fmt.Sprintf("Geqrf lapack error: %d", info))
	}
	return nil
}
예제 #7
0
파일: posv.go 프로젝트: jvlmdr/linalg
func checkPosv(ind *linalg.IndexOpts, A, B matrix.Matrix) error {
	arows := ind.LDa
	brows := ind.LDb
	if ind.N < 0 {
		ind.N = A.Rows()
	}
	if ind.Nrhs < 0 {
		ind.Nrhs = B.Cols()
	}
	if ind.N == 0 || ind.Nrhs == 0 {
		return nil
	}
	if ind.LDa == 0 {
		ind.LDa = max(1, A.LeadingIndex())
		arows = max(1, A.Rows())
	}
	if ind.LDa < max(1, ind.N) {
		return onError("Posv: lda")
	}
	if ind.LDb == 0 {
		ind.LDb = max(1, B.LeadingIndex())
		brows = max(1, B.Rows())
	}
	if ind.LDb < max(1, ind.N) {
		return onError("Posv: ldb")
	}
	if ind.OffsetA < 0 {
		return onError("Posv: offsetA")
	}
	sizeA := A.NumElements()
	if sizeA < ind.OffsetA+(ind.N-1)*arows+ind.N {
		return onError("Posv: sizeA")
	}
	if ind.OffsetB < 0 {
		return onError("Posv: offsetB")
	}
	sizeB := B.NumElements()
	if sizeB < ind.OffsetB+(ind.Nrhs-1)*brows+ind.N {
		return onError("Posv: sizeB")
	}
	return nil
}
예제 #8
0
파일: getrf.go 프로젝트: jvlmdr/linalg
/*
 LU factorization of a general real or complex m by n matrix.

 PURPOSE

 On exit, A is replaced with L, U in the factorization P*A = L*U
 and ipiv contains the permutation:
 P = P_min{m,n} * ... * P2 * P1 where Pi interchanges rows i and
 ipiv[i] of A (using the Fortran convention, i.e., the first row
 is numbered 1).

 ARGUMENTS
  A         float or complex matrix
  ipiv      int vector of length at least min(m,n)

 OPTIONS
  m         nonnegative integer.  If negative, the default value is used.
  n         nonnegative integer.  If negative, the default value is used.
  ldA       positive integer.  ldA >= max(1,m).  If zero, the default
            value is used.
  offsetA   nonnegative integer

*/
func Getrf(A matrix.Matrix, ipiv []int32, opts ...linalg.Option) error {
	ind := linalg.GetIndexOpts(opts...)
	arows := ind.LDa
	if ind.M < 0 {
		ind.M = A.Rows()
	}
	if ind.N < 0 {
		ind.N = A.Cols()
	}
	if ind.N == 0 || ind.M == 0 {
		return nil
	}
	if ind.LDa == 0 {
		ind.LDa = max(1, A.LeadingIndex())
		arows = max(1, A.Rows())
	}
	if ind.LDa < max(1, ind.M) {
		return onError("lda")
	}
	if ind.OffsetA < 0 {
		return onError("offsetA")
	}
	sizeA := A.NumElements()
	if sizeA < ind.OffsetA+(ind.N-1)*arows+ind.M {
		return onError("sizeA")
	}
	if ipiv != nil && len(ipiv) < min(ind.N, ind.M) {
		return onError("size ipiv")
	}
	info := -1
	switch A.(type) {
	case *matrix.FloatMatrix:
		Aa := A.(*matrix.FloatMatrix).FloatArray()
		info = dgetrf(ind.M, ind.N, Aa[ind.OffsetA:], ind.LDa, ipiv)
	case *matrix.ComplexMatrix:
	}
	if info != 0 {
		return onError("Getrf call error")
	}
	return nil
}
예제 #9
0
파일: getri.go 프로젝트: jvlmdr/linalg
/*
 Inverse of a real or complex matrix.

 PURPOSE

 Computes the inverse of real or complex matrix of order n.  On
 entry, A and ipiv contain the LU factorization, as returned by
 gesv() or getrf().  On exit A is replaced by the inverse.

 ARGUMENTS
  A         float or complex matrix
  ipiv      int vector

 OPTIONS
  n         nonnegative integer.  If negative, the default value is used.
  ldA       positive integer.  ldA >= max(1,n).  If zero, the default
            value is used.
  offsetA   nonnegative integer;
*/
func Getri(A matrix.Matrix, ipiv []int32, opts ...linalg.Option) error {
	ind := linalg.GetIndexOpts(opts...)
	arows := ind.LDa
	if ind.N < 0 {
		ind.N = A.Cols()
	}
	if ind.N == 0 {
		return nil
	}
	if ind.LDa == 0 {
		ind.LDa = max(1, A.LeadingIndex())
		arows = max(1, A.Rows())
	}
	if ind.OffsetA < 0 {
		return onError("Getri: offset")
	}
	sizeA := A.NumElements()
	if sizeA < ind.OffsetA+(ind.N-1)*arows+ind.N {
		return onError("Getri: sizeA")
	}
	if ipiv != nil && len(ipiv) < ind.N {
		return onError("Getri: size ipiv")
	}
	info := -1
	switch A.(type) {
	case *matrix.FloatMatrix:
		Aa := A.(*matrix.FloatMatrix).FloatArray()
		info = dgetri(ind.N, Aa[ind.OffsetA:], ind.LDa, ipiv)
	case *matrix.ComplexMatrix:
		return onError("Getri: complex not yet implemented")
	}
	if info != 0 {
		return onError(fmt.Sprintf("Getri lapack error: %d", info))
	}
	return nil
}
예제 #10
0
파일: gesvd.go 프로젝트: jvlmdr/linalg
func checkGesvd(ind *linalg.IndexOpts, pars *linalg.Parameters, A, S, U, Vt matrix.Matrix) error {
	arows := ind.LDa
	if ind.M < 0 {
		ind.M = A.Rows()
	}
	if ind.N < 0 {
		ind.N = A.Cols()
	}
	if ind.M == 0 || ind.N == 0 {
		return nil
	}
	if pars.Jobu == linalg.PJobO && pars.Jobvt == linalg.PJobO {
		return onError("Gesvd: jobu and jobvt cannot both have value PJobO")
	}
	if pars.Jobu == linalg.PJobAll || pars.Jobu == linalg.PJobS {
		if U == nil {
			return onError("Gesvd: missing matrix U")
		}
		if ind.LDu == 0 {
			ind.LDu = max(1, U.LeadingIndex())
		}
		if ind.LDu < max(1, ind.M) {
			return onError("Gesvd: ldU")
		}
	} else {
		if ind.LDu == 0 {
			ind.LDu = 1
		}
		if ind.LDu < 1 {
			return onError("Gesvd: ldU")
		}
	}
	if pars.Jobvt == linalg.PJobAll || pars.Jobvt == linalg.PJobS {
		if Vt == nil {
			return onError("Gesvd: missing matrix Vt")
		}
		if ind.LDvt == 0 {
			ind.LDvt = max(1, Vt.LeadingIndex())
		}
		if pars.Jobvt == linalg.PJobAll && ind.LDvt < max(1, ind.N) {
			return onError("Gesvd: ldVt")
		} else if pars.Jobvt != linalg.PJobAll && ind.LDvt < max(1, min(ind.M, ind.N)) {
			return onError("Gesvd: ldVt")
		}
	} else {
		if ind.LDvt == 0 {
			ind.LDvt = 1
		}
		if ind.LDvt < 1 {
			return onError("Gesvd: ldVt")
		}
	}
	if ind.OffsetA < 0 {
		return onError("Gesvd: offsetA")
	}
	sizeA := A.NumElements()
	if ind.LDa == 0 {
		ind.LDa = max(1, A.LeadingIndex())
		arows = max(1, A.Rows())
	}
	if sizeA < ind.OffsetA+(ind.N-1)*arows+ind.M {
		return onError("Gesvd: sizeA")
	}

	if ind.OffsetS < 0 {
		return onError("Gesvd: offsetS")
	}
	sizeS := S.NumElements()
	if sizeS < ind.OffsetS+min(ind.M, ind.N) {
		return onError("Gesvd: sizeA")
	}

	/*
		if U != nil {
			if ind.OffsetU < 0 {
				return onError("Gesvd: OffsetU")
			}
			sizeU := U.NumElements()
			if pars.Jobu == linalg.PJobAll && sizeU < ind.LDu*(ind.M-1) {
				return onError("Gesvd: sizeU")
			} else if pars.Jobu == linalg.PJobS && sizeU < ind.LDu*(min(ind.M,ind.N)-1) {
				return onError("Gesvd: sizeU")
			}
		}

		if Vt != nil {
			if ind.OffsetVt < 0 {
				return onError("Gesvd: OffsetVt")
			}
			sizeVt := Vt.NumElements()
			if pars.Jobvt == linalg.PJobAll && sizeVt <  ind.N {
				return onError("Gesvd: sizeVt")
			} else if pars.Jobvt == linalg.PJobS && sizeVt < min(ind.M, ind.N) {
				return onError("Gesvd: sizeVt")
			}
		}
	*/
	return nil
}
예제 #11
0
파일: gesv.go 프로젝트: jvlmdr/linalg
/*
 Solves a general real or complex set of linear equations.

 PURPOSE

 Solves A*X=B with A n by n real or complex.

 If ipiv is provided, then on exit A is overwritten with the details
 of the LU factorization, and ipiv contains the permutation matrix.
 If ipiv is not provided, then gesv() does not return the
 factorization and does not modify A.  On exit B is replaced with
 the solution X.

 ARGUMENTS.
  A         float or complex matrix
  B         float or complex matrix.  Must have the same type as A.
  ipiv      int vector of length at least n

 OPTIONS:
  n         nonnegative integer.  If negative, the default value is used.
  nrhs      nonnegative integer.  If negative, the default value is used.
  ldA       positive integer.  ldA >= max(1,n).  If zero, the default value is used.
  ldB       positive integer.  ldB >= max(1,n).  If zero, the default value is used.
  offsetA   nonnegative integer
  offsetA   nonnegative integer;
*/
func Gesv(A, B matrix.Matrix, ipiv []int32, opts ...linalg.Option) error {
	//pars, err := linalg.GetParameters(opts...)
	ind := linalg.GetIndexOpts(opts...)
	arows := ind.LDa
	brows := ind.LDb
	if ind.N < 0 {
		ind.N = A.Rows()
		if ind.N != A.Cols() {
			return onError("Gesv: A not square")
		}
	}
	if ind.Nrhs < 0 {
		ind.Nrhs = B.Cols()
	}
	if ind.N == 0 || ind.Nrhs == 0 {
		return nil
	}
	if ind.LDa == 0 {
		ind.LDa = max(1, A.LeadingIndex())
		arows = max(1, A.Rows())
	}
	if ind.LDa < max(1, ind.N) {
		return onError("Gesv: ldA")
	}
	if ind.LDb == 0 {
		ind.LDb = max(1, B.LeadingIndex())
		brows = max(1, B.Rows())
	}
	if ind.LDb < max(1, ind.N) {
		return onError("Gesv: ldB")
	}
	if ind.OffsetA < 0 {
		return onError("Gesv: offsetA")
	}
	if ind.OffsetB < 0 {
		return onError("Gesv: offsetB")
	}
	sizeA := A.NumElements()
	if sizeA < ind.OffsetA+(ind.N-1)*arows+ind.N {
		return onError("Gesv: sizeA")
	}
	sizeB := B.NumElements()
	if sizeB < ind.OffsetB+(ind.Nrhs-1)*brows+ind.N {
		return onError("Gesv: sizeB")
	}
	if ipiv != nil && len(ipiv) < ind.N {
		return onError("Gesv: size ipiv")
	}
	if !matrix.EqualTypes(A, B) {
		return onError("Gesv: arguments not of same type")
	}
	info := -1
	if ipiv == nil {
		ipiv = make([]int32, ind.N)
		// Do not overwrite A.
		A = A.MakeCopy()
	}
	switch A.(type) {
	case *matrix.FloatMatrix:
		Aa := A.(*matrix.FloatMatrix).FloatArray()
		Aa = Aa[ind.OffsetA:]
		// Ensure there are sufficient elements in A.
		Aa = Aa[:ind.LDa*ind.LDb]
		Ba := B.(*matrix.FloatMatrix).FloatArray()
		Ba = Ba[ind.OffsetB:]
		info = dgesv(ind.N, ind.Nrhs, Aa, ind.LDa, ipiv, Ba, ind.LDb)
	case *matrix.ComplexMatrix:
		Aa := A.(*matrix.ComplexMatrix).ComplexArray()
		Aa = Aa[ind.OffsetA:]
		// Ensure there are sufficient elements in A.
		Aa = Aa[:ind.LDa*ind.LDb]
		Ba := B.(*matrix.ComplexMatrix).ComplexArray()
		Ba = Ba[ind.OffsetB:]
		info = zgesv(ind.N, ind.Nrhs, Aa, ind.LDa, ipiv, Ba, ind.LDb)
	}
	if info != 0 {
		return onError(fmt.Sprintf("Gesv: lapack error: %d", info))
	}
	return nil
}
예제 #12
0
파일: gbtrs.go 프로젝트: jvlmdr/linalg
/*
 Solves a real or complex set of linear equations with a banded
 coefficient matrix, given the LU factorization computed by gbtrf()
 or gbsv().

 PURPOSE

 Solves linear equations
  A*X = B,   if trans is PNoTrans
  A^T*X = B, if trans is PTrans
  A^H*X = B, if trans is PConjTrans

 On entry, A and ipiv contain the LU factorization of an n by n
 band matrix A as computed by Getrf() or Gbsv().  On exit B is
 replaced by the solution X.

 ARGUMENTS
  A         float or complex matrix
  B         float or complex  matrix.  Must have the same type as A.
  ipiv      int vector
  kl        nonnegative integer

 OPTIONS
  trans     PNoTrans, PTrans or PConjTrans
  n         nonnegative integer.  If negative, the default value is used.
  ku        nonnegative integer.  If negative, the default value is used.
  nrhs      nonnegative integer.  If negative, the default value is used.
  ldA       positive integer, ldA >= 2*kl+ku+1. If zero, the  default value is used.
  ldB       positive integer, ldB >= max(1,n). If zero, the default value is used.
  offsetA   nonnegative integer
  offsetB   nonnegative integer;
*/
func Gbtrs(A, B matrix.Matrix, ipiv []int32, KL int, opts ...linalg.Option) error {
	pars, err := linalg.GetParameters(opts...)
	if err != nil {
		return err
	}
	ind := linalg.GetIndexOpts(opts...)
	ind.Kl = KL
	arows := ind.LDa
	brows := ind.LDb
	if ind.Kl < 0 {
		return onError("Gbtrs: invalid kl")
	}
	if ind.N < 0 {
		ind.N = A.Rows()
	}
	if ind.Nrhs < 0 {
		ind.Nrhs = A.Cols()
	}
	if ind.N == 0 || ind.Nrhs == 0 {
		return nil
	}
	if ind.Ku < 0 {
		ind.Ku = A.Rows() - 2*ind.Kl - 1
	}
	if ind.Ku < 0 {
		return onError("Gbtrs: invalid ku")
	}
	if ind.LDa == 0 {
		ind.LDa = max(1, A.LeadingIndex())
		arows = max(1, A.Rows())
	}
	if ind.LDa < 2*ind.Kl+ind.Ku+1 {
		return onError("Gbtrs: ldA")
	}
	if ind.OffsetA < 0 {
		return onError("Gbtrs: offsetA")
	}
	sizeA := A.NumElements()
	if sizeA < ind.OffsetA+(ind.N-1)*arows+2*ind.Kl+ind.Ku+1 {
		return onError("Gbtrs: sizeA")
	}
	if ind.LDb == 0 {
		ind.LDb = max(1, B.LeadingIndex())
		brows = max(1, B.Rows())
	}
	if ind.OffsetB < 0 {
		return onError("Gbtrs: offsetB")
	}
	sizeB := B.NumElements()
	if sizeB < ind.OffsetB+(ind.Nrhs-1)*brows+ind.N {
		return onError("Gbtrs: sizeB")
	}
	if ipiv != nil && len(ipiv) < ind.N {
		return onError("Gbtrs: size ipiv")
	}

	if !matrix.EqualTypes(A, B) {
		return onError("Gbtrs: arguments not of same type")
	}
	info := -1
	switch A.(type) {
	case *matrix.FloatMatrix:
		Aa := A.(*matrix.FloatMatrix).FloatArray()
		Ba := B.(*matrix.FloatMatrix).FloatArray()
		trans := linalg.ParamString(pars.Trans)
		info = dgbtrs(trans, ind.N, ind.Kl, ind.Ku, ind.Nrhs,
			Aa[ind.OffsetA:], ind.LDa, ipiv, Ba[ind.OffsetB:], ind.LDb)
	case *matrix.ComplexMatrix:
		return onError("Gbtrs: complex not yet implemented")
	}
	if info != 0 {
		return onError(fmt.Sprintf("Gbtrs lapack error: %d", info))
	}
	return nil
}
예제 #13
0
파일: gttrs.go 프로젝트: jvlmdr/linalg
/*
 Solves a real or complex tridiagonal set of linear equations,
 given the LU factorization computed by gttrf().

 PURPOSE
  solves A*X=B,   if trans is PNoTrans
  solves A^T*X=B, if trans is PTrans
  solves A^H*X=B, if trans is PConjTrans

 On entry, DL, D, DU, DU2 and ipiv contain the LU factorization of
 an n by n tridiagonal matrix A as computed by gttrf().  On exit B
 is replaced by the solution X.

 ARGUMENTS.
  DL        float or complex matrix
  D         float or complex matrix.  Must have the same type as dl.
  DU        float or complex matrix.  Must have the same type as dl.
  DU2       float or complex matrix.  Must have the same type as dl.
  B         float or complex matrix.  Must have the same type oas dl.
  ipiv      int vector

 OPTIONS
  trans     PNoTrans, PTrans, PConjTrans
  n         nonnegative integer.  If negative, the default value is used.
  nrhs      nonnegative integer.  If negative, the default value is used.
  ldB       positive integer, ldB >= max(1,n). If zero, the default value is used.
  offsetdl  nonnegative integer
  offsetd   nonnegative integer
  offsetdu  nonnegative integer
  offsetB   nonnegative integer

*/
func Gtrrs(DL, D, DU, DU2, B matrix.Matrix, ipiv []int32, opts ...linalg.Option) error {
	pars, err := linalg.GetParameters(opts...)
	if err != nil {
		return err
	}
	ind := linalg.GetIndexOpts(opts...)
	brows := ind.LDb
	if ind.OffsetD < 0 {
		return onError("Gttrs: offset D")
	}
	if ind.N < 0 {
		ind.N = D.NumElements() - ind.OffsetD
	}
	if ind.N < 0 {
		return onError("Gttrs: size D")
	}
	if ind.N == 0 {
		return nil
	}
	if ind.OffsetDL < 0 {
		return onError("Gttrs: offset DL")
	}
	sizeDL := DL.NumElements()
	if sizeDL < ind.OffsetDL+ind.N-1 {
		return onError("Gttrs: sizeDL")
	}
	if ind.OffsetDU < 0 {
		return onError("Gttrs: offset DU")
	}
	sizeDU := DU.NumElements()
	if sizeDU < ind.OffsetDU+ind.N-1 {
		return onError("Gttrs: sizeDU")
	}
	sizeDU2 := DU2.NumElements()
	if sizeDU2 < ind.N-2 {
		return onError("Gttrs: sizeDU2")
	}
	if ind.Nrhs < 0 {
		ind.Nrhs = B.Cols()
	}
	if ind.Nrhs == 0 {
		return nil
	}
	if ind.LDb == 0 {
		ind.LDb = max(1, B.LeadingIndex())
		brows = max(1, B.Rows())
	}
	if ind.LDb < max(1, ind.N) {
		return onError("Gttrs: ldB")
	}
	if ind.OffsetB < 0 {
		return onError("Gttrs: offset B")
	}
	sizeB := B.NumElements()
	if sizeB < ind.OffsetB+(ind.Nrhs-1)*brows+ind.N {
		return onError("Gttrs: sizeB")
	}
	if len(ipiv) < ind.N {
		return onError("Gttrs: size ipiv")
	}
	if !matrix.EqualTypes(DL, D, DU, DU2, B) {
		return onError("Gttrs: matrix types")
	}
	var info int = -1
	switch DL.(type) {
	case *matrix.FloatMatrix:
		DLa := DL.(*matrix.FloatMatrix).FloatArray()
		Da := D.(*matrix.FloatMatrix).FloatArray()
		DUa := DU.(*matrix.FloatMatrix).FloatArray()
		DU2a := DU2.(*matrix.FloatMatrix).FloatArray()
		Ba := B.(*matrix.FloatMatrix).FloatArray()
		trans := linalg.ParamString(pars.Trans)
		info = dgttrs(trans, ind.N, ind.Nrhs,
			DLa[ind.OffsetDL:], Da[ind.OffsetD:], DUa[ind.OffsetDU:], DU2a,
			ipiv, Ba[ind.OffsetB:], ind.LDb)
	case *matrix.ComplexMatrix:
		return onError("Gttrs: complex valued not yet implemented")
	}
	if info != 0 {
		return onError(fmt.Sprintf("Gttrs lapack error: %d", info))
	}
	return nil
}
예제 #14
0
파일: indexcheck.go 프로젝트: jvlmdr/linalg
func check_level2_func(ind *linalg.IndexOpts, fn funcNum, X, Y, A matrix.Matrix, pars *linalg.Parameters) error {
	if ind.IncX <= 0 {
		return onError("incX")
	}
	if ind.IncY <= 0 {
		return onError("incY")
	}

	sizeA := A.NumElements()
	arows := ind.LDa
	switch fn {
	case fgemv: // general matrix
		if ind.M < 0 {
			ind.M = A.Rows()
		}
		if ind.N < 0 {
			ind.N = A.Cols()
		}
		if ind.LDa == 0 {
			ind.LDa = max(1, A.LeadingIndex())
			arows = max(1, A.Rows())
		}
		if ind.OffsetA < 0 {
			return onError("offsetA")
		}
		if ind.N > 0 && ind.M > 0 &&
			sizeA < ind.OffsetA+(ind.N-1)*arows+ind.M {
			return onError("sizeA")
		}
		if ind.OffsetX < 0 {
			return onError("offsetX")
		}
		if ind.OffsetY < 0 {
			return onError("offsetY")
		}
		sizeX := X.NumElements()
		sizeY := Y.NumElements()
		if pars.Trans == linalg.PNoTrans {
			if ind.N > 0 && sizeX < ind.OffsetX+(ind.N-1)*abs(ind.IncX)+1 {
				return onError("sizeX")
			}
			if ind.M > 0 && sizeY < ind.OffsetY+(ind.M-1)*abs(ind.IncY)+1 {
				return onError("sizeY")
			}
		} else {
			if ind.M > 0 && sizeX < ind.OffsetX+(ind.M-1)*abs(ind.IncX)+1 {
				return onError("sizeX")
			}
			if ind.N > 0 && sizeY < ind.OffsetY+(ind.N-1)*abs(ind.IncY)+1 {
				return onError("sizeY")
			}
		}
	case fger:
		if ind.M < 0 {
			ind.M = A.Rows()
		}
		if ind.N < 0 {
			ind.N = A.Cols()
		}
		if ind.M == 0 || ind.N == 0 {
			return nil
		}
		if ind.M > 0 && ind.N > 0 {
			if ind.LDa == 0 {
				ind.LDa = max(1, A.LeadingIndex())
				arows = max(1, A.Rows())
			}
			if ind.LDa < max(1, ind.M) {
				return onError("ldA")
			}
			if ind.OffsetA < 0 {
				return onError("offsetA")
			}
			if sizeA < ind.OffsetA+(ind.N-1)*arows+ind.M {
				return onError("sizeA")
			}
			if ind.OffsetX < 0 {
				return onError("offsetX")
			}
			if ind.OffsetY < 0 {
				return onError("offsetY")
			}
			sizeX := X.NumElements()
			if sizeX < ind.OffsetX+(ind.M-1)*abs(ind.IncX)+1 {
				return onError("sizeX")
			}
			sizeY := Y.NumElements()
			if sizeY < ind.OffsetY+(ind.N-1)*abs(ind.IncY)+1 {
				return onError("sizeY")
			}
		}
	case fgbmv: // general banded
		if ind.M < 0 {
			ind.M = A.Rows()
		}
		if ind.N < 0 {
			ind.N = A.Cols()
		}
		if ind.Kl < 0 {
			return onError("kl")
		}
		if ind.Ku < 0 {
			ind.Ku = A.Rows() - 1 - ind.Kl
		}
		if ind.Ku < 0 {
			return onError("ku")
		}
		if ind.LDa == 0 {
			ind.LDa = max(1, A.LeadingIndex())
			arows = max(1, A.Rows())
		}
		if ind.LDa < ind.Kl+ind.Ku+1 {
			return onError("ldA")
		}
		if ind.OffsetA < 0 {
			return onError("offsetA")
		}
		sizeA := A.NumElements()
		if ind.N > 0 && ind.M > 0 &&
			sizeA < ind.OffsetA+(ind.N-1)*arows+ind.Kl+ind.Ku+1 {
			return onError("sizeA")
		}
		if ind.OffsetX < 0 {
			return onError("offsetX")
		}
		if ind.OffsetY < 0 {
			return onError("offsetY")
		}
		sizeX := X.NumElements()
		sizeY := Y.NumElements()
		if pars.Trans == linalg.PNoTrans {
			if ind.N > 0 && sizeX < ind.OffsetX+(ind.N-1)*abs(ind.IncX)+1 {
				return onError("sizeX")
			}
			if ind.N > 0 && sizeY < ind.OffsetY+(ind.M-1)*abs(ind.IncY)+1 {
				return onError("sizeY")
			}
		} else {
			if ind.N > 0 && sizeX < ind.OffsetX+(ind.M-1)*abs(ind.IncX)+1 {
				return onError("sizeX")
			}
			if ind.N > 0 && sizeY < ind.OffsetY+(ind.N-1)*abs(ind.IncY)+1 {
				return onError("sizeY")
			}
		}
	case ftrmv, ftrsv:
		// ftrmv = triangular
		// ftrsv = triangular solve
		if ind.N < 0 {
			if A.Rows() != A.Cols() {
				return onError("A not square")
			}
			ind.N = A.Rows()
		}
		if ind.N > 0 {
			if ind.LDa == 0 {
				ind.LDa = max(1, A.LeadingIndex())
				arows = max(1, A.Rows())
			}
			if ind.LDa < max(1, ind.N) {
				return onError("ldA")
			}
			if ind.OffsetA < 0 {
				return onError("offsetA")
			}
			sizeA := A.NumElements()
			if sizeA < ind.OffsetA+(ind.N-1)*arows+ind.N {
				return onError("sizeA")
			}
			sizeX := X.NumElements()
			if sizeX < ind.OffsetX+(ind.N-1)*abs(ind.IncX)+1 {
				return onError("sizeX")
			}
		}
	case ftbmv, ftbsv, fsbmv:
		// ftbmv = triangular banded
		// ftbsv = triangular banded solve
		// fsbmv = symmetric banded product
		arows := ind.LDa
		if ind.N < 0 {
			ind.N = A.Rows()
		}
		if ind.N > 0 {
			if ind.K < 0 {
				ind.K = max(0, A.Rows()-1)
			}
			if ind.LDa == 0 {
				ind.LDa = max(1, A.LeadingIndex())
				arows = max(1, A.Rows())
			}
			if ind.LDa < ind.K+1 {
				return onError("ldA")
			}
			if ind.OffsetA < 0 {
				return onError("offsetA")
			}
			sizeA := A.NumElements()
			if sizeA < ind.OffsetA+(ind.N-1)*arows+ind.K+1 {
				return onError("sizeA")
			}
			sizeX := X.NumElements()
			if sizeX < ind.OffsetX+(ind.N-1)*abs(ind.IncX)+1 {
				return onError("sizeX")
			}
			if Y != nil {
				sizeY := Y.NumElements()
				if sizeY < ind.OffsetY+(ind.N-1)*abs(ind.IncY)+1 {
					return onError("sizeY")
				}
			}
		}
	case fsymv, fsyr, fsyr2:
		// fsymv = symmetric product
		// fsyr = symmetric rank update
		// fsyr2 = symmetric rank-2 update
		if ind.N < 0 {
			if A.Rows() != A.Cols() {
				return onError("A not square")
			}
			ind.N = A.Rows()
		}
		arows := ind.LDa
		if ind.N > 0 {
			if ind.LDa == 0 {
				ind.LDa = max(1, A.LeadingIndex())
				arows = max(1, A.Rows())
			}
			if ind.LDa < max(1, ind.N) {
				return onError("ldA")
			}
			if ind.OffsetA < 0 {
				return onError("offsetA")
			}
			sizeA := A.NumElements()
			if sizeA < ind.OffsetA+(ind.N-1)*arows+ind.N {
				return onError("sizeA")
			}
			if ind.OffsetX < 0 {
				return onError("offsetX")
			}
			sizeX := X.NumElements()
			if sizeX < ind.OffsetX+(ind.N-1)*abs(ind.IncX)+1 {
				return onError("sizeX")
			}
			if Y != nil {
				if ind.OffsetY < 0 {
					return onError("offsetY")
				}
				sizeY := Y.NumElements()
				if sizeY < ind.OffsetY+(ind.N-1)*abs(ind.IncY)+1 {
					return onError("sizeY")
				}
			}
		}
	case fspr, fdspr2, ftpsv, fspmv, ftpmv:
		// ftpsv = triangular packed solve
		// fspmv = symmetric packed product
		// ftpmv = triangular packed
	}
	return nil
}
예제 #15
0
파일: indexcheck.go 프로젝트: jvlmdr/linalg
func check_level3_func(ind *linalg.IndexOpts, fn funcNum, A, B, C matrix.Matrix,
	pars *linalg.Parameters) (err error) {

	// defaults for these
	arows := ind.LDa
	brows := ind.LDb
	crows := ind.LDc

	switch fn {
	case fgemm:
		if ind.M < 0 {
			if pars.TransA == linalg.PNoTrans {
				ind.M = A.Rows()
			} else {
				ind.M = A.Cols()
			}
		}
		if ind.N < 0 {
			if pars.TransB == linalg.PNoTrans {
				ind.N = B.Cols()
			} else {
				ind.N = B.Rows()
			}
		}
		if ind.M == 0 || ind.N == 0 {
			return nil
		}
		if ind.K < 0 {
			if pars.TransA == linalg.PNoTrans {
				ind.K = A.Cols()
			} else {
				ind.K = A.Rows()
			}
			if pars.TransB == linalg.PNoTrans && ind.K != B.Rows() ||
				pars.TransB != linalg.PNoTrans && ind.K != B.Cols() {
				return onError("dimensions of A and B do not match")
			}
		}
		if ind.OffsetA < 0 {
			return onError("offsetA illegal, <0")
		}
		if ind.LDa == 0 {
			ind.LDa = max(1, A.LeadingIndex())
			arows = max(1, A.Rows())
		}
		if ind.K > 0 {
			if (pars.TransA == linalg.PNoTrans && ind.LDa < max(1, ind.M)) ||
				(pars.TransA != linalg.PNoTrans && ind.LDa < max(1, ind.K)) {
				return onError("inconsistent ldA")
			}
			sizeA := A.NumElements()
			if (pars.TransA == linalg.PNoTrans &&
				sizeA < ind.OffsetA+(ind.K-1)*arows+ind.M) ||
				(pars.TransA != linalg.PNoTrans &&
					sizeA < ind.OffsetA+(ind.M-1)*arows+ind.K) {
				return onError("sizeA")
			}
		}
		// B matrix
		if ind.OffsetB < 0 {
			return onError("offsetB illegal, <0")
		}
		if ind.LDb == 0 {
			ind.LDb = max(1, B.LeadingIndex())
			brows = max(1, B.Rows())
		}
		if ind.K > 0 {
			if (pars.TransB == linalg.PNoTrans && ind.LDb < max(1, ind.K)) ||
				(pars.TransB != linalg.PNoTrans && ind.LDb < max(1, ind.N)) {
				return onError("inconsistent ldB")
			}
			sizeB := B.NumElements()
			if (pars.TransB == linalg.PNoTrans &&
				sizeB < ind.OffsetB+(ind.N-1)*brows+ind.K) ||
				(pars.TransB != linalg.PNoTrans &&
					sizeB < ind.OffsetB+(ind.K-1)*brows+ind.N) {
				return onError("sizeB")
			}
		}
		// C matrix
		if ind.OffsetC < 0 {
			return onError("offsetC illegal, <0")
		}
		if ind.LDc == 0 {
			ind.LDc = max(1, C.LeadingIndex())
			crows = max(1, C.Rows())
		}
		if ind.LDc < max(1, ind.M) {
			return onError("inconsistent ldC")
		}
		sizeC := C.NumElements()
		if sizeC < ind.OffsetC+(ind.N-1)*crows+ind.M {
			return onError("sizeC")
		}

	case fsymm, ftrmm, ftrsm:
		if ind.M < 0 {
			ind.M = B.Rows()
			if pars.Side == linalg.PLeft && (ind.M != A.Rows() || ind.M != A.Cols()) {
				return onError("dimensions of A and B do not match")
			}
		}
		if ind.N < 0 {
			ind.N = B.Cols()
			if pars.Side == linalg.PRight && (ind.N != A.Rows() || ind.N != A.Cols()) {
				return onError("dimensions of A and B do not match")
			}
		}
		if ind.M == 0 || ind.N == 0 {
			return
		}
		// check A
		if ind.OffsetB < 0 {
			return onError("offsetB illegal, <0")
		}
		if ind.LDa == 0 {
			ind.LDa = max(1, A.LeadingIndex())
			arows = max(1, A.Rows())
		}
		if pars.Side == linalg.PLeft && ind.LDa < max(1, ind.M) || ind.LDa < max(1, ind.N) {
			return onError("ldA")
		}
		sizeA := A.NumElements()
		if (pars.Side == linalg.PLeft && sizeA < ind.OffsetA+(ind.M-1)*arows+ind.M) ||
			(pars.Side == linalg.PRight && sizeA < ind.OffsetA+(ind.N-1)*arows+ind.N) {
			return onError("sizeA")
		}

		if B != nil {
			if ind.OffsetB < 0 {
				return onError("offsetB illegal, <0")
			}
			if ind.LDb == 0 {
				ind.LDb = max(1, B.LeadingIndex())
				brows = max(1, B.Rows())
			}
			if ind.LDb < max(1, ind.M) {
				return onError("ldB")
			}
			sizeB := B.NumElements()
			if sizeB < ind.OffsetB+(ind.N-1)*brows+ind.M {
				return onError("sizeB")
			}
		}

		if C != nil {
			if ind.OffsetC < 0 {
				return onError("offsetC illegal, <0")
			}
			if ind.LDc == 0 {
				ind.LDc = max(1, C.LeadingIndex())
				crows = max(1, C.Rows())
			}
			if ind.LDc < max(1, ind.M) {
				return onError("ldC")
			}
			sizeC := C.NumElements()
			if sizeC < ind.OffsetC+(ind.N-1)*crows+ind.M {
				return onError("sizeC")
			}
		}
	case fsyrk:
		if ind.N < 0 {
			if pars.Trans == linalg.PNoTrans {
				ind.N = A.Rows()
			} else {
				ind.N = A.Cols()
			}
			//ind.N = C.Rows()
		}
		if ind.K < 0 {
			if pars.Trans == linalg.PNoTrans {
				ind.K = A.Cols()
			} else {
				ind.K = A.Rows()
			}
		}
		if ind.N == 0 {
			return
		}
		if ind.LDa == 0 {
			ind.LDa = max(1, A.LeadingIndex())
			arows = max(1, A.Rows())
		}
		if ind.OffsetA < 0 {
			return onError("offsetA")
		}
		if ind.K > 0 {
			if (pars.Trans == linalg.PNoTrans && ind.LDa < max(1, ind.N)) ||
				(pars.Trans != linalg.PNoTrans && ind.LDa < max(1, ind.K)) {
				return onError("inconsistent ldA")
			}
			sizeA := A.NumElements()
			if (pars.Trans == linalg.PNoTrans &&
				sizeA < ind.OffsetA+(ind.K-1)*arows+ind.N) ||
				(pars.TransA != linalg.PNoTrans &&
					sizeA < ind.OffsetA+(ind.N-1)*arows+ind.K) {
				return onError("sizeA")
			}
		}

		if ind.OffsetC < 0 {
			return onError("offsetC illegal, <0")
		}
		if ind.LDc == 0 {
			ind.LDc = max(1, C.LeadingIndex())
			crows = max(1, C.Rows())
		}
		if ind.LDc < max(1, ind.N) {
			return onError("ldC")
		}
		sizeC := C.NumElements()
		if sizeC < ind.OffsetC+(ind.N-1)*crows+ind.N {
			return onError("sizeC")
		}
	case fsyr2k:
		if ind.N < 0 {
			if pars.Trans == linalg.PNoTrans {
				ind.N = A.Rows()
				if ind.N != B.Rows() {
					return onError("dimensions of A and B do not match")
				}
			} else {
				ind.N = A.Cols()
				if ind.N != B.Cols() {
					return onError("dimensions of A and B do not match")
				}
			}
		}
		if ind.N == 0 {
			return
		}
		if ind.K < 0 {
			if pars.Trans == linalg.PNoTrans {
				ind.K = A.Cols()
				if ind.K != B.Cols() {
					return onError("dimensions of A and B do not match")
				}
			} else {
				ind.K = A.Rows()
				if ind.K != B.Rows() {
					return onError("dimensions of A and B do not match")
				}
			}
		}
		if ind.LDa == 0 {
			ind.LDa = max(1, A.LeadingIndex())
			arows = max(1, A.Rows())
		}
		if ind.K > 0 {
			if (pars.Trans == linalg.PNoTrans && ind.LDa < max(1, ind.N)) ||
				(pars.Trans != linalg.PNoTrans && ind.LDa < max(1, ind.K)) {
				return onError("inconsistent ldA")
			}
			sizeA := A.NumElements()
			if (pars.Trans == linalg.PNoTrans &&
				sizeA < ind.OffsetA+(ind.K-1)*arows+ind.N) ||
				(pars.TransA != linalg.PNoTrans &&
					sizeA < ind.OffsetA+(ind.N-1)*arows+ind.K) {
				return onError("sizeA")
			}
		}
		if ind.OffsetB < 0 {
			return onError("offsetB illegal, <0")
		}
		if ind.LDb == 0 {
			ind.LDb = max(1, B.LeadingIndex())
			brows = max(1, B.Rows())
		}
		if ind.K > 0 {
			if (pars.Trans == linalg.PNoTrans && ind.LDb < max(1, ind.N)) ||
				(pars.Trans != linalg.PNoTrans && ind.LDb < max(1, ind.K)) {
				return onError("ldB")
			}
			sizeB := B.NumElements()
			if (pars.Trans == linalg.PNoTrans &&
				sizeB < ind.OffsetB+(ind.K-1)*brows+ind.N) ||
				(pars.Trans != linalg.PNoTrans &&
					sizeB < ind.OffsetB+(ind.N-1)*brows+ind.K) {
				return onError("sizeB")
			}
		}
		if ind.OffsetC < 0 {
			return onError("offsetC illegal, <0")
		}
		if ind.LDc == 0 {
			ind.LDc = max(1, C.LeadingIndex())
			crows = max(1, C.Rows())
		}
		if ind.LDc < max(1, ind.N) {
			return onError("ldC")
		}
		sizeC := C.NumElements()
		if sizeC < ind.OffsetC+(ind.N-1)*crows+ind.N {
			return onError("sizeC")
		}
	}
	err = nil
	return
}
예제 #16
0
파일: sytrs.go 프로젝트: jvlmdr/linalg
/*
 Solves a real or complex symmetric set of linear equations,
 given the LDL^T factorization computed by sytrf() or sysv().

 PURPOSE
 Solves
  A*X = B

 where A is real or complex symmetric and n by n,
 and B is n by nrhs.  On entry, A and ipiv contain the
 factorization of A as returned by Sytrf() or Sysv().  On exit, B is
 replaced by the solution.

 ARGUMENTS
  A         float or complex matrix
  B         float or complex matrix.  Must have the same type as A.
  ipiv      int vector

 OPTIONS
  uplo      PLower or PUpper
  n         nonnegative integer.  If negative, the default value is used.
  nrhs      nonnegative integer.  If negative, the default value is used.
  ldA       positive integer.  ldA >= max(1,n).  If zero, the default
            value is used.
  ldB       nonnegative integer.  ldB >= max(1,n).  If zero, the
            default value is used.
  offsetA   nonnegative integer
  offsetB   nonnegative integer;

*/
func Sytrs(A, B matrix.Matrix, ipiv []int32, opts ...linalg.Option) error {
	pars, err := linalg.GetParameters(opts...)
	if err != nil {
		return err
	}
	ind := linalg.GetIndexOpts(opts...)
	arows := ind.LDa
	brows := ind.LDb
	if ind.N < 0 {
		ind.N = A.Rows()
		if ind.N != A.Cols() {
			return onError("Sytrs: A not square")
		}
	}
	if ind.Nrhs < 0 {
		ind.Nrhs = B.Cols()
	}
	if ind.N == 0 || ind.Nrhs == 0 {
		return nil
	}
	if ind.LDa == 0 {
		ind.LDa = max(1, A.LeadingIndex())
		arows = max(1, A.Rows())
	}
	if ind.LDa < max(1, ind.N) {
		return onError("Sytrs: ldA")
	}
	if ind.LDb == 0 {
		ind.LDb = max(1, B.LeadingIndex())
		brows = max(1, B.Rows())
	}
	if ind.LDb < max(1, ind.N) {
		return onError("Sytrs: ldB")
	}
	if ind.OffsetA < 0 {
		return onError("Sytrs: offsetA")
	}
	sizeA := A.NumElements()
	if sizeA < ind.OffsetA+(ind.N-1)*arows+ind.N {
		return onError("Sytrs: sizeA")
	}
	if ind.OffsetB < 0 {
		return onError("Sytrs: offsetB")
	}
	sizeB := B.NumElements()
	if sizeB < ind.OffsetB+(ind.Nrhs-1)*brows+ind.N {
		return onError("Sytrs: sizeB")
	}
	if ipiv != nil && len(ipiv) < ind.N {
		return onError("Sytrs: size ipiv")
	}
	if !matrix.EqualTypes(A, B) {
		return onError("Sytrs: arguments not of same type")
	}
	info := -1
	switch A.(type) {
	case *matrix.FloatMatrix:
		Aa := A.(*matrix.FloatMatrix).FloatArray()
		Ba := B.(*matrix.FloatMatrix).FloatArray()
		uplo := linalg.ParamString(pars.Uplo)
		info = dsytrs(uplo, ind.N, ind.Nrhs, Aa[ind.OffsetA:], ind.LDa, ipiv,
			Ba[ind.OffsetB:], ind.LDb)
	case *matrix.ComplexMatrix:
		return onError("Sytrs: complex not yet implemented")
	}
	if info != 0 {
		return onError(fmt.Sprintf("Sytrs lapack error: %d", info))
	}
	return nil
}
예제 #17
0
파일: ormqr.go 프로젝트: jvlmdr/linalg
/*
 Product with a real orthogonal matrix.

 PURPOSE

 Computes
  C := Q*C   if side = PLeft  and trans = PNoTrans
  C := Q^T*C if side = PLeft  and trans = PTrans
  C := C*Q   if side = PRight and trans = PNoTrans
  C := C*Q^T if side = PRight and trans = PTrans

 C is m by n and Q is a square orthogonal matrix computed by geqrf.

 Q is defined as a product of k elementary reflectors, stored as
 the first k columns of A and the first k entries of tau.

 ARGUMENTS
  A         float matrix
  tau       float matrix of length at least k
  C         float matrix

 OPTIONS
  side      PLeft or PRight
  trans     PNoTrans or PTrans
  m         integer.  If negative, the default value is used.
  n         integer.  If negative, the default value is used.
  k         integer.  k <= m if side = PRight and k <= n if side = PLeft.
            If negative, the default value is used.
  ldA       nonnegative integer.  ldA >= max(1,m) if side = PLeft
            and ldA >= max(1,n) if side = PRight.  If zero, the
            default value is used.
  ldC       nonnegative integer.  ldC >= max(1,m).  If zero, the
            default value is used.
  offsetA   nonnegative integer
  offsetB   nonnegative integer

*/
func Ormqr(A, tau, C matrix.Matrix, opts ...linalg.Option) error {
	pars, err := linalg.GetParameters(opts...)
	if err != nil {
		return err
	}
	ind := linalg.GetIndexOpts(opts...)
	arows := ind.LDa
	crows := ind.LDc
	if ind.N < 0 {
		ind.N = C.Cols()
	}
	if ind.M < 0 {
		ind.M = C.Rows()
	}
	if ind.K < 0 {
		ind.K = tau.NumElements()
	}
	if ind.N == 0 || ind.M == 0 || ind.K == 0 {
		return nil
	}
	if ind.LDa == 0 {
		ind.LDa = max(1, A.LeadingIndex())
		arows = max(1, A.Rows())
	}
	if ind.LDc == 0 {
		ind.LDc = max(1, C.LeadingIndex())
		crows = max(1, C.Rows())
	}
	switch pars.Side {
	case linalg.PLeft:
		if ind.K > ind.M {
			onError("Ormqf: K")
		}
		if ind.LDa < max(1, ind.M) {
			return onError("Ormqf: ldA")
		}
	case linalg.PRight:
		if ind.K > ind.N {
			onError("Ormqf: K")
		}
		if ind.LDa < max(1, ind.N) {
			return onError("Ormqf: ldA")
		}
	}
	if ind.OffsetA < 0 {
		return onError("Ormqf: offsetA")
	}
	if A.NumElements() < ind.OffsetA+ind.K*arows {
		return onError("Ormqf: sizeA")
	}
	if ind.OffsetC < 0 {
		return onError("Ormqf: offsetC")
	}
	if C.NumElements() < ind.OffsetC+(ind.N-1)*crows+ind.M {
		return onError("Ormqf: sizeC")
	}
	if tau.NumElements() < ind.K {
		return onError("Ormqf: sizeTau")
	}
	if !matrix.EqualTypes(A, C, tau) {
		return onError("Ormqf: arguments not of same type")
	}
	info := -1
	side := linalg.ParamString(pars.Side)
	trans := linalg.ParamString(pars.Trans)
	switch A.(type) {
	case *matrix.FloatMatrix:
		Aa := A.(*matrix.FloatMatrix).FloatArray()
		Ca := C.(*matrix.FloatMatrix).FloatArray()
		taua := tau.(*matrix.FloatMatrix).FloatArray()
		info = dormqr(side, trans, ind.M, ind.N, ind.K, Aa[ind.OffsetA:], ind.LDa,
			taua, Ca[ind.OffsetC:], ind.LDc)
	case *matrix.ComplexMatrix:
		return onError("Ormqf: complex not implemented yet")
	}
	if info != 0 {
		return onError(fmt.Sprintf("Ormqr: lapack error %d", info))
	}
	return nil
}
예제 #18
0
파일: syevx.go 프로젝트: jvlmdr/linalg
func SyevxFloat(A, W, Z matrix.Matrix, abstol float64, vlimit []float64, ilimit []int, opts ...linalg.Option) error {
	var vl, vu float64
	var il, iu int

	pars, err := linalg.GetParameters(opts...)
	if err != nil {
		return err
	}
	ind := linalg.GetIndexOpts(opts...)
	arows := ind.LDa
	if ind.N < 0 {
		ind.N = A.Rows()
		if ind.N != A.Cols() {
			return onError("Syevr: A not square")
		}
	}
	// Check indexes
	if ind.N == 0 {
		return nil
	}
	if ind.LDa == 0 {
		ind.LDa = max(1, A.LeadingIndex())
		arows = max(1, A.Rows())
	}
	if ind.LDa < max(1, A.Rows()) {
		return onError("Syevr: lda")
	}
	if pars.Range == linalg.PRangeValue {
		if vlimit == nil {
			return onError("Syevx: vlimit is nil")
		}
		vl = vlimit[0]
		vu = vlimit[1]
		if vl >= vu {
			return onError("Syevx: must be: vl < vu")
		}
	} else if pars.Range == linalg.PRangeInt {
		if ilimit == nil {
			return onError("Syevx: ilimit is nil")
		}
		il = ilimit[0]
		iu = ilimit[1]
		if il < 1 || il > iu || iu > ind.N {
			return onError("Syevx: must be:1 <= il <= iu <= N")
		}
	}
	if pars.Jobz == linalg.PJobValue {
		if Z == nil {
			return onError("Syevx: Z is nil")
		}
		if ind.LDz == 0 {
			ind.LDz = max(1, Z.LeadingIndex())
		}
		if ind.LDz < max(1, ind.N) {
			return onError("Syevx: ldz")
		}
	} else {
		if ind.LDz == 0 {
			ind.LDz = 1
		}
		if ind.LDz < 1 {
			return onError("Syevx: ldz")
		}
	}
	if ind.OffsetA < 0 {
		return onError("Syevx: OffsetA")
	}
	sizeA := A.NumElements()
	if sizeA < ind.OffsetA+(ind.N-1)*arows+ind.N {
		return onError("Syevx: sizeA")
	}
	if ind.OffsetW < 0 {
		return onError("Syevx: OffsetW")
	}
	sizeW := W.NumElements()
	if sizeW < ind.OffsetW+ind.N {
		return onError("Syevx: sizeW")
	}
	if pars.Jobz == linalg.PJobValue {
		if ind.OffsetZ < 0 {
			return onError("Syevx: OffsetW")
		}
		zrows := max(1, Z.Rows())
		minZ := ind.OffsetZ + (ind.N-1)*zrows + ind.N
		if pars.Range == linalg.PRangeInt {
			minZ = ind.OffsetZ + (iu-il)*zrows + ind.N
		}
		if Z.NumElements() < minZ {
			return onError("Syevx: sizeZ")
		}
	}

	Aa := A.(*matrix.FloatMatrix).FloatArray()
	Wa := W.(*matrix.FloatMatrix).FloatArray()
	var Za []float64
	if pars.Jobz == linalg.PJobValue {
		Za = Z.(*matrix.FloatMatrix).FloatArray()
	} else {
		Za = nil
	}
	jobz := linalg.ParamString(pars.Jobz)
	rnge := linalg.ParamString(pars.Range)
	uplo := linalg.ParamString(pars.Uplo)

	info := dsyevx(jobz, rnge, uplo, ind.N, Aa[ind.OffsetA:], ind.LDa,
		vl, vu, il, iu, ind.M, Wa[ind.OffsetW:], Za, ind.LDz)
	if info != 0 {
		return onError(fmt.Sprintf("Syevx: call failed %d", info))
	}
	return nil
}