コード例 #1
0
ファイル: indexcheck.go プロジェクト: sguzwf/algorithm
func check_level2_func(ind *linalg.IndexOpts, fn funcNum, X, Y, A matrix.Matrix, pars *linalg.Parameters) error {
	if ind.IncX <= 0 {
		return onError("incX")
	}
	if ind.IncY <= 0 {
		return onError("incY")
	}

	sizeA := A.NumElements()
	arows := ind.LDa
	switch fn {
	case fgemv: // general matrix
		if ind.M < 0 {
			ind.M = A.Rows()
		}
		if ind.N < 0 {
			ind.N = A.Cols()
		}
		if ind.LDa == 0 {
			ind.LDa = max(1, A.LeadingIndex())
			arows = max(1, A.Rows())
		}
		if ind.OffsetA < 0 {
			return onError("offsetA")
		}
		if ind.N > 0 && ind.M > 0 &&
			sizeA < ind.OffsetA+(ind.N-1)*arows+ind.M {
			return onError("sizeA")
		}
		if ind.OffsetX < 0 {
			return onError("offsetX")
		}
		if ind.OffsetY < 0 {
			return onError("offsetY")
		}
		sizeX := X.NumElements()
		sizeY := Y.NumElements()
		if pars.Trans == linalg.PNoTrans {
			if ind.N > 0 && sizeX < ind.OffsetX+(ind.N-1)*abs(ind.IncX)+1 {
				return onError("sizeX")
			}
			if ind.M > 0 && sizeY < ind.OffsetY+(ind.M-1)*abs(ind.IncY)+1 {
				return onError("sizeY")
			}
		} else {
			if ind.M > 0 && sizeX < ind.OffsetX+(ind.M-1)*abs(ind.IncX)+1 {
				return onError("sizeX")
			}
			if ind.N > 0 && sizeY < ind.OffsetY+(ind.N-1)*abs(ind.IncY)+1 {
				return onError("sizeY")
			}
		}
	case fger:
		if ind.M < 0 {
			ind.M = A.Rows()
		}
		if ind.N < 0 {
			ind.N = A.Cols()
		}
		if ind.M == 0 || ind.N == 0 {
			return nil
		}
		if ind.M > 0 && ind.N > 0 {
			if ind.LDa == 0 {
				ind.LDa = max(1, A.LeadingIndex())
				arows = max(1, A.Rows())
			}
			if ind.LDa < max(1, ind.M) {
				return onError("ldA")
			}
			if ind.OffsetA < 0 {
				return onError("offsetA")
			}
			if sizeA < ind.OffsetA+(ind.N-1)*arows+ind.M {
				return onError("sizeA")
			}
			if ind.OffsetX < 0 {
				return onError("offsetX")
			}
			if ind.OffsetY < 0 {
				return onError("offsetY")
			}
			sizeX := X.NumElements()
			if sizeX < ind.OffsetX+(ind.M-1)*abs(ind.IncX)+1 {
				return onError("sizeX")
			}
			sizeY := Y.NumElements()
			if sizeY < ind.OffsetY+(ind.N-1)*abs(ind.IncY)+1 {
				return onError("sizeY")
			}
		}
	case fgbmv: // general banded
		if ind.M < 0 {
			ind.M = A.Rows()
		}
		if ind.N < 0 {
			ind.N = A.Cols()
		}
		if ind.Kl < 0 {
			return onError("kl")
		}
		if ind.Ku < 0 {
			ind.Ku = A.Rows() - 1 - ind.Kl
		}
		if ind.Ku < 0 {
			return onError("ku")
		}
		if ind.LDa == 0 {
			ind.LDa = max(1, A.LeadingIndex())
			arows = max(1, A.Rows())
		}
		if ind.LDa < ind.Kl+ind.Ku+1 {
			return onError("ldA")
		}
		if ind.OffsetA < 0 {
			return onError("offsetA")
		}
		sizeA := A.NumElements()
		if ind.N > 0 && ind.M > 0 &&
			sizeA < ind.OffsetA+(ind.N-1)*arows+ind.Kl+ind.Ku+1 {
			return onError("sizeA")
		}
		if ind.OffsetX < 0 {
			return onError("offsetX")
		}
		if ind.OffsetY < 0 {
			return onError("offsetY")
		}
		sizeX := X.NumElements()
		sizeY := Y.NumElements()
		if pars.Trans == linalg.PNoTrans {
			if ind.N > 0 && sizeX < ind.OffsetX+(ind.N-1)*abs(ind.IncX)+1 {
				return onError("sizeX")
			}
			if ind.N > 0 && sizeY < ind.OffsetY+(ind.M-1)*abs(ind.IncY)+1 {
				return onError("sizeY")
			}
		} else {
			if ind.N > 0 && sizeX < ind.OffsetX+(ind.M-1)*abs(ind.IncX)+1 {
				return onError("sizeX")
			}
			if ind.N > 0 && sizeY < ind.OffsetY+(ind.N-1)*abs(ind.IncY)+1 {
				return onError("sizeY")
			}
		}
	case ftrmv, ftrsv:
		// ftrmv = triangular
		// ftrsv = triangular solve
		if ind.N < 0 {
			if A.Rows() != A.Cols() {
				return onError("A not square")
			}
			ind.N = A.Rows()
		}
		if ind.N > 0 {
			if ind.LDa == 0 {
				ind.LDa = max(1, A.LeadingIndex())
				arows = max(1, A.Rows())
			}
			if ind.LDa < max(1, ind.N) {
				return onError("ldA")
			}
			if ind.OffsetA < 0 {
				return onError("offsetA")
			}
			sizeA := A.NumElements()
			if sizeA < ind.OffsetA+(ind.N-1)*arows+ind.N {
				return onError("sizeA")
			}
			sizeX := X.NumElements()
			if sizeX < ind.OffsetX+(ind.N-1)*abs(ind.IncX)+1 {
				return onError("sizeX")
			}
		}
	case ftbmv, ftbsv, fsbmv:
		// ftbmv = triangular banded
		// ftbsv = triangular banded solve
		// fsbmv = symmetric banded product
		arows := ind.LDa
		if ind.N < 0 {
			ind.N = A.Rows()
		}
		if ind.N > 0 {
			if ind.K < 0 {
				ind.K = max(0, A.Rows()-1)
			}
			if ind.LDa == 0 {
				ind.LDa = max(1, A.LeadingIndex())
				arows = max(1, A.Rows())
			}
			if ind.LDa < ind.K+1 {
				return onError("ldA")
			}
			if ind.OffsetA < 0 {
				return onError("offsetA")
			}
			sizeA := A.NumElements()
			if sizeA < ind.OffsetA+(ind.N-1)*arows+ind.K+1 {
				return onError("sizeA")
			}
			sizeX := X.NumElements()
			if sizeX < ind.OffsetX+(ind.N-1)*abs(ind.IncX)+1 {
				return onError("sizeX")
			}
			if Y != nil {
				sizeY := Y.NumElements()
				if sizeY < ind.OffsetY+(ind.N-1)*abs(ind.IncY)+1 {
					return onError("sizeY")
				}
			}
		}
	case fsymv, fsyr, fsyr2:
		// fsymv = symmetric product
		// fsyr = symmetric rank update
		// fsyr2 = symmetric rank-2 update
		if ind.N < 0 {
			if A.Rows() != A.Cols() {
				return onError("A not square")
			}
			ind.N = A.Rows()
		}
		arows := ind.LDa
		if ind.N > 0 {
			if ind.LDa == 0 {
				ind.LDa = max(1, A.LeadingIndex())
				arows = max(1, A.Rows())
			}
			if ind.LDa < max(1, ind.N) {
				return onError("ldA")
			}
			if ind.OffsetA < 0 {
				return onError("offsetA")
			}
			sizeA := A.NumElements()
			if sizeA < ind.OffsetA+(ind.N-1)*arows+ind.N {
				return onError("sizeA")
			}
			if ind.OffsetX < 0 {
				return onError("offsetX")
			}
			sizeX := X.NumElements()
			if sizeX < ind.OffsetX+(ind.N-1)*abs(ind.IncX)+1 {
				return onError("sizeX")
			}
			if Y != nil {
				if ind.OffsetY < 0 {
					return onError("offsetY")
				}
				sizeY := Y.NumElements()
				if sizeY < ind.OffsetY+(ind.N-1)*abs(ind.IncY)+1 {
					return onError("sizeY")
				}
			}
		}
	case fspr, fdspr2, ftpsv, fspmv, ftpmv:
		// ftpsv = triangular packed solve
		// fspmv = symmetric packed product
		// ftpmv = triangular packed
	}
	return nil
}