コード例 #1
0
ファイル: util_db.go プロジェクト: fishedee/fishgo
func (this *databaseImplement) UpdateBatch(rowsSlicePtr interface{}, indexColName string) (int64, error) {
	sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
	if sliceValue.Kind() != reflect.Slice {
		return 0, errors.New("needs a pointer to a slice")
	}
	if sliceValue.Len() == 0 {
		return 0, errors.New("update rows is empty")
	}

	bean := sliceValue.Index(0).Interface()
	elementValue := this.rValue(bean)
	table := this.autoMapType(elementValue)
	size := sliceValue.Len()

	var rows = make([][]interface{}, 0)
	var indexRow = make([]interface{}, 0)
	cols := make([]*core.Column, 0)
	updateCols := make([]bool, 0)
	var indexCol *core.Column

	//提取字段
	for i := 0; i < size; i++ {
		v := sliceValue.Index(i)
		vv := reflect.Indirect(v)

		//处理需要的update的列
		if i == 0 {
			for _, col := range table.Columns() {
				if col.Name == indexColName {
					indexCol = col
				} else {
					cols = append(cols, col)
					updateCols = append(updateCols, false)
				}
			}
			if indexCol == nil {
				return 0, errors.New("counld not found index col " + indexColName)
			}
		}

		//处理需要的update的值
		var singleRow = make([]interface{}, 0)
		for colIndex, col := range cols {
			ptrFieldValue, err := col.ValueOfV(&vv)
			if err != nil {
				return 0, err
			}
			fieldValue := *ptrFieldValue
			var arg interface{}
			if this.isZero(fieldValue.Interface()) {
				arg = nil
			} else {
				var err error
				arg, err = this.value2Interface(fieldValue)
				if err != nil {
					return 0, err
				}
				updateCols[colIndex] = true
			}
			singleRow = append(singleRow, arg)
		}
		rows = append(rows, singleRow)
		ptrFieldValue, err := indexCol.ValueOfV(&vv)
		if err != nil {
			return 0, err
		}
		fieldValue := *ptrFieldValue
		arg, err := this.value2Interface(fieldValue)
		if err != nil {
			return 0, err
		}
		indexRow = append(indexRow, arg)
	}
	if len(cols) == 0 {
		return 0, errors.New("update cols is empty! " + fmt.Sprintf("%v", rowsSlicePtr))
	}

	//拼接sql
	var sqlArgs = make([]interface{}, 0)
	var sql = "UPDATE " + table.Name + " SET "
	var isFirstUpdateCol = true
	for colIndex, col := range cols {
		if updateCols[colIndex] == false {
			continue
		}
		if isFirstUpdateCol == false {
			sql += " , "
		}
		sql += this.Engine.QuoteStr() + col.Name + this.Engine.QuoteStr()
		sql += " = CASE "
		sql += this.Engine.QuoteStr() + indexCol.Name + this.Engine.QuoteStr()
		for rowIndex, row := range rows {
			if row[colIndex] == nil {
				continue
			}
			sql += " WHEN ? THEN ? "
			sqlArgs = append(sqlArgs, indexRow[rowIndex])
			sqlArgs = append(sqlArgs, row[colIndex])
		}
		sql += " END "
		isFirstUpdateCol = false
	}
	sql += " WHERE " + this.Engine.QuoteStr() + indexCol.Name + this.Engine.QuoteStr() + " IN ( "
	for rowIndex, row := range indexRow {
		if rowIndex != 0 {
			sql += " , "
		}
		sql += " ? "
		sqlArgs = append(sqlArgs, row)
	}
	sql += " ) "

	//执行sql
	res, err := this.Exec(sql, sqlArgs...)
	if err != nil {
		return 0, err
	}
	return res.RowsAffected()
}