func TestForeignKey(t *testing.T) { t1 := pqt.NewTable("left") c11 := pqt.NewColumn("id", pqt.TypeSerialBig()) c12 := pqt.NewColumn("name", pqt.TypeText()) t1.AddColumn(c11) t1.AddColumn(c12) t2 := pqt.NewTable("right") c21 := pqt.NewColumn("id", pqt.TypeSerialBig()) c22 := pqt.NewColumn("name", pqt.TypeText()) t2.AddColumn(c21) t2.AddColumn(c22) cstr := pqt.ForeignKey(t1, pqt.Columns{c11, c12}, pqt.Columns{c21, c22}) if cstr.Type != pqt.ConstraintTypeForeignKey { t.Errorf("wrong type, expected %s but got %s", pqt.ConstraintTypeForeignKey, cstr.Type) } if len(cstr.Columns) != 2 { t.Errorf("wrong number of columns, expected %d but got %d", 2, len(cstr.Columns)) } if len(cstr.ReferenceColumns) != 2 { t.Errorf("wrong number of columns, expected %d but got %d", 2, len(cstr.ReferenceColumns)) } if !reflect.DeepEqual(cstr.Table, t1) { t.Errorf("table does not match, expected %v but got %v", t1, cstr.Table) } if !reflect.DeepEqual(cstr.ReferenceTable, t2) { t.Errorf("reference table does not match, expected %v but got %v", t2, cstr.ReferenceTable) } }
func schema(sn string) *pqt.Schema { title := pqt.NewColumn("title", pqt.TypeText(), pqt.WithNotNull(), pqt.WithUnique()) lead := pqt.NewColumn("lead", pqt.TypeText()) news := pqt.NewTable("news", pqt.WithTableIfNotExists()). AddColumn(pqt.NewColumn("id", pqt.TypeSerialBig(), pqt.WithPrimaryKey())). AddColumn(title). AddColumn(lead). AddColumn(pqt.NewColumn("continue", pqt.TypeBool(), pqt.WithNotNull(), pqt.WithDefault("false"))). AddColumn(pqt.NewColumn("content", pqt.TypeText(), pqt.WithNotNull())). AddUnique(title, lead) comment := pqt.NewTable("comment", pqt.WithTableIfNotExists()). AddColumn(pqt.NewColumn("id", pqt.TypeSerialBig())). AddColumn(pqt.NewColumn("content", pqt.TypeText(), pqt.WithNotNull())). AddColumn(pqt.NewColumn( "news_title", pqt.TypeText(), pqt.WithNotNull(), pqt.WithReference(title, pqt.WithBidirectional(), pqt.WithOwnerName("comments_by_news_title"), pqt.WithInversedName("news_by_title")), )) category := pqt.NewTable("category", pqt.WithTableIfNotExists()). AddColumn(pqt.NewColumn("id", pqt.TypeSerialBig(), pqt.WithPrimaryKey())). AddColumn(pqt.NewColumn("name", pqt.TypeText(), pqt.WithNotNull())). AddColumn(pqt.NewColumn("content", pqt.TypeText(), pqt.WithNotNull())). AddRelationship( pqt.OneToMany( pqt.SelfReference(), pqt.WithBidirectional(), pqt.WithInversedName("child_category"), pqt.WithOwnerName("parent_category"), pqt.WithColumnName("parent_id"), ), ) pkg := pqt.NewTable("package", pqt.WithTableIfNotExists()). AddColumn(pqt.NewColumn("id", pqt.TypeSerialBig(), pqt.WithPrimaryKey())). AddColumn(pqt.NewColumn("break", pqt.TypeText())). AddRelationship(pqt.ManyToOne( category, pqt.WithBidirectional(), )) timestampable(news) timestampable(comment) timestampable(category) timestampable(pkg) comment.AddRelationship(pqt.ManyToOne(news, pqt.WithBidirectional(), pqt.WithInversedName("news_by_id")), pqt.WithNotNull()) pqt.ManyToMany(category, news, pqt.WithBidirectional()) return pqt.NewSchema(sn, pqt.WithSchemaIfNotExists()). AddTable(category). AddTable(pkg). AddTable(news). AddTable(comment) }
func generateBaseType(t pqt.Type, m int32) string { switch t { case pqt.TypeText(): return chooseType("string", "*ntypes.String", "*qtypes.String", m) case pqt.TypeBool(): return chooseType("bool", "*ntypes.Bool", "*ntypes.Bool", m) case pqt.TypeIntegerSmall(): return chooseType("int16", "*int16", "*int16", m) case pqt.TypeInteger(): return chooseType("int32", "*ntypes.Int32", "*ntypes.Int32", m) case pqt.TypeIntegerBig(): return chooseType("int64", "*ntypes.Int64", "*qtypes.Int64", m) case pqt.TypeSerial(): return chooseType("int32", "*ntypes.Int32", "*ntypes.Int32", m) case pqt.TypeSerialSmall(): return chooseType("int16", "*int16", "*int16", m) case pqt.TypeSerialBig(): return chooseType("int64", "*ntypes.Int64", "*qtypes.Int64", m) case pqt.TypeTimestamp(), pqt.TypeTimestampTZ(): return chooseType("time.Time", "*time.Time", "*qtypes.Timestamp", m) case pqt.TypeReal(): return chooseType("float32", "*ntypes.Float32", "*ntypes.Float32", m) case pqt.TypeDoublePrecision(): return chooseType("float64", "*ntypes.Float64", "*qtypes.Float64", m) case pqt.TypeBytea(), pqt.TypeJSON(), pqt.TypeJSONB(): return "[]byte" case pqt.TypeUUID(): return "uuid.UUID" default: gt := t.String() switch { case strings.HasPrefix(gt, "SMALLINT["): return chooseType("pqt.ArrayInt64", "pqt.ArrayInt64", "*qtypes.Int64", m) case strings.HasPrefix(gt, "INTEGER["): return chooseType("pqt.ArrayInt64", "pqt.ArrayInt64", "*qtypes.Int64", m) case strings.HasPrefix(gt, "BIGINT["): return chooseType("pqt.ArrayInt64", "pqt.ArrayInt64", "*qtypes.Int64", m) case strings.HasPrefix(gt, "DOUBLE PRECISION["): return chooseType("pqt.ArrayFloat64", "pqt.ArrayFloat64", "*qtypes.Float64", m) case strings.HasPrefix(gt, "TEXT["): return "pqt.ArrayString" case strings.HasPrefix(gt, "DECIMAL"), strings.HasPrefix(gt, "NUMERIC"): return chooseType("float64", "*ntypes.Float64", "*qtypes.Float64", m) case strings.HasPrefix(gt, "VARCHAR"): return chooseType("string", "*ntypes.String", "*qtypes.String", m) default: return "interface{}" } } }
func TestNewColumn(t *testing.T) { collate := "UTF-7" check := "username = '******'" r := pqt.NewColumn("username", pqt.TypeText()) c := pqt.NewColumn( "user_username", pqt.TypeText(), pqt.WithCollate(collate), pqt.WithCheck(check), pqt.WithDefault("janusz"), pqt.WithUnique(), pqt.WithTypeMapping(pqtgo.BuiltinType(types.Byte)), pqt.WithNotNull(), pqt.WithPrimaryKey(), pqt.WithReference(r), ) if c.Type.String() != pqt.TypeText().String() { t.Errorf("wrong column type, expected %s but got %s", pqt.TypeText().String(), c.Type.String()) } if c.Collate != collate { t.Errorf("wrong column collate, expected %s but got %s", collate, c.Collate) } if c.Check != check { t.Errorf("wrong column check, expected %s but got %s", check, c.Check) } if d, ok := c.Default[pqt.EventInsert]; ok && d != "janusz" { t.Errorf("wrong column default, expected %s but got %s", "janusz", d) } if !c.Unique { t.Error("wrong column unique, expected true but got false") } if !c.NotNull { t.Error("wrong column not null, expected true but got false") } if !c.PrimaryKey { t.Error("wrong column primary key, expected true but got false") } if c.Reference != r { t.Errorf("wrong column reference, expected %p but got %p", r, c.Reference) } constraints := c.Constraints() if len(constraints) != 3 { t.Errorf("wrong number of constraints, expected 3 but got %d", len(constraints)) } var hasPK, hasFK, hasCH bool for _, constraint := range constraints { switch constraint.Type { case pqt.ConstraintTypePrimaryKey: hasPK = true case pqt.ConstraintTypeForeignKey: hasFK = true case pqt.ConstraintTypeCheck: hasCH = true } } if !hasPK { t.Errorf("mising primary key constraint") } if !hasFK { t.Errorf("mising foreign key constraint") } if !hasCH { t.Errorf("mising check constraint") } }
func TestGenerator_Generate(t *testing.T) { success := []struct { expected string given *pqt.Table }{ { expected: `-- do not modify, generated by pqt CREATE TEMPORARY TABLE schema.user ( created_at TIMESTAMPTZ, password TEXT, username TEXT NOT NULL ); `, given: func() *pqt.Table { return pqt.NewTable("user", pqt.WithTemporary()). SetSchema(pqt.NewSchema("schema")). AddColumn(&pqt.Column{Name: "username", Type: pqt.TypeText(), NotNull: true}). AddColumn(&pqt.Column{Name: "password", Type: pqt.TypeText()}). AddColumn(&pqt.Column{Name: "created_at", Type: pqt.TypeTimestampTZ()}) }(), }, { expected: `-- do not modify, generated by pqt CREATE TABLE IF NOT EXISTS table_name ( created_at TIMESTAMPTZ DEFAULT NOW() NOT NULL, created_by INTEGER NOT NULL, enabled BOOL, end_at TIMESTAMPTZ NOT NULL, id SERIAL, name TEXT, price DECIMAL(10,1), rel_id INTEGER, slug TEXT NOT NULL, start_at TIMESTAMPTZ NOT NULL, updated_at TIMESTAMPTZ, updated_by INTEGER, CONSTRAINT "public.table_name_id_pkey" PRIMARY KEY (id), CONSTRAINT "public.table_name_name_key" UNIQUE (name), CONSTRAINT "public.table_name_slug_key" UNIQUE (slug), CONSTRAINT "public.table_name_rel_id_fkey" FOREIGN KEY (rel_id) REFERENCES related_table (id), CONSTRAINT "public.table_name_start_at_end_at_check" CHECK ((start_at IS NULL AND end_at IS NULL) OR start_at < end_at) ); `, given: func() *pqt.Table { id := pqt.Column{Name: "id", Type: pqt.TypeSerial()} startAt := &pqt.Column{Name: "start_at", Type: pqt.TypeTimestampTZ(), NotNull: true} endAt := &pqt.Column{Name: "end_at", Type: pqt.TypeTimestampTZ(), NotNull: true} _ = pqt.NewTable("related_table"). AddColumn(&id) return pqt.NewTable("table_name", pqt.WithTableIfNotExists()). AddColumn(&pqt.Column{Name: "id", Type: pqt.TypeSerial(), PrimaryKey: true}). AddColumn(&pqt.Column{Name: "rel_id", Type: pqt.TypeInteger(), Reference: &id}). AddColumn(&pqt.Column{Name: "name", Type: pqt.TypeText(), Unique: true}). AddColumn(&pqt.Column{Name: "enabled", Type: pqt.TypeBool()}). AddColumn(&pqt.Column{Name: "price", Type: pqt.TypeDecimal(10, 1)}). AddColumn(startAt). AddColumn(endAt). AddColumn(pqt.NewColumn("created_at", pqt.TypeTimestampTZ(), pqt.WithNotNull(), pqt.WithDefault("NOW()"))). AddColumn(&pqt.Column{Name: "created_by", Type: pqt.TypeInteger(), NotNull: true}). AddColumn(&pqt.Column{Name: "updated_at", Type: pqt.TypeTimestampTZ()}). AddColumn(&pqt.Column{Name: "updated_by", Type: pqt.TypeInteger()}). AddColumn(&pqt.Column{Name: "slug", Type: pqt.TypeText(), NotNull: true, Unique: true}). AddCheck("(start_at IS NULL AND end_at IS NULL) OR start_at < end_at", startAt, endAt) }(), }, } for i, data := range success { q, err := pqtsql.NewGenerator().Generate(&pqt.Schema{ Tables: []*pqt.Table{data.given}, }) if err != nil { t.Errorf("unexpected error for schema #%d : %s", i, err.Error()) continue } if string(q) != data.expected { t.Errorf("wrong query, expected:\n'%s'\nbut got:\n'%s'", data.expected, q) } } }
func TestGenerator_Generate(t *testing.T) { cases := map[string]struct { schema *pqt.Schema generator *pqtgo.Generator expected string }{ "basic": { schema: pqt.NewSchema("text"), generator: pqtgo.NewGenerator(). AddImport("github.com/piotrkowalczuk/ntypes"), expected: `package main import ( "github.com/go-kit/kit/log" "github.com/m4rw3r/uuid" "github.com/piotrkowalczuk/ntypes" ) `, }, "custom-package": { schema: pqt.NewSchema("text"), generator: pqtgo.NewGenerator(). SetPackage("example"). AddImport("fmt"), expected: `package example import ( "github.com/go-kit/kit/log" "github.com/m4rw3r/uuid" "fmt" ) `, }, "simple table": { schema: pqt.NewSchema("text").AddTable( pqt.NewTable("first").AddColumn( pqt.NewColumn("id", pqt.TypeSerialBig()), ).AddColumn( pqt.NewColumn("name", pqt.TypeText()), ), ), generator: pqtgo.NewGenerator(). SetPackage("custom"), expected: `package custom import ( "github.com/go-kit/kit/log" "github.com/m4rw3r/uuid" ) const ( tableFirst = "text.first" tableFirstColumnId = "id" tableFirstColumnName = "name" ) var ( tableFirstColumns = []string{ tableFirstColumnId, tableFirstColumnName, }) type firstEntity struct{ // id ... id *ntypes.Int64 // name ... name *ntypes.String } func (e *firstEntity) prop(cn string) (interface{}, bool) { switch cn { case tableFirstColumnId: return &e.id, true case tableFirstColumnName: return &e.name, true default: return nil, false } } func (e *firstEntity) props(cns ...string) ([]interface{}, error) { res := make([]interface{}, 0, len(cns)) for _, cn := range cns { if prop, ok := e.prop(cn); ok { res = append(res, prop) } else { return nil, fmt.Errorf("unexpected column provided: %s", cn) } } return res, nil } // firstIterator is not thread safe. type firstIterator struct { rows *sql.Rows cols []string } func (i *firstIterator) Next() bool { return i.rows.Next() } func (i *firstIterator) Close() error { return i.rows.Close() } func (i *firstIterator) Err() error { return i.rows.Err() } // Columns is wrapper around sql.Rows.Columns method, that also cache outpu inside iterator. func (i *firstIterator) Columns() ([]string, error) { if i.cols == nil { cols, err := i.rows.Columns() if err != nil { return nil, err } i.cols = cols } return i.cols, nil } // Ent is wrapper around first method that makes iterator more generic. func (i *firstIterator) Ent() (interface{}, error) { return i.First() } func (i *firstIterator) First() (*firstEntity, error) { var ent firstEntity cols, err := i.rows.Columns() if err != nil { return nil, err } props, err := ent.props(cols...) if err != nil { return nil, err } if err := i.rows.Scan(props...); err != nil { return nil, err } return &ent, nil } type firstCriteria struct { offset, limit int64 sort map[string]bool id *qtypes.Int64 name *qtypes.String } func (c *firstCriteria) WriteComposition(sel string, com *pqtgo.Composer, opt *pqtgo.CompositionOpts) (err error) { if err = pqtgo.WriteCompositionQueryInt64(c.id, tableFirstColumnId, com, pqtgo.And); err != nil { return } if err = pqtgo.WriteCompositionQueryString(c.name, tableFirstColumnName, com, pqtgo.And); err != nil { return } if len(c.sort) > 0 { i:=0 com.WriteString(" ORDER BY ") for cn, asc := range c.sort { for _, tcn := range tableFirstColumns { if cn == tcn { if i > 0 { com.WriteString(", ") } com.WriteString(cn) if !asc { com.WriteString(" DESC ") } i++ break } } } } if c.offset > 0 { if _, err = com.WriteString(" OFFSET "); err != nil { return } if err = com.WritePlaceholder(); err != nil { return } if _, err = com.WriteString(" "); err != nil { return } com.Add(c.offset) } if c.limit > 0 { if _, err = com.WriteString(" LIMIT "); err != nil { return } if err = com.WritePlaceholder(); err != nil { return } if _, err = com.WriteString(" "); err != nil { return } com.Add(c.limit) } return } type firstPatch struct { id *ntypes.Int64 name *ntypes.String } type firstRepositoryBase struct { table string columns []string db *sql.DB dbg bool log log.Logger } func scanFirstRows(rows *sql.Rows) ([]*firstEntity, error) { var ( entities []*firstEntity err error ) for rows.Next() { var ent firstEntity err = rows.Scan( &ent.id, &ent.name, ) if err != nil { return nil, err } entities = append(entities, &ent) } if rows.Err() != nil { return nil, rows.Err() } return entities, nil } func (r *firstRepositoryBase) count(c *firstCriteria) (int64, error) { com := pqtgo.NewComposer(2) buf := bytes.NewBufferString("SELECT COUNT(*) FROM ") buf.WriteString(r.table) if err := c.WriteComposition("", com, pqtgo.And); err != nil { return 0, err } if com.Dirty { buf.WriteString(" WHERE ") } if com.Len() > 0 { buf.ReadFrom(com) } if r.dbg { if err := r.log.Log("msg", buf.String(), "function", "Count"); err != nil { return 0, err } } var count int64 if err := r.db.QueryRow(buf.String(), com.Args()...).Scan(&count); err != nil { return 0, err } return count, nil } func (r *firstRepositoryBase) find(c *firstCriteria) ([]*firstEntity, error) { com := pqtgo.NewComposer(1) buf := bytes.NewBufferString("SELECT ") buf.WriteString(strings.Join(r.columns, ", ")) buf.WriteString(" FROM ") buf.WriteString(r.table) buf.WriteString(" ") if err := c.WriteComposition("", com, pqtgo.And); err != nil { return nil, err } if com.Dirty { buf.WriteString(" WHERE ") } if com.Len() > 0 { buf.ReadFrom(com) } if r.dbg { if err := r.log.Log("msg", buf.String(), "function", "Find"); err != nil { return nil, err } } rows, err := r.db.Query(buf.String(), com.Args()...) if err != nil { return nil, err } defer rows.Close() return scanFirstRows(rows) } func (r *firstRepositoryBase) findIter(c *firstCriteria) (*firstIterator, error) { com := pqtgo.NewComposer(1) buf := bytes.NewBufferString("SELECT ") buf.WriteString(strings.Join(r.columns, ", ")) buf.WriteString(" FROM ") buf.WriteString(r.table) buf.WriteString(" ") if err := c.WriteComposition("", com, pqtgo.And); err != nil { return nil, err } if com.Dirty { buf.WriteString(" WHERE ") } if com.Len() > 0 { buf.ReadFrom(com) } if r.dbg { if err := r.log.Log("msg", buf.String(), "function", "Find"); err != nil { return nil, err } } rows, err := r.db.Query(buf.String(), com.Args()...) if err != nil { return nil, err } return &firstIterator{rows: rows}, nil } func (r *firstRepositoryBase) insert(e *firstEntity) (*firstEntity, error) { insert := pqcomp.New(0, 2) insert.AddExpr(tableFirstColumnName, "", e.name) b := bytes.NewBufferString("INSERT INTO " + r.table) if insert.Len() != 0 { b.WriteString(" (") for insert.Next() { if !insert.First() { b.WriteString(", ") } fmt.Fprintf(b, "%s", insert.Key()) } insert.Reset() b.WriteString(") VALUES (") for insert.Next() { if !insert.First() { b.WriteString(", ") } fmt.Fprintf(b, "%s", insert.PlaceHolder()) } b.WriteString(")") if len(r.columns) > 0 { b.WriteString(" RETURNING ") b.WriteString(strings.Join(r.columns, ", ")) } } if r.dbg { if err := r.log.Log("msg", b.String(), "function", "Insert"); err != nil { return nil, err } } err := r.db.QueryRow(b.String(), insert.Args()...).Scan( &e.id, &e.name, ) if err != nil { return nil, err } return e, nil } func (r *firstRepositoryBase) upsert(e *firstEntity, p *firstPatch, inf ...string) (*firstEntity, error) { insert := pqcomp.New(0, 2) update := insert.Compose(2) insert.AddExpr(tableFirstColumnName, "", e.name) if len(inf) > 0 { update.AddExpr(tableFirstColumnName, "=", p.name) } b := bytes.NewBufferString("INSERT INTO " + r.table) if insert.Len() > 0 { b.WriteString(" (") for insert.Next() { if !insert.First() { b.WriteString(", ") } fmt.Fprintf(b, "%s", insert.Key()) } insert.Reset() b.WriteString(") VALUES (") for insert.Next() { if !insert.First() { b.WriteString(", ") } fmt.Fprintf(b, "%s", insert.PlaceHolder()) } b.WriteString(")") } b.WriteString(" ON CONFLICT ") if len(inf) > 0 && update.Len() > 0 { b.WriteString(" (") for j, i := range inf { if j != 0 { b.WriteString(", ") } b.WriteString(i) } b.WriteString(") ") b.WriteString(" DO UPDATE SET ") for update.Next() { if !update.First() { b.WriteString(", ") } b.WriteString(update.Key()) b.WriteString(" ") b.WriteString(update.Oper()) b.WriteString(" ") b.WriteString(update.PlaceHolder()) } } else { b.WriteString(" DO NOTHING ") } if insert.Len() > 0 { if len(r.columns) > 0 { b.WriteString(" RETURNING ") b.WriteString(strings.Join(r.columns, ", ")) } } if r.dbg { if err := r.log.Log("msg", b.String(), "function", "Upsert"); err != nil { return nil, err } } err := r.db.QueryRow(b.String(), insert.Args()...).Scan( &e.id, &e.name, ) if err != nil { return nil, err } return e, nil } `, }, } for hint, c := range cases { b, err := c.generator.Generate(c.schema) if err != nil { t.Errorf("%s: unexpected error: %s", hint, err.Error()) continue } assertGoCode(t, c.expected, string(b), hint) } }