示例#1
0
文件: insert.go 项目: huozhe3136/tidb
func (s *InsertValues) getRow(ctx context.Context, t table.Table, cols []*column.Col, list []expression.Expression, m map[interface{}]interface{}) ([]interface{}, error) {
	r := make([]interface{}, len(t.Cols()))
	marked := make(map[int]struct{}, len(list))
	for i, expr := range list {
		// For "insert into t values (default)" Default Eval.
		m[expression.ExprEvalDefaultName] = cols[i].Name.O

		val, err := expr.Eval(ctx, m)
		if err != nil {
			return nil, errors.Trace(err)
		}
		r[cols[i].Offset] = val
		marked[cols[i].Offset] = struct{}{}
	}

	// Clear last insert id.
	variable.GetSessionVars(ctx).SetLastInsertID(0)

	err := s.initDefaultValues(ctx, t, r, marked)
	if err != nil {
		return nil, errors.Trace(err)
	}
	if err = column.CastValues(ctx, r, cols); err != nil {
		return nil, errors.Trace(err)
	}
	if err = column.CheckNotNull(t.Cols(), r); err != nil {
		return nil, errors.Trace(err)
	}

	return r, nil
}
示例#2
0
文件: insert.go 项目: huozhe3136/tidb
// execExecSelect implements `insert table select ... from ...`.
func (s *InsertValues) execSelect(t table.Table, cols []*column.Col, ctx context.Context) (rset.Recordset, error) {
	r, err := s.Sel.Plan(ctx)
	if err != nil {
		return nil, errors.Trace(err)
	}
	defer r.Close()
	if len(r.GetFields()) != len(cols) {
		return nil, errors.Errorf("Column count %d doesn't match value count %d", len(cols), len(r.GetFields()))
	}

	var bufRecords [][]interface{}
	var lastInsertIds []uint64
	for {
		var row *plan.Row
		row, err = r.Next(ctx)
		if err != nil {
			return nil, errors.Trace(err)
		}
		if row == nil {
			break
		}
		data0 := make([]interface{}, len(t.Cols()))
		marked := make(map[int]struct{}, len(cols))
		for i, d := range row.Data {
			data0[cols[i].Offset] = d
			marked[cols[i].Offset] = struct{}{}
		}

		if err = s.initDefaultValues(ctx, t, data0, marked); err != nil {
			return nil, errors.Trace(err)
		}

		if err = column.CastValues(ctx, data0, cols); err != nil {
			return nil, errors.Trace(err)
		}

		if err = column.CheckNotNull(t.Cols(), data0); err != nil {
			return nil, errors.Trace(err)
		}
		var v interface{}
		v, err = types.Clone(data0)
		if err != nil {
			return nil, errors.Trace(err)
		}

		bufRecords = append(bufRecords, v.([]interface{}))
		lastInsertIds = append(lastInsertIds, variable.GetSessionVars(ctx).LastInsertID)
	}

	for i, r := range bufRecords {
		variable.GetSessionVars(ctx).SetLastInsertID(lastInsertIds[i])
		if _, err = t.AddRecord(ctx, r); err != nil {
			return nil, errors.Trace(err)
		}
	}
	return nil, nil
}
示例#3
0
文件: insert.go 项目: ninefive/tidb
// execExecSelect implements `insert table select ... from ...`.
func (s *InsertIntoStmt) execSelect(t table.Table, cols []*column.Col, ctx context.Context) (_ rset.Recordset, err error) {
	r, err := s.Sel.Plan(ctx)
	if err != nil {
		return nil, errors.Trace(err)
	} else if len(r.GetFields()) != len(cols) {
		return nil, errors.Errorf("Column count %d doesn't match value count %d", len(cols), len(r.GetFields()))
	}

	var bufRecords [][]interface{}
	var lastInsertIds []uint64
	err = r.Do(ctx, func(_ interface{}, data []interface{}) (more bool, err error) {
		data0 := make([]interface{}, len(t.Cols()))
		marked := make(map[int]struct{}, len(cols))
		for i, d := range data {
			data0[cols[i].Offset] = d
			marked[cols[i].Offset] = struct{}{}
		}

		if err = s.initDefaultValues(ctx, t, t.Cols(), data0, marked); err != nil {
			return false, errors.Trace(err)
		}

		if err = column.CastValues(ctx, data0, cols); err != nil {
			return false, errors.Trace(err)
		}

		if err = column.CheckNotNull(t.Cols(), data0); err != nil {
			return false, errors.Trace(err)
		}

		v, err := types.Clone(data0)
		if err != nil {
			return false, errors.Trace(err)
		}

		bufRecords = append(bufRecords, v.([]interface{}))
		lastInsertIds = append(lastInsertIds, variable.GetSessionVars(ctx).LastInsertID)
		return true, nil
	})
	if err != nil {
		return nil, errors.Trace(err)
	}

	for i, r := range bufRecords {
		variable.GetSessionVars(ctx).SetLastInsertID(lastInsertIds[i])

		if _, err = t.AddRecord(ctx, r); err != nil {
			return nil, errors.Trace(err)
		}
	}

	return nil, nil
}
示例#4
0
func (e *InsertValues) fillRowData(cols []*column.Col, vals []types.Datum) ([]types.Datum, error) {
	row := make([]types.Datum, len(e.Table.Cols()))
	marked := make(map[int]struct{}, len(vals))
	for i, v := range vals {
		offset := cols[i].Offset
		row[offset] = v
		marked[offset] = struct{}{}
	}
	err := e.initDefaultValues(row, marked)
	if err != nil {
		return nil, errors.Trace(err)
	}
	if err = column.CastValues(e.ctx, row, cols); err != nil {
		return nil, errors.Trace(err)
	}
	if err = column.CheckNotNull(e.Table.Cols(), row); err != nil {
		return nil, errors.Trace(err)
	}
	return row, nil
}
示例#5
0
文件: insert.go 项目: lovedboy/tidb
func (s *InsertValues) fillRowData(ctx context.Context, t table.Table, cols []*column.Col, vals []interface{}) ([]interface{}, error) {
	row := make([]interface{}, len(t.Cols()))
	marked := make(map[int]struct{}, len(vals))
	for i, v := range vals {
		offset := cols[i].Offset
		row[offset] = v
		marked[offset] = struct{}{}
	}
	err := s.initDefaultValues(ctx, t, row, marked)
	if err != nil {
		return nil, errors.Trace(err)
	}
	if err = column.CastValues(ctx, row, cols); err != nil {
		return nil, errors.Trace(err)
	}
	if err = column.CheckNotNull(t.Cols(), row); err != nil {
		return nil, errors.Trace(err)
	}
	return row, nil
}
示例#6
0
文件: update.go 项目: Brian110/tidb
func updateRecord(ctx context.Context, h int64, data []interface{}, t table.Table,
	updateColumns map[int]expression.Assignment, m map[interface{}]interface{},
	offset int, onDuplicateUpdate bool) error {
	if err := t.LockRow(ctx, h, true); err != nil {
		return errors.Trace(err)
	}

	oldData := make([]interface{}, len(t.Cols()))
	touched := make([]bool, len(t.Cols()))
	copy(oldData, data)

	cols := t.Cols()

	assignExists := false
	for i, asgn := range updateColumns {
		if i < offset || i >= offset+len(cols) {
			// The assign expression is for another table, not this.
			continue
		}
		val, err := asgn.Expr.Eval(ctx, m)
		if err != nil {
			return err
		}
		colIndex := i - offset
		touched[colIndex] = true
		data[colIndex] = val
		assignExists = true
	}

	// no assign list for this table, no need to update.
	if !assignExists {
		return nil
	}

	// Check whether new value is valid.
	if err := column.CastValues(ctx, data, t.Cols()); err != nil {
		return err
	}

	if err := column.CheckNotNull(t.Cols(), data); err != nil {
		return err
	}

	// If row is not changed, we should do nothing.
	rowChanged := false
	for i, d := range data {
		if !touched[i] {
			continue
		}
		od := oldData[i]
		n, err := types.Compare(d, od)
		if err != nil {
			return errors.Trace(err)
		}

		if n != 0 {
			rowChanged = true
			break
		}
	}
	if !rowChanged {
		// See: https://dev.mysql.com/doc/refman/5.7/en/mysql-real-connect.html  CLIENT_FOUND_ROWS
		if variable.GetSessionVars(ctx).ClientCapability&mysql.ClientFoundRows > 0 {
			variable.GetSessionVars(ctx).AddAffectedRows(1)
		}
		return nil
	}

	// Update record to new value and update index.
	err := t.UpdateRecord(ctx, h, oldData, data, touched)
	if err != nil {
		return errors.Trace(err)
	}
	// Record affected rows.
	if !onDuplicateUpdate {
		variable.GetSessionVars(ctx).AddAffectedRows(1)
	} else {
		variable.GetSessionVars(ctx).AddAffectedRows(2)

	}
	return nil
}
示例#7
0
文件: update.go 项目: hulunbier/tidb
func updateRecord(ctx context.Context, h int64, data []interface{}, t table.Table, tcols []*column.Col, assignList []expressions.Assignment, insertData []interface{}, args map[interface{}]interface{}) error {
	if err := t.LockRow(ctx, h, true); err != nil {
		return errors.Trace(err)
	}

	oldData := make([]interface{}, len(t.Cols()))
	touched := make([]bool, len(t.Cols()))
	copy(oldData, data)

	// Generate new values
	m := args
	if m == nil {
		m = make(map[interface{}]interface{}, len(t.Cols()))
		// Set parameter for evaluating expression.
		for _, col := range t.Cols() {
			m[col.Name.L] = data[col.Offset]
		}
	}
	if insertData != nil {
		m[expressions.ExprEvalValuesFunc] = func(name string) (interface{}, error) {
			return getInsertValue(name, t.Cols(), insertData)
		}
	}

	for i, asgn := range assignList {
		val, err := asgn.Expr.Eval(ctx, m)
		if err != nil {
			return err
		}
		colIndex := tcols[i].Offset
		touched[colIndex] = true
		data[colIndex] = val
	}

	// Check whether new value is valid.
	if err := column.CastValues(ctx, data, t.Cols()); err != nil {
		return err
	}

	if err := column.CheckNotNull(t.Cols(), data); err != nil {
		return err
	}

	// If row is not changed, we should do nothing.
	rowChanged := false
	for i, d := range data {
		if !touched[i] {
			continue
		}
		od := oldData[i]
		n, err := types.Compare(d, od)
		if err != nil {
			return errors.Trace(err)
		}

		if n != 0 {
			rowChanged = true
			break
		}
	}
	if !rowChanged {
		// See: https://dev.mysql.com/doc/refman/5.7/en/mysql-real-connect.html  CLIENT_FOUND_ROWS
		if variable.GetSessionVars(ctx).ClientCapability&mysql.ClientFoundRows > 0 {
			variable.GetSessionVars(ctx).AddAffectedRows(1)
		}
		return nil
	}

	// Update record to new value and update index.
	err := t.UpdateRecord(ctx, h, oldData, data, touched)
	if err != nil {
		return errors.Trace(err)
	}
	// Record affected rows.
	if len(insertData) == 0 {
		variable.GetSessionVars(ctx).AddAffectedRows(1)
	} else {
		variable.GetSessionVars(ctx).AddAffectedRows(2)

	}
	return nil
}
示例#8
0
文件: update.go 项目: lovedboy/tidb
func updateRecord(ctx context.Context, h int64, data []interface{}, t table.Table,
	updateColumns map[int]*expression.Assignment, evalMap map[interface{}]interface{},
	offset int, onDuplicateUpdate bool) error {
	if err := t.LockRow(ctx, h); err != nil {
		return errors.Trace(err)
	}

	cols := t.Cols()
	oldData := data
	newData := make([]interface{}, len(cols))
	touched := make(map[int]bool, len(cols))
	copy(newData, oldData)

	assignExists := false
	var newHandle interface{}
	for i, asgn := range updateColumns {
		if i < offset || i >= offset+len(cols) {
			// The assign expression is for another table, not this.
			continue
		}

		val, err := asgn.Expr.Eval(ctx, evalMap)
		if err != nil {
			return errors.Trace(err)
		}

		colIndex := i - offset
		col := cols[colIndex]
		if col.IsPKHandleColumn(t.Meta()) {
			newHandle = val
		}

		touched[colIndex] = true
		newData[colIndex] = val
		assignExists = true
	}

	// If no assign list for this table, no need to update.
	if !assignExists {
		return nil
	}

	// Check whether new value is valid.
	if err := column.CastValues(ctx, newData, cols); err != nil {
		return errors.Trace(err)
	}

	if err := column.CheckNotNull(cols, newData); err != nil {
		return errors.Trace(err)
	}

	// If row is not changed, we should do nothing.
	rowChanged := false
	for i := range oldData {
		if !touched[i] {
			continue
		}

		n, err := types.Compare(newData[i], oldData[i])
		if err != nil {
			return errors.Trace(err)
		}
		if n != 0 {
			rowChanged = true
			break
		}
	}

	if !rowChanged {
		// See: https://dev.mysql.com/doc/refman/5.7/en/mysql-real-connect.html  CLIENT_FOUND_ROWS
		if variable.GetSessionVars(ctx).ClientCapability&mysql.ClientFoundRows > 0 {
			variable.GetSessionVars(ctx).AddAffectedRows(1)
		}
		return nil
	}

	var err error
	if newHandle != nil {
		err = t.RemoveRecord(ctx, h, oldData)
		if err != nil {
			return errors.Trace(err)
		}
		_, err = t.AddRecord(ctx, newData)
	} else {
		// Update record to new value and update index.
		err = t.UpdateRecord(ctx, h, oldData, newData, touched)
	}
	if err != nil {
		return errors.Trace(err)
	}

	// Record affected rows.
	if !onDuplicateUpdate {
		variable.GetSessionVars(ctx).AddAffectedRows(1)
	} else {
		variable.GetSessionVars(ctx).AddAffectedRows(2)

	}

	return nil
}
示例#9
0
文件: insert.go 项目: hulunbier/tidb
// Exec implements the stmt.Statement Exec interface.
func (s *InsertIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) {
	t, err := getTable(ctx, s.TableIdent)
	if err != nil {
		return nil, errors.Trace(err)
	}

	tableCols := t.Cols()
	cols, err := s.getColumns(tableCols)
	if err != nil {
		return nil, errors.Trace(err)
	}

	// Process `insert ... (select ..) `
	if s.Sel != nil {
		return s.execSelect(t, cols, ctx)
	}

	// Process `insert ... set x=y...`
	if len(s.Setlist) > 0 {
		if len(s.Lists) > 0 {
			return nil, errors.Errorf("INSERT INTO %s: set type should not use values", s.TableIdent)
		}

		var l []expression.Expression
		for _, v := range s.Setlist {
			l = append(l, v.Expr)
		}
		s.Lists = append(s.Lists, l)
	}

	m := map[interface{}]interface{}{}
	for _, v := range tableCols {
		var (
			value interface{}
			ok    bool
		)
		value, ok, err = getDefaultValue(ctx, v)
		if ok {
			if err != nil {
				return nil, errors.Trace(err)
			}

			m[v.Name.L] = value
		}
	}

	insertValueCount := len(s.Lists[0])
	for i, list := range s.Lists {
		r := make([]interface{}, len(tableCols))
		valueCount := len(list)

		if insertValueCount != valueCount {
			// "insert into t values (), ()" is valid.
			// "insert into t values (), (1)" is not valid.
			// "insert into t values (1), ()" is not valid.
			// "insert into t values (1,2), (1)" is not valid.
			// So the value count must be same for all insert list.
			return nil, errors.Errorf("Column count doesn't match value count at row %d", i+1)
		}

		if valueCount == 0 && len(s.ColNames) > 0 {
			// "insert into t (c1) values ()" is not valid.
			return nil, errors.Errorf("INSERT INTO %s: expected %d value(s), have %d", s.TableIdent, len(s.ColNames), 0)
		} else if valueCount > 0 && valueCount != len(cols) {
			return nil, errors.Errorf("INSERT INTO %s: expected %d value(s), have %d", s.TableIdent, len(cols), valueCount)
		}

		// Clear last insert id.
		variable.GetSessionVars(ctx).SetLastInsertID(0)

		marked := make(map[int]struct{}, len(list))
		for i, expr := range list {
			// For "insert into t values (default)" Default Eval.
			m[expressions.ExprEvalDefaultName] = cols[i].Name.O

			val, evalErr := expr.Eval(ctx, m)
			if evalErr != nil {
				return nil, errors.Trace(evalErr)
			}
			r[cols[i].Offset] = val
			marked[cols[i].Offset] = struct{}{}
		}

		if err := s.initDefaultValues(ctx, t, tableCols, r, marked); err != nil {
			return nil, errors.Trace(err)
		}

		if err = column.CastValues(ctx, r, cols); err != nil {
			return nil, errors.Trace(err)
		}
		if err = column.CheckNotNull(tableCols, r); err != nil {
			return nil, errors.Trace(err)
		}

		// Notes: incompatible with mysql
		// MySQL will set last insert id to the first row, as follows:
		// `t(id int AUTO_INCREMENT, c1 int, PRIMARY KEY (id))`
		// `insert t (c1) values(1),(2),(3);`
		// Last insert id will be 1, not 3.
		h, err := t.AddRecord(ctx, r)
		if err == nil {
			continue
		}
		if len(s.OnDuplicate) == 0 || !errors2.ErrorEqual(err, kv.ErrKeyExists) {
			return nil, errors.Trace(err)
		}
		// On duplicate key Update the duplicate row.
		// Evaluate the updated value.
		// TODO: report rows affected and last insert id.
		toUpdateColumns, err := getUpdateColumns(t, s.OnDuplicate, false)
		if err != nil {
			return nil, errors.Trace(err)
		}
		data, err := t.Row(ctx, h)
		if err != nil {
			return nil, errors.Trace(err)
		}
		err = updateRecord(ctx, h, data, t, toUpdateColumns, s.OnDuplicate, r, nil)
		if err != nil {
			return nil, errors.Trace(err)
		}
	}

	return nil, nil
}
示例#10
0
func updateRecord(ctx context.Context, h int64, oldData, newData []types.Datum, updateColumns map[int]*ast.Assignment, t table.Table, offset int, onDuplicateUpdate bool) error {
	if err := t.LockRow(ctx, h, false); err != nil {
		return errors.Trace(err)
	}

	cols := t.Cols()
	touched := make(map[int]bool, len(cols))

	assignExists := false
	var newHandle types.Datum
	for i, asgn := range updateColumns {
		if asgn == nil {
			continue
		}
		if i < offset || i >= offset+len(cols) {
			// The assign expression is for another table, not this.
			continue
		}

		colIndex := i - offset
		col := cols[colIndex]
		if col.IsPKHandleColumn(t.Meta()) {
			newHandle = newData[i]
		}
		if mysql.HasAutoIncrementFlag(col.Flag) {
			if newData[i].Kind() == types.KindNull {
				return errors.Errorf("Column '%v' cannot be null", col.Name.O)
			}
			val, err := newData[i].ToInt64()
			if err != nil {
				return errors.Trace(err)
			}
			t.RebaseAutoID(val, true)
		}

		touched[colIndex] = true
		assignExists = true
	}

	// If no assign list for this table, no need to update.
	if !assignExists {
		return nil
	}

	// Check whether new value is valid.
	if err := column.CastValues(ctx, newData, cols); err != nil {
		return errors.Trace(err)
	}

	if err := column.CheckNotNull(cols, newData); err != nil {
		return errors.Trace(err)
	}

	// If row is not changed, we should do nothing.
	rowChanged := false
	for i := range oldData {
		if !touched[i] {
			continue
		}

		n, err := newData[i].CompareDatum(oldData[i])
		if err != nil {
			return errors.Trace(err)
		}
		if n != 0 {
			rowChanged = true
			break
		}
	}
	if !rowChanged {
		// See: https://dev.mysql.com/doc/refman/5.7/en/mysql-real-connect.html  CLIENT_FOUND_ROWS
		if variable.GetSessionVars(ctx).ClientCapability&mysql.ClientFoundRows > 0 {
			variable.GetSessionVars(ctx).AddAffectedRows(1)
		}
		return nil
	}

	var err error
	if newHandle.Kind() != types.KindNull {
		err = t.RemoveRecord(ctx, h, oldData)
		if err != nil {
			return errors.Trace(err)
		}
		_, err = t.AddRecord(ctx, newData)
	} else {
		// Update record to new value and update index.
		err = t.UpdateRecord(ctx, h, oldData, newData, touched)
	}
	if err != nil {
		return errors.Trace(err)
	}

	// Record affected rows.
	if !onDuplicateUpdate {
		variable.GetSessionVars(ctx).AddAffectedRows(1)
	} else {
		variable.GetSessionVars(ctx).AddAffectedRows(2)
	}
	return nil
}