Ejemplo n.º 1
0
func TestColumn_DefaultOn(t *testing.T) {
	success := []struct {
		d string
		e []pqt.Event
	}{
		{
			d: "NOW()",
			e: []pqt.Event{pqt.EventUpdate},
		},
	}

	for _, data := range success {
		c := pqt.NewColumn("column", pqt.TypeTimestampTZ(), pqt.WithDefault(data.d, data.e...))

	EventLoop:
		for _, e := range data.e {
			d, ok := c.DefaultOn(e)
			if !ok {
				t.Errorf("missing default value for %s", e)
				continue EventLoop
			}

			if d != data.d {
				t.Errorf("wrong value, expected %s but got %s", data.d, d)
			}
		}
	}
}
Ejemplo n.º 2
0
func generateBaseType(t pqt.Type, m int32) string {
	switch t {
	case pqt.TypeText():
		return chooseType("string", "*ntypes.String", "*qtypes.String", m)
	case pqt.TypeBool():
		return chooseType("bool", "*ntypes.Bool", "*ntypes.Bool", m)
	case pqt.TypeIntegerSmall():
		return chooseType("int16", "*int16", "*int16", m)
	case pqt.TypeInteger():
		return chooseType("int32", "*ntypes.Int32", "*ntypes.Int32", m)
	case pqt.TypeIntegerBig():
		return chooseType("int64", "*ntypes.Int64", "*qtypes.Int64", m)
	case pqt.TypeSerial():
		return chooseType("int32", "*ntypes.Int32", "*ntypes.Int32", m)
	case pqt.TypeSerialSmall():
		return chooseType("int16", "*int16", "*int16", m)
	case pqt.TypeSerialBig():
		return chooseType("int64", "*ntypes.Int64", "*qtypes.Int64", m)
	case pqt.TypeTimestamp(), pqt.TypeTimestampTZ():
		return chooseType("time.Time", "*time.Time", "*qtypes.Timestamp", m)
	case pqt.TypeReal():
		return chooseType("float32", "*ntypes.Float32", "*ntypes.Float32", m)
	case pqt.TypeDoublePrecision():
		return chooseType("float64", "*ntypes.Float64", "*qtypes.Float64", m)
	case pqt.TypeBytea(), pqt.TypeJSON(), pqt.TypeJSONB():
		return "[]byte"
	case pqt.TypeUUID():
		return "uuid.UUID"
	default:
		gt := t.String()
		switch {
		case strings.HasPrefix(gt, "SMALLINT["):
			return chooseType("pqt.ArrayInt64", "pqt.ArrayInt64", "*qtypes.Int64", m)
		case strings.HasPrefix(gt, "INTEGER["):
			return chooseType("pqt.ArrayInt64", "pqt.ArrayInt64", "*qtypes.Int64", m)
		case strings.HasPrefix(gt, "BIGINT["):
			return chooseType("pqt.ArrayInt64", "pqt.ArrayInt64", "*qtypes.Int64", m)
		case strings.HasPrefix(gt, "DOUBLE PRECISION["):
			return chooseType("pqt.ArrayFloat64", "pqt.ArrayFloat64", "*qtypes.Float64", m)
		case strings.HasPrefix(gt, "TEXT["):
			return "pqt.ArrayString"
		case strings.HasPrefix(gt, "DECIMAL"), strings.HasPrefix(gt, "NUMERIC"):
			return chooseType("float64", "*ntypes.Float64", "*qtypes.Float64", m)
		case strings.HasPrefix(gt, "VARCHAR"):
			return chooseType("string", "*ntypes.String", "*qtypes.String", m)
		default:
			return "interface{}"
		}
	}
}
Ejemplo n.º 3
0
func timestampable(t *pqt.Table) {
	t.AddColumn(pqt.NewColumn("created_at", pqt.TypeTimestampTZ(), pqt.WithNotNull(), pqt.WithDefault("NOW()"))).
		AddColumn(pqt.NewColumn("updated_at", pqt.TypeTimestampTZ(), pqt.WithDefault("NOW()", pqt.EventUpdate)))
}
Ejemplo n.º 4
0
func (g *Generator) generateRepositoryUpdateOneByPrimaryKey(w io.Writer, table *pqt.Table) {
	entityName := g.name(table.Name)
	pk, ok := table.PrimaryKey()
	if !ok {
		return
	}

	fmt.Fprintf(w, "func (r *%sRepositoryBase) %s%s(%s %s, patch *%sPatch) (*%sEntity, error) {\n", entityName, g.name("UpdateOneBy"), g.public(pk.Name), g.private(pk.Name), g.generateColumnTypeString(pk, modeMandatory), entityName, entityName)
	fmt.Fprintf(w, "update := pqcomp.New(1, %d)\n", len(table.Columns))
	fmt.Fprintf(w, "update.AddArg(%s)\n", g.private(pk.Name))
	fmt.Fprintln(w, "")

ColumnsLoop:
	for _, c := range table.Columns {
		if c == pk {
			continue ColumnsLoop
		}
		if _, ok := c.DefaultOn(pqt.EventInsert, pqt.EventUpdate); ok {
			switch c.Type {
			case pqt.TypeTimestamp(), pqt.TypeTimestampTZ():
				fmt.Fprintf(w, "if patch.%s != nil {\n", g.propertyName(c.Name))

			}
		} else if g.canBeNil(c, modeOptional) {
			fmt.Fprintf(w, "if patch.%s != nil {\n", g.propertyName(c.Name))
		}

		fmt.Fprint(w, "update.AddExpr(")
		g.writeTableNameColumnNameTo(w, c.Table.Name, c.Name)
		fmt.Fprintf(w, ", pqcomp.Equal, patch.%s)\n", g.propertyName(c.Name))

		if d, ok := c.DefaultOn(pqt.EventUpdate); ok {
			switch c.Type {
			case pqt.TypeTimestamp(), pqt.TypeTimestampTZ():
				fmt.Fprint(w, `} else {`)
				fmt.Fprint(w, "update.AddExpr(")
				g.writeTableNameColumnNameTo(w, c.Table.Name, c.Name)
				fmt.Fprintf(w, `, pqcomp.Equal, "%s")`, d)
			}
		}
		if _, ok := c.DefaultOn(pqt.EventInsert, pqt.EventUpdate); ok {
			switch c.Type {
			case pqt.TypeTimestamp(), pqt.TypeTimestampTZ():
				fmt.Fprint(w, "\n}\n")
			}
		} else if g.canBeNil(c, modeOptional) {
			fmt.Fprint(w, "\n}\n")
		}
	}
	fmt.Fprintf(w, `
	if update.Len() == 0 {
		return nil, errors.New("%s update failure, nothing to update")
	}`, entityName)

	fmt.Fprintf(w, `
	query := "UPDATE %s SET "
	for update.Next() {
		if !update.First() {
			query += ", "
		}

		query += update.Key() + " " + update.Oper() + " " + update.PlaceHolder()
	}
	query += " WHERE %s = $1 RETURNING " + strings.Join(r.columns, ", ")
	var e %sEntity
	err := r.db.QueryRow(query, update.Args()...).Scan(
	`, table.FullName(), pk.Name, entityName)
	for _, c := range table.Columns {
		fmt.Fprintf(w, "&e.%s,\n", g.propertyName(c.Name))
	}
	fmt.Fprint(w, `)
if err != nil {
	return nil, err
}


return &e, nil
}
`)
}
Ejemplo n.º 5
0
func (g *Generator) generateRepositoryUpdateOneByUniqueConstraint(w io.Writer, table *pqt.Table) {
	entityName := g.name(table.Name)
	var unique []*pqt.Constraint
	for _, c := range tableConstraints(table) {
		if c.Type == pqt.ConstraintTypeUnique {
			unique = append(unique, c)
		}
	}
	if len(unique) < 1 {
		return
	}

	for _, u := range unique {
		arguments := ""
		methodName := "UpdateOneBy"
		for i, c := range u.Columns {
			if i != 0 {
				methodName += "And"
				arguments += ", "
			}
			methodName += g.public(c.Name)
			arguments += fmt.Sprintf("%s %s", g.private(columnForeignName(c)), g.generateColumnTypeString(c, modeMandatory))
		}
		fmt.Fprintf(w, `func (r *%sRepositoryBase) %s(%s, patch *%sPatch) (*%sEntity, error) {
		`, entityName, g.name(methodName), arguments, entityName, entityName)
		fmt.Fprintf(w, "update := pqcomp.New(%d, %d)\n", len(u.Columns), len(table.Columns))
		for _, c := range u.Columns {
			fmt.Fprintf(w, "update.AddArg(%s)\n", g.private(columnForeignName(c)))
		}
		pk, pkOK := table.PrimaryKey()
	ColumnsLoop:
		for _, c := range table.Columns {
			if pkOK && c == pk {
				continue ColumnsLoop
			}
			for _, uc := range u.Columns {
				if c == uc {
					continue
				}
			}
			if _, ok := c.DefaultOn(pqt.EventInsert, pqt.EventUpdate); ok {
				switch c.Type {
				case pqt.TypeTimestamp(), pqt.TypeTimestampTZ():
					fmt.Fprintf(w, "if patch.%s != nil {\n", g.propertyName(c.Name))

				}
			} else if g.canBeNil(c, modeOptional) {
				fmt.Fprintf(w, "if patch.%s != nil {\n", g.propertyName(c.Name))
			}

			fmt.Fprint(w, "update.AddExpr(")
			g.writeTableNameColumnNameTo(w, c.Table.Name, c.Name)
			fmt.Fprintf(w, ", pqcomp.Equal, patch.%s)\n", g.propertyName(c.Name))

			if d, ok := c.DefaultOn(pqt.EventUpdate); ok {
				switch c.Type {
				case pqt.TypeTimestamp(), pqt.TypeTimestampTZ():
					fmt.Fprint(w, `} else {`)
					fmt.Fprint(w, "update.AddExpr(")
					g.writeTableNameColumnNameTo(w, c.Table.Name, c.Name)
					fmt.Fprintf(w, `, pqcomp.Equal, "%s")`, d)
				}
			}
			if _, ok := c.DefaultOn(pqt.EventInsert, pqt.EventUpdate); ok {
				switch c.Type {
				case pqt.TypeTimestamp(), pqt.TypeTimestampTZ():
					fmt.Fprint(w, "\n}\n")
				}
			} else if g.canBeNil(c, modeOptional) {
				fmt.Fprint(w, "\n}\n")
			}
		}

		fmt.Fprintf(w, `
	if update.Len() == 0 {
		return nil, errors.New("%s update failure, nothing to update")
	}`, entityName)

		fmt.Fprintf(w, `
	query := "UPDATE %s SET "
	for update.Next() {
		if !update.First() {
			query += ", "
		}

		query += update.Key() + " " + update.Oper() + " " + update.PlaceHolder()
	}
`, table.FullName())
		fmt.Fprint(w, `query += " WHERE `)
		for i, c := range u.Columns {
			if i != 0 {
				fmt.Fprint(w, " AND ")
			}
			fmt.Fprintf(w, "%s = $%d", c.Name, i+1)
		}
		fmt.Fprintf(w, ` RETURNING " + strings.Join(r.columns, ", ")
	if r.dbg {
		if err := r.log.Log("msg", query, "function", "%s"); err != nil {
			return nil, err
		}
	}
	var e %sEntity
	err := r.db.QueryRow(query, update.Args()...).Scan(
	`, methodName, entityName)
		for _, c := range table.Columns {
			fmt.Fprintf(w, "&e.%s,\n", g.propertyName(c.Name))
		}
		fmt.Fprint(w, `)
if err != nil {
	return nil, err
}


return &e, nil
}
`)
	}
}
Ejemplo n.º 6
0
func TestGenerator_Generate(t *testing.T) {
	success := []struct {
		expected string
		given    *pqt.Table
	}{
		{
			expected: `-- do not modify, generated by pqt

CREATE TEMPORARY TABLE schema.user (
	created_at TIMESTAMPTZ,
	password TEXT,
	username TEXT NOT NULL
);

`,
			given: func() *pqt.Table {
				return pqt.NewTable("user", pqt.WithTemporary()).
					SetSchema(pqt.NewSchema("schema")).
					AddColumn(&pqt.Column{Name: "username", Type: pqt.TypeText(), NotNull: true}).
					AddColumn(&pqt.Column{Name: "password", Type: pqt.TypeText()}).
					AddColumn(&pqt.Column{Name: "created_at", Type: pqt.TypeTimestampTZ()})
			}(),
		},
		{
			expected: `-- do not modify, generated by pqt

CREATE TABLE IF NOT EXISTS table_name (
	created_at TIMESTAMPTZ DEFAULT NOW() NOT NULL,
	created_by INTEGER NOT NULL,
	enabled BOOL,
	end_at TIMESTAMPTZ NOT NULL,
	id SERIAL,
	name TEXT,
	price DECIMAL(10,1),
	rel_id INTEGER,
	slug TEXT NOT NULL,
	start_at TIMESTAMPTZ NOT NULL,
	updated_at TIMESTAMPTZ,
	updated_by INTEGER,

	CONSTRAINT "public.table_name_id_pkey" PRIMARY KEY (id),
	CONSTRAINT "public.table_name_name_key" UNIQUE (name),
	CONSTRAINT "public.table_name_slug_key" UNIQUE (slug),
	CONSTRAINT "public.table_name_rel_id_fkey" FOREIGN KEY (rel_id) REFERENCES related_table (id),
	CONSTRAINT "public.table_name_start_at_end_at_check" CHECK ((start_at IS NULL AND end_at IS NULL) OR start_at < end_at)
);

`,
			given: func() *pqt.Table {
				id := pqt.Column{Name: "id", Type: pqt.TypeSerial()}
				startAt := &pqt.Column{Name: "start_at", Type: pqt.TypeTimestampTZ(), NotNull: true}
				endAt := &pqt.Column{Name: "end_at", Type: pqt.TypeTimestampTZ(), NotNull: true}
				_ = pqt.NewTable("related_table").
					AddColumn(&id)

				return pqt.NewTable("table_name", pqt.WithTableIfNotExists()).
					AddColumn(&pqt.Column{Name: "id", Type: pqt.TypeSerial(), PrimaryKey: true}).
					AddColumn(&pqt.Column{Name: "rel_id", Type: pqt.TypeInteger(), Reference: &id}).
					AddColumn(&pqt.Column{Name: "name", Type: pqt.TypeText(), Unique: true}).
					AddColumn(&pqt.Column{Name: "enabled", Type: pqt.TypeBool()}).
					AddColumn(&pqt.Column{Name: "price", Type: pqt.TypeDecimal(10, 1)}).
					AddColumn(startAt).
					AddColumn(endAt).
					AddColumn(pqt.NewColumn("created_at", pqt.TypeTimestampTZ(), pqt.WithNotNull(), pqt.WithDefault("NOW()"))).
					AddColumn(&pqt.Column{Name: "created_by", Type: pqt.TypeInteger(), NotNull: true}).
					AddColumn(&pqt.Column{Name: "updated_at", Type: pqt.TypeTimestampTZ()}).
					AddColumn(&pqt.Column{Name: "updated_by", Type: pqt.TypeInteger()}).
					AddColumn(&pqt.Column{Name: "slug", Type: pqt.TypeText(), NotNull: true, Unique: true}).
					AddCheck("(start_at IS NULL AND end_at IS NULL) OR start_at < end_at", startAt, endAt)
			}(),
		},
	}

	for i, data := range success {
		q, err := pqtsql.NewGenerator().Generate(&pqt.Schema{
			Tables: []*pqt.Table{data.given},
		})

		if err != nil {
			t.Errorf("unexpected error for schema #%d : %s", i, err.Error())
			continue
		}

		if string(q) != data.expected {
			t.Errorf("wrong query, expected:\n'%s'\nbut got:\n'%s'", data.expected, q)
		}
	}
}