func paginate(next chi.Handler) chi.Handler { return chi.HandlerFunc(func(ctx context.Context, w http.ResponseWriter, r *http.Request) { // just a stub.. some ideas are to look at URL query params for something like // the page number, or the limit, and send a query cursor down the chain next.ServeHTTPC(ctx, w, r) }) }
// CloseNotify is a middleware that cancels ctx when the underlying // connection has gone away. It can be used to cancel long operations // on the server when the client disconnects before the response is ready. func CloseNotify(next chi.Handler) chi.Handler { fn := func(ctx context.Context, w http.ResponseWriter, r *http.Request) { cn, ok := w.(http.CloseNotifier) if !ok { panic("middleware.CloseNotify expects http.ResponseWriter to implement http.CloseNotifier interface") } ctx, cancel := context.WithCancel(ctx) defer cancel() go func() { select { case <-ctx.Done(): return case <-cn.CloseNotify(): w.WriteHeader(StatusClientClosedRequest) cancel() return } }() next.ServeHTTPC(ctx, w, r) } return chi.HandlerFunc(fn) }
func accountCtx(h chi.Handler) chi.Handler { handler := func(ctx context.Context, w http.ResponseWriter, r *http.Request) { ctx = context.WithValue(ctx, "account", "account 123") h.ServeHTTPC(ctx, w, r) } return chi.HandlerFunc(handler) }
func sup2(next chi.Handler) chi.Handler { hfn := func(ctx context.Context, w http.ResponseWriter, r *http.Request) { ctx = context.WithValue(ctx, "sup2", "sup2") next.ServeHTTPC(ctx, w, r) } return chi.HandlerFunc(hfn) }
// RequestID is a middleware that injects a request ID into the context of each // request. A request ID is a string of the form "host.example.com/random-0001", // where "random" is a base62 random string that uniquely identifies this go // process, and where the last number is an atomically incremented request // counter. func RequestID(next chi.Handler) chi.Handler { fn := func(ctx context.Context, w http.ResponseWriter, r *http.Request) { myid := atomic.AddUint64(&reqid, 1) ctx = context.WithValue(ctx, RequestIDKey, fmt.Sprintf("%s-%06d", prefix, myid)) next.ServeHTTPC(ctx, w, r) } return chi.HandlerFunc(fn) }
// Set the parent context in the middleware chain to something else. Useful // in the instance of having a global server context to signal all requests. func ParentContext(parent context.Context) func(next chi.Handler) chi.Handler { return func(next chi.Handler) chi.Handler { fn := func(ctx context.Context, w http.ResponseWriter, r *http.Request) { pctx := context.WithValue(parent, chi.URLParamsCtxKey, chi.URLParams(ctx)) next.ServeHTTPC(pctx, w, r) } return chi.HandlerFunc(fn) } }
func GetLongID(next chi.Handler) chi.Handler { fn := func(ctx context.Context, w http.ResponseWriter, r *http.Request) { ctx = context.WithValue(ctx, "id", strings.TrimPrefix(r.RequestURI, "/jobs/")) next.ServeHTTPC(ctx, w, r) } return chi.HandlerFunc(fn) }
func (ja *JwtAuth) Handle(paramAliases ...string) func(chi.Handler) chi.Handler { return func(next chi.Handler) chi.Handler { hfn := func(ctx context.Context, w http.ResponseWriter, r *http.Request) { var tokenStr string var err error // Get token from query params tokenStr = r.URL.Query().Get("jwt") // Get token from other query param aliases if tokenStr == "" && paramAliases != nil && len(paramAliases) > 0 { for _, p := range paramAliases { tokenStr = r.URL.Query().Get(p) if tokenStr != "" { break } } } // Get token from authorization header if tokenStr == "" { bearer := r.Header.Get("Authorization") if len(bearer) > 7 && strings.ToUpper(bearer[0:6]) == "BEARER" { tokenStr = bearer[7:] } } // Get token from cookie if tokenStr == "" { cookie, err := r.Cookie("jwt") if err == nil { tokenStr = cookie.Value } } // Token is required, cya if tokenStr == "" { err = errUnauthorized } // Verify the token token, err := ja.Decode(tokenStr) if err != nil || !token.Valid || token.Method != ja.signer { utils.Respond(w, 401, errUnauthorized) return } ctx = context.WithValue(ctx, "jwt", token.Raw) ctx = context.WithValue(ctx, "jwt.token", token) next.ServeHTTPC(ctx, w, r) } return chi.HandlerFunc(hfn) } }
func AdminOnly(next chi.Handler) chi.Handler { return chi.HandlerFunc(func(ctx context.Context, w http.ResponseWriter, r *http.Request) { isAdmin, ok := ctx.Value("acl.admin").(bool) if !ok || !isAdmin { http.Error(w, http.StatusText(403), 403) return } next.ServeHTTPC(ctx, w, r) }) }
// Airbrake recoverer middleware to capture and report any panics to // airbrake.io. func AirbrakeRecoverer(apiKey string) func(chi.Handler) chi.Handler { airbrake.ApiKey = apiKey return func(next chi.Handler) chi.Handler { return chi.HandlerFunc(func(ctx context.Context, w http.ResponseWriter, r *http.Request) { if apiKey != "" { defer airbrake.CapturePanic(r) } next.ServeHTTPC(ctx, w, r) }) } }
func ArticleCtx(next chi.Handler) chi.Handler { return chi.HandlerFunc(func(ctx context.Context, w http.ResponseWriter, r *http.Request) { articleID := chi.URLParams(ctx)["articleID"] article, err := dbGetArticle(articleID) if err != nil { http.Error(w, http.StatusText(404), 404) return } ctx = context.WithValue(ctx, "article", article) next.ServeHTTPC(ctx, w, r) }) }
// AccessControl is an example that just prints the ACL route + operation that // is being checked without actually doing any checking func AccessControl(next chi.Handler) chi.Handler { hn := func(ctx context.Context, w http.ResponseWriter, r *http.Request) { route, ok := ctx.Value("acl.route").([]string) if !ok { http.Error(w, "undefined acl route", 403) return } // Put ACL code here log.Printf("Checking permission to %s %s", r.Method, strings.Join(route, " -> ")) next.ServeHTTPC(ctx, w, r) } return chi.HandlerFunc(hn) }
// Timeout is a middleware that cancels ctx after a given timeout. func Timeout(timeout time.Duration) func(next chi.Handler) chi.Handler { return func(next chi.Handler) chi.Handler { fn := func(ctx context.Context, w http.ResponseWriter, r *http.Request) { ctx, cancel := context.WithTimeout(ctx, timeout) defer func() { cancel() if ctx.Err() == context.DeadlineExceeded { w.WriteHeader(StatusServerTimeout) } }() next.ServeHTTPC(ctx, w, r) } return chi.HandlerFunc(fn) } }
// Recoverer is a middleware that recovers from panics, logs the panic (and a // backtrace), and returns a HTTP 500 (Internal Server Error) status if // possible. // // Recoverer prints a request ID if one is provided. func Recoverer(next chi.Handler) chi.Handler { fn := func(ctx context.Context, w http.ResponseWriter, r *http.Request) { defer func() { if err := recover(); err != nil { reqID := GetReqID(ctx) printPanic(reqID, err) debug.PrintStack() http.Error(w, http.StatusText(500), 500) } }() next.ServeHTTPC(ctx, w, r) } return chi.HandlerFunc(fn) }
// Logger is a middleware that logs the start and end of each request, along // with some useful data about what was requested, what the response status was, // and how long it took to return. When standard output is a TTY, Logger will // print in color, otherwise it will print in black and white. // // Logger prints a request ID if one is provided. func Logger(next chi.Handler) chi.Handler { fn := func(ctx context.Context, w http.ResponseWriter, r *http.Request) { reqID := GetReqID(ctx) prefix := requestPrefix(reqID, r) lw := wrapWriter(w) t1 := time.Now() defer func() { t2 := time.Now() printRequest(prefix, reqID, lw, t2.Sub(t1)) }() next.ServeHTTPC(ctx, lw, r) } return chi.HandlerFunc(fn) }
// Route builds a resource identifier from the middleware chain. This resource // identifier along with the operation (HTTP verb) can be used for determining // access to a resource func Route(part string) func(chi.Handler) chi.Handler { return func(next chi.Handler) chi.Handler { hn := func(ctx context.Context, w http.ResponseWriter, r *http.Request) { part = extractACLRoute(part) if part != "" { route, ok := ctx.Value("acl.route").([]string) if ok { route = append(route, part) } else { route = []string{part} } ctx = context.WithValue(ctx, "acl.route", route) } next.ServeHTTPC(ctx, w, r) } return chi.HandlerFunc(hn) } }
// Authenticator is a default authentication middleware to enforce access following // the Verifier middleware. The Authenticator sends a 401 Unauthorized response for // all unverified tokens and passes the good ones through. It's just fine until you // decide to write something similar and customize your client response. func Authenticator(next chi.Handler) chi.Handler { return chi.HandlerFunc(func(ctx context.Context, w http.ResponseWriter, r *http.Request) { if jwtErr, ok := ctx.Value("jwt.err").(error); ok { if jwtErr != nil { http.Error(w, http.StatusText(401), 401) return } } jwtToken, ok := ctx.Value("jwt").(*jwt.Token) if !ok || jwtToken == nil || !jwtToken.Valid { http.Error(w, http.StatusText(401), 401) return } // Token is authenticated, pass it through next.ServeHTTPC(ctx, w, r) }) }
func Router() chi.Router { r := chi.NewRouter() r.Use(middleware.RequestID) r.Use(middleware.RealIP) r.Use(middleware.Logger) r.Use(middleware.Recoverer) r.Use(func(h chi.Handler) chi.Handler { return chi.HandlerFunc(func(ctx context.Context, w http.ResponseWriter, r *http.Request) { ctx = context.WithValue(ctx, "example", true) h.ServeHTTPC(ctx, w, r) }) }) r.Get("/", apiIndex) return r }
func main() { r := chi.NewRouter() r.Use(middleware.RequestID) r.Use(middleware.RealIP) r.Use(middleware.Logger) r.Use(func(h chi.Handler) chi.Handler { return chi.HandlerFunc(func(ctx context.Context, w http.ResponseWriter, r *http.Request) { ctx = context.WithValue(ctx, "example", true) h.ServeHTTPC(ctx, w, r) }) }) r.Get("/", apiIndex) r.Mount("/accounts", accountsRouter()) http.ListenAndServe(":3333", r) }
// Logger is a middleware that logs the start and end of each request, along // with some useful data about what was requested, what the response status was, // and how long it took to return. When standard output is a TTY, Logger will // print in color, otherwise it will print in black and white. // // Logger prints a request ID if one is provided. // // Logger has been designed explicitly to be Good Enough for use in small // applications and for people just getting started with Goji. It is expected // that applications will eventually outgrow this middleware and replace it with // a custom request logger, such as one that produces machine-parseable output, // outputs logs to a different service (e.g., syslog), or formats lines like // those printed elsewhere in the application. func Logger(next chi.Handler) chi.Handler { fn := func(ctx context.Context, w http.ResponseWriter, r *http.Request) { reqID := GetReqID(ctx) printStart(reqID, r) lw := wrapWriter(w) t1 := time.Now() next.ServeHTTPC(ctx, lw, r) if lw.Status() == 0 { lw.WriteHeader(http.StatusOK) } t2 := time.Now() printEnd(reqID, lw, t2.Sub(t1)) } return chi.HandlerFunc(fn) }
//BodyParser loads builder with maxSize and tries to load the message. //if for some reason it can't parse the message, it will return an error. //if successful, it will put the processed data into context with key 'json_body' func BodyParser(builder func() interface{}, maxSize int64) func(chi.Handler) chi.Handler { return func(next chi.Handler) chi.Handler { return chi.HandlerFunc(func(ctx context.Context, w http.ResponseWriter, r *http.Request) { to := builder() if err := utils.StreamJSONToStructWithLimit(r.Body, to, maxSize); err != nil { utils.Respond(w, 422, err) return } //check for required fields if err := utils.JSONValidation(to); err != nil { utils.Respond(w, 400, err) return } ctx = context.WithValue(ctx, constants.CtxKeyParsedBody, to) next.ServeHTTPC(ctx, w, r) }) } }
func TestMore(t *testing.T) { r := chi.NewRouter() // Protected routes r.Group(func(r chi.Router) { r.Use(TokenAuth.Verifier) authenticator := func(next chi.Handler) chi.Handler { return chi.HandlerFunc(func(ctx context.Context, w http.ResponseWriter, r *http.Request) { if jwtErr, ok := ctx.Value("jwt.err").(error); ok { switch jwtErr { default: http.Error(w, http.StatusText(401), 401) return case jwtauth.ErrExpired: http.Error(w, "expired", 401) return case jwtauth.ErrUnauthorized: http.Error(w, http.StatusText(401), 401) return case nil: // no error } } jwtToken, ok := ctx.Value("jwt").(*jwt.Token) if !ok || jwtToken == nil || !jwtToken.Valid { http.Error(w, http.StatusText(401), 401) return } // Token is authenticated, pass it through next.ServeHTTPC(ctx, w, r) }) } r.Use(authenticator) r.Get("/admin", func(ctx context.Context, w http.ResponseWriter, r *http.Request) { w.Write([]byte("protected")) }) }) // Public routes r.Group(func(r chi.Router) { r.Get("/", func(ctx context.Context, w http.ResponseWriter, r *http.Request) { w.Write([]byte("welcome")) }) }) ts := httptest.NewServer(r) defer ts.Close() // sending unauthorized requests if status, resp := testRequest(t, ts, "GET", "/admin", nil, nil); status != 401 && resp != "Unauthorized\n" { t.Fatalf(resp) } h := http.Header{} h.Set("Authorization", "BEARER "+newJwtToken([]byte("wrong"), map[string]interface{}{})) if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 && resp != "Unauthorized\n" { t.Fatalf(resp) } h.Set("Authorization", "BEARER asdf") if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 && resp != "Unauthorized\n" { t.Fatalf(resp) } h = newAuthHeader((jwtauth.Claims{}).Set("exp", jwtauth.EpochNow()-1000)) if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 && resp != "expired\n" { t.Fatalf(resp) } // sending authorized requests if status, resp := testRequest(t, ts, "GET", "/", nil, nil); status != 200 && resp != "welcome" { t.Fatalf(resp) } h = newAuthHeader((jwtauth.Claims{}).SetExpiryIn(5 * time.Minute)) if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 200 && resp != "protected" { t.Fatalf(resp) } }
func (ja *JwtAuth) Verify(paramAliases ...string) func(chi.Handler) chi.Handler { return func(next chi.Handler) chi.Handler { hfn := func(ctx context.Context, w http.ResponseWriter, r *http.Request) { var tokenStr string var err error // Get token from query params tokenStr = r.URL.Query().Get("jwt") // Get token from other query param aliases if tokenStr == "" && paramAliases != nil && len(paramAliases) > 0 { for _, p := range paramAliases { tokenStr = r.URL.Query().Get(p) if tokenStr != "" { break } } } // Get token from authorization header if tokenStr == "" { bearer := r.Header.Get("Authorization") if len(bearer) > 7 && strings.ToUpper(bearer[0:6]) == "BEARER" { tokenStr = bearer[7:] } } // Get token from cookie if tokenStr == "" { cookie, err := r.Cookie("jwt") if err == nil { tokenStr = cookie.Value } } // Token is required, cya if tokenStr == "" { err = ErrUnauthorized } // Verify the token token, err := ja.Decode(tokenStr) if err != nil || !token.Valid || token.Method != ja.signer { switch err.Error() { case "token is expired": err = ErrExpired } ctx = ja.SetContext(ctx, token, err) next.ServeHTTPC(ctx, w, r) return } // Check expiry via "exp" claim if ja.IsExpired(token) { err = ErrExpired ctx = ja.SetContext(ctx, token, err) next.ServeHTTPC(ctx, w, r) return } // Valid! pass it down the context to an authenticator middleware ctx = ja.SetContext(ctx, token, err) next.ServeHTTPC(ctx, w, r) } return chi.HandlerFunc(hfn) } }