示例#1
0
文件: orm_funcs.go 项目: d0ngw/go
//构建实体模型的指定类名的更新函数
func createUpdateColumnsFunc(modelInfo *modelMeta) EntityUCFunc {
	return func(executor interface{}, entity EntityInterface, columns string, condition string, params []interface{}) (int64, error) {
		checkEntity(modelInfo, entity, executor)
		if len(columns) == 0 {
			panic(NewDBError(nil, "Can't update empty columns"))
		}
		updateSql := fmt.Sprintf("UPDATE %s SET %s ", entity.TableName(), columns)
		if len(condition) > 0 {
			updateSql += condition

		}
		c.Debugf("updateSql:%v", updateSql)

		rs, err := exec(executor, updateSql, params)
		if err != nil {
			return 0, err
		}

		//检查更新的记录数
		if rows, err := rs.RowsAffected(); err == nil {
			c.Debugf("Updated rows:%v", rows)
			return rows, err
		} else {
			return 0, err
		}
	}
}
示例#2
0
文件: oper.go 项目: d0ngw/go
func (op *DBOper) decrTransDepth() {
	op.transDepth = op.transDepth - 1
	c.Debugf("op.tranDepth:%v", op.transDepth)
	if op.transDepth < 0 {
		panic(NewDBError(nil, "Too many invoke commit or rollback"))
	}
}
示例#3
0
文件: orm_funcs.go 项目: d0ngw/go
//构建实体模型的插入函数
func createInsertFunc(modelInfo *modelMeta) EntityCFunc {
	insertFields := fun.Filter(exceptIdPred, modelInfo.fields).([]*modelField)
	columns := strings.Join(fun.Map(func(field *modelField) string {
		return field.column
	}, insertFields).([]string), ",")
	params := strings.Join(toSlice("?", len(insertFields)), ",")

	return func(executor interface{}, entity EntityInterface) error {
		ind := checkEntity(modelInfo, entity, executor)

		paramValues := make([]interface{}, 0, len(insertFields))
		for _, field := range insertFields {
			fv := ind.Field(field.index).Interface()
			paramValues = append(paramValues, fv)
		}

		insertSql := fmt.Sprintf("INSERT INTO %s (%s) VALUES(%s)", entity.TableName(), columns, params)
		c.Debugf("insertSql:%v", insertSql)

		rs, err := exec(executor, insertSql, paramValues)
		if err != nil {
			return err
		}

		if modelInfo.pkField.pkAuto {
			if id, err := rs.LastInsertId(); err == nil {
				ind.Field(modelInfo.pkField.index).SetInt(id)
			} else {
				return err
			}
		}
		return nil
	}
}
示例#4
0
文件: oper.go 项目: d0ngw/go
//结束事务
func (op *DBOper) finishTrans() error {
	if err := op.checkTransStatus(); err != nil {
		return err
	}
	op.decrTransDepth()
	if op.transDepth > 0 {
		return nil
	}
	defer op.close()
	op.txDone = true
	if op.rollbackOnly {
		c.Debugf("Rollback")
		return op.tx.Rollback()
	} else {
		c.Debugf("Commit")
		return op.tx.Commit()
	}
}
示例#5
0
文件: orm_funcs.go 项目: d0ngw/go
func query(executor interface{}, execSql string, args []interface{}) (rows *sql.Rows, err error) {
	c.Debugf("Exec sql %s with %T", execSql, executor)
	if tx, ok := executor.(*sql.Tx); ok {
		rows, err = tx.Query(execSql, args...)
	} else if db, ok := executor.(*sql.DB); ok {
		rows, err = db.Query(execSql, args...)
	} else {
		panic(NewDBErrorf(nil, "Not a valid executor:%T", executor))
	}
	return
}
示例#6
0
文件: orm_funcs.go 项目: d0ngw/go
//构建删除函数
func createDelFunc(modelInfo *modelMeta) EntityDFunc {
	return func(executor interface{}, entity EntityInterface, condition string, params []interface{}) (int64, error) {
		checkEntity(modelInfo, entity, executor)
		delSql := fmt.Sprintf("DELETE FROM %s ", entity.TableName())
		if len(condition) > 0 {
			delSql += condition
		}
		c.Debugf("delSql:%v", delSql)

		rs, err := exec(executor, delSql, params)
		if err != nil {
			return 0, err
		}
		//检查更新的记录数
		if rows, err := rs.RowsAffected(); err == nil {
			return rows, err
		} else {
			return 0, err
		}
	}
}
示例#7
0
文件: orm_funcs.go 项目: d0ngw/go
//构建查询函数
func createQueryFunc(modelInfo *modelMeta) EntityQFunc {
	columns := strings.Join(fun.Map(func(field *modelField) string {
		return "`" + field.column + "`"
	}, modelInfo.fields).([]string), ",")

	return func(executor interface{}, entity EntityInterface, condition string, params []interface{}) ([]EntityInterface, error) {
		ind := checkEntity(modelInfo, entity, executor)
		querySql := fmt.Sprintf("SELECT %s FROM %s ", columns, entity.TableName())
		if len(condition) > 0 {
			querySql += condition
		}
		c.Debugf("querySql:%v", querySql)

		rows, err := query(executor, querySql, params)
		if err != nil {
			return nil, err
		}
		defer rows.Close()

		var rt = make([]EntityInterface, 0, 10)
		for rows.Next() {
			ptrValue := reflect.New(ind.Type())
			ptrValueInd := reflect.Indirect(ptrValue)
			ptrValueSlice := make([]interface{}, 0, len(modelInfo.fields))
			for _, field := range modelInfo.fields {
				fv := ptrValueInd.Field(field.index).Addr().Interface()
				//c.Debugf("fv:%v,type:%T", fv, fv)
				ptrValueSlice = append(ptrValueSlice, fv)
			}
			if err := rows.Scan(ptrValueSlice...); err == nil {
				rt = append(rt, ptrValue.Interface().(EntityInterface))
			} else {
				return nil, err
			}
		}
		return rt, nil
	}
}
示例#8
0
文件: orm_funcs.go 项目: d0ngw/go
//构建实体模型的更新函数
func createUpdateFunc(modelInfo *modelMeta) EntityUFunc {
	updateFields := fun.Filter(exceptIdPred, modelInfo.fields).([]*modelField)
	columns := strings.Join(fun.Map(func(field *modelField) string {
		return field.column + "=?"
	}, updateFields).([]string), ",")

	return func(executor interface{}, entity EntityInterface) (bool, error) {
		ind := checkEntity(modelInfo, entity, executor)

		paramValues := make([]interface{}, 0, len(updateFields)+1)
		for _, field := range updateFields {
			fv := ind.Field(field.index).Interface()
			paramValues = append(paramValues, fv)
		}

		id := ind.Field(modelInfo.pkField.index).Interface()
		paramValues = append(paramValues, id)

		updateSql := fmt.Sprintf("UPDATE %s SET %s where %s = %s", entity.TableName(), columns, modelInfo.pkField.column, "?")
		c.Debugf("updateSql:%v", updateSql)

		rs, err := exec(executor, updateSql, paramValues)
		if err != nil {
			return false, err
		}

		//检查更新的记录数
		if rows, err := rs.RowsAffected(); err == nil {
			if rows != 1 {
				return false, err
			} else {
				return true, err
			}
		} else {
			return false, err
		}
	}
}
示例#9
0
文件: orm.go 项目: d0ngw/go
//注册一个数据模型
func (reg *modelReg) RegModel(model EntityInterface) error {
	if model == nil {
		panic(NewDBError(nil, "Invalid model"))
	}

	val, ind, typ := extract(model)
	fullName := getFullModelName(typ)

	if val.Kind() != reflect.Ptr {
		panic(NewDBErrorf(nil, "Expect ptr ,but it's %s,type:%s", val.Kind(), typ))
	}
	if ind.Kind() != reflect.Struct {
		panic(NewDBErrorf(nil, "Expect struct ,but it's %s,type:%s", typ.Kind(), typ))
	}

	fieldCount := ind.NumField()
	fields := make([]*modelField, 0, fieldCount)
	mInfo := &modelMeta{name: fullName, table: model.TableName(), modelType: typ}
	var pkField *modelField = nil

	for i := 0; i < ind.NumField(); i++ {
		structField := typ.Field(i)
		sfTag := structField.Tag
		column := sfTag.Get("column")
		pk := strings.ToLower(sfTag.Get("pk"))
		pkAuto := strings.ToLower(sfTag.Get("pkAuto"))
		if len(column) == 0 {
			panic(NewDBErrorf(nil, "Can't find the column tag for %s.%s", typ, structField.Name))
		}

		mField := &modelField{
			name:        structField.Name,
			column:      column,
			pk:          pk == "y",
			pkAuto:      !(pkAuto == "n"),
			index:       i,
			structField: structField}

		if mField.pk {
			if pkField == nil {
				pkField = mField
			} else {
				panic(NewDBErrorf(nil, "Duplicate pk column for %s.%s and %s ", typ, pkField.name, mField.name))
			}
		}
		fields = append(fields, mField)
	}

	if pkField == nil {
		panic(NewDBErrorf(nil, "Can't find pk column for %s", typ))
	} else {
		mInfo.pkField = pkField
	}
	c.Debugf("Register Model:%s,fields:%s,pkFiled:%+v", fullName, fields, pkField)

	mInfo.fields = fields
	mInfo.insertFunc = createInsertFunc(mInfo)
	mInfo.updateFunc = createUpdateFunc(mInfo)
	mInfo.updateColumnsFunc = createUpdateColumnsFunc(mInfo)
	mInfo.queryFunc = createQueryFunc(mInfo)
	mInfo.getFunc = func(executor interface{}, entity EntityInterface, id int64) (e EntityInterface, err error) {
		e = nil
		var l []EntityInterface = nil
		if l, err = mInfo.queryFunc(executor, entity, " WHERE "+mInfo.pkField.column+" = ?", []interface{}{id}); err == nil {
			if len(l) == 1 {
				e = l[0]
			}
		}
		return
	}
	mInfo.delFunc = createDelFunc(mInfo)
	mInfo.delEFunc = func(executor interface{}, entity EntityInterface, id int64) (r bool, err error) {
		var l int64
		if l, err = mInfo.delFunc(executor, entity, " WHERE "+mInfo.pkField.column+" = ?", []interface{}{id}); err == nil {
			if l == 1 {
				r = true
			}
		}
		return
	}

	_modelReg.lock.Lock()
	defer _modelReg.lock.Unlock()
	if _, exist := _modelReg.cache[fullName]; exist {
		return &DBError{"Duplicate mode name:" + fullName, nil}
	}
	_modelReg.cache[fullName] = mInfo
	return nil
}
示例#10
0
文件: oper.go 项目: d0ngw/go
func (op *DBOper) incrTransDepth() {
	op.transDepth = op.transDepth + 1
	c.Debugf("op.tranDepth:%v", op.transDepth)
}