// The return type of a CASE expression is the compatible aggregated type of all return values, // but also depends on the context in which it is used. // If used in a string context, the result is returned as a string. // If used in a numeric context, the result is returned as a decimal, real, or integer value. func (v *typeInferrer) handleCaseExpr(x *ast.CaseExpr) { var currType *types.FieldType for _, w := range x.WhenClauses { t := w.Result.GetType() if currType == nil { currType = t continue } mtp := types.MergeFieldType(currType.Tp, t.Tp) if mtp == t.Tp && mtp != currType.Tp { currType.Charset = t.Charset currType.Collate = t.Collate } currType.Tp = mtp } if x.ElseClause != nil { t := x.ElseClause.GetType() if currType == nil { currType = t } else { mtp := types.MergeFieldType(currType.Tp, t.Tp) if mtp == t.Tp && mtp != currType.Tp { currType.Charset = t.Charset currType.Collate = t.Collate } currType.Tp = mtp } } x.SetType(currType) // TODO: We need a better way to set charset/collation x.Type.Charset, x.Type.Collate = types.DefaultCharsetForType(x.Type.Tp) }
func (d *ddl) setCharsetCollationFlenDecimal(tp *types.FieldType) { if len(tp.Charset) == 0 { switch tp.Tp { case mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: tp.Charset, tp.Collate = getDefaultCharsetAndCollate() default: tp.Charset = charset.CharsetBin tp.Collate = charset.CharsetBin } } // If flen is not assigned, assigned it by type. if tp.Flen == types.UnspecifiedLength { tp.Flen = mysql.GetDefaultFieldLength(tp.Tp) } if tp.Decimal == types.UnspecifiedLength { tp.Decimal = mysql.GetDefaultDecimal(tp.Tp) } }
func (v *typeInferrer) handleFuncCallExpr(x *ast.FuncCallExpr) { var ( tp *types.FieldType chs = charset.CharsetBin ) switch x.FnName.L { case "abs", "ifnull", "nullif": tp = x.Args[0].GetType() case "pow", "power", "rand": tp = types.NewFieldType(mysql.TypeDouble) case "curdate", "current_date", "date": tp = types.NewFieldType(mysql.TypeDate) case "curtime", "current_time": tp = types.NewFieldType(mysql.TypeDuration) tp.Decimal = v.getFsp(x) case "current_timestamp": tp = types.NewFieldType(mysql.TypeDatetime) case "microsecond", "second", "minute", "hour", "day", "week", "month", "year", "dayofweek", "dayofmonth", "dayofyear", "weekday", "weekofyear", "yearweek", "found_rows", "length": tp = types.NewFieldType(mysql.TypeLonglong) case "now", "sysdate": tp = types.NewFieldType(mysql.TypeDatetime) tp.Decimal = v.getFsp(x) case "dayname", "version", "database", "user", "current_user", "concat", "concat_ws", "left", "lower", "repeat", "replace", "upper": tp = types.NewFieldType(mysql.TypeVarString) chs = v.defaultCharset case "connection_id": tp = types.NewFieldType(mysql.TypeLonglong) tp.Flag |= mysql.UnsignedFlag case "if": // TODO: fix this // See: https://dev.mysql.com/doc/refman/5.5/en/control-flow-functions.html#function_if // The default return type of IF() (which may matter when it is stored into a temporary table) is calculated as follows. // Expression Return Value // expr2 or expr3 returns a string string // expr2 or expr3 returns a floating-point value floating-point // expr2 or expr3 returns an integer integer tp = x.Args[1].GetType() default: tp = types.NewFieldType(mysql.TypeUnspecified) } // If charset is unspecified. if len(tp.Charset) == 0 { tp.Charset = chs cln := charset.CollationBin if chs != charset.CharsetBin { var err error cln, err = charset.GetDefaultCollation(chs) if err != nil { v.err = err } } tp.Collate = cln } x.SetType(tp) }
// ProtoColumnsToFieldTypes converts tipb column info slice to FieldTyps slice. func ProtoColumnsToFieldTypes(pColumns []*tipb.ColumnInfo) []*types.FieldType { fields := make([]*types.FieldType, len(pColumns)) for i, v := range pColumns { field := new(types.FieldType) field.Tp = byte(v.GetTp()) field.Collate = mysql.Collations[byte(v.GetCollation())] field.Decimal = int(v.GetDecimal()) field.Flen = int(v.GetColumnLen()) field.Flag = uint(v.GetFlag()) field.Elems = v.GetElems() fields[i] = field } return fields }
func (v *typeInferrer) handleFuncCallExpr(x *ast.FuncCallExpr) { var ( tp *types.FieldType chs = charset.CharsetBin ) switch x.FnName.L { case "abs", "ifnull", "nullif": tp = x.Args[0].GetType() case "pow", "power", "rand": tp = types.NewFieldType(mysql.TypeDouble) case "curdate", "current_date", "date": tp = types.NewFieldType(mysql.TypeDate) case "curtime", "current_time": tp = types.NewFieldType(mysql.TypeDuration) tp.Decimal = v.getFsp(x) case "current_timestamp": tp = types.NewFieldType(mysql.TypeDatetime) case "microsecond", "second", "minute", "hour", "day", "week", "month", "year", "dayofweek", "dayofmonth", "dayofyear", "weekday", "weekofyear", "yearweek", "found_rows", "length": tp = types.NewFieldType(mysql.TypeLonglong) case "now", "sysdate": tp = types.NewFieldType(mysql.TypeDatetime) tp.Decimal = v.getFsp(x) case "dayname", "version", "database", "user", "current_user", "concat", "concat_ws", "left", "lower", "repeat", "replace", "upper": tp = types.NewFieldType(mysql.TypeVarString) chs = v.defaultCharset case "connection_id": tp = types.NewFieldType(mysql.TypeLonglong) tp.Flag |= mysql.UnsignedFlag case "if": tp = x.Args[1].GetType() default: tp = types.NewFieldType(mysql.TypeUnspecified) } // If charset is unspecified. if len(tp.Charset) == 0 { tp.Charset = chs cln := charset.CollationBin if chs != charset.CharsetBin { var err error cln, err = charset.GetDefaultCollation(chs) if err != nil { v.err = err } } tp.Collate = cln } x.SetType(tp) }
func (e *Evaluator) funcDateArith(v *ast.FuncDateArithExpr) bool { // health check for date and interval nodeDate := v.Date.GetValue() if types.IsNil(nodeDate) { v.SetValue(nil) return true } nodeInterval := v.Interval.GetValue() if types.IsNil(nodeInterval) { v.SetValue(nil) return true } // parse date fieldType := mysql.TypeDate var resultField *types.FieldType switch x := nodeDate.(type) { case mysql.Time: if (x.Type == mysql.TypeDatetime) || (x.Type == mysql.TypeTimestamp) { fieldType = mysql.TypeDatetime } case string: if !mysql.IsDateFormat(x) { fieldType = mysql.TypeDatetime } case int64: if t, err := mysql.ParseTimeFromInt64(x); err == nil { if (t.Type == mysql.TypeDatetime) || (t.Type == mysql.TypeTimestamp) { fieldType = mysql.TypeDatetime } } } if mysql.IsClockUnit(v.Unit) { fieldType = mysql.TypeDatetime } resultField = types.NewFieldType(fieldType) resultField.Decimal = mysql.MaxFsp value, err := types.Convert(nodeDate, resultField) if err != nil { e.err = ErrInvalidOperation.Gen("DateArith invalid args, need date but get %T", nodeDate) return false } if types.IsNil(value) { e.err = ErrInvalidOperation.Gen("DateArith invalid args, need date but get %v", value) return false } result, ok := value.(mysql.Time) if !ok { e.err = ErrInvalidOperation.Gen("DateArith need time type, but got %T", value) return false } // parse interval var interval string if strings.ToLower(v.Unit) == "day" { day, err2 := parseDayInterval(nodeInterval) if err2 != nil { e.err = ErrInvalidOperation.Gen("DateArith invalid day interval, need int but got %T", nodeInterval) return false } interval = fmt.Sprintf("%d", day) } else { interval = fmt.Sprintf("%v", nodeInterval) } year, month, day, duration, err := mysql.ExtractTimeValue(v.Unit, interval) if err != nil { e.err = errors.Trace(err) return false } if v.Op == ast.DateSub { year, month, day, duration = -year, -month, -day, -duration } result.Time = result.Time.Add(duration) result.Time = result.Time.AddDate(int(year), int(month), int(day)) if result.Time.Nanosecond() == 0 { result.Fsp = 0 } v.SetValue(result) return true }
func (v *typeInferrer) handleFuncCallExpr(x *ast.FuncCallExpr) { var ( tp *types.FieldType chs = charset.CharsetBin ) switch x.FnName.L { case "abs", "ifnull", "nullif": tp = x.Args[0].GetType() // TODO: We should cover all types. if x.FnName.L == "abs" && tp.Tp == mysql.TypeDatetime { tp = types.NewFieldType(mysql.TypeDouble) } case "greatest": for _, arg := range x.Args { InferType(v.sc, arg) } if len(x.Args) > 0 { tp = x.Args[0].GetType() for i := 1; i < len(x.Args); i++ { mergeArithType(tp.Tp, x.Args[i].GetType().Tp) } } case "ceil", "ceiling": t := x.Args[0].GetType().Tp if t == mysql.TypeNull || t == mysql.TypeFloat || t == mysql.TypeDouble || t == mysql.TypeVarchar || t == mysql.TypeTinyBlob || t == mysql.TypeMediumBlob || t == mysql.TypeLongBlob || t == mysql.TypeBlob || t == mysql.TypeVarString || t == mysql.TypeString { tp = types.NewFieldType(mysql.TypeDouble) } else { tp = types.NewFieldType(mysql.TypeLonglong) } case "ln", "log", "log2", "log10": tp = types.NewFieldType(mysql.TypeDouble) case "pow", "power", "rand": tp = types.NewFieldType(mysql.TypeDouble) case "curdate", "current_date", "date": tp = types.NewFieldType(mysql.TypeDate) case "curtime", "current_time", "timediff": tp = types.NewFieldType(mysql.TypeDuration) tp.Decimal = v.getFsp(x) case "current_timestamp", "date_arith": tp = types.NewFieldType(mysql.TypeDatetime) case "microsecond", "second", "minute", "hour", "day", "week", "month", "year", "dayofweek", "dayofmonth", "dayofyear", "weekday", "weekofyear", "yearweek", "found_rows", "length", "extract", "locate": tp = types.NewFieldType(mysql.TypeLonglong) case "now", "sysdate": tp = types.NewFieldType(mysql.TypeDatetime) tp.Decimal = v.getFsp(x) case "from_unixtime": if len(x.Args) == 1 { tp = types.NewFieldType(mysql.TypeDatetime) } else { tp = types.NewFieldType(mysql.TypeVarString) chs = v.defaultCharset } case "str_to_date": tp = types.NewFieldType(mysql.TypeDatetime) case "dayname", "version", "database", "user", "current_user", "schema", "concat", "concat_ws", "left", "lcase", "lower", "repeat", "replace", "ucase", "upper", "convert", "substring", "substring_index", "trim", "ltrim", "rtrim", "reverse", "hex", "unhex", "date_format": tp = types.NewFieldType(mysql.TypeVarString) chs = v.defaultCharset case "strcmp", "isnull": tp = types.NewFieldType(mysql.TypeLonglong) case "connection_id": tp = types.NewFieldType(mysql.TypeLonglong) tp.Flag |= mysql.UnsignedFlag case "if": // TODO: fix this // See https://dev.mysql.com/doc/refman/5.5/en/control-flow-functions.html#function_if // The default return type of IF() (which may matter when it is stored into a temporary table) is calculated as follows. // Expression Return Value // expr2 or expr3 returns a string string // expr2 or expr3 returns a floating-point value floating-point // expr2 or expr3 returns an integer integer tp = x.Args[1].GetType() case "get_lock", "release_lock": tp = types.NewFieldType(mysql.TypeLonglong) default: tp = types.NewFieldType(mysql.TypeUnspecified) } // If charset is unspecified. if len(tp.Charset) == 0 { tp.Charset = chs cln := charset.CollationBin if chs != charset.CharsetBin { var err error cln, err = charset.GetDefaultCollation(chs) if err != nil { v.err = err } } tp.Collate = cln } x.SetType(tp) }
func builtinDateArith(args []types.Datum, ctx context.Context) (d types.Datum, err error) { // Op is used for distinguishing date_add and date_sub. // args[0] -> Op // args[1] -> Date // args[2] -> DateArithInterval // health check for date and interval if args[1].Kind() == types.KindNull { d.SetNull() return d, nil } nodeDate := args[1] nodeInterval := args[2].GetInterface().(ast.DateArithInterval) nodeIntervalIntervalDatum := nodeInterval.Interval.GetDatum() if nodeIntervalIntervalDatum.Kind() == types.KindNull { d.SetNull() return d, nil } // parse date fieldType := mysql.TypeDate var resultField *types.FieldType switch nodeDate.Kind() { case types.KindMysqlTime: x := nodeDate.GetMysqlTime() if (x.Type == mysql.TypeDatetime) || (x.Type == mysql.TypeTimestamp) { fieldType = mysql.TypeDatetime } case types.KindString: x := nodeDate.GetString() if !mysql.IsDateFormat(x) { fieldType = mysql.TypeDatetime } case types.KindInt64: x := nodeDate.GetInt64() if t, err1 := mysql.ParseTimeFromInt64(x); err1 == nil { if (t.Type == mysql.TypeDatetime) || (t.Type == mysql.TypeTimestamp) { fieldType = mysql.TypeDatetime } } } if mysql.IsClockUnit(nodeInterval.Unit) { fieldType = mysql.TypeDatetime } resultField = types.NewFieldType(fieldType) resultField.Decimal = mysql.MaxFsp value, err := nodeDate.ConvertTo(resultField) if err != nil { d.SetNull() return d, ErrInvalidOperation.Gen("DateArith invalid args, need date but get %T", nodeDate) } if value.Kind() == types.KindNull { d.SetNull() return d, ErrInvalidOperation.Gen("DateArith invalid args, need date but get %v", value.GetValue()) } if value.Kind() != types.KindMysqlTime { d.SetNull() return d, ErrInvalidOperation.Gen("DateArith need time type, but got %T", value.GetValue()) } result := value.GetMysqlTime() // parse interval var interval string if strings.ToLower(nodeInterval.Unit) == "day" { day, err1 := parseDayInterval(*nodeIntervalIntervalDatum) if err1 != nil { d.SetNull() return d, ErrInvalidOperation.Gen("DateArith invalid day interval, need int but got %T", nodeIntervalIntervalDatum.GetString()) } interval = fmt.Sprintf("%d", day) } else { if nodeIntervalIntervalDatum.Kind() == types.KindString { interval = fmt.Sprintf("%v", nodeIntervalIntervalDatum.GetString()) } else { ii, err1 := nodeIntervalIntervalDatum.ToInt64() if err1 != nil { d.SetNull() return d, errors.Trace(err1) } interval = fmt.Sprintf("%v", ii) } } year, month, day, duration, err := mysql.ExtractTimeValue(nodeInterval.Unit, interval) if err != nil { d.SetNull() return d, errors.Trace(err) } op := args[0].GetInterface().(ast.DateArithType) if op == ast.DateSub { year, month, day, duration = -year, -month, -day, -duration } result.Time = result.Time.Add(duration) result.Time = result.Time.AddDate(int(year), int(month), int(day)) if result.Time.Nanosecond() == 0 { result.Fsp = 0 } d.SetMysqlTime(result) return d, nil }