예제 #1
0
func (v *typeInferrer) unaryOperation(x *ast.UnaryOperationExpr) {
	switch x.Op {
	case opcode.Not:
		x.Type = types.NewFieldType(mysql.TypeLonglong)
	case opcode.BitNeg:
		x.Type = types.NewFieldType(mysql.TypeLonglong)
		x.Type.Flag |= mysql.UnsignedFlag
	case opcode.Plus:
		x.Type = x.V.GetType()
	case opcode.Minus:
		x.Type = types.NewFieldType(mysql.TypeLonglong)
		if x.V.GetType() != nil {
			switch x.V.GetType().Tp {
			case mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeDouble, mysql.TypeFloat:
				x.Type.Tp = mysql.TypeDouble
			case mysql.TypeNewDecimal:
				x.Type.Tp = mysql.TypeNewDecimal
			}
		}
	}
}
예제 #2
0
func (e *Evaluator) unaryOperation(u *ast.UnaryOperationExpr) bool {
	defer func() {
		if er := recover(); er != nil {
			e.err = errors.Errorf("%v", er)
		}
	}()
	aDatum := u.V.GetDatum()
	if aDatum.IsNull() {
		u.SetNull()
		return true
	}
	switch op := u.Op; op {
	case opcode.Not:
		n, err := aDatum.ToBool()
		if err != nil {
			e.err = errors.Trace(err)
		} else if n == 0 {
			u.SetInt64(1)
		} else {
			u.SetInt64(0)
		}
	case opcode.BitNeg:
		// for bit operation, we will use int64 first, then return uint64
		n, err := aDatum.ToInt64()
		if err != nil {
			e.err = errors.Trace(err)
			return false
		}
		u.SetUint64(uint64(^n))
	case opcode.Plus:
		switch aDatum.Kind() {
		case types.KindInt64,
			types.KindUint64,
			types.KindFloat64,
			types.KindFloat32,
			types.KindMysqlDuration,
			types.KindMysqlTime,
			types.KindString,
			types.KindMysqlDecimal,
			types.KindBytes,
			types.KindMysqlHex,
			types.KindMysqlBit,
			types.KindMysqlEnum,
			types.KindMysqlSet:
			u.SetDatum(*aDatum)
		default:
			e.err = ErrInvalidOperation
			return false
		}
	case opcode.Minus:
		switch aDatum.Kind() {
		case types.KindInt64:
			u.SetInt64(-aDatum.GetInt64())
		case types.KindUint64:
			u.SetInt64(-int64(aDatum.GetUint64()))
		case types.KindFloat64:
			u.SetFloat64(-aDatum.GetFloat64())
		case types.KindFloat32:
			u.SetFloat32(-aDatum.GetFloat32())
		case types.KindMysqlDuration:
			var to = new(mysql.MyDecimal)
			var zero mysql.MyDecimal
			mysql.DecimalSub(&zero, aDatum.GetMysqlDuration().ToNumber(), to)
			u.SetMysqlDecimal(to)
		case types.KindMysqlTime:
			dec := aDatum.GetMysqlTime().ToNumber()
			var zero, to mysql.MyDecimal
			mysql.DecimalSub(&zero, dec, &to)
			u.SetMysqlDecimal(&to)
		case types.KindString, types.KindBytes:
			f, err := types.StrToFloat(aDatum.GetString())
			e.err = errors.Trace(err)
			u.SetFloat64(-f)
		case types.KindMysqlDecimal:
			dec := aDatum.GetMysqlDecimal()
			var zero, to mysql.MyDecimal
			mysql.DecimalSub(&zero, dec, &to)
			u.SetMysqlDecimal(&to)
		case types.KindMysqlHex:
			u.SetFloat64(-aDatum.GetMysqlHex().ToNumber())
		case types.KindMysqlBit:
			u.SetFloat64(-aDatum.GetMysqlBit().ToNumber())
		case types.KindMysqlEnum:
			u.SetFloat64(-aDatum.GetMysqlEnum().ToNumber())
		case types.KindMysqlSet:
			u.SetFloat64(-aDatum.GetMysqlSet().ToNumber())
		default:
			e.err = ErrInvalidOperation
			return false
		}
	default:
		e.err = ErrInvalidOperation
		return false
	}

	return true
}
예제 #3
0
파일: evaluator.go 프로젝트: mrtoms/tidb
func (e *Evaluator) unaryOperation(u *ast.UnaryOperationExpr) bool {
	defer func() {
		if er := recover(); er != nil {
			e.err = errors.Errorf("%v", er)
		}
	}()
	a := u.V.GetValue()
	a = types.RawData(a)
	if a == nil {
		u.SetValue(nil)
		return true
	}
	switch op := u.Op; op {
	case opcode.Not:
		n, err := types.ToBool(a)
		if err != nil {
			e.err = errors.Trace(err)
		} else if n == 0 {
			u.SetValue(int64(1))
		} else {
			u.SetValue(int64(0))
		}
	case opcode.BitNeg:
		// for bit operation, we will use int64 first, then return uint64
		n, err := types.ToInt64(a)
		if err != nil {
			e.err = errors.Trace(err)
			return false
		}
		u.SetValue(uint64(^n))
	case opcode.Plus:
		switch x := a.(type) {
		case bool:
			u.SetValue(boolToInt64(x))
		case float32:
			u.SetValue(+x)
		case float64:
			u.SetValue(+x)
		case int:
			u.SetValue(+x)
		case int8:
			u.SetValue(+x)
		case int16:
			u.SetValue(+x)
		case int32:
			u.SetValue(+x)
		case int64:
			u.SetValue(+x)
		case uint:
			u.SetValue(+x)
		case uint8:
			u.SetValue(+x)
		case uint16:
			u.SetValue(+x)
		case uint32:
			u.SetValue(+x)
		case uint64:
			u.SetValue(+x)
		case mysql.Duration:
			u.SetValue(x)
		case mysql.Time:
			u.SetValue(x)
		case string:
			u.SetValue(x)
		case mysql.Decimal:
			u.SetValue(x)
		case []byte:
			u.SetValue(x)
		case mysql.Hex:
			u.SetValue(x)
		case mysql.Bit:
			u.SetValue(x)
		case mysql.Enum:
			u.SetValue(x)
		case mysql.Set:
			u.SetValue(x)
		default:
			e.err = ErrInvalidOperation
			return false
		}
	case opcode.Minus:
		switch x := a.(type) {
		case bool:
			if x {
				u.SetValue(int64(-1))
			} else {
				u.SetValue(int64(0))
			}
		case float32:
			u.SetValue(-x)
		case float64:
			u.SetValue(-x)
		case int:
			u.SetValue(-x)
		case int8:
			u.SetValue(-x)
		case int16:
			u.SetValue(-x)
		case int32:
			u.SetValue(-x)
		case int64:
			u.SetValue(-x)
		case uint:
			u.SetValue(-int64(x))
		case uint8:
			u.SetValue(-int64(x))
		case uint16:
			u.SetValue(-int64(x))
		case uint32:
			u.SetValue(-int64(x))
		case uint64:
			// TODO: check overflow and do more test for unsigned type
			u.SetValue(-int64(x))
		case mysql.Duration:
			u.SetValue(mysql.ZeroDecimal.Sub(x.ToNumber()))
		case mysql.Time:
			u.SetValue(mysql.ZeroDecimal.Sub(x.ToNumber()))
		case string:
			f, err := types.StrToFloat(x)
			e.err = errors.Trace(err)
			u.SetValue(-f)
		case mysql.Decimal:
			f, _ := x.Float64()
			u.SetValue(mysql.NewDecimalFromFloat(-f))
		case []byte:
			f, err := types.StrToFloat(string(x))
			e.err = errors.Trace(err)
			u.SetValue(-f)
		case mysql.Hex:
			u.SetValue(-x.ToNumber())
		case mysql.Bit:
			u.SetValue(-x.ToNumber())
		case mysql.Enum:
			u.SetValue(-x.ToNumber())
		case mysql.Set:
			u.SetValue(-x.ToNumber())
		default:
			e.err = ErrInvalidOperation
			return false
		}
	default:
		e.err = ErrInvalidOperation
		return false
	}

	return true
}