func TestPriorityGroupModifyRequest(t *testing.T) { var order []string pg := NewGroup() tm50 := martiantest.NewModifier() tm50.RequestFunc(func(*http.Request) { order = append(order, "tm50") }) pg.AddRequestModifier(tm50, 50) tm100a := martiantest.NewModifier() tm100a.RequestFunc(func(*http.Request) { order = append(order, "tm100a") }) pg.AddRequestModifier(tm100a, 100) tm100b := martiantest.NewModifier() tm100b.RequestFunc(func(*http.Request) { order = append(order, "tm100b") }) pg.AddRequestModifier(tm100b, 100) tm75 := martiantest.NewModifier() tm75.RequestFunc(func(*http.Request) { order = append(order, "tm75") }) if err := pg.RemoveRequestModifier(tm75); err != ErrModifierNotFound { t.Fatalf("RemoveRequestModifier(): got %v, want ErrModifierNotFound", err) } pg.AddRequestModifier(tm75, 100) if err := pg.RemoveRequestModifier(tm75); err != nil { t.Fatalf("RemoveRequestModifier(): got %v, want no error", err) } req, err := http.NewRequest("GET", "http://example.com/", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := pg.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if got, want := order, []string{"tm100b", "tm100a", "tm50"}; !reflect.DeepEqual(got, want) { t.Fatalf("reflect.DeepEqual(%v, %v): got false, want true", got, want) } }
func TestModifyRequest(t *testing.T) { f := NewFilter() tm := martiantest.NewModifier() f.SetRequestModifier("id", tm) req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("NewRequest(): got %v, want no error", err) } // No ID, auth required. f.SetAuthRequired(true) ctx := session.FromContext(nil) martian.SetContext(req, ctx) defer martian.RemoveContext(req) if err := f.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } actx := FromContext(ctx) if actx.Error() == nil { t.Error("actx.Error(): got nil, want error") } if tm.RequestModified() { t.Error("tm.RequestModified(): got true, want false") } tm.Reset() // No ID, auth not required. f.SetAuthRequired(false) actx.SetError(nil) if err := f.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if actx.Error() != nil { t.Errorf("actx.Error(): got %v, want no error", err) } if tm.RequestModified() { t.Error("tm.RequestModified(): got true, want false") } // Valid ID. actx.SetError(nil) actx.SetID("id") if err := f.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if actx.Error() != nil { t.Errorf("actx.Error(): got %v, want no error", actx.Error()) } if !tm.RequestModified() { t.Error("tm.RequestModified(): got false, want true") } }
func TestResultModifierAccessors(t *testing.T) { tm := martiantest.NewModifier() r := &Result{ reqmod: tm, resmod: nil, } if reqmod := r.RequestModifier(); reqmod == nil { t.Error("r.RequestModifier: got nil, want reqmod") } if resmod := r.ResponseModifier(); resmod != nil { t.Error("r.ResponseModifier: got resmod, want nil") } r = &Result{ reqmod: nil, resmod: tm, } if reqmod := r.RequestModifier(); reqmod != nil { t.Errorf("r.RequestModifier: got reqmod, want nil") } if resmod := r.ResponseModifier(); resmod == nil { t.Error("r.ResponseModifier: got nil, want resmod") } }
func TestModifyRequest(t *testing.T) { m := NewModifier() m.SetRequestModifier(nil) req, err := http.NewRequest("CONNECT", "https://www.example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } ctx := session.FromContext(nil) martian.SetContext(req, ctx) defer martian.RemoveContext(req) if err := m.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } actx := auth.FromContext(ctx) if got, want := actx.ID(), ""; got != want { t.Errorf("actx.ID(): got %q, want %q", got, want) } // IP with port and modifier with error. tm := martiantest.NewModifier() reqerr := errors.New("request error") tm.RequestError(reqerr) req.RemoteAddr = "1.1.1.1:8111" m.SetRequestModifier(tm) if err := m.ModifyRequest(req); err != reqerr { t.Fatalf("ModifyConnectRequest(): got %v, want %v", err, reqerr) } if got, want := actx.ID(), "1.1.1.1"; got != want { t.Errorf("actx.ID(): got %q, want %q", got, want) } // IP without port and modifier with auth error. req.RemoteAddr = "4.4.4.4" autherr := errors.New("auth error") tm.RequestError(nil) tm.RequestFunc(func(req *http.Request) { ctx := martian.Context(req) actx := auth.FromContext(ctx) actx.SetError(autherr) }) if err := m.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if got, want := actx.ID(), ""; got != want { t.Errorf("actx.ID(): got %q, want %q", got, want) } }
func TestPriorityGroupModifyResponse(t *testing.T) { var order []string pg := NewGroup() tm50 := martiantest.NewModifier() tm50.ResponseFunc(func(*http.Response) { order = append(order, "tm50") }) pg.AddResponseModifier(tm50, 50) tm100a := martiantest.NewModifier() tm100a.ResponseFunc(func(*http.Response) { order = append(order, "tm100a") }) pg.AddResponseModifier(tm100a, 100) tm100b := martiantest.NewModifier() tm100b.ResponseFunc(func(*http.Response) { order = append(order, "tm100b") }) pg.AddResponseModifier(tm100b, 100) tm75 := martiantest.NewModifier() tm75.ResponseFunc(func(*http.Response) { order = append(order, "tm75") }) if err := pg.RemoveResponseModifier(tm75); err != ErrModifierNotFound { t.Fatalf("RemoveResponseModifier(): got %v, want ErrModifierNotFound", err) } pg.AddResponseModifier(tm75, 100) if err := pg.RemoveResponseModifier(tm75); err != nil { t.Fatalf("RemoveResponseModifier(): got %v, want no error", err) } res := proxyutil.NewResponse(200, nil, nil) if err := pg.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if got, want := order, []string{"tm100b", "tm100a", "tm50"}; !reflect.DeepEqual(got, want) { t.Fatalf("reflect.DeepEqual(%v, %v): got false, want true", got, want) } }
func TestProxyAuth(t *testing.T) { m := NewModifier() tm := martiantest.NewModifier() m.SetRequestModifier(tm) m.SetResponseModifier(tm) req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.Header.Set("Proxy-Authorization", "Basic "+encode("user:pass")) ctx, err := session.FromContext(nil) if err != nil { t.Fatalf("session.FromContext(): got %v, want no error", err) } martian.SetContext(req, ctx) defer martian.RemoveContext(req) if err := m.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } actx := auth.FromContext(ctx) if got, want := actx.ID(), "user:pass"; got != want { t.Fatalf("actx.ID(): got %q, want %q", got, want) } if err := m.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if !tm.RequestModified() { t.Error("tm.RequestModified(): got false, want true") } res := proxyutil.NewResponse(200, nil, req) if err := m.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if err := m.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if !tm.ResponseModified() { t.Error("tm.ResponseModified(): got false, want true") } if got, want := res.StatusCode, 200; got != want { t.Errorf("res.StatusCode: got %d, want %d", got, want) } if got, want := res.Header.Get("Proxy-Authenticate"), ""; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Proxy-Authenticate", got, want) } }
func TestIntegrationHTTPDownstreamProxyError(t *testing.T) { t.Parallel() l, err := net.Listen("tcp", "[::1]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } p := NewProxy() defer p.Close() // Set proxy's downstream proxy to invalid host:port to force failure. p.SetDownstreamProxy(&url.URL{ Host: "[::1]:0", }) p.SetTimeout(600 * time.Millisecond) tm := martiantest.NewModifier() reserr := errors.New("response error") tm.ResponseError(reserr) p.SetResponseModifier(tm) go p.Serve(l) // Open connection to upstream proxy. conn, err := net.Dial("tcp", l.Addr().String()) if err != nil { t.Fatalf("net.Dial(): got %v, want no error", err) } defer conn.Close() req, err := http.NewRequest("CONNECT", "//example.com:443", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } // CONNECT example.com:443 HTTP/1.1 // Host: example.com if err := req.Write(conn); err != nil { t.Fatalf("req.Write(): got %v, want no error", err) } // Response from upstream proxy, assuming downstream proxy failed to CONNECT. res, err := http.ReadResponse(bufio.NewReader(conn), req) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } if got, want := res.StatusCode, 502; got != want { t.Fatalf("res.StatusCode: got %d, want %d", got, want) } if got, want := res.Header["Warning"][1], reserr.Error(); !strings.Contains(got, want) { t.Errorf("res.Header.get(%q): got %q, want to contain %q", "Warning", got, want) } }
func TestModifyResponse(t *testing.T) { m := NewModifier() m.SetResponseModifier(nil) req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } ctx, err := session.FromContext(nil) if err != nil { t.Fatalf("session.FromContext(): got %v, want no error", err) } martian.SetContext(req, ctx) defer martian.RemoveContext(req) res := proxyutil.NewResponse(200, nil, req) if err := m.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } // Modifier with error. tm := martiantest.NewModifier() reserr := errors.New("response error") tm.ResponseError(reserr) m.SetResponseModifier(tm) if err := m.ModifyResponse(res); err != reserr { t.Fatalf("ModifyResponse(): got %v, want %v", err, reserr) } // Modifier with auth error. tm.ResponseError(nil) autherr := errors.New("auth error") tm.ResponseFunc(func(res *http.Response) { ctx := martian.Context(res.Request) actx := auth.FromContext(ctx) actx.SetError(autherr) }) actx := auth.FromContext(ctx) actx.SetID("bad-auth") if err := m.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if got, want := res.StatusCode, 403; got != want { t.Errorf("res.StatusCode: got %d, want %d", got, want) } if got, want := actx.Error(), autherr; got != want { t.Errorf("actx.Error(): got %v, want %v", got, want) } }
func TestModifyResponse(t *testing.T) { f := NewFilter() tm := martiantest.NewModifier() f.SetResponseModifier("id", tm) req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } res := proxyutil.NewResponse(200, nil, req) // No ID, auth required. f.SetAuthRequired(true) ctx, remove, err := martian.TestContext(req) if err != nil { t.Fatalf("martian.TestContext(): got %v, want no error", err) } defer remove() if err := f.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } actx := FromContext(ctx) if actx.Error() == nil { t.Error("actx.Error(): got nil, want error") } if tm.ResponseModified() { t.Error("tm.RequestModified(): got true, want false") } // No ID, no auth required. f.SetAuthRequired(false) actx.SetError(nil) if err := f.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if tm.ResponseModified() { t.Error("tm.ResponseModified(): got true, want false") } // Valid ID. actx.SetID("id") if err := f.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if !tm.ResponseModified() { t.Error("tm.ResponseModified(): got false, want true") } }
func TestIntegrationSkipRoundTrip(t *testing.T) { t.Parallel() l, err := net.Listen("tcp", "[::1]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } p := NewProxy() defer p.Close() // Transport will be skipped, no 500. tr := martiantest.NewTransport() tr.Respond(500) p.SetRoundTripper(tr) p.SetTimeout(200 * time.Millisecond) tm := martiantest.NewModifier() tm.RequestFunc(func(req *http.Request) { ctx := Context(req) ctx.SkipRoundTrip() }) p.SetRequestModifier(tm) go p.Serve(l) conn, err := net.Dial("tcp", l.Addr().String()) if err != nil { t.Fatalf("net.Dial(): got %v, want no error", err) } defer conn.Close() req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } // GET http://example.com/ HTTP/1.1 // Host: example.com if err := req.WriteProxy(conn); err != nil { t.Fatalf("req.WriteProxy(): got %v, want no error", err) } // Response from skipped round trip. res, err := http.ReadResponse(bufio.NewReader(conn), req) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } defer res.Body.Close() if got, want := res.StatusCode, 200; got != want { t.Errorf("res.StatusCode: got %d, want %d", got, want) } }
func TestModifyResponseHaltsOnError(t *testing.T) { fg := NewGroup() reserr := errors.New("request error") tm := martiantest.NewModifier() tm.ResponseError(reserr) fg.AddResponseModifier(tm) tm2 := martiantest.NewModifier() fg.AddResponseModifier(tm2) res := proxyutil.NewResponse(200, nil, nil) if err := fg.ModifyResponse(res); err != reserr { t.Fatalf("fg.ModifyResponse(): got %v, want %v", err, reserr) } if tm2.ResponseModified() { t.Error("tm2.ResponseModified(): got true, want false") } }
func TestIntegrationTransparentHTTP(t *testing.T) { t.Parallel() l, err := net.Listen("tcp", "[::1]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } p := NewProxy() defer p.Close() tr := martiantest.NewTransport() p.SetRoundTripper(tr) p.SetTimeout(200 * time.Millisecond) tm := martiantest.NewModifier() p.SetRequestModifier(tm) p.SetResponseModifier(tm) go p.Serve(l) conn, err := net.Dial("tcp", l.Addr().String()) if err != nil { t.Fatalf("net.Dial(): got %v, want no error", err) } defer conn.Close() req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } // GET / HTTP/1.1 // Host: www.example.com if err := req.Write(conn); err != nil { t.Fatalf("req.Write(): got %v, want no error", err) } res, err := http.ReadResponse(bufio.NewReader(conn), req) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } if got, want := res.StatusCode, 200; got != want { t.Fatalf("res.StatusCode: got %d, want %d", got, want) } if !tm.RequestModified() { t.Error("tm.RequestModified(): got false, want true") } if !tm.ResponseModified() { t.Error("tm.ResponseModified(): got false, want true") } }
func TestPriorityGroupModifyResponseHaltsOnError(t *testing.T) { pg := NewGroup() reserr := errors.New("response error") tm := martiantest.NewModifier() tm.ResponseError(reserr) pg.AddResponseModifier(tm, 100) tm2 := martiantest.NewModifier() pg.AddResponseModifier(tm2, 75) res := proxyutil.NewResponse(200, nil, nil) if err := pg.ModifyResponse(res); err != reserr { t.Fatalf("ModifyRequest(): got %v, want %v", err, reserr) } if tm2.ResponseModified() { t.Error("tm2.ResponseModified(): got true, want false") } }
func TestNewStack(t *testing.T) { stack, fg := NewStack("martian") tm := martiantest.NewModifier() fg.AddRequestModifier(tm) fg.AddResponseModifier(tm) req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } _, remove, err := martian.TestContext(req) if err != nil { t.Fatalf("martian.TestContext(): got %v, want no error", err) } defer remove() // Hop-by-hop header to be removed. req.Header.Set("Hop-By-Hop", "true") req.Header.Set("Connection", "Hop-By-Hop") req.RemoteAddr = "10.0.0.1:5000" if err := stack.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if got, want := req.Header.Get("Hop-By-Hop"), ""; got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", "Hop-By-Hop", got, want) } if got, want := req.Header.Get("X-Forwarded-For"), "10.0.0.1"; got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", "X-Forwarded-For", got, want) } if got, want := req.Header.Get("Via"), "1.1 martian"; got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", "Via", got, want) } res := proxyutil.NewResponse(200, nil, req) // Hop-by-hop header to be removed. res.Header.Set("Hop-By-Hop", "true") res.Header.Set("Connection", "Hop-By-Hop") if err := stack.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if got, want := res.Header.Get("Hop-By-Hop"), ""; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Hop-By-Hop", got, want) } }
func TestModifyRequestHaltsOnError(t *testing.T) { fg := NewGroup() reqerr := errors.New("request error") tm := martiantest.NewModifier() tm.RequestError(reqerr) fg.AddRequestModifier(tm) tm2 := martiantest.NewModifier() fg.AddRequestModifier(tm2) req, err := http.NewRequest("GET", "http://example.com/", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := fg.ModifyRequest(req); err != reqerr { t.Fatalf("fg.ModifyRequest(): got %v, want %v", err, reqerr) } if tm2.RequestModified() { t.Error("tm2.RequestModified(): got true, want false") } }
func TestModifyResponse(t *testing.T) { fg := NewGroup() tm := martiantest.NewModifier() fg.AddResponseModifier(tm) res := proxyutil.NewResponse(200, nil, nil) if err := fg.ModifyResponse(res); err != nil { t.Fatalf("fg.ModifyResponse(): got %v, want no error", err) } if !tm.ResponseModified() { t.Error("tm.ResponseModified(): got false, want true") } }
func TestModifyRequest(t *testing.T) { f := NewFilter("mARTian-teSTInG", "true") f.SetRequestModifier(nil) req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := f.ModifyRequest(req); err != nil { t.Errorf("ModifyRequest(): got %v, want no error", err) } tt := []struct { name string values []string want bool }{ { name: "Martian-Production", values: []string{"true"}, want: false, }, { name: "Martian-Testing", values: []string{"see-next-value", "true"}, want: true, }, } for i, tc := range tt { f := NewFilter("mARTian-teSTInG", "true") tm := martiantest.NewModifier() f.SetRequestModifier(tm) req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("%d. http.NewRequest(): got %v, want no error", i, err) } req.Header[tc.name] = tc.values if err := f.ModifyRequest(req); err != nil { t.Fatalf("%d. ModifyRequest(): got %v, want no error", i, err) } if tm.RequestModified() != tc.want { t.Errorf("%d. tm.RequestModified(): got %t, want %t", i, tm.RequestModified(), tc.want) } } }
func TestModifyRequest(t *testing.T) { fg := NewGroup() tm := martiantest.NewModifier() fg.AddRequestModifier(tm) req, err := http.NewRequest("GET", "/", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := fg.ModifyRequest(req); err != nil { t.Fatalf("fg.ModifyRequest(): got %v, want no error", err) } if !tm.RequestModified() { t.Error("tm.RequestModified(): got false, want true") } }
func TestModifyResponse(t *testing.T) { f := NewFilter("mARTian-teSTInG", "true") f.SetResponseModifier(nil) res := proxyutil.NewResponse(200, nil, nil) if err := f.ModifyResponse(res); err != nil { t.Errorf("ModifyResponse(): got %v, want no error", err) } tt := []struct { name string values []string want bool }{ { name: "Martian-Production", values: []string{"true"}, want: false, }, { name: "Martian-Testing", values: []string{"see-next-value", "true"}, want: true, }, } for i, tc := range tt { f := NewFilter("mARTian-teSTInG", "true") tm := martiantest.NewModifier() f.SetResponseModifier(tm) res := proxyutil.NewResponse(200, nil, nil) res.Header[tc.name] = tc.values if err := f.ModifyResponse(res); err != nil { t.Fatalf("%d. ModifyResponse(): got %v, want no error", i, err) } if tm.ResponseModified() != tc.want { t.Errorf("%d. tm.ResponseModified(): got %t, want %t", i, tm.ResponseModified(), tc.want) } } }
func TestFilter(t *testing.T) { f := NewFilter() if f.RequestModifier("id") != nil { t.Fatalf("f.RequestModifier(%q): got reqmod, want nil", "id") } if f.ResponseModifier("id") != nil { t.Fatalf("f.ResponseModifier(%q): got resmod, want nil", "id") } tm := martiantest.NewModifier() f.SetRequestModifier("id", tm) f.SetResponseModifier("id", tm) if f.RequestModifier("id") != tm { t.Errorf("f.RequestModifier(%q): got nil, want martiantest.Modifier", "id") } if f.ResponseModifier("id") != tm { t.Errorf("f.ResponseModifier(%q): got nil, want martiantest.Modifier", "id") } }
func TestModifyRequest(t *testing.T) { m := NewModifier() tm := martiantest.NewModifier() m.SetRequestModifier(tm) req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := m.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if !tm.RequestModified() { t.Error("tm.RequestModified(): got false, want true") } m.SetRequestModifier(nil) if err := m.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } }
func TestIntegrationConnect(t *testing.T) { t.Parallel() l, err := net.Listen("tcp", "[::1]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } p := NewProxy() defer p.Close() // Test TLS server. ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", time.Hour) if err != nil { t.Fatalf("mitm.NewAuthority(): got %v, want no error", err) } mc, err := mitm.NewConfig(ca, priv) if err != nil { t.Fatalf("mitm.NewConfig(): got %v, want no error", err) } tl, err := net.Listen("tcp", "[::1]:0") if err != nil { t.Fatalf("tls.Listen(): got %v, want no error", err) } tl = tls.NewListener(tl, mc.TLS()) go http.Serve(tl, http.HandlerFunc( func(rw http.ResponseWriter, req *http.Request) { rw.WriteHeader(299) })) tm := martiantest.NewModifier() reqerr := errors.New("request error") reserr := errors.New("response error") // Force the CONNECT request to dial the local TLS server. tm.RequestFunc(func(req *http.Request) { req.URL.Host = tl.Addr().String() }) tm.RequestError(reqerr) tm.ResponseError(reserr) p.SetRequestModifier(tm) p.SetResponseModifier(tm) go p.Serve(l) conn, err := net.Dial("tcp", l.Addr().String()) if err != nil { t.Fatalf("net.Dial(): got %v, want no error", err) } defer conn.Close() req, err := http.NewRequest("CONNECT", "//example.com:443", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } // CONNECT example.com:443 HTTP/1.1 // Host: example.com // // Rewritten to CONNECT to host:port in CONNECT request modifier. if err := req.Write(conn); err != nil { t.Fatalf("req.Write(): got %v, want no error", err) } // CONNECT response after establishing tunnel. res, err := http.ReadResponse(bufio.NewReader(conn), req) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } if got, want := res.StatusCode, 200; got != want { t.Fatalf("res.StatusCode: got %d, want %d", got, want) } if !tm.RequestModified() { t.Error("tm.RequestModified(): got false, want true") } if !tm.ResponseModified() { t.Error("tm.ResponseModified(): got false, want true") } if got, want := res.Header.Get("Warning"), reserr.Error(); !strings.Contains(got, want) { t.Errorf("res.Header.Get(%q): got %q, want to contain %q", "Warning", got, want) } roots := x509.NewCertPool() roots.AddCert(ca) tlsconn := tls.Client(conn, &tls.Config{ ServerName: "example.com", RootCAs: roots, }) defer tlsconn.Close() req, err = http.NewRequest("GET", "https://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.Header.Set("Connection", "close") // GET / HTTP/1.1 // Host: example.com // Connection: close if err := req.Write(tlsconn); err != nil { t.Fatalf("req.Write(): got %v, want no error", err) } res, err = http.ReadResponse(bufio.NewReader(tlsconn), req) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } defer res.Body.Close() if got, want := res.StatusCode, 299; got != want { t.Fatalf("res.StatusCode: got %d, want %d", got, want) } if got, want := res.Header.Get("Warning"), reserr.Error(); strings.Contains(got, want) { t.Errorf("res.Header.Get(%q): got %s, want to not contain %s", "Warning", got, want) } }
func TestHTTPThroughConnectWithMITM(t *testing.T) { t.Parallel() l, err := net.Listen("tcp", "[::1]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } p := NewProxy() defer p.Close() tm := martiantest.NewModifier() tm.RequestFunc(func(req *http.Request) { ctx := NewContext(req) ctx.SkipRoundTrip() if req.Method != "GET" && req.Method != "CONNECT" { t.Errorf("unexpected method on request handler: %v", req.Method) } }) p.SetRequestModifier(tm) ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", 2*time.Hour) if err != nil { t.Fatalf("mitm.NewAuthority(): got %v, want no error", err) } mc, err := mitm.NewConfig(ca, priv) if err != nil { t.Fatalf("mitm.NewConfig(): got %v, want no error", err) } p.SetMITM(mc) go p.Serve(l) conn, err := net.Dial("tcp", l.Addr().String()) if err != nil { t.Fatalf("net.Dial(): got %v, want no error", err) } defer conn.Close() req, err := http.NewRequest("CONNECT", "//example.com:80", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } // CONNECT example.com:80 HTTP/1.1 // Host: example.com if err := req.Write(conn); err != nil { t.Fatalf("req.Write(): got %v, want no error", err) } // Response skipped round trip. res, err := http.ReadResponse(bufio.NewReader(conn), req) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } res.Body.Close() if got, want := res.StatusCode, 200; got != want { t.Errorf("res.StatusCode: got %d, want %d", got, want) } req, err = http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } // GET http://example.com/ HTTP/1.1 // Host: example.com if err := req.WriteProxy(conn); err != nil { t.Fatalf("req.WriteProxy(): got %v, want no error", err) } // Response from skipped round trip. res, err = http.ReadResponse(bufio.NewReader(conn), req) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } res.Body.Close() if got, want := res.StatusCode, 200; got != want { t.Errorf("res.StatusCode: got %d, want %d", got, want) } req, err = http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } // GET http://example.com/ HTTP/1.1 // Host: example.com if err := req.WriteProxy(conn); err != nil { t.Fatalf("req.WriteProxy(): got %v, want no error", err) } // Response from skipped round trip. res, err = http.ReadResponse(bufio.NewReader(conn), req) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } res.Body.Close() if got, want := res.StatusCode, 200; got != want { t.Errorf("res.StatusCode: got %d, want %d", got, want) } }
func TestIntegrationTransparentMITM(t *testing.T) { t.Parallel() ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", 2*time.Hour) if err != nil { t.Fatalf("mitm.NewAuthority(): got %v, want no error", err) } mc, err := mitm.NewConfig(ca, priv) if err != nil { t.Fatalf("mitm.NewConfig(): got %v, want no error", err) } // Start TLS listener with config that will generate certificates based on // SNI from connection. // // BUG: tls.Listen will not accept a tls.Config where Certificates is empty, // even though it is supported by tls.Server when GetCertificate is not nil. l, err := net.Listen("tcp", "[::1]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } l = tls.NewListener(l, mc.TLS()) p := NewProxy() defer p.Close() tr := martiantest.NewTransport() tr.Func(func(req *http.Request) (*http.Response, error) { res := proxyutil.NewResponse(200, nil, req) res.Header.Set("Request-Scheme", req.URL.Scheme) return res, nil }) p.SetRoundTripper(tr) tm := martiantest.NewModifier() p.SetRequestModifier(tm) p.SetResponseModifier(tm) go p.Serve(l) roots := x509.NewCertPool() roots.AddCert(ca) tlsconn, err := tls.Dial("tcp", l.Addr().String(), &tls.Config{ // Verify the hostname is example.com. ServerName: "example.com", // The certificate will have been generated during MITM, so we need to // verify it with the generated CA certificate. RootCAs: roots, }) if err != nil { t.Fatalf("tls.Dial(): got %v, want no error", err) } defer tlsconn.Close() req, err := http.NewRequest("GET", "https://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } // Write Encrypted request directly, no CONNECT. // GET / HTTP/1.1 // Host: example.com if err := req.Write(tlsconn); err != nil { t.Fatalf("req.Write(): got %v, want no error", err) } res, err := http.ReadResponse(bufio.NewReader(tlsconn), req) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } defer res.Body.Close() if got, want := res.StatusCode, 200; got != want { t.Fatalf("res.StatusCode: got %d, want %d", got, want) } if got, want := res.Header.Get("Request-Scheme"), "https"; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Request-Scheme", got, want) } if !tm.RequestModified() { t.Errorf("tm.RequestModified(): got false, want true") } if !tm.ResponseModified() { t.Errorf("tm.ResponseModified(): got false, want true") } }
func TestIntegrationTLSHandshakeErrorCallback(t *testing.T) { t.Parallel() l, err := net.Listen("tcp", "[::1]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } p := NewProxy() defer p.Close() // Test TLS server. ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", time.Hour) if err != nil { t.Fatalf("mitm.NewAuthority(): got %v, want no error", err) } mc, err := mitm.NewConfig(ca, priv) if err != nil { t.Fatalf("mitm.NewConfig(): got %v, want no error", err) } cb := make(chan error) mc.SetHandshakeErrorCallback(func(_ *http.Request, err error) { cb <- err }) p.SetMITM(mc) tl, err := net.Listen("tcp", "[::1]:0") if err != nil { t.Fatalf("tls.Listen(): got %v, want no error", err) } tl = tls.NewListener(tl, mc.TLS()) go http.Serve(tl, http.HandlerFunc( func(rw http.ResponseWriter, req *http.Request) { rw.WriteHeader(200) })) tm := martiantest.NewModifier() // Force the CONNECT request to dial the local TLS server. tm.RequestFunc(func(req *http.Request) { req.URL.Host = tl.Addr().String() }) go p.Serve(l) conn, err := net.Dial("tcp", l.Addr().String()) if err != nil { t.Fatalf("net.Dial(): got %v, want no error", err) } defer conn.Close() req, err := http.NewRequest("CONNECT", "//example.com:443", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } // CONNECT example.com:443 HTTP/1.1 // Host: example.com // // Rewritten to CONNECT to host:port in CONNECT request modifier. if err := req.Write(conn); err != nil { t.Fatalf("req.Write(): got %v, want no error", err) } // CONNECT response after establishing tunnel. if _, err := http.ReadResponse(bufio.NewReader(conn), req); err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } tlsconn := tls.Client(conn, &tls.Config{ ServerName: "example.com", // Client has no cert so it will get "x509: certificate signed by unknown authority" from the // handshake and send "remote error: bad certificate" to the server. RootCAs: x509.NewCertPool(), }) defer tlsconn.Close() req, err = http.NewRequest("GET", "https://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.Header.Set("Connection", "close") if got, want := req.Write(tlsconn), "x509: certificate signed by unknown authority"; !strings.Contains(got.Error(), want) { t.Fatalf("Got incorrect error from Client Handshake(), got: %v, want: %v", got, want) } if got, want := <-cb, "remote error: bad certificate"; !strings.Contains(got.Error(), want) { t.Fatalf("Got incorrect error from Server Handshake(), got: %v, want: %v", got, want) } }
func TestIntegrationMITM(t *testing.T) { t.Parallel() l, err := net.Listen("tcp", "[::1]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } p := NewProxy() defer p.Close() tr := martiantest.NewTransport() tr.Func(func(req *http.Request) (*http.Response, error) { res := proxyutil.NewResponse(200, nil, req) res.Header.Set("Request-Scheme", req.URL.Scheme) return res, nil }) p.SetRoundTripper(tr) p.SetTimeout(600 * time.Millisecond) ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", 2*time.Hour) if err != nil { t.Fatalf("mitm.NewAuthority(): got %v, want no error", err) } mc, err := mitm.NewConfig(ca, priv) if err != nil { t.Fatalf("mitm.NewConfig(): got %v, want no error", err) } p.SetMITM(mc) tm := martiantest.NewModifier() reqerr := errors.New("request error") reserr := errors.New("response error") tm.RequestError(reqerr) tm.ResponseError(reserr) p.SetRequestModifier(tm) p.SetResponseModifier(tm) go p.Serve(l) conn, err := net.Dial("tcp", l.Addr().String()) if err != nil { t.Fatalf("net.Dial(): got %v, want no error", err) } defer conn.Close() req, err := http.NewRequest("CONNECT", "//example.com:443", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } // CONNECT example.com:443 HTTP/1.1 // Host: example.com if err := req.Write(conn); err != nil { t.Fatalf("req.Write(): got %v, want no error", err) } // Response MITM'd from proxy. res, err := http.ReadResponse(bufio.NewReader(conn), req) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } if got, want := res.StatusCode, 200; got != want { t.Errorf("res.StatusCode: got %d, want %d", got, want) } if got, want := res.Header.Get("Warning"), reserr.Error(); !strings.Contains(got, want) { t.Errorf("res.Header.Get(%q): got %q, want to contain %q", "Warning", got, want) } roots := x509.NewCertPool() roots.AddCert(ca) tlsconn := tls.Client(conn, &tls.Config{ ServerName: "example.com", RootCAs: roots, }) defer tlsconn.Close() req, err = http.NewRequest("GET", "https://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } // GET / HTTP/1.1 // Host: example.com if err := req.Write(tlsconn); err != nil { t.Fatalf("req.Write(): got %v, want no error", err) } // Response from MITM proxy. res, err = http.ReadResponse(bufio.NewReader(tlsconn), req) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } defer res.Body.Close() if got, want := res.StatusCode, 200; got != want { t.Errorf("res.StatusCode: got %d, want %d", got, want) } if got, want := res.Header.Get("Request-Scheme"), "https"; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Request-Scheme", got, want) } if got, want := res.Header.Get("Warning"), reserr.Error(); !strings.Contains(got, want) { t.Errorf("res.Header.Get(%q): got %q, want to contain %q", "Warning", got, want) } }
func TestFilterModifyRequest(t *testing.T) { tt := []struct { want bool match string url *url.URL }{ { match: "https://www.example.com", url: &url.URL{Scheme: "https"}, want: true, }, { match: "http://www.martian.local", url: &url.URL{Host: "www.martian.local"}, want: true, }, { match: "http://www.example.com/test", url: &url.URL{Path: "/test"}, want: true, }, { match: "http://www.example.com?test=true", url: &url.URL{RawQuery: "test=true"}, want: true, }, { match: "http://www.example.com#test", url: &url.URL{Fragment: "test"}, want: true, }, { match: "https://martian.local/test?test=true#test", url: &url.URL{ Scheme: "https", Host: "martian.local", Path: "/test", RawQuery: "test=true", Fragment: "test", }, want: true, }, { match: "https://www.example.com", url: &url.URL{Scheme: "http"}, want: false, }, { match: "http://www.martian.external", url: &url.URL{Host: "www.martian.local"}, want: false, }, { match: "http://www.example.com/testing", url: &url.URL{Path: "/test"}, want: false, }, { match: "http://www.example.com?test=false", url: &url.URL{RawQuery: "test=true"}, want: false, }, { match: "http://www.example.com#test", url: &url.URL{Fragment: "testing"}, want: false, }, } for i, tc := range tt { req, err := http.NewRequest("GET", tc.match, nil) if err != nil { t.Fatalf("%d. NewRequest(): got %v, want no error", i, err) } mod := NewFilter(tc.url) tm := martiantest.NewModifier() mod.SetRequestModifier(tm) if err := mod.ModifyRequest(req); err != nil { t.Fatalf("%d. ModifyRequest(): got %q, want no error", i, err) } if tm.RequestModified() != tc.want { t.Errorf("tm.RequestModified(): got %t, want %t", tm.RequestModified(), tc.want) } } }
func TestIntegrationHTTP100Continue(t *testing.T) { t.Parallel() l, err := net.Listen("tcp", "[::1]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } p := NewProxy() defer p.Close() p.SetTimeout(2 * time.Second) sl, err := net.Listen("tcp", "[::1]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } go func() { conn, err := sl.Accept() if err != nil { log.Errorf("proxy_test: failed to accept connection: %v", err) return } defer conn.Close() log.Infof("proxy_test: accepted connection: %s", conn.RemoteAddr()) req, err := http.ReadRequest(bufio.NewReader(conn)) if err != nil { log.Errorf("proxy_test: failed to read request: %v", err) return } if req.Header.Get("Expect") == "100-continue" { log.Infof("proxy_test: received 100-continue request") conn.Write([]byte("HTTP/1.1 100 Continue\r\n\r\n")) log.Infof("proxy_test: sent 100-continue response") } else { log.Infof("proxy_test: received non 100-continue request") res := proxyutil.NewResponse(417, nil, req) res.Header.Set("Connection", "close") res.Write(conn) return } res := proxyutil.NewResponse(200, req.Body, req) res.Header.Set("Connection", "close") res.Write(conn) log.Infof("proxy_test: sent 200 response") }() tm := martiantest.NewModifier() p.SetRequestModifier(tm) p.SetResponseModifier(tm) go p.Serve(l) conn, err := net.Dial("tcp", l.Addr().String()) if err != nil { t.Fatalf("net.Dial(): got %v, want no error", err) } defer conn.Close() host := sl.Addr().String() raw := fmt.Sprintf("POST http://%s/ HTTP/1.1\r\n"+ "Host: %s\r\n"+ "Content-Length: 12\r\n"+ "Expect: 100-continue\r\n\r\n", host, host) if _, err := conn.Write([]byte(raw)); err != nil { t.Fatalf("conn.Write(headers): got %v, want no error", err) } go func() { select { case <-time.After(time.Second): conn.Write([]byte("body content")) } }() res, err := http.ReadResponse(bufio.NewReader(conn), nil) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } defer res.Body.Close() if got, want := res.StatusCode, 200; got != want { t.Fatalf("res.StatusCode: got %d, want %d", got, want) } got, err := ioutil.ReadAll(res.Body) if err != nil { t.Fatalf("ioutil.ReadAll(): got %v, want no error", err) } if want := []byte("body content"); !bytes.Equal(got, want) { t.Errorf("res.Body: got %q, want %q", got, want) } if !tm.RequestModified() { t.Error("tm.RequestModified(): got false, want true") } if !tm.ResponseModified() { t.Error("tm.ResponseModified(): got false, want true") } }
func TestIntegrationHTTP(t *testing.T) { t.Parallel() l, err := net.Listen("tcp", "[::1]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } p := NewProxy() defer p.Close() p.SetRequestModifier(nil) p.SetResponseModifier(nil) tr := martiantest.NewTransport() p.SetRoundTripper(tr) p.SetTimeout(200 * time.Millisecond) tm := martiantest.NewModifier() tm.RequestFunc(func(req *http.Request) { ctx := Context(req) ctx.Set("martian.test", "true") }) tm.ResponseFunc(func(res *http.Response) { ctx := Context(res.Request) v, _ := ctx.Get("martian.test") res.Header.Set("Martian-Test", v.(string)) }) p.SetRequestModifier(tm) p.SetResponseModifier(tm) go p.Serve(l) conn, err := net.Dial("tcp", l.Addr().String()) if err != nil { t.Fatalf("net.Dial(): got %v, want no error", err) } defer conn.Close() req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } // GET http://example.com/ HTTP/1.1 // Host: example.com if err := req.WriteProxy(conn); err != nil { t.Fatalf("req.WriteProxy(): got %v, want no error", err) } res, err := http.ReadResponse(bufio.NewReader(conn), req) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } if got, want := res.StatusCode, 200; got != want { t.Fatalf("res.StatusCode: got %d, want %d", got, want) } if got, want := res.Header.Get("Martian-Test"), "true"; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Martian-Test", got, want) } }
func TestProxyAuthInvalidCredentials(t *testing.T) { m := NewModifier() autherr := errors.New("auth error") tm := martiantest.NewModifier() tm.RequestFunc(func(req *http.Request) { ctx := martian.Context(req) actx := auth.FromContext(ctx) actx.SetError(autherr) }) tm.ResponseFunc(func(res *http.Response) { ctx := martian.Context(res.Request) actx := auth.FromContext(ctx) actx.SetError(autherr) }) m.SetRequestModifier(tm) m.SetResponseModifier(tm) req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.Header.Set("Proxy-Authorization", "Basic "+encode("user:pass")) ctx, err := session.FromContext(nil) if err != nil { t.Fatalf("session.FromContext(): got %v, want no error", err) } martian.SetContext(req, ctx) defer martian.RemoveContext(req) if err := m.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if !tm.RequestModified() { t.Error("tm.RequestModified(): got false, want true") } actx := auth.FromContext(ctx) if actx.Error() != autherr { t.Fatalf("auth.Error(): got %v, want %v", actx.Error(), autherr) } actx.SetError(nil) res := proxyutil.NewResponse(200, nil, req) if err := m.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if !tm.ResponseModified() { t.Error("tm.ResponseModified(): got false, want true") } if actx.Error() != autherr { t.Fatalf("auth.Error(): got %v, want %v", actx.Error(), autherr) } if got, want := res.StatusCode, http.StatusProxyAuthRequired; got != want { t.Errorf("res.StatusCode: got %d, want %d", got, want) } if got, want := res.Header.Get("Proxy-Authenticate"), "Basic"; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Proxy-Authenticate", got, want) } }