func builtinAvg(args []interface{}, ctx map[interface{}]interface{}) (v interface{}, err error) { // avg use decimal for integer and decimal type, use float for others // see https://dev.mysql.com/doc/refman/5.7/en/group-by-functions.html type avg struct { sum interface{} n uint64 decimalResult bool } if _, ok := ctx[ExprEvalArgAggEmpty]; ok { return } fn := ctx[ExprEvalFn] distinct := getDistinct(ctx, fn) if _, ok := ctx[ExprAggDone]; ok { distinct.clear() data, ok := ctx[fn].(avg) if !ok { return } switch x := data.sum.(type) { case float64: return float64(x) / float64(data.n), nil case mysql.Decimal: return x.Div(mysql.NewDecimalFromUint(data.n, 0)), nil } panic("should not happend") } data, _ := ctx[fn].(avg) y := args[0] if y == nil { return } ok, err := distinct.isDistinct(args...) if err != nil || !ok { // if err or not distinct, return return nil, err } if data.sum == nil { data.n = 0 } data.sum, err = calculateSum(data.sum, y) if err != nil { return nil, errors.Errorf("eval AVG aggregate err: %v", err) } data.n++ ctx[fn] = data return }
func (s *testTypeEtcSuite) TestCoerce(c *C) { checkCoerce(c, uint64(3), int16(4)) checkCoerce(c, uint64(0xffffffffffffffff), float64(2.3)) checkCoerce(c, float64(1.3), uint64(0xffffffffffffffff)) checkCoerce(c, int64(11), float64(4.313)) checkCoerce(c, uint(2), uint16(52)) checkCoerce(c, uint8(8), true) checkCoerce(c, uint32(62), int8(8)) checkCoerce(c, mysql.NewDecimalFromInt(1, 0), false) checkCoerce(c, float32(3.4), mysql.NewDecimalFromUint(1, 0)) checkCoerce(c, int32(43), 3.235) }
// CastValue casts a value based on column's type. func (c *Col) CastValue(ctx context.Context, val interface{}) (casted interface{}, err error) { if val == nil { return } switch c.Tp { case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeYear, mysql.TypeBit: intVal, errCode := c.normalizeIntegerValue(val) if errCode == errCodeType { casted = intVal err = c.TypeError(val) return } return c.castIntegerValue(intVal, errCode) case mysql.TypeFloat, mysql.TypeDouble: return c.castFloatValue(val) case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp: switch v := val.(type) { case int64: casted, err = mysql.ParseTimeFromNum(v, c.Tp, c.Decimal) if err != nil { err = newParseColError(err, c) } case string: casted, err = mysql.ParseTime(v, c.Tp, c.Decimal) if err != nil { err = newParseColError(err, c) } case mysql.Time: var t mysql.Time t, err = v.Convert(c.Tp) if err != nil { err = newParseColError(err, c) return } casted, err = t.RoundFrac(c.Decimal) if err != nil { err = newParseColError(err, c) } default: err = c.TypeError(val) } case mysql.TypeDuration: switch v := val.(type) { case string: casted, err = mysql.ParseDuration(v, c.Decimal) if err != nil { err = newParseColError(err, c) } case mysql.Time: var t mysql.Duration t, err = v.ConvertToDuration() if err != nil { err = newParseColError(err, c) return } casted, err = t.RoundFrac(c.Decimal) if err != nil { err = newParseColError(err, c) } case mysql.Duration: casted, err = v.RoundFrac(c.Decimal) default: err = c.TypeError(val) } case mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob, mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString: strV := "" switch v := val.(type) { case mysql.Time: strV = v.String() case mysql.Duration: strV = v.String() case []byte: if c.Charset == charset.CharsetBin { casted = v return } strV = string(v) default: strV = fmt.Sprintf("%v", val) } if (c.Flen != types.UnspecifiedLength) && (len(strV) > c.Flen) { strV = strV[:c.Flen] } casted = strV case mysql.TypeDecimal, mysql.TypeNewDecimal: switch v := val.(type) { case string: casted, err = mysql.ParseDecimal(v) if err != nil { err = newParseColError(err, c) } case int8: casted = mysql.NewDecimalFromInt(int64(v), 0) case int16: casted = mysql.NewDecimalFromInt(int64(v), 0) case int32: casted = mysql.NewDecimalFromInt(int64(v), 0) case int64: casted = mysql.NewDecimalFromInt(int64(v), 0) case int: casted = mysql.NewDecimalFromInt(int64(v), 0) case uint8: casted = mysql.NewDecimalFromUint(uint64(v), 0) case uint16: casted = mysql.NewDecimalFromUint(uint64(v), 0) case uint32: casted = mysql.NewDecimalFromUint(uint64(v), 0) case uint64: casted = mysql.NewDecimalFromUint(uint64(v), 0) case uint: casted = mysql.NewDecimalFromUint(uint64(v), 0) case float32: casted = mysql.NewDecimalFromFloat(float64(v)) case float64: casted = mysql.NewDecimalFromFloat(float64(v)) case mysql.Decimal: casted = v } default: err = c.TypeError(val) } return }