コード例 #1
0
ファイル: executor_write.go プロジェクト: XuHuaiyu/tidb
func (e *InsertValues) fillRowData(cols []*table.Column, vals []types.Datum) ([]types.Datum, error) {
	row := make([]types.Datum, len(e.Table.Cols()))
	marked := make(map[int]struct{}, len(vals))
	for i, v := range vals {
		offset := cols[i].Offset
		row[offset] = v
		marked[offset] = struct{}{}
	}
	err := e.initDefaultValues(row, marked)
	if err != nil {
		return nil, errors.Trace(err)
	}
	if err = table.CastValues(e.ctx, row, cols); err != nil {
		return nil, errors.Trace(err)
	}
	if err = table.CheckNotNull(e.Table.Cols(), row); err != nil {
		return nil, errors.Trace(err)
	}
	return row, nil
}
コード例 #2
0
ファイル: executor_write.go プロジェクト: pingcap/tidb
func (e *InsertValues) initDefaultValues(row []types.Datum, marked map[int]struct{}, ignoreErr bool) error {
	var defaultValueCols []*table.Column
	sc := e.ctx.GetSessionVars().StmtCtx
	for i, c := range e.Table.Cols() {
		// It's used for retry.
		if mysql.HasAutoIncrementFlag(c.Flag) && row[i].IsNull() &&
			e.ctx.GetSessionVars().RetryInfo.Retrying {
			id, err := e.ctx.GetSessionVars().RetryInfo.GetCurrAutoIncrementID()
			if err != nil {
				return errors.Trace(err)
			}
			row[i].SetInt64(id)
		}
		if !row[i].IsNull() {
			// Column value isn't nil and column isn't auto-increment, continue.
			if !mysql.HasAutoIncrementFlag(c.Flag) {
				continue
			}
			val, err := row[i].ToInt64(sc)
			if filterErr(errors.Trace(err), ignoreErr) != nil {
				return errors.Trace(err)
			}
			row[i].SetInt64(val)
			if val != 0 {
				e.Table.RebaseAutoID(val, true)
				continue
			}
		}

		// If the nil value is evaluated in insert list, we will use nil except auto increment column.
		if _, ok := marked[i]; ok && !mysql.HasAutoIncrementFlag(c.Flag) && !mysql.HasTimestampFlag(c.Flag) {
			continue
		}

		if mysql.HasAutoIncrementFlag(c.Flag) {
			recordID, err := e.Table.AllocAutoID()
			if err != nil {
				return errors.Trace(err)
			}
			row[i].SetInt64(recordID)
			// It's compatible with mysql. So it sets last insert id to the first row.
			if e.currRow == 0 {
				e.lastInsertID = uint64(recordID)
			}
			// It's used for retry.
			if !e.ctx.GetSessionVars().RetryInfo.Retrying {
				e.ctx.GetSessionVars().RetryInfo.AddAutoIncrementID(recordID)
			}
		} else {
			var err error
			row[i], _, err = table.GetColDefaultValue(e.ctx, c.ToInfo())
			if filterErr(err, ignoreErr) != nil {
				return errors.Trace(err)
			}
		}

		defaultValueCols = append(defaultValueCols, c)
	}
	if err := table.CastValues(e.ctx, row, defaultValueCols, ignoreErr); err != nil {
		return errors.Trace(err)
	}

	return nil
}
コード例 #3
0
ファイル: executor_write.go プロジェクト: pingcap/tidb
func updateRecord(ctx context.Context, h int64, oldData, newData []types.Datum, assignFlag []bool, t table.Table, offset int, onDuplicateUpdate bool) error {
	cols := t.Cols()
	touched := make(map[int]bool, len(cols))
	assignExists := false
	sc := ctx.GetSessionVars().StmtCtx
	var newHandle types.Datum
	for i, hasSetExpr := range assignFlag {
		if !hasSetExpr {
			if onDuplicateUpdate {
				newData[i] = oldData[i]
			}
			continue
		}
		if i < offset || i >= offset+len(cols) {
			// The assign expression is for another table, not this.
			continue
		}

		colIndex := i - offset
		col := cols[colIndex]
		if col.IsPKHandleColumn(t.Meta()) {
			newHandle = newData[i]
		}
		if mysql.HasAutoIncrementFlag(col.Flag) {
			if newData[i].IsNull() {
				return errors.Errorf("Column '%v' cannot be null", col.Name.O)
			}
			val, err := newData[i].ToInt64(sc)
			if err != nil {
				return errors.Trace(err)
			}
			t.RebaseAutoID(val, true)
		}

		touched[colIndex] = true
		assignExists = true
	}

	// If no assign list for this table, no need to update.
	if !assignExists {
		return nil
	}

	// Check whether new value is valid.
	if err := table.CastValues(ctx, newData, cols, false); err != nil {
		return errors.Trace(err)
	}

	if err := table.CheckNotNull(cols, newData); err != nil {
		return errors.Trace(err)
	}

	// If row is not changed, we should do nothing.
	rowChanged := false
	for i := range oldData {
		if !touched[i] {
			continue
		}

		n, err := newData[i].CompareDatum(sc, oldData[i])
		if err != nil {
			return errors.Trace(err)
		}
		if n != 0 {
			rowChanged = true
			break
		}
	}
	if !rowChanged {
		// See https://dev.mysql.com/doc/refman/5.7/en/mysql-real-connect.html  CLIENT_FOUND_ROWS
		if ctx.GetSessionVars().ClientCapability&mysql.ClientFoundRows > 0 {
			sc.AddAffectedRows(1)
		}
		return nil
	}

	var err error
	if !newHandle.IsNull() {
		err = t.RemoveRecord(ctx, h, oldData)
		if err != nil {
			return errors.Trace(err)
		}
		_, err = t.AddRecord(ctx, newData)
	} else {
		// Update record to new value and update index.
		err = t.UpdateRecord(ctx, h, oldData, newData, touched)
	}
	if err != nil {
		return errors.Trace(err)
	}
	dirtyDB := getDirtyDB(ctx)
	tid := t.Meta().ID
	dirtyDB.deleteRow(tid, h)
	dirtyDB.addRow(tid, h, newData)

	// Record affected rows.
	if !onDuplicateUpdate {
		sc.AddAffectedRows(1)
	} else {
		sc.AddAffectedRows(2)
	}
	return nil
}