func (s *testEvalSuite) TestEvalCoalesce(c *C) { colID := int64(1) row := make(map[int64]types.Datum) row[colID] = types.NewIntDatum(100) xevaluator := &Evaluator{Row: row} nullDatum := types.Datum{} nullDatum.SetNull() notNullDatum := types.NewStringDatum("not-null") cases := []struct { expr *tipb.Expr result types.Datum }{ { expr: buildExpr(tipb.ExprType_Coalesce, nullDatum, nullDatum, nullDatum), result: nullDatum, }, { expr: buildExpr(tipb.ExprType_Coalesce, nullDatum, notNullDatum, nullDatum), result: notNullDatum, }, { expr: buildExpr(tipb.ExprType_Coalesce, nullDatum, notNullDatum, types.NewStringDatum("not-null-2"), nullDatum), result: notNullDatum, }, } for _, ca := range cases { result, err := xevaluator.Eval(ca.expr) c.Assert(err, IsNil) c.Assert(result.Kind(), Equals, ca.result.Kind()) cmp, err := result.CompareDatum(xevaluator.sc, ca.result) c.Assert(err, IsNil) c.Assert(cmp, Equals, 0) } }
func (e *Evaluator) checkAnyResult(cs *ast.CompareSubqueryExpr, lv types.Datum, result []types.Datum) (d types.Datum, err error) { hasNull := false for _, v := range result { if v.IsNull() { hasNull = true continue } comRes, err1 := lv.CompareDatum(v) if err1 != nil { return d, errors.Trace(err1) } res, err1 := getCompResult(cs.Op, comRes) if err1 != nil { return d, errors.Trace(err1) } if res { d.SetInt64(boolToInt64(true)) return d, nil } } if hasNull { // If no matched but we get null, return null. // Like `insert t (c) values (1),(2),(null)`, then // `select 0 > any (select c from t)`, returns null. return d, nil } d.SetInt64(boolToInt64(false)) return d, nil }
// SetSystemVar sets a system variable. func (s *SessionVars) SetSystemVar(key string, value types.Datum) error { key = strings.ToLower(key) if value.IsNull() { if key != characterSetResults { return errCantSetToNull } delete(s.systems, key) return nil } sVal, err := value.ToString() if err != nil { return errors.Trace(err) } if key == sqlMode { sVal = strings.ToUpper(sVal) if strings.Contains(sVal, "STRICT_TRANS_TABLES") || strings.Contains(sVal, "STRICT_ALL_TABLES") { s.StrictSQLMode = true } else { s.StrictSQLMode = false } } else if key == TiDBSnapshot { err = s.setSnapshotTS(sVal) if err != nil { return errors.Trace(err) } } s.systems[key] = sVal return nil }
// SetSystemVar sets system variable and updates SessionVars states. func SetSystemVar(vars *variable.SessionVars, name string, value types.Datum) error { name = strings.ToLower(name) if value.IsNull() { if name != variable.CharacterSetResults { return variable.ErrCantSetToNull } delete(vars.Systems, name) return nil } sVal, err := value.ToString() if err != nil { return errors.Trace(err) } switch name { case variable.SQLModeVar: sVal = strings.ToUpper(sVal) if strings.Contains(sVal, "STRICT_TRANS_TABLES") || strings.Contains(sVal, "STRICT_ALL_TABLES") { vars.StrictSQLMode = true } else { vars.StrictSQLMode = false } case variable.TiDBSnapshot: err = setSnapshotTS(vars, sVal) if err != nil { return errors.Trace(err) } case variable.AutocommitVar: isAutocommit := strings.EqualFold(sVal, "ON") || sVal == "1" vars.SetStatusFlag(mysql.ServerStatusAutocommit, isAutocommit) case variable.TiDBSkipConstraintCheck: vars.SkipConstraintCheck = (sVal == "1") } vars.Systems[name] = sVal return nil }
// See http://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-format func builtinDateFormat(args []types.Datum, ctx context.Context) (types.Datum, error) { var ( isPercent bool ret []byte d types.Datum ) // TODO: Some invalid format like 2000-00-01(the month is 0) will return null. for _, b := range []byte(args[1].GetString()) { if isPercent { if b == '%' { ret = append(ret, b) } else { str, err := convertDateFormat(ctx, args[0], b) if err != nil { return types.Datum{}, errors.Trace(err) } if str.IsNull() { return types.Datum{}, nil } ret = append(ret, str.GetString()...) } isPercent = false continue } if b == '%' { isPercent = true } else { ret = append(ret, b) } } d.SetString(string(ret)) return d, nil }
func (e *Evaluator) evalUint(val []byte) (types.Datum, error) { var d types.Datum _, u, err := codec.DecodeUint(val) if err != nil { return d, ErrInvalid.Gen("invalid uint % x", val) } d.SetUint64(u) return d, nil }
// GetSystemVar gets a system variable. func (s *SessionVars) GetSystemVar(key string) types.Datum { var d types.Datum key = strings.ToLower(key) sVal, ok := s.systems[key] if ok { d.SetString(sVal) } return d }
func (n *finalAggregater) updateFirst(val types.Datum) error { ctx := n.getContext() if ctx.Evaluated { return nil } ctx.Value = val.GetValue() ctx.Evaluated = true return nil }
func (e *Evaluator) evalDecimal(val []byte) (types.Datum, error) { var d types.Datum _, dec, err := codec.DecodeDecimal(val) if err != nil { return d, ErrInvalid.Gen("invalid decimal % x", val) } d.SetMysqlDecimal(dec) return d, nil }
func (e *Evaluator) evalDuration(val []byte) (types.Datum, error) { var d types.Datum _, i, err := codec.DecodeInt(val) if err != nil { return d, ErrInvalid.Gen("invalid duration %d", i) } d.SetMysqlDuration(mysql.Duration{Duration: time.Duration(i), Fsp: mysql.MaxFsp}) return d, nil }
func (r *rangeBuilder) buildFormBinOp(expr *expression.ScalarFunction) []rangePoint { // This has been checked that the binary operation is comparison operation, and one of // the operand is column name expression. var value types.Datum var op string if v, ok := expr.Args[0].(*expression.Constant); ok { value = v.Value switch expr.FuncName.L { case ast.GE: op = ast.LE case ast.GT: op = ast.LT case ast.LT: op = ast.GT case ast.LE: op = ast.GE default: op = expr.FuncName.L } } else { value = expr.Args[1].(*expression.Constant).Value op = expr.FuncName.L } if value.IsNull() { return nil } switch op { case ast.EQ: startPoint := rangePoint{value: value, start: true} endPoint := rangePoint{value: value} return []rangePoint{startPoint, endPoint} case ast.NE: startPoint1 := rangePoint{value: types.MinNotNullDatum(), start: true} endPoint1 := rangePoint{value: value, excl: true} startPoint2 := rangePoint{value: value, start: true, excl: true} endPoint2 := rangePoint{value: types.MaxValueDatum()} return []rangePoint{startPoint1, endPoint1, startPoint2, endPoint2} case ast.LT: startPoint := rangePoint{value: types.MinNotNullDatum(), start: true} endPoint := rangePoint{value: value, excl: true} return []rangePoint{startPoint, endPoint} case ast.LE: startPoint := rangePoint{value: types.MinNotNullDatum(), start: true} endPoint := rangePoint{value: value} return []rangePoint{startPoint, endPoint} case ast.GT: startPoint := rangePoint{value: value, start: true, excl: true} endPoint := rangePoint{value: types.MaxValueDatum()} return []rangePoint{startPoint, endPoint} case ast.GE: startPoint := rangePoint{value: value, start: true} endPoint := rangePoint{value: types.MaxValueDatum()} return []rangePoint{startPoint, endPoint} } return nil }
// Convert datum to gotime. // TODO: This is used for timediff(). After we finish time refactor, we should abandan this function. func convertToGoTime(sc *variable.StatementContext, d types.Datum) (t time.Time, err error) { if d.Kind() != types.KindMysqlTime { d, err = convertToTime(sc, d, mysql.TypeDatetime) if err != nil { return t, errors.Trace(err) } } t, err = d.GetMysqlTime().Time.GoTime() return t, errors.Trace(err) }
// CastValue casts a value based on column type. func CastValue(ctx context.Context, val types.Datum, col *Column) (casted types.Datum, err error) { casted, err = val.ConvertTo(&col.FieldType) if err != nil { if variable.GetSessionVars(ctx).StrictSQLMode { return casted, errors.Trace(err) } // TODO: add warnings. log.Warnf("cast value error %v", err) } return casted, nil }
func checkFsp(arg types.Datum) (int, error) { fsp, err := arg.ToInt64() if err != nil { return 0, errors.Trace(err) } if int(fsp) > mysql.MaxFsp { return 0, errors.Errorf("Too big precision %d specified. Maximum is 6.", fsp) } else if fsp < 0 { return 0, errors.Errorf("Invalid negative %d specified, must in [0, 6].", fsp) } return int(fsp), nil }
func (s *testEvaluatorSuite) TestFromUnixTime(c *C) { defer testleak.AfterTest(c)() tbl := []struct { isDecimal bool integralPart int64 fractionalPart int64 decimal float64 format string ansLen int }{ {false, 1451606400, 0, 0, "", 19}, {true, 1451606400, 123456000, 1451606400.123456, "", 26}, {true, 1451606400, 999999000, 1451606400.999999, "", 26}, {true, 1451606400, 999999900, 1451606400.9999999, "", 19}, {false, 1451606400, 0, 0, "%Y %D %M %h:%i:%s %x", 19}, {true, 1451606400, 123456000, 1451606400.123456, "%Y %D %M %h:%i:%s %x", 26}, {true, 1451606400, 999999000, 1451606400.999999, "%Y %D %M %h:%i:%s %x", 26}, {true, 1451606400, 999999900, 1451606400.9999999, "%Y %D %M %h:%i:%s %x", 19}, } for _, t := range tbl { var timestamp types.Datum if !t.isDecimal { timestamp.SetInt64(t.integralPart) } else { timestamp.SetFloat64(t.decimal) } // result of from_unixtime() is dependent on specific time zone. unixTime := time.Unix(t.integralPart, t.fractionalPart).Round(time.Microsecond).String()[:t.ansLen] if len(t.format) == 0 { v, err := builtinFromUnixTime([]types.Datum{timestamp}, s.ctx) c.Assert(err, IsNil) ans := v.GetMysqlTime() c.Assert(ans.String(), Equals, unixTime) } else { format := types.NewStringDatum(t.format) v, err := builtinFromUnixTime([]types.Datum{timestamp, format}, s.ctx) c.Assert(err, IsNil) result, err := builtinDateFormat([]types.Datum{types.NewStringDatum(unixTime), format}, s.ctx) c.Assert(err, IsNil) c.Assert(v.GetString(), Equals, result.GetString()) } } v, err := builtinFromUnixTime([]types.Datum{types.NewIntDatum(-12345)}, s.ctx) c.Assert(err, IsNil) c.Assert(v.Kind(), Equals, types.KindNull) _, err = builtinFromUnixTime([]types.Datum{types.NewIntDatum(math.MaxInt32 + 1)}, s.ctx) c.Assert(err, IsNil) c.Assert(v.Kind(), Equals, types.KindNull) }
func (e *Evaluator) evalFloat(val []byte, f32 bool) (types.Datum, error) { var d types.Datum _, f, err := codec.DecodeFloat(val) if err != nil { return d, ErrInvalid.Gen("invalid float % x", val) } if f32 { d.SetFloat32(float32(f)) } else { d.SetFloat64(f) } return d, nil }
func (n *finalAggregater) updateSum(val types.Datum, count uint64) error { ctx := n.getContext() if val.IsNull() { return nil } var err error ctx.Value, err = types.CalculateSum(ctx.Value, val.GetValue()) if err != nil { return errors.Trace(err) } ctx.Count += int64(count) return nil }
// GetSystemVar gets a system variable. func GetSystemVar(s *variable.SessionVars, key string) types.Datum { var d types.Datum key = strings.ToLower(key) sVal, ok := s.Systems[key] if ok { d.SetString(sVal) } else { // TiDBSkipConstraintCheck is a session scope vars. We do not store it in the global table. if key == variable.TiDBSkipConstraintCheck { d.SetString(variable.SysVars[variable.TiDBSkipConstraintCheck].Value) } } return d }
func (h *rpcHandler) getRowByHandle(ctx *selectContext, handle int64) (*tipb.Row, error) { tid := ctx.sel.TableInfo.GetTableId() columns := ctx.sel.TableInfo.Columns row := new(tipb.Row) var d types.Datum d.SetInt64(handle) var err error row.Handle, err = codec.EncodeValue(nil, d) if err != nil { return nil, errors.Trace(err) } for _, col := range columns { if col.GetPkHandle() { if mysql.HasUnsignedFlag(uint(col.GetFlag())) { row.Data, err = codec.EncodeValue(row.Data, types.NewUintDatum(uint64(handle))) if err != nil { return nil, errors.Trace(err) } } else { row.Data = append(row.Data, row.Handle...) } } else { colID := col.GetColumnId() if ctx.whereColumns[colID] != nil { // The column is saved in evaluator, use it directly. datum := ctx.eval.Row[colID] row.Data, err = codec.EncodeValue(row.Data, datum) if err != nil { return nil, errors.Trace(err) } } else { key := tablecodec.EncodeColumnKey(tid, handle, colID) data, err1 := h.mvccStore.Get(key, ctx.sel.GetStartTs()) if err1 != nil { return nil, errors.Trace(err1) } if data == nil { if mysql.HasNotNullFlag(uint(col.GetFlag())) { return nil, errors.Trace(kv.ErrNotExist) } row.Data = append(row.Data, codec.NilFlag) } else { row.Data = append(row.Data, data...) } } } } return row, nil }
// See https://dev.mysql.com/doc/refman/5.5/en/date-and-time-functions.html#function_str-to-date func builtinStrToDate(args []types.Datum, _ context.Context) (types.Datum, error) { date := args[0].GetString() format := args[1].GetString() var ( d types.Datum t types.Time ) succ := t.StrToDate(date, format) if !succ { d.SetNull() return d, nil } d.SetMysqlTime(t) return d, nil }
func (sq *subquery) EvalRows(ctx context.Context, rowCount int) ([]types.Datum, error) { b := newExecutorBuilder(ctx, sq.is) plan.Refine(sq.plan) e := b.build(sq.plan) if b.err != nil { return nil, errors.Trace(b.err) } defer e.Close() if len(e.Fields()) == 0 { // No result fields means no Recordset. for { row, err := e.Next() if err != nil { return nil, errors.Trace(err) } if row == nil { return nil, nil } } } var ( err error row *Row rows []types.Datum ) for rowCount != 0 { row, err = e.Next() if err != nil { return rows, errors.Trace(err) } if row == nil { break } if len(row.Data) == 1 { rows = append(rows, row.Data[0]) } else { var d types.Datum d.SetRow(row.Data) rows = append(rows, d) } if rowCount > 0 { rowCount-- } } return rows, nil }
func abbrDayOfMonth(arg types.Datum, ctx context.Context) (types.Datum, error) { day, err := builtinDayOfMonth([]types.Datum{arg}, ctx) if err != nil || arg.IsNull() { return types.Datum{}, errors.Trace(err) } var str string switch day.GetInt64() { case 1, 21, 31: str = "st" case 2, 22: str = "nd" case 3, 23: str = "rd" default: str = "th" } d := types.NewStringDatum(fmt.Sprintf("%d%s", day.GetInt64(), str)) return d, nil }
func convertToDuration(sc *variable.StatementContext, arg types.Datum, fsp int) (d types.Datum, err error) { f := types.NewFieldType(mysql.TypeDuration) f.Decimal = fsp d, err = arg.ConvertTo(sc, f) if err != nil { d.SetNull() return d, errors.Trace(err) } if d.IsNull() { return d, nil } if d.Kind() != types.KindMysqlDuration { d.SetNull() return d, errors.Errorf("need duration type, but got %T", d.GetValue()) } return d, nil }
func convertToTime(sc *variable.StatementContext, arg types.Datum, tp byte) (d types.Datum, err error) { f := types.NewFieldType(tp) f.Decimal = types.MaxFsp d, err = arg.ConvertTo(sc, f) if err != nil { d.SetNull() return d, errors.Trace(err) } if d.IsNull() { return d, nil } if d.Kind() != types.KindMysqlTime { d.SetNull() return d, errors.Errorf("need time type, but got %T", d.GetValue()) } return d, nil }
func convertToTime(arg types.Datum, tp byte) (d types.Datum, err error) { f := types.NewFieldType(tp) f.Decimal = mysql.MaxFsp d, err = arg.ConvertTo(f) if err != nil { d.SetNull() return d, errors.Trace(err) } if d.Kind() == types.KindNull { return d, nil } if d.Kind() != types.KindMysqlTime { err = errors.Errorf("need time type, but got %T", d.GetValue()) d.SetNull() return d, err } return d, nil }
func computeDiv(a, b types.Datum) (d types.Datum, err error) { // MySQL support integer divison Div and division operator / // we use opcode.Div for division operator and will use another for integer division later. // for division operator, we will use float64 for calculation. switch a.Kind() { case types.KindFloat64: y, err1 := b.ToFloat64() if err1 != nil { return d, errors.Trace(err1) } if y == 0 { return d, nil } x := a.GetFloat64() d.SetFloat64(x / y) return d, nil default: // the scale of the result is the scale of the first operand plus // the value of the div_precision_increment system variable (which is 4 by default) // we will use 4 here xa, err1 := a.ToDecimal() if err != nil { return d, errors.Trace(err1) } xb, err1 := b.ToDecimal() if err1 != nil { return d, errors.Trace(err1) } if f, _ := xb.Float64(); f == 0 { // division by zero return null return d, nil } d.SetMysqlDecimal(xa.Div(xb)) return d, nil } }
func parseDayInterval(value types.Datum) (int64, error) { switch value.Kind() { case types.KindString: vs := value.GetString() s := strings.ToLower(vs) if s == "false" { return 0, nil } else if s == "true" { return 1, nil } value.SetString(reg.FindString(vs)) } return value.ToInt64() }
func (n *finalAggregater) updateMaxMin(val types.Datum, max bool) error { ctx := n.getContext() if val.IsNull() { return nil } if ctx.Value.IsNull() { ctx.Value = val return nil } c, err := ctx.Value.CompareDatum(val) if err != nil { return errors.Trace(err) } if max { if c == -1 { ctx.Value = val } } else if c == 1 { ctx.Value = val } return nil }
func testFrac(c *C, v *mysql.MyDecimal) { var d1 types.Datum d1.SetMysqlDecimal(v) b := EncodeDecimal([]byte{}, d1) _, d2, err := DecodeDecimal(b) c.Assert(err, IsNil) cmp, err := d1.CompareDatum(d2) c.Assert(err, IsNil) c.Assert(cmp, Equals, 0) c.Assert(d1.GetMysqlDecimal().String(), Equals, d2.GetMysqlDecimal().String()) }
func convertToDuration(arg types.Datum, fsp int) (d types.Datum, err error) { f := types.NewFieldType(mysql.TypeDuration) f.Decimal = fsp d, err = arg.ConvertTo(f) if err != nil { d.SetNull() return d, errors.Trace(err) } if d.Kind() == types.KindNull { d.SetNull() return d, nil } if d.Kind() != types.KindMysqlDuration { err = errors.Errorf("need duration type, but got %T", d.GetValue()) d.SetNull() return d, err } return d, nil }