func (s *InsertValues) getRow(ctx context.Context, t table.Table, cols []*column.Col, list []expression.Expression, m map[interface{}]interface{}) ([]interface{}, error) { r := make([]interface{}, len(t.Cols())) marked := make(map[int]struct{}, len(list)) for i, expr := range list { // For "insert into t values (default)" Default Eval. m[expression.ExprEvalDefaultName] = cols[i].Name.O val, err := expr.Eval(ctx, m) if err != nil { return nil, errors.Trace(err) } r[cols[i].Offset] = val marked[cols[i].Offset] = struct{}{} } // Clear last insert id. variable.GetSessionVars(ctx).SetLastInsertID(0) err := s.initDefaultValues(ctx, t, r, marked) if err != nil { return nil, errors.Trace(err) } if err = column.CastValues(ctx, r, cols); err != nil { return nil, errors.Trace(err) } if err = column.CheckNotNull(t.Cols(), r); err != nil { return nil, errors.Trace(err) } return r, nil }
// execExecSelect implements `insert table select ... from ...`. func (s *InsertValues) execSelect(t table.Table, cols []*column.Col, ctx context.Context) (rset.Recordset, error) { r, err := s.Sel.Plan(ctx) if err != nil { return nil, errors.Trace(err) } defer r.Close() if len(r.GetFields()) != len(cols) { return nil, errors.Errorf("Column count %d doesn't match value count %d", len(cols), len(r.GetFields())) } var bufRecords [][]interface{} var lastInsertIds []uint64 for { var row *plan.Row row, err = r.Next(ctx) if err != nil { return nil, errors.Trace(err) } if row == nil { break } data0 := make([]interface{}, len(t.Cols())) marked := make(map[int]struct{}, len(cols)) for i, d := range row.Data { data0[cols[i].Offset] = d marked[cols[i].Offset] = struct{}{} } if err = s.initDefaultValues(ctx, t, data0, marked); err != nil { return nil, errors.Trace(err) } if err = column.CastValues(ctx, data0, cols); err != nil { return nil, errors.Trace(err) } if err = column.CheckNotNull(t.Cols(), data0); err != nil { return nil, errors.Trace(err) } var v interface{} v, err = types.Clone(data0) if err != nil { return nil, errors.Trace(err) } bufRecords = append(bufRecords, v.([]interface{})) lastInsertIds = append(lastInsertIds, variable.GetSessionVars(ctx).LastInsertID) } for i, r := range bufRecords { variable.GetSessionVars(ctx).SetLastInsertID(lastInsertIds[i]) if _, err = t.AddRecord(ctx, r); err != nil { return nil, errors.Trace(err) } } return nil, nil }
// execExecSelect implements `insert table select ... from ...`. func (s *InsertIntoStmt) execSelect(t table.Table, cols []*column.Col, ctx context.Context) (_ rset.Recordset, err error) { r, err := s.Sel.Plan(ctx) if err != nil { return nil, errors.Trace(err) } else if len(r.GetFields()) != len(cols) { return nil, errors.Errorf("Column count %d doesn't match value count %d", len(cols), len(r.GetFields())) } var bufRecords [][]interface{} var lastInsertIds []uint64 err = r.Do(ctx, func(_ interface{}, data []interface{}) (more bool, err error) { data0 := make([]interface{}, len(t.Cols())) marked := make(map[int]struct{}, len(cols)) for i, d := range data { data0[cols[i].Offset] = d marked[cols[i].Offset] = struct{}{} } if err = s.initDefaultValues(ctx, t, t.Cols(), data0, marked); err != nil { return false, errors.Trace(err) } if err = column.CastValues(ctx, data0, cols); err != nil { return false, errors.Trace(err) } if err = column.CheckNotNull(t.Cols(), data0); err != nil { return false, errors.Trace(err) } v, err := types.Clone(data0) if err != nil { return false, errors.Trace(err) } bufRecords = append(bufRecords, v.([]interface{})) lastInsertIds = append(lastInsertIds, variable.GetSessionVars(ctx).LastInsertID) return true, nil }) if err != nil { return nil, errors.Trace(err) } for i, r := range bufRecords { variable.GetSessionVars(ctx).SetLastInsertID(lastInsertIds[i]) if _, err = t.AddRecord(ctx, r); err != nil { return nil, errors.Trace(err) } } return nil, nil }
func (e *InsertValues) fillRowData(cols []*column.Col, vals []types.Datum) ([]types.Datum, error) { row := make([]types.Datum, len(e.Table.Cols())) marked := make(map[int]struct{}, len(vals)) for i, v := range vals { offset := cols[i].Offset row[offset] = v marked[offset] = struct{}{} } err := e.initDefaultValues(row, marked) if err != nil { return nil, errors.Trace(err) } if err = column.CastValues(e.ctx, row, cols); err != nil { return nil, errors.Trace(err) } if err = column.CheckNotNull(e.Table.Cols(), row); err != nil { return nil, errors.Trace(err) } return row, nil }
func (s *InsertValues) fillRowData(ctx context.Context, t table.Table, cols []*column.Col, vals []interface{}) ([]interface{}, error) { row := make([]interface{}, len(t.Cols())) marked := make(map[int]struct{}, len(vals)) for i, v := range vals { offset := cols[i].Offset row[offset] = v marked[offset] = struct{}{} } err := s.initDefaultValues(ctx, t, row, marked) if err != nil { return nil, errors.Trace(err) } if err = column.CastValues(ctx, row, cols); err != nil { return nil, errors.Trace(err) } if err = column.CheckNotNull(t.Cols(), row); err != nil { return nil, errors.Trace(err) } return row, nil }
func updateRecord(ctx context.Context, h int64, data []interface{}, t table.Table, updateColumns map[int]expression.Assignment, m map[interface{}]interface{}, offset int, onDuplicateUpdate bool) error { if err := t.LockRow(ctx, h, true); err != nil { return errors.Trace(err) } oldData := make([]interface{}, len(t.Cols())) touched := make([]bool, len(t.Cols())) copy(oldData, data) cols := t.Cols() assignExists := false for i, asgn := range updateColumns { if i < offset || i >= offset+len(cols) { // The assign expression is for another table, not this. continue } val, err := asgn.Expr.Eval(ctx, m) if err != nil { return err } colIndex := i - offset touched[colIndex] = true data[colIndex] = val assignExists = true } // no assign list for this table, no need to update. if !assignExists { return nil } // Check whether new value is valid. if err := column.CastValues(ctx, data, t.Cols()); err != nil { return err } if err := column.CheckNotNull(t.Cols(), data); err != nil { return err } // If row is not changed, we should do nothing. rowChanged := false for i, d := range data { if !touched[i] { continue } od := oldData[i] n, err := types.Compare(d, od) if err != nil { return errors.Trace(err) } if n != 0 { rowChanged = true break } } if !rowChanged { // See: https://dev.mysql.com/doc/refman/5.7/en/mysql-real-connect.html CLIENT_FOUND_ROWS if variable.GetSessionVars(ctx).ClientCapability&mysql.ClientFoundRows > 0 { variable.GetSessionVars(ctx).AddAffectedRows(1) } return nil } // Update record to new value and update index. err := t.UpdateRecord(ctx, h, oldData, data, touched) if err != nil { return errors.Trace(err) } // Record affected rows. if !onDuplicateUpdate { variable.GetSessionVars(ctx).AddAffectedRows(1) } else { variable.GetSessionVars(ctx).AddAffectedRows(2) } return nil }
func updateRecord(ctx context.Context, h int64, data []interface{}, t table.Table, tcols []*column.Col, assignList []expressions.Assignment, insertData []interface{}, args map[interface{}]interface{}) error { if err := t.LockRow(ctx, h, true); err != nil { return errors.Trace(err) } oldData := make([]interface{}, len(t.Cols())) touched := make([]bool, len(t.Cols())) copy(oldData, data) // Generate new values m := args if m == nil { m = make(map[interface{}]interface{}, len(t.Cols())) // Set parameter for evaluating expression. for _, col := range t.Cols() { m[col.Name.L] = data[col.Offset] } } if insertData != nil { m[expressions.ExprEvalValuesFunc] = func(name string) (interface{}, error) { return getInsertValue(name, t.Cols(), insertData) } } for i, asgn := range assignList { val, err := asgn.Expr.Eval(ctx, m) if err != nil { return err } colIndex := tcols[i].Offset touched[colIndex] = true data[colIndex] = val } // Check whether new value is valid. if err := column.CastValues(ctx, data, t.Cols()); err != nil { return err } if err := column.CheckNotNull(t.Cols(), data); err != nil { return err } // If row is not changed, we should do nothing. rowChanged := false for i, d := range data { if !touched[i] { continue } od := oldData[i] n, err := types.Compare(d, od) if err != nil { return errors.Trace(err) } if n != 0 { rowChanged = true break } } if !rowChanged { // See: https://dev.mysql.com/doc/refman/5.7/en/mysql-real-connect.html CLIENT_FOUND_ROWS if variable.GetSessionVars(ctx).ClientCapability&mysql.ClientFoundRows > 0 { variable.GetSessionVars(ctx).AddAffectedRows(1) } return nil } // Update record to new value and update index. err := t.UpdateRecord(ctx, h, oldData, data, touched) if err != nil { return errors.Trace(err) } // Record affected rows. if len(insertData) == 0 { variable.GetSessionVars(ctx).AddAffectedRows(1) } else { variable.GetSessionVars(ctx).AddAffectedRows(2) } return nil }
func updateRecord(ctx context.Context, h int64, data []interface{}, t table.Table, updateColumns map[int]*expression.Assignment, evalMap map[interface{}]interface{}, offset int, onDuplicateUpdate bool) error { if err := t.LockRow(ctx, h); err != nil { return errors.Trace(err) } cols := t.Cols() oldData := data newData := make([]interface{}, len(cols)) touched := make(map[int]bool, len(cols)) copy(newData, oldData) assignExists := false var newHandle interface{} for i, asgn := range updateColumns { if i < offset || i >= offset+len(cols) { // The assign expression is for another table, not this. continue } val, err := asgn.Expr.Eval(ctx, evalMap) if err != nil { return errors.Trace(err) } colIndex := i - offset col := cols[colIndex] if col.IsPKHandleColumn(t.Meta()) { newHandle = val } touched[colIndex] = true newData[colIndex] = val assignExists = true } // If no assign list for this table, no need to update. if !assignExists { return nil } // Check whether new value is valid. if err := column.CastValues(ctx, newData, cols); err != nil { return errors.Trace(err) } if err := column.CheckNotNull(cols, newData); err != nil { return errors.Trace(err) } // If row is not changed, we should do nothing. rowChanged := false for i := range oldData { if !touched[i] { continue } n, err := types.Compare(newData[i], oldData[i]) if err != nil { return errors.Trace(err) } if n != 0 { rowChanged = true break } } if !rowChanged { // See: https://dev.mysql.com/doc/refman/5.7/en/mysql-real-connect.html CLIENT_FOUND_ROWS if variable.GetSessionVars(ctx).ClientCapability&mysql.ClientFoundRows > 0 { variable.GetSessionVars(ctx).AddAffectedRows(1) } return nil } var err error if newHandle != nil { err = t.RemoveRecord(ctx, h, oldData) if err != nil { return errors.Trace(err) } _, err = t.AddRecord(ctx, newData) } else { // Update record to new value and update index. err = t.UpdateRecord(ctx, h, oldData, newData, touched) } if err != nil { return errors.Trace(err) } // Record affected rows. if !onDuplicateUpdate { variable.GetSessionVars(ctx).AddAffectedRows(1) } else { variable.GetSessionVars(ctx).AddAffectedRows(2) } return nil }
// Exec implements the stmt.Statement Exec interface. func (s *InsertIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) { t, err := getTable(ctx, s.TableIdent) if err != nil { return nil, errors.Trace(err) } tableCols := t.Cols() cols, err := s.getColumns(tableCols) if err != nil { return nil, errors.Trace(err) } // Process `insert ... (select ..) ` if s.Sel != nil { return s.execSelect(t, cols, ctx) } // Process `insert ... set x=y...` if len(s.Setlist) > 0 { if len(s.Lists) > 0 { return nil, errors.Errorf("INSERT INTO %s: set type should not use values", s.TableIdent) } var l []expression.Expression for _, v := range s.Setlist { l = append(l, v.Expr) } s.Lists = append(s.Lists, l) } m := map[interface{}]interface{}{} for _, v := range tableCols { var ( value interface{} ok bool ) value, ok, err = getDefaultValue(ctx, v) if ok { if err != nil { return nil, errors.Trace(err) } m[v.Name.L] = value } } insertValueCount := len(s.Lists[0]) for i, list := range s.Lists { r := make([]interface{}, len(tableCols)) valueCount := len(list) if insertValueCount != valueCount { // "insert into t values (), ()" is valid. // "insert into t values (), (1)" is not valid. // "insert into t values (1), ()" is not valid. // "insert into t values (1,2), (1)" is not valid. // So the value count must be same for all insert list. return nil, errors.Errorf("Column count doesn't match value count at row %d", i+1) } if valueCount == 0 && len(s.ColNames) > 0 { // "insert into t (c1) values ()" is not valid. return nil, errors.Errorf("INSERT INTO %s: expected %d value(s), have %d", s.TableIdent, len(s.ColNames), 0) } else if valueCount > 0 && valueCount != len(cols) { return nil, errors.Errorf("INSERT INTO %s: expected %d value(s), have %d", s.TableIdent, len(cols), valueCount) } // Clear last insert id. variable.GetSessionVars(ctx).SetLastInsertID(0) marked := make(map[int]struct{}, len(list)) for i, expr := range list { // For "insert into t values (default)" Default Eval. m[expressions.ExprEvalDefaultName] = cols[i].Name.O val, evalErr := expr.Eval(ctx, m) if evalErr != nil { return nil, errors.Trace(evalErr) } r[cols[i].Offset] = val marked[cols[i].Offset] = struct{}{} } if err := s.initDefaultValues(ctx, t, tableCols, r, marked); err != nil { return nil, errors.Trace(err) } if err = column.CastValues(ctx, r, cols); err != nil { return nil, errors.Trace(err) } if err = column.CheckNotNull(tableCols, r); err != nil { return nil, errors.Trace(err) } // Notes: incompatible with mysql // MySQL will set last insert id to the first row, as follows: // `t(id int AUTO_INCREMENT, c1 int, PRIMARY KEY (id))` // `insert t (c1) values(1),(2),(3);` // Last insert id will be 1, not 3. h, err := t.AddRecord(ctx, r) if err == nil { continue } if len(s.OnDuplicate) == 0 || !errors2.ErrorEqual(err, kv.ErrKeyExists) { return nil, errors.Trace(err) } // On duplicate key Update the duplicate row. // Evaluate the updated value. // TODO: report rows affected and last insert id. toUpdateColumns, err := getUpdateColumns(t, s.OnDuplicate, false) if err != nil { return nil, errors.Trace(err) } data, err := t.Row(ctx, h) if err != nil { return nil, errors.Trace(err) } err = updateRecord(ctx, h, data, t, toUpdateColumns, s.OnDuplicate, r, nil) if err != nil { return nil, errors.Trace(err) } } return nil, nil }
func updateRecord(ctx context.Context, h int64, oldData, newData []types.Datum, updateColumns map[int]*ast.Assignment, t table.Table, offset int, onDuplicateUpdate bool) error { if err := t.LockRow(ctx, h, false); err != nil { return errors.Trace(err) } cols := t.Cols() touched := make(map[int]bool, len(cols)) assignExists := false var newHandle types.Datum for i, asgn := range updateColumns { if asgn == nil { continue } if i < offset || i >= offset+len(cols) { // The assign expression is for another table, not this. continue } colIndex := i - offset col := cols[colIndex] if col.IsPKHandleColumn(t.Meta()) { newHandle = newData[i] } if mysql.HasAutoIncrementFlag(col.Flag) { if newData[i].Kind() == types.KindNull { return errors.Errorf("Column '%v' cannot be null", col.Name.O) } val, err := newData[i].ToInt64() if err != nil { return errors.Trace(err) } t.RebaseAutoID(val, true) } touched[colIndex] = true assignExists = true } // If no assign list for this table, no need to update. if !assignExists { return nil } // Check whether new value is valid. if err := column.CastValues(ctx, newData, cols); err != nil { return errors.Trace(err) } if err := column.CheckNotNull(cols, newData); err != nil { return errors.Trace(err) } // If row is not changed, we should do nothing. rowChanged := false for i := range oldData { if !touched[i] { continue } n, err := newData[i].CompareDatum(oldData[i]) if err != nil { return errors.Trace(err) } if n != 0 { rowChanged = true break } } if !rowChanged { // See: https://dev.mysql.com/doc/refman/5.7/en/mysql-real-connect.html CLIENT_FOUND_ROWS if variable.GetSessionVars(ctx).ClientCapability&mysql.ClientFoundRows > 0 { variable.GetSessionVars(ctx).AddAffectedRows(1) } return nil } var err error if newHandle.Kind() != types.KindNull { err = t.RemoveRecord(ctx, h, oldData) if err != nil { return errors.Trace(err) } _, err = t.AddRecord(ctx, newData) } else { // Update record to new value and update index. err = t.UpdateRecord(ctx, h, oldData, newData, touched) } if err != nil { return errors.Trace(err) } // Record affected rows. if !onDuplicateUpdate { variable.GetSessionVars(ctx).AddAffectedRows(1) } else { variable.GetSessionVars(ctx).AddAffectedRows(2) } return nil }