func checkGbtrs(ind *linalg.IndexOpts, A, B matrix.Matrix, ipiv []int32) error { if ind.Kl < 0 { return errors.New("Gbtrs: invalid kl") } if ind.N < 0 { ind.N = A.Rows() } if ind.Nrhs < 0 { ind.Nrhs = A.Cols() } if ind.N == 0 || ind.Nrhs == 0 { return nil } if ind.Ku < 0 { ind.Ku = A.Rows() - 2*ind.Kl - 1 } if ind.Ku < 0 { return errors.New("Gbtrs: invalid ku") } if ind.LDa == 0 { ind.LDa = max(1, A.Rows()) } if ind.LDa < 2*ind.Kl+ind.Ku+1 { return errors.New("Gbtrs: lda") } if ind.OffsetA < 0 { return errors.New("Gbtrs: offsetA") } sizeA := A.NumElements() if sizeA < ind.OffsetA+(ind.N-1)*ind.LDa+2*ind.Kl+ind.Ku+1 { return errors.New("Gbtrs: sizeA") } if ind.LDb == 0 { ind.LDb = max(1, B.Rows()) } if ind.OffsetB < 0 { return errors.New("Gbtrs: offsetB") } sizeB := B.NumElements() if sizeB < ind.OffsetB+(ind.Nrhs-1)*ind.LDb+ind.N { return errors.New("Gbtrs: sizeB") } if ipiv != nil && len(ipiv) < ind.N { return errors.New("Gbtrs: size ipiv") } return nil }
func check_level3_func(ind *linalg.IndexOpts, fn funcNum, A, B, C matrix.Matrix, pars *linalg.Parameters) (err error) { switch fn { case fgemm: if ind.M < 0 { if pars.TransA == linalg.PNoTrans { ind.M = A.Rows() } else { ind.M = A.Cols() } } if ind.N < 0 { if pars.TransB == linalg.PNoTrans { ind.N = A.Cols() } else { ind.N = A.Rows() } } if ind.K < 0 { if pars.TransA == linalg.PNoTrans { ind.K = A.Cols() } else { ind.K = A.Rows() } if pars.TransB == linalg.PNoTrans && ind.K != B.Rows() || pars.TransB != linalg.PNoTrans && ind.K != B.Cols() { return errors.New("dimensions of A and B do not match") } } if ind.OffsetB < 0 { return errors.New("offsetB illegal, <0") } if ind.LDa == 0 { ind.LDa = max(1, A.Rows()) } if ind.K > 0 { if (pars.TransA == linalg.PNoTrans && ind.LDa < max(1, ind.M)) || (pars.TransA != linalg.PNoTrans && ind.LDa < max(1, ind.K)) { return errors.New("inconsistent ldA") } sizeA := A.NumElements() if (pars.TransA == linalg.PNoTrans && sizeA < ind.OffsetA+(ind.K-1)*ind.LDa+ind.M) || (pars.TransA != linalg.PNoTrans && sizeA < ind.OffsetA+(ind.M-1)*ind.LDa+ind.K) { return errors.New("sizeA") } } // B matrix if B != nil { if ind.OffsetB < 0 { return errors.New("offsetB illegal, <0") } if ind.LDb == 0 { ind.LDb = max(1, B.Rows()) } if ind.K > 0 { if (pars.TransB == linalg.PNoTrans && ind.LDb < max(1, ind.K)) || (pars.TransB != linalg.PNoTrans && ind.LDb < max(1, ind.N)) { return errors.New("inconsistent ldB") } sizeB := B.NumElements() if (pars.TransB == linalg.PNoTrans && sizeB < ind.OffsetB+(ind.N-1)*ind.LDb+ind.K) || (pars.TransB != linalg.PNoTrans && sizeB < ind.OffsetB+(ind.K-1)*ind.LDb+ind.N) { return errors.New("sizeB") } } } // C matrix if C != nil { if ind.OffsetC < 0 { return errors.New("offsetC illegal, <0") } if ind.LDc == 0 { ind.LDb = max(1, C.Rows()) } if ind.LDc < max(1, ind.M) { return errors.New("inconsistent ldC") } sizeC := C.NumElements() if sizeC < ind.OffsetC+(ind.N-1)*ind.LDc+ind.M { return errors.New("sizeC") } } case fsymm, ftrmm, ftrsm: if ind.M < 0 { ind.M = B.Rows() if pars.Side == linalg.PLeft && (ind.M != A.Rows() || ind.M != A.Cols()) { return errors.New("dimensions of A and B do not match") } } if ind.N < 0 { ind.N = B.Cols() if pars.Side == linalg.PRight && (ind.N != A.Rows() || ind.N != A.Cols()) { return errors.New("dimensions of A and B do not match") } } if ind.M == 0 || ind.N == 0 { return } // check A if ind.OffsetB < 0 { return errors.New("offsetB illegal, <0") } if ind.LDa == 0 { ind.LDa = max(1, A.Rows()) } if pars.Side == linalg.PLeft && ind.LDa < max(1, ind.M) || ind.LDa < max(1, ind.N) { return errors.New("ldA") } sizeA := A.NumElements() if (pars.Side == linalg.PLeft && sizeA < ind.OffsetA+(ind.M-1)*ind.LDa+ind.M) || (pars.Side == linalg.PRight && sizeA < ind.OffsetA+(ind.N-1)*ind.LDa+ind.N) { return errors.New("sizeA") } if B != nil { if ind.OffsetB < 0 { return errors.New("offsetB illegal, <0") } if ind.LDb == 0 { ind.LDb = max(1, B.Rows()) } if ind.LDb < max(1, ind.M) { return errors.New("ldB") } sizeB := B.NumElements() if sizeB < ind.OffsetB+(ind.N-1)*ind.LDb+ind.M { return errors.New("sizeB") } } if C != nil { if ind.OffsetC < 0 { return errors.New("offsetC illegal, <0") } if ind.LDc == 0 { ind.LDc = max(1, C.Rows()) } if ind.LDc < max(1, ind.M) { return errors.New("ldC") } sizeC := C.NumElements() if sizeC < ind.OffsetC+(ind.N-1)*ind.LDc+ind.M { return errors.New("sizeC") } } case fsyrk, fsyr2k: if ind.N < 0 { if pars.Trans == linalg.PNoTrans { ind.N = B.Rows() } else { ind.N = B.Cols() } } if ind.K < 0 { if pars.Trans == linalg.PNoTrans { ind.K = A.Cols() } else { ind.K = A.Rows() } } if ind.N == 0 { return } if ind.LDa == 0 { ind.LDa = max(1, A.Rows()) } if ind.K > 0 { if (pars.Trans == linalg.PNoTrans && ind.LDa < max(1, ind.N)) || (pars.Trans != linalg.PNoTrans && ind.LDa < max(1, ind.K)) { return errors.New("inconsistent ldA") } sizeA := A.NumElements() if (pars.Trans == linalg.PNoTrans && sizeA < ind.OffsetA+(ind.K-1)*ind.LDa+ind.N) || (pars.TransA != linalg.PNoTrans && sizeA < ind.OffsetA+(ind.N-1)*ind.LDa+ind.K) { return errors.New("sizeA") } } if B != nil { if ind.OffsetB < 0 { return errors.New("offsetB illegal, <0") } if ind.LDb == 0 { ind.LDb = max(1, B.Rows()) } if ind.K > 0 { if (pars.Trans == linalg.PNoTrans && ind.LDb < max(1, ind.N)) || (pars.Trans != linalg.PNoTrans && ind.LDb < max(1, ind.K)) { return errors.New("ldB") } sizeB := B.NumElements() if (pars.Trans == linalg.PNoTrans && sizeB < ind.OffsetB+(ind.K-1)*ind.LDb+ind.N) || (pars.Trans != linalg.PNoTrans && sizeB < ind.OffsetB+(ind.N-1)*ind.LDb+ind.K) { return errors.New("sizeB") } } } if C != nil { if ind.OffsetC < 0 { return errors.New("offsetC illegal, <0") } if ind.LDc == 0 { ind.LDc = max(1, C.Rows()) } if ind.LDc < max(1, ind.N) { return errors.New("ldC") } sizeC := C.NumElements() if sizeC < ind.OffsetC+(ind.N-1)*ind.LDc+ind.N { return errors.New("sizeC") } } } err = nil return }