func checkGbtrs(ind *linalg.IndexOpts, A, B matrix.Matrix, ipiv []int32) error { if ind.Kl < 0 { return errors.New("Gbtrs: invalid kl") } if ind.N < 0 { ind.N = A.Rows() } if ind.Nrhs < 0 { ind.Nrhs = A.Cols() } if ind.N == 0 || ind.Nrhs == 0 { return nil } if ind.Ku < 0 { ind.Ku = A.Rows() - 2*ind.Kl - 1 } if ind.Ku < 0 { return errors.New("Gbtrs: invalid ku") } if ind.LDa == 0 { ind.LDa = max(1, A.Rows()) } if ind.LDa < 2*ind.Kl+ind.Ku+1 { return errors.New("Gbtrs: lda") } if ind.OffsetA < 0 { return errors.New("Gbtrs: offsetA") } sizeA := A.NumElements() if sizeA < ind.OffsetA+(ind.N-1)*ind.LDa+2*ind.Kl+ind.Ku+1 { return errors.New("Gbtrs: sizeA") } if ind.LDb == 0 { ind.LDb = max(1, B.Rows()) } if ind.OffsetB < 0 { return errors.New("Gbtrs: offsetB") } sizeB := B.NumElements() if sizeB < ind.OffsetB+(ind.Nrhs-1)*ind.LDb+ind.N { return errors.New("Gbtrs: sizeB") } if ipiv != nil && len(ipiv) < ind.N { return errors.New("Gbtrs: size ipiv") } return nil }
func checkGbtrf(ind *linalg.IndexOpts, A matrix.Matrix, ipiv []int32) error { if ind.M < 0 { return errors.New("Gbtrf: illegal m") } if ind.Kl < 0 { return errors.New("GBtrf: illegal kl") } if ind.N < 0 { ind.N = A.Rows() } if ind.M == 0 || ind.N == 0 { return nil } if ind.Ku < 0 { ind.Ku = A.Rows() - 2*ind.Kl - 1 } if ind.Ku < 0 { return errors.New("Gbtrf: invalid ku") } if ind.LDa == 0 { ind.LDa = max(1, A.Rows()) } if ind.LDa < 2*ind.Kl+ind.Ku+1 { return errors.New("Gbtrf: lda") } if ind.OffsetA < 0 { return errors.New("Gbtrf: offsetA") } sizeA := A.NumElements() if sizeA < ind.OffsetA+(ind.N-1)*ind.LDa+2*ind.Kl+ind.Ku+1 { return errors.New("Gbtrf: sizeA") } if ipiv != nil && len(ipiv) < min(ind.N, ind.M) { return errors.New("Gbtrf: size ipiv") } return nil }
func check_level2_func(ind *linalg.IndexOpts, fn funcNum, X, Y, A matrix.Matrix, pars *linalg.Parameters) error { if ind.IncX <= 0 { return errors.New("incX") } if ind.IncY <= 0 { return errors.New("incY") } sizeA := A.NumElements() switch fn { case fgemv: // general matrix if ind.M < 0 { ind.M = A.Rows() } if ind.N < 0 { ind.N = A.Cols() } if ind.LDa == 0 { ind.LDa = max(1, A.Rows()) } if ind.OffsetA < 0 { return errors.New("offsetA") } if ind.N > 0 && ind.M > 0 && sizeA < ind.OffsetA+(ind.N-1)*ind.LDa+ind.M { return errors.New("sizeA") } if ind.OffsetX < 0 { return errors.New("offsetX") } if ind.OffsetY < 0 { return errors.New("offsetY") } sizeX := X.NumElements() sizeY := Y.NumElements() if pars.Trans == linalg.PNoTrans { if ind.N > 0 && sizeX < ind.OffsetX+(ind.N-1)*abs(ind.IncX)+1 { return errors.New("sizeX") } if ind.M > 0 && sizeY < ind.OffsetY+(ind.M-1)*abs(ind.IncY)+1 { return errors.New("sizeY") } } else { if ind.M > 0 && sizeX < ind.OffsetX+(ind.M-1)*abs(ind.IncX)+1 { return errors.New("sizeX") } if ind.N > 0 && sizeY < ind.OffsetY+(ind.N-1)*abs(ind.IncY)+1 { return errors.New("sizeY") } } case fger: if ind.M < 0 { ind.M = A.Rows() } if ind.N < 0 { ind.N = A.Cols() } if ind.M == 0 || ind.N == 0 { return nil } if ind.M > 0 && ind.N > 0 { if ind.LDa == 0 { ind.LDa = max(1, A.Rows()) } if ind.LDa < max(1, ind.M) { return errors.New("ldA") } if ind.OffsetA < 0 { return errors.New("offsetA") } if sizeA < ind.OffsetA+(ind.N-1)*ind.LDa+ind.M { return errors.New("sizeA") } if ind.OffsetX < 0 { return errors.New("offsetX") } if ind.OffsetY < 0 { return errors.New("offsetY") } sizeX := X.NumElements() if sizeX < ind.OffsetX+(ind.M-1)*abs(ind.IncX)+1 { return errors.New("sizeX") } sizeY := Y.NumElements() if sizeY < ind.OffsetY+(ind.N-1)*abs(ind.IncY)+1 { return errors.New("sizeY") } } case fgbmv: // general banded if ind.M < 0 { ind.M = A.Rows() } if ind.N < 0 { ind.N = A.Cols() } if ind.Kl < 0 { return errors.New("kl") } if ind.Ku < 0 { ind.Ku = A.Rows() - 1 - ind.Kl } if ind.Ku < 0 { return errors.New("ku") } if ind.LDa == 0 { ind.LDa = max(1, A.Rows()) } if ind.LDa < ind.Kl+ind.Ku+1 { return errors.New("ldA") } if ind.OffsetA < 0 { return errors.New("offsetA") } sizeA := A.NumElements() if ind.N > 0 && ind.M > 0 && sizeA < ind.OffsetA+(ind.N-1)*ind.LDa+ind.Kl+ind.Ku+1 { return errors.New("sizeA") } if ind.OffsetX < 0 { return errors.New("offsetX") } if ind.OffsetY < 0 { return errors.New("offsetY") } sizeX := X.NumElements() sizeY := Y.NumElements() if pars.Trans == linalg.PNoTrans { if ind.N > 0 && sizeX < ind.OffsetX+(ind.N-1)*abs(ind.IncX)+1 { return errors.New("sizeX") } if ind.N > 0 && sizeY < ind.OffsetY+(ind.M-1)*abs(ind.IncY)+1 { return errors.New("sizeY") } } else { if ind.N > 0 && sizeX < ind.OffsetX+(ind.M-1)*abs(ind.IncX)+1 { return errors.New("sizeX") } if ind.N > 0 && sizeY < ind.OffsetY+(ind.N-1)*abs(ind.IncY)+1 { return errors.New("sizeY") } } case ftrmv, ftrsv: // ftrmv = triangular // ftrsv = triangular solve if ind.N < 0 { if A.Rows() != A.Cols() { return errors.New("A not square") } ind.N = A.Rows() } if ind.N > 0 { if ind.LDa == 0 { ind.LDa = max(1, A.Rows()) } if ind.LDa < max(1, ind.N) { return errors.New("ldA") } if ind.OffsetA < 0 { return errors.New("offsetA") } sizeA := A.NumElements() if sizeA < ind.OffsetA+(ind.N-1)*ind.LDa+ind.N { return errors.New("sizeA") } sizeX := X.NumElements() if sizeX < ind.OffsetX+(ind.N-1)*abs(ind.IncX)+1 { return errors.New("sizeX") } } case ftbmv, ftbsv, fsbmv: // ftbmv = triangular banded // ftbsv = triangular banded solve // fsbmv = symmetric banded product if ind.N < 0 { ind.N = A.Rows() } if ind.N > 0 { if ind.K < 0 { ind.K = max(0, A.Rows()-1) } if ind.LDa == 0 { ind.LDa = max(1, A.Rows()) } if ind.LDa < ind.K+1 { return errors.New("ldA") } if ind.OffsetA < 0 { return errors.New("offsetA") } sizeA := A.NumElements() if sizeA < ind.OffsetA+(ind.N-1)*ind.LDa+ind.K+1 { return errors.New("sizeA") } sizeX := X.NumElements() if sizeX < ind.OffsetX+(ind.N-1)*abs(ind.IncX)+1 { return errors.New("sizeX") } if Y != nil { sizeY := Y.NumElements() if sizeY < ind.OffsetY+(ind.N-1)*abs(ind.IncY)+1 { return errors.New("sizeY") } } } case fsymv, fsyr, fsyr2: // fsymv = symmetric product // fsyr = symmetric rank update // fsyr2 = symmetric rank-2 update if ind.N < 0 { if A.Rows() != A.Cols() { return errors.New("A not square") } ind.N = A.Rows() } if ind.N > 0 { if ind.LDa == 0 { ind.LDa = max(1, A.Rows()) } if ind.LDa < max(1, ind.N) { return errors.New("ldA") } if ind.OffsetA < 0 { return errors.New("offsetA") } sizeA := A.NumElements() if sizeA < ind.OffsetA+(ind.N-1)*ind.LDa+ind.N { return errors.New("sizeA") } if ind.OffsetX < 0 { return errors.New("offsetX") } sizeX := X.NumElements() if sizeX < ind.OffsetX+(ind.N-1)*abs(ind.IncX)+1 { return errors.New("sizeX") } if Y != nil { if ind.OffsetY < 0 { return errors.New("offsetY") } sizeY := Y.NumElements() if sizeY < ind.OffsetY+(ind.N-1)*abs(ind.IncY)+1 { return errors.New("sizeY") } } } case fspr, fdspr2, ftpsv, fspmv, ftpmv: // ftpsv = triangular packed solve // fspmv = symmetric packed product // ftpmv = triangular packed } return nil }