func TestTable_AddRelationship_oneToOneUnidirectional(t *testing.T) { user := pqt.NewTable("user").AddColumn(pqt.NewColumn("id", pqt.TypeSerial(), pqt.WithPrimaryKey())) userDetail := pqt.NewTable("user_detail").AddColumn(pqt.NewColumn("id", pqt.TypeSerial(), pqt.WithPrimaryKey())). AddRelationship(pqt.OneToOne( user, pqt.WithInversedName("user"), pqt.WithOwnerName("details"), )) if len(user.InversedRelationships) != 0 { t.Fatalf("user should have 0 relationship, but has %d", len(user.InversedRelationships)) } if len(userDetail.OwnedRelationships) != 1 { t.Fatalf("user_detail should have 1 relationship, but has %d", len(userDetail.OwnedRelationships)) } if userDetail.OwnedRelationships[0].InversedName != "user" { t.Errorf("user_detail relationship to user should be mapped by user") } if userDetail.OwnedRelationships[0].InversedTable != user { t.Errorf("user_detail relationship to user should be mapped by user table") } if userDetail.OwnedRelationships[0].Type != pqt.RelationshipTypeOneToOne { t.Errorf("user_detail relationship to user should be %d, but is %d", pqt.RelationshipTypeOneToOne, userDetail.OwnedRelationships[0].Type) } }
func TestWithColumnName(t *testing.T) { icn := "author" t1 := pqt.NewTable("user").AddColumn(pqt.NewColumn("id", pqt.TypeSerial(), pqt.WithPrimaryKey())) t2 := pqt.NewTable("comment") t2.AddRelationship(pqt.OneToOne(t1, pqt.WithColumnName(icn))) if len(t1.InversedRelationships) != 0 { t.Fatalf("user table should have exactly 0 relationship, got %d", len(t1.InversedRelationships)) } if len(t2.OwnedRelationships) != 1 { t.Fatalf("comment table should have exactly 1 relationship, got %d", len(t2.OwnedRelationships)) } var exists bool for _, c := range t2.Columns { if c.Name == icn { exists = true break } } if !exists { t.Errorf("comment table should have column with name %s", icn) } }
func TestForeignKey(t *testing.T) { t1 := pqt.NewTable("left") c11 := pqt.NewColumn("id", pqt.TypeSerialBig()) c12 := pqt.NewColumn("name", pqt.TypeText()) t1.AddColumn(c11) t1.AddColumn(c12) t2 := pqt.NewTable("right") c21 := pqt.NewColumn("id", pqt.TypeSerialBig()) c22 := pqt.NewColumn("name", pqt.TypeText()) t2.AddColumn(c21) t2.AddColumn(c22) cstr := pqt.ForeignKey(t1, pqt.Columns{c11, c12}, pqt.Columns{c21, c22}) if cstr.Type != pqt.ConstraintTypeForeignKey { t.Errorf("wrong type, expected %s but got %s", pqt.ConstraintTypeForeignKey, cstr.Type) } if len(cstr.Columns) != 2 { t.Errorf("wrong number of columns, expected %d but got %d", 2, len(cstr.Columns)) } if len(cstr.ReferenceColumns) != 2 { t.Errorf("wrong number of columns, expected %d but got %d", 2, len(cstr.ReferenceColumns)) } if !reflect.DeepEqual(cstr.Table, t1) { t.Errorf("table does not match, expected %v but got %v", t1, cstr.Table) } if !reflect.DeepEqual(cstr.ReferenceTable, t2) { t.Errorf("reference table does not match, expected %v but got %v", t2, cstr.ReferenceTable) } }
func schema(sn string) *pqt.Schema { title := pqt.NewColumn("title", pqt.TypeText(), pqt.WithNotNull(), pqt.WithUnique()) lead := pqt.NewColumn("lead", pqt.TypeText()) news := pqt.NewTable("news", pqt.WithTableIfNotExists()). AddColumn(pqt.NewColumn("id", pqt.TypeSerialBig(), pqt.WithPrimaryKey())). AddColumn(title). AddColumn(lead). AddColumn(pqt.NewColumn("continue", pqt.TypeBool(), pqt.WithNotNull(), pqt.WithDefault("false"))). AddColumn(pqt.NewColumn("content", pqt.TypeText(), pqt.WithNotNull())). AddUnique(title, lead) comment := pqt.NewTable("comment", pqt.WithTableIfNotExists()). AddColumn(pqt.NewColumn("id", pqt.TypeSerialBig())). AddColumn(pqt.NewColumn("content", pqt.TypeText(), pqt.WithNotNull())). AddColumn(pqt.NewColumn( "news_title", pqt.TypeText(), pqt.WithNotNull(), pqt.WithReference(title, pqt.WithBidirectional(), pqt.WithOwnerName("comments_by_news_title"), pqt.WithInversedName("news_by_title")), )) category := pqt.NewTable("category", pqt.WithTableIfNotExists()). AddColumn(pqt.NewColumn("id", pqt.TypeSerialBig(), pqt.WithPrimaryKey())). AddColumn(pqt.NewColumn("name", pqt.TypeText(), pqt.WithNotNull())). AddColumn(pqt.NewColumn("content", pqt.TypeText(), pqt.WithNotNull())). AddRelationship( pqt.OneToMany( pqt.SelfReference(), pqt.WithBidirectional(), pqt.WithInversedName("child_category"), pqt.WithOwnerName("parent_category"), pqt.WithColumnName("parent_id"), ), ) pkg := pqt.NewTable("package", pqt.WithTableIfNotExists()). AddColumn(pqt.NewColumn("id", pqt.TypeSerialBig(), pqt.WithPrimaryKey())). AddColumn(pqt.NewColumn("break", pqt.TypeText())). AddRelationship(pqt.ManyToOne( category, pqt.WithBidirectional(), )) timestampable(news) timestampable(comment) timestampable(category) timestampable(pkg) comment.AddRelationship(pqt.ManyToOne(news, pqt.WithBidirectional(), pqt.WithInversedName("news_by_id")), pqt.WithNotNull()) pqt.ManyToMany(category, news, pqt.WithBidirectional()) return pqt.NewSchema(sn, pqt.WithSchemaIfNotExists()). AddTable(category). AddTable(pkg). AddTable(news). AddTable(comment) }
func TestTable_AddColumn(t *testing.T) { c0 := pqt.NewColumn("c0", pqt.TypeSerialBig(), pqt.WithPrimaryKey()) c1 := &pqt.Column{Name: "c1"} c2 := &pqt.Column{Name: "c2"} c3 := &pqt.Column{Name: "c3"} tbl := pqt.NewTable("test"). AddColumn(c0). AddColumn(c1). AddColumn(c2). AddColumn(c3). AddColumn(pqt.NewColumn("c4", pqt.TypeIntegerBig(), pqt.WithReference(c0))). AddRelationship(pqt.ManyToOne(pqt.SelfReference())) if len(tbl.Columns) != 6 { t.Errorf("wrong number of colums, expected %d but got %d", 6, len(tbl.Columns)) } if len(tbl.OwnedRelationships) != 2 { // Reference is not a relationship t.Errorf("wrong number of owned relationships, expected %d but got %d", 2, len(tbl.OwnedRelationships)) } for i, c := range tbl.Columns { if c.Name == "" { t.Errorf("column #%d table name is empty", i) } if c.Table == nil { t.Errorf("column #%d table nil pointer", i) } } }
func TestTable_AddRelationship_oneToOneSelfReferencing(t *testing.T) { user := pqt.NewTable("user").AddColumn(pqt.NewColumn("id", pqt.TypeSerial(), pqt.WithPrimaryKey())) user.AddRelationship(pqt.OneToOne( pqt.SelfReference(), pqt.WithInversedName("child"), pqt.WithOwnerName("parent"), )) if len(user.OwnedRelationships) != 1 { t.Fatalf("user should have 1 owned relationship, but has %d", len(user.OwnedRelationships)) } if user.OwnedRelationships[0].OwnerName != "parent" { t.Errorf("user relationship to user should be mapped by parent") } if user.OwnedRelationships[0].OwnerTable != user { t.Errorf("user relationship to user should be mapped by user table") } if user.OwnedRelationships[0].Type != pqt.RelationshipTypeOneToOne { t.Errorf("user relationship to user should be %d, but is %d", pqt.RelationshipTypeOneToOne, user.OwnedRelationships[0].Type) } if len(user.InversedRelationships) != 0 { t.Fatalf("user should have 0 inversed relationship, but has %d", len(user.InversedRelationships)) } }
func TestTable_AddRelationship_oneToMany(t *testing.T) { user := pqt.NewTable("user").AddColumn(pqt.NewColumn("id", pqt.TypeSerial(), pqt.WithPrimaryKey())) comment := pqt.NewTable("comment").AddColumn(pqt.NewColumn("id", pqt.TypeSerial(), pqt.WithPrimaryKey())) user.AddRelationship(pqt.OneToMany( comment, pqt.WithBidirectional(), pqt.WithInversedName("author"), pqt.WithOwnerName("comments"), )) if len(user.InversedRelationships) != 1 { t.Fatalf("user should have 1 inversed relationship, but has %d", len(user.InversedRelationships)) } if user.InversedRelationships[0].OwnerName != "comments" { t.Errorf("user inversed relationship to comment should be mapped by comments") } if user.InversedRelationships[0].OwnerTable != comment { t.Errorf("user inversed relationship to comment should be mapped by comment table") } if user.InversedRelationships[0].Type != pqt.RelationshipTypeOneToMany { t.Errorf("user inversed relationship to comment should be one to many") } if len(comment.OwnedRelationships) != 1 { t.Fatalf("comment should have 1 owned relationship, but has %d", len(comment.OwnedRelationships)) } if comment.OwnedRelationships[0].InversedName != "author" { t.Errorf("comment relationship to user should be mapped by author") } if comment.OwnedRelationships[0].InversedTable != user { t.Errorf("comment relationship to user should be mapped by user table") } if comment.OwnedRelationships[0].Type != pqt.RelationshipTypeOneToMany { t.Errorf("comment relationship to user should be %d, but is %d", pqt.RelationshipTypeOneToMany, comment.OwnedRelationships[0].Type) } }
func TestConstraint_Name(t *testing.T) { id := pqt.NewColumn("id", pqt.TypeSerial(), pqt.WithPrimaryKey()) success := map[string]*pqt.Constraint{ "public.user_id_pkey": pqt.PrimaryKey(pqt.NewTable("user"), id), "custom_schema.user_id_pkey": pqt.PrimaryKey(func() *pqt.Table { t := pqt.NewTable("user") s := pqt.NewSchema("custom_schema") s.AddTable(t) return t }(), id), "<missing table>": pqt.Check(nil, "a > b", id), "public.news_key": pqt.Unique(pqt.NewTable("news")), } for expected, given := range success { got := given.Name() if got != expected { t.Errorf("wrong name, expected %s got %s", expected, got) } } }
func TestNewTable(t *testing.T) { tbl := pqt.NewTable("test", pqt.WithTableIfNotExists(), pqt.WithTableSpace("table_space"), pqt.WithTemporary()) if !tbl.IfNotExists { t.Errorf("table should have field if not exists set to true") } if !tbl.Temporary { t.Errorf("table should have field temporary set to true") } if tbl.TableSpace != "table_space" { t.Errorf("table should have field table space set to table_space") } }
func TestGenerator_Generate(t *testing.T) { success := []struct { expected string given *pqt.Table }{ { expected: `-- do not modify, generated by pqt CREATE TEMPORARY TABLE schema.user ( created_at TIMESTAMPTZ, password TEXT, username TEXT NOT NULL ); `, given: func() *pqt.Table { return pqt.NewTable("user", pqt.WithTemporary()). SetSchema(pqt.NewSchema("schema")). AddColumn(&pqt.Column{Name: "username", Type: pqt.TypeText(), NotNull: true}). AddColumn(&pqt.Column{Name: "password", Type: pqt.TypeText()}). AddColumn(&pqt.Column{Name: "created_at", Type: pqt.TypeTimestampTZ()}) }(), }, { expected: `-- do not modify, generated by pqt CREATE TABLE IF NOT EXISTS table_name ( created_at TIMESTAMPTZ DEFAULT NOW() NOT NULL, created_by INTEGER NOT NULL, enabled BOOL, end_at TIMESTAMPTZ NOT NULL, id SERIAL, name TEXT, price DECIMAL(10,1), rel_id INTEGER, slug TEXT NOT NULL, start_at TIMESTAMPTZ NOT NULL, updated_at TIMESTAMPTZ, updated_by INTEGER, CONSTRAINT "public.table_name_id_pkey" PRIMARY KEY (id), CONSTRAINT "public.table_name_name_key" UNIQUE (name), CONSTRAINT "public.table_name_slug_key" UNIQUE (slug), CONSTRAINT "public.table_name_rel_id_fkey" FOREIGN KEY (rel_id) REFERENCES related_table (id), CONSTRAINT "public.table_name_start_at_end_at_check" CHECK ((start_at IS NULL AND end_at IS NULL) OR start_at < end_at) ); `, given: func() *pqt.Table { id := pqt.Column{Name: "id", Type: pqt.TypeSerial()} startAt := &pqt.Column{Name: "start_at", Type: pqt.TypeTimestampTZ(), NotNull: true} endAt := &pqt.Column{Name: "end_at", Type: pqt.TypeTimestampTZ(), NotNull: true} _ = pqt.NewTable("related_table"). AddColumn(&id) return pqt.NewTable("table_name", pqt.WithTableIfNotExists()). AddColumn(&pqt.Column{Name: "id", Type: pqt.TypeSerial(), PrimaryKey: true}). AddColumn(&pqt.Column{Name: "rel_id", Type: pqt.TypeInteger(), Reference: &id}). AddColumn(&pqt.Column{Name: "name", Type: pqt.TypeText(), Unique: true}). AddColumn(&pqt.Column{Name: "enabled", Type: pqt.TypeBool()}). AddColumn(&pqt.Column{Name: "price", Type: pqt.TypeDecimal(10, 1)}). AddColumn(startAt). AddColumn(endAt). AddColumn(pqt.NewColumn("created_at", pqt.TypeTimestampTZ(), pqt.WithNotNull(), pqt.WithDefault("NOW()"))). AddColumn(&pqt.Column{Name: "created_by", Type: pqt.TypeInteger(), NotNull: true}). AddColumn(&pqt.Column{Name: "updated_at", Type: pqt.TypeTimestampTZ()}). AddColumn(&pqt.Column{Name: "updated_by", Type: pqt.TypeInteger()}). AddColumn(&pqt.Column{Name: "slug", Type: pqt.TypeText(), NotNull: true, Unique: true}). AddCheck("(start_at IS NULL AND end_at IS NULL) OR start_at < end_at", startAt, endAt) }(), }, } for i, data := range success { q, err := pqtsql.NewGenerator().Generate(&pqt.Schema{ Tables: []*pqt.Table{data.given}, }) if err != nil { t.Errorf("unexpected error for schema #%d : %s", i, err.Error()) continue } if string(q) != data.expected { t.Errorf("wrong query, expected:\n'%s'\nbut got:\n'%s'", data.expected, q) } } }
func TestGenerator_Generate(t *testing.T) { cases := map[string]struct { schema *pqt.Schema generator *pqtgo.Generator expected string }{ "basic": { schema: pqt.NewSchema("text"), generator: pqtgo.NewGenerator(). AddImport("github.com/piotrkowalczuk/ntypes"), expected: `package main import ( "github.com/go-kit/kit/log" "github.com/m4rw3r/uuid" "github.com/piotrkowalczuk/ntypes" ) `, }, "custom-package": { schema: pqt.NewSchema("text"), generator: pqtgo.NewGenerator(). SetPackage("example"). AddImport("fmt"), expected: `package example import ( "github.com/go-kit/kit/log" "github.com/m4rw3r/uuid" "fmt" ) `, }, "simple table": { schema: pqt.NewSchema("text").AddTable( pqt.NewTable("first").AddColumn( pqt.NewColumn("id", pqt.TypeSerialBig()), ).AddColumn( pqt.NewColumn("name", pqt.TypeText()), ), ), generator: pqtgo.NewGenerator(). SetPackage("custom"), expected: `package custom import ( "github.com/go-kit/kit/log" "github.com/m4rw3r/uuid" ) const ( tableFirst = "text.first" tableFirstColumnId = "id" tableFirstColumnName = "name" ) var ( tableFirstColumns = []string{ tableFirstColumnId, tableFirstColumnName, }) type firstEntity struct{ // id ... id *ntypes.Int64 // name ... name *ntypes.String } func (e *firstEntity) prop(cn string) (interface{}, bool) { switch cn { case tableFirstColumnId: return &e.id, true case tableFirstColumnName: return &e.name, true default: return nil, false } } func (e *firstEntity) props(cns ...string) ([]interface{}, error) { res := make([]interface{}, 0, len(cns)) for _, cn := range cns { if prop, ok := e.prop(cn); ok { res = append(res, prop) } else { return nil, fmt.Errorf("unexpected column provided: %s", cn) } } return res, nil } // firstIterator is not thread safe. type firstIterator struct { rows *sql.Rows cols []string } func (i *firstIterator) Next() bool { return i.rows.Next() } func (i *firstIterator) Close() error { return i.rows.Close() } func (i *firstIterator) Err() error { return i.rows.Err() } // Columns is wrapper around sql.Rows.Columns method, that also cache outpu inside iterator. func (i *firstIterator) Columns() ([]string, error) { if i.cols == nil { cols, err := i.rows.Columns() if err != nil { return nil, err } i.cols = cols } return i.cols, nil } // Ent is wrapper around first method that makes iterator more generic. func (i *firstIterator) Ent() (interface{}, error) { return i.First() } func (i *firstIterator) First() (*firstEntity, error) { var ent firstEntity cols, err := i.rows.Columns() if err != nil { return nil, err } props, err := ent.props(cols...) if err != nil { return nil, err } if err := i.rows.Scan(props...); err != nil { return nil, err } return &ent, nil } type firstCriteria struct { offset, limit int64 sort map[string]bool id *qtypes.Int64 name *qtypes.String } func (c *firstCriteria) WriteComposition(sel string, com *pqtgo.Composer, opt *pqtgo.CompositionOpts) (err error) { if err = pqtgo.WriteCompositionQueryInt64(c.id, tableFirstColumnId, com, pqtgo.And); err != nil { return } if err = pqtgo.WriteCompositionQueryString(c.name, tableFirstColumnName, com, pqtgo.And); err != nil { return } if len(c.sort) > 0 { i:=0 com.WriteString(" ORDER BY ") for cn, asc := range c.sort { for _, tcn := range tableFirstColumns { if cn == tcn { if i > 0 { com.WriteString(", ") } com.WriteString(cn) if !asc { com.WriteString(" DESC ") } i++ break } } } } if c.offset > 0 { if _, err = com.WriteString(" OFFSET "); err != nil { return } if err = com.WritePlaceholder(); err != nil { return } if _, err = com.WriteString(" "); err != nil { return } com.Add(c.offset) } if c.limit > 0 { if _, err = com.WriteString(" LIMIT "); err != nil { return } if err = com.WritePlaceholder(); err != nil { return } if _, err = com.WriteString(" "); err != nil { return } com.Add(c.limit) } return } type firstPatch struct { id *ntypes.Int64 name *ntypes.String } type firstRepositoryBase struct { table string columns []string db *sql.DB dbg bool log log.Logger } func scanFirstRows(rows *sql.Rows) ([]*firstEntity, error) { var ( entities []*firstEntity err error ) for rows.Next() { var ent firstEntity err = rows.Scan( &ent.id, &ent.name, ) if err != nil { return nil, err } entities = append(entities, &ent) } if rows.Err() != nil { return nil, rows.Err() } return entities, nil } func (r *firstRepositoryBase) count(c *firstCriteria) (int64, error) { com := pqtgo.NewComposer(2) buf := bytes.NewBufferString("SELECT COUNT(*) FROM ") buf.WriteString(r.table) if err := c.WriteComposition("", com, pqtgo.And); err != nil { return 0, err } if com.Dirty { buf.WriteString(" WHERE ") } if com.Len() > 0 { buf.ReadFrom(com) } if r.dbg { if err := r.log.Log("msg", buf.String(), "function", "Count"); err != nil { return 0, err } } var count int64 if err := r.db.QueryRow(buf.String(), com.Args()...).Scan(&count); err != nil { return 0, err } return count, nil } func (r *firstRepositoryBase) find(c *firstCriteria) ([]*firstEntity, error) { com := pqtgo.NewComposer(1) buf := bytes.NewBufferString("SELECT ") buf.WriteString(strings.Join(r.columns, ", ")) buf.WriteString(" FROM ") buf.WriteString(r.table) buf.WriteString(" ") if err := c.WriteComposition("", com, pqtgo.And); err != nil { return nil, err } if com.Dirty { buf.WriteString(" WHERE ") } if com.Len() > 0 { buf.ReadFrom(com) } if r.dbg { if err := r.log.Log("msg", buf.String(), "function", "Find"); err != nil { return nil, err } } rows, err := r.db.Query(buf.String(), com.Args()...) if err != nil { return nil, err } defer rows.Close() return scanFirstRows(rows) } func (r *firstRepositoryBase) findIter(c *firstCriteria) (*firstIterator, error) { com := pqtgo.NewComposer(1) buf := bytes.NewBufferString("SELECT ") buf.WriteString(strings.Join(r.columns, ", ")) buf.WriteString(" FROM ") buf.WriteString(r.table) buf.WriteString(" ") if err := c.WriteComposition("", com, pqtgo.And); err != nil { return nil, err } if com.Dirty { buf.WriteString(" WHERE ") } if com.Len() > 0 { buf.ReadFrom(com) } if r.dbg { if err := r.log.Log("msg", buf.String(), "function", "Find"); err != nil { return nil, err } } rows, err := r.db.Query(buf.String(), com.Args()...) if err != nil { return nil, err } return &firstIterator{rows: rows}, nil } func (r *firstRepositoryBase) insert(e *firstEntity) (*firstEntity, error) { insert := pqcomp.New(0, 2) insert.AddExpr(tableFirstColumnName, "", e.name) b := bytes.NewBufferString("INSERT INTO " + r.table) if insert.Len() != 0 { b.WriteString(" (") for insert.Next() { if !insert.First() { b.WriteString(", ") } fmt.Fprintf(b, "%s", insert.Key()) } insert.Reset() b.WriteString(") VALUES (") for insert.Next() { if !insert.First() { b.WriteString(", ") } fmt.Fprintf(b, "%s", insert.PlaceHolder()) } b.WriteString(")") if len(r.columns) > 0 { b.WriteString(" RETURNING ") b.WriteString(strings.Join(r.columns, ", ")) } } if r.dbg { if err := r.log.Log("msg", b.String(), "function", "Insert"); err != nil { return nil, err } } err := r.db.QueryRow(b.String(), insert.Args()...).Scan( &e.id, &e.name, ) if err != nil { return nil, err } return e, nil } func (r *firstRepositoryBase) upsert(e *firstEntity, p *firstPatch, inf ...string) (*firstEntity, error) { insert := pqcomp.New(0, 2) update := insert.Compose(2) insert.AddExpr(tableFirstColumnName, "", e.name) if len(inf) > 0 { update.AddExpr(tableFirstColumnName, "=", p.name) } b := bytes.NewBufferString("INSERT INTO " + r.table) if insert.Len() > 0 { b.WriteString(" (") for insert.Next() { if !insert.First() { b.WriteString(", ") } fmt.Fprintf(b, "%s", insert.Key()) } insert.Reset() b.WriteString(") VALUES (") for insert.Next() { if !insert.First() { b.WriteString(", ") } fmt.Fprintf(b, "%s", insert.PlaceHolder()) } b.WriteString(")") } b.WriteString(" ON CONFLICT ") if len(inf) > 0 && update.Len() > 0 { b.WriteString(" (") for j, i := range inf { if j != 0 { b.WriteString(", ") } b.WriteString(i) } b.WriteString(") ") b.WriteString(" DO UPDATE SET ") for update.Next() { if !update.First() { b.WriteString(", ") } b.WriteString(update.Key()) b.WriteString(" ") b.WriteString(update.Oper()) b.WriteString(" ") b.WriteString(update.PlaceHolder()) } } else { b.WriteString(" DO NOTHING ") } if insert.Len() > 0 { if len(r.columns) > 0 { b.WriteString(" RETURNING ") b.WriteString(strings.Join(r.columns, ", ")) } } if r.dbg { if err := r.log.Log("msg", b.String(), "function", "Upsert"); err != nil { return nil, err } } err := r.db.QueryRow(b.String(), insert.Args()...).Scan( &e.id, &e.name, ) if err != nil { return nil, err } return e, nil } `, }, } for hint, c := range cases { b, err := c.generator.Generate(c.schema) if err != nil { t.Errorf("%s: unexpected error: %s", hint, err.Error()) continue } assertGoCode(t, c.expected, string(b), hint) } }