Ejemplo n.º 1
0
// Tests that the context is cleared or not cleared properly depending on
// the configuration of the router
func TestKeepContext(t *testing.T) {
	func1 := func(w http.ResponseWriter, r *http.Request) {}

	r := NewRouter()
	r.HandleFunc("/", func1).Name("func1")

	req, _ := http.NewRequest("GET", "http://localhost/", nil)
	context.Set(req, "t", 1)

	res := new(http.ResponseWriter)
	r.ServeHTTP(*res, req)

	if _, ok := context.GetOk(req, "t"); ok {
		t.Error("Context should have been cleared at end of request")
	}

	r.KeepContext = true

	req, _ = http.NewRequest("GET", "http://localhost/", nil)
	context.Set(req, "t", 1)

	r.ServeHTTP(*res, req)
	if _, ok := context.GetOk(req, "t"); !ok {
		t.Error("Context should NOT have been cleared at end of request")
	}

}
Ejemplo n.º 2
0
// envError stores a CSRF error in the request context.
func envError(r *http.Request, err error) {
	context.Set(r, errorKey, err)
}
Ejemplo n.º 3
0
// Implements http.Handler for the csrf type.
func (cs *csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) {
	// Retrieve the token from the session.
	// An error represents either a cookie that failed HMAC validation
	// or that doesn't exist.
	realToken, err := cs.st.Get(r)
	if err != nil || len(realToken) != tokenLength {
		// If there was an error retrieving the token, the token doesn't exist
		// yet, or it's the wrong length, generate a new token.
		// Note that the new token will (correctly) fail validation downstream
		// as it will no longer match the request token.
		realToken, err = generateRandomBytes(tokenLength)
		if err != nil {
			envError(r, err)
			cs.opts.ErrorHandler.ServeHTTP(w, r)
			return
		}

		// Save the new (real) token in the session store.
		err = cs.st.Save(realToken, w)
		if err != nil {
			envError(r, err)
			cs.opts.ErrorHandler.ServeHTTP(w, r)
			return
		}
	}

	// Save the masked token to the request context
	context.Set(r, tokenKey, mask(realToken, r))

	// HTTP methods not defined as idempotent ("safe") under RFC7231 require
	// inspection.
	if !contains(safeMethods, r.Method) {
		// Enforce an origin check for HTTPS connections. As per the Django CSRF
		// implementation (https://goo.gl/vKA7GE) the Referer header is almost
		// always present for same-domain HTTP requests.
		if r.URL.Scheme == "https" {
			// Fetch the Referer value. Call the error handler if it's empty or
			// otherwise fails to parse.
			referer, err := url.Parse(r.Referer())
			if err != nil || referer.String() == "" {
				envError(r, ErrNoReferer)
				cs.opts.ErrorHandler.ServeHTTP(w, r)
				return
			}

			if sameOrigin(r.URL, referer) == false {
				envError(r, ErrBadReferer)
				cs.opts.ErrorHandler.ServeHTTP(w, r)
				return
			}
		}

		// If the token returned from the session store is nil for non-idempotent
		// ("unsafe") methods, call the error handler.
		if realToken == nil {
			envError(r, ErrNoToken)
			cs.opts.ErrorHandler.ServeHTTP(w, r)
			return
		}

		// Retrieve the combined token (pad + masked) token and unmask it.
		requestToken := unmask(cs.requestToken(r))

		// Compare the request token against the real token
		if !compareTokens(requestToken, realToken) {
			envError(r, ErrBadToken)
			cs.opts.ErrorHandler.ServeHTTP(w, r)
			return
		}

	}

	// Set the Vary: Cookie header to protect clients from caching the response.
	w.Header().Add("Vary", "Cookie")

	// Call the wrapped handler/router on success.
	cs.h.ServeHTTP(w, r)
	// Clear the request context after the handler has completed.
	context.Clear(r)
}
Ejemplo n.º 4
0
func setCurrentRoute(r *http.Request, val interface{}) {
	context.Set(r, routeKey, val)
}
Ejemplo n.º 5
0
func setVars(r *http.Request, val interface{}) {
	context.Set(r, varsKey, val)
}