func TestColumn_DefaultOn(t *testing.T) { success := []struct { d string e []pqt.Event }{ { d: "NOW()", e: []pqt.Event{pqt.EventUpdate}, }, } for _, data := range success { c := pqt.NewColumn("column", pqt.TypeTimestampTZ(), pqt.WithDefault(data.d, data.e...)) EventLoop: for _, e := range data.e { d, ok := c.DefaultOn(e) if !ok { t.Errorf("missing default value for %s", e) continue EventLoop } if d != data.d { t.Errorf("wrong value, expected %s but got %s", data.d, d) } } } }
func generateBaseType(t pqt.Type, m int32) string { switch t { case pqt.TypeText(): return chooseType("string", "*ntypes.String", "*qtypes.String", m) case pqt.TypeBool(): return chooseType("bool", "*ntypes.Bool", "*ntypes.Bool", m) case pqt.TypeIntegerSmall(): return chooseType("int16", "*int16", "*int16", m) case pqt.TypeInteger(): return chooseType("int32", "*ntypes.Int32", "*ntypes.Int32", m) case pqt.TypeIntegerBig(): return chooseType("int64", "*ntypes.Int64", "*qtypes.Int64", m) case pqt.TypeSerial(): return chooseType("int32", "*ntypes.Int32", "*ntypes.Int32", m) case pqt.TypeSerialSmall(): return chooseType("int16", "*int16", "*int16", m) case pqt.TypeSerialBig(): return chooseType("int64", "*ntypes.Int64", "*qtypes.Int64", m) case pqt.TypeTimestamp(), pqt.TypeTimestampTZ(): return chooseType("time.Time", "*time.Time", "*qtypes.Timestamp", m) case pqt.TypeReal(): return chooseType("float32", "*ntypes.Float32", "*ntypes.Float32", m) case pqt.TypeDoublePrecision(): return chooseType("float64", "*ntypes.Float64", "*qtypes.Float64", m) case pqt.TypeBytea(), pqt.TypeJSON(), pqt.TypeJSONB(): return "[]byte" case pqt.TypeUUID(): return "uuid.UUID" default: gt := t.String() switch { case strings.HasPrefix(gt, "SMALLINT["): return chooseType("pqt.ArrayInt64", "pqt.ArrayInt64", "*qtypes.Int64", m) case strings.HasPrefix(gt, "INTEGER["): return chooseType("pqt.ArrayInt64", "pqt.ArrayInt64", "*qtypes.Int64", m) case strings.HasPrefix(gt, "BIGINT["): return chooseType("pqt.ArrayInt64", "pqt.ArrayInt64", "*qtypes.Int64", m) case strings.HasPrefix(gt, "DOUBLE PRECISION["): return chooseType("pqt.ArrayFloat64", "pqt.ArrayFloat64", "*qtypes.Float64", m) case strings.HasPrefix(gt, "TEXT["): return "pqt.ArrayString" case strings.HasPrefix(gt, "DECIMAL"), strings.HasPrefix(gt, "NUMERIC"): return chooseType("float64", "*ntypes.Float64", "*qtypes.Float64", m) case strings.HasPrefix(gt, "VARCHAR"): return chooseType("string", "*ntypes.String", "*qtypes.String", m) default: return "interface{}" } } }
func timestampable(t *pqt.Table) { t.AddColumn(pqt.NewColumn("created_at", pqt.TypeTimestampTZ(), pqt.WithNotNull(), pqt.WithDefault("NOW()"))). AddColumn(pqt.NewColumn("updated_at", pqt.TypeTimestampTZ(), pqt.WithDefault("NOW()", pqt.EventUpdate))) }
func (g *Generator) generateRepositoryUpdateOneByPrimaryKey(w io.Writer, table *pqt.Table) { entityName := g.name(table.Name) pk, ok := table.PrimaryKey() if !ok { return } fmt.Fprintf(w, "func (r *%sRepositoryBase) %s%s(%s %s, patch *%sPatch) (*%sEntity, error) {\n", entityName, g.name("UpdateOneBy"), g.public(pk.Name), g.private(pk.Name), g.generateColumnTypeString(pk, modeMandatory), entityName, entityName) fmt.Fprintf(w, "update := pqcomp.New(1, %d)\n", len(table.Columns)) fmt.Fprintf(w, "update.AddArg(%s)\n", g.private(pk.Name)) fmt.Fprintln(w, "") ColumnsLoop: for _, c := range table.Columns { if c == pk { continue ColumnsLoop } if _, ok := c.DefaultOn(pqt.EventInsert, pqt.EventUpdate); ok { switch c.Type { case pqt.TypeTimestamp(), pqt.TypeTimestampTZ(): fmt.Fprintf(w, "if patch.%s != nil {\n", g.propertyName(c.Name)) } } else if g.canBeNil(c, modeOptional) { fmt.Fprintf(w, "if patch.%s != nil {\n", g.propertyName(c.Name)) } fmt.Fprint(w, "update.AddExpr(") g.writeTableNameColumnNameTo(w, c.Table.Name, c.Name) fmt.Fprintf(w, ", pqcomp.Equal, patch.%s)\n", g.propertyName(c.Name)) if d, ok := c.DefaultOn(pqt.EventUpdate); ok { switch c.Type { case pqt.TypeTimestamp(), pqt.TypeTimestampTZ(): fmt.Fprint(w, `} else {`) fmt.Fprint(w, "update.AddExpr(") g.writeTableNameColumnNameTo(w, c.Table.Name, c.Name) fmt.Fprintf(w, `, pqcomp.Equal, "%s")`, d) } } if _, ok := c.DefaultOn(pqt.EventInsert, pqt.EventUpdate); ok { switch c.Type { case pqt.TypeTimestamp(), pqt.TypeTimestampTZ(): fmt.Fprint(w, "\n}\n") } } else if g.canBeNil(c, modeOptional) { fmt.Fprint(w, "\n}\n") } } fmt.Fprintf(w, ` if update.Len() == 0 { return nil, errors.New("%s update failure, nothing to update") }`, entityName) fmt.Fprintf(w, ` query := "UPDATE %s SET " for update.Next() { if !update.First() { query += ", " } query += update.Key() + " " + update.Oper() + " " + update.PlaceHolder() } query += " WHERE %s = $1 RETURNING " + strings.Join(r.columns, ", ") var e %sEntity err := r.db.QueryRow(query, update.Args()...).Scan( `, table.FullName(), pk.Name, entityName) for _, c := range table.Columns { fmt.Fprintf(w, "&e.%s,\n", g.propertyName(c.Name)) } fmt.Fprint(w, `) if err != nil { return nil, err } return &e, nil } `) }
func (g *Generator) generateRepositoryUpdateOneByUniqueConstraint(w io.Writer, table *pqt.Table) { entityName := g.name(table.Name) var unique []*pqt.Constraint for _, c := range tableConstraints(table) { if c.Type == pqt.ConstraintTypeUnique { unique = append(unique, c) } } if len(unique) < 1 { return } for _, u := range unique { arguments := "" methodName := "UpdateOneBy" for i, c := range u.Columns { if i != 0 { methodName += "And" arguments += ", " } methodName += g.public(c.Name) arguments += fmt.Sprintf("%s %s", g.private(columnForeignName(c)), g.generateColumnTypeString(c, modeMandatory)) } fmt.Fprintf(w, `func (r *%sRepositoryBase) %s(%s, patch *%sPatch) (*%sEntity, error) { `, entityName, g.name(methodName), arguments, entityName, entityName) fmt.Fprintf(w, "update := pqcomp.New(%d, %d)\n", len(u.Columns), len(table.Columns)) for _, c := range u.Columns { fmt.Fprintf(w, "update.AddArg(%s)\n", g.private(columnForeignName(c))) } pk, pkOK := table.PrimaryKey() ColumnsLoop: for _, c := range table.Columns { if pkOK && c == pk { continue ColumnsLoop } for _, uc := range u.Columns { if c == uc { continue } } if _, ok := c.DefaultOn(pqt.EventInsert, pqt.EventUpdate); ok { switch c.Type { case pqt.TypeTimestamp(), pqt.TypeTimestampTZ(): fmt.Fprintf(w, "if patch.%s != nil {\n", g.propertyName(c.Name)) } } else if g.canBeNil(c, modeOptional) { fmt.Fprintf(w, "if patch.%s != nil {\n", g.propertyName(c.Name)) } fmt.Fprint(w, "update.AddExpr(") g.writeTableNameColumnNameTo(w, c.Table.Name, c.Name) fmt.Fprintf(w, ", pqcomp.Equal, patch.%s)\n", g.propertyName(c.Name)) if d, ok := c.DefaultOn(pqt.EventUpdate); ok { switch c.Type { case pqt.TypeTimestamp(), pqt.TypeTimestampTZ(): fmt.Fprint(w, `} else {`) fmt.Fprint(w, "update.AddExpr(") g.writeTableNameColumnNameTo(w, c.Table.Name, c.Name) fmt.Fprintf(w, `, pqcomp.Equal, "%s")`, d) } } if _, ok := c.DefaultOn(pqt.EventInsert, pqt.EventUpdate); ok { switch c.Type { case pqt.TypeTimestamp(), pqt.TypeTimestampTZ(): fmt.Fprint(w, "\n}\n") } } else if g.canBeNil(c, modeOptional) { fmt.Fprint(w, "\n}\n") } } fmt.Fprintf(w, ` if update.Len() == 0 { return nil, errors.New("%s update failure, nothing to update") }`, entityName) fmt.Fprintf(w, ` query := "UPDATE %s SET " for update.Next() { if !update.First() { query += ", " } query += update.Key() + " " + update.Oper() + " " + update.PlaceHolder() } `, table.FullName()) fmt.Fprint(w, `query += " WHERE `) for i, c := range u.Columns { if i != 0 { fmt.Fprint(w, " AND ") } fmt.Fprintf(w, "%s = $%d", c.Name, i+1) } fmt.Fprintf(w, ` RETURNING " + strings.Join(r.columns, ", ") if r.dbg { if err := r.log.Log("msg", query, "function", "%s"); err != nil { return nil, err } } var e %sEntity err := r.db.QueryRow(query, update.Args()...).Scan( `, methodName, entityName) for _, c := range table.Columns { fmt.Fprintf(w, "&e.%s,\n", g.propertyName(c.Name)) } fmt.Fprint(w, `) if err != nil { return nil, err } return &e, nil } `) } }
func TestGenerator_Generate(t *testing.T) { success := []struct { expected string given *pqt.Table }{ { expected: `-- do not modify, generated by pqt CREATE TEMPORARY TABLE schema.user ( created_at TIMESTAMPTZ, password TEXT, username TEXT NOT NULL ); `, given: func() *pqt.Table { return pqt.NewTable("user", pqt.WithTemporary()). SetSchema(pqt.NewSchema("schema")). AddColumn(&pqt.Column{Name: "username", Type: pqt.TypeText(), NotNull: true}). AddColumn(&pqt.Column{Name: "password", Type: pqt.TypeText()}). AddColumn(&pqt.Column{Name: "created_at", Type: pqt.TypeTimestampTZ()}) }(), }, { expected: `-- do not modify, generated by pqt CREATE TABLE IF NOT EXISTS table_name ( created_at TIMESTAMPTZ DEFAULT NOW() NOT NULL, created_by INTEGER NOT NULL, enabled BOOL, end_at TIMESTAMPTZ NOT NULL, id SERIAL, name TEXT, price DECIMAL(10,1), rel_id INTEGER, slug TEXT NOT NULL, start_at TIMESTAMPTZ NOT NULL, updated_at TIMESTAMPTZ, updated_by INTEGER, CONSTRAINT "public.table_name_id_pkey" PRIMARY KEY (id), CONSTRAINT "public.table_name_name_key" UNIQUE (name), CONSTRAINT "public.table_name_slug_key" UNIQUE (slug), CONSTRAINT "public.table_name_rel_id_fkey" FOREIGN KEY (rel_id) REFERENCES related_table (id), CONSTRAINT "public.table_name_start_at_end_at_check" CHECK ((start_at IS NULL AND end_at IS NULL) OR start_at < end_at) ); `, given: func() *pqt.Table { id := pqt.Column{Name: "id", Type: pqt.TypeSerial()} startAt := &pqt.Column{Name: "start_at", Type: pqt.TypeTimestampTZ(), NotNull: true} endAt := &pqt.Column{Name: "end_at", Type: pqt.TypeTimestampTZ(), NotNull: true} _ = pqt.NewTable("related_table"). AddColumn(&id) return pqt.NewTable("table_name", pqt.WithTableIfNotExists()). AddColumn(&pqt.Column{Name: "id", Type: pqt.TypeSerial(), PrimaryKey: true}). AddColumn(&pqt.Column{Name: "rel_id", Type: pqt.TypeInteger(), Reference: &id}). AddColumn(&pqt.Column{Name: "name", Type: pqt.TypeText(), Unique: true}). AddColumn(&pqt.Column{Name: "enabled", Type: pqt.TypeBool()}). AddColumn(&pqt.Column{Name: "price", Type: pqt.TypeDecimal(10, 1)}). AddColumn(startAt). AddColumn(endAt). AddColumn(pqt.NewColumn("created_at", pqt.TypeTimestampTZ(), pqt.WithNotNull(), pqt.WithDefault("NOW()"))). AddColumn(&pqt.Column{Name: "created_by", Type: pqt.TypeInteger(), NotNull: true}). AddColumn(&pqt.Column{Name: "updated_at", Type: pqt.TypeTimestampTZ()}). AddColumn(&pqt.Column{Name: "updated_by", Type: pqt.TypeInteger()}). AddColumn(&pqt.Column{Name: "slug", Type: pqt.TypeText(), NotNull: true, Unique: true}). AddCheck("(start_at IS NULL AND end_at IS NULL) OR start_at < end_at", startAt, endAt) }(), }, } for i, data := range success { q, err := pqtsql.NewGenerator().Generate(&pqt.Schema{ Tables: []*pqt.Table{data.given}, }) if err != nil { t.Errorf("unexpected error for schema #%d : %s", i, err.Error()) continue } if string(q) != data.expected { t.Errorf("wrong query, expected:\n'%s'\nbut got:\n'%s'", data.expected, q) } } }