// Operator: DIV / - % MOD + * // See https://dev.mysql.com/doc/refman/5.7/en/arithmetic-functions.html#operator_divide func (o *BinaryOperation) evalArithmeticOp(ctx context.Context, args map[interface{}]interface{}) (interface{}, error) { a, b, err := o.get2(ctx, args) if err != nil { return nil, err } a = types.RawData(a) b = types.RawData(b) if a == nil || b == nil { return nil, nil } if a, b, err = o.coerceArithmetic2(a, b); err != nil { return nil, o.traceErr(err) } switch o.Op { case opcode.Plus: return o.evalPlus(a, b) case opcode.Minus: return o.evalMinus(a, b) case opcode.Mul: return o.evalMul(a, b) case opcode.Div: return o.evalDiv(a, b) case opcode.Mod: return o.evalMod(a, b) case opcode.IntDiv: return o.evalIntDiv(a, b) default: return nil, o.errorf("invalid op %v in arithmetic operation", o.Op) } }
func (a *recordsetAdapter) Next() (*oplan.Row, error) { row, err := a.executor.Next() if err != nil || row == nil { return nil, errors.Trace(err) } oRow := &oplan.Row{ Data: make([]interface{}, len(row.Data)), RowKeys: make([]*oplan.RowKeyEntry, 0, len(row.RowKeys)), } for i, v := range row.Data { d := types.RawData(v) switch v := d.(type) { case bool: // Convert bool field to int if v { oRow.Data[i] = uint8(1) } else { oRow.Data[i] = uint8(0) } default: oRow.Data[i] = d } } for _, v := range row.RowKeys { oldRowKey := &oplan.RowKeyEntry{ Key: v.Key, Tbl: v.Tbl, } oRow.RowKeys = append(oRow.RowKeys, oldRowKey) } return oRow, nil }
// NewValueExpr creates a ValueExpr with value, and sets default field type. func NewValueExpr(value interface{}) *ValueExpr { ve := &ValueExpr{} ve.Data = types.RawData(value) // TODO: make it more precise. switch value.(type) { case nil: ve.Type = types.NewFieldType(mysql.TypeNull) case bool, int64: ve.Type = types.NewFieldType(mysql.TypeLonglong) case uint64: ve.Type = types.NewFieldType(mysql.TypeLonglong) ve.Type.Flag |= mysql.UnsignedFlag case string, UnquoteString: ve.Type = types.NewFieldType(mysql.TypeVarchar) ve.Type.Charset = mysql.DefaultCharset ve.Type.Collate = mysql.DefaultCollationName case float64: ve.Type = types.NewFieldType(mysql.TypeDouble) case []byte: ve.Type = types.NewFieldType(mysql.TypeBlob) ve.Type.Charset = "binary" ve.Type.Collate = "binary" case mysql.Bit: ve.Type = types.NewFieldType(mysql.TypeBit) case mysql.Hex: ve.Type = types.NewFieldType(mysql.TypeVarchar) ve.Type.Charset = "binary" ve.Type.Collate = "binary" case *types.DataItem: ve.Type = value.(*types.DataItem).Type default: panic(fmt.Sprintf("illegal literal value type:%T", value)) } return ve }
// Eval is a helper function evaluates expression v and do a panic if evaluating error. func Eval(v Expression, ctx context.Context, env map[interface{}]interface{}) (y interface{}) { var err error y, err = v.Eval(ctx, env) if err != nil { panic(err) // panic ok here } y = types.RawData(y) return }
// NewValueExpr creates a ValueExpr with value, and sets default field type. func NewValueExpr(value interface{}) *ValueExpr { ve := &ValueExpr{} ve.Data = types.RawData(value) if _, ok := value.(UnquoteString); ok { ve.Type = types.NewFieldType(mysql.TypeVarchar) ve.Type.Charset = mysql.DefaultCharset ve.Type.Collate = mysql.DefaultCollationName return ve } ve.Type = types.DefaultTypeForValue(value) return ve }
func (e *Evaluator) handleArithmeticOp(o *ast.BinaryOperationExpr) bool { a, err := coerceArithmetic(types.RawData(o.L.GetValue())) if err != nil { e.err = errors.Trace(err) return false } b, err := coerceArithmetic(types.RawData(o.R.GetValue())) if err != nil { e.err = errors.Trace(err) return false } a, b = types.Coerce(a, b) if a == nil || b == nil { o.SetValue(nil) return true } var result interface{} switch o.Op { case opcode.Plus: result, e.err = computePlus(a, b) case opcode.Minus: result, e.err = computeMinus(a, b) case opcode.Mul: result, e.err = computeMul(a, b) case opcode.Div: result, e.err = computeDiv(a, b) case opcode.Mod: result, e.err = computeMod(a, b) case opcode.IntDiv: result, e.err = computeIntDiv(a, b) default: e.err = ErrInvalidOperation.Gen("invalid op %v in arithmetic operation", o.Op) return false } o.SetValue(result) return e.err == nil }
func calculateSum(sum interface{}, v interface{}) (interface{}, error) { // for avg and sum calculation // avg and sum use decimal for integer and decimal type, use float for others // see https://dev.mysql.com/doc/refman/5.7/en/group-by-functions.html var ( data interface{} err error ) v = types.RawData(v) switch y := v.(type) { case int, uint, int8, uint8, int16, uint16, int32, uint32, int64, uint64: data, err = mysql.ConvertToDecimal(v) case mysql.Decimal: data = y case nil: data = nil default: data, err = types.ToFloat64(v) } if err != nil { return nil, err } if data == nil { return sum, nil } data = types.RawData(data) switch x := sum.(type) { case nil: return data, nil case float64: return x + data.(float64), nil case mysql.Decimal: return x.Add(data.(mysql.Decimal)), nil default: return nil, errors.Errorf("invalid value %v(%T) for aggregate", x, x) } }
// FastEval evaluates Value and static +/- Unary expression and returns its value. func FastEval(v interface{}) interface{} { v = types.RawData(v) switch x := v.(type) { case Value: return x.Val case int64, uint64: return v case *UnaryOperation: if x.Op != opcode.Plus && x.Op != opcode.Minus { return nil } if !x.IsStatic() { return nil } m := map[interface{}]interface{}{} return Eval(x, nil, m) default: return nil } }
// Eval implements the Expression Eval interface. func (f *FunctionCast) Eval(ctx context.Context, args map[interface{}]interface{}) (interface{}, error) { value, err := f.Expr.Eval(ctx, args) if err != nil { return nil, errors.Trace(err) } value = types.RawData(value) d := &types.DataItem{Type: f.Tp} // Casting nil to any type returns null if value == nil { d.Data = nil return d, nil } d.Data, err = types.Cast(value, f.Tp) if err != nil { return nil, errors.Trace(err) } return d, nil }
func getDefaultValue(c *ConstraintOpt, tp byte, fsp int) (interface{}, error) { if tp == mysql.TypeTimestamp || tp == mysql.TypeDatetime { value, err := expression.GetTimeValue(nil, c.Evalue, tp, fsp) if err != nil { return nil, errors.Trace(err) } // Value is nil means `default null`. if value == nil { return nil, nil } // If value is mysql.Time, convert it to string. if vv, ok := value.(mysql.Time); ok { return vv.String(), nil } return value, nil } v := expression.FastEval(c.Evalue) return types.RawData(v), nil }
func builtinAbs(args []interface{}, ctx map[interface{}]interface{}) (v interface{}, err error) { d := types.RawData(args[0]) switch x := d.(type) { case nil: return nil, nil case uint, uint8, uint16, uint32, uint64: return x, nil case int, int8, int16, int32, int64: // we don't need to handle error here, it must be success v, _ := types.ToInt64(d) if v >= 0 { return x, nil } // TODO: handle overflow if x is MinInt64 return -v, nil default: // we will try to convert other types to float // TODO: if time has no precision, it will be a integer f, err := types.ToFloat64(d) return math.Abs(f), errors.Trace(err) } }
func (e *Evaluator) unaryOperation(u *ast.UnaryOperationExpr) bool { defer func() { if er := recover(); er != nil { e.err = errors.Errorf("%v", er) } }() a := u.V.GetValue() a = types.RawData(a) if a == nil { u.SetValue(nil) return true } switch op := u.Op; op { case opcode.Not: n, err := types.ToBool(a) if err != nil { e.err = errors.Trace(err) } else if n == 0 { u.SetValue(int64(1)) } else { u.SetValue(int64(0)) } case opcode.BitNeg: // for bit operation, we will use int64 first, then return uint64 n, err := types.ToInt64(a) if err != nil { e.err = errors.Trace(err) return false } u.SetValue(uint64(^n)) case opcode.Plus: switch x := a.(type) { case bool: u.SetValue(boolToInt64(x)) case float32: u.SetValue(+x) case float64: u.SetValue(+x) case int: u.SetValue(+x) case int8: u.SetValue(+x) case int16: u.SetValue(+x) case int32: u.SetValue(+x) case int64: u.SetValue(+x) case uint: u.SetValue(+x) case uint8: u.SetValue(+x) case uint16: u.SetValue(+x) case uint32: u.SetValue(+x) case uint64: u.SetValue(+x) case mysql.Duration: u.SetValue(x) case mysql.Time: u.SetValue(x) case string: u.SetValue(x) case mysql.Decimal: u.SetValue(x) case []byte: u.SetValue(x) case mysql.Hex: u.SetValue(x) case mysql.Bit: u.SetValue(x) case mysql.Enum: u.SetValue(x) case mysql.Set: u.SetValue(x) default: e.err = ErrInvalidOperation return false } case opcode.Minus: switch x := a.(type) { case bool: if x { u.SetValue(int64(-1)) } else { u.SetValue(int64(0)) } case float32: u.SetValue(-x) case float64: u.SetValue(-x) case int: u.SetValue(-x) case int8: u.SetValue(-x) case int16: u.SetValue(-x) case int32: u.SetValue(-x) case int64: u.SetValue(-x) case uint: u.SetValue(-int64(x)) case uint8: u.SetValue(-int64(x)) case uint16: u.SetValue(-int64(x)) case uint32: u.SetValue(-int64(x)) case uint64: // TODO: check overflow and do more test for unsigned type u.SetValue(-int64(x)) case mysql.Duration: u.SetValue(mysql.ZeroDecimal.Sub(x.ToNumber())) case mysql.Time: u.SetValue(mysql.ZeroDecimal.Sub(x.ToNumber())) case string: f, err := types.StrToFloat(x) e.err = errors.Trace(err) u.SetValue(-f) case mysql.Decimal: f, _ := x.Float64() u.SetValue(mysql.NewDecimalFromFloat(-f)) case []byte: f, err := types.StrToFloat(string(x)) e.err = errors.Trace(err) u.SetValue(-f) case mysql.Hex: u.SetValue(-x.ToNumber()) case mysql.Bit: u.SetValue(-x.ToNumber()) case mysql.Enum: u.SetValue(-x.ToNumber()) case mysql.Set: u.SetValue(-x.ToNumber()) default: e.err = ErrInvalidOperation return false } default: e.err = ErrInvalidOperation return false } return true }
func getTimeValue(ctx context.Context, v interface{}, tp byte, fsp int) (interface{}, error) { value := mysql.Time{ Type: tp, Fsp: fsp, } defaultTime, err := getSystemTimestamp(ctx) if err != nil { return nil, errors.Trace(err) } switch x := v.(type) { case string: if x == CurrentTimestamp { value.Time = defaultTime } else if x == ZeroTimestamp { value, _ = mysql.ParseTimeFromNum(0, tp, fsp) } else { value, err = mysql.ParseTime(x, tp, fsp) if err != nil { return nil, errors.Trace(err) } } case Value: x.Val = types.RawData(x.Val) switch xval := x.Val.(type) { case string: value, err = mysql.ParseTime(xval, tp, fsp) if err != nil { return nil, errors.Trace(err) } case int64: value, err = mysql.ParseTimeFromNum(int64(xval), tp, fsp) if err != nil { return nil, errors.Trace(err) } case nil: return nil, nil default: return nil, errors.Trace(errDefaultValue) } case *Ident: if x.Equal(CurrentTimeExpr) { return CurrentTimestamp, nil } return nil, errors.Trace(errDefaultValue) case *UnaryOperation: // support some expression, like `-1` m := map[interface{}]interface{}{} v := Eval(x, nil, m) ft := types.NewFieldType(mysql.TypeLonglong) xval, err := types.Convert(v, ft) if err != nil { return nil, errors.Trace(err) } value, err = mysql.ParseTimeFromNum(xval.(int64), tp, fsp) if err != nil { return nil, errors.Trace(err) } default: return nil, nil } return value, nil }
// Eval implements the Expression Eval interface. func (u *UnaryOperation) Eval(ctx context.Context, args map[interface{}]interface{}) (r interface{}, err error) { defer func() { if e := recover(); e != nil { r, err = nil, errors.Errorf("%v", e) } }() switch op := u.Op; op { case opcode.Not: a := Eval(u.V, ctx, args) a = types.RawData(a) if a == nil { return } n, err := types.ToBool(a) if err != nil { return types.UndOp(a, op) } else if n == 0 { return int64(1), nil } return int64(0), nil case opcode.BitNeg: a := Eval(u.V, ctx, args) a = types.RawData(a) if a == nil { return } // for bit operation, we will use int64 first, then return uint64 n, err := types.ToInt64(a) if err != nil { return types.UndOp(a, op) } return uint64(^n), nil case opcode.Plus: a := Eval(u.V, ctx, args) a = types.RawData(a) if a == nil { return } switch x := a.(type) { case nil: return nil, nil case bool: if x { return int64(1), nil } return int64(0), nil case float32: return +x, nil case float64: return +x, nil case int: return +x, nil case int8: return +x, nil case int16: return +x, nil case int32: return +x, nil case int64: return +x, nil case uint: return +x, nil case uint8: return +x, nil case uint16: return +x, nil case uint32: return +x, nil case uint64: return +x, nil case mysql.Duration: return x, nil case mysql.Time: return x, nil case string: return x, nil case mysql.Decimal: return x, nil case []byte: return x, nil case mysql.Hex: return x, nil case mysql.Bit: return x, nil case mysql.Enum: return x, nil case mysql.Set: return x, nil default: return types.UndOp(a, op) } case opcode.Minus: a := Eval(u.V, ctx, args) a = types.RawData(a) if a == nil { return } switch x := a.(type) { case nil: return nil, nil case bool: if x { return int64(-1), nil } return int64(0), nil case float32: return -x, nil case float64: return -x, nil case int: return -x, nil case int8: return -x, nil case int16: return -x, nil case int32: return -x, nil case int64: return -x, nil case uint: return -int64(x), nil case uint8: return -int64(x), nil case uint16: return -int64(x), nil case uint32: return -int64(x), nil case uint64: // TODO: check overflow and do more test for unsigned type return -int64(x), nil case mysql.Duration: return mysql.ZeroDecimal.Sub(x.ToNumber()), nil case mysql.Time: return mysql.ZeroDecimal.Sub(x.ToNumber()), nil case string: f, err := types.StrToFloat(x) return -f, err case mysql.Decimal: f, _ := x.Float64() return mysql.NewDecimalFromFloat(-f), nil case []byte: f, err := types.StrToFloat(string(x)) return -f, err case mysql.Hex: return -x.ToNumber(), nil case mysql.Bit: return -x.ToNumber(), nil case mysql.Enum: return -x.ToNumber(), nil case mysql.Set: return -x.ToNumber(), nil default: return types.UndOp(a, op) } default: panic("should never happen") } }