Пример #1
0
// Operator: DIV / - % MOD + *
// See https://dev.mysql.com/doc/refman/5.7/en/arithmetic-functions.html#operator_divide
func (o *BinaryOperation) evalArithmeticOp(ctx context.Context, args map[interface{}]interface{}) (interface{}, error) {
	a, b, err := o.get2(ctx, args)
	if err != nil {
		return nil, err
	}

	a = types.RawData(a)
	b = types.RawData(b)

	if a == nil || b == nil {
		return nil, nil
	}

	if a, b, err = o.coerceArithmetic2(a, b); err != nil {
		return nil, o.traceErr(err)
	}

	switch o.Op {
	case opcode.Plus:
		return o.evalPlus(a, b)
	case opcode.Minus:
		return o.evalMinus(a, b)
	case opcode.Mul:
		return o.evalMul(a, b)
	case opcode.Div:
		return o.evalDiv(a, b)
	case opcode.Mod:
		return o.evalMod(a, b)
	case opcode.IntDiv:
		return o.evalIntDiv(a, b)
	default:
		return nil, o.errorf("invalid op %v in arithmetic operation", o.Op)
	}
}
Пример #2
0
func (a *recordsetAdapter) Next() (*oplan.Row, error) {
	row, err := a.executor.Next()
	if err != nil || row == nil {
		return nil, errors.Trace(err)
	}
	oRow := &oplan.Row{
		Data:    make([]interface{}, len(row.Data)),
		RowKeys: make([]*oplan.RowKeyEntry, 0, len(row.RowKeys)),
	}
	for i, v := range row.Data {
		d := types.RawData(v)
		switch v := d.(type) {
		case bool:
			// Convert bool field to int
			if v {
				oRow.Data[i] = uint8(1)
			} else {
				oRow.Data[i] = uint8(0)
			}
		default:
			oRow.Data[i] = d
		}
	}
	for _, v := range row.RowKeys {
		oldRowKey := &oplan.RowKeyEntry{
			Key: v.Key,
			Tbl: v.Tbl,
		}
		oRow.RowKeys = append(oRow.RowKeys, oldRowKey)
	}
	return oRow, nil
}
Пример #3
0
// NewValueExpr creates a ValueExpr with value, and sets default field type.
func NewValueExpr(value interface{}) *ValueExpr {
	ve := &ValueExpr{}
	ve.Data = types.RawData(value)
	// TODO: make it more precise.
	switch value.(type) {
	case nil:
		ve.Type = types.NewFieldType(mysql.TypeNull)
	case bool, int64:
		ve.Type = types.NewFieldType(mysql.TypeLonglong)
	case uint64:
		ve.Type = types.NewFieldType(mysql.TypeLonglong)
		ve.Type.Flag |= mysql.UnsignedFlag
	case string, UnquoteString:
		ve.Type = types.NewFieldType(mysql.TypeVarchar)
		ve.Type.Charset = mysql.DefaultCharset
		ve.Type.Collate = mysql.DefaultCollationName
	case float64:
		ve.Type = types.NewFieldType(mysql.TypeDouble)
	case []byte:
		ve.Type = types.NewFieldType(mysql.TypeBlob)
		ve.Type.Charset = "binary"
		ve.Type.Collate = "binary"
	case mysql.Bit:
		ve.Type = types.NewFieldType(mysql.TypeBit)
	case mysql.Hex:
		ve.Type = types.NewFieldType(mysql.TypeVarchar)
		ve.Type.Charset = "binary"
		ve.Type.Collate = "binary"
	case *types.DataItem:
		ve.Type = value.(*types.DataItem).Type
	default:
		panic(fmt.Sprintf("illegal literal value type:%T", value))
	}
	return ve
}
Пример #4
0
// Eval is a helper function evaluates expression v and do a panic if evaluating error.
func Eval(v Expression, ctx context.Context, env map[interface{}]interface{}) (y interface{}) {
	var err error
	y, err = v.Eval(ctx, env)
	if err != nil {
		panic(err) // panic ok here
	}
	y = types.RawData(y)
	return
}
Пример #5
0
// NewValueExpr creates a ValueExpr with value, and sets default field type.
func NewValueExpr(value interface{}) *ValueExpr {
	ve := &ValueExpr{}
	ve.Data = types.RawData(value)
	if _, ok := value.(UnquoteString); ok {
		ve.Type = types.NewFieldType(mysql.TypeVarchar)
		ve.Type.Charset = mysql.DefaultCharset
		ve.Type.Collate = mysql.DefaultCollationName
		return ve
	}
	ve.Type = types.DefaultTypeForValue(value)
	return ve
}
Пример #6
0
func (e *Evaluator) handleArithmeticOp(o *ast.BinaryOperationExpr) bool {
	a, err := coerceArithmetic(types.RawData(o.L.GetValue()))
	if err != nil {
		e.err = errors.Trace(err)
		return false
	}

	b, err := coerceArithmetic(types.RawData(o.R.GetValue()))
	if err != nil {
		e.err = errors.Trace(err)
		return false
	}

	a, b = types.Coerce(a, b)
	if a == nil || b == nil {
		o.SetValue(nil)
		return true
	}

	var result interface{}
	switch o.Op {
	case opcode.Plus:
		result, e.err = computePlus(a, b)
	case opcode.Minus:
		result, e.err = computeMinus(a, b)
	case opcode.Mul:
		result, e.err = computeMul(a, b)
	case opcode.Div:
		result, e.err = computeDiv(a, b)
	case opcode.Mod:
		result, e.err = computeMod(a, b)
	case opcode.IntDiv:
		result, e.err = computeIntDiv(a, b)
	default:
		e.err = ErrInvalidOperation.Gen("invalid op %v in arithmetic operation", o.Op)
		return false
	}
	o.SetValue(result)
	return e.err == nil
}
Пример #7
0
func calculateSum(sum interface{}, v interface{}) (interface{}, error) {
	// for avg and sum calculation
	// avg and sum use decimal for integer and decimal type, use float for others
	// see https://dev.mysql.com/doc/refman/5.7/en/group-by-functions.html
	var (
		data interface{}
		err  error
	)

	v = types.RawData(v)
	switch y := v.(type) {
	case int, uint, int8, uint8, int16, uint16, int32, uint32, int64, uint64:
		data, err = mysql.ConvertToDecimal(v)
	case mysql.Decimal:
		data = y
	case nil:
		data = nil
	default:
		data, err = types.ToFloat64(v)
	}

	if err != nil {
		return nil, err
	}
	if data == nil {
		return sum, nil
	}
	data = types.RawData(data)
	switch x := sum.(type) {
	case nil:
		return data, nil
	case float64:
		return x + data.(float64), nil
	case mysql.Decimal:
		return x.Add(data.(mysql.Decimal)), nil
	default:
		return nil, errors.Errorf("invalid value %v(%T) for aggregate", x, x)
	}
}
Пример #8
0
// FastEval evaluates Value and static +/- Unary expression and returns its value.
func FastEval(v interface{}) interface{} {
	v = types.RawData(v)
	switch x := v.(type) {
	case Value:
		return x.Val
	case int64, uint64:
		return v
	case *UnaryOperation:
		if x.Op != opcode.Plus && x.Op != opcode.Minus {
			return nil
		}
		if !x.IsStatic() {
			return nil
		}
		m := map[interface{}]interface{}{}
		return Eval(x, nil, m)
	default:
		return nil
	}
}
Пример #9
0
// Eval implements the Expression Eval interface.
func (f *FunctionCast) Eval(ctx context.Context, args map[interface{}]interface{}) (interface{}, error) {
	value, err := f.Expr.Eval(ctx, args)
	if err != nil {
		return nil, errors.Trace(err)
	}
	value = types.RawData(value)
	d := &types.DataItem{Type: f.Tp}
	// Casting nil to any type returns null
	if value == nil {
		d.Data = nil
		return d, nil
	}

	d.Data, err = types.Cast(value, f.Tp)
	if err != nil {
		return nil, errors.Trace(err)
	}

	return d, nil
}
Пример #10
0
func getDefaultValue(c *ConstraintOpt, tp byte, fsp int) (interface{}, error) {
	if tp == mysql.TypeTimestamp || tp == mysql.TypeDatetime {
		value, err := expression.GetTimeValue(nil, c.Evalue, tp, fsp)
		if err != nil {
			return nil, errors.Trace(err)
		}

		// Value is nil means `default null`.
		if value == nil {
			return nil, nil
		}

		// If value is mysql.Time, convert it to string.
		if vv, ok := value.(mysql.Time); ok {
			return vv.String(), nil
		}

		return value, nil
	}
	v := expression.FastEval(c.Evalue)
	return types.RawData(v), nil
}
Пример #11
0
func builtinAbs(args []interface{}, ctx map[interface{}]interface{}) (v interface{}, err error) {
	d := types.RawData(args[0])
	switch x := d.(type) {
	case nil:
		return nil, nil
	case uint, uint8, uint16, uint32, uint64:
		return x, nil
	case int, int8, int16, int32, int64:
		// we don't need to handle error here, it must be success
		v, _ := types.ToInt64(d)
		if v >= 0 {
			return x, nil
		}

		// TODO: handle overflow if x is MinInt64
		return -v, nil
	default:
		// we will try to convert other types to float
		// TODO: if time has no precision, it will be a integer
		f, err := types.ToFloat64(d)
		return math.Abs(f), errors.Trace(err)
	}
}
Пример #12
0
func (e *Evaluator) unaryOperation(u *ast.UnaryOperationExpr) bool {
	defer func() {
		if er := recover(); er != nil {
			e.err = errors.Errorf("%v", er)
		}
	}()
	a := u.V.GetValue()
	a = types.RawData(a)
	if a == nil {
		u.SetValue(nil)
		return true
	}
	switch op := u.Op; op {
	case opcode.Not:
		n, err := types.ToBool(a)
		if err != nil {
			e.err = errors.Trace(err)
		} else if n == 0 {
			u.SetValue(int64(1))
		} else {
			u.SetValue(int64(0))
		}
	case opcode.BitNeg:
		// for bit operation, we will use int64 first, then return uint64
		n, err := types.ToInt64(a)
		if err != nil {
			e.err = errors.Trace(err)
			return false
		}
		u.SetValue(uint64(^n))
	case opcode.Plus:
		switch x := a.(type) {
		case bool:
			u.SetValue(boolToInt64(x))
		case float32:
			u.SetValue(+x)
		case float64:
			u.SetValue(+x)
		case int:
			u.SetValue(+x)
		case int8:
			u.SetValue(+x)
		case int16:
			u.SetValue(+x)
		case int32:
			u.SetValue(+x)
		case int64:
			u.SetValue(+x)
		case uint:
			u.SetValue(+x)
		case uint8:
			u.SetValue(+x)
		case uint16:
			u.SetValue(+x)
		case uint32:
			u.SetValue(+x)
		case uint64:
			u.SetValue(+x)
		case mysql.Duration:
			u.SetValue(x)
		case mysql.Time:
			u.SetValue(x)
		case string:
			u.SetValue(x)
		case mysql.Decimal:
			u.SetValue(x)
		case []byte:
			u.SetValue(x)
		case mysql.Hex:
			u.SetValue(x)
		case mysql.Bit:
			u.SetValue(x)
		case mysql.Enum:
			u.SetValue(x)
		case mysql.Set:
			u.SetValue(x)
		default:
			e.err = ErrInvalidOperation
			return false
		}
	case opcode.Minus:
		switch x := a.(type) {
		case bool:
			if x {
				u.SetValue(int64(-1))
			} else {
				u.SetValue(int64(0))
			}
		case float32:
			u.SetValue(-x)
		case float64:
			u.SetValue(-x)
		case int:
			u.SetValue(-x)
		case int8:
			u.SetValue(-x)
		case int16:
			u.SetValue(-x)
		case int32:
			u.SetValue(-x)
		case int64:
			u.SetValue(-x)
		case uint:
			u.SetValue(-int64(x))
		case uint8:
			u.SetValue(-int64(x))
		case uint16:
			u.SetValue(-int64(x))
		case uint32:
			u.SetValue(-int64(x))
		case uint64:
			// TODO: check overflow and do more test for unsigned type
			u.SetValue(-int64(x))
		case mysql.Duration:
			u.SetValue(mysql.ZeroDecimal.Sub(x.ToNumber()))
		case mysql.Time:
			u.SetValue(mysql.ZeroDecimal.Sub(x.ToNumber()))
		case string:
			f, err := types.StrToFloat(x)
			e.err = errors.Trace(err)
			u.SetValue(-f)
		case mysql.Decimal:
			f, _ := x.Float64()
			u.SetValue(mysql.NewDecimalFromFloat(-f))
		case []byte:
			f, err := types.StrToFloat(string(x))
			e.err = errors.Trace(err)
			u.SetValue(-f)
		case mysql.Hex:
			u.SetValue(-x.ToNumber())
		case mysql.Bit:
			u.SetValue(-x.ToNumber())
		case mysql.Enum:
			u.SetValue(-x.ToNumber())
		case mysql.Set:
			u.SetValue(-x.ToNumber())
		default:
			e.err = ErrInvalidOperation
			return false
		}
	default:
		e.err = ErrInvalidOperation
		return false
	}

	return true
}
Пример #13
0
func getTimeValue(ctx context.Context, v interface{}, tp byte, fsp int) (interface{}, error) {
	value := mysql.Time{
		Type: tp,
		Fsp:  fsp,
	}

	defaultTime, err := getSystemTimestamp(ctx)
	if err != nil {
		return nil, errors.Trace(err)
	}

	switch x := v.(type) {
	case string:
		if x == CurrentTimestamp {
			value.Time = defaultTime
		} else if x == ZeroTimestamp {
			value, _ = mysql.ParseTimeFromNum(0, tp, fsp)
		} else {
			value, err = mysql.ParseTime(x, tp, fsp)
			if err != nil {
				return nil, errors.Trace(err)
			}
		}
	case Value:
		x.Val = types.RawData(x.Val)
		switch xval := x.Val.(type) {
		case string:
			value, err = mysql.ParseTime(xval, tp, fsp)
			if err != nil {
				return nil, errors.Trace(err)
			}
		case int64:
			value, err = mysql.ParseTimeFromNum(int64(xval), tp, fsp)
			if err != nil {
				return nil, errors.Trace(err)
			}
		case nil:
			return nil, nil
		default:
			return nil, errors.Trace(errDefaultValue)
		}
	case *Ident:
		if x.Equal(CurrentTimeExpr) {
			return CurrentTimestamp, nil
		}

		return nil, errors.Trace(errDefaultValue)
	case *UnaryOperation:
		// support some expression, like `-1`
		m := map[interface{}]interface{}{}
		v := Eval(x, nil, m)
		ft := types.NewFieldType(mysql.TypeLonglong)
		xval, err := types.Convert(v, ft)
		if err != nil {
			return nil, errors.Trace(err)
		}

		value, err = mysql.ParseTimeFromNum(xval.(int64), tp, fsp)
		if err != nil {
			return nil, errors.Trace(err)
		}
	default:
		return nil, nil
	}

	return value, nil
}
Пример #14
0
// Eval implements the Expression Eval interface.
func (u *UnaryOperation) Eval(ctx context.Context, args map[interface{}]interface{}) (r interface{}, err error) {
	defer func() {
		if e := recover(); e != nil {
			r, err = nil, errors.Errorf("%v", e)
		}
	}()

	switch op := u.Op; op {
	case opcode.Not:
		a := Eval(u.V, ctx, args)
		a = types.RawData(a)
		if a == nil {
			return
		}
		n, err := types.ToBool(a)
		if err != nil {
			return types.UndOp(a, op)
		} else if n == 0 {
			return int64(1), nil
		}
		return int64(0), nil
	case opcode.BitNeg:
		a := Eval(u.V, ctx, args)
		a = types.RawData(a)
		if a == nil {
			return
		}
		// for bit operation, we will use int64 first, then return uint64
		n, err := types.ToInt64(a)
		if err != nil {
			return types.UndOp(a, op)
		}

		return uint64(^n), nil
	case opcode.Plus:
		a := Eval(u.V, ctx, args)
		a = types.RawData(a)
		if a == nil {
			return
		}
		switch x := a.(type) {
		case nil:
			return nil, nil
		case bool:
			if x {
				return int64(1), nil
			}
			return int64(0), nil
		case float32:
			return +x, nil
		case float64:
			return +x, nil
		case int:
			return +x, nil
		case int8:
			return +x, nil
		case int16:
			return +x, nil
		case int32:
			return +x, nil
		case int64:
			return +x, nil
		case uint:
			return +x, nil
		case uint8:
			return +x, nil
		case uint16:
			return +x, nil
		case uint32:
			return +x, nil
		case uint64:
			return +x, nil
		case mysql.Duration:
			return x, nil
		case mysql.Time:
			return x, nil
		case string:
			return x, nil
		case mysql.Decimal:
			return x, nil
		case []byte:
			return x, nil
		case mysql.Hex:
			return x, nil
		case mysql.Bit:
			return x, nil
		case mysql.Enum:
			return x, nil
		case mysql.Set:
			return x, nil
		default:
			return types.UndOp(a, op)
		}
	case opcode.Minus:
		a := Eval(u.V, ctx, args)
		a = types.RawData(a)
		if a == nil {
			return
		}

		switch x := a.(type) {
		case nil:
			return nil, nil
		case bool:
			if x {
				return int64(-1), nil
			}
			return int64(0), nil
		case float32:
			return -x, nil
		case float64:
			return -x, nil
		case int:
			return -x, nil
		case int8:
			return -x, nil
		case int16:
			return -x, nil
		case int32:
			return -x, nil
		case int64:
			return -x, nil
		case uint:
			return -int64(x), nil
		case uint8:
			return -int64(x), nil
		case uint16:
			return -int64(x), nil
		case uint32:
			return -int64(x), nil
		case uint64:
			// TODO: check overflow and do more test for unsigned type
			return -int64(x), nil
		case mysql.Duration:
			return mysql.ZeroDecimal.Sub(x.ToNumber()), nil
		case mysql.Time:
			return mysql.ZeroDecimal.Sub(x.ToNumber()), nil
		case string:
			f, err := types.StrToFloat(x)
			return -f, err
		case mysql.Decimal:
			f, _ := x.Float64()
			return mysql.NewDecimalFromFloat(-f), nil
		case []byte:
			f, err := types.StrToFloat(string(x))
			return -f, err
		case mysql.Hex:
			return -x.ToNumber(), nil
		case mysql.Bit:
			return -x.ToNumber(), nil
		case mysql.Enum:
			return -x.ToNumber(), nil
		case mysql.Set:
			return -x.ToNumber(), nil
		default:
			return types.UndOp(a, op)
		}
	default:
		panic("should never happen")
	}
}