func TestCSRF(t *testing.T) { e := vodka.New() req := test.NewRequest(vodka.GET, "/", nil) rec := test.NewResponseRecorder() c := e.NewContext(req, rec) csrf := CSRFWithConfig(CSRFConfig{ TokenLength: 16, }) h := csrf(func(c vodka.Context) error { return c.String(http.StatusOK, "test") }) // Generate CSRF token h(c) assert.Contains(t, rec.Header().Get(vodka.HeaderSetCookie), "_csrf") // Without CSRF cookie req = test.NewRequest(vodka.POST, "/", nil) rec = test.NewResponseRecorder() c = e.NewContext(req, rec) assert.Error(t, h(c)) // Empty/invalid CSRF token req = test.NewRequest(vodka.POST, "/", nil) rec = test.NewResponseRecorder() c = e.NewContext(req, rec) req.Header().Set(vodka.HeaderXCSRFToken, "") assert.Error(t, h(c)) // Valid CSRF token token := random.String(16) req.Header().Set(vodka.HeaderCookie, "_csrf="+token) req.Header().Set(vodka.HeaderXCSRFToken, token) if assert.NoError(t, h(c)) { assert.Equal(t, http.StatusOK, rec.Status()) } }
// CSRFWithConfig returns a CSRF middleware with config. // See `CSRF()`. func CSRFWithConfig(config CSRFConfig) vodka.MiddlewareFunc { // Defaults if config.Skipper == nil { config.Skipper = DefaultCSRFConfig.Skipper } if config.TokenLength == 0 { config.TokenLength = DefaultCSRFConfig.TokenLength } if config.TokenLookup == "" { config.TokenLookup = DefaultCSRFConfig.TokenLookup } if config.ContextKey == "" { config.ContextKey = DefaultCSRFConfig.ContextKey } if config.CookieName == "" { config.CookieName = DefaultCSRFConfig.CookieName } if config.CookieMaxAge == 0 { config.CookieMaxAge = DefaultCSRFConfig.CookieMaxAge } // Initialize parts := strings.Split(config.TokenLookup, ":") extractor := csrfTokenFromHeader(parts[1]) switch parts[0] { case "form": extractor = csrfTokenFromForm(parts[1]) case "query": extractor = csrfTokenFromQuery(parts[1]) } return func(next vodka.HandlerFunc) vodka.HandlerFunc { return func(c vodka.Context) error { if config.Skipper(c) { return next(c) } req := c.Request() k, err := c.Cookie(config.CookieName) token := "" if err != nil { // Generate token token = random.String(config.TokenLength) } else { // Reuse token token = k.Value() } switch req.Method() { case vodka.GET, vodka.HEAD, vodka.OPTIONS, vodka.TRACE: default: // Validate token only for requests which are not defined as 'safe' by RFC7231 clientToken, err := extractor(c) if err != nil { return err } if !validateCSRFToken(token, clientToken) { return vodka.NewHTTPError(http.StatusForbidden, "csrf token is invalid") } } // Set CSRF cookie cookie := new(vodka.Cookie) cookie.SetName(config.CookieName) cookie.SetValue(token) if config.CookiePath != "" { cookie.SetPath(config.CookiePath) } if config.CookieDomain != "" { cookie.SetDomain(config.CookieDomain) } cookie.SetExpires(time.Now().Add(time.Duration(config.CookieMaxAge) * time.Second)) cookie.SetSecure(config.CookieSecure) cookie.SetHTTPOnly(config.CookieHTTPOnly) c.SetCookie(cookie) // Store token in the context c.Set(config.ContextKey, token) // Protect clients from caching the response c.Response().Header().Add(vodka.HeaderVary, vodka.HeaderCookie) return next(c) } } }