コード例 #1
0
ファイル: groupby.go プロジェクト: H0bby/tidb
func calculateSum(sum interface{}, v interface{}) (interface{}, error) {
	// for avg and sum calculation
	// avg and sum use decimal for integer and decimal type, use float for others
	// see https://dev.mysql.com/doc/refman/5.7/en/group-by-functions.html
	var (
		data interface{}
		err  error
	)

	switch y := v.(type) {
	case int, uint, int8, uint8, int16, uint16, int32, uint32, int64, uint64:
		data, err = mysql.ConvertToDecimal(v)
	case mysql.Decimal:
		data = y
	default:
		data, err = types.ToFloat64(v)
	}

	if err != nil {
		return nil, err
	}

	switch x := sum.(type) {
	case nil:
		return data, nil
	case float64:
		return x + data.(float64), nil
	case mysql.Decimal:
		return x.Add(data.(mysql.Decimal)), nil
	default:
		return nil, errors.Errorf("invalid value %v(%T) for aggregate", x, x)
	}
}
コード例 #2
0
ファイル: etc.go プロジェクト: 52Jolynn/tidb
// Coerce changes type.
// If a or b is Decimal, changes the both to Decimal.
// Else if a or b is Float, changes the both to Float.
func Coerce(a, b interface{}) (x, y interface{}) {
	var hasDecimal bool
	var hasFloat bool
	x = convergeType(a, &hasDecimal, &hasFloat)
	y = convergeType(b, &hasDecimal, &hasFloat)
	if hasDecimal {
		d, err := mysql.ConvertToDecimal(x)
		if err == nil {
			x = d
		}
		d, err = mysql.ConvertToDecimal(y)
		if err == nil {
			y = d
		}
	} else if hasFloat {
		switch v := x.(type) {
		case int64:
			x = float64(v)
		case uint64:
			x = float64(v)
		case mysql.Hex:
			x = v.ToNumber()
		case mysql.Bit:
			x = v.ToNumber()
		case mysql.Enum:
			x = v.ToNumber()
		case mysql.Set:
			x = v.ToNumber()
		}
		switch v := y.(type) {
		case int64:
			y = float64(v)
		case uint64:
			y = float64(v)
		case mysql.Hex:
			y = v.ToNumber()
		case mysql.Bit:
			y = v.ToNumber()
		case mysql.Enum:
			y = v.ToNumber()
		case mysql.Set:
			y = v.ToNumber()
		}
	}
	return
}
コード例 #3
0
ファイル: convert.go プロジェクト: H0bby/tidb
// ToDecimal converts an interface to a Decimal.
func ToDecimal(value interface{}) (mysql.Decimal, error) {
	switch v := value.(type) {
	case bool:
		if v {
			return mysql.ConvertToDecimal(1)
		}
		return mysql.ConvertToDecimal(0)
	case []byte:
		return mysql.ConvertToDecimal(string(v))
	case mysql.Time:
		return v.ToNumber(), nil
	case mysql.Duration:
		return v.ToNumber(), nil
	default:
		return mysql.ConvertToDecimal(value)
	}
}
コード例 #4
0
ファイル: functions_test.go プロジェクト: anywhy/tidb
func (ts *testFunctionsSuite) TestAggFuncSum(c *C) {
	args := make([]ExprNode, 1)
	// sum with distinct
	agg := &AggregateFuncExpr{
		Args:     args,
		F:        AggFuncSum,
		Distinct: true,
	}
	agg.CurrentGroup = "xx"
	expr := NewValueExpr(1)
	expr1 := NewValueExpr(nil)
	expr2 := NewValueExpr(1)
	exprs := []ExprNode{expr, expr1, expr2}
	for _, e := range exprs {
		args[0] = e
		agg.Update()
	}
	ctx := agg.GetContext()
	expect, _ := mysql.ConvertToDecimal(1)
	v, ok := ctx.Value.(mysql.Decimal)
	c.Assert(ok, IsTrue)
	c.Assert(v.Equals(expect), IsTrue)
	// sum without distinct
	agg = &AggregateFuncExpr{
		Args: args,
		F:    AggFuncSum,
	}
	agg.CurrentGroup = "xx"
	expr = NewValueExpr(2)
	expr1 = NewValueExpr(nil)
	expr2 = NewValueExpr(2)
	exprs = []ExprNode{expr, expr1, expr2}
	for _, e := range exprs {
		args[0] = e
		agg.Update()
	}
	ctx = agg.GetContext()
	expect, _ = mysql.ConvertToDecimal(4)
	v, ok = ctx.Value.(mysql.Decimal)
	c.Assert(ok, IsTrue)
	c.Assert(v.Equals(expect), IsTrue)
}
コード例 #5
0
ファイル: evaluator_test.go プロジェクト: mumubusu/tidb
func (s *testEvaluatorSuite) TestAggFuncAvg(c *C) {
	ctx := mock.NewContext()
	avg := &ast.AggregateFuncExpr{
		F: ast.AggFuncAvg,
	}
	avg.CurrentGroup = "emptyGroup"
	result, err := Eval(ctx, avg)
	c.Assert(err, IsNil)
	// Empty group should return nil.
	c.Assert(result, IsNil)

	avg.Args = []ast.ExprNode{ast.NewValueExpr(2)}
	avg.Update()
	avg.Args = []ast.ExprNode{ast.NewValueExpr(4)}
	avg.Update()

	result, err = Eval(ctx, avg)
	c.Assert(err, IsNil)
	expect, _ := mysql.ConvertToDecimal(3)
	v, ok := result.(mysql.Decimal)
	c.Assert(ok, IsTrue)
	c.Assert(v.Equals(expect), IsTrue)
}
コード例 #6
0
ファイル: evaluator_test.go プロジェクト: XuHuaiyu/tidb
func (s *testEvaluatorSuite) TestAggFuncAvg(c *C) {
	defer testleak.AfterTest(c)()
	ctx := mock.NewContext()
	avg := &ast.AggregateFuncExpr{
		F: ast.AggFuncAvg,
	}
	avg.CurrentGroup = []byte("emptyGroup")
	ast.SetFlag(avg)
	result, err := Eval(ctx, avg)
	c.Assert(err, IsNil)
	// Empty group should return nil.
	c.Assert(result.Kind(), Equals, types.KindNull)

	avg.Args = []ast.ExprNode{ast.NewValueExpr(2)}
	avg.Update()
	avg.Args = []ast.ExprNode{ast.NewValueExpr(4)}
	avg.Update()

	result, err = Eval(ctx, avg)
	c.Assert(err, IsNil)
	expect, _ := mysql.ConvertToDecimal(3)
	c.Assert(result.Kind(), Equals, types.KindMysqlDecimal)
	c.Assert(result.GetMysqlDecimal().Equals(expect), IsTrue)
}
コード例 #7
0
ファイル: codec_test.go プロジェクト: yzl11/vessel
func (s *testCodecSuite) TestDecimal(c *C) {
	tbl := []string{
		"1234.00",
		"1234",
		"12.34",
		"12.340",
		"0.1234",
		"0.0",
		"0",
		"-0.0",
		"-0.0000",
		"-1234.00",
		"-1234",
		"-12.34",
		"-12.340",
		"-0.1234"}

	for _, t := range tbl {
		m, err := mysql.ParseDecimal(t)
		c.Assert(err, IsNil)
		b, err := EncodeKey(m)
		c.Assert(err, IsNil)
		v, err := DecodeKey(b)
		c.Assert(err, IsNil)
		c.Assert(v, HasLen, 1)
		vv, ok := v[0].(mysql.Decimal)
		c.Assert(ok, IsTrue)
		c.Assert(vv.Equals(m), IsTrue)
	}

	tblCmp := []struct {
		Arg1 interface{}
		Arg2 interface{}
		Ret  int
	}{
		// Test for float type decimal.
		{"1234", "123400", -1},
		{"12340", "123400", -1},
		{"1234", "1234.5", -1},
		{"1234", "1234.0000", 0},
		{"1234", "12.34", 1},
		{"12.34", "12.35", -1},
		{"0.1234", "12.3400", -1},
		{"0.1234", "0.1235", -1},
		{"0.123400", "12.34", -1},
		{"12.34000", "12.34", 0},
		{"0.01234", "0.01235", -1},
		{"0.1234", "0", 1},
		{"0.0000", "0", 0},
		{"0.0001", "0", 1},
		{"0.0001", "0.0000", 1},
		{"0", "-0.0000", 0},
		{"-0.0001", "0", -1},
		{"-0.1234", "0", -1},
		{"-0.1234", "0.1234", -1},
		{"-1.234", "-12.34", 1},
		{"-0.1234", "-12.34", 1},
		{"-12.34", "1234", -1},
		{"-12.34", "-12.35", 1},
		{"-0.01234", "-0.01235", 1},
		{"-1234", "-123400", 1},
		{"-12340", "-123400", 1},

		// Test for int type decimal.
		{int64(-1), int64(1), -1},
		{int64(math.MaxInt64), int64(math.MinInt64), 1},
		{int64(math.MaxInt64), int64(math.MaxInt32), 1},
		{int64(math.MinInt32), int64(math.MaxInt16), -1},
		{int64(math.MinInt64), int64(math.MaxInt8), -1},
		{int64(0), int64(math.MaxInt8), -1},
		{int64(math.MinInt8), int64(0), -1},
		{int64(math.MinInt16), int64(math.MaxInt16), -1},
		{int64(1), int64(-1), 1},
		{int64(1), int64(0), 1},
		{int64(-1), int64(0), -1},
		{int64(0), int64(0), 0},
		{int64(math.MaxInt16), int64(math.MaxInt16), 0},

		// Test for uint type decimal.
		{uint64(0), uint64(0), 0},
		{uint64(1), uint64(0), 1},
		{uint64(0), uint64(1), -1},
		{uint64(math.MaxInt8), uint64(math.MaxInt16), -1},
		{uint64(math.MaxUint32), uint64(math.MaxInt32), 1},
		{uint64(math.MaxUint8), uint64(math.MaxInt8), 1},
		{uint64(math.MaxUint16), uint64(math.MaxInt32), -1},
		{uint64(math.MaxUint64), uint64(math.MaxInt64), 1},
		{uint64(math.MaxInt64), uint64(math.MaxUint32), 1},
		{uint64(math.MaxUint64), uint64(0), 1},
		{uint64(0), uint64(math.MaxUint64), -1},
	}

	for _, t := range tblCmp {
		m1, err := mysql.ConvertToDecimal(t.Arg1)
		c.Assert(err, IsNil)
		m2, err := mysql.ConvertToDecimal(t.Arg2)
		c.Assert(err, IsNil)

		b1, err := EncodeKey(m1)
		c.Assert(err, IsNil)
		b2, err := EncodeKey(m2)
		c.Assert(err, IsNil)

		ret := bytes.Compare(b1, b2)
		c.Assert(ret, Equals, t.Ret)
	}
}
コード例 #8
0
ファイル: compare.go プロジェクト: H0bby/tidb
// Compare returns an integer comparing the interface a with b.
// a > b -> 1
// a = b -> 0
// a < b -> -1
func Compare(a, b interface{}) (int, error) {
	var coerceErr error
	a, b, coerceErr = coerceCompare(a, b)
	if coerceErr != nil {
		return 0, errors.Trace(coerceErr)
	}

	if va, ok := a.([]interface{}); ok {
		// we guarantee in coerceCompare that a and b are both []interface{}
		vb := b.([]interface{})
		return compareRow(va, vb)
	}

	if a == nil || b == nil {
		// Check ni first, nil is always less than none nil value.
		if a == nil && b != nil {
			return -1, nil
		} else if a != nil && b == nil {
			return 1, nil
		} else {
			// here a and b are all nil
			return 0, nil
		}
	}

	// TODO: support compare time type with other int, float, decimal types.
	switch x := a.(type) {
	case float64:
		switch y := b.(type) {
		case float64:
			return CompareFloat64(x, y), nil
		case string:
			return compareFloatString(x, y)
		}
	case int64:
		switch y := b.(type) {
		case int64:
			return CompareInt64(x, y), nil
		case uint64:
			return CompareInteger(x, y), nil
		case string:
			return compareFloatString(float64(x), y)
		case mysql.Hex:
			return CompareFloat64(float64(x), y.ToNumber()), nil
		case mysql.Bit:
			return CompareFloat64(float64(x), y.ToNumber()), nil
		case mysql.Enum:
			return CompareFloat64(float64(x), y.ToNumber()), nil
		case mysql.Set:
			return CompareFloat64(float64(x), y.ToNumber()), nil
		}
	case uint64:
		switch y := b.(type) {
		case uint64:
			return CompareUint64(x, y), nil
		case int64:
			return -CompareInteger(y, x), nil
		case string:
			return compareFloatString(float64(x), y)
		case mysql.Hex:
			return CompareFloat64(float64(x), y.ToNumber()), nil
		case mysql.Bit:
			return CompareFloat64(float64(x), y.ToNumber()), nil
		case mysql.Enum:
			return CompareFloat64(float64(x), y.ToNumber()), nil
		case mysql.Set:
			return CompareFloat64(float64(x), y.ToNumber()), nil
		}
	case mysql.Decimal:
		switch y := b.(type) {
		case mysql.Decimal:
			return x.Cmp(y), nil
		case string:
			f, err := mysql.ConvertToDecimal(y)
			if err != nil {
				return 0, errors.Trace(err)
			}
			return x.Cmp(f), nil
		}
	case string:
		switch y := b.(type) {
		case string:
			return CompareString(x, y), nil
		case int64:
			return compareStringFloat(x, float64(y))
		case uint64:
			return compareStringFloat(x, float64(y))
		case float64:
			return compareStringFloat(x, y)
		case mysql.Decimal:
			f, err := mysql.ConvertToDecimal(x)
			if err != nil {
				return 0, errors.Trace(err)
			}
			return f.Cmp(y), nil
		case mysql.Time:
			n, err := y.CompareString(x)
			return -n, errors.Trace(err)
		case mysql.Duration:
			n, err := y.CompareString(x)
			return -n, errors.Trace(err)
		case mysql.Hex:
			return CompareString(x, y.ToString()), nil
		case mysql.Bit:
			return CompareString(x, y.ToString()), nil
		case mysql.Enum:
			return CompareString(x, y.String()), nil
		case mysql.Set:
			return CompareString(x, y.String()), nil
		}
	case mysql.Time:
		switch y := b.(type) {
		case mysql.Time:
			return x.Compare(y), nil
		case string:
			return x.CompareString(y)
		}
	case mysql.Duration:
		switch y := b.(type) {
		case mysql.Duration:
			return x.Compare(y), nil
		case string:
			return x.CompareString(y)
		}
	case mysql.Hex:
		switch y := b.(type) {
		case int64:
			return CompareFloat64(x.ToNumber(), float64(y)), nil
		case uint64:
			return CompareFloat64(x.ToNumber(), float64(y)), nil
		case mysql.Bit:
			return CompareFloat64(x.ToNumber(), y.ToNumber()), nil
		case string:
			return CompareString(x.ToString(), y), nil
		case mysql.Enum:
			return CompareFloat64(x.ToNumber(), y.ToNumber()), nil
		case mysql.Set:
			return CompareFloat64(x.ToNumber(), y.ToNumber()), nil
		case mysql.Hex:
			return CompareFloat64(x.ToNumber(), y.ToNumber()), nil
		}
	case mysql.Bit:
		switch y := b.(type) {
		case int64:
			return CompareFloat64(x.ToNumber(), float64(y)), nil
		case uint64:
			return CompareFloat64(x.ToNumber(), float64(y)), nil
		case mysql.Hex:
			return CompareFloat64(x.ToNumber(), y.ToNumber()), nil
		case string:
			return CompareString(x.ToString(), y), nil
		case mysql.Enum:
			return CompareFloat64(x.ToNumber(), y.ToNumber()), nil
		case mysql.Set:
			return CompareFloat64(x.ToNumber(), y.ToNumber()), nil
		case mysql.Bit:
			return CompareFloat64(x.ToNumber(), y.ToNumber()), nil
		}
	case mysql.Enum:
		switch y := b.(type) {
		case int64:
			return CompareFloat64(x.ToNumber(), float64(y)), nil
		case uint64:
			return CompareFloat64(x.ToNumber(), float64(y)), nil
		case mysql.Hex:
			return CompareFloat64(x.ToNumber(), y.ToNumber()), nil
		case string:
			return CompareString(x.String(), y), nil
		case mysql.Bit:
			return CompareFloat64(x.ToNumber(), y.ToNumber()), nil
		case mysql.Set:
			return CompareFloat64(x.ToNumber(), y.ToNumber()), nil
		case mysql.Enum:
			return CompareFloat64(x.ToNumber(), y.ToNumber()), nil
		}
	case mysql.Set:
		switch y := b.(type) {
		case int64:
			return CompareFloat64(x.ToNumber(), float64(y)), nil
		case uint64:
			return CompareFloat64(x.ToNumber(), float64(y)), nil
		case mysql.Hex:
			return CompareFloat64(x.ToNumber(), y.ToNumber()), nil
		case string:
			return CompareString(x.String(), y), nil
		case mysql.Bit:
			return CompareFloat64(x.ToNumber(), y.ToNumber()), nil
		case mysql.Enum:
			return CompareFloat64(x.ToNumber(), y.ToNumber()), nil
		case mysql.Set:
			return CompareFloat64(x.ToNumber(), y.ToNumber()), nil
		}
	}

	return 0, errors.Errorf("invalid comapre type %T cmp %T", a, b)
}
コード例 #9
0
ファイル: codec_test.go プロジェクト: anywhy/tidb
func (s *testCodecSuite) TestDecimal(c *C) {
	defer testleak.AfterTest(c)()
	tbl := []string{
		"1234.00",
		"1234",
		"12.34",
		"12.340",
		"0.1234",
		"0.0",
		"0",
		"-0.0",
		"-0.0000",
		"-1234.00",
		"-1234",
		"-12.34",
		"-12.340",
		"-0.1234"}

	for _, t := range tbl {
		m, err := mysql.ParseDecimal(t)
		c.Assert(err, IsNil)
		b, err := EncodeKey(nil, types.NewDatum(m))
		c.Assert(err, IsNil)
		v, err := Decode(b)
		c.Assert(err, IsNil)
		c.Assert(v, HasLen, 1)
		vv := v[0].GetMysqlDecimal()
		c.Assert(vv.Equals(m), IsTrue)
	}

	tblCmp := []struct {
		Arg1 interface{}
		Arg2 interface{}
		Ret  int
	}{
		// Test for float type decimal.
		{"1234", "123400", -1},
		{"12340", "123400", -1},
		{"1234", "1234.5", -1},
		{"1234", "1234.0000", 0},
		{"1234", "12.34", 1},
		{"12.34", "12.35", -1},
		{"0.12", "0.1234", -1},
		{"0.1234", "12.3400", -1},
		{"0.1234", "0.1235", -1},
		{"0.123400", "12.34", -1},
		{"12.34000", "12.34", 0},
		{"0.01234", "0.01235", -1},
		{"0.1234", "0", 1},
		{"0.0000", "0", 0},
		{"0.0001", "0", 1},
		{"0.0001", "0.0000", 1},
		{"0", "-0.0000", 0},
		{"-0.0001", "0", -1},
		{"-0.1234", "0", -1},
		{"-0.1234", "-0.12", -1},
		{"-0.12", "-0.1234", 1},
		{"-0.12", "-0.1200", 0},
		{"-0.1234", "0.1234", -1},
		{"-1.234", "-12.34", 1},
		{"-0.1234", "-12.34", 1},
		{"-12.34", "1234", -1},
		{"-12.34", "-12.35", 1},
		{"-0.01234", "-0.01235", 1},
		{"-1234", "-123400", 1},
		{"-12340", "-123400", 1},

		// Test for int type decimal.
		{int64(-1), int64(1), -1},
		{int64(math.MaxInt64), int64(math.MinInt64), 1},
		{int64(math.MaxInt64), int64(math.MaxInt32), 1},
		{int64(math.MinInt32), int64(math.MaxInt16), -1},
		{int64(math.MinInt64), int64(math.MaxInt8), -1},
		{int64(0), int64(math.MaxInt8), -1},
		{int64(math.MinInt8), int64(0), -1},
		{int64(math.MinInt16), int64(math.MaxInt16), -1},
		{int64(1), int64(-1), 1},
		{int64(1), int64(0), 1},
		{int64(-1), int64(0), -1},
		{int64(0), int64(0), 0},
		{int64(math.MaxInt16), int64(math.MaxInt16), 0},

		// Test for uint type decimal.
		{uint64(0), uint64(0), 0},
		{uint64(1), uint64(0), 1},
		{uint64(0), uint64(1), -1},
		{uint64(math.MaxInt8), uint64(math.MaxInt16), -1},
		{uint64(math.MaxUint32), uint64(math.MaxInt32), 1},
		{uint64(math.MaxUint8), uint64(math.MaxInt8), 1},
		{uint64(math.MaxUint16), uint64(math.MaxInt32), -1},
		{uint64(math.MaxUint64), uint64(math.MaxInt64), 1},
		{uint64(math.MaxInt64), uint64(math.MaxUint32), 1},
		{uint64(math.MaxUint64), uint64(0), 1},
		{uint64(0), uint64(math.MaxUint64), -1},
	}

	for _, t := range tblCmp {
		m1, err := mysql.ConvertToDecimal(t.Arg1)
		c.Assert(err, IsNil)
		m2, err := mysql.ConvertToDecimal(t.Arg2)
		c.Assert(err, IsNil)

		b1, err := EncodeKey(nil, types.NewDatum(m1))
		c.Assert(err, IsNil)
		b2, err := EncodeKey(nil, types.NewDatum(m2))
		c.Assert(err, IsNil)

		ret := bytes.Compare(b1, b2)
		c.Assert(ret, Equals, t.Ret)
	}

	floats := []float64{-123.45, -123.40, -23.45, -1.43, -0.93, -0.4333, -0.068,
		-0.0099, 0, 0.001, 0.0012, 0.12, 1.2, 1.23, 123.3, 2424.242424}
	var decs [][]byte
	for i := range floats {
		dec := mysql.NewDecimalFromFloat(floats[i])
		decs = append(decs, EncodeDecimal(nil, dec))
	}
	for i := 0; i < len(decs)-1; i++ {
		cmp := bytes.Compare(decs[i], decs[i+1])
		c.Assert(cmp, LessEqual, 0)
	}
}