func TestIntegrationTransparentHTTP(t *testing.T) { t.Parallel() l, err := net.Listen("tcp", "[::1]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } p := NewProxy() defer p.Close() tr := martiantest.NewTransport() p.SetRoundTripper(tr) p.SetTimeout(200 * time.Millisecond) tm := martiantest.NewModifier() p.SetRequestModifier(tm) p.SetResponseModifier(tm) go p.Serve(l) conn, err := net.Dial("tcp", l.Addr().String()) if err != nil { t.Fatalf("net.Dial(): got %v, want no error", err) } defer conn.Close() req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } // GET / HTTP/1.1 // Host: www.example.com if err := req.Write(conn); err != nil { t.Fatalf("req.Write(): got %v, want no error", err) } res, err := http.ReadResponse(bufio.NewReader(conn), req) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } if got, want := res.StatusCode, 200; got != want { t.Fatalf("res.StatusCode: got %d, want %d", got, want) } if !tm.RequestModified() { t.Error("tm.RequestModified(): got false, want true") } if !tm.ResponseModified() { t.Error("tm.ResponseModified(): got false, want true") } }
func TestIntegrationSkipRoundTrip(t *testing.T) { t.Parallel() l, err := net.Listen("tcp", "[::1]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } p := NewProxy() defer p.Close() // Transport will be skipped, no 500. tr := martiantest.NewTransport() tr.Respond(500) p.SetRoundTripper(tr) p.SetTimeout(200 * time.Millisecond) tm := martiantest.NewModifier() tm.RequestFunc(func(req *http.Request) { ctx := Context(req) ctx.SkipRoundTrip() }) p.SetRequestModifier(tm) go p.Serve(l) conn, err := net.Dial("tcp", l.Addr().String()) if err != nil { t.Fatalf("net.Dial(): got %v, want no error", err) } defer conn.Close() req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } // GET http://example.com/ HTTP/1.1 // Host: example.com if err := req.WriteProxy(conn); err != nil { t.Fatalf("req.WriteProxy(): got %v, want no error", err) } // Response from skipped round trip. res, err := http.ReadResponse(bufio.NewReader(conn), req) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } defer res.Body.Close() if got, want := res.StatusCode, 200; got != want { t.Errorf("res.StatusCode: got %d, want %d", got, want) } }
func TestIntegrationFailedRoundTrip(t *testing.T) { t.Parallel() l, err := net.Listen("tcp", "[::1]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } p := NewProxy() defer p.Close() tr := martiantest.NewTransport() trerr := errors.New("round trip error") tr.RespondError(trerr) p.SetRoundTripper(tr) p.SetTimeout(200 * time.Millisecond) go p.Serve(l) conn, err := net.Dial("tcp", l.Addr().String()) if err != nil { t.Fatalf("net.Dial(): got %v, want no error", err) } defer conn.Close() req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } // GET http://example.com/ HTTP/1.1 // Host: example.com if err := req.WriteProxy(conn); err != nil { t.Fatalf("req.WriteProxy(): got %v, want no error", err) } // Response from failed round trip. res, err := http.ReadResponse(bufio.NewReader(conn), req) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } defer res.Body.Close() if got, want := res.StatusCode, 502; got != want { t.Errorf("res.StatusCode: got %d, want %d", got, want) } if got, want := res.Header.Get("Warning"), trerr.Error(); !strings.Contains(got, want) { t.Errorf("res.Header.Get(%q): got %q, want to contain %q", "Warning", got, want) } }
func TestIntegrationTemporaryTimeout(t *testing.T) { t.Parallel() l, err := net.Listen("tcp", "[::1]:0") if err != nil { t.Fatalf("net.Liste(): got %v, want no error", err) } p := NewProxy() defer p.Close() tr := martiantest.NewTransport() p.SetRoundTripper(tr) p.SetTimeout(200 * time.Millisecond) // Start the proxy with a listener that will return a temporary error on // Accept() three times. go p.Serve(newTimeoutListener(l, 3)) conn, err := net.Dial("tcp", l.Addr().String()) if err != nil { t.Fatalf("net.Dial(): got %v, want no error", err) } defer conn.Close() req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.Header.Set("Connection", "close") // GET http://example.com/ HTTP/1.1 // Host: example.com if err := req.WriteProxy(conn); err != nil { t.Fatalf("req.WriteProxy(): got %v, want no error", err) } res, err := http.ReadResponse(bufio.NewReader(conn), req) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } defer res.Body.Close() if got, want := res.StatusCode, 200; got != want { t.Errorf("res.StatusCode: got %d, want %d", got, want) } }
func TestIntegrationTransparentMITM(t *testing.T) { t.Parallel() ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", 2*time.Hour) if err != nil { t.Fatalf("mitm.NewAuthority(): got %v, want no error", err) } mc, err := mitm.NewConfig(ca, priv) if err != nil { t.Fatalf("mitm.NewConfig(): got %v, want no error", err) } // Start TLS listener with config that will generate certificates based on // SNI from connection. // // BUG: tls.Listen will not accept a tls.Config where Certificates is empty, // even though it is supported by tls.Server when GetCertificate is not nil. l, err := net.Listen("tcp", "[::1]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } l = tls.NewListener(l, mc.TLS()) p := NewProxy() defer p.Close() tr := martiantest.NewTransport() tr.Func(func(req *http.Request) (*http.Response, error) { res := proxyutil.NewResponse(200, nil, req) res.Header.Set("Request-Scheme", req.URL.Scheme) return res, nil }) p.SetRoundTripper(tr) tm := martiantest.NewModifier() p.SetRequestModifier(tm) p.SetResponseModifier(tm) go p.Serve(l) roots := x509.NewCertPool() roots.AddCert(ca) tlsconn, err := tls.Dial("tcp", l.Addr().String(), &tls.Config{ // Verify the hostname is example.com. ServerName: "example.com", // The certificate will have been generated during MITM, so we need to // verify it with the generated CA certificate. RootCAs: roots, }) if err != nil { t.Fatalf("tls.Dial(): got %v, want no error", err) } defer tlsconn.Close() req, err := http.NewRequest("GET", "https://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } // Write Encrypted request directly, no CONNECT. // GET / HTTP/1.1 // Host: example.com if err := req.Write(tlsconn); err != nil { t.Fatalf("req.Write(): got %v, want no error", err) } res, err := http.ReadResponse(bufio.NewReader(tlsconn), req) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } defer res.Body.Close() if got, want := res.StatusCode, 200; got != want { t.Fatalf("res.StatusCode: got %d, want %d", got, want) } if got, want := res.Header.Get("Request-Scheme"), "https"; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Request-Scheme", got, want) } if !tm.RequestModified() { t.Errorf("tm.RequestModified(): got false, want true") } if !tm.ResponseModified() { t.Errorf("tm.ResponseModified(): got false, want true") } }
func TestIntegrationMITM(t *testing.T) { t.Parallel() l, err := net.Listen("tcp", "[::1]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } p := NewProxy() defer p.Close() tr := martiantest.NewTransport() tr.Func(func(req *http.Request) (*http.Response, error) { res := proxyutil.NewResponse(200, nil, req) res.Header.Set("Request-Scheme", req.URL.Scheme) return res, nil }) p.SetRoundTripper(tr) p.SetTimeout(600 * time.Millisecond) ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", 2*time.Hour) if err != nil { t.Fatalf("mitm.NewAuthority(): got %v, want no error", err) } mc, err := mitm.NewConfig(ca, priv) if err != nil { t.Fatalf("mitm.NewConfig(): got %v, want no error", err) } p.SetMITM(mc) tm := martiantest.NewModifier() reqerr := errors.New("request error") reserr := errors.New("response error") tm.RequestError(reqerr) tm.ResponseError(reserr) p.SetRequestModifier(tm) p.SetResponseModifier(tm) go p.Serve(l) conn, err := net.Dial("tcp", l.Addr().String()) if err != nil { t.Fatalf("net.Dial(): got %v, want no error", err) } defer conn.Close() req, err := http.NewRequest("CONNECT", "//example.com:443", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } // CONNECT example.com:443 HTTP/1.1 // Host: example.com if err := req.Write(conn); err != nil { t.Fatalf("req.Write(): got %v, want no error", err) } // Response MITM'd from proxy. res, err := http.ReadResponse(bufio.NewReader(conn), req) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } if got, want := res.StatusCode, 200; got != want { t.Errorf("res.StatusCode: got %d, want %d", got, want) } if got, want := res.Header.Get("Warning"), reserr.Error(); !strings.Contains(got, want) { t.Errorf("res.Header.Get(%q): got %q, want to contain %q", "Warning", got, want) } roots := x509.NewCertPool() roots.AddCert(ca) tlsconn := tls.Client(conn, &tls.Config{ ServerName: "example.com", RootCAs: roots, }) defer tlsconn.Close() req, err = http.NewRequest("GET", "https://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } // GET / HTTP/1.1 // Host: example.com if err := req.Write(tlsconn); err != nil { t.Fatalf("req.Write(): got %v, want no error", err) } // Response from MITM proxy. res, err = http.ReadResponse(bufio.NewReader(tlsconn), req) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } defer res.Body.Close() if got, want := res.StatusCode, 200; got != want { t.Errorf("res.StatusCode: got %d, want %d", got, want) } if got, want := res.Header.Get("Request-Scheme"), "https"; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Request-Scheme", got, want) } if got, want := res.Header.Get("Warning"), reserr.Error(); !strings.Contains(got, want) { t.Errorf("res.Header.Get(%q): got %q, want to contain %q", "Warning", got, want) } }
func TestIntegrationConnectDownstreamProxy(t *testing.T) { t.Parallel() // Start first proxy to use as downstream. dl, err := net.Listen("tcp", "[::1]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } downstream := NewProxy() defer downstream.Close() dtr := martiantest.NewTransport() dtr.Respond(299) downstream.SetRoundTripper(dtr) ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", 2*time.Hour) if err != nil { t.Fatalf("mitm.NewAuthority(): got %v, want no error", err) } mc, err := mitm.NewConfig(ca, priv) if err != nil { t.Fatalf("mitm.NewConfig(): got %v, want no error", err) } downstream.SetMITM(mc) go downstream.Serve(dl) // Start second proxy as upstream proxy, will CONNECT to downstream proxy. ul, err := net.Listen("tcp", "[::1]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } upstream := NewProxy() defer upstream.Close() // Set upstream proxy's downstream proxy to the host:port of the first proxy. upstream.SetDownstreamProxy(&url.URL{ Host: dl.Addr().String(), }) go upstream.Serve(ul) // Open connection to upstream proxy. conn, err := net.Dial("tcp", ul.Addr().String()) if err != nil { t.Fatalf("net.Dial(): got %v, want no error", err) } defer conn.Close() req, err := http.NewRequest("CONNECT", "//example.com:443", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } // CONNECT example.com:443 HTTP/1.1 // Host: example.com if err := req.Write(conn); err != nil { t.Fatalf("req.Write(): got %v, want no error", err) } // Response from downstream proxy starting MITM. res, err := http.ReadResponse(bufio.NewReader(conn), req) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } if got, want := res.StatusCode, 200; got != want { t.Fatalf("res.StatusCode: got %d, want %d", got, want) } roots := x509.NewCertPool() roots.AddCert(ca) tlsconn := tls.Client(conn, &tls.Config{ // Validate the hostname. ServerName: "example.com", // The certificate will have been MITM'd, verify using the MITM CA // certificate. RootCAs: roots, }) defer tlsconn.Close() req, err = http.NewRequest("GET", "https://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } // GET / HTTP/1.1 // Host: example.com if err := req.Write(tlsconn); err != nil { t.Fatalf("req.Write(): got %v, want no error", err) } // Response from MITM in downstream proxy. res, err = http.ReadResponse(bufio.NewReader(tlsconn), req) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } defer res.Body.Close() if got, want := res.StatusCode, 299; got != want { t.Fatalf("res.StatusCode: got %d, want %d", got, want) } }
func TestIntegrationHTTPDownstreamProxy(t *testing.T) { t.Parallel() // Start first proxy to use as downstream. dl, err := net.Listen("tcp", "[::1]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } downstream := NewProxy() defer downstream.Close() dtr := martiantest.NewTransport() dtr.Respond(299) downstream.SetRoundTripper(dtr) downstream.SetTimeout(600 * time.Millisecond) go downstream.Serve(dl) // Start second proxy as upstream proxy, will write to downstream proxy. ul, err := net.Listen("tcp", "[::1]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } upstream := NewProxy() defer upstream.Close() // Set upstream proxy's downstream proxy to the host:port of the first proxy. upstream.SetDownstreamProxy(&url.URL{ Host: dl.Addr().String(), }) upstream.SetTimeout(600 * time.Millisecond) go upstream.Serve(ul) // Open connection to upstream proxy. conn, err := net.Dial("tcp", ul.Addr().String()) if err != nil { t.Fatalf("net.Dial(): got %v, want no error", err) } defer conn.Close() req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } // GET http://example.com/ HTTP/1.1 // Host: example.com if err := req.WriteProxy(conn); err != nil { t.Fatalf("req.WriteProxy(): got %v, want no error", err) } // Response from downstream proxy. res, err := http.ReadResponse(bufio.NewReader(conn), req) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } if got, want := res.StatusCode, 299; got != want { t.Fatalf("res.StatusCode: got %d, want %d", got, want) } }
func TestIntegrationHTTP(t *testing.T) { t.Parallel() l, err := net.Listen("tcp", "[::1]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } p := NewProxy() defer p.Close() p.SetRequestModifier(nil) p.SetResponseModifier(nil) tr := martiantest.NewTransport() p.SetRoundTripper(tr) p.SetTimeout(200 * time.Millisecond) tm := martiantest.NewModifier() tm.RequestFunc(func(req *http.Request) { ctx := Context(req) ctx.Set("martian.test", "true") }) tm.ResponseFunc(func(res *http.Response) { ctx := Context(res.Request) v, _ := ctx.Get("martian.test") res.Header.Set("Martian-Test", v.(string)) }) p.SetRequestModifier(tm) p.SetResponseModifier(tm) go p.Serve(l) conn, err := net.Dial("tcp", l.Addr().String()) if err != nil { t.Fatalf("net.Dial(): got %v, want no error", err) } defer conn.Close() req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } // GET http://example.com/ HTTP/1.1 // Host: example.com if err := req.WriteProxy(conn); err != nil { t.Fatalf("req.WriteProxy(): got %v, want no error", err) } res, err := http.ReadResponse(bufio.NewReader(conn), req) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } if got, want := res.StatusCode, 200; got != want { t.Fatalf("res.StatusCode: got %d, want %d", got, want) } if got, want := res.Header.Get("Martian-Test"), "true"; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Martian-Test", got, want) } }
func TestIntegration(t *testing.T) { ptr := martiantest.NewTransport() proxy := martian.NewProxy() defer proxy.Close() proxy.SetRoundTripper(ptr) l, err := net.Listen("tcp", "[::1]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } go proxy.Serve(l) m := NewModifier() proxy.SetRequestModifier(m) proxy.SetResponseModifier(m) mux := http.NewServeMux() mux.Handle("/", m) s := httptest.NewServer(mux) defer s.Close() body := strings.NewReader(`{ "header.Modifier": { "scope": ["request", "response"], "name": "Martian-Test", "value": "true" } }`) res, err := http.Post(s.URL, "application/json", body) if err != nil { t.Fatalf("http.Post(%s): got %v, want no error", s.URL, err) } res.Body.Close() if got, want := res.StatusCode, 200; got != want { t.Fatalf("res.StatusCode: got %d, want %d", got, want) } tr := &http.Transport{ Proxy: http.ProxyURL(&url.URL{ Scheme: "http", Host: l.Addr().String(), }), } defer tr.CloseIdleConnections() req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.Header.Set("Connection", "close") res, err = tr.RoundTrip(req) if err != nil { t.Fatalf("transport.RoundTrip(%q): got %v, want no error", req.URL, err) } res.Body.Close() if got, want := res.Header.Get("Martian-Test"), "true"; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Martian-Test", got, want) } }