예제 #1
0
// The return type of a CASE expression is the compatible aggregated type of all return values,
// but also depends on the context in which it is used.
// If used in a string context, the result is returned as a string.
// If used in a numeric context, the result is returned as a decimal, real, or integer value.
func (v *typeInferrer) handleCaseExpr(x *ast.CaseExpr) {
	var currType *types.FieldType
	for _, w := range x.WhenClauses {
		t := w.Result.GetType()
		if currType == nil {
			currType = t
			continue
		}
		mtp := types.MergeFieldType(currType.Tp, t.Tp)
		if mtp == t.Tp && mtp != currType.Tp {
			currType.Charset = t.Charset
			currType.Collate = t.Collate
		}
		currType.Tp = mtp

	}
	if x.ElseClause != nil {
		t := x.ElseClause.GetType()
		if currType == nil {
			currType = t
		} else {
			mtp := types.MergeFieldType(currType.Tp, t.Tp)
			if mtp == t.Tp && mtp != currType.Tp {
				currType.Charset = t.Charset
				currType.Collate = t.Collate
			}
			currType.Tp = mtp
		}
	}
	x.SetType(currType)
	// TODO: We need a better way to set charset/collation
	x.Type.Charset, x.Type.Collate = types.DefaultCharsetForType(x.Type.Tp)
}
예제 #2
0
파일: ddl.go 프로젝트: pingcap/tidb
func (d *ddl) setCharsetCollationFlenDecimal(tp *types.FieldType) {
	if len(tp.Charset) == 0 {
		switch tp.Tp {
		case mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob:
			tp.Charset, tp.Collate = getDefaultCharsetAndCollate()
		default:
			tp.Charset = charset.CharsetBin
			tp.Collate = charset.CharsetBin
		}
	}
	// If flen is not assigned, assigned it by type.
	if tp.Flen == types.UnspecifiedLength {
		tp.Flen = mysql.GetDefaultFieldLength(tp.Tp)
	}
	if tp.Decimal == types.UnspecifiedLength {
		tp.Decimal = mysql.GetDefaultDecimal(tp.Tp)
	}
}
예제 #3
0
func (v *typeInferrer) handleFuncCallExpr(x *ast.FuncCallExpr) {
	var (
		tp  *types.FieldType
		chs = charset.CharsetBin
	)
	switch x.FnName.L {
	case "abs", "ifnull", "nullif":
		tp = x.Args[0].GetType()
	case "pow", "power", "rand":
		tp = types.NewFieldType(mysql.TypeDouble)
	case "curdate", "current_date", "date":
		tp = types.NewFieldType(mysql.TypeDate)
	case "curtime", "current_time":
		tp = types.NewFieldType(mysql.TypeDuration)
		tp.Decimal = v.getFsp(x)
	case "current_timestamp":
		tp = types.NewFieldType(mysql.TypeDatetime)
	case "microsecond", "second", "minute", "hour", "day", "week", "month", "year",
		"dayofweek", "dayofmonth", "dayofyear", "weekday", "weekofyear", "yearweek",
		"found_rows", "length":
		tp = types.NewFieldType(mysql.TypeLonglong)
	case "now", "sysdate":
		tp = types.NewFieldType(mysql.TypeDatetime)
		tp.Decimal = v.getFsp(x)
	case "dayname", "version", "database", "user", "current_user",
		"concat", "concat_ws", "left", "lower", "repeat", "replace", "upper":
		tp = types.NewFieldType(mysql.TypeVarString)
		chs = v.defaultCharset
	case "connection_id":
		tp = types.NewFieldType(mysql.TypeLonglong)
		tp.Flag |= mysql.UnsignedFlag
	case "if":
		// TODO: fix this
		// See: https://dev.mysql.com/doc/refman/5.5/en/control-flow-functions.html#function_if
		// The default return type of IF() (which may matter when it is stored into a temporary table) is calculated as follows.
		// Expression	Return Value
		// expr2 or expr3 returns a string	string
		// expr2 or expr3 returns a floating-point value	floating-point
		// expr2 or expr3 returns an integer	integer
		tp = x.Args[1].GetType()
	default:
		tp = types.NewFieldType(mysql.TypeUnspecified)
	}
	// If charset is unspecified.
	if len(tp.Charset) == 0 {
		tp.Charset = chs
		cln := charset.CollationBin
		if chs != charset.CharsetBin {
			var err error
			cln, err = charset.GetDefaultCollation(chs)
			if err != nil {
				v.err = err
			}
		}
		tp.Collate = cln
	}
	x.SetType(tp)
}
예제 #4
0
파일: distsql.go 프로젝트: pingcap/tidb
// ProtoColumnsToFieldTypes converts tipb column info slice to FieldTyps slice.
func ProtoColumnsToFieldTypes(pColumns []*tipb.ColumnInfo) []*types.FieldType {
	fields := make([]*types.FieldType, len(pColumns))
	for i, v := range pColumns {
		field := new(types.FieldType)
		field.Tp = byte(v.GetTp())
		field.Collate = mysql.Collations[byte(v.GetCollation())]
		field.Decimal = int(v.GetDecimal())
		field.Flen = int(v.GetColumnLen())
		field.Flag = uint(v.GetFlag())
		field.Elems = v.GetElems()
		fields[i] = field
	}
	return fields
}
예제 #5
0
func (v *typeInferrer) handleFuncCallExpr(x *ast.FuncCallExpr) {
	var (
		tp  *types.FieldType
		chs = charset.CharsetBin
	)
	switch x.FnName.L {
	case "abs", "ifnull", "nullif":
		tp = x.Args[0].GetType()
	case "pow", "power", "rand":
		tp = types.NewFieldType(mysql.TypeDouble)
	case "curdate", "current_date", "date":
		tp = types.NewFieldType(mysql.TypeDate)
	case "curtime", "current_time":
		tp = types.NewFieldType(mysql.TypeDuration)
		tp.Decimal = v.getFsp(x)
	case "current_timestamp":
		tp = types.NewFieldType(mysql.TypeDatetime)
	case "microsecond", "second", "minute", "hour", "day", "week", "month", "year",
		"dayofweek", "dayofmonth", "dayofyear", "weekday", "weekofyear", "yearweek",
		"found_rows", "length":
		tp = types.NewFieldType(mysql.TypeLonglong)
	case "now", "sysdate":
		tp = types.NewFieldType(mysql.TypeDatetime)
		tp.Decimal = v.getFsp(x)
	case "dayname", "version", "database", "user", "current_user",
		"concat", "concat_ws", "left", "lower", "repeat", "replace", "upper":
		tp = types.NewFieldType(mysql.TypeVarString)
		chs = v.defaultCharset
	case "connection_id":
		tp = types.NewFieldType(mysql.TypeLonglong)
		tp.Flag |= mysql.UnsignedFlag
	case "if":
		tp = x.Args[1].GetType()
	default:
		tp = types.NewFieldType(mysql.TypeUnspecified)
	}
	// If charset is unspecified.
	if len(tp.Charset) == 0 {
		tp.Charset = chs
		cln := charset.CollationBin
		if chs != charset.CharsetBin {
			var err error
			cln, err = charset.GetDefaultCollation(chs)
			if err != nil {
				v.err = err
			}
		}
		tp.Collate = cln
	}
	x.SetType(tp)
}
예제 #6
0
파일: evaluator.go 프로젝트: mrtoms/tidb
func (e *Evaluator) funcDateArith(v *ast.FuncDateArithExpr) bool {
	// health check for date and interval
	nodeDate := v.Date.GetValue()
	if types.IsNil(nodeDate) {
		v.SetValue(nil)
		return true
	}
	nodeInterval := v.Interval.GetValue()
	if types.IsNil(nodeInterval) {
		v.SetValue(nil)
		return true
	}
	// parse date
	fieldType := mysql.TypeDate
	var resultField *types.FieldType
	switch x := nodeDate.(type) {
	case mysql.Time:
		if (x.Type == mysql.TypeDatetime) || (x.Type == mysql.TypeTimestamp) {
			fieldType = mysql.TypeDatetime
		}
	case string:
		if !mysql.IsDateFormat(x) {
			fieldType = mysql.TypeDatetime
		}
	case int64:
		if t, err := mysql.ParseTimeFromInt64(x); err == nil {
			if (t.Type == mysql.TypeDatetime) || (t.Type == mysql.TypeTimestamp) {
				fieldType = mysql.TypeDatetime
			}
		}
	}
	if mysql.IsClockUnit(v.Unit) {
		fieldType = mysql.TypeDatetime
	}
	resultField = types.NewFieldType(fieldType)
	resultField.Decimal = mysql.MaxFsp
	value, err := types.Convert(nodeDate, resultField)
	if err != nil {
		e.err = ErrInvalidOperation.Gen("DateArith invalid args, need date but get %T", nodeDate)
		return false
	}
	if types.IsNil(value) {
		e.err = ErrInvalidOperation.Gen("DateArith invalid args, need date but get %v", value)
		return false
	}
	result, ok := value.(mysql.Time)
	if !ok {
		e.err = ErrInvalidOperation.Gen("DateArith need time type, but got %T", value)
		return false
	}
	// parse interval
	var interval string
	if strings.ToLower(v.Unit) == "day" {
		day, err2 := parseDayInterval(nodeInterval)
		if err2 != nil {
			e.err = ErrInvalidOperation.Gen("DateArith invalid day interval, need int but got %T", nodeInterval)
			return false
		}
		interval = fmt.Sprintf("%d", day)
	} else {
		interval = fmt.Sprintf("%v", nodeInterval)
	}
	year, month, day, duration, err := mysql.ExtractTimeValue(v.Unit, interval)
	if err != nil {
		e.err = errors.Trace(err)
		return false
	}
	if v.Op == ast.DateSub {
		year, month, day, duration = -year, -month, -day, -duration
	}
	result.Time = result.Time.Add(duration)
	result.Time = result.Time.AddDate(int(year), int(month), int(day))
	if result.Time.Nanosecond() == 0 {
		result.Fsp = 0
	}
	v.SetValue(result)
	return true
}
예제 #7
0
파일: typeinferer.go 프로젝트: pingcap/tidb
func (v *typeInferrer) handleFuncCallExpr(x *ast.FuncCallExpr) {
	var (
		tp  *types.FieldType
		chs = charset.CharsetBin
	)
	switch x.FnName.L {
	case "abs", "ifnull", "nullif":
		tp = x.Args[0].GetType()
		// TODO: We should cover all types.
		if x.FnName.L == "abs" && tp.Tp == mysql.TypeDatetime {
			tp = types.NewFieldType(mysql.TypeDouble)
		}
	case "greatest":
		for _, arg := range x.Args {
			InferType(v.sc, arg)
		}
		if len(x.Args) > 0 {
			tp = x.Args[0].GetType()
			for i := 1; i < len(x.Args); i++ {
				mergeArithType(tp.Tp, x.Args[i].GetType().Tp)
			}
		}
	case "ceil", "ceiling":
		t := x.Args[0].GetType().Tp
		if t == mysql.TypeNull || t == mysql.TypeFloat || t == mysql.TypeDouble || t == mysql.TypeVarchar ||
			t == mysql.TypeTinyBlob || t == mysql.TypeMediumBlob || t == mysql.TypeLongBlob ||
			t == mysql.TypeBlob || t == mysql.TypeVarString || t == mysql.TypeString {
			tp = types.NewFieldType(mysql.TypeDouble)
		} else {
			tp = types.NewFieldType(mysql.TypeLonglong)
		}
	case "ln", "log", "log2", "log10":
		tp = types.NewFieldType(mysql.TypeDouble)
	case "pow", "power", "rand":
		tp = types.NewFieldType(mysql.TypeDouble)
	case "curdate", "current_date", "date":
		tp = types.NewFieldType(mysql.TypeDate)
	case "curtime", "current_time", "timediff":
		tp = types.NewFieldType(mysql.TypeDuration)
		tp.Decimal = v.getFsp(x)
	case "current_timestamp", "date_arith":
		tp = types.NewFieldType(mysql.TypeDatetime)
	case "microsecond", "second", "minute", "hour", "day", "week", "month", "year",
		"dayofweek", "dayofmonth", "dayofyear", "weekday", "weekofyear", "yearweek",
		"found_rows", "length", "extract", "locate":
		tp = types.NewFieldType(mysql.TypeLonglong)
	case "now", "sysdate":
		tp = types.NewFieldType(mysql.TypeDatetime)
		tp.Decimal = v.getFsp(x)
	case "from_unixtime":
		if len(x.Args) == 1 {
			tp = types.NewFieldType(mysql.TypeDatetime)
		} else {
			tp = types.NewFieldType(mysql.TypeVarString)
			chs = v.defaultCharset
		}
	case "str_to_date":
		tp = types.NewFieldType(mysql.TypeDatetime)
	case "dayname", "version", "database", "user", "current_user", "schema",
		"concat", "concat_ws", "left", "lcase", "lower", "repeat",
		"replace", "ucase", "upper", "convert", "substring",
		"substring_index", "trim", "ltrim", "rtrim", "reverse", "hex", "unhex", "date_format":
		tp = types.NewFieldType(mysql.TypeVarString)
		chs = v.defaultCharset
	case "strcmp", "isnull":
		tp = types.NewFieldType(mysql.TypeLonglong)
	case "connection_id":
		tp = types.NewFieldType(mysql.TypeLonglong)
		tp.Flag |= mysql.UnsignedFlag
	case "if":
		// TODO: fix this
		// See https://dev.mysql.com/doc/refman/5.5/en/control-flow-functions.html#function_if
		// The default return type of IF() (which may matter when it is stored into a temporary table) is calculated as follows.
		// Expression	Return Value
		// expr2 or expr3 returns a string	string
		// expr2 or expr3 returns a floating-point value	floating-point
		// expr2 or expr3 returns an integer	integer
		tp = x.Args[1].GetType()
	case "get_lock", "release_lock":
		tp = types.NewFieldType(mysql.TypeLonglong)
	default:
		tp = types.NewFieldType(mysql.TypeUnspecified)
	}
	// If charset is unspecified.
	if len(tp.Charset) == 0 {
		tp.Charset = chs
		cln := charset.CollationBin
		if chs != charset.CharsetBin {
			var err error
			cln, err = charset.GetDefaultCollation(chs)
			if err != nil {
				v.err = err
			}
		}
		tp.Collate = cln
	}
	x.SetType(tp)
}
예제 #8
0
func builtinDateArith(args []types.Datum, ctx context.Context) (d types.Datum, err error) {
	// Op is used for distinguishing date_add and date_sub.
	// args[0] -> Op
	// args[1] -> Date
	// args[2] -> DateArithInterval
	// health check for date and interval
	if args[1].Kind() == types.KindNull {
		d.SetNull()
		return d, nil
	}
	nodeDate := args[1]
	nodeInterval := args[2].GetInterface().(ast.DateArithInterval)
	nodeIntervalIntervalDatum := nodeInterval.Interval.GetDatum()
	if nodeIntervalIntervalDatum.Kind() == types.KindNull {
		d.SetNull()
		return d, nil
	}
	// parse date
	fieldType := mysql.TypeDate
	var resultField *types.FieldType
	switch nodeDate.Kind() {
	case types.KindMysqlTime:
		x := nodeDate.GetMysqlTime()
		if (x.Type == mysql.TypeDatetime) || (x.Type == mysql.TypeTimestamp) {
			fieldType = mysql.TypeDatetime
		}
	case types.KindString:
		x := nodeDate.GetString()
		if !mysql.IsDateFormat(x) {
			fieldType = mysql.TypeDatetime
		}
	case types.KindInt64:
		x := nodeDate.GetInt64()
		if t, err1 := mysql.ParseTimeFromInt64(x); err1 == nil {
			if (t.Type == mysql.TypeDatetime) || (t.Type == mysql.TypeTimestamp) {
				fieldType = mysql.TypeDatetime
			}
		}
	}
	if mysql.IsClockUnit(nodeInterval.Unit) {
		fieldType = mysql.TypeDatetime
	}
	resultField = types.NewFieldType(fieldType)
	resultField.Decimal = mysql.MaxFsp
	value, err := nodeDate.ConvertTo(resultField)
	if err != nil {
		d.SetNull()
		return d, ErrInvalidOperation.Gen("DateArith invalid args, need date but get %T", nodeDate)
	}
	if value.Kind() == types.KindNull {
		d.SetNull()
		return d, ErrInvalidOperation.Gen("DateArith invalid args, need date but get %v", value.GetValue())
	}
	if value.Kind() != types.KindMysqlTime {
		d.SetNull()
		return d, ErrInvalidOperation.Gen("DateArith need time type, but got %T", value.GetValue())
	}
	result := value.GetMysqlTime()
	// parse interval
	var interval string
	if strings.ToLower(nodeInterval.Unit) == "day" {
		day, err1 := parseDayInterval(*nodeIntervalIntervalDatum)
		if err1 != nil {
			d.SetNull()
			return d, ErrInvalidOperation.Gen("DateArith invalid day interval, need int but got %T", nodeIntervalIntervalDatum.GetString())
		}
		interval = fmt.Sprintf("%d", day)
	} else {
		if nodeIntervalIntervalDatum.Kind() == types.KindString {
			interval = fmt.Sprintf("%v", nodeIntervalIntervalDatum.GetString())
		} else {
			ii, err1 := nodeIntervalIntervalDatum.ToInt64()
			if err1 != nil {
				d.SetNull()
				return d, errors.Trace(err1)
			}
			interval = fmt.Sprintf("%v", ii)
		}
	}
	year, month, day, duration, err := mysql.ExtractTimeValue(nodeInterval.Unit, interval)
	if err != nil {
		d.SetNull()
		return d, errors.Trace(err)
	}
	op := args[0].GetInterface().(ast.DateArithType)
	if op == ast.DateSub {
		year, month, day, duration = -year, -month, -day, -duration
	}
	result.Time = result.Time.Add(duration)
	result.Time = result.Time.AddDate(int(year), int(month), int(day))
	if result.Time.Nanosecond() == 0 {
		result.Fsp = 0
	}
	d.SetMysqlTime(result)
	return d, nil
}