func calculateSum(sum interface{}, v interface{}) (interface{}, error) { // for avg and sum calculation // avg and sum use decimal for integer and decimal type, use float for others // see https://dev.mysql.com/doc/refman/5.7/en/group-by-functions.html var ( data interface{} err error ) switch y := v.(type) { case int, uint, int8, uint8, int16, uint16, int32, uint32, int64, uint64: data, err = mysql.ConvertToDecimal(v) case mysql.Decimal: data = y default: data, err = types.ToFloat64(v) } if err != nil { return nil, err } switch x := sum.(type) { case nil: return data, nil case float64: return x + data.(float64), nil case mysql.Decimal: return x.Add(data.(mysql.Decimal)), nil default: return nil, errors.Errorf("invalid value %v(%T) for aggregate", x, x) } }
// Coerce changes type. // If a or b is Decimal, changes the both to Decimal. // Else if a or b is Float, changes the both to Float. func Coerce(a, b interface{}) (x, y interface{}) { var hasDecimal bool var hasFloat bool x = convergeType(a, &hasDecimal, &hasFloat) y = convergeType(b, &hasDecimal, &hasFloat) if hasDecimal { d, err := mysql.ConvertToDecimal(x) if err == nil { x = d } d, err = mysql.ConvertToDecimal(y) if err == nil { y = d } } else if hasFloat { switch v := x.(type) { case int64: x = float64(v) case uint64: x = float64(v) case mysql.Hex: x = v.ToNumber() case mysql.Bit: x = v.ToNumber() case mysql.Enum: x = v.ToNumber() case mysql.Set: x = v.ToNumber() } switch v := y.(type) { case int64: y = float64(v) case uint64: y = float64(v) case mysql.Hex: y = v.ToNumber() case mysql.Bit: y = v.ToNumber() case mysql.Enum: y = v.ToNumber() case mysql.Set: y = v.ToNumber() } } return }
// ToDecimal converts an interface to a Decimal. func ToDecimal(value interface{}) (mysql.Decimal, error) { switch v := value.(type) { case bool: if v { return mysql.ConvertToDecimal(1) } return mysql.ConvertToDecimal(0) case []byte: return mysql.ConvertToDecimal(string(v)) case mysql.Time: return v.ToNumber(), nil case mysql.Duration: return v.ToNumber(), nil default: return mysql.ConvertToDecimal(value) } }
func (ts *testFunctionsSuite) TestAggFuncSum(c *C) { args := make([]ExprNode, 1) // sum with distinct agg := &AggregateFuncExpr{ Args: args, F: AggFuncSum, Distinct: true, } agg.CurrentGroup = "xx" expr := NewValueExpr(1) expr1 := NewValueExpr(nil) expr2 := NewValueExpr(1) exprs := []ExprNode{expr, expr1, expr2} for _, e := range exprs { args[0] = e agg.Update() } ctx := agg.GetContext() expect, _ := mysql.ConvertToDecimal(1) v, ok := ctx.Value.(mysql.Decimal) c.Assert(ok, IsTrue) c.Assert(v.Equals(expect), IsTrue) // sum without distinct agg = &AggregateFuncExpr{ Args: args, F: AggFuncSum, } agg.CurrentGroup = "xx" expr = NewValueExpr(2) expr1 = NewValueExpr(nil) expr2 = NewValueExpr(2) exprs = []ExprNode{expr, expr1, expr2} for _, e := range exprs { args[0] = e agg.Update() } ctx = agg.GetContext() expect, _ = mysql.ConvertToDecimal(4) v, ok = ctx.Value.(mysql.Decimal) c.Assert(ok, IsTrue) c.Assert(v.Equals(expect), IsTrue) }
func (s *testEvaluatorSuite) TestAggFuncAvg(c *C) { ctx := mock.NewContext() avg := &ast.AggregateFuncExpr{ F: ast.AggFuncAvg, } avg.CurrentGroup = "emptyGroup" result, err := Eval(ctx, avg) c.Assert(err, IsNil) // Empty group should return nil. c.Assert(result, IsNil) avg.Args = []ast.ExprNode{ast.NewValueExpr(2)} avg.Update() avg.Args = []ast.ExprNode{ast.NewValueExpr(4)} avg.Update() result, err = Eval(ctx, avg) c.Assert(err, IsNil) expect, _ := mysql.ConvertToDecimal(3) v, ok := result.(mysql.Decimal) c.Assert(ok, IsTrue) c.Assert(v.Equals(expect), IsTrue) }
func (s *testEvaluatorSuite) TestAggFuncAvg(c *C) { defer testleak.AfterTest(c)() ctx := mock.NewContext() avg := &ast.AggregateFuncExpr{ F: ast.AggFuncAvg, } avg.CurrentGroup = []byte("emptyGroup") ast.SetFlag(avg) result, err := Eval(ctx, avg) c.Assert(err, IsNil) // Empty group should return nil. c.Assert(result.Kind(), Equals, types.KindNull) avg.Args = []ast.ExprNode{ast.NewValueExpr(2)} avg.Update() avg.Args = []ast.ExprNode{ast.NewValueExpr(4)} avg.Update() result, err = Eval(ctx, avg) c.Assert(err, IsNil) expect, _ := mysql.ConvertToDecimal(3) c.Assert(result.Kind(), Equals, types.KindMysqlDecimal) c.Assert(result.GetMysqlDecimal().Equals(expect), IsTrue) }
func (s *testCodecSuite) TestDecimal(c *C) { tbl := []string{ "1234.00", "1234", "12.34", "12.340", "0.1234", "0.0", "0", "-0.0", "-0.0000", "-1234.00", "-1234", "-12.34", "-12.340", "-0.1234"} for _, t := range tbl { m, err := mysql.ParseDecimal(t) c.Assert(err, IsNil) b, err := EncodeKey(m) c.Assert(err, IsNil) v, err := DecodeKey(b) c.Assert(err, IsNil) c.Assert(v, HasLen, 1) vv, ok := v[0].(mysql.Decimal) c.Assert(ok, IsTrue) c.Assert(vv.Equals(m), IsTrue) } tblCmp := []struct { Arg1 interface{} Arg2 interface{} Ret int }{ // Test for float type decimal. {"1234", "123400", -1}, {"12340", "123400", -1}, {"1234", "1234.5", -1}, {"1234", "1234.0000", 0}, {"1234", "12.34", 1}, {"12.34", "12.35", -1}, {"0.1234", "12.3400", -1}, {"0.1234", "0.1235", -1}, {"0.123400", "12.34", -1}, {"12.34000", "12.34", 0}, {"0.01234", "0.01235", -1}, {"0.1234", "0", 1}, {"0.0000", "0", 0}, {"0.0001", "0", 1}, {"0.0001", "0.0000", 1}, {"0", "-0.0000", 0}, {"-0.0001", "0", -1}, {"-0.1234", "0", -1}, {"-0.1234", "0.1234", -1}, {"-1.234", "-12.34", 1}, {"-0.1234", "-12.34", 1}, {"-12.34", "1234", -1}, {"-12.34", "-12.35", 1}, {"-0.01234", "-0.01235", 1}, {"-1234", "-123400", 1}, {"-12340", "-123400", 1}, // Test for int type decimal. {int64(-1), int64(1), -1}, {int64(math.MaxInt64), int64(math.MinInt64), 1}, {int64(math.MaxInt64), int64(math.MaxInt32), 1}, {int64(math.MinInt32), int64(math.MaxInt16), -1}, {int64(math.MinInt64), int64(math.MaxInt8), -1}, {int64(0), int64(math.MaxInt8), -1}, {int64(math.MinInt8), int64(0), -1}, {int64(math.MinInt16), int64(math.MaxInt16), -1}, {int64(1), int64(-1), 1}, {int64(1), int64(0), 1}, {int64(-1), int64(0), -1}, {int64(0), int64(0), 0}, {int64(math.MaxInt16), int64(math.MaxInt16), 0}, // Test for uint type decimal. {uint64(0), uint64(0), 0}, {uint64(1), uint64(0), 1}, {uint64(0), uint64(1), -1}, {uint64(math.MaxInt8), uint64(math.MaxInt16), -1}, {uint64(math.MaxUint32), uint64(math.MaxInt32), 1}, {uint64(math.MaxUint8), uint64(math.MaxInt8), 1}, {uint64(math.MaxUint16), uint64(math.MaxInt32), -1}, {uint64(math.MaxUint64), uint64(math.MaxInt64), 1}, {uint64(math.MaxInt64), uint64(math.MaxUint32), 1}, {uint64(math.MaxUint64), uint64(0), 1}, {uint64(0), uint64(math.MaxUint64), -1}, } for _, t := range tblCmp { m1, err := mysql.ConvertToDecimal(t.Arg1) c.Assert(err, IsNil) m2, err := mysql.ConvertToDecimal(t.Arg2) c.Assert(err, IsNil) b1, err := EncodeKey(m1) c.Assert(err, IsNil) b2, err := EncodeKey(m2) c.Assert(err, IsNil) ret := bytes.Compare(b1, b2) c.Assert(ret, Equals, t.Ret) } }
// Compare returns an integer comparing the interface a with b. // a > b -> 1 // a = b -> 0 // a < b -> -1 func Compare(a, b interface{}) (int, error) { var coerceErr error a, b, coerceErr = coerceCompare(a, b) if coerceErr != nil { return 0, errors.Trace(coerceErr) } if va, ok := a.([]interface{}); ok { // we guarantee in coerceCompare that a and b are both []interface{} vb := b.([]interface{}) return compareRow(va, vb) } if a == nil || b == nil { // Check ni first, nil is always less than none nil value. if a == nil && b != nil { return -1, nil } else if a != nil && b == nil { return 1, nil } else { // here a and b are all nil return 0, nil } } // TODO: support compare time type with other int, float, decimal types. switch x := a.(type) { case float64: switch y := b.(type) { case float64: return CompareFloat64(x, y), nil case string: return compareFloatString(x, y) } case int64: switch y := b.(type) { case int64: return CompareInt64(x, y), nil case uint64: return CompareInteger(x, y), nil case string: return compareFloatString(float64(x), y) case mysql.Hex: return CompareFloat64(float64(x), y.ToNumber()), nil case mysql.Bit: return CompareFloat64(float64(x), y.ToNumber()), nil case mysql.Enum: return CompareFloat64(float64(x), y.ToNumber()), nil case mysql.Set: return CompareFloat64(float64(x), y.ToNumber()), nil } case uint64: switch y := b.(type) { case uint64: return CompareUint64(x, y), nil case int64: return -CompareInteger(y, x), nil case string: return compareFloatString(float64(x), y) case mysql.Hex: return CompareFloat64(float64(x), y.ToNumber()), nil case mysql.Bit: return CompareFloat64(float64(x), y.ToNumber()), nil case mysql.Enum: return CompareFloat64(float64(x), y.ToNumber()), nil case mysql.Set: return CompareFloat64(float64(x), y.ToNumber()), nil } case mysql.Decimal: switch y := b.(type) { case mysql.Decimal: return x.Cmp(y), nil case string: f, err := mysql.ConvertToDecimal(y) if err != nil { return 0, errors.Trace(err) } return x.Cmp(f), nil } case string: switch y := b.(type) { case string: return CompareString(x, y), nil case int64: return compareStringFloat(x, float64(y)) case uint64: return compareStringFloat(x, float64(y)) case float64: return compareStringFloat(x, y) case mysql.Decimal: f, err := mysql.ConvertToDecimal(x) if err != nil { return 0, errors.Trace(err) } return f.Cmp(y), nil case mysql.Time: n, err := y.CompareString(x) return -n, errors.Trace(err) case mysql.Duration: n, err := y.CompareString(x) return -n, errors.Trace(err) case mysql.Hex: return CompareString(x, y.ToString()), nil case mysql.Bit: return CompareString(x, y.ToString()), nil case mysql.Enum: return CompareString(x, y.String()), nil case mysql.Set: return CompareString(x, y.String()), nil } case mysql.Time: switch y := b.(type) { case mysql.Time: return x.Compare(y), nil case string: return x.CompareString(y) } case mysql.Duration: switch y := b.(type) { case mysql.Duration: return x.Compare(y), nil case string: return x.CompareString(y) } case mysql.Hex: switch y := b.(type) { case int64: return CompareFloat64(x.ToNumber(), float64(y)), nil case uint64: return CompareFloat64(x.ToNumber(), float64(y)), nil case mysql.Bit: return CompareFloat64(x.ToNumber(), y.ToNumber()), nil case string: return CompareString(x.ToString(), y), nil case mysql.Enum: return CompareFloat64(x.ToNumber(), y.ToNumber()), nil case mysql.Set: return CompareFloat64(x.ToNumber(), y.ToNumber()), nil case mysql.Hex: return CompareFloat64(x.ToNumber(), y.ToNumber()), nil } case mysql.Bit: switch y := b.(type) { case int64: return CompareFloat64(x.ToNumber(), float64(y)), nil case uint64: return CompareFloat64(x.ToNumber(), float64(y)), nil case mysql.Hex: return CompareFloat64(x.ToNumber(), y.ToNumber()), nil case string: return CompareString(x.ToString(), y), nil case mysql.Enum: return CompareFloat64(x.ToNumber(), y.ToNumber()), nil case mysql.Set: return CompareFloat64(x.ToNumber(), y.ToNumber()), nil case mysql.Bit: return CompareFloat64(x.ToNumber(), y.ToNumber()), nil } case mysql.Enum: switch y := b.(type) { case int64: return CompareFloat64(x.ToNumber(), float64(y)), nil case uint64: return CompareFloat64(x.ToNumber(), float64(y)), nil case mysql.Hex: return CompareFloat64(x.ToNumber(), y.ToNumber()), nil case string: return CompareString(x.String(), y), nil case mysql.Bit: return CompareFloat64(x.ToNumber(), y.ToNumber()), nil case mysql.Set: return CompareFloat64(x.ToNumber(), y.ToNumber()), nil case mysql.Enum: return CompareFloat64(x.ToNumber(), y.ToNumber()), nil } case mysql.Set: switch y := b.(type) { case int64: return CompareFloat64(x.ToNumber(), float64(y)), nil case uint64: return CompareFloat64(x.ToNumber(), float64(y)), nil case mysql.Hex: return CompareFloat64(x.ToNumber(), y.ToNumber()), nil case string: return CompareString(x.String(), y), nil case mysql.Bit: return CompareFloat64(x.ToNumber(), y.ToNumber()), nil case mysql.Enum: return CompareFloat64(x.ToNumber(), y.ToNumber()), nil case mysql.Set: return CompareFloat64(x.ToNumber(), y.ToNumber()), nil } } return 0, errors.Errorf("invalid comapre type %T cmp %T", a, b) }
func (s *testCodecSuite) TestDecimal(c *C) { defer testleak.AfterTest(c)() tbl := []string{ "1234.00", "1234", "12.34", "12.340", "0.1234", "0.0", "0", "-0.0", "-0.0000", "-1234.00", "-1234", "-12.34", "-12.340", "-0.1234"} for _, t := range tbl { m, err := mysql.ParseDecimal(t) c.Assert(err, IsNil) b, err := EncodeKey(nil, types.NewDatum(m)) c.Assert(err, IsNil) v, err := Decode(b) c.Assert(err, IsNil) c.Assert(v, HasLen, 1) vv := v[0].GetMysqlDecimal() c.Assert(vv.Equals(m), IsTrue) } tblCmp := []struct { Arg1 interface{} Arg2 interface{} Ret int }{ // Test for float type decimal. {"1234", "123400", -1}, {"12340", "123400", -1}, {"1234", "1234.5", -1}, {"1234", "1234.0000", 0}, {"1234", "12.34", 1}, {"12.34", "12.35", -1}, {"0.12", "0.1234", -1}, {"0.1234", "12.3400", -1}, {"0.1234", "0.1235", -1}, {"0.123400", "12.34", -1}, {"12.34000", "12.34", 0}, {"0.01234", "0.01235", -1}, {"0.1234", "0", 1}, {"0.0000", "0", 0}, {"0.0001", "0", 1}, {"0.0001", "0.0000", 1}, {"0", "-0.0000", 0}, {"-0.0001", "0", -1}, {"-0.1234", "0", -1}, {"-0.1234", "-0.12", -1}, {"-0.12", "-0.1234", 1}, {"-0.12", "-0.1200", 0}, {"-0.1234", "0.1234", -1}, {"-1.234", "-12.34", 1}, {"-0.1234", "-12.34", 1}, {"-12.34", "1234", -1}, {"-12.34", "-12.35", 1}, {"-0.01234", "-0.01235", 1}, {"-1234", "-123400", 1}, {"-12340", "-123400", 1}, // Test for int type decimal. {int64(-1), int64(1), -1}, {int64(math.MaxInt64), int64(math.MinInt64), 1}, {int64(math.MaxInt64), int64(math.MaxInt32), 1}, {int64(math.MinInt32), int64(math.MaxInt16), -1}, {int64(math.MinInt64), int64(math.MaxInt8), -1}, {int64(0), int64(math.MaxInt8), -1}, {int64(math.MinInt8), int64(0), -1}, {int64(math.MinInt16), int64(math.MaxInt16), -1}, {int64(1), int64(-1), 1}, {int64(1), int64(0), 1}, {int64(-1), int64(0), -1}, {int64(0), int64(0), 0}, {int64(math.MaxInt16), int64(math.MaxInt16), 0}, // Test for uint type decimal. {uint64(0), uint64(0), 0}, {uint64(1), uint64(0), 1}, {uint64(0), uint64(1), -1}, {uint64(math.MaxInt8), uint64(math.MaxInt16), -1}, {uint64(math.MaxUint32), uint64(math.MaxInt32), 1}, {uint64(math.MaxUint8), uint64(math.MaxInt8), 1}, {uint64(math.MaxUint16), uint64(math.MaxInt32), -1}, {uint64(math.MaxUint64), uint64(math.MaxInt64), 1}, {uint64(math.MaxInt64), uint64(math.MaxUint32), 1}, {uint64(math.MaxUint64), uint64(0), 1}, {uint64(0), uint64(math.MaxUint64), -1}, } for _, t := range tblCmp { m1, err := mysql.ConvertToDecimal(t.Arg1) c.Assert(err, IsNil) m2, err := mysql.ConvertToDecimal(t.Arg2) c.Assert(err, IsNil) b1, err := EncodeKey(nil, types.NewDatum(m1)) c.Assert(err, IsNil) b2, err := EncodeKey(nil, types.NewDatum(m2)) c.Assert(err, IsNil) ret := bytes.Compare(b1, b2) c.Assert(ret, Equals, t.Ret) } floats := []float64{-123.45, -123.40, -23.45, -1.43, -0.93, -0.4333, -0.068, -0.0099, 0, 0.001, 0.0012, 0.12, 1.2, 1.23, 123.3, 2424.242424} var decs [][]byte for i := range floats { dec := mysql.NewDecimalFromFloat(floats[i]) decs = append(decs, EncodeDecimal(nil, dec)) } for i := 0; i < len(decs)-1; i++ { cmp := bytes.Compare(decs[i], decs[i+1]) c.Assert(cmp, LessEqual, 0) } }