func TestExportIgnoresOrphanedResponse(t *testing.T) { logger := NewLogger("martian", "2.0.0") req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } ctx, err := session.FromContext(nil) if err != nil { t.Fatalf("session.FromContext(): got %v, want no error", err) } martian.SetContext(req, ctx) defer martian.RemoveContext(req) if err := logger.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } // Reset before the response comes back. logger.Reset() res := proxyutil.NewResponse(200, nil, req) if err := logger.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } log := logger.Export().Log if got, want := len(log.Entries), 0; got != want { t.Errorf("len(log.Entries): got %d, want %d", got, want) } }
func TestModifyRequest(t *testing.T) { f := NewFilter() tm := martiantest.NewModifier() f.SetRequestModifier("id", tm) req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("NewRequest(): got %v, want no error", err) } // No ID, auth required. f.SetAuthRequired(true) ctx := session.FromContext(nil) martian.SetContext(req, ctx) defer martian.RemoveContext(req) if err := f.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } actx := FromContext(ctx) if actx.Error() == nil { t.Error("actx.Error(): got nil, want error") } if tm.RequestModified() { t.Error("tm.RequestModified(): got true, want false") } tm.Reset() // No ID, auth not required. f.SetAuthRequired(false) actx.SetError(nil) if err := f.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if actx.Error() != nil { t.Errorf("actx.Error(): got %v, want no error", err) } if tm.RequestModified() { t.Error("tm.RequestModified(): got true, want false") } // Valid ID. actx.SetError(nil) actx.SetID("id") if err := f.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if actx.Error() != nil { t.Errorf("actx.Error(): got %v, want no error", actx.Error()) } if !tm.RequestModified() { t.Error("tm.RequestModified(): got false, want true") } }
func TestReset(t *testing.T) { logger := NewLogger("martian", "2.0.0") req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("NewRequest(): got %v, want no error", err) } ctx, err := session.FromContext(nil) if err != nil { t.Fatalf("session.FromContext(): got %v, want no error", err) } martian.SetContext(req, ctx) defer martian.RemoveContext(req) if err := logger.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } log := logger.Export().Log if got, want := len(log.Entries), 1; got != want { t.Fatalf("len(log.Entries): got %d, want %d", got, want) } logger.Reset() log = logger.Export().Log if got, want := len(log.Entries), 0; got != want { t.Errorf("len(log.Entries): got %d, want %d", got, want) } }
func TestContext(t *testing.T) { t.Parallel() req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } want, err := session.FromContext(nil) if err != nil { t.Fatalf("session.FromContext(): got %v, want no error", err) } SetContext(req, want) if got := Context(req); got != want { t.Errorf("Context(req): got %v, want %v", got, want) } RemoveContext(req) if got := Context(req); got != nil { t.Errorf("Context(req): got %v, want nil", got) } }
func TestNoModifiers(t *testing.T) { m := NewModifier() m.SetRequestModifier(nil) m.SetResponseModifier(nil) req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } ctx, err := session.FromContext(nil) if err != nil { t.Fatalf("session.FromContext(): got %v, want no error", err) } martian.SetContext(req, ctx) defer martian.RemoveContext(req) if err := m.ModifyRequest(req); err != nil { t.Errorf("ModifyRequest(): got %v, want no error", err) } res := proxyutil.NewResponse(200, nil, req) if err := m.ModifyResponse(res); err != nil { t.Errorf("ModifyResponse(): got %v, want no error", err) } }
func TestExportSortsEntries(t *testing.T) { logger := NewLogger("martian", "2.0.0") count := 10 for i := 0; i < count; i++ { req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("NewRequest(): got %v, want no error", err) } ctx, err := session.FromContext(nil) if err != nil { t.Fatalf("session.FromContext(): got %v, want no error", err) } martian.SetContext(req, ctx) defer martian.RemoveContext(req) if err := logger.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } } log := logger.Export().Log for i := 0; i < count-1; i++ { first := log.Entries[i] second := log.Entries[i+1] if got, want := first.StartedDateTime, second.StartedDateTime; got.After(want) { t.Errorf("entry.StartedDateTime: got %s, want to be before %s", got, want) } } }
func TestModifyRequest(t *testing.T) { m := NewModifier() m.SetRequestModifier(nil) req, err := http.NewRequest("CONNECT", "https://www.example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } ctx := session.FromContext(nil) martian.SetContext(req, ctx) defer martian.RemoveContext(req) if err := m.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } actx := auth.FromContext(ctx) if got, want := actx.ID(), ""; got != want { t.Errorf("actx.ID(): got %q, want %q", got, want) } // IP with port and modifier with error. tm := martiantest.NewModifier() reqerr := errors.New("request error") tm.RequestError(reqerr) req.RemoteAddr = "1.1.1.1:8111" m.SetRequestModifier(tm) if err := m.ModifyRequest(req); err != reqerr { t.Fatalf("ModifyConnectRequest(): got %v, want %v", err, reqerr) } if got, want := actx.ID(), "1.1.1.1"; got != want { t.Errorf("actx.ID(): got %q, want %q", got, want) } // IP without port and modifier with auth error. req.RemoteAddr = "4.4.4.4" autherr := errors.New("auth error") tm.RequestError(nil) tm.RequestFunc(func(req *http.Request) { ctx := martian.Context(req) actx := auth.FromContext(ctx) actx.SetError(autherr) }) if err := m.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if got, want := actx.ID(), ""; got != want { t.Errorf("actx.ID(): got %q, want %q", got, want) } }
func TestProxyAuth(t *testing.T) { m := NewModifier() tm := martiantest.NewModifier() m.SetRequestModifier(tm) m.SetResponseModifier(tm) req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.Header.Set("Proxy-Authorization", "Basic "+encode("user:pass")) ctx, err := session.FromContext(nil) if err != nil { t.Fatalf("session.FromContext(): got %v, want no error", err) } martian.SetContext(req, ctx) defer martian.RemoveContext(req) if err := m.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } actx := auth.FromContext(ctx) if got, want := actx.ID(), "user:pass"; got != want { t.Fatalf("actx.ID(): got %q, want %q", got, want) } if err := m.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if !tm.RequestModified() { t.Error("tm.RequestModified(): got false, want true") } res := proxyutil.NewResponse(200, nil, req) if err := m.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if err := m.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if !tm.ResponseModified() { t.Error("tm.ResponseModified(): got false, want true") } if got, want := res.StatusCode, 200; got != want { t.Errorf("res.StatusCode: got %d, want %d", got, want) } if got, want := res.Header.Get("Proxy-Authenticate"), ""; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Proxy-Authenticate", got, want) } }
func TestModifyResponse(t *testing.T) { m := NewModifier() m.SetResponseModifier(nil) req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } ctx, err := session.FromContext(nil) if err != nil { t.Fatalf("session.FromContext(): got %v, want no error", err) } martian.SetContext(req, ctx) defer martian.RemoveContext(req) res := proxyutil.NewResponse(200, nil, req) if err := m.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } // Modifier with error. tm := martiantest.NewModifier() reserr := errors.New("response error") tm.ResponseError(reserr) m.SetResponseModifier(tm) if err := m.ModifyResponse(res); err != reserr { t.Fatalf("ModifyResponse(): got %v, want %v", err, reserr) } // Modifier with auth error. tm.ResponseError(nil) autherr := errors.New("auth error") tm.ResponseFunc(func(res *http.Response) { ctx := martian.Context(res.Request) actx := auth.FromContext(ctx) actx.SetError(autherr) }) actx := auth.FromContext(ctx) actx.SetID("bad-auth") if err := m.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if got, want := res.StatusCode, 403; got != want { t.Errorf("res.StatusCode: got %d, want %d", got, want) } if got, want := actx.Error(), autherr; got != want { t.Errorf("actx.Error(): got %v, want %v", got, want) } }
func TestModifyRequestBodyURLEncoded(t *testing.T) { logger := NewLogger("martian", "2.0.0") body := strings.NewReader("first=true&second=false") req, err := http.NewRequest("POST", "http://example.com", body) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") ctx, err := session.FromContext(nil) if err != nil { t.Fatalf("session.FromContext(): got %v, want no error", err) } martian.SetContext(req, ctx) defer martian.RemoveContext(req) if err := logger.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } log := logger.Export().Log if got, want := len(log.Entries), 1; got != want { t.Errorf("len(log.Entries): got %v, want %v", got, want) } pd := log.Entries[0].Request.PostData if got, want := pd.MimeType, "application/x-www-form-urlencoded"; got != want { t.Errorf("PostData.MimeType: got %v, want %v", got, want) } if got, want := len(pd.Params), 2; got != want { t.Fatalf("len(PostData.Params): got %d, want %d", got, want) } for _, p := range pd.Params { var want string switch p.Name { case "first": want = "true" case "second": want = "false" default: t.Errorf("PostData.Params: got %q, want to not be present", p.Name) continue } if got := p.Value; got != want { t.Errorf("PostData.Params[%q]: got %q, want %q", p.Name, got, want) } } }
func TestViaModifier(t *testing.T) { m := NewViaModifier("martian") req, err := http.NewRequest("GET", "/", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } res := proxyutil.NewResponse(200, nil, req) ctx := session.FromContext(nil) martian.SetContext(req, ctx) defer martian.RemoveContext(req) if err := m.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if got, want := req.Header.Get("Via"), "1.1 martian"; got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", "Via", got, want) } if err := m.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } req.Header.Set("Via", "1.0 alpha") if err := m.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if got, want := req.Header.Get("Via"), "1.0 alpha, 1.1 martian"; got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", "Via", got, want) } req.Header.Set("Via", "1.0 alpha, 1.1 martian, 1.1 beta") if err := m.ModifyRequest(req); err == nil { t.Fatal("ModifyRequest(): got nil, want request loop error") } if !ctx.SkippingRoundTrip() { t.Errorf("ctx.SkippingRoundTrip(): got false, want true") } if err := m.ModifyResponse(res); err == nil { t.Fatal("ModifyResponse(): got nil, want request loop error") } if got, want := res.StatusCode, 400; got != want { t.Errorf("res.StatusCode: got %d, want %d", got, want) } if got, want := res.Status, http.StatusText(400); got != want { t.Errorf("res.Status: got %q, want %q", got, want) } }
func (p *Proxy) handleLoop(conn net.Conn) { p.conns.Add(1) defer p.conns.Done() defer conn.Close() ctx := session.FromContext(nil) brw := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)) for { deadline := time.Now().Add(p.timeout) conn.SetDeadline(deadline) if err := p.handle(ctx, conn, brw); isCloseable(err) { Infof("martian: closing connection: %v", conn.RemoteAddr()) return } } }
func TestHARExportsTime(t *testing.T) { logger := NewLogger("martian", "2.0.0") req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("NewRequest(): got %v, want no error", err) } ctx, err := session.FromContext(nil) if err != nil { t.Fatalf("session.FromContext(): got %v, want no error", err) } martian.SetContext(req, ctx) defer martian.RemoveContext(req) if err := logger.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } // Simulate fast network round trip. time.Sleep(10 * time.Millisecond) res := proxyutil.NewResponse(200, nil, req) if err := logger.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } log := logger.Export().Log if got, want := len(log.Entries), 1; got != want { t.Fatalf("len(log.Entries): got %v, want %v", got, want) } entry := log.Entries[0] min, max := int64(10), int64(100) if got := entry.Time; got < min || got > max { t.Errorf("entry.Time: got %dms, want between %dms and %vms", got, min, max) } }
func TestModifyRequestBodyArbitraryContentType(t *testing.T) { logger := NewLogger("martian", "2.0.0") body := "arbitrary binary data" req, err := http.NewRequest("POST", "http://www.example.com", strings.NewReader(body)) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } ctx, err := session.FromContext(nil) if err != nil { t.Fatalf("session.FromContext(): got %v, want no error", err) } martian.SetContext(req, ctx) defer martian.RemoveContext(req) if err := logger.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } log := logger.Export().Log if got, want := len(log.Entries), 1; got != want { t.Errorf("len(log.Entries): got %d, want %d", got, want) } pd := log.Entries[0].Request.PostData if got, want := pd.MimeType, ""; got != want { t.Errorf("PostData.MimeType: got %q, want %q", got, want) } if got, want := len(pd.Params), 0; got != want { t.Errorf("len(PostData.Params): got %d, want %d", got, want) } if got, want := pd.Text, body; got != want { t.Errorf("PostData.Text: got %q, want %q", got, want) } }
func TestModifyRequestBodyMultipart(t *testing.T) { logger := NewLogger("martian", "2.0.0") body := new(bytes.Buffer) mpw := multipart.NewWriter(body) mpw.SetBoundary("boundary") if err := mpw.WriteField("key", "value"); err != nil { t.Errorf("mpw.WriteField(): got %v, want no error", err) } w, err := mpw.CreateFormFile("file", "test.txt") if _, err = w.Write([]byte("file contents")); err != nil { t.Fatalf("Write(): got %v, want no error", err) } mpw.Close() req, err := http.NewRequest("POST", "http://example.com", body) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.Header.Set("Content-Type", mpw.FormDataContentType()) ctx, err := session.FromContext(nil) if err != nil { t.Fatalf("session.FromContext(): got %v, want no error", err) } martian.SetContext(req, ctx) defer martian.RemoveContext(req) if err := logger.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } log := logger.Export().Log if got, want := len(log.Entries), 1; got != want { t.Fatalf("len(log.Entries): got %d, want %d", got, want) } pd := log.Entries[0].Request.PostData if got, want := pd.MimeType, "multipart/form-data"; got != want { t.Errorf("PostData.MimeType: got %q, want %q", got, want) } if got, want := len(pd.Params), 2; got != want { t.Errorf("PostData.Params: got %d, want %d", got, want) } for _, p := range pd.Params { var want Param switch p.Name { case "key": want = Param{ Filename: "", ContentType: "", Value: "value", } case "file": want = Param{ Filename: "test.txt", ContentType: "application/octet-stream", Value: "file contents", } default: t.Errorf("pd.Params: got %q, want not to be present", p.Name) continue } if got, want := p.Filename, want.Filename; got != want { t.Errorf("p.Filename: got %q, want %q", got, want) } if got, want := p.ContentType, want.ContentType; got != want { t.Errorf("p.ContentType: got %q, want %q", got, want) } if got, want := p.Value, want.Value; got != want { t.Errorf("p.Value: got %q, want %q", got, want) } } }
func TestModifyRequest(t *testing.T) { req, err := http.NewRequest("GET", "http://example.com/path?query=true", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.Header.Add("Request-Header", "first") req.Header.Add("Request-Header", "second") cookie := &http.Cookie{ Name: "request", Value: "cookie", } req.AddCookie(cookie) ctx, err := session.FromContext(nil) if err != nil { t.Fatalf("session.FromContext(): got %v, want no error", err) } martian.SetContext(req, ctx) defer martian.RemoveContext(req) logger := NewLogger("martian", "2.0.0") if err := logger.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } log := logger.Export().Log if got, want := log.Version, "1.2"; got != want { t.Errorf("log.Version: got %q, want %q", got, want) } if got, want := log.Creator.Name, "martian"; got != want { t.Errorf("log.Creator.Name: got %q, want %q", got, want) } if got, want := log.Creator.Version, "2.0.0"; got != want { t.Errorf("log.Creator.Version: got %q, want %q", got, want) } if got, want := len(log.Entries), 1; got != want { t.Fatalf("len(log.Entries): got %d, want %d", got, want) } entry := log.Entries[0] if got, want := time.Since(entry.StartedDateTime), time.Second; got > want { t.Errorf("entry.StartedDateTime: got %s, want less than %s", got, want) } hreq := entry.Request if got, want := hreq.Method, "GET"; got != want { t.Errorf("hreq.Method: got %q, want %q", got, want) } if got, want := hreq.URL, "http://example.com/path?query=true"; got != want { t.Errorf("hreq.URL: got %q, want %q", got, want) } if got, want := hreq.HTTPVersion, "HTTP/1.1"; got != want { t.Errorf("hreq.HTTPVersion: got %q, want %q", got, want) } if got, want := hreq.BodySize, int64(0); got != want { t.Errorf("hreq.BodySize: got %d, want %d", got, want) } if got, want := hreq.HeadersSize, int64(-1); got != want { t.Errorf("hreq.HeadersSize: got %d, want %d", got, want) } if got, want := len(hreq.QueryString), 1; got != want { t.Fatalf("len(hreq.QueryString): got %d, want %q", got, want) } qs := hreq.QueryString[0] if got, want := qs.Name, "query"; got != want { t.Errorf("qs.Name: got %q, want %q", got, want) } if got, want := qs.Value, "true"; got != want { t.Errorf("qs.Value: got %q, want %q", got, want) } if got, want := len(hreq.Headers), 2; got != want { t.Fatalf("len(hreq.Headers): got %d, want %d", got, want) } for _, h := range hreq.Headers { var want string switch h.Name { case "Request-Header": want = "first, second" case "Cookie": want = cookie.String() default: t.Errorf("hreq.Headers: got %q, want header to not be present", h.Name) continue } if got := h.Value; got != want { t.Errorf("hreq.Headers[%q]: got %q, want %q", h.Name, got, want) } } if got, want := len(hreq.Cookies), 1; got != want { t.Fatalf("len(hreq.Cookies): got %d, want %d", got, want) } hcookie := hreq.Cookies[0] if got, want := hcookie.Name, "request"; got != want { t.Errorf("hcookie.Name: got %q, want %q", got, want) } if got, want := hcookie.Value, "cookie"; got != want { t.Errorf("hcookie.Value: got %q, want %q", got, want) } }
func TestProxyAuthInvalidCredentials(t *testing.T) { m := NewModifier() autherr := errors.New("auth error") tm := martiantest.NewModifier() tm.RequestFunc(func(req *http.Request) { ctx := martian.Context(req) actx := auth.FromContext(ctx) actx.SetError(autherr) }) tm.ResponseFunc(func(res *http.Response) { ctx := martian.Context(res.Request) actx := auth.FromContext(ctx) actx.SetError(autherr) }) m.SetRequestModifier(tm) m.SetResponseModifier(tm) req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.Header.Set("Proxy-Authorization", "Basic "+encode("user:pass")) ctx, err := session.FromContext(nil) if err != nil { t.Fatalf("session.FromContext(): got %v, want no error", err) } martian.SetContext(req, ctx) defer martian.RemoveContext(req) if err := m.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if !tm.RequestModified() { t.Error("tm.RequestModified(): got false, want true") } actx := auth.FromContext(ctx) if actx.Error() != autherr { t.Fatalf("auth.Error(): got %v, want %v", actx.Error(), autherr) } actx.SetError(nil) res := proxyutil.NewResponse(200, nil, req) if err := m.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if !tm.ResponseModified() { t.Error("tm.ResponseModified(): got false, want true") } if actx.Error() != autherr { t.Fatalf("auth.Error(): got %v, want %v", actx.Error(), autherr) } if got, want := res.StatusCode, http.StatusProxyAuthRequired; got != want { t.Errorf("res.StatusCode: got %d, want %d", got, want) } if got, want := res.Header.Get("Proxy-Authenticate"), "Basic"; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Proxy-Authenticate", got, want) } }
func TestModifyResponse(t *testing.T) { req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("NewRequest(): got %v, want no error", err) } ctx, err := session.FromContext(nil) if err != nil { t.Fatalf("session.FromContext(): got %v, want no error", err) } martian.SetContext(req, ctx) defer martian.RemoveContext(req) res := proxyutil.NewResponse(301, strings.NewReader("response body"), req) res.Header.Add("Response-Header", "first") res.Header.Add("Response-Header", "second") res.Header.Set("Location", "google.com") expires := time.Now() cookie := &http.Cookie{ Name: "response", Value: "cookie", Path: "/", Domain: "example.com", Expires: expires, Secure: true, HttpOnly: true, } res.Header.Set("Set-Cookie", cookie.String()) logger := NewLogger("martian", "2.0.0") if err := logger.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if err := logger.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } log := logger.Export().Log if got, want := len(log.Entries), 1; got != want { t.Fatalf("len(log.Entries): got %d, want %d", got, want) } hres := log.Entries[0].Response if got, want := hres.Status, 301; got != want { t.Errorf("hres.Status: got %d, want %d", got, want) } if got, want := hres.StatusText, "Moved Permanently"; got != want { t.Errorf("hres.StatusText: got %q, want %q", got, want) } if got, want := hres.HTTPVersion, "HTTP/1.1"; got != want { t.Errorf("hres.HTTPVersion: got %q, want %q", got, want) } if got, want := hres.Content.Text, []byte("response body"); !bytes.Equal(got, want) { t.Errorf("hres.Content.Text: got %q, want %q", got, want) } if got, want := len(hres.Headers), 3; got != want { t.Fatalf("len(hreq.Headers): got %d, want %d", got, want) } for _, h := range hres.Headers { var want string switch h.Name { case "Response-Header": want = "first, second" case "Location": want = "google.com" case "Set-Cookie": want = cookie.String() default: t.Errorf("hres.Headers: got %q, want header to not be present", h.Name) continue } if got := h.Value; got != want { t.Errorf("hres.Headers[%q]: got %q, want %q", h.Name, got, want) } } if got, want := len(hres.Cookies), 1; got != want { t.Fatalf("len(hres.Cookies): got %d, want %d", got, want) } hcookie := hres.Cookies[0] if got, want := hcookie.Name, "response"; got != want { t.Errorf("hcookie.Name: got %q, want %q", got, want) } if got, want := hcookie.Value, "cookie"; got != want { t.Errorf("hcookie.Value: got %q, want %q", got, want) } if got, want := hcookie.Path, "/"; got != want { t.Errorf("hcookie.Path: got %q, want %q", got, want) } if got, want := hcookie.Domain, "example.com"; got != want { t.Errorf("hcookie.Domain: got %q, want %q", got, want) } if got, want := hcookie.Expires, expires; got.Equal(want) { t.Errorf("hcookie.Expires: got %s, want %s", got, want) } if !hcookie.HTTPOnly { t.Error("hcookie.HTTPOnly: got false, want true") } if !hcookie.Secure { t.Error("hcookie.Secure: got false, want true") } }
func (p *Proxy) handle(ctx *session.Context, conn net.Conn, brw *bufio.ReadWriter) error { log.Debugf("martian: waiting for request: %v", conn.RemoteAddr()) req, err := http.ReadRequest(brw.Reader) if err != nil { if isCloseable(err) { log.Debugf("martian: connection closed prematurely: %v", err) } else { log.Errorf("martian: failed to read request: %v", err) } // TODO: TCPConn.WriteClose() to avoid sending an RST to the client. return errClose } defer req.Body.Close() if h, pattern := p.mux.Handler(req); pattern != "" { defer brw.Flush() closing := req.Close || p.Closing() log.Infof("martian: intercepted configuration request: %s", req.URL) rw := newResponseWriter(brw, closing) defer rw.Close() h.ServeHTTP(rw, req) // Call WriteHeader to ensure a response is sent, since the handler isn't // required to call WriteHeader/Write. rw.WriteHeader(200) if closing { return errClose } return nil } ctx, err = session.FromContext(ctx) if err != nil { log.Errorf("martian: failed to derive context: %v", err) return err } SetContext(req, ctx) defer RemoveContext(req) if tconn, ok := conn.(*tls.Conn); ok { ctx.GetSession().MarkSecure() cs := tconn.ConnectionState() req.TLS = &cs } req.URL.Scheme = "http" if ctx.GetSession().IsSecure() { log.Debugf("martian: forcing HTTPS inside secure session") req.URL.Scheme = "https" } req.RemoteAddr = conn.RemoteAddr().String() if req.URL.Host == "" { req.URL.Host = req.Host } log.Infof("martian: received request: %s", req.URL) if req.Method == "CONNECT" { if err := p.reqmod.ModifyRequest(req); err != nil { log.Errorf("martian: error modifying CONNECT request: %v", err) proxyutil.Warning(req.Header, err) } if p.mitm != nil { log.Debugf("martian: attempting MITM for connection: %s", req.Host) res := proxyutil.NewResponse(200, nil, req) if err := p.resmod.ModifyResponse(res); err != nil { log.Errorf("martian: error modifying CONNECT response: %v", err) proxyutil.Warning(res.Header, err) } res.Write(brw) brw.Flush() log.Debugf("martian: completed MITM for connection: %s", req.Host) tlsconn := tls.Server(conn, p.mitm.TLSForHost(req.Host)) brw.Writer.Reset(tlsconn) brw.Reader.Reset(tlsconn) return p.handle(ctx, tlsconn, brw) } log.Debugf("martian: attempting to establish CONNECT tunnel: %s", req.URL.Host) res, cconn, cerr := p.connect(req) if cerr != nil { log.Errorf("martian: failed to CONNECT: %v", err) res = proxyutil.NewResponse(502, nil, req) proxyutil.Warning(res.Header, cerr) if err := p.resmod.ModifyResponse(res); err != nil { log.Errorf("martian: error modifying CONNECT response: %v", err) proxyutil.Warning(res.Header, err) } res.Write(brw) return brw.Flush() } defer res.Body.Close() defer cconn.Close() if err := p.resmod.ModifyResponse(res); err != nil { log.Errorf("martian: error modifying CONNECT response: %v", err) proxyutil.Warning(res.Header, err) } res.Write(brw) brw.Flush() cbw := bufio.NewWriter(cconn) cbr := bufio.NewReader(cconn) defer cbw.Flush() copySync := func(w io.Writer, r io.Reader, donec chan<- bool) { io.Copy(w, r) donec <- true } donec := make(chan bool, 2) go copySync(cbw, brw, donec) go copySync(brw, cbr, donec) log.Debugf("martian: established CONNECT tunnel, proxying traffic") <-donec <-donec log.Debugf("martian: closed CONNECT tunnel") return errClose } if err := p.reqmod.ModifyRequest(req); err != nil { log.Errorf("martian: error modifying request: %v", err) proxyutil.Warning(req.Header, err) } res, err := p.roundTrip(ctx, req) if err != nil { log.Errorf("martian: failed to round trip: %v", err) res = proxyutil.NewResponse(502, nil, req) proxyutil.Warning(res.Header, err) } defer res.Body.Close() if err := p.resmod.ModifyResponse(res); err != nil { log.Errorf("martian: error modifying response: %v", err) proxyutil.Warning(res.Header, err) } var closing error if req.Close || p.Closing() { log.Debugf("martian: received close request: %v", req.RemoteAddr) res.Header.Add("Connection", "close") closing = errClose } log.Debugf("martian: sent response: %v", req.URL) res.Write(brw) brw.Flush() return closing }
func TestExportHandlerServeHTTP(t *testing.T) { logger := NewLogger("martian", "2.0.0") req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } ctx, err := session.FromContext(nil) if err != nil { t.Fatalf("session.FromContext(): got %v, want no error", err) } martian.SetContext(req, ctx) defer martian.RemoveContext(req) if err := logger.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } res := proxyutil.NewResponse(200, nil, req) if err := logger.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } h := NewExportHandler(logger) req, err = http.NewRequest("GET", "/", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } rw := httptest.NewRecorder() h.ServeHTTP(rw, req) if got, want := rw.Code, http.StatusOK; got != want { t.Errorf("rw.Code: got %d, want %d", got, want) } hl := &HAR{} if err := json.Unmarshal(rw.Body.Bytes(), hl); err != nil { t.Fatalf("json.Unmarshal(): got %v, want no error", err) } if got, want := len(hl.Log.Entries), 1; got != want { t.Fatalf("len(hl.Log.Entries): got %v, want %v", got, want) } entry := hl.Log.Entries[0] if got, want := entry.Request.URL, "http://example.com"; got != want { t.Errorf("Request.URL: got %q, want %q", got, want) } if got, want := entry.Response.Status, 200; got != want { t.Errorf("Response.Status: got %d, want %d", got, want) } rh := NewResetHandler(logger) req, err = http.NewRequest("DELETE", "/", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } rw = httptest.NewRecorder() rh.ServeHTTP(rw, req) req, err = http.NewRequest("GET", "/", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } rw = httptest.NewRecorder() h.ServeHTTP(rw, req) if got, want := rw.Code, http.StatusOK; got != want { t.Errorf("rw.Code: got %d, want %d", got, want) } hl = &HAR{} if err := json.Unmarshal(rw.Body.Bytes(), hl); err != nil { t.Fatalf("json.Unmarshal(): got %v, want no error", err) } if got, want := len(hl.Log.Entries), 0; got != want { t.Errorf("len(Log.Entries): got %v, want %v", got, want) } }