func (e *InsertExec) onDuplicateUpdate(row []types.Datum, h int64, cols map[int]*ast.Assignment) error { // On duplicate key update the duplicate row. // Evaluate the updated value. // TODO: report rows affected and last insert id. data, err := e.Table.Row(e.ctx, h) if err != nil { return errors.Trace(err) } // For evaluate ValuesExpr // http://dev.mysql.com/doc/refman/5.7/en/miscellaneous-functions.html#function_values for i, rf := range e.fields { rf.Expr.SetValue(row[i].GetValue()) } // Evaluate assignment newData := make([]types.Datum, len(data)) for i, c := range row { asgn, ok := cols[i] if !ok { newData[i] = c continue } val, err1 := evaluator.Eval(e.ctx, asgn.Expr) if err1 != nil { return errors.Trace(err1) } newData[i] = val } if err = updateRecord(e.ctx, h, data, newData, cols, e.Table, 0, true); err != nil { return errors.Trace(err) } return nil }
// Next implements Executor Next interface. func (e *SelectFieldsExec) Next() (*Row, error) { var rowKeys []*RowKeyEntry if e.Src != nil { srcRow, err := e.Src.Next() if err != nil { return nil, errors.Trace(err) } if srcRow == nil { return nil, nil } rowKeys = srcRow.RowKeys } else { // If Src is nil, only one row should be returned. if e.executed { return nil, nil } } e.executed = true row := &Row{ RowKeys: rowKeys, Data: make([]types.Datum, len(e.ResultFields)), } for i, field := range e.ResultFields { val, err := evaluator.Eval(e.ctx, field.Expr) if err != nil { return nil, errors.Trace(err) } row.Data[i] = val } return row, nil }
func getDefaultValue(ctx context.Context, c *ast.ColumnOption, tp byte, fsp int) (interface{}, error) { if tp == mysql.TypeTimestamp || tp == mysql.TypeDatetime { vd, err := evaluator.GetTimeValue(ctx, c.Expr, tp, fsp) value := vd.GetValue() if err != nil { return nil, errors.Trace(err) } // Value is nil means `default null`. if value == nil { return nil, nil } // If value is types.Time, convert it to string. if vv, ok := value.(types.Time); ok { return vv.String(), nil } return value, nil } v, err := evaluator.Eval(ctx, c.Expr) if err != nil { return nil, errors.Trace(err) } if v.IsNull() { return nil, nil } return v.ToString() }
// Next implements Executor Next interface. func (e *SortExec) Next() (*Row, error) { if !e.fetched { for { srcRow, err := e.Src.Next() if err != nil { return nil, errors.Trace(err) } if srcRow == nil { break } orderRow := &orderByRow{ row: srcRow, key: make([]interface{}, len(e.ByItems)), } for i, byItem := range e.ByItems { orderRow.key[i], err = evaluator.Eval(e.ctx, byItem.Expr) if err != nil { return nil, errors.Trace(err) } } e.Rows = append(e.Rows, orderRow) } sort.Sort(e) e.fetched = true } if e.err != nil { return nil, errors.Trace(e.err) } if e.Idx >= len(e.Rows) { return nil, nil } row := e.Rows[e.Idx].row e.Idx++ return row, nil }
func (e *InsertValues) getRow(cols []*table.Column, list []ast.ExprNode, defaultVals map[string]types.Datum) ([]types.Datum, error) { vals := make([]types.Datum, len(list)) var err error for i, expr := range list { if d, ok := expr.(*ast.DefaultExpr); ok { cn := d.Name if cn == nil { vals[i] = defaultVals[cols[i].Name.L] continue } var found bool vals[i], found = defaultVals[cn.Name.L] if !found { return nil, errors.Errorf("default column not found - %s", cn.Name.O) } } else { var val types.Datum val, err = evaluator.Eval(e.ctx, expr) vals[i] = val if err != nil { return nil, errors.Trace(err) } } } return e.fillRowData(cols, vals, false) }
func (e *UpdateExec) fetchRows() error { for { row, err := e.SelectExec.Next() if err != nil { return errors.Trace(err) } if row == nil { return nil } data := make([]types.Datum, len(e.SelectExec.Fields())) newData := make([]types.Datum, len(e.SelectExec.Fields())) for i, f := range e.SelectExec.Fields() { data[i] = types.NewDatum(f.Expr.GetValue()) newData[i] = data[i] if e.OrderedList[i] != nil { val, err := evaluator.Eval(e.ctx, e.OrderedList[i].Expr) if err != nil { return errors.Trace(err) } newData[i] = val } } row.Data = data e.rows = append(e.rows, row) e.newRowsData = append(e.newRowsData, newData) } }
func (e *SimpleExec) executeDo(s *ast.DoStmt) error { for _, expr := range s.Exprs { _, err := evaluator.Eval(e.ctx, expr) if err != nil { return errors.Trace(err) } } return nil }
// Next implements Executor Next interface. func (e *SortExec) Next() (*Row, error) { if !e.fetched { offset := -1 totalCount := -1 if e.Limit != nil { offset = int(e.Limit.Offset) totalCount = offset + int(e.Limit.Count) } for { srcRow, err := e.Src.Next() if err != nil { return nil, errors.Trace(err) } if srcRow == nil { break } orderRow := &orderByRow{ row: srcRow, key: make([]types.Datum, len(e.ByItems)), } for i, byItem := range e.ByItems { orderRow.key[i], err = evaluator.Eval(e.ctx, byItem.Expr) if err != nil { return nil, errors.Trace(err) } } e.Rows = append(e.Rows, orderRow) if totalCount != -1 && e.Len() >= totalCount+SortBufferSize { sort.Sort(e) e.Rows = e.Rows[:totalCount] } } sort.Sort(e) if offset >= 0 && offset < e.Len() { if totalCount > e.Len() { e.Rows = e.Rows[offset:] } else { e.Rows = e.Rows[offset:totalCount] } } else if offset != -1 { e.Rows = e.Rows[:0] } e.fetched = true } if e.err != nil { return nil, errors.Trace(e.err) } if e.Idx >= len(e.Rows) { return nil, nil } row := e.Rows[e.Idx].row e.Idx++ return row, nil }
// Build builds a prepared statement into an executor. func (e *ExecuteExec) Build() error { vars := variable.GetSessionVars(e.Ctx) if e.Name != "" { e.ID = vars.PreparedStmtNameToID[e.Name] } v := vars.PreparedStmts[e.ID] if v == nil { return ErrStmtNotFound } prepared := v.(*Prepared) if len(prepared.Params) != len(e.UsingVars) { return ErrWrongParamCount } for i, usingVar := range e.UsingVars { val, err := evaluator.Eval(e.Ctx, usingVar) if err != nil { return errors.Trace(err) } prepared.Params[i].SetDatum(val) } ast.ResetEvaluatedFlag(prepared.Stmt) if prepared.SchemaVersion != e.IS.SchemaMetaVersion() { // If the schema version has changed we need to prepare it again, // if this time it failed, the real reason for the error is schema changed. err := plan.PrepareStmt(e.IS, e.Ctx, prepared.Stmt) if err != nil { return ErrSchemaChanged.Gen("Schema change casued error: %s", err.Error()) } prepared.SchemaVersion = e.IS.SchemaMetaVersion() } sb := &subqueryBuilder{is: e.IS} p, err := plan.Optimize(e.Ctx, prepared.Stmt, sb, e.IS) if err != nil { return errors.Trace(err) } b := newExecutorBuilder(e.Ctx, e.IS) stmtExec := b.build(p) if b.err != nil { return errors.Trace(b.err) } e.StmtExec = stmtExec e.Stmt = prepared.Stmt return nil }
func (e *AggregateExec) getGroupKey() ([]byte, error) { if len(e.GroupByItems) == 0 { return singleGroup, nil } vals := make([]types.Datum, 0, len(e.GroupByItems)) for _, item := range e.GroupByItems { v, err := evaluator.Eval(e.ctx, item.Expr) if err != nil { return nil, errors.Trace(err) } vals = append(vals, v) } bs, err := codec.EncodeValue([]byte{}, vals...) if err != nil { return nil, errors.Trace(err) } return bs, nil }
func (e *SimpleExec) getVarValue(v *ast.VariableAssignment, sysVar *variable.SysVar, globalVars variable.GlobalVarAccessor) (value types.Datum, err error) { switch v.Value.(type) { case *ast.DefaultExpr: // To set a SESSION variable to the GLOBAL value or a GLOBAL value // to the compiled-in MySQL default value, use the DEFAULT keyword. // See http://dev.mysql.com/doc/refman/5.7/en/set-statement.html if sysVar != nil { value = types.NewStringDatum(sysVar.Value) } else { s, err1 := globalVars.GetGlobalSysVar(e.ctx, strings.ToLower(v.Name)) if err1 != nil { return value, errors.Trace(err1) } value = types.NewStringDatum(s) } default: value, err = evaluator.Eval(e.ctx, v.Value) } return value, errors.Trace(err) }
// onDuplicateUpdate updates the duplicate row. // TODO: Report rows affected and last insert id. func (e *InsertExec) onDuplicateUpdate(row []types.Datum, h int64, cols map[int]*ast.Assignment) error { data, err := e.Table.Row(e.ctx, h) if err != nil { return errors.Trace(err) } // for evaluating ColumnNameExpr for i, rf := range e.fields { rf.Expr.SetValue(data[i].GetValue()) } // for evaluating ValuesExpr // See http://dev.mysql.com/doc/refman/5.7/en/miscellaneous-functions.html#function_values e.ctx.GetSessionVars().CurrInsertValues = row // evaluate assignment newData := make([]types.Datum, len(data)) for i, c := range row { asgn, ok := cols[i] if !ok { newData[i] = c continue } val, err1 := evaluator.Eval(e.ctx, asgn.Expr) if err1 != nil { return errors.Trace(err1) } newData[i] = val } assignFlag := make([]bool, len(e.Table.Cols())) for i, asgn := range cols { if asgn != nil { assignFlag[i] = true } else { assignFlag[i] = false } } if err = updateRecord(e.ctx, h, data, newData, assignFlag, e.Table, 0, true); err != nil { return errors.Trace(err) } return nil }
func (r *preEvaluator) Leave(in ast.Node) (ast.Node, bool) { if expr, ok := in.(ast.ExprNode); ok { if _, ok = expr.(*ast.ValueExpr); ok { return in, true } else if ast.IsPreEvaluable(expr) { val, err := evaluator.Eval(r.ctx, expr) if err != nil { r.err = err return in, false } if ast.IsConstant(expr) { // The expression is constant, rewrite the expression to value expression. valExpr := &ast.ValueExpr{} valExpr.SetText(expr.Text()) valExpr.SetType(expr.GetType()) valExpr.SetDatum(val) return valExpr, true } expr.SetDatum(val) } } return in, true }
// Fetch a single row from src and update each aggregate function. // If the first return value is false, it means there is no more data from src. func (e *AggregateExec) innerNext() (bool, error) { if e.Src != nil { srcRow, err := e.Src.Next() if err != nil { return false, errors.Trace(err) } if srcRow == nil { return false, nil } } else { // If Src is nil, only one row should be returned. if e.executed { return false, nil } } e.executed = true groupKey, err := e.getGroupKey() if err != nil { return false, errors.Trace(err) } if _, ok := e.groupMap[string(groupKey)]; !ok { e.groupMap[string(groupKey)] = true e.groups = append(e.groups, groupKey) } for _, af := range e.AggFuncs { for _, arg := range af.Args { _, err := evaluator.Eval(e.ctx, arg) if err != nil { return false, errors.Trace(err) } } af.CurrentGroup = groupKey af.Update() } return true, nil }
func (s *testAggFuncSuite) TestCount(c *C) { // Compose aggregate exec for "select c1, count(c2) from t"; // c1 c2 // 1 1 // 2 1 // 3 nil c1 := ast.NewValueExpr(0) rf1 := &ast.ResultField{Expr: c1} col1 := &ast.ColumnNameExpr{Refer: rf1} fc1 := &ast.AggregateFuncExpr{ F: ast.AggFuncFirstRow, Args: []ast.ExprNode{col1}, } c2 := ast.NewValueExpr(0) rf2 := &ast.ResultField{Expr: c2} col2 := &ast.ColumnNameExpr{Refer: rf2} fc2 := &ast.AggregateFuncExpr{ F: ast.AggFuncCount, Args: []ast.ExprNode{col2}, } row1 := types.MakeDatums(1, 1) row2 := types.MakeDatums(2, 1) row3 := types.MakeDatums(3, nil) data := []([]types.Datum){row1, row2, row3} rows := make([]*Row, 0, 3) for _, d := range data { rows = append(rows, &Row{Data: d}) } src := &mockExec{ rows: rows, fields: []*ast.ResultField{rf1, rf2}, } agg := &AggregateExec{ AggFuncs: []*ast.AggregateFuncExpr{fc1, fc2}, Src: src, } var ( row *Row cnt int ) for { r, err := agg.Next() c.Assert(err, IsNil) if r == nil { break } row = r cnt++ } c.Assert(cnt, Equals, 1) c.Assert(row, NotNil) ctx := mock.NewContext() val, err := evaluator.Eval(ctx, fc1) c.Assert(err, IsNil) c.Assert(val, Equals, int64(1)) val, err = evaluator.Eval(ctx, fc2) c.Assert(err, IsNil) c.Assert(val, Equals, int64(2)) agg.Close() val, err = evaluator.Eval(ctx, fc1) c.Assert(err, IsNil) c.Assert(val, IsNil) val, err = evaluator.Eval(ctx, fc2) c.Assert(err, IsNil) c.Assert(val, Equals, int64(0)) }
// columnDefToCol converts ColumnDef to Col and TableConstraints. func columnDefToCol(ctx context.Context, offset int, colDef *ast.ColumnDef) (*table.Column, []*ast.Constraint, error) { constraints := []*ast.Constraint{} col := &table.Column{ Offset: offset, Name: colDef.Name.Name, FieldType: *colDef.Tp, } // Check and set TimestampFlag and OnUpdateNowFlag. if col.Tp == mysql.TypeTimestamp { col.Flag |= mysql.TimestampFlag col.Flag |= mysql.OnUpdateNowFlag col.Flag |= mysql.NotNullFlag } setOnUpdateNow := false hasDefaultValue := false if colDef.Options != nil { keys := []*ast.IndexColName{ { Column: colDef.Name, Length: colDef.Tp.Flen, }, } for _, v := range colDef.Options { switch v.Tp { case ast.ColumnOptionNotNull: col.Flag |= mysql.NotNullFlag case ast.ColumnOptionNull: col.Flag &= ^uint(mysql.NotNullFlag) removeOnUpdateNowFlag(col) case ast.ColumnOptionAutoIncrement: col.Flag |= mysql.AutoIncrementFlag case ast.ColumnOptionPrimaryKey: constraint := &ast.Constraint{Tp: ast.ConstraintPrimaryKey, Keys: keys} constraints = append(constraints, constraint) col.Flag |= mysql.PriKeyFlag case ast.ColumnOptionUniq: constraint := &ast.Constraint{Tp: ast.ConstraintUniq, Name: colDef.Name.Name.O, Keys: keys} constraints = append(constraints, constraint) col.Flag |= mysql.UniqueKeyFlag case ast.ColumnOptionIndex: constraint := &ast.Constraint{Tp: ast.ConstraintIndex, Name: colDef.Name.Name.O, Keys: keys} constraints = append(constraints, constraint) case ast.ColumnOptionUniqIndex: constraint := &ast.Constraint{Tp: ast.ConstraintUniqIndex, Name: colDef.Name.Name.O, Keys: keys} constraints = append(constraints, constraint) col.Flag |= mysql.UniqueKeyFlag case ast.ColumnOptionKey: constraint := &ast.Constraint{Tp: ast.ConstraintKey, Name: colDef.Name.Name.O, Keys: keys} constraints = append(constraints, constraint) case ast.ColumnOptionUniqKey: constraint := &ast.Constraint{Tp: ast.ConstraintUniqKey, Name: colDef.Name.Name.O, Keys: keys} constraints = append(constraints, constraint) col.Flag |= mysql.UniqueKeyFlag case ast.ColumnOptionDefaultValue: value, err := getDefaultValue(ctx, v, colDef.Tp.Tp, colDef.Tp.Decimal) if err != nil { return nil, nil, ErrColumnBadNull.Gen("invalid default value - %s", err) } col.DefaultValue = value hasDefaultValue = true removeOnUpdateNowFlag(col) case ast.ColumnOptionOnUpdate: if !evaluator.IsCurrentTimeExpr(v.Expr) { return nil, nil, ErrInvalidOnUpdate.Gen("invalid ON UPDATE for - %s", col.Name) } col.Flag |= mysql.OnUpdateNowFlag setOnUpdateNow = true case ast.ColumnOptionComment: value, err := evaluator.Eval(ctx, v.Expr) if err != nil { return nil, nil, errors.Trace(err) } col.Comment, err = value.ToString() if err != nil { return nil, nil, errors.Trace(err) } case ast.ColumnOptionFulltext: // Do nothing. } } } setTimestampDefaultValue(col, hasDefaultValue, setOnUpdateNow) // Set `NoDefaultValueFlag` if this field doesn't have a default value and // it is `not null` and not an `AUTO_INCREMENT` field or `TIMESTAMP` field. setNoDefaultValueFlag(col, hasDefaultValue) err := checkDefaultValue(col, hasDefaultValue) if err != nil { return nil, nil, errors.Trace(err) } if col.Charset == charset.CharsetBin { col.Flag |= mysql.BinaryFlag } return col, constraints, nil }
func (e *SimpleExec) executeSet(s *ast.SetStmt) error { sessionVars := variable.GetSessionVars(e.ctx) globalVars := variable.GetGlobalVarAccessor(e.ctx) for _, v := range s.Variables { // Variable is case insensitive, we use lower case. if v.Name == ast.SetNames { // This is set charset stmt. cs := v.Value.GetValue().(string) var co string if v.ExtendValue != nil { co = v.ExtendValue.GetValue().(string) } err := e.setCharset(cs, co) if err != nil { return errors.Trace(err) } continue } name := strings.ToLower(v.Name) if !v.IsSystem { // Set user variable. value, err := evaluator.Eval(e.ctx, v.Value) if err != nil { return errors.Trace(err) } if value.IsNull() { delete(sessionVars.Users, name) } else { svalue, err1 := value.ToString() if err1 != nil { return errors.Trace(err1) } sessionVars.Users[name] = fmt.Sprintf("%v", svalue) } continue } // Set system variable sysVar := variable.GetSysVar(name) if sysVar == nil { return variable.UnknownSystemVar.Gen("Unknown system variable '%s'", name) } if sysVar.Scope == variable.ScopeNone { return errors.Errorf("Variable '%s' is a read only variable", name) } if v.IsGlobal { // Set global scope system variable. if sysVar.Scope&variable.ScopeGlobal == 0 { return errors.Errorf("Variable '%s' is a SESSION variable and can't be used with SET GLOBAL", name) } value, err := evaluator.Eval(e.ctx, v.Value) if err != nil { return errors.Trace(err) } if value.IsNull() { value.SetString("") } svalue, err := value.ToString() if err != nil { return errors.Trace(err) } err = globalVars.SetGlobalSysVar(e.ctx, name, svalue) if err != nil { return errors.Trace(err) } } else { // Set session scope system variable. if sysVar.Scope&variable.ScopeSession == 0 { return errors.Errorf("Variable '%s' is a GLOBAL variable and should be set with SET GLOBAL", name) } value, err := evaluator.Eval(e.ctx, v.Value) if err != nil { return errors.Trace(err) } err = sessionVars.SetSystemVar(name, value) if err != nil { return errors.Trace(err) } } } return nil }
func (s *testAggFuncSuite) TestXAPIAvg(c *C) { defer testleak.AfterTest(c)() // Compose aggregate exec for "select avg(c2) from t groupby c1"; // // Data in region1: // c1 c2 // 1 11 // 2 21 // 1 1 // 3 2 // // Partial aggregate result for region1: // groupkey cnt sum // 1 2 12 // 2 1 21 // 3 1 2 // // Data in region2: // 1 nil // 1 3 // 3 31 // // Partial aggregate result for region2: // groupkey cnt sum // 1 1 3 // 3 1 31 // // Expected final aggregate result: // avg(c2) // 5 // 21 // 16.500000 c1 := ast.NewValueExpr([]byte{0}) rf1 := &ast.ResultField{Expr: c1} c2 := ast.NewValueExpr(0) rf2 := &ast.ResultField{Expr: c2} c3 := ast.NewValueExpr(0) rf3 := &ast.ResultField{Expr: c3} col2 := &ast.ColumnNameExpr{Refer: rf2} fc := &ast.AggregateFuncExpr{ F: ast.AggFuncAvg, Args: []ast.ExprNode{col2}, } // Return row: // GroupKey, Sum // Partial result from region1 row1 := types.MakeDatums([]byte{1}, 2, 12) row2 := types.MakeDatums([]byte{2}, 1, 21) row3 := types.MakeDatums([]byte{3}, 1, 2) // Partial result from region2 row4 := types.MakeDatums([]byte{1}, 1, 3) row5 := types.MakeDatums([]byte{3}, 1, 31) data := []([]types.Datum){row1, row2, row3, row4, row5} rows := make([]*Row, 0, 5) for _, d := range data { rows = append(rows, &Row{Data: d}) } src := &mockExec{ rows: rows, fields: []*ast.ResultField{rf1, rf2, rf3}, // groupby, cnt, sum } agg := &XAggregateExec{ AggFuncs: []*ast.AggregateFuncExpr{fc}, Src: src, } ast.SetFlag(fc) // First row: 5 row, err := agg.Next() c.Assert(err, IsNil) c.Assert(row, NotNil) ctx := mock.NewContext() val, err := evaluator.Eval(ctx, fc) c.Assert(err, IsNil) c.Assert(val, testutil.DatumEquals, types.NewDecimalDatum(mysql.NewDecimalFromInt(int64(5), 0))) // Second row: 21 row, err = agg.Next() c.Assert(err, IsNil) c.Assert(row, NotNil) val, err = evaluator.Eval(ctx, fc) c.Assert(err, IsNil) c.Assert(val, testutil.DatumEquals, types.NewDecimalDatum(mysql.NewDecimalFromInt(int64(21), 0))) // Third row: 16.5000 row, err = agg.Next() c.Assert(err, IsNil) c.Assert(row, NotNil) val, err = evaluator.Eval(ctx, fc) c.Assert(err, IsNil) d := mysql.NewDecimalFromFloat(float64(16.5)) d.SetFracDigits(4) // For div operator, default frac is 4. c.Assert(val, testutil.DatumEquals, types.NewDecimalDatum(d)) // Forth row: nil row, err = agg.Next() c.Assert(err, IsNil) c.Assert(row, IsNil) // Close executor err = agg.Close() c.Assert(err, IsNil) }
func (s *testAggFuncSuite) TestXAPIMaxMin(c *C) { defer testleak.AfterTest(c)() // Compose aggregate exec for "select max(c2), min(c2) from t groupby c1"; // // Data in region1: // c1 c2 // 1 11 // 2 21 // 1 1 // 3 2 // // Partial aggregate result for region1: // groupkey max(c2) min(c2) // 1 11 1 // 2 21 21 // 3 2 2 // // Data in region2: // 1 nil // 1 3 // 3 31 // 4 nil // // Partial aggregate result for region2: // groupkey max(c2) min(c2) // 1 3 3 // 3 31 31 // 4 nil nil // // Expected final aggregate result: // max(c2) min(c2) // 11 1 // 21 21 // 31 2 // nil nil c1 := ast.NewValueExpr([]byte{0}) rf1 := &ast.ResultField{Expr: c1} c2 := ast.NewValueExpr(0) rf2 := &ast.ResultField{Expr: c2} c3 := ast.NewValueExpr(0) rf3 := &ast.ResultField{Expr: c3} col2 := &ast.ColumnNameExpr{Refer: rf2} fc := &ast.AggregateFuncExpr{ F: ast.AggFuncMax, Args: []ast.ExprNode{col2}, } fc1 := &ast.AggregateFuncExpr{ F: ast.AggFuncMin, Args: []ast.ExprNode{col2}, } ast.SetFlag(fc) ast.SetFlag(fc1) // Return row: // GroupKey, max(c2), min(c2) // Partial result from region1 row1 := types.MakeDatums([]byte{1}, int64(11), int64(1)) row2 := types.MakeDatums([]byte{2}, int64(21), int64(21)) row3 := types.MakeDatums([]byte{3}, int64(2), int64(2)) // Partial result from region2 row4 := types.MakeDatums([]byte{1}, int64(3), int64(3)) row5 := types.MakeDatums([]byte{3}, int64(31), int64(31)) row6 := types.MakeDatums([]byte{4}, nil, nil) data := []([]types.Datum){row1, row2, row3, row4, row5, row6} rows := make([]*Row, 0, 6) for _, d := range data { rows = append(rows, &Row{Data: d}) } src := &mockExec{ rows: rows, fields: []*ast.ResultField{rf1, rf2, rf3}, // group, max(c2), min(c2) } agg := &XAggregateExec{ AggFuncs: []*ast.AggregateFuncExpr{fc, fc1}, Src: src, } ast.SetFlag(fc) // First row: 11, 1 row, err := agg.Next() c.Assert(err, IsNil) c.Assert(row, NotNil) ctx := mock.NewContext() val, err := evaluator.Eval(ctx, fc) c.Assert(err, IsNil) c.Assert(val, testutil.DatumEquals, types.NewDatum(int64(11))) val, err = evaluator.Eval(ctx, fc1) c.Assert(err, IsNil) c.Assert(val, testutil.DatumEquals, types.NewDatum(int64(1))) // Second row: 21, 21 row, err = agg.Next() c.Assert(err, IsNil) c.Assert(row, NotNil) val, err = evaluator.Eval(ctx, fc) c.Assert(err, IsNil) c.Assert(val, testutil.DatumEquals, types.NewDatum(int64(21))) val, err = evaluator.Eval(ctx, fc1) c.Assert(err, IsNil) c.Assert(val, testutil.DatumEquals, types.NewDatum(int64(21))) // Third row: 31, 2 row, err = agg.Next() c.Assert(err, IsNil) c.Assert(row, NotNil) val, err = evaluator.Eval(ctx, fc) c.Assert(err, IsNil) c.Assert(val, testutil.DatumEquals, types.NewDatum(int64(31))) val, err = evaluator.Eval(ctx, fc1) c.Assert(err, IsNil) c.Assert(val, testutil.DatumEquals, types.NewDatum(int64(2))) // Forth row: nil, nil row, err = agg.Next() c.Assert(err, IsNil) c.Assert(row, NotNil) val, err = evaluator.Eval(ctx, fc) c.Assert(err, IsNil) c.Assert(val, testutil.DatumEquals, types.NewDatum(nil)) val, err = evaluator.Eval(ctx, fc1) c.Assert(err, IsNil) c.Assert(val, testutil.DatumEquals, types.NewDatum(nil)) // Fifth row: nil row, err = agg.Next() c.Assert(err, IsNil) c.Assert(row, IsNil) // Close executor err = agg.Close() c.Assert(err, IsNil) }
func (s *testAggFuncSuite) TestXAPIFirstRow(c *C) { defer testleak.AfterTest(c)() // Compose aggregate exec for "select c2 from t groupby c1"; // c1 c2 // 1 11 // region1 // 2 21 // region1 // 1 1 // region1 // 1 nil // region2 // 1 3 // region2 // 3 31 // region2 // // Expected result: // c2 // 11 // 21 // 31 c1 := ast.NewValueExpr([]byte{0}) rf1 := &ast.ResultField{Expr: c1} c2 := ast.NewValueExpr(0) rf2 := &ast.ResultField{Expr: c2} col2 := &ast.ColumnNameExpr{Refer: rf2} fc := &ast.AggregateFuncExpr{ F: ast.AggFuncFirstRow, Args: []ast.ExprNode{col2}, } // Return row: // GroupKey, Count // Partial result from region1 row1 := types.MakeDatums([]byte{1}, 11) row2 := types.MakeDatums([]byte{2}, 21) // Partial result from region2 row3 := types.MakeDatums([]byte{1}, nil) row4 := types.MakeDatums([]byte{3}, 31) data := []([]types.Datum){row1, row2, row3, row4} rows := make([]*Row, 0, 3) for _, d := range data { rows = append(rows, &Row{Data: d}) } src := &mockExec{ rows: rows, fields: []*ast.ResultField{rf1, rf2}, } agg := &XAggregateExec{ AggFuncs: []*ast.AggregateFuncExpr{fc}, Src: src, } ast.SetFlag(fc) // First Row: 11 row, err := agg.Next() c.Assert(err, IsNil) c.Assert(row, NotNil) ctx := mock.NewContext() val, err := evaluator.Eval(ctx, fc) c.Assert(err, IsNil) c.Assert(val, testutil.DatumEquals, types.NewDatum(int64(11))) // Second Row: 21 row, err = agg.Next() c.Assert(err, IsNil) c.Assert(row, NotNil) val, err = evaluator.Eval(ctx, fc) c.Assert(err, IsNil) c.Assert(val, testutil.DatumEquals, types.NewDatum(int64(21))) // Third Row: 31 row, err = agg.Next() c.Assert(err, IsNil) c.Assert(row, NotNil) val, err = evaluator.Eval(ctx, fc) c.Assert(err, IsNil) c.Assert(val, testutil.DatumEquals, types.NewDatum(int64(31))) agg.Close() // After clear up, fc's value should be default. val, err = evaluator.Eval(ctx, fc) c.Assert(err, IsNil) c.Assert(val, testutil.DatumEquals, types.NewDatum(nil)) }
func (e *SimpleExec) executeSet(s *ast.SetStmt) error { sessionVars := variable.GetSessionVars(e.ctx) globalVars := variable.GetGlobalVarAccessor(e.ctx) for _, v := range s.Variables { // Variable is case insensitive, we use lower case. name := strings.ToLower(v.Name) if !v.IsSystem { // Set user variable. value, err := evaluator.Eval(e.ctx, v.Value) if err != nil { return errors.Trace(err) } if value == nil { delete(sessionVars.Users, name) } else { sessionVars.Users[name] = fmt.Sprintf("%v", value) } continue } // Set system variable sysVar := variable.GetSysVar(name) if sysVar == nil { return variable.UnknownSystemVar.Gen("Unknown system variable '%s'", name) } if sysVar.Scope == variable.ScopeNone { return errors.Errorf("Variable '%s' is a read only variable", name) } if v.IsGlobal { // Set global scope system variable. if sysVar.Scope&variable.ScopeGlobal == 0 { return errors.Errorf("Variable '%s' is a SESSION variable and can't be used with SET GLOBAL", name) } value, err := evaluator.Eval(e.ctx, v.Value) if err != nil { return errors.Trace(err) } if value == nil { value = "" } svalue, err := types.ToString(value) if err != nil { return errors.Trace(err) } err = globalVars.SetGlobalSysVar(e.ctx, name, svalue) if err != nil { return errors.Trace(err) } } else { // Set session scope system variable. if sysVar.Scope&variable.ScopeSession == 0 { return errors.Errorf("Variable '%s' is a GLOBAL variable and should be set with SET GLOBAL", name) } if value, err := evaluator.Eval(e.ctx, v.Value); err != nil { return errors.Trace(err) } else if value == nil { sessionVars.Systems[name] = "" } else { sessionVars.Systems[name] = fmt.Sprintf("%v", value) } } } return nil }