示例#1
0
func (s *testEvalSuite) TestEvalCoalesce(c *C) {
	colID := int64(1)
	row := make(map[int64]types.Datum)
	row[colID] = types.NewIntDatum(100)
	xevaluator := &Evaluator{Row: row}
	nullDatum := types.Datum{}
	nullDatum.SetNull()
	notNullDatum := types.NewStringDatum("not-null")
	cases := []struct {
		expr   *tipb.Expr
		result types.Datum
	}{
		{
			expr:   buildExpr(tipb.ExprType_Coalesce, nullDatum, nullDatum, nullDatum),
			result: nullDatum,
		},
		{
			expr:   buildExpr(tipb.ExprType_Coalesce, nullDatum, notNullDatum, nullDatum),
			result: notNullDatum,
		},
		{
			expr:   buildExpr(tipb.ExprType_Coalesce, nullDatum, notNullDatum, types.NewStringDatum("not-null-2"), nullDatum),
			result: notNullDatum,
		},
	}
	for _, ca := range cases {
		result, err := xevaluator.Eval(ca.expr)
		c.Assert(err, IsNil)
		c.Assert(result.Kind(), Equals, ca.result.Kind())
		cmp, err := result.CompareDatum(xevaluator.sc, ca.result)
		c.Assert(err, IsNil)
		c.Assert(cmp, Equals, 0)
	}
}
示例#2
0
func (e *Evaluator) checkAnyResult(cs *ast.CompareSubqueryExpr, lv types.Datum, result []types.Datum) (d types.Datum, err error) {
	hasNull := false
	for _, v := range result {
		if v.IsNull() {
			hasNull = true
			continue
		}

		comRes, err1 := lv.CompareDatum(v)
		if err1 != nil {
			return d, errors.Trace(err1)
		}

		res, err1 := getCompResult(cs.Op, comRes)
		if err1 != nil {
			return d, errors.Trace(err1)
		}
		if res {
			d.SetInt64(boolToInt64(true))
			return d, nil
		}
	}

	if hasNull {
		// If no matched but we get null, return null.
		// Like `insert t (c) values (1),(2),(null)`, then
		// `select 0 > any (select c from t)`, returns null.
		return d, nil
	}

	d.SetInt64(boolToInt64(false))
	return d, nil
}
示例#3
0
文件: session.go 项目: jmptrader/tidb
// SetSystemVar sets a system variable.
func (s *SessionVars) SetSystemVar(key string, value types.Datum) error {
	key = strings.ToLower(key)
	if value.IsNull() {
		if key != characterSetResults {
			return errCantSetToNull
		}
		delete(s.systems, key)
		return nil
	}
	sVal, err := value.ToString()
	if err != nil {
		return errors.Trace(err)
	}
	if key == sqlMode {
		sVal = strings.ToUpper(sVal)
		if strings.Contains(sVal, "STRICT_TRANS_TABLES") || strings.Contains(sVal, "STRICT_ALL_TABLES") {
			s.StrictSQLMode = true
		} else {
			s.StrictSQLMode = false
		}
	} else if key == TiDBSnapshot {
		err = s.setSnapshotTS(sVal)
		if err != nil {
			return errors.Trace(err)
		}
	}
	s.systems[key] = sVal
	return nil
}
示例#4
0
文件: varsutil.go 项目: pingcap/tidb
// SetSystemVar sets system variable and updates SessionVars states.
func SetSystemVar(vars *variable.SessionVars, name string, value types.Datum) error {
	name = strings.ToLower(name)
	if value.IsNull() {
		if name != variable.CharacterSetResults {
			return variable.ErrCantSetToNull
		}
		delete(vars.Systems, name)
		return nil
	}
	sVal, err := value.ToString()
	if err != nil {
		return errors.Trace(err)
	}
	switch name {
	case variable.SQLModeVar:
		sVal = strings.ToUpper(sVal)
		if strings.Contains(sVal, "STRICT_TRANS_TABLES") || strings.Contains(sVal, "STRICT_ALL_TABLES") {
			vars.StrictSQLMode = true
		} else {
			vars.StrictSQLMode = false
		}
	case variable.TiDBSnapshot:
		err = setSnapshotTS(vars, sVal)
		if err != nil {
			return errors.Trace(err)
		}
	case variable.AutocommitVar:
		isAutocommit := strings.EqualFold(sVal, "ON") || sVal == "1"
		vars.SetStatusFlag(mysql.ServerStatusAutocommit, isAutocommit)
	case variable.TiDBSkipConstraintCheck:
		vars.SkipConstraintCheck = (sVal == "1")
	}
	vars.Systems[name] = sVal
	return nil
}
示例#5
0
// See http://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-format
func builtinDateFormat(args []types.Datum, ctx context.Context) (types.Datum, error) {
	var (
		isPercent bool
		ret       []byte
		d         types.Datum
	)

	// TODO: Some invalid format like 2000-00-01(the month is 0) will return null.
	for _, b := range []byte(args[1].GetString()) {
		if isPercent {
			if b == '%' {
				ret = append(ret, b)
			} else {
				str, err := convertDateFormat(ctx, args[0], b)
				if err != nil {
					return types.Datum{}, errors.Trace(err)
				}
				if str.IsNull() {
					return types.Datum{}, nil
				}
				ret = append(ret, str.GetString()...)
			}
			isPercent = false
			continue
		}
		if b == '%' {
			isPercent = true
		} else {
			ret = append(ret, b)
		}
	}
	d.SetString(string(ret))
	return d, nil
}
示例#6
0
文件: eval.go 项目: XuHuaiyu/tidb
func (e *Evaluator) evalUint(val []byte) (types.Datum, error) {
	var d types.Datum
	_, u, err := codec.DecodeUint(val)
	if err != nil {
		return d, ErrInvalid.Gen("invalid uint % x", val)
	}
	d.SetUint64(u)
	return d, nil
}
示例#7
0
文件: session.go 项目: jmptrader/tidb
// GetSystemVar gets a system variable.
func (s *SessionVars) GetSystemVar(key string) types.Datum {
	var d types.Datum
	key = strings.ToLower(key)
	sVal, ok := s.systems[key]
	if ok {
		d.SetString(sVal)
	}
	return d
}
示例#8
0
func (n *finalAggregater) updateFirst(val types.Datum) error {
	ctx := n.getContext()
	if ctx.Evaluated {
		return nil
	}
	ctx.Value = val.GetValue()
	ctx.Evaluated = true
	return nil
}
示例#9
0
文件: eval.go 项目: XuHuaiyu/tidb
func (e *Evaluator) evalDecimal(val []byte) (types.Datum, error) {
	var d types.Datum
	_, dec, err := codec.DecodeDecimal(val)
	if err != nil {
		return d, ErrInvalid.Gen("invalid decimal % x", val)
	}
	d.SetMysqlDecimal(dec)
	return d, nil
}
示例#10
0
文件: eval.go 项目: XuHuaiyu/tidb
func (e *Evaluator) evalDuration(val []byte) (types.Datum, error) {
	var d types.Datum
	_, i, err := codec.DecodeInt(val)
	if err != nil {
		return d, ErrInvalid.Gen("invalid duration %d", i)
	}
	d.SetMysqlDuration(mysql.Duration{Duration: time.Duration(i), Fsp: mysql.MaxFsp})
	return d, nil
}
示例#11
0
文件: range.go 项目: jmptrader/tidb
func (r *rangeBuilder) buildFormBinOp(expr *expression.ScalarFunction) []rangePoint {
	// This has been checked that the binary operation is comparison operation, and one of
	// the operand is column name expression.
	var value types.Datum
	var op string
	if v, ok := expr.Args[0].(*expression.Constant); ok {
		value = v.Value
		switch expr.FuncName.L {
		case ast.GE:
			op = ast.LE
		case ast.GT:
			op = ast.LT
		case ast.LT:
			op = ast.GT
		case ast.LE:
			op = ast.GE
		default:
			op = expr.FuncName.L
		}
	} else {
		value = expr.Args[1].(*expression.Constant).Value
		op = expr.FuncName.L
	}
	if value.IsNull() {
		return nil
	}

	switch op {
	case ast.EQ:
		startPoint := rangePoint{value: value, start: true}
		endPoint := rangePoint{value: value}
		return []rangePoint{startPoint, endPoint}
	case ast.NE:
		startPoint1 := rangePoint{value: types.MinNotNullDatum(), start: true}
		endPoint1 := rangePoint{value: value, excl: true}
		startPoint2 := rangePoint{value: value, start: true, excl: true}
		endPoint2 := rangePoint{value: types.MaxValueDatum()}
		return []rangePoint{startPoint1, endPoint1, startPoint2, endPoint2}
	case ast.LT:
		startPoint := rangePoint{value: types.MinNotNullDatum(), start: true}
		endPoint := rangePoint{value: value, excl: true}
		return []rangePoint{startPoint, endPoint}
	case ast.LE:
		startPoint := rangePoint{value: types.MinNotNullDatum(), start: true}
		endPoint := rangePoint{value: value}
		return []rangePoint{startPoint, endPoint}
	case ast.GT:
		startPoint := rangePoint{value: value, start: true, excl: true}
		endPoint := rangePoint{value: types.MaxValueDatum()}
		return []rangePoint{startPoint, endPoint}
	case ast.GE:
		startPoint := rangePoint{value: value, start: true}
		endPoint := rangePoint{value: types.MaxValueDatum()}
		return []rangePoint{startPoint, endPoint}
	}
	return nil
}
示例#12
0
// Convert datum to gotime.
// TODO: This is used for timediff(). After we finish time refactor, we should abandan this function.
func convertToGoTime(sc *variable.StatementContext, d types.Datum) (t time.Time, err error) {
	if d.Kind() != types.KindMysqlTime {
		d, err = convertToTime(sc, d, mysql.TypeDatetime)
		if err != nil {
			return t, errors.Trace(err)
		}
	}
	t, err = d.GetMysqlTime().Time.GoTime()
	return t, errors.Trace(err)
}
示例#13
0
// CastValue casts a value based on column type.
func CastValue(ctx context.Context, val types.Datum, col *Column) (casted types.Datum, err error) {
	casted, err = val.ConvertTo(&col.FieldType)
	if err != nil {
		if variable.GetSessionVars(ctx).StrictSQLMode {
			return casted, errors.Trace(err)
		}
		// TODO: add warnings.
		log.Warnf("cast value error %v", err)
	}
	return casted, nil
}
示例#14
0
func checkFsp(arg types.Datum) (int, error) {
	fsp, err := arg.ToInt64()
	if err != nil {
		return 0, errors.Trace(err)
	}
	if int(fsp) > mysql.MaxFsp {
		return 0, errors.Errorf("Too big precision %d specified. Maximum is 6.", fsp)
	} else if fsp < 0 {
		return 0, errors.Errorf("Invalid negative %d specified, must in [0, 6].", fsp)
	}
	return int(fsp), nil
}
示例#15
0
func (s *testEvaluatorSuite) TestFromUnixTime(c *C) {
	defer testleak.AfterTest(c)()

	tbl := []struct {
		isDecimal      bool
		integralPart   int64
		fractionalPart int64
		decimal        float64
		format         string
		ansLen         int
	}{
		{false, 1451606400, 0, 0, "", 19},
		{true, 1451606400, 123456000, 1451606400.123456, "", 26},
		{true, 1451606400, 999999000, 1451606400.999999, "", 26},
		{true, 1451606400, 999999900, 1451606400.9999999, "", 19},
		{false, 1451606400, 0, 0, "%Y %D %M %h:%i:%s %x", 19},
		{true, 1451606400, 123456000, 1451606400.123456, "%Y %D %M %h:%i:%s %x", 26},
		{true, 1451606400, 999999000, 1451606400.999999, "%Y %D %M %h:%i:%s %x", 26},
		{true, 1451606400, 999999900, 1451606400.9999999, "%Y %D %M %h:%i:%s %x", 19},
	}
	for _, t := range tbl {
		var timestamp types.Datum
		if !t.isDecimal {
			timestamp.SetInt64(t.integralPart)
		} else {
			timestamp.SetFloat64(t.decimal)
		}
		// result of from_unixtime() is dependent on specific time zone.
		unixTime := time.Unix(t.integralPart, t.fractionalPart).Round(time.Microsecond).String()[:t.ansLen]
		if len(t.format) == 0 {
			v, err := builtinFromUnixTime([]types.Datum{timestamp}, s.ctx)
			c.Assert(err, IsNil)
			ans := v.GetMysqlTime()
			c.Assert(ans.String(), Equals, unixTime)
		} else {
			format := types.NewStringDatum(t.format)
			v, err := builtinFromUnixTime([]types.Datum{timestamp, format}, s.ctx)
			c.Assert(err, IsNil)
			result, err := builtinDateFormat([]types.Datum{types.NewStringDatum(unixTime), format}, s.ctx)
			c.Assert(err, IsNil)
			c.Assert(v.GetString(), Equals, result.GetString())
		}
	}

	v, err := builtinFromUnixTime([]types.Datum{types.NewIntDatum(-12345)}, s.ctx)
	c.Assert(err, IsNil)
	c.Assert(v.Kind(), Equals, types.KindNull)

	_, err = builtinFromUnixTime([]types.Datum{types.NewIntDatum(math.MaxInt32 + 1)}, s.ctx)
	c.Assert(err, IsNil)
	c.Assert(v.Kind(), Equals, types.KindNull)
}
示例#16
0
文件: eval.go 项目: XuHuaiyu/tidb
func (e *Evaluator) evalFloat(val []byte, f32 bool) (types.Datum, error) {
	var d types.Datum
	_, f, err := codec.DecodeFloat(val)
	if err != nil {
		return d, ErrInvalid.Gen("invalid float % x", val)
	}
	if f32 {
		d.SetFloat32(float32(f))
	} else {
		d.SetFloat64(f)
	}
	return d, nil
}
示例#17
0
func (n *finalAggregater) updateSum(val types.Datum, count uint64) error {
	ctx := n.getContext()
	if val.IsNull() {
		return nil
	}
	var err error
	ctx.Value, err = types.CalculateSum(ctx.Value, val.GetValue())
	if err != nil {
		return errors.Trace(err)
	}
	ctx.Count += int64(count)
	return nil
}
示例#18
0
文件: varsutil.go 项目: pingcap/tidb
// GetSystemVar gets a system variable.
func GetSystemVar(s *variable.SessionVars, key string) types.Datum {
	var d types.Datum
	key = strings.ToLower(key)
	sVal, ok := s.Systems[key]
	if ok {
		d.SetString(sVal)
	} else {
		// TiDBSkipConstraintCheck is a session scope vars. We do not store it in the global table.
		if key == variable.TiDBSkipConstraintCheck {
			d.SetString(variable.SysVars[variable.TiDBSkipConstraintCheck].Value)
		}
	}
	return d
}
示例#19
0
func (h *rpcHandler) getRowByHandle(ctx *selectContext, handle int64) (*tipb.Row, error) {
	tid := ctx.sel.TableInfo.GetTableId()
	columns := ctx.sel.TableInfo.Columns
	row := new(tipb.Row)
	var d types.Datum
	d.SetInt64(handle)
	var err error
	row.Handle, err = codec.EncodeValue(nil, d)
	if err != nil {
		return nil, errors.Trace(err)
	}
	for _, col := range columns {
		if col.GetPkHandle() {
			if mysql.HasUnsignedFlag(uint(col.GetFlag())) {
				row.Data, err = codec.EncodeValue(row.Data, types.NewUintDatum(uint64(handle)))
				if err != nil {
					return nil, errors.Trace(err)
				}
			} else {
				row.Data = append(row.Data, row.Handle...)
			}
		} else {
			colID := col.GetColumnId()
			if ctx.whereColumns[colID] != nil {
				// The column is saved in evaluator, use it directly.
				datum := ctx.eval.Row[colID]
				row.Data, err = codec.EncodeValue(row.Data, datum)
				if err != nil {
					return nil, errors.Trace(err)
				}
			} else {
				key := tablecodec.EncodeColumnKey(tid, handle, colID)
				data, err1 := h.mvccStore.Get(key, ctx.sel.GetStartTs())
				if err1 != nil {
					return nil, errors.Trace(err1)
				}
				if data == nil {
					if mysql.HasNotNullFlag(uint(col.GetFlag())) {
						return nil, errors.Trace(kv.ErrNotExist)
					}
					row.Data = append(row.Data, codec.NilFlag)
				} else {
					row.Data = append(row.Data, data...)
				}
			}
		}
	}
	return row, nil
}
示例#20
0
// See https://dev.mysql.com/doc/refman/5.5/en/date-and-time-functions.html#function_str-to-date
func builtinStrToDate(args []types.Datum, _ context.Context) (types.Datum, error) {
	date := args[0].GetString()
	format := args[1].GetString()
	var (
		d types.Datum
		t types.Time
	)

	succ := t.StrToDate(date, format)
	if !succ {
		d.SetNull()
		return d, nil
	}

	d.SetMysqlTime(t)
	return d, nil
}
示例#21
0
文件: subquery.go 项目: XuHuaiyu/tidb
func (sq *subquery) EvalRows(ctx context.Context, rowCount int) ([]types.Datum, error) {
	b := newExecutorBuilder(ctx, sq.is)
	plan.Refine(sq.plan)
	e := b.build(sq.plan)
	if b.err != nil {
		return nil, errors.Trace(b.err)
	}
	defer e.Close()
	if len(e.Fields()) == 0 {
		// No result fields means no Recordset.
		for {
			row, err := e.Next()
			if err != nil {
				return nil, errors.Trace(err)
			}
			if row == nil {
				return nil, nil
			}
		}
	}
	var (
		err  error
		row  *Row
		rows []types.Datum
	)
	for rowCount != 0 {
		row, err = e.Next()
		if err != nil {
			return rows, errors.Trace(err)
		}
		if row == nil {
			break
		}
		if len(row.Data) == 1 {
			rows = append(rows, row.Data[0])
		} else {
			var d types.Datum
			d.SetRow(row.Data)
			rows = append(rows, d)
		}
		if rowCount > 0 {
			rowCount--
		}
	}
	return rows, nil
}
示例#22
0
func abbrDayOfMonth(arg types.Datum, ctx context.Context) (types.Datum, error) {
	day, err := builtinDayOfMonth([]types.Datum{arg}, ctx)
	if err != nil || arg.IsNull() {
		return types.Datum{}, errors.Trace(err)
	}
	var str string
	switch day.GetInt64() {
	case 1, 21, 31:
		str = "st"
	case 2, 22:
		str = "nd"
	case 3, 23:
		str = "rd"
	default:
		str = "th"
	}

	d := types.NewStringDatum(fmt.Sprintf("%d%s", day.GetInt64(), str))
	return d, nil
}
示例#23
0
func convertToDuration(sc *variable.StatementContext, arg types.Datum, fsp int) (d types.Datum, err error) {
	f := types.NewFieldType(mysql.TypeDuration)
	f.Decimal = fsp

	d, err = arg.ConvertTo(sc, f)
	if err != nil {
		d.SetNull()
		return d, errors.Trace(err)
	}

	if d.IsNull() {
		return d, nil
	}

	if d.Kind() != types.KindMysqlDuration {
		d.SetNull()
		return d, errors.Errorf("need duration type, but got %T", d.GetValue())
	}
	return d, nil
}
示例#24
0
func convertToTime(sc *variable.StatementContext, arg types.Datum, tp byte) (d types.Datum, err error) {
	f := types.NewFieldType(tp)
	f.Decimal = types.MaxFsp

	d, err = arg.ConvertTo(sc, f)
	if err != nil {
		d.SetNull()
		return d, errors.Trace(err)
	}

	if d.IsNull() {
		return d, nil
	}

	if d.Kind() != types.KindMysqlTime {
		d.SetNull()
		return d, errors.Errorf("need time type, but got %T", d.GetValue())
	}
	return d, nil
}
示例#25
0
func convertToTime(arg types.Datum, tp byte) (d types.Datum, err error) {
	f := types.NewFieldType(tp)
	f.Decimal = mysql.MaxFsp

	d, err = arg.ConvertTo(f)
	if err != nil {
		d.SetNull()
		return d, errors.Trace(err)
	}

	if d.Kind() == types.KindNull {
		return d, nil
	}

	if d.Kind() != types.KindMysqlTime {
		err = errors.Errorf("need time type, but got %T", d.GetValue())
		d.SetNull()
		return d, err
	}
	return d, nil
}
示例#26
0
func computeDiv(a, b types.Datum) (d types.Datum, err error) {
	// MySQL support integer divison Div and division operator /
	// we use opcode.Div for division operator and will use another for integer division later.
	// for division operator, we will use float64 for calculation.
	switch a.Kind() {
	case types.KindFloat64:
		y, err1 := b.ToFloat64()
		if err1 != nil {
			return d, errors.Trace(err1)
		}

		if y == 0 {
			return d, nil
		}

		x := a.GetFloat64()
		d.SetFloat64(x / y)
		return d, nil
	default:
		// the scale of the result is the scale of the first operand plus
		// the value of the div_precision_increment system variable (which is 4 by default)
		// we will use 4 here
		xa, err1 := a.ToDecimal()
		if err != nil {
			return d, errors.Trace(err1)
		}

		xb, err1 := b.ToDecimal()
		if err1 != nil {
			return d, errors.Trace(err1)
		}
		if f, _ := xb.Float64(); f == 0 {
			// division by zero return null
			return d, nil
		}

		d.SetMysqlDecimal(xa.Div(xb))
		return d, nil
	}
}
示例#27
0
func parseDayInterval(value types.Datum) (int64, error) {
	switch value.Kind() {
	case types.KindString:
		vs := value.GetString()
		s := strings.ToLower(vs)
		if s == "false" {
			return 0, nil
		} else if s == "true" {
			return 1, nil
		}
		value.SetString(reg.FindString(vs))
	}
	return value.ToInt64()
}
示例#28
0
func (n *finalAggregater) updateMaxMin(val types.Datum, max bool) error {
	ctx := n.getContext()
	if val.IsNull() {
		return nil
	}
	if ctx.Value.IsNull() {
		ctx.Value = val
		return nil
	}
	c, err := ctx.Value.CompareDatum(val)
	if err != nil {
		return errors.Trace(err)
	}
	if max {
		if c == -1 {
			ctx.Value = val
		}
	} else if c == 1 {
		ctx.Value = val
	}
	return nil
}
示例#29
0
func testFrac(c *C, v *mysql.MyDecimal) {
	var d1 types.Datum
	d1.SetMysqlDecimal(v)
	b := EncodeDecimal([]byte{}, d1)
	_, d2, err := DecodeDecimal(b)
	c.Assert(err, IsNil)
	cmp, err := d1.CompareDatum(d2)
	c.Assert(err, IsNil)
	c.Assert(cmp, Equals, 0)
	c.Assert(d1.GetMysqlDecimal().String(), Equals, d2.GetMysqlDecimal().String())
}
示例#30
0
func convertToDuration(arg types.Datum, fsp int) (d types.Datum, err error) {
	f := types.NewFieldType(mysql.TypeDuration)
	f.Decimal = fsp

	d, err = arg.ConvertTo(f)
	if err != nil {
		d.SetNull()
		return d, errors.Trace(err)
	}

	if d.Kind() == types.KindNull {
		d.SetNull()
		return d, nil
	}

	if d.Kind() != types.KindMysqlDuration {
		err = errors.Errorf("need duration type, but got %T", d.GetValue())
		d.SetNull()
		return d, err
	}
	return d, nil
}