示例#1
0
文件: groupby.go 项目: nengwang/tidb
func builtinAvg(args []interface{}, ctx map[interface{}]interface{}) (v interface{}, err error) {
	// avg use decimal for integer and decimal type, use float for others
	// see https://dev.mysql.com/doc/refman/5.7/en/group-by-functions.html
	type avg struct {
		sum           interface{}
		n             uint64
		decimalResult bool
	}

	if _, ok := ctx[ExprEvalArgAggEmpty]; ok {
		return
	}

	fn := ctx[ExprEvalFn]
	distinct := getDistinct(ctx, fn)

	if _, ok := ctx[ExprAggDone]; ok {
		distinct.clear()

		data, ok := ctx[fn].(avg)
		if !ok {
			return
		}

		switch x := data.sum.(type) {
		case float64:
			return float64(x) / float64(data.n), nil
		case mysql.Decimal:
			return x.Div(mysql.NewDecimalFromUint(data.n, 0)), nil
		}

		panic("should not happend")
	}

	data, _ := ctx[fn].(avg)
	y := args[0]
	if y == nil {
		return
	}

	ok, err := distinct.isDistinct(args...)
	if err != nil || !ok {
		// if err or not distinct, return
		return nil, err
	}

	if data.sum == nil {
		data.n = 0
	}

	data.sum, err = calculateSum(data.sum, y)
	if err != nil {
		return nil, errors.Errorf("eval AVG aggregate err: %v", err)
	}

	data.n++
	ctx[fn] = data
	return
}
示例#2
0
文件: etc_test.go 项目: nengwang/tidb
func (s *testTypeEtcSuite) TestCoerce(c *C) {
	checkCoerce(c, uint64(3), int16(4))
	checkCoerce(c, uint64(0xffffffffffffffff), float64(2.3))
	checkCoerce(c, float64(1.3), uint64(0xffffffffffffffff))
	checkCoerce(c, int64(11), float64(4.313))
	checkCoerce(c, uint(2), uint16(52))
	checkCoerce(c, uint8(8), true)
	checkCoerce(c, uint32(62), int8(8))
	checkCoerce(c, mysql.NewDecimalFromInt(1, 0), false)
	checkCoerce(c, float32(3.4), mysql.NewDecimalFromUint(1, 0))
	checkCoerce(c, int32(43), 3.235)
}
示例#3
0
// CastValue casts a value based on column's type.
func (c *Col) CastValue(ctx context.Context, val interface{}) (casted interface{}, err error) {
	if val == nil {
		return
	}
	switch c.Tp {
	case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeYear, mysql.TypeBit:
		intVal, errCode := c.normalizeIntegerValue(val)
		if errCode == errCodeType {
			casted = intVal
			err = c.TypeError(val)
			return
		}
		return c.castIntegerValue(intVal, errCode)
	case mysql.TypeFloat, mysql.TypeDouble:
		return c.castFloatValue(val)
	case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp:
		switch v := val.(type) {
		case int64:
			casted, err = mysql.ParseTimeFromNum(v, c.Tp, c.Decimal)
			if err != nil {
				err = newParseColError(err, c)
			}
		case string:
			casted, err = mysql.ParseTime(v, c.Tp, c.Decimal)
			if err != nil {
				err = newParseColError(err, c)
			}
		case mysql.Time:
			var t mysql.Time
			t, err = v.Convert(c.Tp)
			if err != nil {
				err = newParseColError(err, c)
				return
			}
			casted, err = t.RoundFrac(c.Decimal)
			if err != nil {
				err = newParseColError(err, c)
			}
		default:
			err = c.TypeError(val)
		}
	case mysql.TypeDuration:
		switch v := val.(type) {
		case string:
			casted, err = mysql.ParseDuration(v, c.Decimal)
			if err != nil {
				err = newParseColError(err, c)
			}
		case mysql.Time:
			var t mysql.Duration
			t, err = v.ConvertToDuration()
			if err != nil {
				err = newParseColError(err, c)
				return
			}
			casted, err = t.RoundFrac(c.Decimal)
			if err != nil {
				err = newParseColError(err, c)
			}
		case mysql.Duration:
			casted, err = v.RoundFrac(c.Decimal)
		default:
			err = c.TypeError(val)
		}
	case mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob, mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString:
		strV := ""
		switch v := val.(type) {
		case mysql.Time:
			strV = v.String()
		case mysql.Duration:
			strV = v.String()
		case []byte:
			if c.Charset == charset.CharsetBin {
				casted = v
				return
			}
			strV = string(v)
		default:
			strV = fmt.Sprintf("%v", val)
		}
		if (c.Flen != types.UnspecifiedLength) && (len(strV) > c.Flen) {
			strV = strV[:c.Flen]
		}
		casted = strV
	case mysql.TypeDecimal, mysql.TypeNewDecimal:
		switch v := val.(type) {
		case string:
			casted, err = mysql.ParseDecimal(v)
			if err != nil {
				err = newParseColError(err, c)
			}
		case int8:
			casted = mysql.NewDecimalFromInt(int64(v), 0)
		case int16:
			casted = mysql.NewDecimalFromInt(int64(v), 0)
		case int32:
			casted = mysql.NewDecimalFromInt(int64(v), 0)
		case int64:
			casted = mysql.NewDecimalFromInt(int64(v), 0)
		case int:
			casted = mysql.NewDecimalFromInt(int64(v), 0)
		case uint8:
			casted = mysql.NewDecimalFromUint(uint64(v), 0)
		case uint16:
			casted = mysql.NewDecimalFromUint(uint64(v), 0)
		case uint32:
			casted = mysql.NewDecimalFromUint(uint64(v), 0)
		case uint64:
			casted = mysql.NewDecimalFromUint(uint64(v), 0)
		case uint:
			casted = mysql.NewDecimalFromUint(uint64(v), 0)
		case float32:
			casted = mysql.NewDecimalFromFloat(float64(v))
		case float64:
			casted = mysql.NewDecimalFromFloat(float64(v))
		case mysql.Decimal:
			casted = v
		}
	default:
		err = c.TypeError(val)
	}
	return
}