func TestSetForwardHeaders(t *testing.T) { xfp := "X-Forwarded-Proto" xff := "X-Forwarded-For" m := NewForwardedModifier() req, err := http.NewRequest("GET", "http://martian.local", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.RemoteAddr = "10.0.0.1:8112" if m.ModifyRequest(martian.NewContext(), req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if got, want := req.Header.Get(xfp), "http"; got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", xfp, got, want) } if got, want := req.Header.Get(xff), "10.0.0.1"; got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", xff, got, want) } // Test with existing X-Forwarded-For. req.RemoteAddr = "12.12.12.12" if m.ModifyRequest(martian.NewContext(), req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if got, want := req.Header.Get(xff), "10.0.0.1, 12.12.12.12"; got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", xff, got, want) } }
func TestModifierFromJSON(t *testing.T) { msg := []byte(`{ "cookie.Modifier": { "scope": ["request", "response"], "name": "martian", "value": "value" } }`) r, err := parse.FromJSON(msg) if err != nil { t.Fatalf("parse.FromJSON(): got %v, want no error", err) } req, err := http.NewRequest("GET", "http://example.com/path/", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } reqmod := r.RequestModifier() if reqmod == nil { t.Fatal("reqmod: got nil, want not nil") } if err := reqmod.ModifyRequest(martian.NewContext(), req); err != nil { t.Fatalf("reqmod.ModifyRequest(): got %v, want no error", err) } if got, want := len(req.Cookies()), 1; got != want { t.Fatalf("len(req.Cookies): got %v, want %v", got, want) } if got, want := req.Cookies()[0].Name, "martian"; got != want { t.Errorf("req.Cookies()[0].Name: got %v, want %v", got, want) } if got, want := req.Cookies()[0].Value, "value"; got != want { t.Errorf("req.Cookies()[0].Value: got %v, want %v", got, want) } resmod := r.ResponseModifier() if resmod == nil { t.Fatal("resmod: got nil, want not nil") } res := proxyutil.NewResponse(200, nil, req) if err := resmod.ModifyResponse(martian.NewContext(), res); err != nil { t.Fatalf("resmod.ModifyResponse(): got %v, want no error", err) } if got, want := len(res.Cookies()), 1; got != want { t.Fatalf("len(res.Cookies): got %v, want %v", got, want) } if got, want := res.Cookies()[0].Name, "martian"; got != want { t.Errorf("res.Cookies()[0].Name: got %v, want %v", got, want) } if got, want := res.Cookies()[0].Value, "value"; got != want { t.Errorf("res.Cookies()[0].Value: got %v, want %v", got, want) } }
func TestRemoveHopByHopHeaders(t *testing.T) { m := NewHopByHopModifier() req, err := http.NewRequest("GET", "/", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } hs := http.Header{ // Additional hop-by-hop headers are listed in the // Connection header. "Connection": []string{ "X-Connection", "X-Hop-By-Hop, close", }, // RFC hop-by-hop headers. "Keep-Alive": []string{}, "Proxy-Authenticate": []string{}, "Proxy-Authorization": []string{}, "Te": []string{}, "Trailer": []string{}, "Transfer-Encoding": []string{}, "Upgrade": []string{}, // Hop-by-hop headers listed in the Connection header. "X-Connection": []string{}, "X-Hop-By-Hop": []string{}, // End-to-end header that should not be removed. "X-End-To-End": []string{}, } req.Header = hs if err := m.ModifyRequest(martian.NewContext(), req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if got, want := len(req.Header), 1; got != want { t.Fatalf("len(req.Header): got %d, want %d", got, want) } if _, ok := req.Header["X-End-To-End"]; !ok { t.Errorf("req.Header[%q]: got !ok, want ok", "X-End-To-End") } res := proxyutil.NewResponse(200, nil, req) res.Header = hs if err := m.ModifyResponse(martian.NewContext(), res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if got, want := len(res.Header), 1; got != want { t.Fatalf("len(res.Header): got %d, want %d", got, want) } if _, ok := res.Header["X-End-To-End"]; !ok { t.Errorf("res.Header[%q]: got !ok, want ok", "X-End-To-End") } }
func TestFilterWithQueryStringNameAndValue(t *testing.T) { name, value := "name", "value" nameMatcher, err := regexp.Compile(name) if err != nil { t.Fatalf("regexp.Compile(%q): got %v, want no error", name, err) } valueMatcher, err := regexp.Compile(value) if err != nil { t.Fatalf("regexp.Compile(%q): got %v, want no error", value, err) } modifierRun := false f := func(*martian.Context, *http.Request) error { modifierRun = true return nil } filter, err := NewFilter(nameMatcher, valueMatcher) if err != nil { t.Fatalf("NewFilter(): got %v, want no error", err) } filter.SetRequestModifier(martian.RequestModifierFunc(f)) v := url.Values{} v.Add("nomatch", "value") req, err := http.NewRequest("POST", "http://martian.local?name=value", strings.NewReader(v.Encode())) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.Header.Add("Content-Type", "application/x-www-form-urlencoded") if err := filter.ModifyRequest(martian.NewContext(), req); err != nil { t.Errorf("ModifyRequest(): got %v, want no error", err) } if !modifierRun { t.Error("modifierRun: got false, want true") } v = url.Values{} req, err = http.NewRequest("POST", "http://martian.local", strings.NewReader(v.Encode())) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } modifierRun = false if err := filter.ModifyRequest(martian.NewContext(), req); err != nil { t.Errorf("ModifyRequest(): got %v, want no error", err) } if modifierRun { t.Error("modifierRun: got true, want false") } }
func TestFilterFromJSON(t *testing.T) { msg := []byte(`{ "header.Filter": { "scope": ["request", "response"], "name": "Martian-Passthru", "value": "true", "modifier": { "header.Modifier" : { "scope": ["request", "response"], "name": "Martian-Testing", "value": "true" } } } }`) r, err := parse.FromJSON(msg) if err != nil { t.Fatalf("parse.FromJSON(): got %v, want no error", err) } reqmod := r.RequestModifier() if reqmod == nil { t.Fatal("reqmod: got nil, want not nil") } req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.Header.Set("Martian-Passthru", "true") if err := reqmod.ModifyRequest(martian.NewContext(), req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if got, want := req.Header.Get("Martian-Testing"), "true"; got != want { t.Fatalf("req.Header.Get(%q): got %q, want %q", "Martian-Testing", got, want) } resmod := r.ResponseModifier() if resmod == nil { t.Fatal("resmod: got nil, want not nil") } res := proxyutil.NewResponse(200, nil, nil) res.Header.Set("Martian-Passthru", "true") if err := resmod.ModifyResponse(martian.NewContext(), res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if got, want := res.Header.Get("Martian-Testing"), "true"; got != want { t.Fatalf("res.Header.Get(%q): got %q, want %q", "Martian-Testing", got, want) } }
func TestVerifierFromJSON(t *testing.T) { msg := []byte(`{ "header.Verifier": { "scope": ["request", "response"], "name": "Martian-Test", "value": "true" } }`) r, err := parse.FromJSON(msg) if err != nil { t.Fatalf("parse.FromJSON(): got %v, want no error", err) } reqmod := r.RequestModifier() if reqmod == nil { t.Fatal("reqmod: got nil, want not nil") } reqv, ok := reqmod.(verify.RequestVerifier) if !ok { t.Fatal("reqmod.(verify.RequestVerifier): got !ok, want ok") } req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := reqv.ModifyRequest(martian.NewContext(), req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if err := reqv.VerifyRequests(); err == nil { t.Error("VerifyRequests(): got nil, want not nil") } resmod := r.ResponseModifier() if resmod == nil { t.Fatal("resmod: got nil, want not nil") } resv, ok := resmod.(verify.ResponseVerifier) if !ok { t.Fatal("resmod.(verify.ResponseVerifier): got !ok, want ok") } res := proxyutil.NewResponse(200, nil, req) if err := resv.ModifyResponse(martian.NewContext(), res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if err := resv.VerifyResponses(); err == nil { t.Error("VerifyResponses(): got nil, want not nil") } }
func TestFromJSON(t *testing.T) { msg := []byte(`{ "status.Modifier": { "scope": ["response"], "statusCode": 400 } }`) r, err := parse.FromJSON(msg) if err != nil { t.Fatalf("parse.FromJSON(): got %v, want no error", err) } resmod := r.ResponseModifier() if resmod == nil { t.Fatal("resmod: got nil, want not nil") } res := proxyutil.NewResponse(200, nil, nil) if err := resmod.ModifyResponse(martian.NewContext(), res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if got, want := res.StatusCode, 400; got != want { t.Errorf("res.StatusCode: got %d, want %d", got, want) } }
func TestModifierFromJSON(t *testing.T) { msg := []byte(` { "querystring.Modifier": { "scope": ["request"], "name": "param", "value": "true" } }`) r, err := parse.FromJSON(msg) if err != nil { t.Fatalf("parse.FromJSON(): got %v, want no error", err) } req, err := http.NewRequest("GET", "http://martian.test", nil) if err != nil { t.Fatalf("http.NewRequest(): got %q, want no error", err) } reqmod := r.RequestModifier() if reqmod == nil { t.Fatalf("reqmod: got nil, want not nil") } ctx := martian.NewContext() if err := reqmod.ModifyRequest(ctx, req); err != nil { t.Fatalf("reqmod.ModifyRequest(): got %v, want no error", err) } if got, want := req.URL.Query().Get("param"), "true"; got != want { t.Errorf("req.URL.Query().Get(%q): got %q, want %q", "param", got, want) } }
func TestIntegration(t *testing.T) { server := httptest.NewServer(http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { r.URL.Scheme = "http" r.URL.Host = r.Host w.Header().Set("Martian-URL", r.URL.String()) })) defer server.Close() u := &url.URL{ Scheme: "http", Host: server.Listener.Addr().String(), } m := NewModifier(u) req, err := http.NewRequest("GET", "https://example.com/test", nil) if err != nil { t.Fatalf("http.NewRequest(%q, %q, nil): got %v, want no error", "GET", "http://example.com/test", err) } if err := m.ModifyRequest(martian.NewContext(), req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } res, err := http.DefaultClient.Do(req) if err != nil { t.Fatalf("http.DefaultClient.Do(): got %v, want no error", err) } want := "http://example.com/test" if got := res.Header.Get("Martian-URL"); got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Martian-URL", got, want) } }
func TestVerifierFromJSON(t *testing.T) { msg := []byte(`{ "url.Verifier": { "scope": ["request"], "scheme": "https", "host": "www.martian.proxy", "path": "/testing", "query": "test=true" } }`) r, err := parse.FromJSON(msg) if err != nil { t.Fatalf("parse.FromJSON(): got %v, want no error", err) } reqmod := r.RequestModifier() if reqmod == nil { t.Fatal("reqmod: got nil, want not nil") } reqv, ok := reqmod.(verify.RequestVerifier) if !ok { t.Fatal("reqmod.(verify.RequestVerifier): got !ok, want ok") } req, err := http.NewRequest("GET", "https://www.martian.proxy/testing?test=false", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := reqv.ModifyRequest(martian.NewContext(), req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if err := reqv.VerifyRequests(); err == nil { t.Error("VerifyRequests(): got nil, want not nil") } }
func TestFailureWithMissingKey(t *testing.T) { v, err := NewVerifier("foo", "bar") if err != nil { t.Fatalf("NewVerifier(%q, %q): got %v, want no error", "foo", "bar", err) } req, err := http.NewRequest("GET", "http://www.google.com?fizz=bar", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := v.ModifyRequest(martian.NewContext(), req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } merr, ok := v.VerifyRequests().(*verify.MultiError) if !ok { t.Fatal("VerifyRequests(): got nil, want *verify.MultiError") } errs := merr.Errors() if len(errs) != 1 { t.Fatalf("len(merr.Errors()): got %d, want 1", len(errs)) } expectErr := "request(http://www.google.com?fizz=bar) param verification error: key foo not found" for i := range errs { if got, want := errs[i].Error(), expectErr; got != want { t.Errorf("%d. err.Error(): mismatched error output\ngot: %s\nwant: %s", i, got, want) } } }
func TestModifyRequest(t *testing.T) { m := NewModifier() ctx := martian.NewContext() req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.Header.Set("Proxy-Authorization", "Basic "+encode("user:pass")) if err := m.ModifyRequest(ctx, req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if got, want := ctx.Auth.ID, "user:pass"; got != want { t.Fatalf("ctx.Auth.ID: got %q, want %q", got, want) } modifierRun := false m.SetRequestModifier(martian.RequestModifierFunc( func(*martian.Context, *http.Request) error { modifierRun = true return nil })) if err := m.ModifyRequest(ctx, req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if !modifierRun { t.Error("modifierRun: got false, want true") } }
func TestStatusModifierOnResponse(t *testing.T) { for i, status := range []int{ http.StatusForbidden, http.StatusOK, http.StatusTemporaryRedirect, } { req, err := http.NewRequest("GET", "/", nil) if err != nil { t.Fatalf("NewRequest(): got %v, want no error", err) } res := proxyutil.NewResponse(200, nil, req) mod := NewModifier(status) if err := mod.ModifyResponse(martian.NewContext(), res); err != nil { t.Fatalf("%d. ModifyResponse(): got %v, want no error", i, err) } if got, want := res.StatusCode, status; got != want { t.Errorf("%d. res.StatusCode: got %v, want %v", i, got, want) } if got, want := res.Status, http.StatusText(status); got != want { t.Errorf("%d. res.Status: got %q, want %q", i, got, want) } } }
func TestModifyResponse(t *testing.T) { m := NewModifier() ctx := martian.NewContext() res := proxyutil.NewResponse(200, nil, nil) if err := m.ModifyResponse(ctx, res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } m.SetResponseModifier(martian.ResponseModifierFunc( func(*martian.Context, *http.Response) error { ctx.Auth.Error = errors.New("auth is required") return nil })) if err := m.ModifyResponse(ctx, res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if got, want := res.StatusCode, http.StatusProxyAuthRequired; got != want { t.Errorf("res.StatusCode: got %d, want %d", got, want) } if got, want := res.Header.Get("Proxy-Authenticate"), "Basic"; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Proxy-Authenticate", got, want) } }
func ExampleLogger() { l := NewLogger() l.SetLogFunc(func(line string) { // Remove \r to make it easier to test with examples. fmt.Print(strings.Replace(line, "\r", "", -1)) }) req, err := http.NewRequest("GET", "http://example.com/path?querystring", strings.NewReader("body")) if err != nil { fmt.Println(err) return } req.RequestURI = req.URL.RequestURI() req.Header.Set("Other-Header", "values") req.Close = true if err := l.ModifyRequest(martian.NewContext(), req); err != nil { fmt.Println(err) return } res := proxyutil.NewResponse(200, strings.NewReader("body"), req) res.Header.Set("Date", "Tue, 15 Nov 1994 08:12:31 GMT") res.Header.Set("Other-Header", "values") if err := l.ModifyResponse(martian.NewContext(), res); err != nil { fmt.Println(err) return } // Output: // -------------------------------------------------------------------------------- // Request to http://example.com/path?querystring // -------------------------------------------------------------------------------- // GET /path?querystring HTTP/1.1 // Host: example.com // Connection: close // Other-Header: values // -------------------------------------------------------------------------------- // // -------------------------------------------------------------------------------- // Response from http://example.com/path?querystring // -------------------------------------------------------------------------------- // HTTP/1.1 200 OK // Date: Tue, 15 Nov 1994 08:12:31 GMT // Other-Header: values // -------------------------------------------------------------------------------- }
func TestModifyResponseNoModifier(t *testing.T) { m := NewModifier() res := proxyutil.NewResponse(200, nil, nil) if err := m.ModifyResponse(martian.NewContext(), res); err != nil { t.Errorf("ModifyResponse(): got %v, want no error", err) } }
func TestFilterFromJSON(t *testing.T) { msg := []byte(` { "querystring.Filter": { "scope": ["request", "response"], "name": "param", "value": "true", "modifier": { "header.Modifier": { "scope": ["request", "response"], "name": "Mod-Run", "value": "true" } } } }`) r, err := parse.FromJSON(msg) if err != nil { t.Fatalf("parse.FromJSON(): got %v, want no error", err) } reqmod := r.RequestModifier() if reqmod == nil { t.Fatal("reqmod: got nil, want not nil") } req, err := http.NewRequest("GET", "https://martian.test?param=true", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } ctx := martian.NewContext() if err := reqmod.ModifyRequest(ctx, req); err != nil { t.Fatalf("reqmod.ModifyRequest(): got %v, want no error", err) } if got, want := req.Header.Get("Mod-Run"), "true"; got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", "Mod-Run", got, want) } resmod := r.ResponseModifier() if resmod == nil { t.Fatalf("resmod: got nil, want not nil") } res := proxyutil.NewResponse(200, nil, req) if err := resmod.ModifyResponse(ctx, res); err != nil { t.Fatalf("resmod.ModifyResponse(): got %v, want no error", err) } if got, want := res.Header.Get("Mod-Run"), "true"; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Mod-Run", got, want) } }
func TestFilterWithQueryParamNameAndNoValue(t *testing.T) { name := "name" nameMatcher, err := regexp.Compile(name) if err != nil { t.Fatalf("regexp.Compile(%q): got %v, want no error", name, err) } modifierRun := false f := func(*martian.Context, *http.Request) error { modifierRun = true return nil } filter, err := NewFilter(nameMatcher, nil) if err != nil { t.Fatalf("NewFilter(): got %v, want no error", err) } filter.SetRequestModifier(martian.RequestModifierFunc(f)) req, err := http.NewRequest("GET", "http://martian.local?name", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := filter.ModifyRequest(martian.NewContext(), req); err != nil { t.Errorf("ModifyRequest(): got %v, want no error", err) } if !modifierRun { t.Error("modifierRun: got false, want true") } req, err = http.NewRequest("GET", "http://martian.local?test", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } modifierRun = false if err := filter.ModifyRequest(martian.NewContext(), req); err != nil { t.Errorf("ModifyRequest(): got %v, want no error", err) } if modifierRun { t.Error("modifierRun: got true, want false") } }
func TestServeHTTP(t *testing.T) { m := NewModifier() msg := []byte(` { "header.Modifier": { "scope": ["request", "response"], "name": "Martian-Test", "value": "true" } }`) req, err := http.NewRequest("POST", "/martian/modifiers?id=id", bytes.NewBuffer(msg)) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.Header.Set("Content-Type", "application/json") rw := httptest.NewRecorder() m.ServeHTTP(rw, req) if got, want := rw.Code, 200; got != want { t.Errorf("rw.Code: got %d, want %d", got, want) } req, err = http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := m.ModifyRequest(martian.NewContext(), req); err != nil { t.Fatalf("m.ModifyRequest(): got %v, want no error", err) } if got, want := req.Header.Get("Martian-Test"), "true"; got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", "Martian-Test", got, want) } res := proxyutil.NewResponse(200, nil, req) if err := m.ModifyResponse(martian.NewContext(), res); err != nil { t.Fatalf("m.ModifyResponse(): got %v, want no error", err) } if got, want := res.Header.Get("Martian-Test"), "true"; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Martian-Test", got, want) } }
func TestModifyRequest(t *testing.T) { f := NewFilter() modifierRun := false f.SetRequestModifier("id", martian.RequestModifierFunc( func(*martian.Context, *http.Request) error { modifierRun = true return nil })) req, err := http.NewRequest("GET", "/", nil) if err != nil { t.Fatalf("NewRequest(): got %v, want no error", err) } ctx := martian.NewContext() // No ID, auth required. f.SetAuthRequired(true) if err := f.ModifyRequest(ctx, req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if ctx.Auth.Error == nil { t.Error("ctx.Auth.Error: got nil, want error") } if modifierRun { t.Error("modifierRun: got true, want false") } // No ID, auth not required. f.SetAuthRequired(false) ctx.Auth.Error = nil if err := f.ModifyRequest(ctx, req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if ctx.Auth.Error != nil { t.Errorf("ctx.Auth.Error: got %v, want no error", err) } if modifierRun { t.Error("modifierRun: got true, want false") } // Valid ID. ctx.Auth.ID = "id" ctx.Auth.Error = nil if err := f.ModifyRequest(ctx, req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if ctx.Auth.Error != nil { t.Errorf("ctx.Auth.Error: got %v, want no error", ctx.Auth.Error) } if !modifierRun { t.Error("modifierRun: got false, want true") } }
func TestModifyResponse(t *testing.T) { m := NewModifier() m.SetResponseModifier(nil) req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } ctx, remove, err := martian.TestContext(req) if err != nil { t.Fatalf("martian.TestContext(): got %v, want no error", err) } defer remove() res := proxyutil.NewResponse(200, nil, req) if err := m.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } // Modifier with error. tm := martiantest.NewModifier() reserr := errors.New("response error") tm.ResponseError(reserr) m.SetResponseModifier(tm) if err := m.ModifyResponse(res); err != reserr { t.Fatalf("ModifyResponse(): got %v, want %v", err, reserr) } // Modifier with auth error. tm.ResponseError(nil) autherr := errors.New("auth error") tm.ResponseFunc(func(res *http.Response) { ctx := martian.NewContext(res.Request) actx := auth.FromContext(ctx) actx.SetError(autherr) }) actx := auth.FromContext(ctx) actx.SetID("bad-auth") if err := m.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if got, want := res.StatusCode, 403; got != want { t.Errorf("res.StatusCode: got %d, want %d", got, want) } if got, want := actx.Error(), autherr; got != want { t.Errorf("actx.Error(): got %v, want %v", got, want) } }
func TestBadFramingMultipleContentLengths(t *testing.T) { m := NewBadFramingModifier() req, err := http.NewRequest("GET", "/", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.Header["Content-Length"] = []string{"42", "42, 42"} if err := m.ModifyRequest(martian.NewContext(), req); err != nil { t.Errorf("ModifyRequest(): got %v, want no error", err) } if got, want := req.Header["Content-Length"], []string{"42"}; !reflect.DeepEqual(got, want) { t.Errorf("req.Header[%q]: got %v, want %v", "Content-Length", got, want) } req.Header["Content-Length"] = []string{"42", "32, 42"} if err := m.ModifyRequest(martian.NewContext(), req); err == nil { t.Error("ModifyRequest(): got nil, want error") } }
func TestNewModifier(t *testing.T) { tt := []struct { want string url *url.URL }{ { want: "https://www.example.com", url: &url.URL{Scheme: "https"}, }, { want: "http://www.martian.local", url: &url.URL{Host: "www.martian.local"}, }, { want: "http://www.example.com/test", url: &url.URL{Path: "/test"}, }, { want: "http://www.example.com?test=true", url: &url.URL{RawQuery: "test=true"}, }, { want: "http://www.example.com#test", url: &url.URL{Fragment: "test"}, }, { want: "https://martian.local/test?test=true#test", url: &url.URL{ Scheme: "https", Host: "martian.local", Path: "/test", RawQuery: "test=true", Fragment: "test", }, }, } for i, tc := range tt { req, err := http.NewRequest("GET", "http://www.example.com", nil) if err != nil { t.Fatalf("%d. NewRequest(): got %v, want no error", i, err) } mod := NewModifier(tc.url) if err := mod.ModifyRequest(martian.NewContext(), req); err != nil { t.Fatalf("%d. ModifyRequest(): got %q, want no error", i, err) } if got := req.URL.String(); got != tc.want { t.Errorf("%d. req.URL: got %q, want %q", i, got, tc.want) } } }
func TestCookieModifier(t *testing.T) { cookie := &http.Cookie{ Name: "name", Value: "value", } mod := NewModifier(cookie) req, err := http.NewRequest("GET", "/", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := mod.ModifyRequest(martian.NewContext(), req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if got, want := len(req.Cookies()), 1; got != want { t.Errorf("len(req.Cookies): got %v, want %v", got, want) } if got, want := req.Cookies()[0].Name, cookie.Name; got != want { t.Errorf("req.Cookies()[0].Name: got %v, want %v", got, want) } if got, want := req.Cookies()[0].Value, cookie.Value; got != want { t.Errorf("req.Cookies()[0].Value: got %v, want %v", got, want) } res := proxyutil.NewResponse(200, nil, req) if err := mod.ModifyResponse(martian.NewContext(), res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if got, want := len(res.Cookies()), 1; got != want { t.Errorf("len(res.Cookies): got %v, want %v", got, want) } if got, want := res.Cookies()[0].Name, cookie.Name; got != want { t.Errorf("res.Cookies()[0].Name: got %v, want %v", got, want) } if got, want := res.Cookies()[0].Value, cookie.Value; got != want { t.Errorf("res.Cookies()[0].Value: got %v, want %v", got, want) } }
func TestModifyRequestNoModifier(t *testing.T) { m := NewModifier() req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := m.ModifyRequest(martian.NewContext(), req); err != nil { t.Errorf("ModifyRequest(): got %v, want no error", err) } }
// ModifyResponse sets the status code to 400 Bad Request if a loop was // detected in the request. func (m *viaModifier) ModifyResponse(res *http.Response) error { ctx := martian.NewContext(res.Request) if err, _ := ctx.Get(viaLoopKey); err != nil { res.StatusCode = 400 res.Status = http.StatusText(400) return err.(error) } return nil }
func TestModifyResponseResetAuth(t *testing.T) { auth := NewModifier() auth.SetRequestModifier(martian.RequestModifierFunc( func(ctx *martian.Context, req *http.Request) error { if ctx.Auth.ID != "secret:pass" { ctx.Auth.Error = errors.New("invalid auth") } return nil })) req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.Header.Set("Proxy-Authorization", "Basic "+encode("wrong:pass")) ctx := martian.NewContext() // This will set ctx.Auth.Error since the ID isn't "secret:pass". if err := auth.ModifyRequest(ctx, req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } res := proxyutil.NewResponse(200, nil, req) if err := auth.ModifyResponse(ctx, res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if got, want := res.StatusCode, 407; got != want { t.Errorf("res.StatusCode: got %d, want %d", got, want) } if got, want := ctx.Auth.ID, ""; got != want { t.Errorf("ctx.Auth.ID: got %q, want %q", got, want) } if err := ctx.Auth.Error; err != nil { t.Errorf("ctx.Auth.Error: got %v, want no error", err) } // This will be successful because the ID is "secret:pass". req.Header.Set("Proxy-Authorization", "Basic "+encode("secret:pass")) if err := auth.ModifyRequest(ctx, req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } // Reset the response. res = proxyutil.NewResponse(200, nil, req) if err := auth.ModifyResponse(ctx, res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if got, want := res.StatusCode, 200; got != want { t.Errorf("res.StatusCode: got %d, want %d", got, want) } }
func TestBadFramingTransferEncodingAndContentLength(t *testing.T) { m := NewBadFramingModifier() req, err := http.NewRequest("GET", "/", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.Header["Transfer-Encoding"] = []string{"gzip, chunked"} req.Header["Content-Length"] = []string{"42"} if err := m.ModifyRequest(martian.NewContext(), req); err != nil { t.Errorf("ModifyRequest(): got %v, want no error", err) } if _, ok := req.Header["Content-Length"]; ok { t.Fatalf("req.Header[%q]: got ok, want !ok", "Content-Length") } req.Header.Set("Transfer-Encoding", "gzip, identity") req.Header.Del("Content-Length") if err := m.ModifyRequest(martian.NewContext(), req); err == nil { t.Error("ModifyRequest(): got nil, want error") } }
func TestViaModifier(t *testing.T) { m := NewViaModifier("1.1 martian") req, err := http.NewRequest("GET", "/", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := m.ModifyRequest(martian.NewContext(), req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if got, want := req.Header.Get("Via"), "1.1 martian"; got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", "Via", got, want) } req.Header.Set("Via", "1.0 alpha") if err := m.ModifyRequest(martian.NewContext(), req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if got, want := req.Header.Get("Via"), "1.0 alpha, 1.1 martian"; got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", "Via", got, want) } }
// ModifyResponse runs resmod.ModifyResponse and modifies the response to // include the correct status code and headers if auth error is present. // // If an error is returned from resmod.ModifyResponse it is returned. func (m *Modifier) ModifyResponse(res *http.Response) error { ctx := martian.NewContext(res.Request) actx := auth.FromContext(ctx) err := m.resmod.ModifyResponse(res) if actx.Error() != nil { res.StatusCode = http.StatusProxyAuthRequired res.Header.Set("Proxy-Authenticate", "Basic") } return err }