Example #1
0
func TestTable_AddRelationship_oneToOneUnidirectional(t *testing.T) {
	user := pqt.NewTable("user").AddColumn(pqt.NewColumn("id", pqt.TypeSerial(), pqt.WithPrimaryKey()))
	userDetail := pqt.NewTable("user_detail").AddColumn(pqt.NewColumn("id", pqt.TypeSerial(), pqt.WithPrimaryKey())).
		AddRelationship(pqt.OneToOne(
			user,
			pqt.WithInversedName("user"),
			pqt.WithOwnerName("details"),
		))

	if len(user.InversedRelationships) != 0 {
		t.Fatalf("user should have 0 relationship, but has %d", len(user.InversedRelationships))
	}

	if len(userDetail.OwnedRelationships) != 1 {
		t.Fatalf("user_detail should have 1 relationship, but has %d", len(userDetail.OwnedRelationships))
	}

	if userDetail.OwnedRelationships[0].InversedName != "user" {
		t.Errorf("user_detail relationship to user should be mapped by user")
	}

	if userDetail.OwnedRelationships[0].InversedTable != user {
		t.Errorf("user_detail relationship to user should be mapped by user table")
	}

	if userDetail.OwnedRelationships[0].Type != pqt.RelationshipTypeOneToOne {
		t.Errorf("user_detail relationship to user should be %d, but is %d", pqt.RelationshipTypeOneToOne, userDetail.OwnedRelationships[0].Type)
	}
}
Example #2
0
func TestWithColumnName(t *testing.T) {
	icn := "author"
	t1 := pqt.NewTable("user").AddColumn(pqt.NewColumn("id", pqt.TypeSerial(), pqt.WithPrimaryKey()))
	t2 := pqt.NewTable("comment")

	t2.AddRelationship(pqt.OneToOne(t1, pqt.WithColumnName(icn)))

	if len(t1.InversedRelationships) != 0 {
		t.Fatalf("user table should have exactly 0 relationship, got %d", len(t1.InversedRelationships))
	}

	if len(t2.OwnedRelationships) != 1 {
		t.Fatalf("comment table should have exactly 1 relationship, got %d", len(t2.OwnedRelationships))
	}

	var exists bool
	for _, c := range t2.Columns {
		if c.Name == icn {
			exists = true
			break
		}
	}

	if !exists {
		t.Errorf("comment table should have column with name %s", icn)
	}
}
Example #3
0
func TestTable_AddRelationship_oneToOneSelfReferencing(t *testing.T) {
	user := pqt.NewTable("user").AddColumn(pqt.NewColumn("id", pqt.TypeSerial(), pqt.WithPrimaryKey()))

	user.AddRelationship(pqt.OneToOne(
		pqt.SelfReference(),
		pqt.WithInversedName("child"),
		pqt.WithOwnerName("parent"),
	))

	if len(user.OwnedRelationships) != 1 {
		t.Fatalf("user should have 1 owned relationship, but has %d", len(user.OwnedRelationships))
	}

	if user.OwnedRelationships[0].OwnerName != "parent" {
		t.Errorf("user relationship to user should be mapped by parent")
	}

	if user.OwnedRelationships[0].OwnerTable != user {
		t.Errorf("user relationship to user should be mapped by user table")
	}

	if user.OwnedRelationships[0].Type != pqt.RelationshipTypeOneToOne {
		t.Errorf("user relationship to user should be %d, but is %d", pqt.RelationshipTypeOneToOne, user.OwnedRelationships[0].Type)
	}

	if len(user.InversedRelationships) != 0 {
		t.Fatalf("user should have 0 inversed relationship, but has %d", len(user.InversedRelationships))
	}
}
Example #4
0
func TestTable_AddRelationship_oneToMany(t *testing.T) {
	user := pqt.NewTable("user").AddColumn(pqt.NewColumn("id", pqt.TypeSerial(), pqt.WithPrimaryKey()))
	comment := pqt.NewTable("comment").AddColumn(pqt.NewColumn("id", pqt.TypeSerial(), pqt.WithPrimaryKey()))

	user.AddRelationship(pqt.OneToMany(
		comment,
		pqt.WithBidirectional(),
		pqt.WithInversedName("author"),
		pqt.WithOwnerName("comments"),
	))

	if len(user.InversedRelationships) != 1 {
		t.Fatalf("user should have 1 inversed relationship, but has %d", len(user.InversedRelationships))
	}

	if user.InversedRelationships[0].OwnerName != "comments" {
		t.Errorf("user inversed relationship to comment should be mapped by comments")
	}

	if user.InversedRelationships[0].OwnerTable != comment {
		t.Errorf("user inversed relationship to comment should be mapped by comment table")
	}

	if user.InversedRelationships[0].Type != pqt.RelationshipTypeOneToMany {
		t.Errorf("user inversed relationship to comment should be one to many")
	}

	if len(comment.OwnedRelationships) != 1 {
		t.Fatalf("comment should have 1 owned relationship, but has %d", len(comment.OwnedRelationships))
	}

	if comment.OwnedRelationships[0].InversedName != "author" {
		t.Errorf("comment relationship to user should be mapped by author")
	}

	if comment.OwnedRelationships[0].InversedTable != user {
		t.Errorf("comment relationship to user should be mapped by user table")
	}

	if comment.OwnedRelationships[0].Type != pqt.RelationshipTypeOneToMany {
		t.Errorf("comment relationship to user should be %d, but is %d", pqt.RelationshipTypeOneToMany, comment.OwnedRelationships[0].Type)
	}
}
Example #5
0
func generateBaseType(t pqt.Type, m int32) string {
	switch t {
	case pqt.TypeText():
		return chooseType("string", "*ntypes.String", "*qtypes.String", m)
	case pqt.TypeBool():
		return chooseType("bool", "*ntypes.Bool", "*ntypes.Bool", m)
	case pqt.TypeIntegerSmall():
		return chooseType("int16", "*int16", "*int16", m)
	case pqt.TypeInteger():
		return chooseType("int32", "*ntypes.Int32", "*ntypes.Int32", m)
	case pqt.TypeIntegerBig():
		return chooseType("int64", "*ntypes.Int64", "*qtypes.Int64", m)
	case pqt.TypeSerial():
		return chooseType("int32", "*ntypes.Int32", "*ntypes.Int32", m)
	case pqt.TypeSerialSmall():
		return chooseType("int16", "*int16", "*int16", m)
	case pqt.TypeSerialBig():
		return chooseType("int64", "*ntypes.Int64", "*qtypes.Int64", m)
	case pqt.TypeTimestamp(), pqt.TypeTimestampTZ():
		return chooseType("time.Time", "*time.Time", "*qtypes.Timestamp", m)
	case pqt.TypeReal():
		return chooseType("float32", "*ntypes.Float32", "*ntypes.Float32", m)
	case pqt.TypeDoublePrecision():
		return chooseType("float64", "*ntypes.Float64", "*qtypes.Float64", m)
	case pqt.TypeBytea(), pqt.TypeJSON(), pqt.TypeJSONB():
		return "[]byte"
	case pqt.TypeUUID():
		return "uuid.UUID"
	default:
		gt := t.String()
		switch {
		case strings.HasPrefix(gt, "SMALLINT["):
			return chooseType("pqt.ArrayInt64", "pqt.ArrayInt64", "*qtypes.Int64", m)
		case strings.HasPrefix(gt, "INTEGER["):
			return chooseType("pqt.ArrayInt64", "pqt.ArrayInt64", "*qtypes.Int64", m)
		case strings.HasPrefix(gt, "BIGINT["):
			return chooseType("pqt.ArrayInt64", "pqt.ArrayInt64", "*qtypes.Int64", m)
		case strings.HasPrefix(gt, "DOUBLE PRECISION["):
			return chooseType("pqt.ArrayFloat64", "pqt.ArrayFloat64", "*qtypes.Float64", m)
		case strings.HasPrefix(gt, "TEXT["):
			return "pqt.ArrayString"
		case strings.HasPrefix(gt, "DECIMAL"), strings.HasPrefix(gt, "NUMERIC"):
			return chooseType("float64", "*ntypes.Float64", "*qtypes.Float64", m)
		case strings.HasPrefix(gt, "VARCHAR"):
			return chooseType("string", "*ntypes.String", "*qtypes.String", m)
		default:
			return "interface{}"
		}
	}
}
Example #6
0
func TestConstraint_Name(t *testing.T) {
	id := pqt.NewColumn("id", pqt.TypeSerial(), pqt.WithPrimaryKey())
	success := map[string]*pqt.Constraint{
		"public.user_id_pkey": pqt.PrimaryKey(pqt.NewTable("user"), id),
		"custom_schema.user_id_pkey": pqt.PrimaryKey(func() *pqt.Table {
			t := pqt.NewTable("user")
			s := pqt.NewSchema("custom_schema")
			s.AddTable(t)

			return t
		}(), id),
		"<missing table>": pqt.Check(nil, "a > b", id),
		"public.news_key": pqt.Unique(pqt.NewTable("news")),
	}

	for expected, given := range success {
		got := given.Name()

		if got != expected {
			t.Errorf("wrong name, expected %s got %s", expected, got)
		}
	}
}
Example #7
0
func (g *Generator) generateRepositoryUpsert(code *bytes.Buffer, table *pqt.Table) {
	if g.ver < 9.5 {
		return
	}
	entityName := g.name(table.Name)

	fmt.Fprintf(code, `func (r *%sRepositoryBase) %s(e *%sEntity, p *%sPatch, inf ...string) (*%sEntity, error) {`,
		entityName, g.name("Upsert"),
		entityName, entityName, entityName,
	)
	fmt.Fprintf(code, `
		insert := pqcomp.New(0, %d)
		update := insert.Compose(%d)
	`, len(table.Columns), len(table.Columns))

InsertLoop:
	for _, c := range table.Columns {
		switch c.Type {
		case pqt.TypeSerial(), pqt.TypeSerialBig(), pqt.TypeSerialSmall():
			continue InsertLoop
		default:
			if g.canBeNil(c, modeOptional) {
				fmt.Fprintf(code, `
					if e.%s != nil {
						insert.AddExpr(%s, "", e.%s)
					}
				`,
					g.propertyName(c.Name),
					g.columnNameWithTableName(table.Name, c.Name), g.propertyName(c.Name),
				)
			} else {
				fmt.Fprintf(code, `insert.AddExpr(%s, "", e.%s)`,
					g.columnNameWithTableName(table.Name, c.Name),
					g.propertyName(c.Name),
				)
			}
			fmt.Fprintln(code, "")
		}
	}
	fmt.Fprintln(code, "if len(inf) > 0 {")
UpdateLoop:
	for _, c := range table.Columns {
		switch c.Type {
		case pqt.TypeSerial(), pqt.TypeSerialBig(), pqt.TypeSerialSmall():
			continue UpdateLoop
		default:
			if g.canBeNil(c, modeOptional) {
				fmt.Fprintf(code, `
					if p.%s != nil {
						update.AddExpr(%s, "=", p.%s)
					}
				`,
					g.propertyName(c.Name),
					g.columnNameWithTableName(table.Name, c.Name),
					g.propertyName(c.Name),
				)
			} else {
				fmt.Fprintf(code, `update.AddExpr(%s, "=", p.%s)`, g.columnNameWithTableName(table.Name, c.Name), g.propertyName(c.Name))
			}
			fmt.Fprintln(code, "")
		}
	}
	fmt.Fprintln(code, "}")

	fmt.Fprint(code, `
		b := bytes.NewBufferString("INSERT INTO " + r.table)

		if insert.Len() > 0 {
			b.WriteString(" (")
			for insert.Next() {
				if !insert.First() {
					b.WriteString(", ")
				}

				fmt.Fprintf(b, "%s", insert.Key())
			}
			insert.Reset()
			b.WriteString(") VALUES (")
			for insert.Next() {
				if !insert.First() {
					b.WriteString(", ")
				}

				fmt.Fprintf(b, "%s", insert.PlaceHolder())
			}
			b.WriteString(")")
		}
		b.WriteString(" ON CONFLICT ")
		if len(inf) > 0 && update.Len() > 0 {
			b.WriteString(" (")
			for j, i := range inf {
				if j != 0 {
					b.WriteString(", ")
				}
				b.WriteString(i)
			}
			b.WriteString(") ")
			b.WriteString(" DO UPDATE SET ")
			for update.Next() {
				if !update.First() {
					b.WriteString(", ")
				}

				b.WriteString(update.Key())
				b.WriteString(" ")
				b.WriteString(update.Oper())
				b.WriteString(" ")
				b.WriteString(update.PlaceHolder())
			}
		} else {
			b.WriteString(" DO NOTHING ")
		}
		if insert.Len() > 0 {
			if len(r.columns) > 0 {
				b.WriteString(" RETURNING ")
				b.WriteString(strings.Join(r.columns, ", "))
			}
		}

		if r.dbg {
			if err := r.log.Log("msg", b.String(), "function", "Upsert"); err != nil {
				return nil, err
			}
		}

		err := r.db.QueryRow(b.String(), insert.Args()...).Scan(
	`)

	for _, c := range table.Columns {
		fmt.Fprintf(code, "&e.%s,\n", g.propertyName(c.Name))
	}
	fmt.Fprint(code, `)
		if err != nil {
			return nil, err
		}

		return e, nil
	}
`)
}
Example #8
0
func (g *Generator) generateRepositoryInsert(w io.Writer,
	table *pqt.Table) {
	entityName := g.name(table.Name)

	fmt.Fprintf(w, `func (r *%sRepositoryBase) %s(e *%sEntity) (*%sEntity, error) {`, entityName, g.name("Insert"), entityName, entityName)
	fmt.Fprintf(w, `
		insert := pqcomp.New(0, %d)
	`, len(table.Columns))

ColumnsLoop:
	for _, c := range table.Columns {
		switch c.Type {
		case pqt.TypeSerial(), pqt.TypeSerialBig(), pqt.TypeSerialSmall():
			continue ColumnsLoop
		default:
			if g.canBeNil(c, modeOptional) {
				fmt.Fprintf(w, `
					if e.%s != nil {
						insert.AddExpr(%s, "", e.%s)
					}
				`,
					g.propertyName(c.Name),
					g.columnNameWithTableName(table.Name, c.Name),
					g.propertyName(c.Name),
				)
			} else {
				fmt.Fprintf(
					w,
					`insert.AddExpr(%s, "", e.%s)`,
					g.columnNameWithTableName(table.Name, c.Name),
					g.propertyName(c.Name),
				)
			}
			fmt.Fprintln(w, "")
		}
	}
	fmt.Fprint(w, `
		b := bytes.NewBufferString("INSERT INTO " + r.table)

		if insert.Len() != 0 {
			b.WriteString(" (")
			for insert.Next() {
				if !insert.First() {
					b.WriteString(", ")
				}

				fmt.Fprintf(b, "%s", insert.Key())
			}
			insert.Reset()
			b.WriteString(") VALUES (")
			for insert.Next() {
				if !insert.First() {
					b.WriteString(", ")
				}

				fmt.Fprintf(b, "%s", insert.PlaceHolder())
			}
			b.WriteString(")")
			if len(r.columns) > 0 {
				b.WriteString(" RETURNING ")
				b.WriteString(strings.Join(r.columns, ", "))
			}
		}

		if r.dbg {
			if err := r.log.Log("msg", b.String(), "function", "Insert"); err != nil {
				return nil, err
			}
		}

		err := r.db.QueryRow(b.String(), insert.Args()...).Scan(
	`)

	for _, c := range table.Columns {
		fmt.Fprintf(w, "&e.%s,\n", g.propertyName(c.Name))
	}
	fmt.Fprint(w, `)
		if err != nil {
			return nil, err
		}

		return e, nil
	}
`)
}
Example #9
0
func TestGenerator_Generate(t *testing.T) {
	success := []struct {
		expected string
		given    *pqt.Table
	}{
		{
			expected: `-- do not modify, generated by pqt

CREATE TEMPORARY TABLE schema.user (
	created_at TIMESTAMPTZ,
	password TEXT,
	username TEXT NOT NULL
);

`,
			given: func() *pqt.Table {
				return pqt.NewTable("user", pqt.WithTemporary()).
					SetSchema(pqt.NewSchema("schema")).
					AddColumn(&pqt.Column{Name: "username", Type: pqt.TypeText(), NotNull: true}).
					AddColumn(&pqt.Column{Name: "password", Type: pqt.TypeText()}).
					AddColumn(&pqt.Column{Name: "created_at", Type: pqt.TypeTimestampTZ()})
			}(),
		},
		{
			expected: `-- do not modify, generated by pqt

CREATE TABLE IF NOT EXISTS table_name (
	created_at TIMESTAMPTZ DEFAULT NOW() NOT NULL,
	created_by INTEGER NOT NULL,
	enabled BOOL,
	end_at TIMESTAMPTZ NOT NULL,
	id SERIAL,
	name TEXT,
	price DECIMAL(10,1),
	rel_id INTEGER,
	slug TEXT NOT NULL,
	start_at TIMESTAMPTZ NOT NULL,
	updated_at TIMESTAMPTZ,
	updated_by INTEGER,

	CONSTRAINT "public.table_name_id_pkey" PRIMARY KEY (id),
	CONSTRAINT "public.table_name_name_key" UNIQUE (name),
	CONSTRAINT "public.table_name_slug_key" UNIQUE (slug),
	CONSTRAINT "public.table_name_rel_id_fkey" FOREIGN KEY (rel_id) REFERENCES related_table (id),
	CONSTRAINT "public.table_name_start_at_end_at_check" CHECK ((start_at IS NULL AND end_at IS NULL) OR start_at < end_at)
);

`,
			given: func() *pqt.Table {
				id := pqt.Column{Name: "id", Type: pqt.TypeSerial()}
				startAt := &pqt.Column{Name: "start_at", Type: pqt.TypeTimestampTZ(), NotNull: true}
				endAt := &pqt.Column{Name: "end_at", Type: pqt.TypeTimestampTZ(), NotNull: true}
				_ = pqt.NewTable("related_table").
					AddColumn(&id)

				return pqt.NewTable("table_name", pqt.WithTableIfNotExists()).
					AddColumn(&pqt.Column{Name: "id", Type: pqt.TypeSerial(), PrimaryKey: true}).
					AddColumn(&pqt.Column{Name: "rel_id", Type: pqt.TypeInteger(), Reference: &id}).
					AddColumn(&pqt.Column{Name: "name", Type: pqt.TypeText(), Unique: true}).
					AddColumn(&pqt.Column{Name: "enabled", Type: pqt.TypeBool()}).
					AddColumn(&pqt.Column{Name: "price", Type: pqt.TypeDecimal(10, 1)}).
					AddColumn(startAt).
					AddColumn(endAt).
					AddColumn(pqt.NewColumn("created_at", pqt.TypeTimestampTZ(), pqt.WithNotNull(), pqt.WithDefault("NOW()"))).
					AddColumn(&pqt.Column{Name: "created_by", Type: pqt.TypeInteger(), NotNull: true}).
					AddColumn(&pqt.Column{Name: "updated_at", Type: pqt.TypeTimestampTZ()}).
					AddColumn(&pqt.Column{Name: "updated_by", Type: pqt.TypeInteger()}).
					AddColumn(&pqt.Column{Name: "slug", Type: pqt.TypeText(), NotNull: true, Unique: true}).
					AddCheck("(start_at IS NULL AND end_at IS NULL) OR start_at < end_at", startAt, endAt)
			}(),
		},
	}

	for i, data := range success {
		q, err := pqtsql.NewGenerator().Generate(&pqt.Schema{
			Tables: []*pqt.Table{data.given},
		})

		if err != nil {
			t.Errorf("unexpected error for schema #%d : %s", i, err.Error())
			continue
		}

		if string(q) != data.expected {
			t.Errorf("wrong query, expected:\n'%s'\nbut got:\n'%s'", data.expected, q)
		}
	}
}