Ejemplo n.º 1
0
Archivo: main.go Proyecto: raj347/chi
func paginate(next chi.Handler) chi.Handler {
	return chi.HandlerFunc(func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
		// just a stub.. some ideas are to look at URL query params for something like
		// the page number, or the limit, and send a query cursor down the chain
		next.ServeHTTPC(ctx, w, r)
	})
}
Ejemplo n.º 2
0
// CloseNotify is a middleware that cancels ctx when the underlying
// connection has gone away. It can be used to cancel long operations
// on the server when the client disconnects before the response is ready.
func CloseNotify(next chi.Handler) chi.Handler {
	fn := func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
		cn, ok := w.(http.CloseNotifier)
		if !ok {
			panic("middleware.CloseNotify expects http.ResponseWriter to implement http.CloseNotifier interface")
		}

		ctx, cancel := context.WithCancel(ctx)
		defer cancel()

		go func() {
			select {
			case <-ctx.Done():
				return
			case <-cn.CloseNotify():
				w.WriteHeader(StatusClientClosedRequest)
				cancel()
				return
			}
		}()

		next.ServeHTTPC(ctx, w, r)
	}

	return chi.HandlerFunc(fn)
}
Ejemplo n.º 3
0
Archivo: main.go Proyecto: vladdy/chi
func accountCtx(h chi.Handler) chi.Handler {
	handler := func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
		ctx = context.WithValue(ctx, "account", "account 123")
		h.ServeHTTPC(ctx, w, r)
	}
	return chi.HandlerFunc(handler)
}
Ejemplo n.º 4
0
Archivo: main.go Proyecto: vladdy/chi
func sup2(next chi.Handler) chi.Handler {
	hfn := func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
		ctx = context.WithValue(ctx, "sup2", "sup2")
		next.ServeHTTPC(ctx, w, r)
	}
	return chi.HandlerFunc(hfn)
}
Ejemplo n.º 5
0
// RequestID is a middleware that injects a request ID into the context of each
// request. A request ID is a string of the form "host.example.com/random-0001",
// where "random" is a base62 random string that uniquely identifies this go
// process, and where the last number is an atomically incremented request
// counter.
func RequestID(next chi.Handler) chi.Handler {
	fn := func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
		myid := atomic.AddUint64(&reqid, 1)
		ctx = context.WithValue(ctx, RequestIDKey, fmt.Sprintf("%s-%06d", prefix, myid))
		next.ServeHTTPC(ctx, w, r)
	}
	return chi.HandlerFunc(fn)
}
Ejemplo n.º 6
0
// Set the parent context in the middleware chain to something else. Useful
// in the instance of having a global server context to signal all requests.
func ParentContext(parent context.Context) func(next chi.Handler) chi.Handler {
	return func(next chi.Handler) chi.Handler {
		fn := func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
			pctx := context.WithValue(parent, chi.URLParamsCtxKey, chi.URLParams(ctx))
			next.ServeHTTPC(pctx, w, r)
		}
		return chi.HandlerFunc(fn)
	}
}
Ejemplo n.º 7
0
func GetLongID(next chi.Handler) chi.Handler {
	fn := func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
		ctx = context.WithValue(ctx, "id", strings.TrimPrefix(r.RequestURI, "/jobs/"))

		next.ServeHTTPC(ctx, w, r)
	}

	return chi.HandlerFunc(fn)
}
Ejemplo n.º 8
0
func (ja *JwtAuth) Handle(paramAliases ...string) func(chi.Handler) chi.Handler {
	return func(next chi.Handler) chi.Handler {
		hfn := func(ctx context.Context, w http.ResponseWriter, r *http.Request) {

			var tokenStr string
			var err error

			// Get token from query params
			tokenStr = r.URL.Query().Get("jwt")

			// Get token from other query param aliases
			if tokenStr == "" && paramAliases != nil && len(paramAliases) > 0 {
				for _, p := range paramAliases {
					tokenStr = r.URL.Query().Get(p)
					if tokenStr != "" {
						break
					}
				}
			}

			// Get token from authorization header
			if tokenStr == "" {
				bearer := r.Header.Get("Authorization")
				if len(bearer) > 7 && strings.ToUpper(bearer[0:6]) == "BEARER" {
					tokenStr = bearer[7:]
				}
			}

			// Get token from cookie
			if tokenStr == "" {
				cookie, err := r.Cookie("jwt")
				if err == nil {
					tokenStr = cookie.Value
				}
			}

			// Token is required, cya
			if tokenStr == "" {
				err = errUnauthorized
			}

			// Verify the token
			token, err := ja.Decode(tokenStr)
			if err != nil || !token.Valid || token.Method != ja.signer {
				utils.Respond(w, 401, errUnauthorized)
				return
			}

			ctx = context.WithValue(ctx, "jwt", token.Raw)
			ctx = context.WithValue(ctx, "jwt.token", token)

			next.ServeHTTPC(ctx, w, r)
		}
		return chi.HandlerFunc(hfn)
	}
}
Ejemplo n.º 9
0
Archivo: main.go Proyecto: raj347/chi
func AdminOnly(next chi.Handler) chi.Handler {
	return chi.HandlerFunc(func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
		isAdmin, ok := ctx.Value("acl.admin").(bool)
		if !ok || !isAdmin {
			http.Error(w, http.StatusText(403), 403)
			return
		}
		next.ServeHTTPC(ctx, w, r)
	})
}
Ejemplo n.º 10
0
// Airbrake recoverer middleware to capture and report any panics to
// airbrake.io.
func AirbrakeRecoverer(apiKey string) func(chi.Handler) chi.Handler {
	airbrake.ApiKey = apiKey

	return func(next chi.Handler) chi.Handler {
		return chi.HandlerFunc(func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
			if apiKey != "" {
				defer airbrake.CapturePanic(r)
			}
			next.ServeHTTPC(ctx, w, r)
		})
	}
}
Ejemplo n.º 11
0
Archivo: main.go Proyecto: raj347/chi
func ArticleCtx(next chi.Handler) chi.Handler {
	return chi.HandlerFunc(func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
		articleID := chi.URLParams(ctx)["articleID"]
		article, err := dbGetArticle(articleID)
		if err != nil {
			http.Error(w, http.StatusText(404), 404)
			return
		}
		ctx = context.WithValue(ctx, "article", article)
		next.ServeHTTPC(ctx, w, r)
	})
}
Ejemplo n.º 12
0
// AccessControl is an example that just prints the ACL route + operation that
// is being checked without actually doing any checking
func AccessControl(next chi.Handler) chi.Handler {
	hn := func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
		route, ok := ctx.Value("acl.route").([]string)
		if !ok {
			http.Error(w, "undefined acl route", 403)
			return
		}
		// Put ACL code here
		log.Printf("Checking permission to %s %s", r.Method, strings.Join(route, " -> "))

		next.ServeHTTPC(ctx, w, r)
	}
	return chi.HandlerFunc(hn)
}
Ejemplo n.º 13
0
// Timeout is a middleware that cancels ctx after a given timeout.
func Timeout(timeout time.Duration) func(next chi.Handler) chi.Handler {
	return func(next chi.Handler) chi.Handler {
		fn := func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
			ctx, cancel := context.WithTimeout(ctx, timeout)
			defer func() {
				cancel()
				if ctx.Err() == context.DeadlineExceeded {
					w.WriteHeader(StatusServerTimeout)
				}
			}()

			next.ServeHTTPC(ctx, w, r)
		}
		return chi.HandlerFunc(fn)
	}
}
Ejemplo n.º 14
0
// Recoverer is a middleware that recovers from panics, logs the panic (and a
// backtrace), and returns a HTTP 500 (Internal Server Error) status if
// possible.
//
// Recoverer prints a request ID if one is provided.
func Recoverer(next chi.Handler) chi.Handler {
	fn := func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
		defer func() {
			if err := recover(); err != nil {
				reqID := GetReqID(ctx)
				printPanic(reqID, err)
				debug.PrintStack()
				http.Error(w, http.StatusText(500), 500)
			}
		}()

		next.ServeHTTPC(ctx, w, r)
	}

	return chi.HandlerFunc(fn)
}
Ejemplo n.º 15
0
Archivo: logger.go Proyecto: raj347/chi
// Logger is a middleware that logs the start and end of each request, along
// with some useful data about what was requested, what the response status was,
// and how long it took to return. When standard output is a TTY, Logger will
// print in color, otherwise it will print in black and white.
//
// Logger prints a request ID if one is provided.
func Logger(next chi.Handler) chi.Handler {
	fn := func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
		reqID := GetReqID(ctx)
		prefix := requestPrefix(reqID, r)
		lw := wrapWriter(w)

		t1 := time.Now()
		defer func() {
			t2 := time.Now()
			printRequest(prefix, reqID, lw, t2.Sub(t1))
		}()

		next.ServeHTTPC(ctx, lw, r)
	}

	return chi.HandlerFunc(fn)
}
Ejemplo n.º 16
0
// Route builds a resource identifier from the middleware chain. This resource
// identifier along with the operation (HTTP verb) can be used for determining
// access to a resource
func Route(part string) func(chi.Handler) chi.Handler {
	return func(next chi.Handler) chi.Handler {
		hn := func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
			part = extractACLRoute(part)
			if part != "" {
				route, ok := ctx.Value("acl.route").([]string)
				if ok {
					route = append(route, part)
				} else {
					route = []string{part}
				}
				ctx = context.WithValue(ctx, "acl.route", route)
			}
			next.ServeHTTPC(ctx, w, r)
		}
		return chi.HandlerFunc(hn)
	}
}
Ejemplo n.º 17
0
// Authenticator is a default authentication middleware to enforce access following
// the Verifier middleware. The Authenticator sends a 401 Unauthorized response for
// all unverified tokens and passes the good ones through. It's just fine until you
// decide to write something similar and customize your client response.
func Authenticator(next chi.Handler) chi.Handler {
	return chi.HandlerFunc(func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
		if jwtErr, ok := ctx.Value("jwt.err").(error); ok {
			if jwtErr != nil {
				http.Error(w, http.StatusText(401), 401)
				return
			}
		}

		jwtToken, ok := ctx.Value("jwt").(*jwt.Token)
		if !ok || jwtToken == nil || !jwtToken.Valid {
			http.Error(w, http.StatusText(401), 401)
			return
		}

		// Token is authenticated, pass it through
		next.ServeHTTPC(ctx, w, r)
	})
}
Ejemplo n.º 18
0
func Router() chi.Router {
	r := chi.NewRouter()

	r.Use(middleware.RequestID)
	r.Use(middleware.RealIP)
	r.Use(middleware.Logger)
	r.Use(middleware.Recoverer)

	r.Use(func(h chi.Handler) chi.Handler {
		return chi.HandlerFunc(func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
			ctx = context.WithValue(ctx, "example", true)
			h.ServeHTTPC(ctx, w, r)
		})
	})

	r.Get("/", apiIndex)

	return r
}
Ejemplo n.º 19
0
Archivo: main.go Proyecto: vladdy/chi
func main() {
	r := chi.NewRouter()

	r.Use(middleware.RequestID)
	r.Use(middleware.RealIP)
	r.Use(middleware.Logger)

	r.Use(func(h chi.Handler) chi.Handler {
		return chi.HandlerFunc(func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
			ctx = context.WithValue(ctx, "example", true)
			h.ServeHTTPC(ctx, w, r)
		})
	})

	r.Get("/", apiIndex)

	r.Mount("/accounts", accountsRouter())

	http.ListenAndServe(":3333", r)
}
Ejemplo n.º 20
0
// Logger is a middleware that logs the start and end of each request, along
// with some useful data about what was requested, what the response status was,
// and how long it took to return. When standard output is a TTY, Logger will
// print in color, otherwise it will print in black and white.
//
// Logger prints a request ID if one is provided.
//
// Logger has been designed explicitly to be Good Enough for use in small
// applications and for people just getting started with Goji. It is expected
// that applications will eventually outgrow this middleware and replace it with
// a custom request logger, such as one that produces machine-parseable output,
// outputs logs to a different service (e.g., syslog), or formats lines like
// those printed elsewhere in the application.
func Logger(next chi.Handler) chi.Handler {
	fn := func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
		reqID := GetReqID(ctx)

		printStart(reqID, r)

		lw := wrapWriter(w)

		t1 := time.Now()
		next.ServeHTTPC(ctx, lw, r)

		if lw.Status() == 0 {
			lw.WriteHeader(http.StatusOK)
		}
		t2 := time.Now()

		printEnd(reqID, lw, t2.Sub(t1))
	}

	return chi.HandlerFunc(fn)
}
Ejemplo n.º 21
0
//BodyParser loads builder with maxSize and tries to load the message.
//if for some reason it can't parse the message, it will return an error.
//if successful, it will put the processed data into context with key 'json_body'
func BodyParser(builder func() interface{}, maxSize int64) func(chi.Handler) chi.Handler {
	return func(next chi.Handler) chi.Handler {
		return chi.HandlerFunc(func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
			to := builder()

			if err := utils.StreamJSONToStructWithLimit(r.Body, to, maxSize); err != nil {
				utils.Respond(w, 422, err)
				return
			}

			//check for required fields
			if err := utils.JSONValidation(to); err != nil {
				utils.Respond(w, 400, err)
				return
			}

			ctx = context.WithValue(ctx, constants.CtxKeyParsedBody, to)

			next.ServeHTTPC(ctx, w, r)
		})
	}
}
Ejemplo n.º 22
0
func TestMore(t *testing.T) {
	r := chi.NewRouter()

	// Protected routes
	r.Group(func(r chi.Router) {
		r.Use(TokenAuth.Verifier)

		authenticator := func(next chi.Handler) chi.Handler {
			return chi.HandlerFunc(func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
				if jwtErr, ok := ctx.Value("jwt.err").(error); ok {
					switch jwtErr {
					default:
						http.Error(w, http.StatusText(401), 401)
						return
					case jwtauth.ErrExpired:
						http.Error(w, "expired", 401)
						return
					case jwtauth.ErrUnauthorized:
						http.Error(w, http.StatusText(401), 401)
						return
					case nil:
						// no error
					}
				}

				jwtToken, ok := ctx.Value("jwt").(*jwt.Token)
				if !ok || jwtToken == nil || !jwtToken.Valid {
					http.Error(w, http.StatusText(401), 401)
					return
				}

				// Token is authenticated, pass it through
				next.ServeHTTPC(ctx, w, r)
			})
		}
		r.Use(authenticator)

		r.Get("/admin", func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
			w.Write([]byte("protected"))
		})
	})

	// Public routes
	r.Group(func(r chi.Router) {
		r.Get("/", func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
			w.Write([]byte("welcome"))
		})
	})

	ts := httptest.NewServer(r)
	defer ts.Close()

	// sending unauthorized requests
	if status, resp := testRequest(t, ts, "GET", "/admin", nil, nil); status != 401 && resp != "Unauthorized\n" {
		t.Fatalf(resp)
	}

	h := http.Header{}
	h.Set("Authorization", "BEARER "+newJwtToken([]byte("wrong"), map[string]interface{}{}))
	if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 && resp != "Unauthorized\n" {
		t.Fatalf(resp)
	}
	h.Set("Authorization", "BEARER asdf")
	if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 && resp != "Unauthorized\n" {
		t.Fatalf(resp)
	}

	h = newAuthHeader((jwtauth.Claims{}).Set("exp", jwtauth.EpochNow()-1000))
	if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 && resp != "expired\n" {
		t.Fatalf(resp)
	}

	// sending authorized requests
	if status, resp := testRequest(t, ts, "GET", "/", nil, nil); status != 200 && resp != "welcome" {
		t.Fatalf(resp)
	}

	h = newAuthHeader((jwtauth.Claims{}).SetExpiryIn(5 * time.Minute))
	if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 200 && resp != "protected" {
		t.Fatalf(resp)
	}
}
Ejemplo n.º 23
0
func (ja *JwtAuth) Verify(paramAliases ...string) func(chi.Handler) chi.Handler {
	return func(next chi.Handler) chi.Handler {
		hfn := func(ctx context.Context, w http.ResponseWriter, r *http.Request) {

			var tokenStr string
			var err error

			// Get token from query params
			tokenStr = r.URL.Query().Get("jwt")

			// Get token from other query param aliases
			if tokenStr == "" && paramAliases != nil && len(paramAliases) > 0 {
				for _, p := range paramAliases {
					tokenStr = r.URL.Query().Get(p)
					if tokenStr != "" {
						break
					}
				}
			}

			// Get token from authorization header
			if tokenStr == "" {
				bearer := r.Header.Get("Authorization")
				if len(bearer) > 7 && strings.ToUpper(bearer[0:6]) == "BEARER" {
					tokenStr = bearer[7:]
				}
			}

			// Get token from cookie
			if tokenStr == "" {
				cookie, err := r.Cookie("jwt")
				if err == nil {
					tokenStr = cookie.Value
				}
			}

			// Token is required, cya
			if tokenStr == "" {
				err = ErrUnauthorized
			}

			// Verify the token
			token, err := ja.Decode(tokenStr)
			if err != nil || !token.Valid || token.Method != ja.signer {
				switch err.Error() {
				case "token is expired":
					err = ErrExpired
				}

				ctx = ja.SetContext(ctx, token, err)
				next.ServeHTTPC(ctx, w, r)
				return
			}

			// Check expiry via "exp" claim
			if ja.IsExpired(token) {
				err = ErrExpired
				ctx = ja.SetContext(ctx, token, err)
				next.ServeHTTPC(ctx, w, r)
				return
			}

			// Valid! pass it down the context to an authenticator middleware
			ctx = ja.SetContext(ctx, token, err)
			next.ServeHTTPC(ctx, w, r)
		}
		return chi.HandlerFunc(hfn)
	}
}