func (g *validation) generateTests(msg *generator.Descriptor, field *pb.FieldDescriptorProto, fieldName string, idx int, patterns map[string]string) { if limbo.IsRequiredProperty(field) { g.generateRequiredTest(msg, field, fieldName) } if n, ok := limbo.GetMinItems(field); ok { g.generateMinItemsTest(msg, field, fieldName, int(n)) } if n, ok := limbo.GetMaxItems(field); ok { g.generateMaxItemsTest(msg, field, fieldName, int(n)) } if pattern, ok := limbo.GetPattern(field); ok { patternVar := fmt.Sprintf("valPattern_%s_%d", msg.GetName(), idx) patterns[patternVar] = pattern g.generatePatternTest(msg, field, fieldName, pattern, patternVar) } if n, ok := limbo.GetMinLength(field); ok { g.generateMinLengthTest(msg, field, fieldName, int(n)) } if n, ok := limbo.GetMaxLength(field); ok { g.generateMaxLengthTest(msg, field, fieldName, int(n)) } if field.GetType() == pb.FieldDescriptorProto_TYPE_MESSAGE { g.generateSubMessageTest(msg, field, fieldName) } }
func (g *svcauth) setMessage(inputType *generator.Descriptor, path, input, output string, inputIsNullable bool) { var ( goPath string outputIsNullable bool ) goPath = input if inputIsNullable { g.P(`if `, goPath, `== nil {`) g.P(goPath, `= &`, g.gen.TypeName(inputType), `{}`) g.P(`}`) } for path != "" { // split path part := path idx := strings.IndexByte(path, '.') if idx >= 0 { part = path[:idx] path = path[idx+1:] } else { path = "" } // Get Field field := inputType.GetFieldDescriptor(part) if field == nil { g.gen.Fail("unknown field", part, "in message", inputType.GetName()) } if !field.IsMessage() { g.gen.Fail("expected a message") } // Append code fieldGoName := g.gen.GetFieldName(inputType, field) goPath += "." + fieldGoName inputType = g.messages[strings.TrimPrefix(field.GetTypeName(), ".")] if gogoproto.IsNullable(field) && path != "" { g.P(`if `, goPath, `== nil {`) g.P(goPath, `= &`, g.gen.TypeName(inputType), `{}`) g.P(`}`) } if gogoproto.IsNullable(field) { outputIsNullable = true } else { outputIsNullable = false } } if outputIsNullable { g.P(goPath, ` = `, output) } else { g.P(goPath, ` = &`, output) } }
func (g *gensql) generateStmt(file *generator.FileDescriptor, message *generator.Descriptor) { model := limbo.GetModel(message) g.P(`type `, message.Name, `StmtBuilder interface {`) g.P(`Prepare(scanner string, query string) `, message.Name, `Stmt`) g.P(`PrepareExecer(query string) `, message.Name, `Execer`) g.P(`Err() error`) g.P(`}`) g.P(`type `, message.Name, `Stmt interface {`) g.P(`QueryRow(ctx `, g.contextPkg.Use(), `.Context, args ... interface{}) (`, message.Name, `Row)`) g.P(`Query(ctx `, g.contextPkg.Use(), `.Context, args ... interface{}) (`, message.Name, `Rows, error)`) g.P(`SelectSlice(ctx `, g.contextPkg.Use(), `.Context, dst []*`, message.Name, `, args ... interface{}) ([]*`, message.Name, `, error)`) g.P(`SelectMessageSlice(ctx `, g.contextPkg.Use(), `.Context, dst []interface{}, args ... interface{}) ([]interface{}, error)`) g.P(`ForTx(tx *`, g.sqlPkg.Use(), `.Tx) `, message.Name, `Stmt`) g.P(`}`) g.P(`type `, message.Name, `Execer interface {`) g.P(`Exec(ctx `, g.contextPkg.Use(), `.Context, args ... interface{}) (`, g.sqlPkg.Use(), `.Result, error)`) g.P(`ForTx(tx *`, g.sqlPkg.Use(), `.Tx) `, message.Name, `Execer`) g.P(`}`) g.P(`type `, message.Name, `Row interface {`) g.P(`Scan(out *`, message.Name, `) error`) g.P(`}`) g.P(`type `, message.Name, `Rows interface {`) g.P(`Close() error`) g.P(`Next() bool`) g.P(`Err() error`) g.P(`Scan(out *`, message.Name, `) error`) g.P(`}`) g.P(`type `, unexport(*message.Name), `StmtBuilder struct {`) g.P(`db *`, g.sqlPkg.Use(), `.DB`) g.P(`err error`) g.P(`}`) g.P(`type `, unexport(*message.Name), `Stmt struct {`) g.P(`stmt *`, g.sqlPkg.Use(), `.Stmt`) g.P(`scanner func(func(...interface{}) error, *`, message.Name, `) error`) g.P(`query string`) g.P(`}`) g.P(`type `, unexport(*message.Name), `Execer struct {`) g.P(`stmt *`, g.sqlPkg.Use(), `.Stmt`) g.P(`query string`) g.P(`}`) g.P(`type `, unexport(*message.Name), `Row struct {`) g.P(`row *`, g.sqlPkg.Use(), `.Row`) g.P(`span *`, g.tracePkg.Use(), `.Span`) g.P(`scanner func(func(...interface{}) error, *`, message.Name, `) error`) g.P(`}`) g.P(`type `, unexport(*message.Name), `Rows struct {`) g.P(`*`, g.sqlPkg.Use(), `.Rows`) g.P(`span *`, g.tracePkg.Use(), `.Span`) g.P(`scanner func(func(...interface{}) error, *`, message.Name, `) error`) g.P(`}`) g.P(`func New`, message.Name, `StmtBuilder(db *`, g.sqlPkg.Use(), `.DB) `, message.Name, `StmtBuilder {`) g.P(`return &`, unexport(*message.Name), `StmtBuilder{db: db}`) g.P(`}`) g.P(`func (b *`, unexport(*message.Name), `StmtBuilder) Prepare(scanner string, query string) (`, message.Name, `Stmt) {`) g.P(`if b.err != nil { return nil }`) g.P(`var scannerFunc func(func(...interface{}) error, *`, message.Name, `) error`) g.P(`switch scanner {`) for _, scanner := range model.Scanner { scannerFuncName := `scan_` + message.GetName() if scanner.Name != "" { scannerFuncName += `_` + scanner.Name } g.P(`case `, strconv.Quote(scanner.Name), `:`) g.P(`query = `, scannerFuncName, `SQL + " " + query`) g.P(`scannerFunc = `, scannerFuncName) } g.P(`default:`) g.P(`if b.err == nil { b.err = fmt.Errorf("unknown scanner: %s", scanner) }`) g.P(`}`) g.P(`query = `, g.runtimePkg.Use(), `.CleanSQL(query)`) g.P(`stmt, err := b.db.Prepare(query)`) g.P(`if err != nil { if b.err == nil { b.err = err } }`) g.P(`return &`, unexport(*message.Name), `Stmt{stmt: stmt, query:query, scanner: scannerFunc}`) g.P(`}`) g.P(`func (b *`, unexport(*message.Name), `StmtBuilder) PrepareExecer( query string) (`, message.Name, `Execer) {`) g.P(`if b.err != nil { return nil }`) g.P(`query = `, g.runtimePkg.Use(), `.CleanSQL(query)`) g.P(`stmt, err := b.db.Prepare(query)`) g.P(`if err != nil { if b.err == nil { b.err = err } }`) g.P(`return &`, unexport(*message.Name), `Execer{stmt: stmt, query: query}`) g.P(`}`) g.P(`func (b *`, unexport(*message.Name), `StmtBuilder) Err() (error) {`) g.P(`return b.err`) g.P(`}`) g.P(`func (s *`, unexport(*message.Name), `Stmt) QueryRow(ctx `, g.contextPkg.Use(), `.Context, args ... interface{}) (`, message.Name, `Row) {`) g.P(`span, _ := `, g.tracePkg.Use(), `.New(ctx, "QueryRow("+s.query+")")`) g.P(`row := s.stmt.QueryRow(args...)`) g.P(`return &`, unexport(*message.Name), `Row{row: row, scanner: s.scanner, span: span}`) g.P(`}`) g.P(`func (s *`, unexport(*message.Name), `Stmt) Query(ctx `, g.contextPkg.Use(), `.Context, args ... interface{}) (`, message.Name, `Rows, error) {`) g.P(`span, _ := `, g.tracePkg.Use(), `.New(ctx, "Query("+s.query+")")`) g.P(`rows, err := s.stmt.Query(args...)`) g.P(`if err != nil {`) g.P(`span.Error(err)`) g.P(`span.Close()`) g.P(`return nil, err`) g.P(`}`) g.P(`return &`, unexport(*message.Name), `Rows{Rows: rows, scanner: s.scanner, span: span}, nil`) g.P(`}`) g.P(`func (s *`, unexport(*message.Name), `Stmt) SelectSlice(ctx `, g.contextPkg.Use(), `.Context, dst []*`, message.Name, `, args ... interface{}) ([]*`, message.Name, `, error) {`) g.P(`rows, err := s.Query(ctx, args...)`) g.P(`if err != nil { return nil, err }`) g.P(`defer rows.Close()`) g.P(`for rows.Next() {`) g.P(`var x = &`, message.Name, `{}`) g.P(`err := rows.Scan(x)`) g.P(`if err != nil { return nil, err }`) g.P(`dst = append(dst, x)`) g.P(`}`) g.P(`err = rows.Err()`) g.P(`if err != nil { return nil, err }`) g.P(`return dst, nil`) g.P(`}`) g.P(`func (s *`, unexport(*message.Name), `Stmt) SelectMessageSlice(ctx `, g.contextPkg.Use(), `.Context, dst []interface{}, args ... interface{}) ([]interface{}, error) {`) g.P(`rows, err := s.Query(ctx, args...)`) g.P(`if err != nil { return nil, err }`) g.P(`defer rows.Close()`) g.P(`for rows.Next() {`) g.P(`var x = &`, message.Name, `{}`) g.P(`err := rows.Scan(x)`) g.P(`if err != nil { return nil, err }`) g.P(`dst = append(dst, x)`) g.P(`}`) g.P(`err = rows.Err()`) g.P(`if err != nil { return nil, err }`) g.P(`return dst, nil`) g.P(`}`) g.P(`func (s *`, unexport(*message.Name), `Stmt) ForTx(tx *`, g.sqlPkg.Use(), `.Tx) `, message.Name, `Stmt {`) g.P(`return &`, unexport(*message.Name), `Stmt{stmt: tx.Stmt(s.stmt), scanner: s.scanner, query: s.query}`) g.P(`}`) g.P(`func (s *`, unexport(*message.Name), `Execer) Exec(ctx `, g.contextPkg.Use(), `.Context,args ... interface{}) (`, g.sqlPkg.Use(), `.Result, error) {`) g.P(`span, _ := `, g.tracePkg.Use(), `.New(ctx, "Exec("+s.query+")")`) g.P(`defer span.Close()`) g.P(`res, err := s.stmt.Exec(args...)`) g.P(`if err != nil { span.Error(err) }`) g.P(`return res, err`) g.P(`}`) g.P(`func (s *`, unexport(*message.Name), `Execer) ForTx(tx *`, g.sqlPkg.Use(), `.Tx) `, message.Name, `Execer {`) g.P(`return &`, unexport(*message.Name), `Execer{stmt: tx.Stmt(s.stmt), query: s.query}`) g.P(`}`) g.P(`func (r *`, unexport(*message.Name), `Row) Scan(out *`, message.Name, `) error {`) g.P(`defer r.span.Close()`) g.P(`err := r.scanner(r.row.Scan, out)`) g.P(`if err != nil { r.span.Error(err) }`) g.P(`return err`) g.P(`}`) g.P(`func (r *`, unexport(*message.Name), `Rows) Scan(out *`, message.Name, `) error {`) g.P(`err := r.scanner(r.Rows.Scan, out)`) g.P(`if err != nil { r.span.Error(err) }`) g.P(`return err`) g.P(`}`) g.P(`func (r *`, unexport(*message.Name), `Rows) Close() error {`) g.P(`defer r.span.Close()`) g.P(`err := r.Rows.Close()`) g.P(`if err != nil { r.span.Error(err) }`) g.P(`return err`) g.P(`}`) }
func (g *gensql) generateScanner(file *generator.FileDescriptor, message *generator.Descriptor, scanner *limbo.ScannerDescriptor) { scannerFuncName := `scan_` + message.GetName() if scanner.Name != "" { scannerFuncName += `_` + scanner.Name } joins := map[string]int{} g.P(``) g.P(`const `, scannerFuncName, `SQL = `, strconv.Quote(g.generateQueryPrefix(message, scanner))) g.P(`func `, scannerFuncName, `(scanFunc func(...interface{})error, dst *`, message.Name, `) error {`) g.P(`var (`) for i, column := range scanner.Column { m := g.models[column.MessageType] field := m.GetFieldDescriptor(lastField(column.FieldName)) if field.IsRepeated() { g.P(`b`, i, ` []byte`) continue } switch field.GetType() { case pb.FieldDescriptorProto_TYPE_BOOL: g.P(`b`, i, ` `, g.sqlPkg.Use(), `.NullBool`) case pb.FieldDescriptorProto_TYPE_DOUBLE: g.P(`b`, i, ` `, g.sqlPkg.Use(), `.NullFloat64`) case pb.FieldDescriptorProto_TYPE_FLOAT: g.P(`b`, i, ` `, g.sqlPkg.Use(), `.NullFloat64`) case pb.FieldDescriptorProto_TYPE_FIXED32, pb.FieldDescriptorProto_TYPE_UINT32: g.P(`b`, i, ` `, g.sqlPkg.Use(), `.NullInt64`) case pb.FieldDescriptorProto_TYPE_FIXED64, pb.FieldDescriptorProto_TYPE_UINT64: g.P(`b`, i, ` `, g.sqlPkg.Use(), `.NullInt64`) case pb.FieldDescriptorProto_TYPE_SFIXED32, pb.FieldDescriptorProto_TYPE_INT32, pb.FieldDescriptorProto_TYPE_SINT32: g.P(`b`, i, ` `, g.sqlPkg.Use(), `.NullInt64`) case pb.FieldDescriptorProto_TYPE_SFIXED64, pb.FieldDescriptorProto_TYPE_INT64, pb.FieldDescriptorProto_TYPE_SINT64: g.P(`b`, i, ` `, g.sqlPkg.Use(), `.NullInt64`) case pb.FieldDescriptorProto_TYPE_BYTES: g.P(`b`, i, ` []byte`) case pb.FieldDescriptorProto_TYPE_STRING: g.P(`b`, i, ` `, g.sqlPkg.Use(), `.NullString`) case pb.FieldDescriptorProto_TYPE_ENUM: g.P(`b`, i, ` `, g.sqlPkg.Use(), `.NullInt64`) case pb.FieldDescriptorProto_TYPE_MESSAGE: if field.GetTypeName() == ".google.protobuf.Timestamp" { g.P(`b`, i, ` `, g.mysqlPkg.Use(), `.NullTime`) } else if limbo.IsGoSQLValuer(g.objectNamed(field.GetTypeName()).(*generator.Descriptor)) { g.P(`b`, i, ` Null`, g.typeName(field.GetTypeName())) } else { g.P(`b`, i, ` []byte`) g.P(`m`, i, ` `, g.typeName(field.GetTypeName())) } default: panic("unsuported type: " + field.GetType().String()) } } for i, join := range scanner.Join { m := g.models[join.MessageType] field := m.GetFieldDescriptor(lastField(join.FieldName)) joins[join.FieldName] = i g.P(`j`, i, ` `, g.typeName(field.GetTypeName())) g.P(`j`, i, `Valid bool`) } g.P(`)`) g.P(`err := scanFunc(`) for i := range scanner.Column { g.P(`&b`, i, `,`) } g.P(`)`) g.P(`if err!=nil { return err }`) for i, column := range scanner.Column { var ( m = g.models[column.MessageType] field = m.GetFieldDescriptor(lastField(column.FieldName)) valid string dst = "dst" ) if column.JoinedWith != "" { dst = fmt.Sprintf("j%d", joins[column.JoinedWith]) } if field.IsRepeated() { valid = fmt.Sprintf(`b%d != nil`, i) } else { switch field.GetType() { case pb.FieldDescriptorProto_TYPE_MESSAGE: if field.GetTypeName() == ".google.protobuf.Timestamp" { valid = fmt.Sprintf(`b%d.Valid`, i) } else if limbo.IsGoSQLValuer(g.objectNamed(field.GetTypeName()).(*generator.Descriptor)) { valid = fmt.Sprintf(`b%d.Valid`, i) } else { valid = fmt.Sprintf(`b%d != nil`, i) } case pb.FieldDescriptorProto_TYPE_BYTES: valid = fmt.Sprintf(`b%d != nil`, i) default: valid = fmt.Sprintf(`b%d.Valid`, i) } } fieldName := g.gen.GetFieldName(m, field) g.P(`if `, valid, ` {`) if column.JoinedWith != "" { g.P(dst, `Valid = true`) } if field.IsRepeated() { g.P(`if err:= `, g.jsonPkg.Use(), `.Unmarshal(b`, i, `, &`, dst, `.`, fieldName, `); err!=nil {`) g.P(`return err`) g.P(`}`) } else { switch field.GetType() { case pb.FieldDescriptorProto_TYPE_BOOL: g.P(dst, `.`, fieldName, ` = b`, i, `.Bool`) case pb.FieldDescriptorProto_TYPE_DOUBLE: g.P(dst, `.`, fieldName, ` = float32(b`, i, `.Float64)`) case pb.FieldDescriptorProto_TYPE_FLOAT: g.P(dst, `.`, fieldName, ` = float32(b`, i, `.Float64)`) case pb.FieldDescriptorProto_TYPE_FIXED32, pb.FieldDescriptorProto_TYPE_UINT32: g.P(dst, `.`, fieldName, ` = uint32(b`, i, `.Int64)`) case pb.FieldDescriptorProto_TYPE_FIXED64, pb.FieldDescriptorProto_TYPE_UINT64: g.P(dst, `.`, fieldName, ` = uint64(b`, i, `.Int64)`) case pb.FieldDescriptorProto_TYPE_SFIXED32, pb.FieldDescriptorProto_TYPE_INT32, pb.FieldDescriptorProto_TYPE_SINT32: g.P(dst, `.`, fieldName, ` = int32(b`, i, `.Int64)`) case pb.FieldDescriptorProto_TYPE_SFIXED64, pb.FieldDescriptorProto_TYPE_INT64, pb.FieldDescriptorProto_TYPE_SINT64: g.P(dst, `.`, fieldName, ` = int64(b`, i, `.Int64)`) case pb.FieldDescriptorProto_TYPE_ENUM: g.P(dst, `.`, fieldName, ` = `, g.typeName(field.GetTypeName()), `(b`, i, `.Int64)`) case pb.FieldDescriptorProto_TYPE_BYTES: g.P(dst, `.`, fieldName, ` = b`, i, ``) case pb.FieldDescriptorProto_TYPE_STRING: g.P(dst, `.`, fieldName, ` = b`, i, `.String`) case pb.FieldDescriptorProto_TYPE_MESSAGE: if field.GetTypeName() == ".google.protobuf.Timestamp" { if gogoproto.IsNullable(field) { g.P(dst, `.`, fieldName, ` = &b`, i, `.Time`) } else { g.P(dst, `.`, fieldName, ` = b`, i, `.Time`) } } else if limbo.IsGoSQLValuer(g.objectNamed(field.GetTypeName()).(*generator.Descriptor)) { if gogoproto.IsNullable(field) { g.P(dst, `.`, fieldName, ` = &b`, i, `.`, g.typeName(field.GetTypeName())) } else { g.P(dst, `.`, fieldName, ` = b`, i, `.`, g.typeName(field.GetTypeName())) } } else { g.P(`err := m`, i, `.Unmarshal(b`, i, `)`) g.P(`if err!=nil { return err }`) if gogoproto.IsNullable(field) { g.P(dst, `.`, fieldName, ` = &m`, i, ``) } else { g.P(dst, `.`, fieldName, ` = m`, i, ``) } } } } g.P(`}`) } for i := len(scanner.Join) - 1; i >= 0; i-- { var ( join = scanner.Join[i] m = g.models[join.MessageType] field = m.GetFieldDescriptor(lastField(join.FieldName)) dst = "dst" ) if join.JoinedWith != "" { dst = fmt.Sprintf("j%d", joins[join.JoinedWith]) } fieldName := g.gen.GetFieldName(m, field) g.P(`if j`, i, `Valid {`) if gogoproto.IsNullable(field) { g.P(dst, `.`, fieldName, ` = &j`, i, ``) } else { g.P(dst, `.`, fieldName, ` = j`, i, ``) } g.P(`}`) } g.P(`return nil`) g.P(`}`) g.P(``) }
func (g *gensql) populateMessage(file *generator.FileDescriptor, msg *generator.Descriptor) { model := limbo.GetModel(msg) model.MessageType = "." + file.GetPackage() + "." + msg.GetName() { // default scanner var found bool for _, scanner := range model.Scanner { if scanner.Name == "" { found = true break } } if !found { model.Scanner = append(model.Scanner, &limbo.ScannerDescriptor{Fields: "*"}) } } for _, scanner := range model.Scanner { scanner.MessageType = "." + file.GetPackage() + "." + msg.GetName() } for _, field := range msg.GetField() { if column := limbo.GetColumn(field); column != nil { column.MessageType = "." + file.GetPackage() + "." + msg.GetName() column.FieldName = field.GetName() if column.Name == "" { column.Name = field.GetName() } model.Column = append(model.Column, column) } if join := limbo.GetJoin(field); join != nil { if field.GetType() != pb.FieldDescriptorProto_TYPE_MESSAGE { g.gen.Fail(field.GetName(), "in", msg.GetName(), "must be a message") } join.MessageType = "." + file.GetPackage() + "." + msg.GetName() join.FieldName = field.GetName() join.ForeignMessageType = field.GetTypeName() if join.Name == "" { join.Name = field.GetName() } if join.Key == "" { join.Key = field.GetName() + "_id" } if join.ForeignKey == "" { join.ForeignKey = "id" } model.Join = append(model.Join, join) } } sort.Sort(limbo.SortedColumnDescriptors(model.Column)) sort.Sort(limbo.SortedJoinDescriptors(model.Join)) sort.Sort(limbo.SortedScannerDescriptors(model.Scanner)) }
func (g *svcauth) getMessage(inputType *generator.Descriptor, path, input, output string, inputIsNullable bool) { var ( checks []string goPath string isNullable = inputIsNullable ) if path == "." { g.P(output, ` = `, input) return } goPath = input if inputIsNullable { checks = append(checks, input+" != nil") } for path != "" { // split path part := path idx := strings.IndexByte(path, '.') if idx >= 0 { part = path[:idx] path = path[idx+1:] } else { path = "" } // Get Field field := inputType.GetFieldDescriptor(part) if field == nil { g.gen.Fail("unknown field", strconv.Quote(part), "in message", inputType.GetName()) } if !field.IsMessage() { g.gen.Fail("expected a message") } // Append code fieldGoName := g.gen.GetFieldName(inputType, field) goPath += "." + fieldGoName if gogoproto.IsNullable(field) { checks = append(checks, goPath+" != nil") isNullable = true } else { isNullable = false } inputType = g.messages[strings.TrimPrefix(field.GetTypeName(), ".")] } if len(checks) > 0 { g.P(`if `, strings.Join(checks, " && "), `{`) if isNullable { g.P(output, ` = `, goPath) } else { g.P(output, ` = &`, goPath) } g.P(`}`) } else { if isNullable { g.P(output, ` = `, goPath) } else { g.P(output, ` = &`, goPath) } } }