Ejemplo n.º 1
0
func main() {
	// Compose chains together middleware.
	m := httpware.Compose(
		httpware.DefaultErrHandler,
		contentware.New(contentware.Defaults),
		logware.New(logware.Defaults),
	)

	http.ListenAndServe("localhost:8080", m.ThenFunc(handle))
}
Ejemplo n.º 2
0
func TestLog(t *testing.T) {
	var buffer bytes.Buffer
	conf := Defaults
	conf.Logger.Out = &buffer
	m := httpware.Compose(
		httpware.DefaultErrHandler,
		New(conf),
	)
	s := httptest.NewServer(m.ThenFunc(func(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
		if r.URL.Path == "/400" {
			return httpware.NewErr("didnt like your request", http.StatusBadRequest)
		}
		if r.URL.Path == "/500" {
			return httpware.NewErr("ahhhh it blew up", http.StatusInternalServerError)
		}
		if r.URL.Path == "/panic" {
			panic("PANIC!")
		}
		return nil
	}))

	cases := []struct {
		Path     string
		Expected string
	}{
		{
			Path:     "/",
			Expected: "success",
		},
		{
			Path:     "/400",
			Expected: "didnt like your request",
		},
		{
			Path:     "/500",
			Expected: "ahhhh it blew up",
		},
		{
			Path:     "/panic",
			Expected: "PANIC!",
		},
	}
	for _, c := range cases {
		resp, err := http.Get(s.URL + c.Path)
		if err != nil {
			t.Fatal("failed to make request:", err)
		}
		resp.Body.Close()
		got := buffer.String()
		if !strings.Contains(got, c.Expected) {
			t.Fatalf("expected log output to contain: '%s', got: \n%s", c.Expected, got)
		}
	}
}
Ejemplo n.º 3
0
func TestResponse(t *testing.T) {
	c := httpware.Compose(
		httpware.DefaultErrHandler,
		New(Defaults),
	)
	s := httptest.NewServer(
		c.Then(
			httpware.HandlerFunc(func(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
				ct := ResponseTypeFromCtx(ctx)
				switch r.URL.Path {
				case "/test-json":
					if ct.Key != httpware.JSON {
						t.Fatal("expected json type")
					}
					return nil
				case "/test-xml":
					if ct.Key != httpware.XML {
						t.Fatal("expected xml type")
					}
					return nil
				}
				t.Fatal("this point should never have been reached")
				return nil
			}),
		),
	)

	hc := http.Client{}

	req, err := http.NewRequest("GET", s.URL+"/test-json", nil)
	if err != nil {
		t.Fatal(err)
	}
	req.Header.Set("Accept", "application/json")
	_, err = hc.Do(req)
	if err != nil {
		t.Fatal(err)
	}

	req, err = http.NewRequest("GET", s.URL+"/test-xml", nil)
	if err != nil {
		t.Fatal(err)
	}
	req.Header.Set("Accept", "application/xml")
	_, err = hc.Do(req)
	if err != nil {
		t.Fatal(err)
	}
}
Ejemplo n.º 4
0
func TestPagination(t *testing.T) {
	m := httpware.Compose(
		httpware.DefaultErrHandler,
		New(Defaults),
	)
	s := httptest.NewServer(m.ThenFunc(func(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
		page := PageFromCtx(ctx)
		switch r.URL.RawQuery {
		case "":
			if page.Start != 0 {
				t.Fatal("expected page.Start == 0")
			}
			if page.Limit != Defaults.LimitDefault {
				t.Fatal("expected page.Limit == Defaults.LimitDefault")
			}
		case "start=10&limit=5":
			if page.Start != 10 {
				t.Fatal("expected page.Start == 10")
			}
			if page.Limit != 5 {
				t.Fatal("expected page.Limit == 5")
			}
		default:
			t.Fatal("this point should not be reached")
		}
		return nil
	}))

	// Valid test cases:
	http.Get(s.URL)
	http.Get(s.URL + "?start=10&limit=5")
	// Invalid test cases:
	r, _ := http.Get(s.URL + "?start=-10&limit=5")
	if r.StatusCode != http.StatusBadRequest {
		t.Fatalf("expected status code %v, got: %v, while testing negative start param", http.StatusBadRequest, r.StatusCode)
	}
	r, _ = http.Get(s.URL + "?start=10&limit=0")
	if r.StatusCode != http.StatusBadRequest {
		t.Fatalf("expected status code %v, got: %v, while testing zero limit param", http.StatusBadRequest, r.StatusCode)
	}
}
Ejemplo n.º 5
0
func TestWare(t *testing.T) {
	secret := []byte("shh")
	m := httpware.Compose(
		httpware.DefaultErrHandler,
		New(Config{secret}),
	)
	s := httptest.NewServer(m.ThenFunc(func(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
		return nil
	}))

	// Generate token.
	tkn := jwt.New(jwt.SigningMethodHS256)
	tknStr, err := tkn.SignedString(secret)
	if err != nil {
		t.Fatal(err)
	}

	// Unauthorized
	unauthResp, err := http.Get(s.URL)
	if err != nil {
		t.Fatal(err)
	}
	if unauthResp.StatusCode != http.StatusUnauthorized {
		t.Fatalf("expected status code %v, got %v", http.StatusUnauthorized, unauthResp.StatusCode)
	}

	// Authorized
	req, err := http.NewRequest("GET", s.URL, nil)
	if err != nil {
		t.Fatal(err)
	}
	req.Header.Set("Authorization", "Bearer "+tknStr)
	c := &http.Client{}
	resp, err := c.Do(req)
	if err != nil {
		t.Fatal(err)
	}
	if resp.StatusCode != http.StatusOK {
		t.Fatalf("expected status code %v, got %v", http.StatusOK, resp.StatusCode)
	}
}
Ejemplo n.º 6
0
func TestRemoteLimit(t *testing.T) {
	conf := Config{
		RemoteLimit: 3,
		TotalLimit:  10,
	}
	m := httpware.Compose(
		httpware.DefaultErrHandler,
		New(conf),
	)
	s := httptest.NewServer(m.ThenFunc(func(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
		if r.URL.Path == "/delay" {
			time.Sleep(1 * time.Second)
		}
		return nil
	}))

	resp, err := http.Get(s.URL)
	if err != nil {
		t.Fatal(err)
	}
	if resp.StatusCode != 200 {
		t.Fatalf("expected status code %v, got %v", 200, resp.StatusCode)
	}

	for i := 1; i <= conf.RemoteLimit+1; i++ {
		if i > conf.RemoteLimit {
			// Sleep for 100ms to make sure the other requests have reached the middleware.
			time.Sleep(100 * time.Millisecond)
			resp, err := http.Get(s.URL)
			if err != nil {
				t.Fatal(err)
			}
			if resp.StatusCode != 429 {
				t.Fatalf("expected status code %v, got %v", 429, resp.StatusCode)
			}
		} else {
			go http.Get(s.URL + "/delay")
		}
	}
}
Ejemplo n.º 7
0
func TestWare(t *testing.T) {
	m := httpware.Compose(
		httpware.DefaultErrHandler,
		New(Defaults),
	)
	s := httptest.NewServer(m.ThenFunc(func(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
		w.WriteHeader(http.StatusNoContent)
		return nil
	}))

	// Send Request.
	resp, err := http.Get(s.URL)
	if err != nil {
		t.Fatal(err)
	}
	if resp.StatusCode != http.StatusNoContent {
		t.Fatalf("expected status code %v, got %v", http.StatusNoContent, resp.StatusCode)
	}
	if resp.Header.Get("Access-Control-Allow-Origin") != "*" {
		t.Fatal("expected Access-Control-Allow-Origin header to be set to '*'")
	}
}
Ejemplo n.º 8
0
func TestStreaming(t *testing.T) {
	m := httpware.Compose(
		httpware.DefaultErrHandler,
		New(Defaults),
	)
	s := httptest.NewServer(m.ThenFunc(func(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
		sender := SenderFromCtx(ctx)
		for {
			if err := sender.Send("hello"); err != nil {
				t.Fatal("error sending message: ", err)
			}
		}
	}))

	resp, err := http.Get(s.URL)
	if err != nil {
		t.Fatal(err)
	}
	scanner := bufio.NewScanner(resp.Body)
	i := 1
	var total string
	for scanner.Scan() {
		// Scanner.Text() will pull out everything up to the new-line
		// character. In this case we add it back.
		total = total + scanner.Text() + "\n"
		// There should be exactly 4 new-line characters sent in the
		// for 2 messages.
		if i == 4 {
			if total != "data: hello\n\ndata: hello\n\n" {
				t.Fatalf("unexpected messages: %s", total)
			} else {
				break
			}
		}
		i++
	}
}
Ejemplo n.º 9
0
func TestTotalLimit(t *testing.T) {
	// Define multilayer chaining.
	conf := Config{
		RemoteLimit: 60,
		TotalLimit:  10,
	}
	mid := New(conf)
	c0 := httpware.Compose(
		httpware.DefaultErrHandler,
		mid,
	)
	c1 := c0.With(testWare{})

	// Start test servers.
	s := make([]*httptest.Server, 2)
	s[0] = httptest.NewServer(c0.ThenFunc(testHandler))
	s[1] = httptest.NewServer(c1.ThenFunc(testHandler))

	for i := 1; uint64(i) <= conf.TotalLimit+2; i++ {
		si := i % 2
		if uint64(i) > conf.TotalLimit {
			// Sleep for 100ms to make sure the other requests have reached the middleware.
			time.Sleep(100 * time.Millisecond)
			resp, err := http.Get(s[si].URL)
			if err != nil {
				t.Fatal(err)
			}
			if resp.StatusCode != 429 {
				t.Fatalf("expected status code %v, got %v", 429, resp.StatusCode)
			}
		} else {
			go http.Get(s[si].URL + "/delay")
		}
	}

}
Ejemplo n.º 10
0
func ExampleAdapt() {
	var hdlr = func(ctx context.Context, w http.ResponseWriter, r *http.Request) error { return nil }
	m := httpware.Compose(httpware.DefaultErrHandler)
	rtr := httprouter.New()
	rtr.GET("/something", Adapt(m.ThenFunc(hdlr)))
}