예제 #1
0
func datumExpr(d types.Datum) *tipb.Expr {
	expr := new(tipb.Expr)
	switch d.Kind() {
	case types.KindInt64:
		expr.Tp = tipb.ExprType_Int64
		expr.Val = codec.EncodeInt(nil, d.GetInt64())
	case types.KindUint64:
		expr.Tp = tipb.ExprType_Uint64
		expr.Val = codec.EncodeUint(nil, d.GetUint64())
	case types.KindString:
		expr.Tp = tipb.ExprType_String
		expr.Val = d.GetBytes()
	case types.KindBytes:
		expr.Tp = tipb.ExprType_Bytes
		expr.Val = d.GetBytes()
	case types.KindFloat32:
		expr.Tp = tipb.ExprType_Float32
		expr.Val = codec.EncodeFloat(nil, d.GetFloat64())
	case types.KindFloat64:
		expr.Tp = tipb.ExprType_Float64
		expr.Val = codec.EncodeFloat(nil, d.GetFloat64())
	case types.KindMysqlDuration:
		expr.Tp = tipb.ExprType_MysqlDuration
		expr.Val = codec.EncodeInt(nil, int64(d.GetMysqlDuration().Duration))
	case types.KindMysqlDecimal:
		expr.Tp = tipb.ExprType_MysqlDecimal
		expr.Val = codec.EncodeDecimal(nil, d)
	default:
		expr.Tp = tipb.ExprType_Null
	}
	return expr
}
예제 #2
0
파일: tables.go 프로젝트: anywhy/tidb
func flatten(data types.Datum) (types.Datum, error) {
	switch data.Kind() {
	case types.KindMysqlTime:
		// for mysql datetime, timestamp and date type
		b, err := data.GetMysqlTime().Marshal()
		if err != nil {
			return types.NewDatum(nil), errors.Trace(err)
		}
		return types.NewDatum(b), nil
	case types.KindMysqlDuration:
		// for mysql time type
		data.SetInt64(int64(data.GetMysqlDuration().Duration))
		return data, nil
	case types.KindMysqlDecimal:
		data.SetString(data.GetMysqlDecimal().String())
		return data, nil
	case types.KindMysqlEnum:
		data.SetUint64(data.GetMysqlEnum().Value)
		return data, nil
	case types.KindMysqlSet:
		data.SetUint64(data.GetMysqlSet().Value)
		return data, nil
	case types.KindMysqlBit:
		data.SetUint64(data.GetMysqlBit().Value)
		return data, nil
	case types.KindMysqlHex:
		data.SetInt64(data.GetMysqlHex().Value)
		return data, nil
	default:
		return data, nil
	}
}
예제 #3
0
파일: util.go 프로젝트: pingcap/tidb
func dumpTextValue(mysqlType uint8, value types.Datum) ([]byte, error) {
	switch value.Kind() {
	case types.KindInt64:
		return strconv.AppendInt(nil, value.GetInt64(), 10), nil
	case types.KindUint64:
		return strconv.AppendUint(nil, value.GetUint64(), 10), nil
	case types.KindFloat32:
		return strconv.AppendFloat(nil, value.GetFloat64(), 'f', -1, 32), nil
	case types.KindFloat64:
		return strconv.AppendFloat(nil, value.GetFloat64(), 'f', -1, 64), nil
	case types.KindString, types.KindBytes:
		return value.GetBytes(), nil
	case types.KindMysqlTime:
		return hack.Slice(value.GetMysqlTime().String()), nil
	case types.KindMysqlDuration:
		return hack.Slice(value.GetMysqlDuration().String()), nil
	case types.KindMysqlDecimal:
		return hack.Slice(value.GetMysqlDecimal().String()), nil
	case types.KindMysqlEnum:
		return hack.Slice(value.GetMysqlEnum().String()), nil
	case types.KindMysqlSet:
		return hack.Slice(value.GetMysqlSet().String()), nil
	case types.KindMysqlBit:
		return hack.Slice(value.GetMysqlBit().ToString()), nil
	case types.KindMysqlHex:
		return hack.Slice(value.GetMysqlHex().ToString()), nil
	default:
		return nil, errInvalidType.Gen("invalid type %T", value)
	}
}
예제 #4
0
파일: tablecodec.go 프로젝트: pingcap/tidb
func flatten(data types.Datum) (types.Datum, error) {
	switch data.Kind() {
	case types.KindMysqlTime:
		// for mysql datetime, timestamp and date type
		v, err := data.GetMysqlTime().ToPackedUint()
		return types.NewUintDatum(v), errors.Trace(err)
	case types.KindMysqlDuration:
		// for mysql time type
		data.SetInt64(int64(data.GetMysqlDuration().Duration))
		return data, nil
	case types.KindMysqlEnum:
		data.SetUint64(data.GetMysqlEnum().Value)
		return data, nil
	case types.KindMysqlSet:
		data.SetUint64(data.GetMysqlSet().Value)
		return data, nil
	case types.KindMysqlBit:
		data.SetUint64(data.GetMysqlBit().Value)
		return data, nil
	case types.KindMysqlHex:
		data.SetInt64(data.GetMysqlHex().Value)
		return data, nil
	default:
		return data, nil
	}
}
예제 #5
0
// Convert datum to gotime.
// TODO: This is used for timediff(). After we finish time refactor, we should abandan this function.
func convertToGoTime(sc *variable.StatementContext, d types.Datum) (t time.Time, err error) {
	if d.Kind() != types.KindMysqlTime {
		d, err = convertToTime(sc, d, mysql.TypeDatetime)
		if err != nil {
			return t, errors.Trace(err)
		}
	}
	t, err = d.GetMysqlTime().Time.GoTime()
	return t, errors.Trace(err)
}
예제 #6
0
파일: tablecodec.go 프로젝트: astaxie/tidb
// unflatten converts a raw datum to a column datum.
func unflatten(datum types.Datum, ft *types.FieldType) (types.Datum, error) {
	if datum.Kind() == types.KindNull {
		return datum, nil
	}
	switch ft.Tp {
	case mysql.TypeFloat:
		datum.SetFloat32(float32(datum.GetFloat64()))
		return datum, nil
	case mysql.TypeTiny, mysql.TypeShort, mysql.TypeYear, mysql.TypeInt24,
		mysql.TypeLong, mysql.TypeLonglong, mysql.TypeDouble, mysql.TypeTinyBlob,
		mysql.TypeMediumBlob, mysql.TypeBlob, mysql.TypeLongBlob, mysql.TypeVarchar,
		mysql.TypeString:
		return datum, nil
	case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp:
		var t mysql.Time
		t.Type = ft.Tp
		t.Fsp = ft.Decimal
		err := t.Unmarshal(datum.GetBytes())
		if err != nil {
			return datum, errors.Trace(err)
		}
		datum.SetValue(t)
		return datum, nil
	case mysql.TypeDuration:
		dur := mysql.Duration{Duration: time.Duration(datum.GetInt64())}
		datum.SetValue(dur)
		return datum, nil
	case mysql.TypeNewDecimal:
		dec, err := mysql.ParseDecimal(datum.GetString())
		if err != nil {
			return datum, errors.Trace(err)
		}
		datum.SetValue(dec)
		return datum, nil
	case mysql.TypeEnum:
		enum, err := mysql.ParseEnumValue(ft.Elems, datum.GetUint64())
		if err != nil {
			return datum, errors.Trace(err)
		}
		datum.SetValue(enum)
		return datum, nil
	case mysql.TypeSet:
		set, err := mysql.ParseSetValue(ft.Elems, datum.GetUint64())
		if err != nil {
			return datum, errors.Trace(err)
		}
		datum.SetValue(set)
		return datum, nil
	case mysql.TypeBit:
		bit := mysql.Bit{Value: datum.GetUint64(), Width: ft.Flen}
		datum.SetValue(bit)
		return datum, nil
	}
	return datum, nil
}
예제 #7
0
func parseDayInterval(value types.Datum) (int64, error) {
	switch value.Kind() {
	case types.KindString:
		vs := value.GetString()
		s := strings.ToLower(vs)
		if s == "false" {
			return 0, nil
		} else if s == "true" {
			return 1, nil
		}
		value.SetString(reg.FindString(vs))
	}
	return value.ToInt64()
}
예제 #8
0
func coerceArithmetic(a types.Datum) (d types.Datum, err error) {
	switch a.Kind() {
	case types.KindString, types.KindBytes:
		// MySQL will convert string to float for arithmetic operation
		f, err := types.StrToFloat(a.GetString())
		if err != nil {
			return d, errors.Trace(err)
		}
		d.SetFloat64(f)
		return d, errors.Trace(err)
	case types.KindMysqlTime:
		// if time has no precision, return int64
		t := a.GetMysqlTime()
		de := t.ToNumber()
		if t.Fsp == 0 {
			d.SetInt64(de.IntPart())
			return d, nil
		}
		d.SetMysqlDecimal(de)
		return d, nil
	case types.KindMysqlDuration:
		// if duration has no precision, return int64
		du := a.GetMysqlDuration()
		de := du.ToNumber()
		if du.Fsp == 0 {
			d.SetInt64(de.IntPart())
			return d, nil
		}
		d.SetMysqlDecimal(de)
		return d, nil
	case types.KindMysqlHex:
		d.SetFloat64(a.GetMysqlHex().ToNumber())
		return d, nil
	case types.KindMysqlBit:
		d.SetFloat64(a.GetMysqlBit().ToNumber())
		return d, nil
	case types.KindMysqlEnum:
		d.SetFloat64(a.GetMysqlEnum().ToNumber())
		return d, nil
	case types.KindMysqlSet:
		d.SetFloat64(a.GetMysqlSet().ToNumber())
		return d, nil
	default:
		return a, nil
	}
}
예제 #9
0
func computeMul(a, b types.Datum) (d types.Datum, err error) {
	switch a.Kind() {
	case types.KindInt64:
		switch b.Kind() {
		case types.KindInt64:
			r, err1 := types.MulInt64(a.GetInt64(), b.GetInt64())
			d.SetInt64(r)
			return d, errors.Trace(err1)
		case types.KindUint64:
			r, err1 := types.MulInteger(b.GetUint64(), a.GetInt64())
			d.SetUint64(r)
			return d, errors.Trace(err1)
		}
	case types.KindUint64:
		switch b.Kind() {
		case types.KindInt64:
			r, err1 := types.MulInteger(a.GetUint64(), b.GetInt64())
			d.SetUint64(r)
			return d, errors.Trace(err1)
		case types.KindUint64:
			r, err1 := types.MulUint64(a.GetUint64(), b.GetUint64())
			d.SetUint64(r)
			return d, errors.Trace(err1)
		}
	case types.KindFloat64:
		switch b.Kind() {
		case types.KindFloat64:
			r := a.GetFloat64() * b.GetFloat64()
			d.SetFloat64(r)
			return d, nil
		}
	case types.KindMysqlDecimal:
		switch b.Kind() {
		case types.KindMysqlDecimal:
			r := a.GetMysqlDecimal().Mul(b.GetMysqlDecimal())
			d.SetMysqlDecimal(r)
			return d, nil
		}
	}

	_, err = types.InvOp2(a.GetValue(), b.GetValue(), opcode.Mul)
	return d, errors.Trace(err)
}
예제 #10
0
func computeMinus(a, b types.Datum) (d types.Datum, err error) {
	switch a.Kind() {
	case types.KindInt64:
		switch b.Kind() {
		case types.KindInt64:
			r, err1 := types.SubInt64(a.GetInt64(), b.GetInt64())
			d.SetInt64(r)
			return d, errors.Trace(err1)
		case types.KindUint64:
			r, err1 := types.SubIntWithUint(a.GetInt64(), b.GetUint64())
			d.SetUint64(r)
			return d, errors.Trace(err1)
		}
	case types.KindUint64:
		switch b.Kind() {
		case types.KindInt64:
			r, err1 := types.SubUintWithInt(a.GetUint64(), b.GetInt64())
			d.SetUint64(r)
			return d, errors.Trace(err1)
		case types.KindUint64:
			r, err1 := types.SubUint64(a.GetUint64(), b.GetUint64())
			d.SetUint64(r)
			return d, errors.Trace(err1)
		}
	case types.KindFloat64:
		switch b.Kind() {
		case types.KindFloat64:
			r := a.GetFloat64() - b.GetFloat64()
			d.SetFloat64(r)
			return d, nil
		}
	case types.KindMysqlDecimal:
		switch b.Kind() {
		case types.KindMysqlDecimal:
			r := a.GetMysqlDecimal().Sub(b.GetMysqlDecimal())
			d.SetMysqlDecimal(r)
			return d, nil
		}
	}
	_, err = types.InvOp2(a.GetValue(), b.GetValue(), opcode.Minus)
	return d, errors.Trace(err)
}
예제 #11
0
func computeDiv(a, b types.Datum) (d types.Datum, err error) {
	// MySQL support integer divison Div and division operator /
	// we use opcode.Div for division operator and will use another for integer division later.
	// for division operator, we will use float64 for calculation.
	switch a.Kind() {
	case types.KindFloat64:
		y, err1 := b.ToFloat64()
		if err1 != nil {
			return d, errors.Trace(err1)
		}

		if y == 0 {
			return d, nil
		}

		x := a.GetFloat64()
		d.SetFloat64(x / y)
		return d, nil
	default:
		// the scale of the result is the scale of the first operand plus
		// the value of the div_precision_increment system variable (which is 4 by default)
		// we will use 4 here
		xa, err1 := a.ToDecimal()
		if err != nil {
			return d, errors.Trace(err1)
		}

		xb, err1 := b.ToDecimal()
		if err1 != nil {
			return d, errors.Trace(err1)
		}
		if f, _ := xb.Float64(); f == 0 {
			// division by zero return null
			return d, nil
		}

		d.SetMysqlDecimal(xa.Div(xb))
		return d, nil
	}
}
예제 #12
0
func (b *executorBuilder) datumToPBExpr(client kv.Client, d types.Datum) *tipb.Expr {
	var tp tipb.ExprType
	var val []byte
	switch d.Kind() {
	case types.KindNull:
		tp = tipb.ExprType_Null
	case types.KindInt64:
		tp = tipb.ExprType_Int64
		val = codec.EncodeInt(nil, d.GetInt64())
	case types.KindUint64:
		tp = tipb.ExprType_Uint64
		val = codec.EncodeUint(nil, d.GetUint64())
	case types.KindString:
		tp = tipb.ExprType_String
		val = d.GetBytes()
	case types.KindBytes:
		tp = tipb.ExprType_Bytes
		val = d.GetBytes()
	case types.KindFloat32:
		tp = tipb.ExprType_Float32
		val = codec.EncodeFloat(nil, d.GetFloat64())
	case types.KindFloat64:
		tp = tipb.ExprType_Float64
		val = codec.EncodeFloat(nil, d.GetFloat64())
	case types.KindMysqlDuration:
		tp = tipb.ExprType_MysqlDuration
		val = codec.EncodeInt(nil, int64(d.GetMysqlDuration().Duration))
	case types.KindMysqlDecimal:
		tp = tipb.ExprType_MysqlDecimal
		val = codec.EncodeDecimal(nil, d.GetMysqlDecimal())
	default:
		return nil
	}
	if !client.SupportRequestType(kv.ReqTypeSelect, int64(tp)) {
		return nil
	}
	return &tipb.Expr{Tp: tp.Enum(), Val: val}
}
예제 #13
0
파일: range.go 프로젝트: astaxie/tidb
func (r *rangeBuilder) buildFromBinop(x *ast.BinaryOperationExpr) []rangePoint {
	if x.Op == opcode.OrOr {
		return r.union(r.build(x.L), r.build(x.R))
	} else if x.Op == opcode.AndAnd {
		return r.intersection(r.build(x.L), r.build(x.R))
	}
	// This has been checked that the binary operation is comparison operation, and one of
	// the operand is column name expression.
	var value types.Datum
	var op opcode.Op
	if _, ok := x.L.(*ast.ValueExpr); ok {
		value = types.NewDatum(x.L.GetValue())
		switch x.Op {
		case opcode.GE:
			op = opcode.LE
		case opcode.GT:
			op = opcode.LT
		case opcode.LT:
			op = opcode.GT
		case opcode.LE:
			op = opcode.GE
		default:
			op = x.Op
		}
	} else {
		value = types.NewDatum(x.R.GetValue())
		op = x.Op
	}
	if value.Kind() == types.KindNull {
		return nil
	}
	switch op {
	case opcode.EQ:
		startPoint := rangePoint{value: value, start: true}
		endPoint := rangePoint{value: value}
		return []rangePoint{startPoint, endPoint}
	case opcode.NE:
		startPoint1 := rangePoint{value: types.MinNotNullDatum(), start: true}
		endPoint1 := rangePoint{value: value, excl: true}
		startPoint2 := rangePoint{value: value, start: true, excl: true}
		endPoint2 := rangePoint{value: types.MaxValueDatum()}
		return []rangePoint{startPoint1, endPoint1, startPoint2, endPoint2}
	case opcode.LT:
		startPoint := rangePoint{value: types.MinNotNullDatum(), start: true}
		endPoint := rangePoint{value: value, excl: true}
		return []rangePoint{startPoint, endPoint}
	case opcode.LE:
		startPoint := rangePoint{value: types.MinNotNullDatum(), start: true}
		endPoint := rangePoint{value: value}
		return []rangePoint{startPoint, endPoint}
	case opcode.GT:
		startPoint := rangePoint{value: value, start: true, excl: true}
		endPoint := rangePoint{value: types.MaxValueDatum()}
		return []rangePoint{startPoint, endPoint}
	case opcode.GE:
		startPoint := rangePoint{value: value, start: true}
		endPoint := rangePoint{value: types.MaxValueDatum()}
		return []rangePoint{startPoint, endPoint}
	}
	return nil
}
예제 #14
0
파일: tablecodec.go 프로젝트: XuHuaiyu/tidb
// Unflatten converts a raw datum to a column datum.
func Unflatten(datum types.Datum, ft *types.FieldType, inIndex bool) (types.Datum, error) {
	if datum.IsNull() {
		return datum, nil
	}
	switch ft.Tp {
	case mysql.TypeFloat:
		datum.SetFloat32(float32(datum.GetFloat64()))
		return datum, nil
	case mysql.TypeTiny, mysql.TypeShort, mysql.TypeYear, mysql.TypeInt24,
		mysql.TypeLong, mysql.TypeLonglong, mysql.TypeDouble, mysql.TypeTinyBlob,
		mysql.TypeMediumBlob, mysql.TypeBlob, mysql.TypeLongBlob, mysql.TypeVarchar,
		mysql.TypeString:
		return datum, nil
	case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp:
		var t mysql.Time
		t.Type = ft.Tp
		t.Fsp = ft.Decimal
		var err error
		err = t.FromPackedUint(datum.GetUint64())
		if err != nil {
			return datum, errors.Trace(err)
		}
		datum.SetMysqlTime(t)
		return datum, nil
	case mysql.TypeDuration:
		dur := mysql.Duration{Duration: time.Duration(datum.GetInt64())}
		datum.SetValue(dur)
		return datum, nil
	case mysql.TypeNewDecimal:
		if datum.Kind() == types.KindMysqlDecimal {
			if ft.Decimal >= 0 {
				dec := datum.GetMysqlDecimal().Truncate(int32(ft.Decimal))
				datum.SetMysqlDecimal(dec)
			}
			return datum, nil
		}
		dec, err := mysql.ParseDecimal(datum.GetString())
		if err != nil {
			return datum, errors.Trace(err)
		}
		if ft.Decimal >= 0 {
			dec = dec.Truncate(int32(ft.Decimal))
		}
		datum.SetValue(dec)
		return datum, nil
	case mysql.TypeEnum:
		enum, err := mysql.ParseEnumValue(ft.Elems, datum.GetUint64())
		if err != nil {
			return datum, errors.Trace(err)
		}
		datum.SetValue(enum)
		return datum, nil
	case mysql.TypeSet:
		set, err := mysql.ParseSetValue(ft.Elems, datum.GetUint64())
		if err != nil {
			return datum, errors.Trace(err)
		}
		datum.SetValue(set)
		return datum, nil
	case mysql.TypeBit:
		bit := mysql.Bit{Value: datum.GetUint64(), Width: ft.Flen}
		datum.SetValue(bit)
		return datum, nil
	}
	return datum, nil
}
예제 #15
0
func updateRecord(ctx context.Context, h int64, oldData, newData []types.Datum, updateColumns map[int]*ast.Assignment, t table.Table, offset int, onDuplicateUpdate bool) error {
	if err := t.LockRow(ctx, h, false); err != nil {
		return errors.Trace(err)
	}

	cols := t.Cols()
	touched := make(map[int]bool, len(cols))

	assignExists := false
	var newHandle types.Datum
	for i, asgn := range updateColumns {
		if asgn == nil {
			continue
		}
		if i < offset || i >= offset+len(cols) {
			// The assign expression is for another table, not this.
			continue
		}

		colIndex := i - offset
		col := cols[colIndex]
		if col.IsPKHandleColumn(t.Meta()) {
			newHandle = newData[i]
		}
		if mysql.HasAutoIncrementFlag(col.Flag) {
			if newData[i].Kind() == types.KindNull {
				return errors.Errorf("Column '%v' cannot be null", col.Name.O)
			}
			val, err := newData[i].ToInt64()
			if err != nil {
				return errors.Trace(err)
			}
			t.RebaseAutoID(val, true)
		}

		touched[colIndex] = true
		assignExists = true
	}

	// If no assign list for this table, no need to update.
	if !assignExists {
		return nil
	}

	// Check whether new value is valid.
	if err := column.CastValues(ctx, newData, cols); err != nil {
		return errors.Trace(err)
	}

	if err := column.CheckNotNull(cols, newData); err != nil {
		return errors.Trace(err)
	}

	// If row is not changed, we should do nothing.
	rowChanged := false
	for i := range oldData {
		if !touched[i] {
			continue
		}

		n, err := newData[i].CompareDatum(oldData[i])
		if err != nil {
			return errors.Trace(err)
		}
		if n != 0 {
			rowChanged = true
			break
		}
	}
	if !rowChanged {
		// See: https://dev.mysql.com/doc/refman/5.7/en/mysql-real-connect.html  CLIENT_FOUND_ROWS
		if variable.GetSessionVars(ctx).ClientCapability&mysql.ClientFoundRows > 0 {
			variable.GetSessionVars(ctx).AddAffectedRows(1)
		}
		return nil
	}

	var err error
	if newHandle.Kind() != types.KindNull {
		err = t.RemoveRecord(ctx, h, oldData)
		if err != nil {
			return errors.Trace(err)
		}
		_, err = t.AddRecord(ctx, newData)
	} else {
		// Update record to new value and update index.
		err = t.UpdateRecord(ctx, h, oldData, newData, touched)
	}
	if err != nil {
		return errors.Trace(err)
	}

	// Record affected rows.
	if !onDuplicateUpdate {
		variable.GetSessionVars(ctx).AddAffectedRows(1)
	} else {
		variable.GetSessionVars(ctx).AddAffectedRows(2)
	}
	return nil
}
예제 #16
0
파일: column.go 프로젝트: astaxie/tidb
// CheckNotNull checks if nil value set to a column with NotNull flag is set.
func (c *Col) CheckNotNull(data types.Datum) error {
	if mysql.HasNotNullFlag(c.Flag) && data.Kind() == types.KindNull {
		return errors.Errorf("Column %s can't be null.", c.Name)
	}
	return nil
}
예제 #17
0
func computeIntDiv(a, b types.Datum) (d types.Datum, err error) {
	switch a.Kind() {
	case types.KindInt64:
		x := a.GetInt64()
		switch b.Kind() {
		case types.KindInt64:
			y := b.GetInt64()
			if y == 0 {
				return d, nil
			}
			r, err1 := types.DivInt64(x, y)
			d.SetInt64(r)
			return d, errors.Trace(err1)
		case types.KindUint64:
			y := b.GetUint64()
			if y == 0 {
				return d, nil
			}
			r, err1 := types.DivIntWithUint(x, y)
			d.SetUint64(r)
			return d, errors.Trace(err1)
		}
	case types.KindUint64:
		x := a.GetUint64()
		switch b.Kind() {
		case types.KindInt64:
			y := b.GetInt64()
			if y == 0 {
				return d, nil
			}
			r, err1 := types.DivUintWithInt(x, y)
			d.SetUint64(r)
			return d, errors.Trace(err1)
		case types.KindUint64:
			y := b.GetUint64()
			if y == 0 {
				return d, nil
			}
			d.SetUint64(x / y)
			return d, nil
		}
	}

	// if any is none integer, use decimal to calculate
	x, err := a.ToDecimal()
	if err != nil {
		return d, errors.Trace(err)
	}

	y, err := b.ToDecimal()
	if err != nil {
		return d, errors.Trace(err)
	}

	if f, _ := y.Float64(); f == 0 {
		return d, nil
	}

	d.SetInt64(x.Div(y).IntPart())
	return d, nil
}
예제 #18
0
func computeMod(a, b types.Datum) (d types.Datum, err error) {
	switch a.Kind() {
	case types.KindInt64:
		x := a.GetInt64()
		switch b.Kind() {
		case types.KindInt64:
			y := b.GetInt64()
			if y == 0 {
				return d, nil
			}
			d.SetInt64(x % y)
			return d, nil
		case types.KindUint64:
			y := b.GetUint64()
			if y == 0 {
				return d, nil
			} else if x < 0 {
				d.SetInt64(-int64(uint64(-x) % y))
				// first is int64, return int64.
				return d, nil
			}
			d.SetInt64(int64(uint64(x) % y))
			return d, nil
		}
	case types.KindUint64:
		x := a.GetUint64()
		switch b.Kind() {
		case types.KindInt64:
			y := b.GetInt64()
			if y == 0 {
				return d, nil
			} else if y < 0 {
				// first is uint64, return uint64.
				d.SetUint64(uint64(x % uint64(-y)))
				return d, nil
			}
			d.SetUint64(x % uint64(y))
			return d, nil
		case types.KindUint64:
			y := b.GetUint64()
			if y == 0 {
				return d, nil
			}
			d.SetUint64(x % y)
			return d, nil
		}
	case types.KindFloat64:
		x := a.GetFloat64()
		switch b.Kind() {
		case types.KindFloat64:
			y := b.GetFloat64()
			if y == 0 {
				return d, nil
			}
			d.SetFloat64(math.Mod(x, y))
			return d, nil
		}
	case types.KindMysqlDecimal:
		x := a.GetMysqlDecimal()
		switch b.Kind() {
		case types.KindMysqlDecimal:
			y := b.GetMysqlDecimal()
			xf, _ := x.Float64()
			yf, _ := y.Float64()
			if yf == 0 {
				return d, nil
			}
			d.SetFloat64(math.Mod(xf, yf))
			return d, nil
		}
	}
	_, err = types.InvOp2(a.GetValue(), b.GetValue(), opcode.Mod)
	return d, errors.Trace(err)
}