func TestForeignKey(t *testing.T) { t1 := pqt.NewTable("left") c11 := pqt.NewColumn("id", pqt.TypeSerialBig()) c12 := pqt.NewColumn("name", pqt.TypeText()) t1.AddColumn(c11) t1.AddColumn(c12) t2 := pqt.NewTable("right") c21 := pqt.NewColumn("id", pqt.TypeSerialBig()) c22 := pqt.NewColumn("name", pqt.TypeText()) t2.AddColumn(c21) t2.AddColumn(c22) cstr := pqt.ForeignKey(t1, pqt.Columns{c11, c12}, pqt.Columns{c21, c22}) if cstr.Type != pqt.ConstraintTypeForeignKey { t.Errorf("wrong type, expected %s but got %s", pqt.ConstraintTypeForeignKey, cstr.Type) } if len(cstr.Columns) != 2 { t.Errorf("wrong number of columns, expected %d but got %d", 2, len(cstr.Columns)) } if len(cstr.ReferenceColumns) != 2 { t.Errorf("wrong number of columns, expected %d but got %d", 2, len(cstr.ReferenceColumns)) } if !reflect.DeepEqual(cstr.Table, t1) { t.Errorf("table does not match, expected %v but got %v", t1, cstr.Table) } if !reflect.DeepEqual(cstr.ReferenceTable, t2) { t.Errorf("reference table does not match, expected %v but got %v", t2, cstr.ReferenceTable) } }
func schema(sn string) *pqt.Schema { title := pqt.NewColumn("title", pqt.TypeText(), pqt.WithNotNull(), pqt.WithUnique()) lead := pqt.NewColumn("lead", pqt.TypeText()) news := pqt.NewTable("news", pqt.WithTableIfNotExists()). AddColumn(pqt.NewColumn("id", pqt.TypeSerialBig(), pqt.WithPrimaryKey())). AddColumn(title). AddColumn(lead). AddColumn(pqt.NewColumn("continue", pqt.TypeBool(), pqt.WithNotNull(), pqt.WithDefault("false"))). AddColumn(pqt.NewColumn("content", pqt.TypeText(), pqt.WithNotNull())). AddUnique(title, lead) comment := pqt.NewTable("comment", pqt.WithTableIfNotExists()). AddColumn(pqt.NewColumn("id", pqt.TypeSerialBig())). AddColumn(pqt.NewColumn("content", pqt.TypeText(), pqt.WithNotNull())). AddColumn(pqt.NewColumn( "news_title", pqt.TypeText(), pqt.WithNotNull(), pqt.WithReference(title, pqt.WithBidirectional(), pqt.WithOwnerName("comments_by_news_title"), pqt.WithInversedName("news_by_title")), )) category := pqt.NewTable("category", pqt.WithTableIfNotExists()). AddColumn(pqt.NewColumn("id", pqt.TypeSerialBig(), pqt.WithPrimaryKey())). AddColumn(pqt.NewColumn("name", pqt.TypeText(), pqt.WithNotNull())). AddColumn(pqt.NewColumn("content", pqt.TypeText(), pqt.WithNotNull())). AddRelationship( pqt.OneToMany( pqt.SelfReference(), pqt.WithBidirectional(), pqt.WithInversedName("child_category"), pqt.WithOwnerName("parent_category"), pqt.WithColumnName("parent_id"), ), ) pkg := pqt.NewTable("package", pqt.WithTableIfNotExists()). AddColumn(pqt.NewColumn("id", pqt.TypeSerialBig(), pqt.WithPrimaryKey())). AddColumn(pqt.NewColumn("break", pqt.TypeText())). AddRelationship(pqt.ManyToOne( category, pqt.WithBidirectional(), )) timestampable(news) timestampable(comment) timestampable(category) timestampable(pkg) comment.AddRelationship(pqt.ManyToOne(news, pqt.WithBidirectional(), pqt.WithInversedName("news_by_id")), pqt.WithNotNull()) pqt.ManyToMany(category, news, pqt.WithBidirectional()) return pqt.NewSchema(sn, pqt.WithSchemaIfNotExists()). AddTable(category). AddTable(pkg). AddTable(news). AddTable(comment) }
func TestTable_AddColumn(t *testing.T) { c0 := pqt.NewColumn("c0", pqt.TypeSerialBig(), pqt.WithPrimaryKey()) c1 := &pqt.Column{Name: "c1"} c2 := &pqt.Column{Name: "c2"} c3 := &pqt.Column{Name: "c3"} tbl := pqt.NewTable("test"). AddColumn(c0). AddColumn(c1). AddColumn(c2). AddColumn(c3). AddColumn(pqt.NewColumn("c4", pqt.TypeIntegerBig(), pqt.WithReference(c0))). AddRelationship(pqt.ManyToOne(pqt.SelfReference())) if len(tbl.Columns) != 6 { t.Errorf("wrong number of colums, expected %d but got %d", 6, len(tbl.Columns)) } if len(tbl.OwnedRelationships) != 2 { // Reference is not a relationship t.Errorf("wrong number of owned relationships, expected %d but got %d", 2, len(tbl.OwnedRelationships)) } for i, c := range tbl.Columns { if c.Name == "" { t.Errorf("column #%d table name is empty", i) } if c.Table == nil { t.Errorf("column #%d table nil pointer", i) } } }
func generateBaseType(t pqt.Type, m int32) string { switch t { case pqt.TypeText(): return chooseType("string", "*ntypes.String", "*qtypes.String", m) case pqt.TypeBool(): return chooseType("bool", "*ntypes.Bool", "*ntypes.Bool", m) case pqt.TypeIntegerSmall(): return chooseType("int16", "*int16", "*int16", m) case pqt.TypeInteger(): return chooseType("int32", "*ntypes.Int32", "*ntypes.Int32", m) case pqt.TypeIntegerBig(): return chooseType("int64", "*ntypes.Int64", "*qtypes.Int64", m) case pqt.TypeSerial(): return chooseType("int32", "*ntypes.Int32", "*ntypes.Int32", m) case pqt.TypeSerialSmall(): return chooseType("int16", "*int16", "*int16", m) case pqt.TypeSerialBig(): return chooseType("int64", "*ntypes.Int64", "*qtypes.Int64", m) case pqt.TypeTimestamp(), pqt.TypeTimestampTZ(): return chooseType("time.Time", "*time.Time", "*qtypes.Timestamp", m) case pqt.TypeReal(): return chooseType("float32", "*ntypes.Float32", "*ntypes.Float32", m) case pqt.TypeDoublePrecision(): return chooseType("float64", "*ntypes.Float64", "*qtypes.Float64", m) case pqt.TypeBytea(), pqt.TypeJSON(), pqt.TypeJSONB(): return "[]byte" case pqt.TypeUUID(): return "uuid.UUID" default: gt := t.String() switch { case strings.HasPrefix(gt, "SMALLINT["): return chooseType("pqt.ArrayInt64", "pqt.ArrayInt64", "*qtypes.Int64", m) case strings.HasPrefix(gt, "INTEGER["): return chooseType("pqt.ArrayInt64", "pqt.ArrayInt64", "*qtypes.Int64", m) case strings.HasPrefix(gt, "BIGINT["): return chooseType("pqt.ArrayInt64", "pqt.ArrayInt64", "*qtypes.Int64", m) case strings.HasPrefix(gt, "DOUBLE PRECISION["): return chooseType("pqt.ArrayFloat64", "pqt.ArrayFloat64", "*qtypes.Float64", m) case strings.HasPrefix(gt, "TEXT["): return "pqt.ArrayString" case strings.HasPrefix(gt, "DECIMAL"), strings.HasPrefix(gt, "NUMERIC"): return chooseType("float64", "*ntypes.Float64", "*qtypes.Float64", m) case strings.HasPrefix(gt, "VARCHAR"): return chooseType("string", "*ntypes.String", "*qtypes.String", m) default: return "interface{}" } } }
func (g *Generator) generateRepositoryUpsert(code *bytes.Buffer, table *pqt.Table) { if g.ver < 9.5 { return } entityName := g.name(table.Name) fmt.Fprintf(code, `func (r *%sRepositoryBase) %s(e *%sEntity, p *%sPatch, inf ...string) (*%sEntity, error) {`, entityName, g.name("Upsert"), entityName, entityName, entityName, ) fmt.Fprintf(code, ` insert := pqcomp.New(0, %d) update := insert.Compose(%d) `, len(table.Columns), len(table.Columns)) InsertLoop: for _, c := range table.Columns { switch c.Type { case pqt.TypeSerial(), pqt.TypeSerialBig(), pqt.TypeSerialSmall(): continue InsertLoop default: if g.canBeNil(c, modeOptional) { fmt.Fprintf(code, ` if e.%s != nil { insert.AddExpr(%s, "", e.%s) } `, g.propertyName(c.Name), g.columnNameWithTableName(table.Name, c.Name), g.propertyName(c.Name), ) } else { fmt.Fprintf(code, `insert.AddExpr(%s, "", e.%s)`, g.columnNameWithTableName(table.Name, c.Name), g.propertyName(c.Name), ) } fmt.Fprintln(code, "") } } fmt.Fprintln(code, "if len(inf) > 0 {") UpdateLoop: for _, c := range table.Columns { switch c.Type { case pqt.TypeSerial(), pqt.TypeSerialBig(), pqt.TypeSerialSmall(): continue UpdateLoop default: if g.canBeNil(c, modeOptional) { fmt.Fprintf(code, ` if p.%s != nil { update.AddExpr(%s, "=", p.%s) } `, g.propertyName(c.Name), g.columnNameWithTableName(table.Name, c.Name), g.propertyName(c.Name), ) } else { fmt.Fprintf(code, `update.AddExpr(%s, "=", p.%s)`, g.columnNameWithTableName(table.Name, c.Name), g.propertyName(c.Name)) } fmt.Fprintln(code, "") } } fmt.Fprintln(code, "}") fmt.Fprint(code, ` b := bytes.NewBufferString("INSERT INTO " + r.table) if insert.Len() > 0 { b.WriteString(" (") for insert.Next() { if !insert.First() { b.WriteString(", ") } fmt.Fprintf(b, "%s", insert.Key()) } insert.Reset() b.WriteString(") VALUES (") for insert.Next() { if !insert.First() { b.WriteString(", ") } fmt.Fprintf(b, "%s", insert.PlaceHolder()) } b.WriteString(")") } b.WriteString(" ON CONFLICT ") if len(inf) > 0 && update.Len() > 0 { b.WriteString(" (") for j, i := range inf { if j != 0 { b.WriteString(", ") } b.WriteString(i) } b.WriteString(") ") b.WriteString(" DO UPDATE SET ") for update.Next() { if !update.First() { b.WriteString(", ") } b.WriteString(update.Key()) b.WriteString(" ") b.WriteString(update.Oper()) b.WriteString(" ") b.WriteString(update.PlaceHolder()) } } else { b.WriteString(" DO NOTHING ") } if insert.Len() > 0 { if len(r.columns) > 0 { b.WriteString(" RETURNING ") b.WriteString(strings.Join(r.columns, ", ")) } } if r.dbg { if err := r.log.Log("msg", b.String(), "function", "Upsert"); err != nil { return nil, err } } err := r.db.QueryRow(b.String(), insert.Args()...).Scan( `) for _, c := range table.Columns { fmt.Fprintf(code, "&e.%s,\n", g.propertyName(c.Name)) } fmt.Fprint(code, `) if err != nil { return nil, err } return e, nil } `) }
func (g *Generator) generateRepositoryInsert(w io.Writer, table *pqt.Table) { entityName := g.name(table.Name) fmt.Fprintf(w, `func (r *%sRepositoryBase) %s(e *%sEntity) (*%sEntity, error) {`, entityName, g.name("Insert"), entityName, entityName) fmt.Fprintf(w, ` insert := pqcomp.New(0, %d) `, len(table.Columns)) ColumnsLoop: for _, c := range table.Columns { switch c.Type { case pqt.TypeSerial(), pqt.TypeSerialBig(), pqt.TypeSerialSmall(): continue ColumnsLoop default: if g.canBeNil(c, modeOptional) { fmt.Fprintf(w, ` if e.%s != nil { insert.AddExpr(%s, "", e.%s) } `, g.propertyName(c.Name), g.columnNameWithTableName(table.Name, c.Name), g.propertyName(c.Name), ) } else { fmt.Fprintf( w, `insert.AddExpr(%s, "", e.%s)`, g.columnNameWithTableName(table.Name, c.Name), g.propertyName(c.Name), ) } fmt.Fprintln(w, "") } } fmt.Fprint(w, ` b := bytes.NewBufferString("INSERT INTO " + r.table) if insert.Len() != 0 { b.WriteString(" (") for insert.Next() { if !insert.First() { b.WriteString(", ") } fmt.Fprintf(b, "%s", insert.Key()) } insert.Reset() b.WriteString(") VALUES (") for insert.Next() { if !insert.First() { b.WriteString(", ") } fmt.Fprintf(b, "%s", insert.PlaceHolder()) } b.WriteString(")") if len(r.columns) > 0 { b.WriteString(" RETURNING ") b.WriteString(strings.Join(r.columns, ", ")) } } if r.dbg { if err := r.log.Log("msg", b.String(), "function", "Insert"); err != nil { return nil, err } } err := r.db.QueryRow(b.String(), insert.Args()...).Scan( `) for _, c := range table.Columns { fmt.Fprintf(w, "&e.%s,\n", g.propertyName(c.Name)) } fmt.Fprint(w, `) if err != nil { return nil, err } return e, nil } `) }
func TestGenerator_Generate(t *testing.T) { cases := map[string]struct { schema *pqt.Schema generator *pqtgo.Generator expected string }{ "basic": { schema: pqt.NewSchema("text"), generator: pqtgo.NewGenerator(). AddImport("github.com/piotrkowalczuk/ntypes"), expected: `package main import ( "github.com/go-kit/kit/log" "github.com/m4rw3r/uuid" "github.com/piotrkowalczuk/ntypes" ) `, }, "custom-package": { schema: pqt.NewSchema("text"), generator: pqtgo.NewGenerator(). SetPackage("example"). AddImport("fmt"), expected: `package example import ( "github.com/go-kit/kit/log" "github.com/m4rw3r/uuid" "fmt" ) `, }, "simple table": { schema: pqt.NewSchema("text").AddTable( pqt.NewTable("first").AddColumn( pqt.NewColumn("id", pqt.TypeSerialBig()), ).AddColumn( pqt.NewColumn("name", pqt.TypeText()), ), ), generator: pqtgo.NewGenerator(). SetPackage("custom"), expected: `package custom import ( "github.com/go-kit/kit/log" "github.com/m4rw3r/uuid" ) const ( tableFirst = "text.first" tableFirstColumnId = "id" tableFirstColumnName = "name" ) var ( tableFirstColumns = []string{ tableFirstColumnId, tableFirstColumnName, }) type firstEntity struct{ // id ... id *ntypes.Int64 // name ... name *ntypes.String } func (e *firstEntity) prop(cn string) (interface{}, bool) { switch cn { case tableFirstColumnId: return &e.id, true case tableFirstColumnName: return &e.name, true default: return nil, false } } func (e *firstEntity) props(cns ...string) ([]interface{}, error) { res := make([]interface{}, 0, len(cns)) for _, cn := range cns { if prop, ok := e.prop(cn); ok { res = append(res, prop) } else { return nil, fmt.Errorf("unexpected column provided: %s", cn) } } return res, nil } // firstIterator is not thread safe. type firstIterator struct { rows *sql.Rows cols []string } func (i *firstIterator) Next() bool { return i.rows.Next() } func (i *firstIterator) Close() error { return i.rows.Close() } func (i *firstIterator) Err() error { return i.rows.Err() } // Columns is wrapper around sql.Rows.Columns method, that also cache outpu inside iterator. func (i *firstIterator) Columns() ([]string, error) { if i.cols == nil { cols, err := i.rows.Columns() if err != nil { return nil, err } i.cols = cols } return i.cols, nil } // Ent is wrapper around first method that makes iterator more generic. func (i *firstIterator) Ent() (interface{}, error) { return i.First() } func (i *firstIterator) First() (*firstEntity, error) { var ent firstEntity cols, err := i.rows.Columns() if err != nil { return nil, err } props, err := ent.props(cols...) if err != nil { return nil, err } if err := i.rows.Scan(props...); err != nil { return nil, err } return &ent, nil } type firstCriteria struct { offset, limit int64 sort map[string]bool id *qtypes.Int64 name *qtypes.String } func (c *firstCriteria) WriteComposition(sel string, com *pqtgo.Composer, opt *pqtgo.CompositionOpts) (err error) { if err = pqtgo.WriteCompositionQueryInt64(c.id, tableFirstColumnId, com, pqtgo.And); err != nil { return } if err = pqtgo.WriteCompositionQueryString(c.name, tableFirstColumnName, com, pqtgo.And); err != nil { return } if len(c.sort) > 0 { i:=0 com.WriteString(" ORDER BY ") for cn, asc := range c.sort { for _, tcn := range tableFirstColumns { if cn == tcn { if i > 0 { com.WriteString(", ") } com.WriteString(cn) if !asc { com.WriteString(" DESC ") } i++ break } } } } if c.offset > 0 { if _, err = com.WriteString(" OFFSET "); err != nil { return } if err = com.WritePlaceholder(); err != nil { return } if _, err = com.WriteString(" "); err != nil { return } com.Add(c.offset) } if c.limit > 0 { if _, err = com.WriteString(" LIMIT "); err != nil { return } if err = com.WritePlaceholder(); err != nil { return } if _, err = com.WriteString(" "); err != nil { return } com.Add(c.limit) } return } type firstPatch struct { id *ntypes.Int64 name *ntypes.String } type firstRepositoryBase struct { table string columns []string db *sql.DB dbg bool log log.Logger } func scanFirstRows(rows *sql.Rows) ([]*firstEntity, error) { var ( entities []*firstEntity err error ) for rows.Next() { var ent firstEntity err = rows.Scan( &ent.id, &ent.name, ) if err != nil { return nil, err } entities = append(entities, &ent) } if rows.Err() != nil { return nil, rows.Err() } return entities, nil } func (r *firstRepositoryBase) count(c *firstCriteria) (int64, error) { com := pqtgo.NewComposer(2) buf := bytes.NewBufferString("SELECT COUNT(*) FROM ") buf.WriteString(r.table) if err := c.WriteComposition("", com, pqtgo.And); err != nil { return 0, err } if com.Dirty { buf.WriteString(" WHERE ") } if com.Len() > 0 { buf.ReadFrom(com) } if r.dbg { if err := r.log.Log("msg", buf.String(), "function", "Count"); err != nil { return 0, err } } var count int64 if err := r.db.QueryRow(buf.String(), com.Args()...).Scan(&count); err != nil { return 0, err } return count, nil } func (r *firstRepositoryBase) find(c *firstCriteria) ([]*firstEntity, error) { com := pqtgo.NewComposer(1) buf := bytes.NewBufferString("SELECT ") buf.WriteString(strings.Join(r.columns, ", ")) buf.WriteString(" FROM ") buf.WriteString(r.table) buf.WriteString(" ") if err := c.WriteComposition("", com, pqtgo.And); err != nil { return nil, err } if com.Dirty { buf.WriteString(" WHERE ") } if com.Len() > 0 { buf.ReadFrom(com) } if r.dbg { if err := r.log.Log("msg", buf.String(), "function", "Find"); err != nil { return nil, err } } rows, err := r.db.Query(buf.String(), com.Args()...) if err != nil { return nil, err } defer rows.Close() return scanFirstRows(rows) } func (r *firstRepositoryBase) findIter(c *firstCriteria) (*firstIterator, error) { com := pqtgo.NewComposer(1) buf := bytes.NewBufferString("SELECT ") buf.WriteString(strings.Join(r.columns, ", ")) buf.WriteString(" FROM ") buf.WriteString(r.table) buf.WriteString(" ") if err := c.WriteComposition("", com, pqtgo.And); err != nil { return nil, err } if com.Dirty { buf.WriteString(" WHERE ") } if com.Len() > 0 { buf.ReadFrom(com) } if r.dbg { if err := r.log.Log("msg", buf.String(), "function", "Find"); err != nil { return nil, err } } rows, err := r.db.Query(buf.String(), com.Args()...) if err != nil { return nil, err } return &firstIterator{rows: rows}, nil } func (r *firstRepositoryBase) insert(e *firstEntity) (*firstEntity, error) { insert := pqcomp.New(0, 2) insert.AddExpr(tableFirstColumnName, "", e.name) b := bytes.NewBufferString("INSERT INTO " + r.table) if insert.Len() != 0 { b.WriteString(" (") for insert.Next() { if !insert.First() { b.WriteString(", ") } fmt.Fprintf(b, "%s", insert.Key()) } insert.Reset() b.WriteString(") VALUES (") for insert.Next() { if !insert.First() { b.WriteString(", ") } fmt.Fprintf(b, "%s", insert.PlaceHolder()) } b.WriteString(")") if len(r.columns) > 0 { b.WriteString(" RETURNING ") b.WriteString(strings.Join(r.columns, ", ")) } } if r.dbg { if err := r.log.Log("msg", b.String(), "function", "Insert"); err != nil { return nil, err } } err := r.db.QueryRow(b.String(), insert.Args()...).Scan( &e.id, &e.name, ) if err != nil { return nil, err } return e, nil } func (r *firstRepositoryBase) upsert(e *firstEntity, p *firstPatch, inf ...string) (*firstEntity, error) { insert := pqcomp.New(0, 2) update := insert.Compose(2) insert.AddExpr(tableFirstColumnName, "", e.name) if len(inf) > 0 { update.AddExpr(tableFirstColumnName, "=", p.name) } b := bytes.NewBufferString("INSERT INTO " + r.table) if insert.Len() > 0 { b.WriteString(" (") for insert.Next() { if !insert.First() { b.WriteString(", ") } fmt.Fprintf(b, "%s", insert.Key()) } insert.Reset() b.WriteString(") VALUES (") for insert.Next() { if !insert.First() { b.WriteString(", ") } fmt.Fprintf(b, "%s", insert.PlaceHolder()) } b.WriteString(")") } b.WriteString(" ON CONFLICT ") if len(inf) > 0 && update.Len() > 0 { b.WriteString(" (") for j, i := range inf { if j != 0 { b.WriteString(", ") } b.WriteString(i) } b.WriteString(") ") b.WriteString(" DO UPDATE SET ") for update.Next() { if !update.First() { b.WriteString(", ") } b.WriteString(update.Key()) b.WriteString(" ") b.WriteString(update.Oper()) b.WriteString(" ") b.WriteString(update.PlaceHolder()) } } else { b.WriteString(" DO NOTHING ") } if insert.Len() > 0 { if len(r.columns) > 0 { b.WriteString(" RETURNING ") b.WriteString(strings.Join(r.columns, ", ")) } } if r.dbg { if err := r.log.Log("msg", b.String(), "function", "Upsert"); err != nil { return nil, err } } err := r.db.QueryRow(b.String(), insert.Args()...).Scan( &e.id, &e.name, ) if err != nil { return nil, err } return e, nil } `, }, } for hint, c := range cases { b, err := c.generator.Generate(c.schema) if err != nil { t.Errorf("%s: unexpected error: %s", hint, err.Error()) continue } assertGoCode(t, c.expected, string(b), hint) } }