func (e *InsertValues) fillRowData(cols []*table.Column, vals []types.Datum) ([]types.Datum, error) { row := make([]types.Datum, len(e.Table.Cols())) marked := make(map[int]struct{}, len(vals)) for i, v := range vals { offset := cols[i].Offset row[offset] = v marked[offset] = struct{}{} } err := e.initDefaultValues(row, marked) if err != nil { return nil, errors.Trace(err) } if err = table.CastValues(e.ctx, row, cols); err != nil { return nil, errors.Trace(err) } if err = table.CheckNotNull(e.Table.Cols(), row); err != nil { return nil, errors.Trace(err) } return row, nil }
func (e *InsertValues) initDefaultValues(row []types.Datum, marked map[int]struct{}, ignoreErr bool) error { var defaultValueCols []*table.Column sc := e.ctx.GetSessionVars().StmtCtx for i, c := range e.Table.Cols() { // It's used for retry. if mysql.HasAutoIncrementFlag(c.Flag) && row[i].IsNull() && e.ctx.GetSessionVars().RetryInfo.Retrying { id, err := e.ctx.GetSessionVars().RetryInfo.GetCurrAutoIncrementID() if err != nil { return errors.Trace(err) } row[i].SetInt64(id) } if !row[i].IsNull() { // Column value isn't nil and column isn't auto-increment, continue. if !mysql.HasAutoIncrementFlag(c.Flag) { continue } val, err := row[i].ToInt64(sc) if filterErr(errors.Trace(err), ignoreErr) != nil { return errors.Trace(err) } row[i].SetInt64(val) if val != 0 { e.Table.RebaseAutoID(val, true) continue } } // If the nil value is evaluated in insert list, we will use nil except auto increment column. if _, ok := marked[i]; ok && !mysql.HasAutoIncrementFlag(c.Flag) && !mysql.HasTimestampFlag(c.Flag) { continue } if mysql.HasAutoIncrementFlag(c.Flag) { recordID, err := e.Table.AllocAutoID() if err != nil { return errors.Trace(err) } row[i].SetInt64(recordID) // It's compatible with mysql. So it sets last insert id to the first row. if e.currRow == 0 { e.lastInsertID = uint64(recordID) } // It's used for retry. if !e.ctx.GetSessionVars().RetryInfo.Retrying { e.ctx.GetSessionVars().RetryInfo.AddAutoIncrementID(recordID) } } else { var err error row[i], _, err = table.GetColDefaultValue(e.ctx, c.ToInfo()) if filterErr(err, ignoreErr) != nil { return errors.Trace(err) } } defaultValueCols = append(defaultValueCols, c) } if err := table.CastValues(e.ctx, row, defaultValueCols, ignoreErr); err != nil { return errors.Trace(err) } return nil }
func updateRecord(ctx context.Context, h int64, oldData, newData []types.Datum, assignFlag []bool, t table.Table, offset int, onDuplicateUpdate bool) error { cols := t.Cols() touched := make(map[int]bool, len(cols)) assignExists := false sc := ctx.GetSessionVars().StmtCtx var newHandle types.Datum for i, hasSetExpr := range assignFlag { if !hasSetExpr { if onDuplicateUpdate { newData[i] = oldData[i] } continue } if i < offset || i >= offset+len(cols) { // The assign expression is for another table, not this. continue } colIndex := i - offset col := cols[colIndex] if col.IsPKHandleColumn(t.Meta()) { newHandle = newData[i] } if mysql.HasAutoIncrementFlag(col.Flag) { if newData[i].IsNull() { return errors.Errorf("Column '%v' cannot be null", col.Name.O) } val, err := newData[i].ToInt64(sc) if err != nil { return errors.Trace(err) } t.RebaseAutoID(val, true) } touched[colIndex] = true assignExists = true } // If no assign list for this table, no need to update. if !assignExists { return nil } // Check whether new value is valid. if err := table.CastValues(ctx, newData, cols, false); err != nil { return errors.Trace(err) } if err := table.CheckNotNull(cols, newData); err != nil { return errors.Trace(err) } // If row is not changed, we should do nothing. rowChanged := false for i := range oldData { if !touched[i] { continue } n, err := newData[i].CompareDatum(sc, oldData[i]) if err != nil { return errors.Trace(err) } if n != 0 { rowChanged = true break } } if !rowChanged { // See https://dev.mysql.com/doc/refman/5.7/en/mysql-real-connect.html CLIENT_FOUND_ROWS if ctx.GetSessionVars().ClientCapability&mysql.ClientFoundRows > 0 { sc.AddAffectedRows(1) } return nil } var err error if !newHandle.IsNull() { err = t.RemoveRecord(ctx, h, oldData) if err != nil { return errors.Trace(err) } _, err = t.AddRecord(ctx, newData) } else { // Update record to new value and update index. err = t.UpdateRecord(ctx, h, oldData, newData, touched) } if err != nil { return errors.Trace(err) } dirtyDB := getDirtyDB(ctx) tid := t.Meta().ID dirtyDB.deleteRow(tid, h) dirtyDB.addRow(tid, h, newData) // Record affected rows. if !onDuplicateUpdate { sc.AddAffectedRows(1) } else { sc.AddAffectedRows(2) } return nil }