func generateBaseType(t pqt.Type, m int32) string { switch t { case pqt.TypeText(): return chooseType("string", "*ntypes.String", "*qtypes.String", m) case pqt.TypeBool(): return chooseType("bool", "*ntypes.Bool", "*ntypes.Bool", m) case pqt.TypeIntegerSmall(): return chooseType("int16", "*int16", "*int16", m) case pqt.TypeInteger(): return chooseType("int32", "*ntypes.Int32", "*ntypes.Int32", m) case pqt.TypeIntegerBig(): return chooseType("int64", "*ntypes.Int64", "*qtypes.Int64", m) case pqt.TypeSerial(): return chooseType("int32", "*ntypes.Int32", "*ntypes.Int32", m) case pqt.TypeSerialSmall(): return chooseType("int16", "*int16", "*int16", m) case pqt.TypeSerialBig(): return chooseType("int64", "*ntypes.Int64", "*qtypes.Int64", m) case pqt.TypeTimestamp(), pqt.TypeTimestampTZ(): return chooseType("time.Time", "*time.Time", "*qtypes.Timestamp", m) case pqt.TypeReal(): return chooseType("float32", "*ntypes.Float32", "*ntypes.Float32", m) case pqt.TypeDoublePrecision(): return chooseType("float64", "*ntypes.Float64", "*qtypes.Float64", m) case pqt.TypeBytea(), pqt.TypeJSON(), pqt.TypeJSONB(): return "[]byte" case pqt.TypeUUID(): return "uuid.UUID" default: gt := t.String() switch { case strings.HasPrefix(gt, "SMALLINT["): return chooseType("pqt.ArrayInt64", "pqt.ArrayInt64", "*qtypes.Int64", m) case strings.HasPrefix(gt, "INTEGER["): return chooseType("pqt.ArrayInt64", "pqt.ArrayInt64", "*qtypes.Int64", m) case strings.HasPrefix(gt, "BIGINT["): return chooseType("pqt.ArrayInt64", "pqt.ArrayInt64", "*qtypes.Int64", m) case strings.HasPrefix(gt, "DOUBLE PRECISION["): return chooseType("pqt.ArrayFloat64", "pqt.ArrayFloat64", "*qtypes.Float64", m) case strings.HasPrefix(gt, "TEXT["): return "pqt.ArrayString" case strings.HasPrefix(gt, "DECIMAL"), strings.HasPrefix(gt, "NUMERIC"): return chooseType("float64", "*ntypes.Float64", "*qtypes.Float64", m) case strings.HasPrefix(gt, "VARCHAR"): return chooseType("string", "*ntypes.String", "*qtypes.String", m) default: return "interface{}" } } }
func TestGenerator_Generate(t *testing.T) { success := []struct { expected string given *pqt.Table }{ { expected: `-- do not modify, generated by pqt CREATE TEMPORARY TABLE schema.user ( created_at TIMESTAMPTZ, password TEXT, username TEXT NOT NULL ); `, given: func() *pqt.Table { return pqt.NewTable("user", pqt.WithTemporary()). SetSchema(pqt.NewSchema("schema")). AddColumn(&pqt.Column{Name: "username", Type: pqt.TypeText(), NotNull: true}). AddColumn(&pqt.Column{Name: "password", Type: pqt.TypeText()}). AddColumn(&pqt.Column{Name: "created_at", Type: pqt.TypeTimestampTZ()}) }(), }, { expected: `-- do not modify, generated by pqt CREATE TABLE IF NOT EXISTS table_name ( created_at TIMESTAMPTZ DEFAULT NOW() NOT NULL, created_by INTEGER NOT NULL, enabled BOOL, end_at TIMESTAMPTZ NOT NULL, id SERIAL, name TEXT, price DECIMAL(10,1), rel_id INTEGER, slug TEXT NOT NULL, start_at TIMESTAMPTZ NOT NULL, updated_at TIMESTAMPTZ, updated_by INTEGER, CONSTRAINT "public.table_name_id_pkey" PRIMARY KEY (id), CONSTRAINT "public.table_name_name_key" UNIQUE (name), CONSTRAINT "public.table_name_slug_key" UNIQUE (slug), CONSTRAINT "public.table_name_rel_id_fkey" FOREIGN KEY (rel_id) REFERENCES related_table (id), CONSTRAINT "public.table_name_start_at_end_at_check" CHECK ((start_at IS NULL AND end_at IS NULL) OR start_at < end_at) ); `, given: func() *pqt.Table { id := pqt.Column{Name: "id", Type: pqt.TypeSerial()} startAt := &pqt.Column{Name: "start_at", Type: pqt.TypeTimestampTZ(), NotNull: true} endAt := &pqt.Column{Name: "end_at", Type: pqt.TypeTimestampTZ(), NotNull: true} _ = pqt.NewTable("related_table"). AddColumn(&id) return pqt.NewTable("table_name", pqt.WithTableIfNotExists()). AddColumn(&pqt.Column{Name: "id", Type: pqt.TypeSerial(), PrimaryKey: true}). AddColumn(&pqt.Column{Name: "rel_id", Type: pqt.TypeInteger(), Reference: &id}). AddColumn(&pqt.Column{Name: "name", Type: pqt.TypeText(), Unique: true}). AddColumn(&pqt.Column{Name: "enabled", Type: pqt.TypeBool()}). AddColumn(&pqt.Column{Name: "price", Type: pqt.TypeDecimal(10, 1)}). AddColumn(startAt). AddColumn(endAt). AddColumn(pqt.NewColumn("created_at", pqt.TypeTimestampTZ(), pqt.WithNotNull(), pqt.WithDefault("NOW()"))). AddColumn(&pqt.Column{Name: "created_by", Type: pqt.TypeInteger(), NotNull: true}). AddColumn(&pqt.Column{Name: "updated_at", Type: pqt.TypeTimestampTZ()}). AddColumn(&pqt.Column{Name: "updated_by", Type: pqt.TypeInteger()}). AddColumn(&pqt.Column{Name: "slug", Type: pqt.TypeText(), NotNull: true, Unique: true}). AddCheck("(start_at IS NULL AND end_at IS NULL) OR start_at < end_at", startAt, endAt) }(), }, } for i, data := range success { q, err := pqtsql.NewGenerator().Generate(&pqt.Schema{ Tables: []*pqt.Table{data.given}, }) if err != nil { t.Errorf("unexpected error for schema #%d : %s", i, err.Error()) continue } if string(q) != data.expected { t.Errorf("wrong query, expected:\n'%s'\nbut got:\n'%s'", data.expected, q) } } }