示例#1
0
// GetSession returns the current session from the context,
// if the Session middleware is in use.
func GetSession(c context.Context, r *http.Request) context.Session {
	if val, ok := c.Get(r, context.BaseCtxKey("session")); ok {
		return val.(context.Session)
	}

	conf := GetConfig(c)
	var abspath string

	if filepath.IsAbs(conf.Session.Dir) {
		abspath = conf.Session.Dir
	} else {
		var err error
		abspath, err = filepath.Abs(path.Join(filepath.Dir(os.Args[0]), conf.Session.Dir))

		if err != nil {
			abspath = os.TempDir()
		}
	}

	sess := context.NewSession([]byte(conf.Session.Secret), []byte(conf.Session.Cipher), abspath)
	sess.SetName(util.UUID())
	return sess
}
示例#2
0
文件: session.go 项目: urandom/webfw
func (smw Session) Handler(ph http.Handler, c context.Context) http.Handler {
	var abspath string
	var maxAge, cleanupInterval, cleanupMaxAge time.Duration

	if filepath.IsAbs(smw.Path) {
		abspath = smw.Path
	} else {
		var err error
		abspath, err = filepath.Abs(path.Join(filepath.Dir(os.Args[0]), smw.Path))

		if err != nil {
			panic(err)
		}
	}

	if smw.MaxAge != "" {
		var err error
		maxAge, err = time.ParseDuration(smw.MaxAge)

		if err != nil {
			panic(err)
		}
	}

	logger := webfw.GetLogger(c)

	if smw.CleanupInterval != "" {
		var err error
		cleanupInterval, err = time.ParseDuration(smw.CleanupInterval)

		if err != nil {
			panic(err)
		}

		cleanupMaxAge, err = time.ParseDuration(smw.CleanupMaxAge)

		if err != nil {
			panic(err)
		}

		go func() {
			for _ = range time.Tick(cleanupInterval) {
				logger.Print("Cleaning up old sessions")

				if err := context.CleanupSessions(abspath, cleanupMaxAge); err != nil {
					logger.Printf("Failed to clean up sessions: %v", err)
				}
			}
		}()
	}

	handler := func(w http.ResponseWriter, r *http.Request) {
		uriParts := strings.SplitN(r.RequestURI, "?", 2)
		if uriParts[0] == "" {
			uriParts[0] = r.URL.Path
		}

		ignore := false
		for _, prefix := range smw.IgnoreURLPrefix {
			if prefix[0] == '/' {
				prefix = prefix[1:]
			}

			if strings.HasPrefix(uriParts[0], smw.Pattern+prefix+"/") {
				ignore = true
				break
			}

			if uriParts[0] == smw.Pattern+prefix {
				ignore = true
				break
			}
		}

		if ignore {
			ph.ServeHTTP(w, r)
			return
		}

		firstTimer := false
		var sess context.Session

		if smw.SessionGenerator == nil {
			sess = context.NewSession(smw.Secret, smw.Cipher, abspath)
		} else {
			sess = smw.SessionGenerator(smw.Secret, smw.Cipher, abspath)
		}
		sess.SetMaxAge(maxAge)

		err := sess.Read(r, c)

		if err != nil && err != context.ErrExpired && err != context.ErrNotExist {
			sess.SetName(util.UUID())
			firstTimer = true

			if err != context.ErrCookieNotExist {
				logger.Printf("Error reading session: %v", err)
			}
		}

		c.Set(r, context.BaseCtxKey("session"), sess)
		c.Set(r, context.BaseCtxKey("firstTimer"), firstTimer)

		rec := util.NewRecorderHijacker(w)

		ph.ServeHTTP(rec, r)

		for k, v := range rec.Header() {
			w.Header()[k] = v
		}

		if sess != nil {
			if err := sess.Write(w); err != nil {
				logger.Printf("Unable to write session: %v", err)
			}
		}

		w.WriteHeader(rec.GetCode())
		w.Write(rec.GetBody().Bytes())
	}

	return http.HandlerFunc(handler)
}
示例#3
0
func TestI18NHandler(t *testing.T) {
	c := context.NewContext()
	ren := renderer.NewRenderer("testdata", "test.tmpl")
	c.SetGlobal(context.BaseCtxKey("renderer"), ren)
	mw := I18N{
		Languages:       []string{"en", "bg"},
		Pattern:         "/",
		Dir:             "testdata",
		IgnoreURLPrefix: []string{"/css", "/js"},
	}

	h := mw.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		if err := ren.Render(w, nil, c.GetAll(r), "test_i18n.tmpl"); err != nil {
			t.Fatal(err)
		}
	}), c)

	r, _ := http.NewRequest("GET", "http://localhost:8080", nil)
	rec := httptest.NewRecorder()

	h.ServeHTTP(rec, r)

	if rec.Code != http.StatusFound {
		t.Fatalf("Expected code %d, got %d\n", http.StatusFound, rec.Code)
	}

	r, _ = http.NewRequest("GET", "http://localhost:8080/bg/", nil)
	r.RequestURI = "/bg/"
	rec = httptest.NewRecorder()

	h.ServeHTTP(rec, r)

	if rec.Code != http.StatusOK {
		t.Fatalf("Expected code %d, got %d\n", http.StatusOK, rec.Code)
	}

	expected := "test data bg"
	if !strings.Contains(rec.Body.String(), expected) {
		t.Fatalf("Expected body '%s' to contain '%s'\n", rec.Body.String(), expected)
	}

	r, _ = http.NewRequest("GET", "http://localhost:8080/en/", nil)
	r.RequestURI = "/en/"
	rec = httptest.NewRecorder()

	s := context.NewSession([]byte(""), nil, os.TempDir())
	s.SetName("test1")
	c.Set(r, context.BaseCtxKey("session"), s)
	h.ServeHTTP(rec, r)

	if rec.Code != http.StatusOK {
		t.Fatalf("Expected code %d, got %d\n", http.StatusOK, rec.Code)
	}

	expected = "test data en"
	if !strings.Contains(rec.Body.String(), expected) {
		t.Fatalf("Expected body '%s' to contain '%s'\n", rec.Body.String(), expected)
	}

	if s, ok := s.Get("language"); ok {
		if s.(string) != "en" {
			t.Fatalf("Expected session.language to be '%s', got '%s'\n", "en", s.(string))
		}
	} else {
		t.Fatalf("Expected the session to have a language key\n")
	}

	r, _ = http.NewRequest("GET", "http://localhost:8080/en", nil)
	r.RequestURI = "/en"
	rec = httptest.NewRecorder()

	h.ServeHTTP(rec, r)

	if rec.Code != http.StatusFound {
		t.Fatalf("Expected code %d, got %d\n", http.StatusFound, rec.Code)
	}

	expected = "/en/"
	if rec.Header().Get("Location") != expected {
		t.Fatalf("Expected a redirect to '%s', got '%s'\n", expected, rec.Header().Get("Location"))
	}

	r, _ = http.NewRequest("GET", "http://localhost:8080/foo/bar/baz", nil)
	r.RequestURI = "/foo/bar/baz"
	rec = httptest.NewRecorder()

	h.ServeHTTP(rec, r)

	if rec.Code != http.StatusFound {
		t.Fatalf("Expected code %d, got %d\n", http.StatusFound, rec.Code)
	}

	expected = "/en/foo/bar/baz"
	if rec.Header().Get("Location") != expected {
		t.Fatalf("Expected a redirect to '%s', got '%s'\n", expected, rec.Header().Get("Location"))
	}

	r, _ = http.NewRequest("GET", "http://localhost:8080/foo/bar/baz?alpha=beta&gamma=delta", nil)
	r.RequestURI = "/foo/bar/baz?alpha=beta&gamma=delta"
	rec = httptest.NewRecorder()

	h.ServeHTTP(rec, r)

	if rec.Code != http.StatusFound {
		t.Fatalf("Expected code %d, got %d\n", http.StatusFound, rec.Code)
	}

	expected = "/en/foo/bar/baz?alpha=beta&gamma=delta"
	if rec.Header().Get("Location") != expected {
		t.Fatalf("Expected a redirect to '%s', got '%s'\n", expected, rec.Header().Get("Location"))
	}

	r, _ = http.NewRequest("GET", "http://localhost:8080/foo/bar/baz?alpha=beta&gamma=delta#test", nil)
	r.RequestURI = "/foo/bar/baz?alpha=beta&gamma=delta#test"
	rec = httptest.NewRecorder()

	h.ServeHTTP(rec, r)

	if rec.Code != http.StatusFound {
		t.Fatalf("Expected code %d, got %d\n", http.StatusFound, rec.Code)
	}

	expected = "/en/foo/bar/baz?alpha=beta&gamma=delta#test"
	if rec.Header().Get("Location") != expected {
		t.Fatalf("Expected a redirect to '%s', got '%s'\n", expected, rec.Header().Get("Location"))
	}

	r, _ = http.NewRequest("GET", "http://localhost:8080/css/foo", nil)
	r.RequestURI = "/css/foo"
	rec = httptest.NewRecorder()

	h = mw.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
	}), c)

	h.ServeHTTP(rec, r)

	if rec.Code != http.StatusOK {
		t.Fatalf("Expected code %d, got %d\n", http.StatusOK, rec.Code)
	}

	r, _ = http.NewRequest("GET", "http://localhost:8080/js/foo", nil)
	r.RequestURI = "/js/foo"
	rec = httptest.NewRecorder()

	h.ServeHTTP(rec, r)

	if rec.Code != http.StatusOK {
		t.Fatalf("Expected code %d, got %d\n", http.StatusOK, rec.Code)
	}
}