Exemplo n.º 1
0
func (s *Shard) Transact(txFun TxFunc) errs.Err {
	conn, stdErr := s.db.Begin()
	if stdErr != nil {
		return errs.Wrap(stdErr, errs.Info{"Description": "Could not open transaction"})
	}
	defer func() {
		if panicErr := recover(); panicErr != nil {
			rbErr := conn.Rollback()
			panic(errs.New(errs.Info{
				"Description": "Panic during sql transcation",
				"PanicErr":    panicErr,
				"RollbackErr": rbErr,
			}))
		}
	}()

	err := txFun(&Shard{s.DBName, nil, conn})
	if err != nil {
		rbErr := conn.Rollback()
		if rbErr != nil {
			return errs.Wrap(rbErr, errs.Info{"Description": "Transact rollback error", "TransactionError": err})
		}

	} else {
		stdErr = conn.Commit()
		if stdErr != nil {
			return errs.Wrap(stdErr, errs.Info{"Description": "Could not commit transaction"})
		}
	}

	return nil
}
Exemplo n.º 2
0
func (s *Shard) queryOne(query string, args []interface{}, out interface{}) (found bool, err errs.Err) {
	rows, err := s.Query(query, args...)
	if err != nil {
		return
	}
	defer rows.Close()

	if rows.Next() {
		stdErr := rows.Scan(out)
		if stdErr != nil {
			err = errs.Wrap(stdErr, errInfo("queryOne rows.Scan error", query, args))
			return
		}
		if rows.Next() {
			err = errs.New(errInfo("queryOne query returned too many rows", query, args))
			return
		}
		found = true
	}

	stdErr := rows.Err()
	if stdErr != nil {
		err = errs.Wrap(stdErr, errInfo("queryOne rows.Err", query, args))
		return
	}

	return
}
Exemplo n.º 3
0
func (s *Shard) scanOne(output interface{}, query string, required bool, args ...interface{}) (found bool, err errs.Err) {
	// Check types
	var outputReflectionPtr = reflect.ValueOf(output)
	if !outputReflectionPtr.IsValid() {
		panic(scanOneTypeError)
	}
	if outputReflectionPtr.Kind() != reflect.Ptr {
		panic(scanOneTypeError)
	}
	var outputReflection = outputReflectionPtr.Elem()
	if outputReflection.Kind() != reflect.Ptr {
		panic(scanOneTypeError)
	}

	// Query DB
	rows, err := s.Query(query, args...)
	if err != nil {
		return
	}
	defer rows.Close()

	// Reflect onto struct
	columns, stdErr := rows.Columns()
	if stdErr != nil {
		err = errs.Wrap(stdErr, errInfo("rows.Columns() error", query, args))
		return
	}
	if !rows.Next() {
		return
	}

	var vStruct reflect.Value
	if outputReflection.IsNil() {
		structPtrVal := reflect.New(outputReflection.Type().Elem())
		outputReflection.Set(structPtrVal)
		vStruct = structPtrVal.Elem()
	} else {
		vStruct = outputReflection.Elem()
	}

	err = structFromRow(vStruct, columns, rows, query, args)
	if err != nil {
		return
	}

	if rows.Next() {
		err = errs.New(errInfo("scanOne got multiple rows", query, args))
		return
	}

	stdErr = rows.Err()
	if stdErr != nil {
		err = errs.Wrap(stdErr, errInfo("scanOne rows.Err() error", query, args))
		return
	}

	found = true
	return
}
Exemplo n.º 4
0
func (s *Shard) SelectOne(output interface{}, query string, args ...interface{}) (err errs.Err) {
	found, err := s.scanOne(output, query, true, args...)
	if err != nil {
		return
	}
	if !found {
		err = errs.New(errInfo("scanOne got no rows", query, args))
		return
	}
	return
}
Exemplo n.º 5
0
func (s *Shard) UpdateNum(num int64, query string, args ...interface{}) (err errs.Err) {
	rowsAffected, err := s.Update(query, args...)
	if err != nil {
		return err
	}
	if rowsAffected != num {
		return errs.New(errInfo("UpdateNum affected unexpected number of rows", query, args,
			errs.Info{"ExpectedRows": num, "AffectedRows": rowsAffected}))
	}
	return
}
Exemplo n.º 6
0
func (s *Shard) SelectUint(query string, args ...interface{}) (num uint, err errs.Err) {
	found, err := s.queryOne(query, args, &num)
	if err != nil {
		return
	}
	if !found {
		err = errs.New(errInfo("Query returned no rows", query, args))
		return
	}
	return
}
Exemplo n.º 7
0
func (s *Shard) SelectString(query string, args ...interface{}) (str string, err errs.Err) {
	var nullStr sql.NullString
	found, err := s.queryOne(query, args, &nullStr)
	if err != nil {
		return
	}
	if found {
		str = nullStr.String
	} else {
		err = errs.New(errInfo("Query returned no rows", query, args))
		return
	}
	return
}
Exemplo n.º 8
0
func Uid(numChars int) (uid string, err errs.Err) {
	if numChars%4 != 0 {
		err = errs.New(nil, "uid length must be a multiple of 4")
		return
	}
	buf := make([]byte, numChars)
	_, stdErr := io.ReadFull(rand.Reader, buf)
	if stdErr != nil {
		err = errs.Wrap(stdErr, nil)
		return
	}

	uid = base64.URLEncoding.EncodeToString(buf)
	return
}
Exemplo n.º 9
0
func scanColumnValue(column string, reflectVal reflect.Value, value *sql.RawBytes, query string, args []interface{}) errs.Err {
	bytes := []byte(*value)
	if bytes == nil {
		return nil // Leave struct field empty
	}
	switch reflectVal.Kind() {
	case reflect.String:
		reflectVal.SetString(string(bytes))
	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
		uintVal, stdErr := strconv.ParseUint(string(bytes), 10, 64)
		if stdErr != nil {
			return errs.Wrap(stdErr, errInfo("strconv.ParseUint error", query, args, errs.Info{"Bytes": bytes}))
		}
		reflectVal.SetUint(reflect.ValueOf(uintVal).Uint())
	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
		intVal, stdErr := strconv.ParseInt(string(bytes), 10, 64)
		if stdErr != nil {
			return errs.Wrap(stdErr, errInfo("strconv.ParseInt error", query, args, errs.Info{"Bytes": bytes}))
		}
		reflectVal.SetInt(reflect.ValueOf(intVal).Int())
	case reflect.Bool:
		boolVal, stdErr := strconv.ParseBool(string(bytes))
		if stdErr != nil {
			return errs.Wrap(stdErr, errInfo("strconv.ParseBool error", query, args, errs.Info{"Bytes": bytes}))
		}
		reflectVal.SetBool(reflect.ValueOf(boolVal).Bool())
	default:
		if reflectVal.Kind() == reflect.Slice { // && reflectVal. == reflect.Uint8 {
			// byte slice
			reflectVal.SetBytes(bytes)
		} else {
			return errs.New(errInfo("Bad row value for column "+column+": "+reflectVal.Kind().String(), query, args))
		}
	}
	return nil
}
Exemplo n.º 10
0
func (s *Shard) Select(output interface{}, query string, args ...interface{}) errs.Err {
	// Check types
	var outputPtr = reflect.ValueOf(output)
	if outputPtr.Kind() != reflect.Ptr {
		return errs.New(errInfo("Select expects a pointer to a slice of items", query, args))
	}
	var outputReflection = reflect.Indirect(outputPtr)
	if outputReflection.Kind() != reflect.Slice {
		return errs.New(errInfo("Select expects items to be a slice", query, args))
	}
	if outputReflection.Len() != 0 {
		return errs.New(errInfo("Select expects items to be empty", query, args))
	}
	outputReflection.Set(reflect.MakeSlice(outputReflection.Type(), 0, 0))

	// Query DB
	var rows, err = s.Query(query, args...)
	if err != nil {
		return err
	}
	defer rows.Close()
	columns, stdErr := rows.Columns()
	if stdErr != nil {
		return errs.Wrap(stdErr, errInfo("Select rows.Columns error", query, args))
	}

	valType := outputReflection.Type().Elem()
	isStruct := (valType.Kind() == reflect.Ptr && valType.Elem().Kind() == reflect.Struct)
	if isStruct {
		// Reflect onto structs
		for rows.Next() {
			structPtrVal := reflect.New(valType.Elem())
			outputItemStructVal := structPtrVal.Elem()
			err = structFromRow(outputItemStructVal, columns, rows, query, args)
			if err != nil {
				return err
			}
			outputReflection.Set(reflect.Append(outputReflection, structPtrVal))
		}
	} else {
		if len(columns) != 1 {
			return errs.New(errInfo("Select expected single column in select statement for slice of non-struct values", query, args))
		}
		for rows.Next() {
			rawBytes := &sql.RawBytes{}
			stdErr = rows.Scan(rawBytes)
			if stdErr != nil {
				return errs.Wrap(stdErr, errInfo("Select rows.Scan error", query, args))
			}
			outputValue := reflect.New(valType).Elem()
			err = scanColumnValue(columns[0], outputValue, rawBytes, query, args)
			if err != nil {
				return err
			}
			outputReflection.Set(reflect.Append(outputReflection, outputValue))
		}
	}

	stdErr = rows.Err()
	if err != nil {
		return errs.Wrap(stdErr, errInfo("Select rows.Err() error", query, args))
	}
	return nil
}