func (e *Evaluator) handleComparisonOp(o *ast.BinaryOperationExpr) bool { a, b := types.Coerce(o.L.GetValue(), o.R.GetValue()) if types.IsNil(a) || types.IsNil(b) { // for <=>, if a and b are both nil, return true. // if a or b is nil, return false. if o.Op == opcode.NullEQ { if types.IsNil(a) || types.IsNil(b) { o.SetValue(oneI64) } else { o.SetValue(zeroI64) } } else { o.SetValue(nil) } return true } n, err := types.Compare(a, b) if err != nil { e.err = errors.Trace(err) return false } r, err := getCompResult(o.Op, n) if err != nil { e.err = errors.Trace(err) return false } if r { o.SetValue(oneI64) } else { o.SetValue(zeroI64) } return true }
// Less implements sort.Interface Less interface. func (t *orderByTable) Less(i, j int) bool { for index, asc := range t.Ascs { v1 := t.Rows[i].Key[index] v2 := t.Rows[j].Key[index] ret, err := types.Compare(v1, v2) if err != nil { // we just have to log this error and skip it. // TODO: record this error and handle it out later. log.Errorf("compare %v %v err %v", v1, v2, err) } if !asc { ret = -ret } if ret < 0 { return true } else if ret > 0 { return false } } return false }
func replaceRow(ctx context.Context, t table.Table, handle int64, replaceRow []interface{}) error { row, err := t.Row(ctx, handle) if err != nil { return errors.Trace(err) } result := 0 isReplace := false touched := make(map[int]bool, len(row)) for i, val := range row { result, err = types.Compare(val, replaceRow[i]) if err != nil { return errors.Trace(err) } if result != 0 { touched[i] = true isReplace = true } } if isReplace { variable.GetSessionVars(ctx).AddAffectedRows(1) if err = t.UpdateRecord(ctx, handle, row, replaceRow, touched); err != nil { return errors.Trace(err) } } return nil }
func (e *Evaluator) caseExpr(v *ast.CaseExpr) bool { var target interface{} = true if v.Value != nil { target = v.Value.GetValue() } if target != nil { for _, val := range v.WhenClauses { cmp, err := types.Compare(target, val.Expr.GetValue()) if err != nil { e.err = errors.Trace(err) return false } if cmp == 0 { v.SetValue(val.Result.GetValue()) return true } } } if v.ElseClause != nil { v.SetValue(v.ElseClause.GetValue()) } else { v.SetValue(nil) } return true }
func (e *Evaluator) checkAnyResult(cs *ast.CompareSubqueryExpr, lv interface{}, result []interface{}) (interface{}, error) { hasNull := false for _, v := range result { if v == nil { hasNull = true continue } comRes, err := types.Compare(lv, v) if err != nil { return nil, errors.Trace(err) } res, err := getCompResult(cs.Op, comRes) if err != nil { return nil, errors.Trace(err) } if res { return true, nil } } if hasNull { // If no matched but we get null, return null. // Like `insert t (c) values (1),(2),(null)`, then // `select 0 > any (select c from t)`, returns null. return nil, nil } return false, nil }
func (e *Evaluator) checkInList(not bool, in interface{}, list []interface{}) (interface{}, error) { hasNull := false for _, v := range list { if types.IsNil(v) { hasNull = true continue } r, err := types.Compare(in, v) if err != nil { return nil, errors.Trace(err) } if r == 0 { return !not, nil } } if hasNull { // if no matched but we got null in In, return null // e.g 1 in (null, 2, 3) returns null return nil, nil } return not, nil }
// comparison function that takes minNotNullVal and maxVal into account. func indexCompare(a interface{}, b interface{}) int { if a == nil && b == nil { return 0 } else if b == nil { return 1 } else if b == nil { return -1 } // a and b both not nil if a == minNotNullVal && b == minNotNullVal { return 0 } else if b == minNotNullVal { return 1 } else if a == minNotNullVal { return -1 } // a and b both not min value if a == maxVal && b == maxVal { return 0 } else if a == maxVal { return 1 } else if b == maxVal { return -1 } n, err := types.Compare(a, b) if err != nil { // Old compare panics if err, so here we do the same thing now. // TODO: return err instead of panic. panic(fmt.Sprintf("should never happend %v", err)) } return n }
// operator: >=, >, <=, <, !=, <>, = <=>, etc. // see https://dev.mysql.com/doc/refman/5.7/en/comparison-operators.html func (o *BinaryOperation) evalComparisonOp(ctx context.Context, args map[interface{}]interface{}) (interface{}, error) { //TODO: support <=> later a, b, err := o.get2(ctx, args) if err != nil { return nil, err } if a == nil || b == nil { // TODO: for <=>, if a and b are both nil, return true return nil, nil } n, err := types.Compare(a, b) if err != nil { return nil, o.traceErr(err) } switch o.Op { case opcode.LT: return n < 0, nil case opcode.LE: return n <= 0, nil case opcode.GE: return n >= 0, nil case opcode.GT: return n > 0, nil case opcode.EQ: return n == 0, nil case opcode.NE: return n != 0, nil default: return nil, o.errorf("invalid op %v in comparision operation", o.Op) } }
func (n *PatternIn) checkInList(in interface{}, list []interface{}) (interface{}, error) { hasNull := false for _, v := range list { if types.IsNil(v) { hasNull = true continue } r, err := types.Compare(in, v) if err != nil { return nil, err } if r == 0 { return !n.Not, nil } } if hasNull { // if no matched but we got null in In, return null // e.g 1 in (null, 2, 3) returns null return nil, nil } return n.Not, nil }
// operator: >=, >, <=, <, !=, <>, = <=>, etc. // see https://dev.mysql.com/doc/refman/5.7/en/comparison-operators.html func (o *BinaryOperation) evalComparisonOp(ctx context.Context, args map[interface{}]interface{}) (interface{}, error) { //TODO: support <=> later a, b, err := o.get2(ctx, args) if err != nil { return nil, err } if a == nil || b == nil { // for <=>, if a and b are both nil, return true. // if a or b is nil, return false. if o.Op == opcode.NullEQ { return a == b, nil } return nil, nil } n, err := types.Compare(a, b) if err != nil { return nil, o.traceErr(err) } r, err := getCompResult(o.Op, n) if err != nil { return nil, o.errorf(err.Error()) } return r, nil }
func (n *AggregateFuncExpr) updateMaxMin(max bool) error { ctx := n.GetContext() if len(n.Args) != 1 { return errors.New("Wrong number of args for AggFuncFirstRow") } v := n.Args[0].GetValue() if !ctx.evaluated { ctx.Value = v ctx.evaluated = true return nil } c, err := types.Compare(ctx.Value, v) if err != nil { return errors.Trace(err) } if max { if c == -1 { ctx.Value = v } } else { if c == 1 { ctx.Value = v } } return nil }
func (e *Evaluator) patternIn(n *ast.PatternInExpr) bool { lhs := n.Expr.GetValue() if types.IsNil(lhs) { n.SetValue(nil) return true } hasNull := false for _, v := range n.List { if types.IsNil(v.GetValue()) { hasNull = true continue } r, err := types.Compare(n.Expr.GetValue(), v.GetValue()) if err != nil { e.err = errors.Trace(err) return false } if r == 0 { n.SetValue(boolToInt64(!n.Not)) return true } } if hasNull { // if no matched but we got null in In, return null // e.g 1 in (null, 2, 3) returns null n.SetValue(nil) return true } n.SetValue(boolToInt64(n.Not)) return true }
// comparison function that takes minNotNullVal and maxVal into account. func indexCompare(a interface{}, b interface{}) int { if a == nil && b == nil { return 0 } else if b == nil { return 1 } else if b == nil { return -1 } // a and b both not nil if a == minNotNullVal && b == minNotNullVal { return 0 } else if b == minNotNullVal { return 1 } else if a == minNotNullVal { return -1 } // a and b both not min value if a == maxVal && b == maxVal { return 0 } else if a == maxVal { return 1 } else if b == maxVal { return -1 } return types.Compare(a, b) }
// comparison function that takes minNotNullVal and maxVal into account. func indexColumnCompare(a interface{}, b interface{}) (int, error) { if a == nil && b == nil { return 0, nil } else if b == nil { return 1, nil } else if a == nil { return -1, nil } // a and b both not nil if a == plan.MinNotNullVal && b == plan.MinNotNullVal { return 0, nil } else if b == plan.MinNotNullVal { return 1, nil } else if a == plan.MinNotNullVal { return -1, nil } // a and b both not min value if a == plan.MaxVal && b == plan.MaxVal { return 0, nil } else if a == plan.MaxVal { return 1, nil } else if b == plan.MaxVal { return -1, nil } n, err := types.Compare(a, b) if err != nil { return 0, errors.Trace(err) } return n, nil }
// Update implements AggregationFunction interface. func (mmf *maxMinFunction) Update(row []types.Datum, groupKey []byte, ectx context.Context) error { ctx := mmf.getContext(groupKey) if len(mmf.Args) != 1 { return errors.New("Wrong number of args for AggFuncMaxMin") } a := mmf.Args[0] value, err := a.Eval(row, ectx) if err != nil { return errors.Trace(err) } if !ctx.Evaluated { ctx.Value = value.GetValue() } if value.GetValue() == nil { return nil } var c int c, err = types.Compare(ctx.Value, value.GetValue()) if err != nil { return errors.Trace(err) } if (mmf.isMax && c == -1) || (!mmf.isMax && c == 1) { ctx.Value = value.GetValue() } ctx.Evaluated = true return nil }
func (r *rangeBuilder) buildTableRanges(rangePoints []rangePoint) []TableRange { tableRanges := make([]TableRange, 0, len(rangePoints)/2) for i := 0; i < len(rangePoints); i += 2 { startPoint := rangePoints[i] if startPoint.value == nil || startPoint.value == MinNotNullVal { startPoint.value = math.MinInt64 } startInt, err := types.ToInt64(startPoint.value) if err != nil { r.err = errors.Trace(err) return tableRanges } cmp, err := types.Compare(startInt, startPoint.value) if err != nil { r.err = errors.Trace(err) return tableRanges } if cmp < 0 || (cmp == 0 && startPoint.excl) { startInt++ } endPoint := rangePoints[i+1] if endPoint.value == nil { endPoint.value = math.MinInt64 } else if endPoint.value == MaxVal { endPoint.value = math.MaxInt64 } endInt, err := types.ToInt64(endPoint.value) if err != nil { r.err = errors.Trace(err) return tableRanges } cmp, err = types.Compare(endInt, endPoint.value) if err != nil { r.err = errors.Trace(err) return tableRanges } if cmp > 0 || (cmp == 0 && endPoint.excl) { endInt-- } if startInt > endInt { continue } tableRanges = append(tableRanges, TableRange{LowVal: startInt, HighVal: endInt}) } return tableRanges }
func (r *indexPlan) isPointLookup(span *indexSpan) bool { equalOp := span.lowVal == span.highVal && !span.lowExclude && !span.highExclude if !equalOp || !r.unique || span.lowVal == nil { return false } n, err := types.Compare(span.seekVal, span.lowVal) if err != nil { return false } return n == 0 }
func (r *indexPlan) isPointLookup(span *indexSpan) bool { if span.lowExclude || span.highExclude || span.lowVal == nil || !r.unique { return false } n, err := types.Compare(span.seekVal, span.highVal) if err != nil { return false } // 'seekVal==highVal' means that 'seekVal==lowVal && lowVal==highVal' return n == 0 }
// See https://dev.mysql.com/doc/refman/5.7/en/control-flow-functions.html#function_nullif func builtinNullIf(args []interface{}, m map[interface{}]interface{}) (interface{}, error) { // nullif(expr1, expr2) // returns null if expr1 = expr2 is true, otherwise returns expr1 v1 := args[0] v2 := args[1] if types.IsNil(v1) || types.IsNil(v2) { return v1, nil } if n, err := types.Compare(v1, v2); err != nil || n == 0 { return nil, err } return v1, nil }
// Less implements sort.Interface Less interface. func (t *orderByTable) Less(i, j int) bool { for index, asc := range t.Ascs { v1 := t.Rows[i].Key[index] v2 := t.Rows[j].Key[index] ret := types.Compare(v1, v2) if !asc { ret = -ret } if ret < 0 { return true } else if ret > 0 { return false } } return false }
func (r *rangePointSorter) Less(i, j int) bool { a := r.points[i] b := r.points[j] if a.value == nil && b.value == nil { return r.equalValueLess(a, b) } else if b.value == nil { return false } else if a.value == nil { return true } // a and b both not nil if a.value == MinNotNullVal && b.value == MinNotNullVal { return r.equalValueLess(a, b) } else if b.value == MinNotNullVal { return false } else if a.value == MinNotNullVal { return true } // a and b both not min value if a.value == MaxVal && b.value == MaxVal { return r.equalValueLess(a, b) } else if a.value == MaxVal { return false } else if b.value == MaxVal { return true } n, err := types.Compare(a.value, b.value) if err != nil { r.err = err return true } if n == 0 { return r.equalValueLess(a, b) } return n < 0 }
// Less implements sort.Interface Less interface. func (e *SortExec) Less(i, j int) bool { for index, by := range e.ByItems { v1 := e.Rows[i].key[index] v2 := e.Rows[j].key[index] ret, err := types.Compare(v1, v2) if err != nil { e.err = err return true } if by.Desc { ret = -ret } if ret < 0 { return true } else if ret > 0 { return false } } return false }
func builtinMin(args []interface{}, ctx map[interface{}]interface{}) (v interface{}, err error) { if _, ok := ctx[ExprEvalArgAggEmpty]; ok { return } fn := ctx[ExprEvalFn] if _, ok := ctx[ExprAggDone]; ok { if v, ok = ctx[fn]; ok { return } return nil, nil } min := ctx[fn] y := args[0] if y == nil { return } // Notice: for min, `nil > non nil` if min == nil { min = y } else { n, err := types.Compare(min, y) if err != nil { return nil, errors.Trace(err) } if n > 0 { min = y } } ctx[fn] = min return }
func builtinMax(args []interface{}, ctx map[interface{}]interface{}) (v interface{}, err error) { if _, ok := ctx[ExprEvalArgAggEmpty]; ok { return } fn := ctx[ExprEvalFn] if _, ok := ctx[ExprAggDone]; ok { if v, ok = ctx[fn]; ok { return } return nil, nil } max := ctx[fn] y := args[0] if types.IsNil(y) { return } // Notice: for max, `nil < non nil` if types.IsNil(max) { max = y } else { n, err := types.Compare(max, y) if err != nil { return nil, errors.Trace(err) } if n < 0 { max = y } } ctx[fn] = max return }
func updateRecord(ctx context.Context, h int64, data []interface{}, t table.Table, updateColumns map[int]*expression.Assignment, evalMap map[interface{}]interface{}, offset int, onDuplicateUpdate bool) error { if err := t.LockRow(ctx, h); err != nil { return errors.Trace(err) } cols := t.Cols() oldData := data newData := make([]interface{}, len(cols)) touched := make(map[int]bool, len(cols)) copy(newData, oldData) assignExists := false var newHandle interface{} for i, asgn := range updateColumns { if i < offset || i >= offset+len(cols) { // The assign expression is for another table, not this. continue } val, err := asgn.Expr.Eval(ctx, evalMap) if err != nil { return errors.Trace(err) } colIndex := i - offset col := cols[colIndex] if col.IsPKHandleColumn(t.Meta()) { newHandle = val } touched[colIndex] = true newData[colIndex] = val assignExists = true } // If no assign list for this table, no need to update. if !assignExists { return nil } // Check whether new value is valid. if err := column.CastValues(ctx, newData, cols); err != nil { return errors.Trace(err) } if err := column.CheckNotNull(cols, newData); err != nil { return errors.Trace(err) } // If row is not changed, we should do nothing. rowChanged := false for i := range oldData { if !touched[i] { continue } n, err := types.Compare(newData[i], oldData[i]) if err != nil { return errors.Trace(err) } if n != 0 { rowChanged = true break } } if !rowChanged { // See: https://dev.mysql.com/doc/refman/5.7/en/mysql-real-connect.html CLIENT_FOUND_ROWS if variable.GetSessionVars(ctx).ClientCapability&mysql.ClientFoundRows > 0 { variable.GetSessionVars(ctx).AddAffectedRows(1) } return nil } var err error if newHandle != nil { err = t.RemoveRecord(ctx, h, oldData) if err != nil { return errors.Trace(err) } _, err = t.AddRecord(ctx, newData) } else { // Update record to new value and update index. err = t.UpdateRecord(ctx, h, oldData, newData, touched) } if err != nil { return errors.Trace(err) } // Record affected rows. if !onDuplicateUpdate { variable.GetSessionVars(ctx).AddAffectedRows(1) } else { variable.GetSessionVars(ctx).AddAffectedRows(2) } return nil }
func updateRecord(ctx context.Context, h int64, data []interface{}, t table.Table, updateColumns map[int]expression.Assignment, m map[interface{}]interface{}, offset int, onDuplicateUpdate bool) error { if err := t.LockRow(ctx, h, true); err != nil { return errors.Trace(err) } oldData := make([]interface{}, len(t.Cols())) touched := make([]bool, len(t.Cols())) copy(oldData, data) cols := t.Cols() assignExists := false for i, asgn := range updateColumns { if i < offset || i >= offset+len(cols) { // The assign expression is for another table, not this. continue } val, err := asgn.Expr.Eval(ctx, m) if err != nil { return err } colIndex := i - offset touched[colIndex] = true data[colIndex] = val assignExists = true } // no assign list for this table, no need to update. if !assignExists { return nil } // Check whether new value is valid. if err := column.CastValues(ctx, data, t.Cols()); err != nil { return err } if err := column.CheckNotNull(t.Cols(), data); err != nil { return err } // If row is not changed, we should do nothing. rowChanged := false for i, d := range data { if !touched[i] { continue } od := oldData[i] n, err := types.Compare(d, od) if err != nil { return errors.Trace(err) } if n != 0 { rowChanged = true break } } if !rowChanged { // See: https://dev.mysql.com/doc/refman/5.7/en/mysql-real-connect.html CLIENT_FOUND_ROWS if variable.GetSessionVars(ctx).ClientCapability&mysql.ClientFoundRows > 0 { variable.GetSessionVars(ctx).AddAffectedRows(1) } return nil } // Update record to new value and update index. err := t.UpdateRecord(ctx, h, oldData, data, touched) if err != nil { return errors.Trace(err) } // Record affected rows. if !onDuplicateUpdate { variable.GetSessionVars(ctx).AddAffectedRows(1) } else { variable.GetSessionVars(ctx).AddAffectedRows(2) } return nil }
func (s *testEvaluatorSuite) TestUnaryOp(c *C) { tbl := []struct { arg interface{} op opcode.Op result interface{} }{ // test NOT. {1, opcode.Not, int64(0)}, {0, opcode.Not, int64(1)}, {nil, opcode.Not, nil}, {mysql.Hex{Value: 0}, opcode.Not, int64(1)}, {mysql.Bit{Value: 0, Width: 1}, opcode.Not, int64(1)}, {mysql.Enum{Name: "a", Value: 1}, opcode.Not, int64(0)}, {mysql.Set{Name: "a", Value: 1}, opcode.Not, int64(0)}, // test BitNeg. {nil, opcode.BitNeg, nil}, {-1, opcode.BitNeg, uint64(0)}, // test Plus. {nil, opcode.Plus, nil}, {float64(1.0), opcode.Plus, float64(1.0)}, {int(1), opcode.Plus, int(1)}, {int64(1), opcode.Plus, int64(1)}, {uint64(1), opcode.Plus, uint64(1)}, {"1.0", opcode.Plus, "1.0"}, {[]byte("1.0"), opcode.Plus, []byte("1.0")}, {mysql.Hex{Value: 1}, opcode.Plus, mysql.Hex{Value: 1}}, {mysql.Bit{Value: 1, Width: 1}, opcode.Plus, mysql.Bit{Value: 1, Width: 1}}, {true, opcode.Plus, int64(1)}, {false, opcode.Plus, int64(0)}, {mysql.Enum{Name: "a", Value: 1}, opcode.Plus, mysql.Enum{Name: "a", Value: 1}}, {mysql.Set{Name: "a", Value: 1}, opcode.Plus, mysql.Set{Name: "a", Value: 1}}, // test Minus. {nil, opcode.Minus, nil}, {float64(1.0), opcode.Minus, float64(-1.0)}, {int(1), opcode.Minus, int(-1)}, {int64(1), opcode.Minus, int64(-1)}, {uint64(1), opcode.Minus, -int64(1)}, {"1.0", opcode.Minus, -1.0}, {[]byte("1.0"), opcode.Minus, -1.0}, {mysql.Hex{Value: 1}, opcode.Minus, -1.0}, {mysql.Bit{Value: 1, Width: 1}, opcode.Minus, -1.0}, {true, opcode.Minus, int64(-1)}, {false, opcode.Minus, int64(0)}, {mysql.Enum{Name: "a", Value: 1}, opcode.Minus, -1.0}, {mysql.Set{Name: "a", Value: 1}, opcode.Minus, -1.0}, } ctx := mock.NewContext() expr := &ast.UnaryOperationExpr{} for _, t := range tbl { expr.Op = t.op expr.V = ast.NewValueExpr(t.arg) result, err := Eval(ctx, expr) c.Assert(err, IsNil) c.Assert(result, DeepEquals, t.result) } tbl = []struct { arg interface{} op opcode.Op result interface{} }{ {mysql.NewDecimalFromInt(1, 0), opcode.Plus, mysql.NewDecimalFromInt(1, 0)}, {mysql.Duration{Duration: time.Duration(838*3600 + 59*60 + 59), Fsp: mysql.DefaultFsp}, opcode.Plus, mysql.Duration{Duration: time.Duration(838*3600 + 59*60 + 59), Fsp: mysql.DefaultFsp}}, {mysql.Time{Time: time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC), Type: mysql.TypeDatetime, Fsp: 0}, opcode.Plus, mysql.Time{Time: time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC), Type: mysql.TypeDatetime, Fsp: 0}}, {mysql.NewDecimalFromInt(1, 0), opcode.Minus, mysql.NewDecimalFromInt(-1, 0)}, {mysql.ZeroDuration, opcode.Minus, mysql.NewDecimalFromInt(0, 0)}, {mysql.Time{Time: time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC), Type: mysql.TypeDatetime, Fsp: 0}, opcode.Minus, mysql.NewDecimalFromInt(-20091110230000, 0)}, } for _, t := range tbl { expr := &ast.UnaryOperationExpr{Op: t.op, V: ast.NewValueExpr(t.arg)} result, err := Eval(ctx, expr) c.Assert(err, IsNil) ret, err := types.Compare(result, t.result) c.Assert(err, IsNil) c.Assert(ret, Equals, 0) } }
func (s *testUnaryOperationSuite) TestUnaryOp(c *C) { tbl := []struct { arg interface{} op opcode.Op result interface{} }{ // test NOT. {1, opcode.Not, int64(0)}, {0, opcode.Not, int64(1)}, {nil, opcode.Not, nil}, {mysql.Hex{Value: 0}, opcode.Not, int64(1)}, {mysql.Bit{Value: 0, Width: 1}, opcode.Not, int64(1)}, // test BitNeg. {nil, opcode.BitNeg, nil}, {-1, opcode.BitNeg, uint64(0)}, // test Plus. {nil, opcode.Plus, nil}, {float32(1.0), opcode.Plus, float32(1.0)}, {float64(1.0), opcode.Plus, float64(1.0)}, {int(1), opcode.Plus, int(1)}, {int64(1), opcode.Plus, int64(1)}, {uint64(1), opcode.Plus, uint64(1)}, {"1.0", opcode.Plus, "1.0"}, {[]byte("1.0"), opcode.Plus, []byte("1.0")}, {mysql.Hex{Value: 1}, opcode.Plus, mysql.Hex{Value: 1}}, {mysql.Bit{Value: 1, Width: 1}, opcode.Plus, mysql.Bit{Value: 1, Width: 1}}, {true, opcode.Plus, int64(1)}, {false, opcode.Plus, int64(0)}, // test Minus. {nil, opcode.Minus, nil}, {float32(1.0), opcode.Minus, float32(-1.0)}, {float64(1.0), opcode.Minus, float64(-1.0)}, {int(1), opcode.Minus, int(-1)}, {int64(1), opcode.Minus, int64(-1)}, {uint(1), opcode.Minus, -int64(1)}, {uint8(1), opcode.Minus, -int64(1)}, {uint16(1), opcode.Minus, -int64(1)}, {uint32(1), opcode.Minus, -int64(1)}, {uint64(1), opcode.Minus, -int64(1)}, {"1.0", opcode.Minus, -1.0}, {[]byte("1.0"), opcode.Minus, -1.0}, {mysql.Hex{Value: 1}, opcode.Minus, -1.0}, {mysql.Bit{Value: 1, Width: 1}, opcode.Minus, -1.0}, {true, opcode.Minus, int64(-1)}, {false, opcode.Minus, int64(0)}, } for _, t := range tbl { expr := NewUnaryOperation(t.op, Value{t.arg}) exprc := expr.Clone() result, err := exprc.Eval(nil, nil) c.Assert(err, IsNil) c.Assert(result, DeepEquals, t.result) } tbl = []struct { arg interface{} op opcode.Op result interface{} }{ {mysql.NewDecimalFromInt(1, 0), opcode.Plus, mysql.NewDecimalFromInt(1, 0)}, {mysql.Duration{Duration: time.Duration(838*3600 + 59*60 + 59), Fsp: mysql.DefaultFsp}, opcode.Plus, mysql.Duration{Duration: time.Duration(838*3600 + 59*60 + 59), Fsp: mysql.DefaultFsp}}, {mysql.Time{Time: time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC), Type: mysql.TypeDatetime, Fsp: 0}, opcode.Plus, mysql.Time{Time: time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC), Type: mysql.TypeDatetime, Fsp: 0}}, {mysql.NewDecimalFromInt(1, 0), opcode.Minus, mysql.NewDecimalFromInt(-1, 0)}, {mysql.ZeroDuration, opcode.Minus, mysql.NewDecimalFromInt(0, 0)}, {mysql.Time{Time: time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC), Type: mysql.TypeDatetime, Fsp: 0}, opcode.Minus, mysql.NewDecimalFromInt(-20091110230000, 0)}, } for _, t := range tbl { expr := NewUnaryOperation(t.op, Value{t.arg}) exprc := expr.Clone() result, err := exprc.Eval(nil, nil) c.Assert(err, IsNil) ret, err := types.Compare(result, t.result) c.Assert(err, IsNil) c.Assert(ret, Equals, 0) } // test String(). strTbl := []struct { expr expression.Expression op opcode.Op isStatic bool }{ {NewBinaryOperation(opcode.EQ, Value{1}, Value{1}), opcode.Plus, true}, {NewUnaryOperation(opcode.Not, Value{1}), opcode.Not, true}, {Value{1}, opcode.Not, true}, {Value{1}, opcode.Plus, true}, {&PExpr{Value{1}}, opcode.Not, true}, {NewBinaryOperation(opcode.EQ, Value{1}, Value{1}), opcode.Not, true}, {NewBinaryOperation(opcode.NE, Value{1}, Value{1}), opcode.Not, true}, {NewBinaryOperation(opcode.GT, Value{1}, Value{1}), opcode.Not, true}, {NewBinaryOperation(opcode.GE, Value{1}, Value{1}), opcode.Not, true}, {NewBinaryOperation(opcode.LT, Value{1}, Value{1}), opcode.Not, true}, {NewBinaryOperation(opcode.LE, Value{1}, Value{1}), opcode.Not, true}, } for _, t := range strTbl { expr := NewUnaryOperation(t.op, t.expr) c.Assert(expr.IsStatic(), Equals, t.isStatic) str := expr.String() c.Assert(len(str), Greater, 0) } // test error. errTbl := []struct { arg interface{} op opcode.Op }{ {mockExpr{}, opcode.Not}, {mockExpr{}, opcode.BitNeg}, {mockExpr{}, opcode.Plus}, {mockExpr{}, opcode.Minus}, {mockExpr{}, opcode.EQ}, } // test error clone expr := NewUnaryOperation(opcode.Not, mockExpr{err: errors.New("must error")}) c.Assert(expr.Clone(), NotNil) for _, t := range errTbl { expr := NewUnaryOperation(t.op, Value{t.arg}) _, err := expr.Eval(nil, nil) c.Assert(err, NotNil) } }
func updateRecord(ctx context.Context, h int64, data []interface{}, t table.Table, tcols []*column.Col, assignList []expressions.Assignment, insertData []interface{}, args map[interface{}]interface{}) error { if err := t.LockRow(ctx, h, true); err != nil { return errors.Trace(err) } oldData := make([]interface{}, len(t.Cols())) touched := make([]bool, len(t.Cols())) copy(oldData, data) // Generate new values m := args if m == nil { m = make(map[interface{}]interface{}, len(t.Cols())) // Set parameter for evaluating expression. for _, col := range t.Cols() { m[col.Name.L] = data[col.Offset] } } if insertData != nil { m[expressions.ExprEvalValuesFunc] = func(name string) (interface{}, error) { return getInsertValue(name, t.Cols(), insertData) } } for i, asgn := range assignList { val, err := asgn.Expr.Eval(ctx, m) if err != nil { return err } colIndex := tcols[i].Offset touched[colIndex] = true data[colIndex] = val } // Check whether new value is valid. if err := column.CastValues(ctx, data, t.Cols()); err != nil { return err } if err := column.CheckNotNull(t.Cols(), data); err != nil { return err } // If row is not changed, we should do nothing. rowChanged := false for i, d := range data { if !touched[i] { continue } od := oldData[i] n, err := types.Compare(d, od) if err != nil { return errors.Trace(err) } if n != 0 { rowChanged = true break } } if !rowChanged { // See: https://dev.mysql.com/doc/refman/5.7/en/mysql-real-connect.html CLIENT_FOUND_ROWS if variable.GetSessionVars(ctx).ClientCapability&mysql.ClientFoundRows > 0 { variable.GetSessionVars(ctx).AddAffectedRows(1) } return nil } // Update record to new value and update index. err := t.UpdateRecord(ctx, h, oldData, data, touched) if err != nil { return errors.Trace(err) } // Record affected rows. if len(insertData) == 0 { variable.GetSessionVars(ctx).AddAffectedRows(1) } else { variable.GetSessionVars(ctx).AddAffectedRows(2) } return nil }