예제 #1
0
파일: convert.go 프로젝트: nengwang/tidb
// Convert converts the val with type tp.
func Convert(val interface{}, target *FieldType) (v interface{}, err error) { //NTYPE
	tp := target.Tp
	if val == nil {
		return nil, nil
	}
	switch tp { // TODO: implement mysql types convert when "CAST() AS" syntax are supported.
	case mysql.TypeFloat:
		x, err := ToFloat64(val)
		if err != nil {
			return invConv(val, tp)
		}
		if target.Flen != UnspecifiedLength {
			x, err = TruncateFloat(x, target.Flen, target.Decimal)
			if err != nil {
				return nil, errors.Trace(err)
			}
		}
		return float32(x), nil
	case mysql.TypeDouble:
		x, err := ToFloat64(val)
		if err != nil {
			return invConv(val, tp)
		}
		if target.Flen != UnspecifiedLength {
			x, err = TruncateFloat(x, target.Flen, target.Decimal)
			if err != nil {
				return nil, errors.Trace(err)
			}
		}
		return float64(x), nil
	case mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob,
		mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString:
		x, err := ToString(val)
		if err != nil {
			return invConv(val, tp)
		}
		// TODO: consider target.Charset/Collate
		x = truncateStr(x, target.Flen)
		if target.Charset == charset.CharsetBin {
			return []byte(x), nil
		}
		return x, nil
	case mysql.TypeDuration:
		fsp := mysql.DefaultFsp
		if target.Decimal != UnspecifiedLength {
			fsp = target.Decimal
		}
		switch x := val.(type) {
		case mysql.Duration:
			return x.RoundFrac(fsp)
		case mysql.Time:
			t, err := x.ConvertToDuration()
			if err != nil {
				return nil, errors.Trace(err)
			}

			return t.RoundFrac(fsp)
		case string:
			return mysql.ParseDuration(x, fsp)
		default:
			return invConv(val, tp)
		}
	case mysql.TypeTimestamp, mysql.TypeDatetime, mysql.TypeDate:
		fsp := mysql.DefaultFsp
		if target.Decimal != UnspecifiedLength {
			fsp = target.Decimal
		}
		switch x := val.(type) {
		case mysql.Time:
			t, err := x.Convert(tp)
			if err != nil {
				return nil, errors.Trace(err)
			}
			return t.RoundFrac(fsp)
		case mysql.Duration:
			t, err := x.ConvertToTime(tp)
			if err != nil {
				return nil, errors.Trace(err)
			}
			return t.RoundFrac(fsp)
		case string:
			return mysql.ParseTime(x, tp, fsp)
		case int64:
			return mysql.ParseTimeFromNum(x, tp, fsp)
		default:
			return invConv(val, tp)
		}
	case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong:
		unsigned := mysql.HasUnsignedFlag(target.Flag)
		if unsigned {
			return convertToUint(val, target)
		}
		return convertToInt(val, target)
	case mysql.TypeBit:
		x, err := convertToUint(val, target)
		if err != nil {
			return x, errors.Trace(err)
		}

		// check bit boundary, if bit has n width, the boundary is
		// in [0, (1 << n) - 1]
		width := target.Flen
		if width == 0 {
			width = mysql.MinBitWidth
		} else if width == mysql.UnspecifiedBitWidth {
			width = mysql.MaxBitWidth
		}

		maxValue := uint64(1)<<uint64(width) - 1

		if x > maxValue {
			return maxValue, overflow(val, tp)
		}
		return x, nil
	case mysql.TypeDecimal, mysql.TypeNewDecimal:
		x, err := ToDecimal(val)
		if err != nil {
			return invConv(val, tp)
		}
		if target.Decimal != UnspecifiedLength {
			x = x.Round(int32(target.Decimal))
		}
		// TODO: check Flen
		return x, nil
	case mysql.TypeYear:
		var (
			intVal int64
			err    error
		)
		switch x := val.(type) {
		case string:
			intVal, err = StrToInt(x)
		case mysql.Time:
			return int64(x.Year()), nil
		case mysql.Duration:
			return int64(time.Now().Year()), nil
		default:
			intVal, err = ToInt64(x)
		}
		if err != nil {
			return invConv(val, tp)
		}
		y, err := mysql.AdjustYear(int(intVal))
		if err != nil {
			return invConv(val, tp)
		}
		return int64(y), nil
	default:
		panic("should never happen")
	}
}
예제 #2
0
파일: column.go 프로젝트: romanticode/tidb
func (c *Col) castIntegerValue(val int64, errCode int) (casted interface{}, err error) {
	unsigned := mysql.HasUnsignedFlag(c.Flag)
	var overflow bool
	switch c.Tp {
	case mysql.TypeTiny:
		if unsigned {
			if val > math.MaxUint8 || errCode == errCodeOverflowUpper {
				overflow = true
				casted = uint8(math.MaxUint8)
			} else if val < 0 || errCode == errCodeOverflowLower {
				overflow = true
				casted = uint8(0)
			} else {
				casted = uint8(val)
			}
		} else {
			if val > math.MaxInt8 || errCode == errCodeOverflowUpper {
				overflow = true
				casted = int8(math.MaxInt8)
			} else if val < math.MinInt8 || errCode == errCodeOverflowLower {
				overflow = true
				casted = int8(math.MinInt8)
			} else {
				casted = int8(val)
			}
		}
	case mysql.TypeShort:
		if unsigned {
			if val > math.MaxUint16 || errCode == errCodeOverflowUpper {
				overflow = true
				casted = uint16(math.MaxUint16)
			} else if val < 0 || errCode == errCodeOverflowLower {
				overflow = true
				casted = uint16(0)
			} else {
				casted = uint16(val)
			}
		} else {
			if val > math.MaxInt16 || errCode == errCodeOverflowUpper {
				overflow = true
				casted = int16(math.MaxInt16)
			} else if val < math.MinInt16 || errCode == errCodeOverflowLower {
				overflow = true
				casted = int16(math.MinInt16)
			} else {
				casted = int16(val)
			}
		}
	case mysql.TypeYear:
		if val > int64(mysql.MaxYear) || errCode == errCodeOverflowUpper {
			overflow = true
			casted = mysql.MaxYear
		} else if val < int64(mysql.MinYear) {
			overflow = true
			casted = mysql.MinYear
		} else {
			casted, _ = mysql.AdjustYear(int(val))
		}
	case mysql.TypeInt24:
		if unsigned {
			if val < 0 || errCode == errCodeOverflowLower {
				overflow = true
				casted = uint32(0)
			} else if val > 1<<24-1 || errCode == errCodeOverflowUpper {
				overflow = true
				casted = uint32(1<<24 - 1)
			} else {
				casted = uint32(val)
			}
		} else {
			if val > 1<<23-1 || errCode == errCodeOverflowUpper {
				overflow = true
				casted = int32(1<<23 - 1)
			} else if val < -1<<23 || errCode == errCodeOverflowLower {
				overflow = true
				casted = int32(-1 << 23)
			} else {
				casted = int32(val)
			}
		}
	case mysql.TypeLong:
		if unsigned {
			if val > math.MaxUint32 || errCode == errCodeOverflowUpper {
				overflow = true
				casted = uint32(math.MaxUint32)
			} else if (val < 0 && errCode != errCodeOverflowMaxInt64) || errCode == errCodeOverflowLower {
				overflow = true
				casted = uint32(0)
			} else {
				casted = uint32(val)
			}
		} else {
			if val > math.MaxInt32 || errCode == errCodeOverflowUpper {
				overflow = true
				casted = int32(math.MaxInt32)
			} else if val < math.MinInt32 || errCode == errCodeOverflowLower {
				overflow = true
				casted = int32(math.MinInt32)
			} else {
				casted = int32(val)
			}
		}
	case mysql.TypeLonglong:
		// TypeLonglong overflow has already been handled by normalizeInteger
		if unsigned {
			if errCode == errCodeOverflowUpper {
				overflow = true
				casted = uint64(math.MaxUint64)
			} else if (val < 0 && errCode != errCodeOverflowMaxInt64) || errCode == errCodeOverflowLower {
				overflow = true
				casted = uint64(0)
			} else {
				casted = uint64(val)
			}
		} else {
			if errCode == errCodeOverflowUpper {
				overflow = true
				casted = int64(math.MaxInt64)
			} else if errCode == errCodeOverflowLower {
				overflow = true
				casted = int64(math.MinInt64)
			} else {
				casted = int64(val)
			}
		}
	case mysql.TypeBit:
		// Convert bit as uint64
		if errCode == errCodeOverflowUpper {
			overflow = true
			casted = uint64(math.MaxUint64)
		} else if (val < 0 && errCode != errCodeOverflowMaxInt64) || errCode == errCodeOverflowLower {
			overflow = true
			casted = uint64(0)
		} else {
			casted = uint64(val)
		}
	}
	if overflow {
		err = types.Overflow(val, c.Tp)
	}
	return
}
예제 #3
0
파일: convert.go 프로젝트: szctop/tidb
// Convert converts the val with type tp.
func Convert(val interface{}, target *FieldType) (v interface{}, err error) { //NTYPE
	tp := target.Tp
	if val == nil {
		return nil, nil
	}
	switch tp { // TODO: implement mysql types convert when "CAST() AS" syntax are supported.
	case mysql.TypeFloat:
		x, err := ToFloat64(val)
		if err != nil {
			return InvConv(val, tp)
		}
		if target.Flen != UnspecifiedLength {
			x, err = TruncateFloat(x, target.Flen, target.Decimal)
			if err != nil {
				return nil, errors.Trace(err)
			}
		}
		return float32(x), nil
	case mysql.TypeDouble:
		x, err := ToFloat64(val)
		if err != nil {
			return InvConv(val, tp)
		}
		if target.Flen != UnspecifiedLength {
			x, err = TruncateFloat(x, target.Flen, target.Decimal)
			if err != nil {
				return nil, errors.Trace(err)
			}
		}
		return float64(x), nil
	case mysql.TypeString:
		x, err := ToString(val)
		if err != nil {
			return InvConv(val, tp)
		}
		// TODO: consider target.Charset/Collate
		x = truncateStr(x, target.Flen)
		if target.Charset == charset.CharsetBin {
			return []byte(x), nil
		}
		return x, nil
	case mysql.TypeBlob:
		x, err := ToString(val)
		if err != nil {
			return InvConv(val, tp)
		}
		x = truncateStr(x, target.Flen)
		if target.Charset == charset.CharsetBin {
			return []byte(x), nil
		}
		return x, nil
	case mysql.TypeDuration:
		fsp := mysql.DefaultFsp
		if target.Decimal != UnspecifiedLength {
			fsp = target.Decimal
		}
		switch x := val.(type) {
		case mysql.Duration:
			return x.RoundFrac(fsp)
		case mysql.Time:
			t, err := x.ConvertToDuration()
			if err != nil {
				return nil, errors.Trace(err)
			}

			return t.RoundFrac(fsp)
		case string:
			return mysql.ParseDuration(x, fsp)
		default:
			return InvConv(val, tp)
		}
	case mysql.TypeTimestamp, mysql.TypeDatetime, mysql.TypeDate:
		fsp := mysql.DefaultFsp
		if target.Decimal != UnspecifiedLength {
			fsp = target.Decimal
		}
		switch x := val.(type) {
		case mysql.Time:
			t, err := x.Convert(tp)
			if err != nil {
				return nil, errors.Trace(err)
			}
			return t.RoundFrac(fsp)
		case mysql.Duration:
			t, err := x.ConvertToTime(tp)
			if err != nil {
				return nil, errors.Trace(err)
			}
			return t.RoundFrac(fsp)
		case string:
			return mysql.ParseTime(x, tp, fsp)
		default:
			return InvConv(val, tp)
		}
	case mysql.TypeLonglong:
		x, err := ToInt64(val)
		if err != nil {
			return InvConv(val, tp)
		}
		// TODO: We should first convert to uint64 then check unsigned flag.
		if mysql.HasUnsignedFlag(target.Flag) {
			return uint64(x), nil
		}
		return x, nil
	case mysql.TypeNewDecimal:
		x, err := ToDecimal(val)
		if err != nil {
			return InvConv(val, tp)
		}
		if target.Decimal != UnspecifiedLength {
			x = x.Round(int32(target.Decimal))
		}
		// TODO: check Flen
		return x, nil
	case mysql.TypeYear:
		var (
			intVal int64
			err    error
		)
		switch x := val.(type) {
		case string:
			intVal, err = StrToInt(x)
		case mysql.Time:
			return int16(x.Year()), nil
		case mysql.Duration:
			return int16(time.Now().Year()), nil
		default:
			intVal, err = ToInt64(x)
		}
		if err != nil {
			return InvConv(val, tp)
		}
		y, err := mysql.AdjustYear(int(intVal))
		if err != nil {
			return InvConv(val, tp)
		}
		return int16(y), nil
	default:
		panic("should never happen")
	}
}
예제 #4
0
// Convert converts the val with type tp.
func Convert(val interface{}, target *FieldType) (v interface{}, err error) {
	tp := target.Tp
	if val == nil {
		return nil, nil
	}
	switch tp { // TODO: implement mysql types convert when "CAST() AS" syntax are supported.
	case mysql.TypeFloat:
		x, err := ToFloat64(val)
		if err != nil {
			return invConv(val, tp)
		}
		// For float and following double type, we will only truncate it for float(M, D) format.
		// If no D is set, we will handle it like origin float whether M is set or not.
		if target.Flen != UnspecifiedLength && target.Decimal != UnspecifiedLength {
			x, err = TruncateFloat(x, target.Flen, target.Decimal)
			if err != nil {
				return nil, errors.Trace(err)
			}
		}
		return float32(x), nil
	case mysql.TypeDouble:
		x, err := ToFloat64(val)
		if err != nil {
			return invConv(val, tp)
		}
		if target.Flen != UnspecifiedLength && target.Decimal != UnspecifiedLength {
			x, err = TruncateFloat(x, target.Flen, target.Decimal)
			if err != nil {
				return nil, errors.Trace(err)
			}
		}
		return float64(x), nil
	case mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob,
		mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString:
		x, err := ToString(val)
		if err != nil {
			return invConv(val, tp)
		}
		// TODO: consider target.Charset/Collate
		x = truncateStr(x, target.Flen)
		if target.Charset == charset.CharsetBin {
			return []byte(x), nil
		}
		return x, nil
	case mysql.TypeDuration:
		fsp := mysql.DefaultFsp
		if target.Decimal != UnspecifiedLength {
			fsp = target.Decimal
		}
		switch x := val.(type) {
		case mysql.Duration:
			return x.RoundFrac(fsp)
		case mysql.Time:
			t, err := x.ConvertToDuration()
			if err != nil {
				return nil, errors.Trace(err)
			}

			return t.RoundFrac(fsp)
		case string:
			return mysql.ParseDuration(x, fsp)
		default:
			return invConv(val, tp)
		}
	case mysql.TypeTimestamp, mysql.TypeDatetime, mysql.TypeDate:
		fsp := mysql.DefaultFsp
		if target.Decimal != UnspecifiedLength {
			fsp = target.Decimal
		}
		switch x := val.(type) {
		case mysql.Time:
			t, err := x.Convert(tp)
			if err != nil {
				return nil, errors.Trace(err)
			}
			return t.RoundFrac(fsp)
		case mysql.Duration:
			t, err := x.ConvertToTime(tp)
			if err != nil {
				return nil, errors.Trace(err)
			}
			return t.RoundFrac(fsp)
		case string:
			return mysql.ParseTime(x, tp, fsp)
		case int64:
			return mysql.ParseTimeFromNum(x, tp, fsp)
		default:
			return invConv(val, tp)
		}
	case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong:
		unsigned := mysql.HasUnsignedFlag(target.Flag)
		if unsigned {
			return convertToUint(val, target)
		}
		return convertToInt(val, target)
	case mysql.TypeBit:
		x, err := convertToUint(val, target)
		if err != nil {
			return x, errors.Trace(err)
		}

		// check bit boundary, if bit has n width, the boundary is
		// in [0, (1 << n) - 1]
		width := target.Flen
		if width == 0 || width == mysql.UnspecifiedBitWidth {
			width = mysql.MinBitWidth
		}

		maxValue := uint64(1)<<uint64(width) - 1

		if x > maxValue {
			return maxValue, overflow(val, tp)
		}
		return mysql.Bit{Value: x, Width: width}, nil
	case mysql.TypeDecimal, mysql.TypeNewDecimal:
		x, err := ToDecimal(val)
		if err != nil {
			return invConv(val, tp)
		}
		if target.Decimal != UnspecifiedLength {
			x = x.Round(int32(target.Decimal))
		}
		// TODO: check Flen
		return x, nil
	case mysql.TypeYear:
		var (
			intVal int64
			err    error
		)
		switch x := val.(type) {
		case string:
			intVal, err = StrToInt(x)
		case mysql.Time:
			return int64(x.Year()), nil
		case mysql.Duration:
			return int64(time.Now().Year()), nil
		default:
			intVal, err = ToInt64(x)
		}
		if err != nil {
			return invConv(val, tp)
		}
		y, err := mysql.AdjustYear(int(intVal))
		if err != nil {
			return invConv(val, tp)
		}
		return int64(y), nil
	case mysql.TypeEnum:
		var (
			e   mysql.Enum
			err error
		)
		switch x := val.(type) {
		case string:
			e, err = mysql.ParseEnumName(target.Elems, x)
		case []byte:
			e, err = mysql.ParseEnumName(target.Elems, string(x))
		default:
			var number uint64
			number, err = ToUint64(x)
			if err != nil {
				return nil, errors.Trace(err)
			}
			e, err = mysql.ParseEnumValue(target.Elems, number)
		}

		if err != nil {
			return invConv(val, tp)
		}
		return e, nil
	case mysql.TypeSet:
		var (
			s   mysql.Set
			err error
		)
		switch x := val.(type) {
		case string:
			s, err = mysql.ParseSetName(target.Elems, x)
		case []byte:
			s, err = mysql.ParseSetName(target.Elems, string(x))
		default:
			var number uint64
			number, err = ToUint64(x)
			if err != nil {
				return nil, errors.Trace(err)
			}
			s, err = mysql.ParseSetValue(target.Elems, number)
		}

		if err != nil {
			return invConv(val, tp)
		}
		return s, nil
	default:
		panic("should never happen")
	}
}