func SessionGet(app *forest.App, manager SessionManager) func(ctx *bear.Context) { return func(ctx *bear.Context) { cookieName := forest.SessionID createEmptySession := func(sessionID string) { path := app.Config.CookiePath if path == "" { path = "/" } cookieValue := sessionID duration := app.Duration("Cookie") // Reset the cookie. app.SetCookie(ctx, path, cookieName, cookieValue, duration) manager.CreateEmpty(sessionID, ctx) ctx.Next() } cookie, err := ctx.Request.Cookie(cookieName) if err != nil || cookie.Value == "" { createEmptySession(uuid.New()) return } sessionID := cookie.Value userID, userJSON, err := manager.Read(sessionID) if err != nil || userID == "" || userJSON == "" { createEmptySession(uuid.New()) return } if err := manager.Create(sessionID, userID, userJSON, ctx); err != nil { println(fmt.Sprintf("error creating session: %s", err)) defer func(sessionID string, userID string) { if err := manager.Delete(sessionID, userID); err != nil { println(fmt.Sprintf("error deleting session: %s", err)) } }(sessionID, userID) createEmptySession(uuid.New()) return } // If SessionRefresh is set to false, the session will not refresh; // if it's not set or if it's set to true, the session is refreshed. refresh, ok := ctx.Get(forest.SessionRefresh).(bool) if !ok || refresh { path := app.Config.CookiePath if path == "" { path = "/" } cookieName := forest.SessionID cookieValue := sessionID duration := app.Duration("Cookie") // Refresh the cookie. app.SetCookie(ctx, path, cookieName, cookieValue, duration) err := manager.Update(sessionID, userID, userJSON, app.Duration("Session")) if err != nil { println(fmt.Sprintf("error updating session: %s", err)) } } ctx.Next() } }
func ErrorsUnauthorized(app *forest.App) func(ctx *bear.Context) { return func(ctx *bear.Context) { app.Response( ctx, http.StatusUnauthorized, forest.Failure, safeErrorMessage(app, ctx, app.Error("Unauthorized"))).Write(nil) } }
func ErrorsServerError(app *forest.App) func(ctx *bear.Context) { return func(ctx *bear.Context) { app.Response( ctx, http.StatusInternalServerError, forest.Failure, safeErrorMessage(app, ctx, app.Error("Generic"))).Write(nil) } }
func ErrorsBadRequest(app *forest.App) func(ctx *bear.Context) { return func(ctx *bear.Context) { app.Response( ctx, http.StatusBadRequest, forest.Failure, safeErrorMessage(app, ctx, app.Error("Generic"))).Write(nil) } }
func Authenticate(app *forest.App) func(ctx *bear.Context) { return func(ctx *bear.Context) { userID, ok := ctx.Get(forest.SessionUserID).(string) if !ok || len(userID) == 0 { app.Response(ctx, http.StatusUnauthorized, forest.Failure, app.Error("Unauthorized")).Write(nil) return } ctx.Next() } }
func safeErrorFilter(app *forest.App, err error, friendly string) error { if app.Config.Debug { return err } else { if app.SafeErrorFilter != nil { if err := app.SafeErrorFilter(err); err != nil { return err } else { return fmt.Errorf(friendly) } } else { return fmt.Errorf(friendly) } } }
func makeRequest(t *testing.T, app *forest.App, params *requested, want *wanted) (*http.Response, *forest.Response) { var request *http.Request method := params.method auth := params.auth path := params.path body := params.body if body != nil { request, _ = http.NewRequest(method, path, bytes.NewBuffer(body)) } else { request, _ = http.NewRequest(method, path, nil) } if len(auth) > 0 { request.AddCookie(&http.Cookie{Name: forest.SessionID, Value: auth}) } response := httptest.NewRecorder() app.ServeHTTP(response, request) responseData := new(forest.Response) responseBody, err := ioutil.ReadAll(response.Body) if err != nil { t.Error(err) return nil, responseData } if err := json.Unmarshal(responseBody, responseData); err != nil { t.Errorf("unmarshal error: %v when attempting to read: %s", err, string(responseBody)) return nil, responseData } if response.Code != want.code { t.Errorf("%s %s want: %d (%s) got: %d %s, body: %s", method, path, want.code, http.StatusText(want.code), response.Code, http.StatusText(response.Code), string(responseBody)) return nil, responseData } if responseData.Success != want.success { t.Errorf("%s %s should return success: %t", method, path, want.success) return nil, responseData } return &http.Response{Header: response.Header()}, responseData }
func InstallSessionWares(app *forest.App, manager SessionManager) { app.InstallWare("SessionDel", SessionDel(app, manager), forest.WareInstalled) app.InstallWare("SessionGet", SessionGet(app, manager), forest.WareInstalled) app.InstallWare("SessionSet", SessionSet(app, manager), forest.WareInstalled) }
func ErrorsNotFound(app *forest.App) func(ctx *bear.Context) { return func(ctx *bear.Context) { message := safeErrorMessage(app, ctx, app.Error("NotFound")) app.Response(ctx, http.StatusNotFound, forest.Failure, message).Write(nil) } }
func SessionDel(app *forest.App, manager SessionManager) func(ctx *bear.Context) { return func(ctx *bear.Context) { sessionID, ok := ctx.Get(forest.SessionID).(string) if !ok { err := fmt.Errorf("SessionDel %s: %v", forest.SessionID, ctx.Get(forest.SessionID)) ctx.Set(forest.Error, err) message := safeErrorMessage(app, ctx, app.Error("Generic")) app.Response(ctx, http.StatusInternalServerError, forest.Failure, message).Write(nil) return } userID, ok := ctx.Get(forest.SessionUserID).(string) if !ok { err := fmt.Errorf("SessionDel %s: %v", forest.SessionUserID, ctx.Get(forest.SessionUserID)) ctx.Set(forest.Error, err) message := safeErrorMessage(app, ctx, app.Error("Generic")) app.Response(ctx, http.StatusInternalServerError, forest.Failure, message).Write(nil) return } if err := manager.Delete(sessionID, userID); err != nil { ctx.Set(forest.Error, err) message := safeErrorMessage(app, ctx, app.Error("Generic")) app.Response(ctx, http.StatusInternalServerError, forest.Failure, message).Write(nil) return } ctx.Next() } }
func BodyParser(app *forest.App) func(ctx *bear.Context) { return func(ctx *bear.Context) { destination, ok := ctx.Get(forest.Body).(Populater) if !ok { ctx.Set(forest.Error, fmt.Errorf("(*forest.App).BodyParser unitialized")) message := safeErrorMessage(app, ctx, app.Error("Parse")) app.Response(ctx, http.StatusInternalServerError, forest.Failure, message).Write(nil) return } if ctx.Request.Body == nil { ctx.Set(forest.SafeError, fmt.Errorf("%s: body is empty", app.Error("Parse"))) message := safeErrorMessage(app, ctx, app.Error("Parse")) app.Response(ctx, http.StatusBadRequest, forest.Failure, message).Write(nil) return } if err := destination.Populate(ctx.Request.Body); err != nil { ctx.Set(forest.SafeError, fmt.Errorf("%s: %s", app.Error("Parse"), err)) message := safeErrorMessage(app, ctx, app.Error("Parse")) app.Response(ctx, http.StatusBadRequest, forest.Failure, message).Write(nil) return } ctx.Next() } }
func InstallSecurityWares(app *forest.App) { app.InstallWare("Authenticate", Authenticate(app), forest.WareInstalled) app.InstallWare("CSRF", CSRF(app), forest.WareInstalled) }
func InstallErrorWares(app *forest.App) { app.InstallWare("BadRequest", ErrorsBadRequest(app), forest.WareInstalled) app.InstallWare("Conflict", ErrorsConflict(app), forest.WareInstalled) app.InstallWare("MethodNotAllowed", ErrorsMethodNotAllowed(app), forest.WareInstalled) app.InstallWare("NotFound", ErrorsNotFound(app), forest.WareInstalled) app.InstallWare("ServerError", ErrorsServerError(app), forest.WareInstalled) app.InstallWare("Unauthorized", ErrorsUnauthorized(app), forest.WareInstalled) }
func InstallBodyParser(app *forest.App) { app.InstallWare("BodyParser", BodyParser(app), forest.WareInstalled) }
func CSRF(app *forest.App) func(ctx *bear.Context) { type postBody struct { SessionID string `json:"sessionid"` // forest.SessionID == "sessionid" } return func(ctx *bear.Context) { if ctx.Request.Body == nil { app.Response(ctx, http.StatusBadRequest, forest.Failure, app.Error("CSRF")).Write(nil) return } pb := new(postBody) body, _ := ioutil.ReadAll(ctx.Request.Body) if body == nil || len(body) < 2 { // smallest JSON body is {}, 2 chars app.Response( ctx, http.StatusBadRequest, forest.Failure, app.Error("Parse")).Write(nil) return } // set ctx.Request.Body back to an untouched io.ReadCloser ctx.Request.Body = ioutil.NopCloser(bytes.NewBuffer(body)) if err := json.Unmarshal(body, pb); err != nil { app.Response( ctx, http.StatusBadRequest, forest.Failure, app.Error("Parse")+": "+err.Error()).Write(nil) return } sessionID, ok := ctx.Get(forest.SessionID).(string) if !ok || sessionID != pb.SessionID { app.Response( ctx, http.StatusBadRequest, forest.Failure, app.Error("CSRF")).Write(nil) return } ctx.Next() } }