func generateBaseType(t pqt.Type, m int32) string { switch t { case pqt.TypeText(): return chooseType("string", "*ntypes.String", "*qtypes.String", m) case pqt.TypeBool(): return chooseType("bool", "*ntypes.Bool", "*ntypes.Bool", m) case pqt.TypeIntegerSmall(): return chooseType("int16", "*int16", "*int16", m) case pqt.TypeInteger(): return chooseType("int32", "*ntypes.Int32", "*ntypes.Int32", m) case pqt.TypeIntegerBig(): return chooseType("int64", "*ntypes.Int64", "*qtypes.Int64", m) case pqt.TypeSerial(): return chooseType("int32", "*ntypes.Int32", "*ntypes.Int32", m) case pqt.TypeSerialSmall(): return chooseType("int16", "*int16", "*int16", m) case pqt.TypeSerialBig(): return chooseType("int64", "*ntypes.Int64", "*qtypes.Int64", m) case pqt.TypeTimestamp(), pqt.TypeTimestampTZ(): return chooseType("time.Time", "*time.Time", "*qtypes.Timestamp", m) case pqt.TypeReal(): return chooseType("float32", "*ntypes.Float32", "*ntypes.Float32", m) case pqt.TypeDoublePrecision(): return chooseType("float64", "*ntypes.Float64", "*qtypes.Float64", m) case pqt.TypeBytea(), pqt.TypeJSON(), pqt.TypeJSONB(): return "[]byte" case pqt.TypeUUID(): return "uuid.UUID" default: gt := t.String() switch { case strings.HasPrefix(gt, "SMALLINT["): return chooseType("pqt.ArrayInt64", "pqt.ArrayInt64", "*qtypes.Int64", m) case strings.HasPrefix(gt, "INTEGER["): return chooseType("pqt.ArrayInt64", "pqt.ArrayInt64", "*qtypes.Int64", m) case strings.HasPrefix(gt, "BIGINT["): return chooseType("pqt.ArrayInt64", "pqt.ArrayInt64", "*qtypes.Int64", m) case strings.HasPrefix(gt, "DOUBLE PRECISION["): return chooseType("pqt.ArrayFloat64", "pqt.ArrayFloat64", "*qtypes.Float64", m) case strings.HasPrefix(gt, "TEXT["): return "pqt.ArrayString" case strings.HasPrefix(gt, "DECIMAL"), strings.HasPrefix(gt, "NUMERIC"): return chooseType("float64", "*ntypes.Float64", "*qtypes.Float64", m) case strings.HasPrefix(gt, "VARCHAR"): return chooseType("string", "*ntypes.String", "*qtypes.String", m) default: return "interface{}" } } }
func (g *Generator) generateRepositoryUpsert(code *bytes.Buffer, table *pqt.Table) { if g.ver < 9.5 { return } entityName := g.name(table.Name) fmt.Fprintf(code, `func (r *%sRepositoryBase) %s(e *%sEntity, p *%sPatch, inf ...string) (*%sEntity, error) {`, entityName, g.name("Upsert"), entityName, entityName, entityName, ) fmt.Fprintf(code, ` insert := pqcomp.New(0, %d) update := insert.Compose(%d) `, len(table.Columns), len(table.Columns)) InsertLoop: for _, c := range table.Columns { switch c.Type { case pqt.TypeSerial(), pqt.TypeSerialBig(), pqt.TypeSerialSmall(): continue InsertLoop default: if g.canBeNil(c, modeOptional) { fmt.Fprintf(code, ` if e.%s != nil { insert.AddExpr(%s, "", e.%s) } `, g.propertyName(c.Name), g.columnNameWithTableName(table.Name, c.Name), g.propertyName(c.Name), ) } else { fmt.Fprintf(code, `insert.AddExpr(%s, "", e.%s)`, g.columnNameWithTableName(table.Name, c.Name), g.propertyName(c.Name), ) } fmt.Fprintln(code, "") } } fmt.Fprintln(code, "if len(inf) > 0 {") UpdateLoop: for _, c := range table.Columns { switch c.Type { case pqt.TypeSerial(), pqt.TypeSerialBig(), pqt.TypeSerialSmall(): continue UpdateLoop default: if g.canBeNil(c, modeOptional) { fmt.Fprintf(code, ` if p.%s != nil { update.AddExpr(%s, "=", p.%s) } `, g.propertyName(c.Name), g.columnNameWithTableName(table.Name, c.Name), g.propertyName(c.Name), ) } else { fmt.Fprintf(code, `update.AddExpr(%s, "=", p.%s)`, g.columnNameWithTableName(table.Name, c.Name), g.propertyName(c.Name)) } fmt.Fprintln(code, "") } } fmt.Fprintln(code, "}") fmt.Fprint(code, ` b := bytes.NewBufferString("INSERT INTO " + r.table) if insert.Len() > 0 { b.WriteString(" (") for insert.Next() { if !insert.First() { b.WriteString(", ") } fmt.Fprintf(b, "%s", insert.Key()) } insert.Reset() b.WriteString(") VALUES (") for insert.Next() { if !insert.First() { b.WriteString(", ") } fmt.Fprintf(b, "%s", insert.PlaceHolder()) } b.WriteString(")") } b.WriteString(" ON CONFLICT ") if len(inf) > 0 && update.Len() > 0 { b.WriteString(" (") for j, i := range inf { if j != 0 { b.WriteString(", ") } b.WriteString(i) } b.WriteString(") ") b.WriteString(" DO UPDATE SET ") for update.Next() { if !update.First() { b.WriteString(", ") } b.WriteString(update.Key()) b.WriteString(" ") b.WriteString(update.Oper()) b.WriteString(" ") b.WriteString(update.PlaceHolder()) } } else { b.WriteString(" DO NOTHING ") } if insert.Len() > 0 { if len(r.columns) > 0 { b.WriteString(" RETURNING ") b.WriteString(strings.Join(r.columns, ", ")) } } if r.dbg { if err := r.log.Log("msg", b.String(), "function", "Upsert"); err != nil { return nil, err } } err := r.db.QueryRow(b.String(), insert.Args()...).Scan( `) for _, c := range table.Columns { fmt.Fprintf(code, "&e.%s,\n", g.propertyName(c.Name)) } fmt.Fprint(code, `) if err != nil { return nil, err } return e, nil } `) }
func (g *Generator) generateRepositoryInsert(w io.Writer, table *pqt.Table) { entityName := g.name(table.Name) fmt.Fprintf(w, `func (r *%sRepositoryBase) %s(e *%sEntity) (*%sEntity, error) {`, entityName, g.name("Insert"), entityName, entityName) fmt.Fprintf(w, ` insert := pqcomp.New(0, %d) `, len(table.Columns)) ColumnsLoop: for _, c := range table.Columns { switch c.Type { case pqt.TypeSerial(), pqt.TypeSerialBig(), pqt.TypeSerialSmall(): continue ColumnsLoop default: if g.canBeNil(c, modeOptional) { fmt.Fprintf(w, ` if e.%s != nil { insert.AddExpr(%s, "", e.%s) } `, g.propertyName(c.Name), g.columnNameWithTableName(table.Name, c.Name), g.propertyName(c.Name), ) } else { fmt.Fprintf( w, `insert.AddExpr(%s, "", e.%s)`, g.columnNameWithTableName(table.Name, c.Name), g.propertyName(c.Name), ) } fmt.Fprintln(w, "") } } fmt.Fprint(w, ` b := bytes.NewBufferString("INSERT INTO " + r.table) if insert.Len() != 0 { b.WriteString(" (") for insert.Next() { if !insert.First() { b.WriteString(", ") } fmt.Fprintf(b, "%s", insert.Key()) } insert.Reset() b.WriteString(") VALUES (") for insert.Next() { if !insert.First() { b.WriteString(", ") } fmt.Fprintf(b, "%s", insert.PlaceHolder()) } b.WriteString(")") if len(r.columns) > 0 { b.WriteString(" RETURNING ") b.WriteString(strings.Join(r.columns, ", ")) } } if r.dbg { if err := r.log.Log("msg", b.String(), "function", "Insert"); err != nil { return nil, err } } err := r.db.QueryRow(b.String(), insert.Args()...).Scan( `) for _, c := range table.Columns { fmt.Fprintf(w, "&e.%s,\n", g.propertyName(c.Name)) } fmt.Fprint(w, `) if err != nil { return nil, err } return e, nil } `) }