// Compose together the middleware chain and wrap the handler with it func (z *Inlineware) On(handler interface{}) web.Handler { var wh web.Handler switch t := handler.(type) { case web.Handler: wh = t case func(web.C, http.ResponseWriter, *http.Request): wh = web.HandlerFunc(t) case func(http.ResponseWriter, *http.Request): wh = web.HandlerFunc(func(c web.C, w http.ResponseWriter, r *http.Request) { t(w, r) }) default: panic(fmt.Sprintf("unsupported handler type: %T", t)) } if len(z.middlewares) == 0 { return wh } m := z.wrap(z.middlewares[len(z.middlewares)-1])(wh) for i := len(z.middlewares) - 2; i >= 0; i-- { f := z.wrap(z.middlewares[i]) m = f(m) } return m }
// Test that our form helpers correctly inject a token into the response body. func TestFormToken(t *testing.T) { s := web.New() s.Use(Protect(testKey)) // Make the token available outside of the handler for comparison. var token string s.Get("/", web.HandlerFunc(func(c web.C, w http.ResponseWriter, r *http.Request) { token = Token(c, r) t := template.Must((template.New("base").Parse(testTemplate))) t.Execute(w, map[string]interface{}{ TemplateTag: TemplateField(c, r), }) })) r, err := http.NewRequest("GET", "/", nil) if err != nil { t.Fatal(err) } rr := httptest.NewRecorder() s.ServeHTTP(rr, r) if rr.Code != http.StatusOK { t.Fatalf("middleware failed to pass to the next handler: got %v want %v", rr.Code, http.StatusOK) } if len(token) != base64.StdEncoding.EncodedLen(tokenLength*2) { t.Fatalf("token length invalid: got %v want %v", len(token), base64.StdEncoding.EncodedLen(tokenLength*2)) } if !strings.Contains(rr.Body.String(), token) { t.Fatalf("token not in response body: got %v want %v", rr.Body.String(), token) } }
func TestTemplateField(t *testing.T) { s := web.New() CSRF := Protect( testKey, FieldName(testFieldName), ) s.Use(CSRF) var token string var customTemplateField string s.Handle("/", web.HandlerFunc(func(c web.C, w http.ResponseWriter, r *http.Request) { token = Token(c, r) customTemplateField = string(TemplateField(c, r)) })) r, err := http.NewRequest("GET", "/", nil) if err != nil { t.Fatal(err) } rr := httptest.NewRecorder() s.ServeHTTP(rr, r) expectedTemplateField := fmt.Sprintf(testTemplateField, testFieldName, token) if customTemplateField != expectedTemplateField { t.Fatalf("templateField not set correctly: got %v want %v", customTemplateField, expectedTemplateField) } }
func (application *Application) Route(action func(web.C, *http.Request) (string, int)) web.Handler { fn := func(c web.C, w http.ResponseWriter, r *http.Request) { c.Env["Content-Type"] = "text/html" body, code := action(c, r) if session, exists := c.Env["Session"]; exists { err := session.(*sessions.Session).Save(r, w) if err != nil { logrus.Errorf("Can't save session: %v", err) } } switch code { case http.StatusOK: if _, exists := c.Env["Content-Type"]; exists { w.Header().Set("Content-Type", c.Env["Content-Type"].(string)) } io.WriteString(w, body) case http.StatusSeeOther, http.StatusFound: http.Redirect(w, r, body, code) default: w.WriteHeader(code) io.WriteString(w, body) } } return web.HandlerFunc(fn) }
// Test that we can extract a CSRF token from a multipart form. func TestMultipartFormToken(t *testing.T) { s := web.New() s.Use(Protect(testKey)) // Make the token available outside of the handler for comparison. var token string s.Handle("/", web.HandlerFunc(func(c web.C, w http.ResponseWriter, r *http.Request) { token = Token(c, r) t := template.Must((template.New("base").Parse(testTemplate))) t.Execute(w, map[string]interface{}{ TemplateTag: TemplateField(c, r), }) })) r, err := http.NewRequest("GET", "/", nil) if err != nil { t.Fatal(err) } rr := httptest.NewRecorder() s.ServeHTTP(rr, r) // Set up our multipart form var b bytes.Buffer mp := multipart.NewWriter(&b) wr, err := mp.CreateFormField(fieldName) if err != nil { t.Fatal(err) } wr.Write([]byte(token)) mp.Close() r, err = http.NewRequest("POST", "/", &b) if err != nil { t.Fatal(err) } // Add the multipart header. r.Header.Set("Content-Type", mp.FormDataContentType()) // Send back the issued cookie. setCookie(rr, r) rr = httptest.NewRecorder() s.ServeHTTP(rr, r) if rr.Code != http.StatusOK { t.Fatalf("middleware failed to pass to the next handler: got %v want %v", rr.Code, http.StatusOK) } if body := rr.Body.String(); !strings.Contains(body, token) { t.Fatalf("token not in response body: got %v want %v", body, token) } }
// Protect is HTTP middleware that provides Cross-Site Request Forgery // protection. // // It securely generates a masked (unique-per-request) token that // can be embedded in the HTTP response (e.g. form field or HTTP header). // The original (unmasked) token is stored in the session, which is inaccessible // by an attacker (provided you are using HTTPS). Subsequent requests are // expected to include this token, which is compared against the session token. // Requests that do not provide a matching token are served with a HTTP 403 // 'Forbidden' error response. // // Example: // package main // // import ( // "github.com/goji/csrf" // "github.com/zenazn/goji" // ) // // func main() { // // Add the middleware to your router. // goji.Use(csrf.Protect([]byte("32-byte-long-auth-key"))) // goji.Get("/signup", GetSignupForm) // // POST requests without a valid token will return a HTTP 403 Forbidden. // goji.Post("/signup/post", PostSignupForm) // // goji.Serve() // } // // func GetSignupForm(c web.C, w http.ResponseWriter, r *http.Request) { // // signup_form.tmpl just needs a {{ .csrfField }} template tag for // // csrf.TemplateField to inject the CSRF token into. Easy! // t.ExecuteTemplate(w, "signup_form.tmpl", map[string]interface{ // csrf.TemplateTag: csrf.TemplateField(c, r), // }) // // We could also retrieve the token directly from csrf.Token(c, r) and // // set it in the request header - w.Header.Set("X-CSRF-Token", token) // // This is useful if your sending JSON to clients or a front-end JavaScript // // framework. // } // func Protect(authKey []byte, opts ...Option) func(*web.C, http.Handler) http.Handler { return func(c *web.C, h http.Handler) http.Handler { cs := parseOptions(h, opts...) // Set the defaults if no options have been specified if cs.opts.ErrorHandler == nil { cs.opts.ErrorHandler = web.HandlerFunc(unauthorizedHandler) } if cs.opts.MaxAge < 1 { // Default of 12 hours cs.opts.MaxAge = 3600 * 12 } if cs.opts.FieldName == "" { cs.opts.FieldName = fieldName } if cs.opts.CookieName == "" { cs.opts.CookieName = cookieName } if cs.opts.RequestHeader == "" { cs.opts.RequestHeader = headerName } // Create an authenticated securecookie instance. if cs.sc == nil { cs.sc = securecookie.New(authKey, nil) // Use JSON serialization (faster than one-off gob encoding) cs.sc.SetSerializer(securecookie.JSONEncoder{}) // Set the MaxAge of the underlying securecookie. cs.sc.MaxAge(cs.opts.MaxAge) } if cs.st == nil { // Default to the cookieStore cs.st = &cookieStore{ name: cs.opts.CookieName, maxAge: cs.opts.MaxAge, secure: cs.opts.Secure, httpOnly: cs.opts.HttpOnly, path: cs.opts.Path, domain: cs.opts.Domain, sc: cs.sc, } } // Initialize Goji's request context cs.c = c return *cs } }
// Test that idempotent methods return a 200 OK status and that non-idempotent // methods return a 403 Forbidden status when a CSRF cookie is not present. func TestMethods(t *testing.T) { s := web.New() s.Use(Protect(testKey)) s.Handle("/", web.HandlerFunc(func(c web.C, w http.ResponseWriter, r *http.Request) { })) // Test idempontent ("safe") methods for _, method := range safeMethods { r, err := http.NewRequest(method, "/", nil) if err != nil { t.Fatal(err) } rr := httptest.NewRecorder() s.ServeHTTP(rr, r) if rr.Code != http.StatusOK { t.Fatalf("middleware failed to pass to the next handler: got %v want %v", rr.Code, http.StatusOK) } if rr.Header().Get("Set-Cookie") == "" { t.Fatalf("cookie not set: got %q", rr.Header().Get("Set-Cookie")) } } // Test non-idempotent methods (should return a 403 without a cookie set) nonIdempotent := []string{"POST", "PUT", "DELETE", "PATCH"} for _, method := range nonIdempotent { r, err := http.NewRequest(method, "/", nil) if err != nil { t.Fatal(err) } rr := httptest.NewRecorder() s.ServeHTTP(rr, r) if rr.Code != http.StatusForbidden { t.Fatalf("middleware failed to pass to the next handler: got %v want %v", rr.Code, http.StatusOK) } if rr.Header().Get("Set-Cookie") == "" { t.Fatalf("cookie not set: got %q", rr.Header().Get("Set-Cookie")) } } }
// Requests with no Referer header should fail. func TestNoReferer(t *testing.T) { s := web.New() s.Use(Protect(testKey)) s.Handle("/", web.HandlerFunc(func(c web.C, w http.ResponseWriter, r *http.Request) {})) r, err := http.NewRequest("POST", "https://golang.org/", nil) if err != nil { t.Fatal(err) } rr := httptest.NewRecorder() s.ServeHTTP(rr, r) if rr.Code != http.StatusForbidden { t.Fatalf("middleware failed to pass to the next handler: got %v want %v", rr.Code, http.StatusForbidden) } }
// TestBadCookie tests for failure when a cookie header is modified (malformed). func TestBadCookie(t *testing.T) { s := web.New() CSRF := Protect(testKey) s.Use(CSRF) var token string s.Handle("/", web.HandlerFunc(func(c web.C, w http.ResponseWriter, r *http.Request) { token = Token(c, r) })) // Obtain a CSRF cookie via a GET request. r, err := http.NewRequest("GET", "http://www.gorillatoolkit.org/", nil) if err != nil { t.Fatal(err) } rr := httptest.NewRecorder() s.ServeHTTP(rr, r) // POST the token back in the header. r, err = http.NewRequest("POST", "http://www.gorillatoolkit.org/", nil) if err != nil { t.Fatal(err) } // Replace the cookie prefix badHeader := strings.Replace("_csrfToken=", rr.Header().Get("Set-Cookie"), "_badCookie", -1) r.Header.Set("Cookie", badHeader) r.Header.Set("X-CSRF-Token", token) r.Header.Set("Referer", "http://www.gorillatoolkit.org/") rr = httptest.NewRecorder() s.ServeHTTP(rr, r) if rr.Code != http.StatusForbidden { t.Fatalf("middleware failed to reject a bad cookie: got %v want %v", rr.Code, http.StatusForbidden) } }
// TestBadReferer checks that HTTPS requests with a Referer that does not // match the request URL correctly fail CSRF validation. func TestBadReferer(t *testing.T) { s := web.New() CSRF := Protect(testKey) s.Use(CSRF) var token string s.Handle("/", web.HandlerFunc(func(c web.C, w http.ResponseWriter, r *http.Request) { token = Token(c, r) })) // Obtain a CSRF cookie via a GET request. r, err := http.NewRequest("GET", "https://www.gorillatoolkit.org/", nil) if err != nil { t.Fatal(err) } rr := httptest.NewRecorder() s.ServeHTTP(rr, r) // POST the token back in the header. r, err = http.NewRequest("POST", "https://www.gorillatoolkit.org/", nil) if err != nil { t.Fatal(err) } setCookie(rr, r) r.Header.Set("X-CSRF-Token", token) // Set a non-matching Referer header. r.Header.Set("Referer", "http://goji.io") rr = httptest.NewRecorder() s.ServeHTTP(rr, r) if rr.Code != http.StatusForbidden { t.Fatalf("middleware failed to pass to the next handler: got %v want %v", rr.Code, http.StatusForbidden) } }
func (z *Inlineware) wrap(middleware interface{}) func(web.Handler) web.Handler { fn := func(wh web.Handler) web.Handler { return web.HandlerFunc(func(c web.C, w http.ResponseWriter, r *http.Request) { newFn := func(ww http.ResponseWriter, rr *http.Request) { wh.ServeHTTPC(c, ww, rr) } var fn http.HandlerFunc switch mw := middleware.(type) { default: panic(fmt.Sprintf("unsupported middleware type: %T", mw)) case func(http.Handler) http.Handler: fn = mw(http.HandlerFunc(newFn)).ServeHTTP case func(*web.C, http.Handler) http.Handler: fn = mw(&c, http.HandlerFunc(newFn)).ServeHTTP } fn(w, r) }) } return fn }
func (r *Mux) Mount(path string, handlers ...interface{}) { h := append(r.middlewares, handlers...) subRouter := Use(middleware.SubRouter).On(r.chain(h...)) subRouterIndex := web.HandlerFunc(func(c web.C, w http.ResponseWriter, r *http.Request) { if c.URLParams == nil { c.URLParams = make(map[string]string) } c.URLParams["*"] = "/" subRouter.ServeHTTPC(c, w, r) }) if path == "/" { path = "" } r.Mux.Get(path, subRouterIndex) r.Mux.Handle(path, subRouterIndex) if path != "" { r.Mux.Handle(path+"/", http.NotFound) } r.Mux.Handle(path+"/*", subRouter) }
// Tests that options functions are applied to the middleware. func TestOptions(t *testing.T) { var h http.Handler age := 86400 domain := "goji.io" path := "/forms/" header := "X-AUTH-TOKEN" field := "authenticity_token" errorHandler := unauthorizedHandler name := "_goji_goji_goji" testOpts := []Option{ MaxAge(age), Domain(domain), Path(path), HttpOnly(false), Secure(false), RequestHeader(header), FieldName(field), ErrorHandler(web.HandlerFunc(errorHandler)), CookieName(name), } // Parse our test options and check that they set the related struct fields. cs := parseOptions(h, testOpts...) if cs.opts.MaxAge != age { t.Errorf("MaxAge not set correctly: got %v want %v", cs.opts.MaxAge, age) } if cs.opts.Domain != domain { t.Errorf("Domain not set correctly: got %v want %v", cs.opts.Domain, domain) } if cs.opts.Path != path { t.Errorf("Path not set correctly: got %v want %v", cs.opts.Path, path) } if cs.opts.HttpOnly != false { t.Errorf("HttpOnly not set correctly: got %v want %v", cs.opts.HttpOnly, false) } if cs.opts.Secure != false { t.Errorf("Secure not set correctly: got %v want %v", cs.opts.Secure, false) } if cs.opts.RequestHeader != header { t.Errorf("RequestHeader not set correctly: got %v want %v", cs.opts.RequestHeader, header) } if cs.opts.FieldName != field { t.Errorf("FieldName not set correctly: got %v want %v", cs.opts.FieldName, field) } if !reflect.ValueOf(cs.opts.ErrorHandler).IsValid() { t.Errorf("ErrorHandler not set correctly: got %v want %v", reflect.ValueOf(cs.opts.ErrorHandler).IsValid(), reflect.ValueOf(errorHandler).IsValid()) } if cs.opts.CookieName != name { t.Errorf("CookieName not set correctly: got %v want %v", cs.opts.CookieName, name) } }
package csrf import ( "net/http" "net/http/httptest" "strings" "testing" "github.com/zenazn/goji/web" ) var testKey = []byte("keep-it-secret-keep-it-safe-----") var testHandler = web.HandlerFunc(func(c web.C, w http.ResponseWriter, r *http.Request) {}) // TestProtect is a high-level test to make sure the middleware returns the // wrapped handler with a 200 OK status. func TestProtect(t *testing.T) { s := web.New() s.Use(Protect(testKey)) s.Get("/", testHandler) r, err := http.NewRequest("GET", "/", nil) if err != nil { t.Fatal(err) } rr := httptest.NewRecorder() s.ServeHTTP(rr, r) if rr.Code != http.StatusOK {
// handler transformation xhandler.HandlerC -> web.Handler func handle(ctx context.Context, handlerc xhandler.HandlerC) web.Handler { return web.HandlerFunc(func(c web.C, w http.ResponseWriter, r *http.Request) { newctx := context.WithValue(ctx, "urlparams", c.URLParams) handlerc.ServeHTTPC(newctx, w, r) }) }