This repository has been archived by the owner on Apr 5, 2021. It is now read-only.
/
login.go
332 lines (278 loc) · 9.14 KB
/
login.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
package oauthmw
import (
"crypto/md5"
"fmt"
"log"
"net/http"
"net/url"
"strings"
"time"
"goji.io"
"github.com/knq/sessionmw"
"golang.org/x/net/context"
"golang.org/x/oauth2"
)
var oauth2Context = oauth2.NoContext
// A CheckFn is passed a provider name, the original provider config, and the
// redeemed token after a successful OAuth2.0 exchange.
//
// CheckFn should return a redirect URL (if any) and whether or not to allow
// the login.
type CheckFn func(string, *oauth2.Config, *oauth2.Token) (string, bool)
// Login
type login struct {
// provider configuration
provider *Provider
// whether or not a valid login is required
required bool
// check function after exchange
checkFn CheckFn
// the protected handler
h goji.Handler
}
// sessionStore returns the oauthmw session store.
func (l login) sessionStore(ctxt context.Context) *Store {
// get store from session
obj, ok := sessionmw.Get(ctxt, l.provider.SessionKey)
if ok {
store, ok := obj.(Store)
if !ok {
// this shouldn't ever happen ...
log.Println("CORRUPTED/MALFORMED SESSION STORAGE. OVERWRITING")
store = Store{
Provider: "",
Token: &oauth2.Token{},
States: make(map[string]StoreState),
}
sessionmw.Set(ctxt, l.provider.SessionKey, store)
return &store
}
return &store
}
// create new store in session and return
store := Store{
Provider: "",
Token: &oauth2.Token{},
States: make(map[string]StoreState),
}
sessionmw.Set(ctxt, l.provider.SessionKey, store)
return &store
}
// addState adds a state to session store.
func (l login) addState(ctxt context.Context, provName, state string) {
sess := l.sessionStore(ctxt)
key := fmt.Sprintf("%x", md5.Sum([]byte(state)))
sess.States[key] = StoreState{
Provider: provName,
Expiration: time.Now().Add(l.provider.StateLifetime),
Redeemed: false,
}
}
// getSafeSessionID retrieves the session id from the context and returns it
// after hashing the value.
func (l login) getSafeSessionID(ctxt context.Context) string {
sessID := sessionmw.ID(ctxt)
return fmt.Sprintf("%x", md5.Sum([]byte(sessID)))
}
// getToken returns the stored token from session.
//
// Returns the token, expired, and ok state.
func (l login) getToken(ctxt context.Context) (*oauth2.Token, bool, bool) {
// grab session object
sess := l.sessionStore(ctxt)
// if token not present
if sess.Token == nil {
return nil, false, false
}
// determine if token is expired
if !sess.Token.Expiry.IsZero() && time.Now().After(sess.Token.Expiry) {
return sess.Token, true, true
}
return sess.Token, false, true
}
// doRedirect handles oauthmw redirects.
//
// Will validate passed state, and adds it to the session store.
func (l login) doRedirect(ctxt context.Context, res http.ResponseWriter, req *http.Request, provName string, stateDec map[string]string) {
prov, ok := l.provider.Configs[provName]
if !ok {
l.provider.ErrorFn(500, "invalid provider", res, req)
return
}
// verify state belongs to this session
if l.getSafeSessionID(ctxt) != stateDec["sid"] {
l.provider.ErrorFn(500, "forged sid in redirect", res, req)
return
}
// verify it matches provider
if provName != stateDec["provider"] {
l.provider.ErrorFn(500, "forged provider in redirect", res, req)
return
}
// store state to session
passedState := req.URL.Query().Get("state")
l.addState(ctxt, provName, passedState)
http.Redirect(res, req, prov.AuthCodeURL(passedState), 302)
}
// doReturn handles oauthmw returns.
//
// Verifies passed oauth2 code, and state from the values stored in session,
// and then redeems (calls oauth2 Exchange) token.
//
// If successful, the oauth2 token will be stored in the session.
func (l login) doReturn(ctxt context.Context, res http.ResponseWriter, req *http.Request, stateDec map[string]string) {
// verify state belongs to this session
if l.getSafeSessionID(ctxt) != stateDec["sid"] {
l.provider.ErrorFn(500, "forged sid in return", res, req)
return
}
// grab passed state
passedState := req.URL.Query().Get("state")
// grab state from session
stateKey := fmt.Sprintf("%x", md5.Sum([]byte(passedState)))
sess := l.sessionStore(ctxt)
storedState, ok := sess.States[stateKey]
if !ok {
l.provider.ErrorFn(500, "state not found in session", res, req)
return
}
// verify that stored state has not expired yet
if !storedState.Expiration.IsZero() && time.Now().After(storedState.Expiration) {
l.provider.ErrorFn(500, "request expired. try again", res, req)
return
}
// verify not already redeemed
if storedState.Redeemed {
l.provider.ErrorFn(500, "already redeemed. try again", res, req)
return
}
// verify that stored provider is same as passed provider
if stateDec["provider"] != storedState.Provider {
l.provider.ErrorFn(500, "invalid provider", res, req)
return
}
// grab redirect path
resource, ok := stateDec["resource"]
if !ok {
l.provider.ErrorFn(500, "invalid resource", res, req)
return
}
// use code for oauth2 exchange
code := req.URL.Query().Get("code")
token, err := l.provider.Configs[storedState.Provider].Exchange(oauth2Context, code)
if err != nil {
//log.Printf("error doing exchange with %s: %s", storedState.Provider, err)
l.provider.ErrorFn(500, fmt.Sprintf("could not do exchange with %s", storedState.Provider), res, req)
return
}
// verify token is valid
if !token.Valid() {
l.provider.ErrorFn(403, http.StatusText(403), res, req)
return
}
// pass to checkFn
if l.checkFn != nil {
msg, ok := l.checkFn(storedState.Provider, l.provider.Configs[storedState.Provider], token)
if !ok {
l.provider.ErrorFn(500, msg, res, req)
return
}
}
// set token expiry if TokenLifetime specified and not already indicated
tokenExpiry := time.Now().Add(l.provider.TokenLifetime)
if l.provider.TokenLifetime > 0 && (sess.Token.Expiry.IsZero() || sess.Token.Expiry.After(tokenExpiry)) {
sess.Token.Expiry = tokenExpiry
}
// save oauth2 token in session
*(sess.Token) = *token
sess.Provider = storedState.Provider
// flag redeemed status
storedState.Redeemed = true
sess.States[stateKey] = storedState
// redirect -- use 301 because token cannot be redeemed twice
http.Redirect(res, req, resource, 301)
}
// redirectPath returns a built oauthmw redirect path for a provider.
func (l login) redirectPath(provName, state string) string {
path := ""
if l.provider.Path != "/" {
path = l.provider.Path
}
return path + l.provider.RedirectPrefix + provName + "?state=" + url.QueryEscape(state)
}
// doProtectedPage handles protected page logic.
//
// If only one oauth2 provider, do redirect, otherwise output protected page
// template allowing user to select login mechanism.
func (l login) doProtectedPage(ctxt context.Context, res http.ResponseWriter, req *http.Request) {
// build sessionid for encodestate
sid := l.getSafeSessionID(ctxt)
// build path
path := req.URL.Path
/*if l.provider.SubRouter && l.provider.Path != "/" {
path = l.provider.Path + req.URL.Path
}*/
// if only one in ConfigsOrder, then redirect
if len(l.provider.ConfigsOrder) == 1 {
provName := l.provider.ConfigsOrder[0]
state, err := l.provider.EncodeState(sid, provName, path)
if err != nil {
l.provider.ErrorFn(500, fmt.Sprintf("could not encode state for %s", provName), res, req)
return
}
http.Redirect(res, req, l.redirectPath(provName, state), 302)
return
}
// build hrefs for template
hrefs := make(map[string]interface{}, len(l.provider.ConfigsOrder))
for _, provName := range l.provider.ConfigsOrder {
state, err := l.provider.EncodeState(sid, provName, path)
if err != nil {
l.provider.ErrorFn(500, fmt.Sprintf("could not encode state for %s (2)", provName), res, req)
return
}
hrefs[provName] = l.redirectPath(provName, state)
}
l.provider.TemplateFn(res, req, hrefs)
}
// ServeHTTPC handles oauth2 logic for the login middleware.
func (l login) ServeHTTPC(ctxt context.Context, res http.ResponseWriter, req *http.Request) {
// loop through states and do cleanup if enabled
sess := l.sessionStore(ctxt)
if l.provider.CleanupStates && len(sess.States) >= l.provider.MaxStates {
expiration := time.Now()
for h, s := range sess.States {
if expiration.After(s.Expiration) {
delete(sess.States, h)
}
}
}
// grab last page, and check if matches special paths
if i := strings.LastIndexByte(req.URL.Path, '/'); i >= 0 {
path := req.URL.Path[i:]
if strings.HasPrefix(path, l.provider.PagePrefix) {
// decode passed state
passedState := req.URL.Query().Get("state")
stateDec, err := l.provider.DecodeState(passedState)
switch {
// state properly decoded and is a redirect path
case err == nil && strings.HasPrefix(path, l.provider.RedirectPrefix):
l.doRedirect(ctxt, res, req, path[len(l.provider.RedirectPrefix):], stateDec)
return
// state properly decoded and is return (login) path
case err == nil && path == l.provider.ReturnName:
l.doReturn(ctxt, res, req, stateDec)
return
}
}
}
// run protected page logic if login required and
// token invalid, expired, or otherwise bad
token, expired, ok := l.getToken(ctxt)
if l.required && (!ok || expired || token == nil || !token.Valid()) {
l.doProtectedPage(ctxt, res, req)
return
}
// pass to next middleware
l.h.ServeHTTPC(ctxt, res, req)
}