/
sqlbuilder.go
196 lines (164 loc) · 3.91 KB
/
sqlbuilder.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
// Package sqlbuilder facilitates programamatically generating SQL queries using a chainable interface.
package sqlbuilder
import (
"database/sql"
"encoding/json"
"fmt"
sqlx "github.com/jmoiron/sqlx"
"github.com/visionmedia/go-debug"
"os"
"strings"
)
var debugEnabled = strings.Contains(os.Getenv("DEBUG"), "sql")
var Debug = debug.Debug("sql")
const (
gate_and = iota
gate_or = iota
)
const (
join_none = iota
join_inner = iota
join_left = iota
join_right = iota
join_outer = iota
)
const (
action_select = iota
action_insert = iota
action_update = iota
action_delete = iota
action_union = iota
action_count = iota
)
type PlaceholderFunction func(index int) string
type Query struct {
action int
fields []string
tables []*table
cache *VarCache
having *constraint
where *constraint
groups groups
ordering ordering
data map[string]interface{}
limit int
offset int
placeholder PlaceholderFunction
returning string
unions []SQLProvider
}
type SQLProvider interface {
GetSQL(cache *VarCache) string
}
type VarCache struct {
placeholder PlaceholderFunction
vars []interface{}
}
func (v *VarCache) add(val interface{}) string {
v.vars = append(v.vars, val)
if v.placeholder != nil {
return v.placeholder(len(v.vars))
}
return fmt.Sprintf("$%d", len(v.vars))
}
type group struct {
field string
descending bool
}
func (q *Query) Limit(limit int) *Query {
q.limit = limit
return q
}
func (q *Query) Offset(offset int) *Query {
q.offset = offset
return q
}
// Change the prepared statement placeholder (the question mark in this example) (INSERT INTO _ (?, ?, ?) VALUES())
func (q *Query) Placeholder(placeholder PlaceholderFunction) *Query {
q.placeholder = placeholder
return q
}
func newQuery() *Query {
q := new(Query)
q.cache = new(VarCache)
return q
}
// Generate the SQL for this query. Returns the generated SQL (string), and a slice of arbitrary values to pass to sql.DB.Exec or sql.DB.Query
func (q *Query) GetFullSQL() (string, []interface{}) {
cache := &VarCache{
placeholder: q.placeholder,
}
return q.GetSQL(cache), cache.vars
}
// This satisfies the SQLProvider interface so we can use subqueries
func (q *Query) GetSQL(cache *VarCache) string {
var sql string
switch q.action {
case action_select:
sql = q.getSelectSQL(cache)
case action_insert:
sql = q.getInsertSQL(cache)
case action_update:
sql = q.getUpdateSQL(cache)
case action_delete:
sql = q.getDeleteSQL(cache)
case action_union:
sql = q.getUnionSQL(cache)
case action_count:
sql = q.getCountSQL(cache)
}
return sql
}
func (q *Query) GetCount(db *sqlx.DB) (int, error) {
var count int
prevAction := q.action
q.action = action_count
defer func() {
q.action = prevAction
}()
err := q.GetValue(db, &count)
if err != nil {
return 0, err
}
return count, nil
}
// Execute a write query (INSERT/UPDATE/DELETE) on a given SQL database
func (q *Query) ExecWrite(db *sqlx.DB) (sql.Result, error) {
sql, vars := q.GetFullSQL()
if debugEnabled {
marshaled, _ := json.Marshal(vars)
Debug("%s, %s", sql, string(marshaled))
}
return db.Exec(sql, vars...)
}
// Execute a read query (SELECT) on a given SQL database
func (q *Query) ExecRead(db *sqlx.DB) (*sqlx.Rows, error) {
sql, vars := q.GetFullSQL()
if debugEnabled {
marshaled, _ := json.Marshal(vars)
Debug("%s, %s", sql, string(marshaled))
}
return db.Queryx(sql, vars...)
}
func (q *Query) GetResult(db *sqlx.DB, result interface{}) error {
sql, vars := q.GetFullSQL()
if debugEnabled {
marshaled, _ := json.Marshal(vars)
Debug("%s, %s", sql, string(marshaled))
}
return db.Get(result, sql, vars...)
}
func (q *Query) GetValue(db *sqlx.DB, val interface{}) error {
results, err := q.ExecRead(db)
if err != nil {
return err
}
defer results.Close()
results.Next()
err = results.Err()
if err != nil {
return err
}
err = results.Scan(val)
return err
}