func TestProxy(t *testing.T) { goodOrigin := buildOrigin(false) defer goodOrigin.Close() prematureCloser := buildOrigin(true) defer prematureCloser.Close() server := httptest.NewServer(filters.Join( New(&Options{ IdleTimeout: 500 * time.Second, OnRequest: func(req *http.Request) { req.Header.Set(fakeRequestHeader, "faker") }, OnResponse: func(resp *http.Response) *http.Response { // Add fake response header resp.Header.Set(fakeResponseHeader, "fakeresp") return resp }, }), filters.Adapt(http.NotFoundHandler()))) defer server.Close() doTestProxy(t, goodOrigin, server, false) doTestProxy(t, goodOrigin, server, true) doTestProxy(t, prematureCloser, server, false) doTestProxy(t, prematureCloser, server, true) }
func basicServer(maxConns uint64, idleTimeout time.Duration) *Server { filterChain := filters.Join( commonfilter.New(&commonfilter.Options{ AllowLocalhost: testingLocal, }), httpconnect.New(&httpconnect.Options{ IdleTimeout: idleTimeout, }), forward.New(&forward.Options{ IdleTimeout: idleTimeout, }), ) // Create server srv := NewServer(filterChain) // Add net.Listener wrappers for inbound connections srv.AddListenerWrappers( // Limit max number of simultaneous connections func(ls net.Listener) net.Listener { return listeners.NewLimitedListener(ls, maxConns) }, // Close connections after 30 seconds of no activity func(ls net.Listener) net.Listener { return listeners.NewIdleConnListener(ls, idleTimeout) }, ) return srv }
// A proxy with a custom origin server connection timeout func impatientProxy(maxConns uint64, idleTimeout time.Duration) (string, error) { filterChain := filters.Join( httpconnect.New(&httpconnect.Options{ IdleTimeout: idleTimeout, }), forward.New(&forward.Options{ IdleTimeout: idleTimeout, }), ) srv := NewServer(filterChain) // Add net.Listener wrappers for inbound connections srv.AddListenerWrappers( // Close connections after 30 seconds of no activity func(ls net.Listener) net.Listener { return listeners.NewIdleConnListener(ls, time.Second*30) }, ) ready := make(chan string) wait := func(addr string) { ready <- addr } var err error go func(err *error) { if *err = srv.ListenAndServeHTTP("localhost:0", wait); err != nil { log.Errorf("Unable to serve: %v", err) } }(&err) return <-ready, err }
func main() { var err error _ = flag.CommandLine.Parse(os.Args[1:]) if *help { flag.Usage() return } // Logging // TODO: use real parameters err = logging.Init("instanceid", "version", "releasedate", "") if err != nil { log.Error(err) } filterChain := filters.Join( commonfilter.New(&commonfilter.Options{ AllowLocalhost: testingLocal, }), httpconnect.New(&httpconnect.Options{IdleTimeout: time.Duration(*idleClose) * time.Second}), pforward.New(&pforward.Options{Force: true, IdleTimeout: time.Duration(*idleClose) * time.Second}), ) // Create server srv := server.NewServer(filterChain) // Add net.Listener wrappers for inbound connections srv.AddListenerWrappers( // Limit max number of simultaneous connections func(ls net.Listener) net.Listener { return listeners.NewLimitedListener(ls, *maxConns) }, // Close connections after 30 seconds of no activity func(ls net.Listener) net.Listener { return listeners.NewIdleConnListener(ls, time.Duration(*idleClose)*time.Second) }, ) // Serve HTTP/S if *https { err = srv.ListenAndServeHTTPS(*addr, *keyfile, *certfile, nil) } else { err = srv.ListenAndServeHTTP(*addr, nil) } if err != nil { log.Errorf("Error serving: %v", err) } }
func TestFilterTunnelPorts(t *testing.T) { origin := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { w.WriteHeader(200) w.Write([]byte("hi")) })) origin.StartTLS() defer origin.Close() ou, _ := url.Parse(origin.URL) _, _port, _ := net.SplitHostPort(ou.Host) port, _ := strconv.Atoi(_port) server := httptest.NewServer(filters.Join( New(&Options{AllowedPorts: []int{port, 443}, IdleTimeout: 30 * time.Second}), filters.Adapt(http.NotFoundHandler()))) defer server.Close() u, _ := url.Parse(server.URL) client := http.Client{Transport: &http.Transport{ Proxy: func(req *http.Request) (*url.URL, error) { return u, nil }, DisableKeepAlives: true, }} req, _ := http.NewRequest("CONNECT", "https://site.com:abc", nil) resp, _ := client.Do(req) assert.Nil(t, resp, "CONNECT request with non-integer port should fail with 400") req, _ = http.NewRequest("GET", "https://www.google.com/humans.txt", nil) resp, err := client.Do(req) if !assert.NoError(t, err) { return } _ = resp.Body.Close() assert.Equal(t, http.StatusOK, resp.StatusCode, "CONNECT request to allowed port should succeed") req, _ = http.NewRequest("CONNECT", fmt.Sprintf("https://site.com:%d", (port-1)), nil) resp, _ = client.Do(req) assert.Nil(t, resp, "CONNECT request to disallowed port should fail with 403") }