func TestTable_AddRelationship_oneToOneUnidirectional(t *testing.T) { user := pqt.NewTable("user").AddColumn(pqt.NewColumn("id", pqt.TypeSerial(), pqt.WithPrimaryKey())) userDetail := pqt.NewTable("user_detail").AddColumn(pqt.NewColumn("id", pqt.TypeSerial(), pqt.WithPrimaryKey())). AddRelationship(pqt.OneToOne( user, pqt.WithInversedName("user"), pqt.WithOwnerName("details"), )) if len(user.InversedRelationships) != 0 { t.Fatalf("user should have 0 relationship, but has %d", len(user.InversedRelationships)) } if len(userDetail.OwnedRelationships) != 1 { t.Fatalf("user_detail should have 1 relationship, but has %d", len(userDetail.OwnedRelationships)) } if userDetail.OwnedRelationships[0].InversedName != "user" { t.Errorf("user_detail relationship to user should be mapped by user") } if userDetail.OwnedRelationships[0].InversedTable != user { t.Errorf("user_detail relationship to user should be mapped by user table") } if userDetail.OwnedRelationships[0].Type != pqt.RelationshipTypeOneToOne { t.Errorf("user_detail relationship to user should be %d, but is %d", pqt.RelationshipTypeOneToOne, userDetail.OwnedRelationships[0].Type) } }
func TestWithColumnName(t *testing.T) { icn := "author" t1 := pqt.NewTable("user").AddColumn(pqt.NewColumn("id", pqt.TypeSerial(), pqt.WithPrimaryKey())) t2 := pqt.NewTable("comment") t2.AddRelationship(pqt.OneToOne(t1, pqt.WithColumnName(icn))) if len(t1.InversedRelationships) != 0 { t.Fatalf("user table should have exactly 0 relationship, got %d", len(t1.InversedRelationships)) } if len(t2.OwnedRelationships) != 1 { t.Fatalf("comment table should have exactly 1 relationship, got %d", len(t2.OwnedRelationships)) } var exists bool for _, c := range t2.Columns { if c.Name == icn { exists = true break } } if !exists { t.Errorf("comment table should have column with name %s", icn) } }
func TestTable_AddRelationship_oneToOneSelfReferencing(t *testing.T) { user := pqt.NewTable("user").AddColumn(pqt.NewColumn("id", pqt.TypeSerial(), pqt.WithPrimaryKey())) user.AddRelationship(pqt.OneToOne( pqt.SelfReference(), pqt.WithInversedName("child"), pqt.WithOwnerName("parent"), )) if len(user.OwnedRelationships) != 1 { t.Fatalf("user should have 1 owned relationship, but has %d", len(user.OwnedRelationships)) } if user.OwnedRelationships[0].OwnerName != "parent" { t.Errorf("user relationship to user should be mapped by parent") } if user.OwnedRelationships[0].OwnerTable != user { t.Errorf("user relationship to user should be mapped by user table") } if user.OwnedRelationships[0].Type != pqt.RelationshipTypeOneToOne { t.Errorf("user relationship to user should be %d, but is %d", pqt.RelationshipTypeOneToOne, user.OwnedRelationships[0].Type) } if len(user.InversedRelationships) != 0 { t.Fatalf("user should have 0 inversed relationship, but has %d", len(user.InversedRelationships)) } }
func TestTable_AddRelationship_oneToMany(t *testing.T) { user := pqt.NewTable("user").AddColumn(pqt.NewColumn("id", pqt.TypeSerial(), pqt.WithPrimaryKey())) comment := pqt.NewTable("comment").AddColumn(pqt.NewColumn("id", pqt.TypeSerial(), pqt.WithPrimaryKey())) user.AddRelationship(pqt.OneToMany( comment, pqt.WithBidirectional(), pqt.WithInversedName("author"), pqt.WithOwnerName("comments"), )) if len(user.InversedRelationships) != 1 { t.Fatalf("user should have 1 inversed relationship, but has %d", len(user.InversedRelationships)) } if user.InversedRelationships[0].OwnerName != "comments" { t.Errorf("user inversed relationship to comment should be mapped by comments") } if user.InversedRelationships[0].OwnerTable != comment { t.Errorf("user inversed relationship to comment should be mapped by comment table") } if user.InversedRelationships[0].Type != pqt.RelationshipTypeOneToMany { t.Errorf("user inversed relationship to comment should be one to many") } if len(comment.OwnedRelationships) != 1 { t.Fatalf("comment should have 1 owned relationship, but has %d", len(comment.OwnedRelationships)) } if comment.OwnedRelationships[0].InversedName != "author" { t.Errorf("comment relationship to user should be mapped by author") } if comment.OwnedRelationships[0].InversedTable != user { t.Errorf("comment relationship to user should be mapped by user table") } if comment.OwnedRelationships[0].Type != pqt.RelationshipTypeOneToMany { t.Errorf("comment relationship to user should be %d, but is %d", pqt.RelationshipTypeOneToMany, comment.OwnedRelationships[0].Type) } }
func generateBaseType(t pqt.Type, m int32) string { switch t { case pqt.TypeText(): return chooseType("string", "*ntypes.String", "*qtypes.String", m) case pqt.TypeBool(): return chooseType("bool", "*ntypes.Bool", "*ntypes.Bool", m) case pqt.TypeIntegerSmall(): return chooseType("int16", "*int16", "*int16", m) case pqt.TypeInteger(): return chooseType("int32", "*ntypes.Int32", "*ntypes.Int32", m) case pqt.TypeIntegerBig(): return chooseType("int64", "*ntypes.Int64", "*qtypes.Int64", m) case pqt.TypeSerial(): return chooseType("int32", "*ntypes.Int32", "*ntypes.Int32", m) case pqt.TypeSerialSmall(): return chooseType("int16", "*int16", "*int16", m) case pqt.TypeSerialBig(): return chooseType("int64", "*ntypes.Int64", "*qtypes.Int64", m) case pqt.TypeTimestamp(), pqt.TypeTimestampTZ(): return chooseType("time.Time", "*time.Time", "*qtypes.Timestamp", m) case pqt.TypeReal(): return chooseType("float32", "*ntypes.Float32", "*ntypes.Float32", m) case pqt.TypeDoublePrecision(): return chooseType("float64", "*ntypes.Float64", "*qtypes.Float64", m) case pqt.TypeBytea(), pqt.TypeJSON(), pqt.TypeJSONB(): return "[]byte" case pqt.TypeUUID(): return "uuid.UUID" default: gt := t.String() switch { case strings.HasPrefix(gt, "SMALLINT["): return chooseType("pqt.ArrayInt64", "pqt.ArrayInt64", "*qtypes.Int64", m) case strings.HasPrefix(gt, "INTEGER["): return chooseType("pqt.ArrayInt64", "pqt.ArrayInt64", "*qtypes.Int64", m) case strings.HasPrefix(gt, "BIGINT["): return chooseType("pqt.ArrayInt64", "pqt.ArrayInt64", "*qtypes.Int64", m) case strings.HasPrefix(gt, "DOUBLE PRECISION["): return chooseType("pqt.ArrayFloat64", "pqt.ArrayFloat64", "*qtypes.Float64", m) case strings.HasPrefix(gt, "TEXT["): return "pqt.ArrayString" case strings.HasPrefix(gt, "DECIMAL"), strings.HasPrefix(gt, "NUMERIC"): return chooseType("float64", "*ntypes.Float64", "*qtypes.Float64", m) case strings.HasPrefix(gt, "VARCHAR"): return chooseType("string", "*ntypes.String", "*qtypes.String", m) default: return "interface{}" } } }
func TestConstraint_Name(t *testing.T) { id := pqt.NewColumn("id", pqt.TypeSerial(), pqt.WithPrimaryKey()) success := map[string]*pqt.Constraint{ "public.user_id_pkey": pqt.PrimaryKey(pqt.NewTable("user"), id), "custom_schema.user_id_pkey": pqt.PrimaryKey(func() *pqt.Table { t := pqt.NewTable("user") s := pqt.NewSchema("custom_schema") s.AddTable(t) return t }(), id), "<missing table>": pqt.Check(nil, "a > b", id), "public.news_key": pqt.Unique(pqt.NewTable("news")), } for expected, given := range success { got := given.Name() if got != expected { t.Errorf("wrong name, expected %s got %s", expected, got) } } }
func (g *Generator) generateRepositoryUpsert(code *bytes.Buffer, table *pqt.Table) { if g.ver < 9.5 { return } entityName := g.name(table.Name) fmt.Fprintf(code, `func (r *%sRepositoryBase) %s(e *%sEntity, p *%sPatch, inf ...string) (*%sEntity, error) {`, entityName, g.name("Upsert"), entityName, entityName, entityName, ) fmt.Fprintf(code, ` insert := pqcomp.New(0, %d) update := insert.Compose(%d) `, len(table.Columns), len(table.Columns)) InsertLoop: for _, c := range table.Columns { switch c.Type { case pqt.TypeSerial(), pqt.TypeSerialBig(), pqt.TypeSerialSmall(): continue InsertLoop default: if g.canBeNil(c, modeOptional) { fmt.Fprintf(code, ` if e.%s != nil { insert.AddExpr(%s, "", e.%s) } `, g.propertyName(c.Name), g.columnNameWithTableName(table.Name, c.Name), g.propertyName(c.Name), ) } else { fmt.Fprintf(code, `insert.AddExpr(%s, "", e.%s)`, g.columnNameWithTableName(table.Name, c.Name), g.propertyName(c.Name), ) } fmt.Fprintln(code, "") } } fmt.Fprintln(code, "if len(inf) > 0 {") UpdateLoop: for _, c := range table.Columns { switch c.Type { case pqt.TypeSerial(), pqt.TypeSerialBig(), pqt.TypeSerialSmall(): continue UpdateLoop default: if g.canBeNil(c, modeOptional) { fmt.Fprintf(code, ` if p.%s != nil { update.AddExpr(%s, "=", p.%s) } `, g.propertyName(c.Name), g.columnNameWithTableName(table.Name, c.Name), g.propertyName(c.Name), ) } else { fmt.Fprintf(code, `update.AddExpr(%s, "=", p.%s)`, g.columnNameWithTableName(table.Name, c.Name), g.propertyName(c.Name)) } fmt.Fprintln(code, "") } } fmt.Fprintln(code, "}") fmt.Fprint(code, ` b := bytes.NewBufferString("INSERT INTO " + r.table) if insert.Len() > 0 { b.WriteString(" (") for insert.Next() { if !insert.First() { b.WriteString(", ") } fmt.Fprintf(b, "%s", insert.Key()) } insert.Reset() b.WriteString(") VALUES (") for insert.Next() { if !insert.First() { b.WriteString(", ") } fmt.Fprintf(b, "%s", insert.PlaceHolder()) } b.WriteString(")") } b.WriteString(" ON CONFLICT ") if len(inf) > 0 && update.Len() > 0 { b.WriteString(" (") for j, i := range inf { if j != 0 { b.WriteString(", ") } b.WriteString(i) } b.WriteString(") ") b.WriteString(" DO UPDATE SET ") for update.Next() { if !update.First() { b.WriteString(", ") } b.WriteString(update.Key()) b.WriteString(" ") b.WriteString(update.Oper()) b.WriteString(" ") b.WriteString(update.PlaceHolder()) } } else { b.WriteString(" DO NOTHING ") } if insert.Len() > 0 { if len(r.columns) > 0 { b.WriteString(" RETURNING ") b.WriteString(strings.Join(r.columns, ", ")) } } if r.dbg { if err := r.log.Log("msg", b.String(), "function", "Upsert"); err != nil { return nil, err } } err := r.db.QueryRow(b.String(), insert.Args()...).Scan( `) for _, c := range table.Columns { fmt.Fprintf(code, "&e.%s,\n", g.propertyName(c.Name)) } fmt.Fprint(code, `) if err != nil { return nil, err } return e, nil } `) }
func (g *Generator) generateRepositoryInsert(w io.Writer, table *pqt.Table) { entityName := g.name(table.Name) fmt.Fprintf(w, `func (r *%sRepositoryBase) %s(e *%sEntity) (*%sEntity, error) {`, entityName, g.name("Insert"), entityName, entityName) fmt.Fprintf(w, ` insert := pqcomp.New(0, %d) `, len(table.Columns)) ColumnsLoop: for _, c := range table.Columns { switch c.Type { case pqt.TypeSerial(), pqt.TypeSerialBig(), pqt.TypeSerialSmall(): continue ColumnsLoop default: if g.canBeNil(c, modeOptional) { fmt.Fprintf(w, ` if e.%s != nil { insert.AddExpr(%s, "", e.%s) } `, g.propertyName(c.Name), g.columnNameWithTableName(table.Name, c.Name), g.propertyName(c.Name), ) } else { fmt.Fprintf( w, `insert.AddExpr(%s, "", e.%s)`, g.columnNameWithTableName(table.Name, c.Name), g.propertyName(c.Name), ) } fmt.Fprintln(w, "") } } fmt.Fprint(w, ` b := bytes.NewBufferString("INSERT INTO " + r.table) if insert.Len() != 0 { b.WriteString(" (") for insert.Next() { if !insert.First() { b.WriteString(", ") } fmt.Fprintf(b, "%s", insert.Key()) } insert.Reset() b.WriteString(") VALUES (") for insert.Next() { if !insert.First() { b.WriteString(", ") } fmt.Fprintf(b, "%s", insert.PlaceHolder()) } b.WriteString(")") if len(r.columns) > 0 { b.WriteString(" RETURNING ") b.WriteString(strings.Join(r.columns, ", ")) } } if r.dbg { if err := r.log.Log("msg", b.String(), "function", "Insert"); err != nil { return nil, err } } err := r.db.QueryRow(b.String(), insert.Args()...).Scan( `) for _, c := range table.Columns { fmt.Fprintf(w, "&e.%s,\n", g.propertyName(c.Name)) } fmt.Fprint(w, `) if err != nil { return nil, err } return e, nil } `) }
func TestGenerator_Generate(t *testing.T) { success := []struct { expected string given *pqt.Table }{ { expected: `-- do not modify, generated by pqt CREATE TEMPORARY TABLE schema.user ( created_at TIMESTAMPTZ, password TEXT, username TEXT NOT NULL ); `, given: func() *pqt.Table { return pqt.NewTable("user", pqt.WithTemporary()). SetSchema(pqt.NewSchema("schema")). AddColumn(&pqt.Column{Name: "username", Type: pqt.TypeText(), NotNull: true}). AddColumn(&pqt.Column{Name: "password", Type: pqt.TypeText()}). AddColumn(&pqt.Column{Name: "created_at", Type: pqt.TypeTimestampTZ()}) }(), }, { expected: `-- do not modify, generated by pqt CREATE TABLE IF NOT EXISTS table_name ( created_at TIMESTAMPTZ DEFAULT NOW() NOT NULL, created_by INTEGER NOT NULL, enabled BOOL, end_at TIMESTAMPTZ NOT NULL, id SERIAL, name TEXT, price DECIMAL(10,1), rel_id INTEGER, slug TEXT NOT NULL, start_at TIMESTAMPTZ NOT NULL, updated_at TIMESTAMPTZ, updated_by INTEGER, CONSTRAINT "public.table_name_id_pkey" PRIMARY KEY (id), CONSTRAINT "public.table_name_name_key" UNIQUE (name), CONSTRAINT "public.table_name_slug_key" UNIQUE (slug), CONSTRAINT "public.table_name_rel_id_fkey" FOREIGN KEY (rel_id) REFERENCES related_table (id), CONSTRAINT "public.table_name_start_at_end_at_check" CHECK ((start_at IS NULL AND end_at IS NULL) OR start_at < end_at) ); `, given: func() *pqt.Table { id := pqt.Column{Name: "id", Type: pqt.TypeSerial()} startAt := &pqt.Column{Name: "start_at", Type: pqt.TypeTimestampTZ(), NotNull: true} endAt := &pqt.Column{Name: "end_at", Type: pqt.TypeTimestampTZ(), NotNull: true} _ = pqt.NewTable("related_table"). AddColumn(&id) return pqt.NewTable("table_name", pqt.WithTableIfNotExists()). AddColumn(&pqt.Column{Name: "id", Type: pqt.TypeSerial(), PrimaryKey: true}). AddColumn(&pqt.Column{Name: "rel_id", Type: pqt.TypeInteger(), Reference: &id}). AddColumn(&pqt.Column{Name: "name", Type: pqt.TypeText(), Unique: true}). AddColumn(&pqt.Column{Name: "enabled", Type: pqt.TypeBool()}). AddColumn(&pqt.Column{Name: "price", Type: pqt.TypeDecimal(10, 1)}). AddColumn(startAt). AddColumn(endAt). AddColumn(pqt.NewColumn("created_at", pqt.TypeTimestampTZ(), pqt.WithNotNull(), pqt.WithDefault("NOW()"))). AddColumn(&pqt.Column{Name: "created_by", Type: pqt.TypeInteger(), NotNull: true}). AddColumn(&pqt.Column{Name: "updated_at", Type: pqt.TypeTimestampTZ()}). AddColumn(&pqt.Column{Name: "updated_by", Type: pqt.TypeInteger()}). AddColumn(&pqt.Column{Name: "slug", Type: pqt.TypeText(), NotNull: true, Unique: true}). AddCheck("(start_at IS NULL AND end_at IS NULL) OR start_at < end_at", startAt, endAt) }(), }, } for i, data := range success { q, err := pqtsql.NewGenerator().Generate(&pqt.Schema{ Tables: []*pqt.Table{data.given}, }) if err != nil { t.Errorf("unexpected error for schema #%d : %s", i, err.Error()) continue } if string(q) != data.expected { t.Errorf("wrong query, expected:\n'%s'\nbut got:\n'%s'", data.expected, q) } } }