func TestModifyResponseResetAuth(t *testing.T) { auth := NewModifier() auth.SetRequestModifier(martian.RequestModifierFunc( func(ctx *martian.Context, req *http.Request) error { if ctx.Auth.ID != "secret:pass" { ctx.Auth.Error = errors.New("invalid auth") } return nil })) req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.Header.Set("Proxy-Authorization", "Basic "+encode("wrong:pass")) ctx := martian.NewContext() // This will set ctx.Auth.Error since the ID isn't "secret:pass". if err := auth.ModifyRequest(ctx, req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } res := proxyutil.NewResponse(200, nil, req) if err := auth.ModifyResponse(ctx, res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if got, want := res.StatusCode, 407; got != want { t.Errorf("res.StatusCode: got %d, want %d", got, want) } if got, want := ctx.Auth.ID, ""; got != want { t.Errorf("ctx.Auth.ID: got %q, want %q", got, want) } if err := ctx.Auth.Error; err != nil { t.Errorf("ctx.Auth.Error: got %v, want no error", err) } // This will be successful because the ID is "secret:pass". req.Header.Set("Proxy-Authorization", "Basic "+encode("secret:pass")) if err := auth.ModifyRequest(ctx, req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } // Reset the response. res = proxyutil.NewResponse(200, nil, req) if err := auth.ModifyResponse(ctx, res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if got, want := res.StatusCode, 200; got != want { t.Errorf("res.StatusCode: got %d, want %d", got, want) } }
func TestModifyResponse(t *testing.T) { m := NewModifier() ctx := martian.NewContext() res := proxyutil.NewResponse(200, nil, nil) if err := m.ModifyResponse(ctx, res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } m.SetResponseModifier(martian.ResponseModifierFunc( func(*martian.Context, *http.Response) error { ctx.Auth.Error = errors.New("auth is required") return nil })) if err := m.ModifyResponse(ctx, res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if got, want := res.StatusCode, http.StatusProxyAuthRequired; got != want { t.Errorf("res.StatusCode: got %d, want %d", got, want) } if got, want := res.Header.Get("Proxy-Authenticate"), "Basic"; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Proxy-Authenticate", got, want) } }
func TestCopyModifier(t *testing.T) { m := NewCopyModifier("Original", "Copy") req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.Header.Set("Original", "test") if err := m.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if got, want := req.Header.Get("Copy"), "test"; got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", "Copy", got, want) } res := proxyutil.NewResponse(200, nil, req) res.Header.Set("Original", "test") if err := m.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if got, want := res.Header.Get("Copy"), "test"; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Copy", got, want) } }
func TestModifyResponse(t *testing.T) { p := NewProxy(mitm) req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(%q, ..., nil): got %v, want no error", "GET", err) } res := proxyutil.NewResponse(200, nil, req) if err := p.ModifyResponse(NewContext(), res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } modifierRun := false f := func(*Context, *http.Response) error { modifierRun = true return nil } p.SetResponseModifier(ResponseModifierFunc(f)) if err := p.ModifyResponse(NewContext(), res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if !modifierRun { t.Error("modifierRun: got false, want true") } }
func TestModifierFromJSON(t *testing.T) { msg := []byte(`{ "cookie.Modifier": { "scope": ["request", "response"], "name": "martian", "value": "value" } }`) r, err := parse.FromJSON(msg) if err != nil { t.Fatalf("parse.FromJSON(): got %v, want no error", err) } req, err := http.NewRequest("GET", "http://example.com/path/", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } reqmod := r.RequestModifier() if reqmod == nil { t.Fatal("reqmod: got nil, want not nil") } if err := reqmod.ModifyRequest(martian.NewContext(), req); err != nil { t.Fatalf("reqmod.ModifyRequest(): got %v, want no error", err) } if got, want := len(req.Cookies()), 1; got != want { t.Fatalf("len(req.Cookies): got %v, want %v", got, want) } if got, want := req.Cookies()[0].Name, "martian"; got != want { t.Errorf("req.Cookies()[0].Name: got %v, want %v", got, want) } if got, want := req.Cookies()[0].Value, "value"; got != want { t.Errorf("req.Cookies()[0].Value: got %v, want %v", got, want) } resmod := r.ResponseModifier() if resmod == nil { t.Fatal("resmod: got nil, want not nil") } res := proxyutil.NewResponse(200, nil, req) if err := resmod.ModifyResponse(martian.NewContext(), res); err != nil { t.Fatalf("resmod.ModifyResponse(): got %v, want no error", err) } if got, want := len(res.Cookies()), 1; got != want { t.Fatalf("len(res.Cookies): got %v, want %v", got, want) } if got, want := res.Cookies()[0].Name, "martian"; got != want { t.Errorf("res.Cookies()[0].Name: got %v, want %v", got, want) } if got, want := res.Cookies()[0].Value, "value"; got != want { t.Errorf("res.Cookies()[0].Value: got %v, want %v", got, want) } }
func TestFromJSON(t *testing.T) { msg := []byte(`{ "status.Modifier": { "scope": ["response"], "statusCode": 400 } }`) r, err := parse.FromJSON(msg) if err != nil { t.Fatalf("parse.FromJSON(): got %v, want no error", err) } resmod := r.ResponseModifier() if resmod == nil { t.Fatal("resmod: got nil, want not nil") } res := proxyutil.NewResponse(200, nil, nil) if err := resmod.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if got, want := res.StatusCode, 400; got != want { t.Errorf("res.StatusCode: got %d, want %d", got, want) } }
func TestModifierFuncs(t *testing.T) { reqmod := RequestModifierFunc( func(req *http.Request) error { req.Header.Set("Request-Modified", "true") return nil }) resmod := ResponseModifierFunc( func(res *http.Response) error { res.Header.Set("Response-Modified", "true") return nil }) req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := reqmod.ModifyRequest(req); err != nil { t.Fatalf("reqmod.ModifyRequest(): got %v, want no error", err) } if got, want := req.Header.Get("Request-Modified"), "true"; got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", "Request-Modified", got, want) } res := proxyutil.NewResponse(200, nil, req) if err := resmod.ModifyResponse(res); err != nil { t.Fatalf("resmod.ModifyResponse(): got %v, want no error", err) } if got, want := res.Header.Get("Response-Modified"), "true"; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Response-Modified", got, want) } }
func TestStatusModifierOnResponse(t *testing.T) { for i, status := range []int{ http.StatusForbidden, http.StatusOK, http.StatusTemporaryRedirect, } { req, err := http.NewRequest("GET", "/", nil) if err != nil { t.Fatalf("NewRequest(): got %v, want no error", err) } res := proxyutil.NewResponse(200, nil, req) mod := NewModifier(status) if err := mod.ModifyResponse(res); err != nil { t.Fatalf("%d. ModifyResponse(): got %v, want no error", i, err) } if got, want := res.StatusCode, status; got != want { t.Errorf("%d. res.StatusCode: got %v, want %v", i, got, want) } if got, want := res.Status, fmt.Sprintf("%d %s", res.StatusCode, http.StatusText(status)); got != want { t.Errorf("%d. res.Status: got %q, want %q", i, got, want) } } }
func TestExportIgnoresOrphanedResponse(t *testing.T) { logger := NewLogger("martian", "2.0.0") req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } _, remove, err := martian.TestContext(req) if err != nil { t.Fatalf("martian.TestContext(): got %v, want no error", err) } defer remove() if err := logger.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } // Reset before the response comes back. logger.Reset() res := proxyutil.NewResponse(200, nil, req) if err := logger.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } log := logger.Export().Log if got, want := len(log.Entries), 0; got != want { t.Errorf("len(log.Entries): got %d, want %d", got, want) } }
func TestVerifierFromJSON(t *testing.T) { msg := []byte(`{ "status.Verifier": { "scope": ["response"], "statusCode": 400 } }`) r, err := parse.FromJSON(msg) if err != nil { t.Fatalf("parse.FromJSON(): got %v, want no error", err) } resmod := r.ResponseModifier() if resmod == nil { t.Fatal("resmod: got nil, want not nil") } resv, ok := resmod.(verify.ResponseVerifier) if !ok { t.Fatal("reqmod.(verify.RequestVerifier): got !ok, want ok") } req, err := http.NewRequest("GET", "http://www.example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } res := proxyutil.NewResponse(200, nil, req) if err := resv.ModifyResponse(martian.NewContext(), res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if err := resv.VerifyResponses(); err == nil { t.Error("VerifyResponses(): got nil, want not nil") } }
func (p *Proxy) connect(req *http.Request) (*http.Response, net.Conn, error) { if p.proxyURL != nil { log.Debugf("martian: CONNECT with downstream proxy: %s", p.proxyURL.Host) conn, err := net.Dial("tcp", p.proxyURL.Host) if err != nil { return nil, nil, err } pbw := bufio.NewWriter(conn) pbr := bufio.NewReader(conn) req.Write(pbw) pbw.Flush() res, err := http.ReadResponse(pbr, req) if err != nil { return nil, nil, err } return res, conn, nil } log.Debugf("martian: CONNECT to host directly: %s", req.URL.Host) conn, err := net.Dial("tcp", req.URL.Host) if err != nil { return nil, nil, err } return proxyutil.NewResponse(200, nil, req), conn, nil }
func TestBodyModifier(t *testing.T) { mod := NewModifier([]byte("text"), "text/plain") req, err := http.NewRequest("GET", "/", strings.NewReader("")) if err != nil { t.Fatalf("NewRequest(): got %v, want no error", err) } req.Header.Set("Content-Encoding", "gzip") if err := mod.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if got, want := req.Header.Get("Content-Type"), "text/plain"; got != want { t.Errorf("req.Header.Get(%q): got %v, want %v", "Content-Type", got, want) } if got, want := req.ContentLength, int64(len([]byte("text"))); got != want { t.Errorf("req.ContentLength: got %d, want %d", got, want) } if got, want := req.Header.Get("Content-Encoding"), ""; got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", "Content-Encoding", got, want) } got, err := ioutil.ReadAll(req.Body) if err != nil { t.Fatalf("ioutil.ReadAll(): got %v, want no error", err) } req.Body.Close() if want := []byte("text"); !bytes.Equal(got, want) { t.Errorf("res.Body: got %q, want %q", got, want) } res := proxyutil.NewResponse(200, nil, req) res.Header.Set("Content-Encoding", "gzip") if err := mod.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if got, want := res.Header.Get("Content-Type"), "text/plain"; got != want { t.Errorf("res.Header.Get(%q): got %v, want %v", "Content-Type", got, want) } if got, want := res.ContentLength, int64(len([]byte("text"))); got != want { t.Errorf("res.ContentLength: got %d, want %d", got, want) } if got, want := res.Header.Get("Content-Encoding"), ""; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Content-Encoding", got, want) } got, err = ioutil.ReadAll(res.Body) if err != nil { t.Fatalf("ioutil.ReadAll(): got %v, want no error", err) } res.Body.Close() if want := []byte("text"); !bytes.Equal(got, want) { t.Errorf("res.Body: got %q, want %q", got, want) } }
func TestNoModifiers(t *testing.T) { m := NewModifier() m.SetRequestModifier(nil) m.SetResponseModifier(nil) req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } ctx, err := session.FromContext(nil) if err != nil { t.Fatalf("session.FromContext(): got %v, want no error", err) } martian.SetContext(req, ctx) defer martian.RemoveContext(req) if err := m.ModifyRequest(req); err != nil { t.Errorf("ModifyRequest(): got %v, want no error", err) } res := proxyutil.NewResponse(200, nil, req) if err := m.ModifyResponse(res); err != nil { t.Errorf("ModifyResponse(): got %v, want no error", err) } }
func ExampleLogger() { l := NewLogger() l.SetLogFunc(func(line string) { // Remove \r to make it easier to test with examples. fmt.Print(strings.Replace(line, "\r", "", -1)) }) l.SetDecode(true) buf := new(bytes.Buffer) gw := gzip.NewWriter(buf) gw.Write([]byte("request content")) gw.Close() req, err := http.NewRequest("GET", "http://example.com/path?querystring", buf) if err != nil { fmt.Println(err) return } req.TransferEncoding = []string{"chunked"} req.Header.Set("Content-Encoding", "gzip") if err := l.ModifyRequest(req); err != nil { fmt.Println(err) return } res := proxyutil.NewResponse(200, strings.NewReader("response content"), req) res.ContentLength = 16 res.Header.Set("Date", "Tue, 15 Nov 1994 08:12:31 GMT") res.Header.Set("Other-Header", "values") if err := l.ModifyResponse(res); err != nil { fmt.Println(err) return } // Output: // -------------------------------------------------------------------------------- // Request to http://example.com/path?querystring // -------------------------------------------------------------------------------- // GET http://example.com/path?querystring HTTP/1.1 // Host: example.com // Transfer-Encoding: chunked // Content-Encoding: gzip // // request content // // -------------------------------------------------------------------------------- // // -------------------------------------------------------------------------------- // Response from http://example.com/path?querystring // -------------------------------------------------------------------------------- // HTTP/1.1 200 OK // Content-Length: 16 // Date: Tue, 15 Nov 1994 08:12:31 GMT // Other-Header: values // // response content // -------------------------------------------------------------------------------- }
func TestModifyResponseNoModifier(t *testing.T) { m := NewModifier() res := proxyutil.NewResponse(200, nil, nil) if err := m.ModifyResponse(martian.NewContext(), res); err != nil { t.Errorf("ModifyResponse(): got %v, want no error", err) } }
func (p *Proxy) roundTrip(ctx *Context, req *http.Request) (*http.Response, error) { if ctx.SkippingRoundTrip() { log.Debugf("martian: skipping round trip") return proxyutil.NewResponse(200, nil, req), nil } return p.roundTripper.RoundTrip(req) }
func TestModifier(t *testing.T) { var reqrun bool var resrun bool moderr := errors.New("modifier error") tm := NewModifier() tm.RequestError(moderr) tm.RequestFunc(func(*http.Request) { reqrun = true }) tm.ResponseError(moderr) tm.ResponseFunc(func(*http.Response) { resrun = true }) req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := tm.ModifyRequest(req); err != moderr { t.Fatalf("tm.ModifyRequest(): got %v, want %v", err, moderr) } if !tm.RequestModified() { t.Errorf("tm.RequestModified(): got false, want true") } if tm.RequestCount() != 1 { t.Errorf("tm.RequestCount(): got %d, want %d", tm.RequestCount(), 1) } if !reqrun { t.Error("reqrun: got false, want true") } res := proxyutil.NewResponse(200, nil, req) if err := tm.ModifyResponse(res); err != moderr { t.Fatalf("tm.ModifyResponse(): got %v, want %v", err, moderr) } if !tm.ResponseModified() { t.Errorf("tm.ResponseModified(): got false, want true") } if tm.ResponseCount() != 1 { t.Errorf("tm.ResponseCount(): got %d, want %d", tm.ResponseCount(), 1) } if !resrun { t.Error("resrun: got false, want true") } tm.Reset() if tm.RequestModified() { t.Error("tm.RequestModified(): got true, want false") } if tm.ResponseModified() { t.Error("tm.ResponseModified(): got true, want false") } }
func TestResponseViewDecodeGzipContentEncoding(t *testing.T) { body := new(bytes.Buffer) gw := gzip.NewWriter(body) gw.Write([]byte("body content")) gw.Flush() gw.Close() res := proxyutil.NewResponse(200, body, nil) res.TransferEncoding = []string{"chunked"} res.Header.Set("Content-Encoding", "gzip") mv := New() if err := mv.SnapshotResponse(res); err != nil { t.Fatalf("SnapshotResponse(): got %v, want no error", err) } got, err := ioutil.ReadAll(mv.HeaderReader()) if err != nil { t.Fatalf("ioutil.ReadAll(mv.HeaderReader()): got %v, want no error", err) } hdrwant := "HTTP/1.1 200 OK\r\n" + "Transfer-Encoding: chunked\r\n" + "Content-Encoding: gzip\r\n\r\n" if !bytes.Equal(got, []byte(hdrwant)) { t.Fatalf("mv.HeaderReader(): got %q, want %q", got, hdrwant) } br, err := mv.BodyReader(Decode()) if err != nil { t.Fatalf("mv.BodyReader(): got %v, want no error", err) } got, err = ioutil.ReadAll(br) if err != nil { t.Fatalf("ioutil.ReadAll(mv.BodyReader()): got %v, wt o error", err) } bodywant := "body content" if !bytes.Equal(got, []byte(bodywant)) { t.Fatalf("mv.BodyReader(): got %q, want %q", got, bodywant) } r, err := mv.Reader(Decode()) if err != nil { t.Fatalf("mv.Reader(): got %v, want no error", err) } got, err = ioutil.ReadAll(r) if err != nil { t.Fatalf("ioutil.ReadAll(mv.Reader()): got %v, want no error", err) } if want := []byte(hdrwant + bodywant + "\r\n"); !bytes.Equal(got, want) { t.Fatalf("mv.Read(): got %q, want %q", got, want) } }
func TestModifyResponse(t *testing.T) { f := NewFilter() tm := martiantest.NewModifier() f.SetResponseModifier("id", tm) req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } res := proxyutil.NewResponse(200, nil, req) // No ID, auth required. f.SetAuthRequired(true) ctx, err := session.FromContext(nil) if err != nil { t.Fatalf("session.FromContext(): got %v, want no error", err) } martian.SetContext(req, ctx) defer martian.RemoveContext(req) if err := f.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } actx := FromContext(ctx) if actx.Error() == nil { t.Error("actx.Error(): got nil, want error") } if tm.ResponseModified() { t.Error("tm.RequestModified(): got true, want false") } // No ID, no auth required. f.SetAuthRequired(false) actx.SetError(nil) if err := f.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if tm.ResponseModified() { t.Error("tm.ResponseModified(): got true, want false") } // Valid ID. actx.SetID("id") if err := f.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if !tm.ResponseModified() { t.Error("tm.ResponseModified(): got false, want true") } }
func TestFilterFromJSON(t *testing.T) { msg := []byte(` { "querystring.Filter": { "scope": ["request", "response"], "name": "param", "value": "true", "modifier": { "header.Modifier": { "scope": ["request", "response"], "name": "Mod-Run", "value": "true" } } } }`) r, err := parse.FromJSON(msg) if err != nil { t.Fatalf("parse.FromJSON(): got %v, want no error", err) } reqmod := r.RequestModifier() if reqmod == nil { t.Fatal("reqmod: got nil, want not nil") } req, err := http.NewRequest("GET", "https://martian.test?param=true", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } ctx := martian.NewContext() if err := reqmod.ModifyRequest(ctx, req); err != nil { t.Fatalf("reqmod.ModifyRequest(): got %v, want no error", err) } if got, want := req.Header.Get("Mod-Run"), "true"; got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", "Mod-Run", got, want) } resmod := r.ResponseModifier() if resmod == nil { t.Fatalf("resmod: got nil, want not nil") } res := proxyutil.NewResponse(200, nil, req) if err := resmod.ModifyResponse(ctx, res); err != nil { t.Fatalf("resmod.ModifyResponse(): got %v, want no error", err) } if got, want := res.Header.Get("Mod-Run"), "true"; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Mod-Run", got, want) } }
// handleRequest runs the request and response modifiers and performs the roundtrip to the destination server. func (p *Proxy) handleRequest(ctx *Context, rw *bufio.ReadWriter, req *http.Request) (closing bool) { if err := proxyutil.FixBadFraming(req.Header); err != nil { Errorf("proxyutil.FixBadFraming(): %v", err) proxyutil.NewErrorResponse(400, err, req).Write(rw) } proxyutil.SetForwardedHeaders(req) proxyutil.SetViaHeader(req.Header, "1.1 martian") if err := p.ModifyRequest(ctx, req); err != nil { Errorf("martian.ModifyRequest(): %v", err) proxyutil.NewErrorResponse(400, err, req).Write(rw) return } if shouldCloseAfterReply(req.Header) { Debugf("closing after reply") closing = true } proxyutil.RemoveHopByHopHeaders(req.Header) var res *http.Response var err error if !ctx.SkipRoundTrip { Debugf("proceed to round trip for %s", req.URL) res, err = p.RoundTripper.RoundTrip(req) if err != nil { Errorf("RoundTripper.RoundTrip(%s): %v", req.URL, err) proxyutil.NewErrorResponse(502, err, req).Write(rw) return } } else { Debugf("skipped round trip for %s", req.URL) res = proxyutil.NewResponse(200, nil, req) } proxyutil.RemoveHopByHopHeaders(res.Header) if err := p.ModifyResponse(ctx, res); err != nil { Errorf("martian.ModifyResponse(): %v", err) proxyutil.NewErrorResponse(400, err, req).Write(rw) return } if closing { res.Header.Set("Connection", "close") res.Close = true } if err := res.Write(rw); err != nil { Errorf("res.Write(): %v", err) } return }
func TestResponseView(t *testing.T) { body := strings.NewReader("body content") res := proxyutil.NewResponse(200, body, nil) res.ContentLength = 12 res.Header.Set("Response-Header", "true") mv := New() if err := mv.SnapshotResponse(res); err != nil { t.Fatalf("SnapshotResponse(): got %v, want no error", err) } got, err := ioutil.ReadAll(mv.HeaderReader()) if err != nil { t.Fatalf("ioutil.ReadAll(mv.HeaderReader()): got %v, want no error", err) } hdrwant := "HTTP/1.1 200 OK\r\n" + "Content-Length: 12\r\n" + "Response-Header: true\r\n\r\n" if !bytes.Equal(got, []byte(hdrwant)) { t.Fatalf("mv.HeaderReader(): got %q, want %q", got, hdrwant) } br, err := mv.BodyReader() if err != nil { t.Fatalf("mv.BodyReader(): got %v, want no error", err) } got, err = ioutil.ReadAll(br) if err != nil { t.Fatalf("ioutil.ReadAll(mv.BodyReader()): got %v, want no error", err) } bodywant := "body content" if !bytes.Equal(got, []byte(bodywant)) { t.Fatalf("mv.BodyReader(): got %q, want %q", got, bodywant) } r, err := mv.Reader() if err != nil { t.Fatalf("mv.Reader(): got %v, want no error", err) } got, err = ioutil.ReadAll(r) if err != nil { t.Fatalf("ioutil.ReadAll(mv.Reader()): got %v, want no error", err) } if want := []byte(hdrwant + bodywant); !bytes.Equal(got, want) { t.Fatalf("mv.Read(): got %q, want %q", got, want) } // Sanity check to ensure it still parses. if _, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(got)), nil); err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } }
func TestProxyAuth(t *testing.T) { m := NewModifier() tm := martiantest.NewModifier() m.SetRequestModifier(tm) m.SetResponseModifier(tm) req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.Header.Set("Proxy-Authorization", "Basic "+encode("user:pass")) ctx, err := session.FromContext(nil) if err != nil { t.Fatalf("session.FromContext(): got %v, want no error", err) } martian.SetContext(req, ctx) defer martian.RemoveContext(req) if err := m.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } actx := auth.FromContext(ctx) if got, want := actx.ID(), "user:pass"; got != want { t.Fatalf("actx.ID(): got %q, want %q", got, want) } if err := m.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if !tm.RequestModified() { t.Error("tm.RequestModified(): got false, want true") } res := proxyutil.NewResponse(200, nil, req) if err := m.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if err := m.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if !tm.ResponseModified() { t.Error("tm.ResponseModified(): got false, want true") } if got, want := res.StatusCode, 200; got != want { t.Errorf("res.StatusCode: got %d, want %d", got, want) } if got, want := res.Header.Get("Proxy-Authenticate"), ""; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Proxy-Authenticate", got, want) } }
func TestModifyResponse(t *testing.T) { f := NewFilter("mARTian-teSTInG", "true") f.SetResponseModifier(nil) res := proxyutil.NewResponse(200, nil, nil) if err := f.ModifyResponse(res); err != nil { t.Errorf("ModifyResponse(): got %v, want no error", err) } tt := []struct { name string values []string want bool }{ { name: "Martian-Production", values: []string{"true"}, want: false, }, { name: "Martian-Testing", values: []string{"see-next-value", "true"}, want: true, }, } for i, tc := range tt { f := NewFilter("mARTian-teSTInG", "true") tm := martiantest.NewModifier() f.SetResponseModifier(tm) res := proxyutil.NewResponse(200, nil, nil) res.Header[tc.name] = tc.values if err := f.ModifyResponse(res); err != nil { t.Fatalf("%d. ModifyResponse(): got %v, want no error", i, err) } if tm.ResponseModified() != tc.want { t.Errorf("%d. tm.ResponseModified(): got %t, want %t", i, tm.ResponseModified(), tc.want) } } }
func TestRemoveHopByHopHeaders(t *testing.T) { m := NewHopByHopModifier() req, err := http.NewRequest("GET", "/", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } hs := http.Header{ // Additional hop-by-hop headers are listed in the // Connection header. "Connection": []string{ "X-Connection", "X-Hop-By-Hop, close", }, // RFC hop-by-hop headers. "Keep-Alive": []string{}, "Proxy-Authenticate": []string{}, "Proxy-Authorization": []string{}, "Te": []string{}, "Trailer": []string{}, "Transfer-Encoding": []string{}, "Upgrade": []string{}, // Hop-by-hop headers listed in the Connection header. "X-Connection": []string{}, "X-Hop-By-Hop": []string{}, // End-to-end header that should not be removed. "X-End-To-End": []string{}, } req.Header = hs if err := m.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if got, want := len(req.Header), 1; got != want { t.Fatalf("len(req.Header): got %d, want %d", got, want) } if _, ok := req.Header["X-End-To-End"]; !ok { t.Errorf("req.Header[%q]: got !ok, want ok", "X-End-To-End") } res := proxyutil.NewResponse(200, nil, req) res.Header = hs if err := m.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if got, want := len(res.Header), 1; got != want { t.Fatalf("len(res.Header): got %d, want %d", got, want) } if _, ok := res.Header["X-End-To-End"]; !ok { t.Errorf("res.Header[%q]: got !ok, want ok", "X-End-To-End") } }
func TestModifyResponse(t *testing.T) { m := NewModifier() m.SetResponseModifier(nil) req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } ctx, err := session.FromContext(nil) if err != nil { t.Fatalf("session.FromContext(): got %v, want no error", err) } martian.SetContext(req, ctx) defer martian.RemoveContext(req) res := proxyutil.NewResponse(200, nil, req) if err := m.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } // Modifier with error. tm := martiantest.NewModifier() reserr := errors.New("response error") tm.ResponseError(reserr) m.SetResponseModifier(tm) if err := m.ModifyResponse(res); err != reserr { t.Fatalf("ModifyResponse(): got %v, want %v", err, reserr) } // Modifier with auth error. tm.ResponseError(nil) autherr := errors.New("auth error") tm.ResponseFunc(func(res *http.Response) { ctx := martian.Context(res.Request) actx := auth.FromContext(ctx) actx.SetError(autherr) }) actx := auth.FromContext(ctx) actx.SetID("bad-auth") if err := m.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if got, want := res.StatusCode, 403; got != want { t.Errorf("res.StatusCode: got %d, want %d", got, want) } if got, want := actx.Error(), autherr; got != want { t.Errorf("actx.Error(): got %v, want %v", got, want) } }
func TestServeHTTPBuildsValidRequest(t *testing.T) { p := NewProxy(mitm) p.RoundTripper = RoundTripFunc(func(req *http.Request) (*http.Response, error) { if got, want := req.URL.Scheme, "https"; got != want { t.Errorf("req.URL.Scheme: got %q, want %q", got, want) } if got, want := req.URL.Host, "www.example.com"; got != want { t.Errorf("req.URL.Host: got %q, want %q", got, want) } if req.RemoteAddr == "" { t.Error("req.RemoteAddr: got empty, want addr") } return proxyutil.NewResponse(201, nil, req), nil }) rc, wc := pipeWithTimeout() defer rc.Close() defer wc.Close() rw := newHijackRecorder(wc) req, err := http.NewRequest("CONNECT", "//www.example.com:443", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } go p.ServeHTTP(rw, req) res, err := http.ReadResponse(bufio.NewReader(rc), req) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } res.Body.Close() tlsConn := tlsClient(rc, p.mitm.Authority, "www.example.com") req, err = http.NewRequest("GET", "https://www.example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no erro", err) } req.Header.Set("Connection", "close") if err := req.Write(tlsConn); err != nil { t.Fatalf("req.Write(): got %v, want no error", err) } res, err = http.ReadResponse(bufio.NewReader(tlsConn), nil) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } defer res.Body.Close() if got, want := res.StatusCode, 201; got != want { t.Errorf("res.StatusCode: got %d, want %d", got, want) } }
func TestModifyResponse(t *testing.T) { f := NewFilter() modifierRun := false f.SetResponseModifier("id", martian.ResponseModifierFunc( func(*martian.Context, *http.Response) error { modifierRun = true return nil })) res := proxyutil.NewResponse(200, nil, nil) ctx := martian.NewContext() // No ID, auth required. f.SetAuthRequired(true) if err := f.ModifyResponse(ctx, res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if ctx.Auth.Error == nil { t.Error("ctx.Auth.Error: got nil, want error") } if modifierRun { t.Error("modifierRun: got true, want false") } // No ID, no auth required. f.SetAuthRequired(false) ctx.Auth.Error = nil if err := f.ModifyResponse(ctx, res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if ctx.Auth.Error != nil { t.Errorf("ctx.Auth.Error: got %v, want no error", ctx.Auth.Error) } if modifierRun { t.Error("modifierRun: got true, want false") } // Valid ID. ctx.Auth.ID = "id" ctx.Auth.Error = nil if err := f.ModifyResponse(ctx, res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if ctx.Auth.Error != nil { t.Errorf("ctx.Auth.Error: got %v, want no error", ctx.Auth.Error) } if !modifierRun { t.Error("modifierRun: got false, want true") } }
func ExampleLogger() { l := NewLogger() l.IncludeBody(true) l.SetLogFunc(func(line string) { // Remove \r to make it easier to test with examples. fmt.Print(strings.Replace(line, "\r", "", -1)) }) req, err := http.NewRequest("GET", "http://example.com/path?querystring", strings.NewReader("request content")) if err != nil { fmt.Println(err) return } req.RequestURI = req.URL.RequestURI() req.Header.Set("Other-Header", "values") req.Close = true if err := l.ModifyRequest(req); err != nil { fmt.Println(err) return } res := proxyutil.NewResponse(200, strings.NewReader("response content"), req) res.ContentLength = 16 res.Header.Set("Date", "Tue, 15 Nov 1994 08:12:31 GMT") res.Header.Set("Other-Header", "values") if err := l.ModifyResponse(res); err != nil { fmt.Println(err) return } // Output: // -------------------------------------------------------------------------------- // Request to http://example.com/path?querystring // -------------------------------------------------------------------------------- // GET /path?querystring HTTP/1.1 // Host: example.com // Connection: close // Other-Header: values // // request content // -------------------------------------------------------------------------------- // // -------------------------------------------------------------------------------- // Response from http://example.com/path?querystring // -------------------------------------------------------------------------------- // HTTP/1.1 200 OK // Content-Length: 16 // Date: Tue, 15 Nov 1994 08:12:31 GMT // Other-Header: values // // response content // -------------------------------------------------------------------------------- }
// CopyHeaders sets the transport to respond with a 200 OK response with // headers copied from the request to the response verbatim. func (tr *Transport) CopyHeaders(names ...string) { tr.rtfunc = func(req *http.Request) (*http.Response, error) { res := proxyutil.NewResponse(200, nil, req) for _, n := range names { res.Header.Set(n, req.Header.Get(n)) } return res, nil } }