func (v *typeInferrer) unaryOperation(x *ast.UnaryOperationExpr) { switch x.Op { case opcode.Not: x.Type = types.NewFieldType(mysql.TypeLonglong) case opcode.BitNeg: x.Type = types.NewFieldType(mysql.TypeLonglong) x.Type.Flag |= mysql.UnsignedFlag case opcode.Plus: x.Type = x.V.GetType() case opcode.Minus: x.Type = types.NewFieldType(mysql.TypeLonglong) if x.V.GetType() != nil { switch x.V.GetType().Tp { case mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeDouble, mysql.TypeFloat: x.Type.Tp = mysql.TypeDouble case mysql.TypeNewDecimal: x.Type.Tp = mysql.TypeNewDecimal } } } }
func (e *Evaluator) unaryOperation(u *ast.UnaryOperationExpr) bool { defer func() { if er := recover(); er != nil { e.err = errors.Errorf("%v", er) } }() aDatum := u.V.GetDatum() if aDatum.IsNull() { u.SetNull() return true } switch op := u.Op; op { case opcode.Not: n, err := aDatum.ToBool() if err != nil { e.err = errors.Trace(err) } else if n == 0 { u.SetInt64(1) } else { u.SetInt64(0) } case opcode.BitNeg: // for bit operation, we will use int64 first, then return uint64 n, err := aDatum.ToInt64() if err != nil { e.err = errors.Trace(err) return false } u.SetUint64(uint64(^n)) case opcode.Plus: switch aDatum.Kind() { case types.KindInt64, types.KindUint64, types.KindFloat64, types.KindFloat32, types.KindMysqlDuration, types.KindMysqlTime, types.KindString, types.KindMysqlDecimal, types.KindBytes, types.KindMysqlHex, types.KindMysqlBit, types.KindMysqlEnum, types.KindMysqlSet: u.SetDatum(*aDatum) default: e.err = ErrInvalidOperation return false } case opcode.Minus: switch aDatum.Kind() { case types.KindInt64: u.SetInt64(-aDatum.GetInt64()) case types.KindUint64: u.SetInt64(-int64(aDatum.GetUint64())) case types.KindFloat64: u.SetFloat64(-aDatum.GetFloat64()) case types.KindFloat32: u.SetFloat32(-aDatum.GetFloat32()) case types.KindMysqlDuration: var to = new(mysql.MyDecimal) var zero mysql.MyDecimal mysql.DecimalSub(&zero, aDatum.GetMysqlDuration().ToNumber(), to) u.SetMysqlDecimal(to) case types.KindMysqlTime: dec := aDatum.GetMysqlTime().ToNumber() var zero, to mysql.MyDecimal mysql.DecimalSub(&zero, dec, &to) u.SetMysqlDecimal(&to) case types.KindString, types.KindBytes: f, err := types.StrToFloat(aDatum.GetString()) e.err = errors.Trace(err) u.SetFloat64(-f) case types.KindMysqlDecimal: dec := aDatum.GetMysqlDecimal() var zero, to mysql.MyDecimal mysql.DecimalSub(&zero, dec, &to) u.SetMysqlDecimal(&to) case types.KindMysqlHex: u.SetFloat64(-aDatum.GetMysqlHex().ToNumber()) case types.KindMysqlBit: u.SetFloat64(-aDatum.GetMysqlBit().ToNumber()) case types.KindMysqlEnum: u.SetFloat64(-aDatum.GetMysqlEnum().ToNumber()) case types.KindMysqlSet: u.SetFloat64(-aDatum.GetMysqlSet().ToNumber()) default: e.err = ErrInvalidOperation return false } default: e.err = ErrInvalidOperation return false } return true }
func (e *Evaluator) unaryOperation(u *ast.UnaryOperationExpr) bool { defer func() { if er := recover(); er != nil { e.err = errors.Errorf("%v", er) } }() a := u.V.GetValue() a = types.RawData(a) if a == nil { u.SetValue(nil) return true } switch op := u.Op; op { case opcode.Not: n, err := types.ToBool(a) if err != nil { e.err = errors.Trace(err) } else if n == 0 { u.SetValue(int64(1)) } else { u.SetValue(int64(0)) } case opcode.BitNeg: // for bit operation, we will use int64 first, then return uint64 n, err := types.ToInt64(a) if err != nil { e.err = errors.Trace(err) return false } u.SetValue(uint64(^n)) case opcode.Plus: switch x := a.(type) { case bool: u.SetValue(boolToInt64(x)) case float32: u.SetValue(+x) case float64: u.SetValue(+x) case int: u.SetValue(+x) case int8: u.SetValue(+x) case int16: u.SetValue(+x) case int32: u.SetValue(+x) case int64: u.SetValue(+x) case uint: u.SetValue(+x) case uint8: u.SetValue(+x) case uint16: u.SetValue(+x) case uint32: u.SetValue(+x) case uint64: u.SetValue(+x) case mysql.Duration: u.SetValue(x) case mysql.Time: u.SetValue(x) case string: u.SetValue(x) case mysql.Decimal: u.SetValue(x) case []byte: u.SetValue(x) case mysql.Hex: u.SetValue(x) case mysql.Bit: u.SetValue(x) case mysql.Enum: u.SetValue(x) case mysql.Set: u.SetValue(x) default: e.err = ErrInvalidOperation return false } case opcode.Minus: switch x := a.(type) { case bool: if x { u.SetValue(int64(-1)) } else { u.SetValue(int64(0)) } case float32: u.SetValue(-x) case float64: u.SetValue(-x) case int: u.SetValue(-x) case int8: u.SetValue(-x) case int16: u.SetValue(-x) case int32: u.SetValue(-x) case int64: u.SetValue(-x) case uint: u.SetValue(-int64(x)) case uint8: u.SetValue(-int64(x)) case uint16: u.SetValue(-int64(x)) case uint32: u.SetValue(-int64(x)) case uint64: // TODO: check overflow and do more test for unsigned type u.SetValue(-int64(x)) case mysql.Duration: u.SetValue(mysql.ZeroDecimal.Sub(x.ToNumber())) case mysql.Time: u.SetValue(mysql.ZeroDecimal.Sub(x.ToNumber())) case string: f, err := types.StrToFloat(x) e.err = errors.Trace(err) u.SetValue(-f) case mysql.Decimal: f, _ := x.Float64() u.SetValue(mysql.NewDecimalFromFloat(-f)) case []byte: f, err := types.StrToFloat(string(x)) e.err = errors.Trace(err) u.SetValue(-f) case mysql.Hex: u.SetValue(-x.ToNumber()) case mysql.Bit: u.SetValue(-x.ToNumber()) case mysql.Enum: u.SetValue(-x.ToNumber()) case mysql.Set: u.SetValue(-x.ToNumber()) default: e.err = ErrInvalidOperation return false } default: e.err = ErrInvalidOperation return false } return true }