Example #1
0
func TestPriorityGroupModifyRequest(t *testing.T) {
	var order []string

	pg := NewGroup()

	tm50 := martiantest.NewModifier()
	tm50.RequestFunc(func(*http.Request) {
		order = append(order, "tm50")
	})
	pg.AddRequestModifier(tm50, 50)

	tm100a := martiantest.NewModifier()
	tm100a.RequestFunc(func(*http.Request) {
		order = append(order, "tm100a")
	})
	pg.AddRequestModifier(tm100a, 100)

	tm100b := martiantest.NewModifier()
	tm100b.RequestFunc(func(*http.Request) {
		order = append(order, "tm100b")
	})
	pg.AddRequestModifier(tm100b, 100)

	tm75 := martiantest.NewModifier()
	tm75.RequestFunc(func(*http.Request) {
		order = append(order, "tm75")
	})

	if err := pg.RemoveRequestModifier(tm75); err != ErrModifierNotFound {
		t.Fatalf("RemoveRequestModifier(): got %v, want ErrModifierNotFound", err)
	}

	pg.AddRequestModifier(tm75, 100)

	if err := pg.RemoveRequestModifier(tm75); err != nil {
		t.Fatalf("RemoveRequestModifier(): got %v, want no error", err)
	}

	req, err := http.NewRequest("GET", "http://example.com/", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}
	if err := pg.ModifyRequest(req); err != nil {
		t.Fatalf("ModifyRequest(): got %v, want no error", err)
	}
	if got, want := order, []string{"tm100b", "tm100a", "tm50"}; !reflect.DeepEqual(got, want) {
		t.Fatalf("reflect.DeepEqual(%v, %v): got false, want true", got, want)
	}
}
Example #2
0
func TestModifyRequest(t *testing.T) {
	f := NewFilter()

	tm := martiantest.NewModifier()
	f.SetRequestModifier("id", tm)

	req, err := http.NewRequest("GET", "http://example.com", nil)
	if err != nil {
		t.Fatalf("NewRequest(): got %v, want no error", err)
	}

	// No ID, auth required.
	f.SetAuthRequired(true)

	ctx := session.FromContext(nil)
	martian.SetContext(req, ctx)
	defer martian.RemoveContext(req)

	if err := f.ModifyRequest(req); err != nil {
		t.Fatalf("ModifyRequest(): got %v, want no error", err)
	}

	actx := FromContext(ctx)
	if actx.Error() == nil {
		t.Error("actx.Error(): got nil, want error")
	}
	if tm.RequestModified() {
		t.Error("tm.RequestModified(): got true, want false")
	}
	tm.Reset()

	// No ID, auth not required.
	f.SetAuthRequired(false)
	actx.SetError(nil)

	if err := f.ModifyRequest(req); err != nil {
		t.Fatalf("ModifyRequest(): got %v, want no error", err)
	}

	if actx.Error() != nil {
		t.Errorf("actx.Error(): got %v, want no error", err)
	}
	if tm.RequestModified() {
		t.Error("tm.RequestModified(): got true, want false")
	}

	// Valid ID.
	actx.SetError(nil)
	actx.SetID("id")

	if err := f.ModifyRequest(req); err != nil {
		t.Fatalf("ModifyRequest(): got %v, want no error", err)
	}
	if actx.Error() != nil {
		t.Errorf("actx.Error(): got %v, want no error", actx.Error())
	}
	if !tm.RequestModified() {
		t.Error("tm.RequestModified(): got false, want true")
	}
}
Example #3
0
func TestResultModifierAccessors(t *testing.T) {
	tm := martiantest.NewModifier()

	r := &Result{
		reqmod: tm,
		resmod: nil,
	}
	if reqmod := r.RequestModifier(); reqmod == nil {
		t.Error("r.RequestModifier: got nil, want reqmod")
	}

	if resmod := r.ResponseModifier(); resmod != nil {
		t.Error("r.ResponseModifier: got resmod, want nil")
	}

	r = &Result{
		reqmod: nil,
		resmod: tm,
	}
	if reqmod := r.RequestModifier(); reqmod != nil {
		t.Errorf("r.RequestModifier: got reqmod, want nil")
	}

	if resmod := r.ResponseModifier(); resmod == nil {
		t.Error("r.ResponseModifier: got nil, want resmod")
	}
}
Example #4
0
func TestModifyRequest(t *testing.T) {
	m := NewModifier()
	m.SetRequestModifier(nil)

	req, err := http.NewRequest("CONNECT", "https://www.example.com", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}

	ctx := session.FromContext(nil)
	martian.SetContext(req, ctx)
	defer martian.RemoveContext(req)

	if err := m.ModifyRequest(req); err != nil {
		t.Fatalf("ModifyRequest(): got %v, want no error", err)
	}

	actx := auth.FromContext(ctx)

	if got, want := actx.ID(), ""; got != want {
		t.Errorf("actx.ID(): got %q, want %q", got, want)
	}

	// IP with port and modifier with error.
	tm := martiantest.NewModifier()
	reqerr := errors.New("request error")
	tm.RequestError(reqerr)

	req.RemoteAddr = "1.1.1.1:8111"
	m.SetRequestModifier(tm)

	if err := m.ModifyRequest(req); err != reqerr {
		t.Fatalf("ModifyConnectRequest(): got %v, want %v", err, reqerr)
	}

	if got, want := actx.ID(), "1.1.1.1"; got != want {
		t.Errorf("actx.ID(): got %q, want %q", got, want)
	}

	// IP without port and modifier with auth error.
	req.RemoteAddr = "4.4.4.4"

	autherr := errors.New("auth error")
	tm.RequestError(nil)
	tm.RequestFunc(func(req *http.Request) {
		ctx := martian.Context(req)
		actx := auth.FromContext(ctx)

		actx.SetError(autherr)
	})

	if err := m.ModifyRequest(req); err != nil {
		t.Fatalf("ModifyRequest(): got %v, want no error", err)
	}

	if got, want := actx.ID(), ""; got != want {
		t.Errorf("actx.ID(): got %q, want %q", got, want)
	}
}
Example #5
0
func TestPriorityGroupModifyResponse(t *testing.T) {
	var order []string

	pg := NewGroup()

	tm50 := martiantest.NewModifier()
	tm50.ResponseFunc(func(*http.Response) {
		order = append(order, "tm50")
	})
	pg.AddResponseModifier(tm50, 50)

	tm100a := martiantest.NewModifier()
	tm100a.ResponseFunc(func(*http.Response) {
		order = append(order, "tm100a")
	})
	pg.AddResponseModifier(tm100a, 100)

	tm100b := martiantest.NewModifier()
	tm100b.ResponseFunc(func(*http.Response) {
		order = append(order, "tm100b")
	})
	pg.AddResponseModifier(tm100b, 100)

	tm75 := martiantest.NewModifier()
	tm75.ResponseFunc(func(*http.Response) {
		order = append(order, "tm75")
	})

	if err := pg.RemoveResponseModifier(tm75); err != ErrModifierNotFound {
		t.Fatalf("RemoveResponseModifier(): got %v, want ErrModifierNotFound", err)
	}

	pg.AddResponseModifier(tm75, 100)

	if err := pg.RemoveResponseModifier(tm75); err != nil {
		t.Fatalf("RemoveResponseModifier(): got %v, want no error", err)
	}

	res := proxyutil.NewResponse(200, nil, nil)
	if err := pg.ModifyResponse(res); err != nil {
		t.Fatalf("ModifyResponse(): got %v, want no error", err)
	}
	if got, want := order, []string{"tm100b", "tm100a", "tm50"}; !reflect.DeepEqual(got, want) {
		t.Fatalf("reflect.DeepEqual(%v, %v): got false, want true", got, want)
	}
}
Example #6
0
func TestProxyAuth(t *testing.T) {
	m := NewModifier()
	tm := martiantest.NewModifier()
	m.SetRequestModifier(tm)
	m.SetResponseModifier(tm)

	req, err := http.NewRequest("GET", "http://example.com", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}
	req.Header.Set("Proxy-Authorization", "Basic "+encode("user:pass"))

	ctx, err := session.FromContext(nil)
	if err != nil {
		t.Fatalf("session.FromContext(): got %v, want no error", err)
	}

	martian.SetContext(req, ctx)
	defer martian.RemoveContext(req)

	if err := m.ModifyRequest(req); err != nil {
		t.Fatalf("ModifyRequest(): got %v, want no error", err)
	}

	actx := auth.FromContext(ctx)
	if got, want := actx.ID(), "user:pass"; got != want {
		t.Fatalf("actx.ID(): got %q, want %q", got, want)
	}

	if err := m.ModifyRequest(req); err != nil {
		t.Fatalf("ModifyRequest(): got %v, want no error", err)
	}
	if !tm.RequestModified() {
		t.Error("tm.RequestModified(): got false, want true")
	}

	res := proxyutil.NewResponse(200, nil, req)

	if err := m.ModifyResponse(res); err != nil {
		t.Fatalf("ModifyResponse(): got %v, want no error", err)
	}

	if err := m.ModifyResponse(res); err != nil {
		t.Fatalf("ModifyResponse(): got %v, want no error", err)
	}

	if !tm.ResponseModified() {
		t.Error("tm.ResponseModified(): got false, want true")
	}

	if got, want := res.StatusCode, 200; got != want {
		t.Errorf("res.StatusCode: got %d, want %d", got, want)
	}
	if got, want := res.Header.Get("Proxy-Authenticate"), ""; got != want {
		t.Errorf("res.Header.Get(%q): got %q, want %q", "Proxy-Authenticate", got, want)
	}
}
Example #7
0
func TestIntegrationHTTPDownstreamProxyError(t *testing.T) {
	t.Parallel()

	l, err := net.Listen("tcp", "[::1]:0")
	if err != nil {
		t.Fatalf("net.Listen(): got %v, want no error", err)
	}

	p := NewProxy()
	defer p.Close()

	// Set proxy's downstream proxy to invalid host:port to force failure.
	p.SetDownstreamProxy(&url.URL{
		Host: "[::1]:0",
	})
	p.SetTimeout(600 * time.Millisecond)

	tm := martiantest.NewModifier()
	reserr := errors.New("response error")
	tm.ResponseError(reserr)

	p.SetResponseModifier(tm)

	go p.Serve(l)

	// Open connection to upstream proxy.
	conn, err := net.Dial("tcp", l.Addr().String())
	if err != nil {
		t.Fatalf("net.Dial(): got %v, want no error", err)
	}
	defer conn.Close()

	req, err := http.NewRequest("CONNECT", "//example.com:443", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}

	// CONNECT example.com:443 HTTP/1.1
	// Host: example.com
	if err := req.Write(conn); err != nil {
		t.Fatalf("req.Write(): got %v, want no error", err)
	}

	// Response from upstream proxy, assuming downstream proxy failed to CONNECT.
	res, err := http.ReadResponse(bufio.NewReader(conn), req)
	if err != nil {
		t.Fatalf("http.ReadResponse(): got %v, want no error", err)
	}

	if got, want := res.StatusCode, 502; got != want {
		t.Fatalf("res.StatusCode: got %d, want %d", got, want)
	}
	if got, want := res.Header["Warning"][1], reserr.Error(); !strings.Contains(got, want) {
		t.Errorf("res.Header.get(%q): got %q, want to contain %q", "Warning", got, want)
	}
}
Example #8
0
func TestModifyResponse(t *testing.T) {
	m := NewModifier()
	m.SetResponseModifier(nil)

	req, err := http.NewRequest("GET", "http://example.com", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}

	ctx, err := session.FromContext(nil)
	if err != nil {
		t.Fatalf("session.FromContext(): got %v, want no error", err)
	}

	martian.SetContext(req, ctx)
	defer martian.RemoveContext(req)

	res := proxyutil.NewResponse(200, nil, req)
	if err := m.ModifyResponse(res); err != nil {
		t.Fatalf("ModifyResponse(): got %v, want no error", err)
	}

	// Modifier with error.
	tm := martiantest.NewModifier()
	reserr := errors.New("response error")
	tm.ResponseError(reserr)

	m.SetResponseModifier(tm)
	if err := m.ModifyResponse(res); err != reserr {
		t.Fatalf("ModifyResponse(): got %v, want %v", err, reserr)
	}

	// Modifier with auth error.
	tm.ResponseError(nil)
	autherr := errors.New("auth error")
	tm.ResponseFunc(func(res *http.Response) {
		ctx := martian.Context(res.Request)
		actx := auth.FromContext(ctx)

		actx.SetError(autherr)
	})

	actx := auth.FromContext(ctx)
	actx.SetID("bad-auth")

	if err := m.ModifyResponse(res); err != nil {
		t.Fatalf("ModifyResponse(): got %v, want no error", err)
	}
	if got, want := res.StatusCode, 403; got != want {
		t.Errorf("res.StatusCode: got %d, want %d", got, want)
	}

	if got, want := actx.Error(), autherr; got != want {
		t.Errorf("actx.Error(): got %v, want %v", got, want)
	}
}
Example #9
0
func TestModifyResponse(t *testing.T) {
	f := NewFilter()

	tm := martiantest.NewModifier()
	f.SetResponseModifier("id", tm)

	req, err := http.NewRequest("GET", "http://example.com", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}
	res := proxyutil.NewResponse(200, nil, req)

	// No ID, auth required.
	f.SetAuthRequired(true)

	ctx, remove, err := martian.TestContext(req)
	if err != nil {
		t.Fatalf("martian.TestContext(): got %v, want no error", err)
	}
	defer remove()

	if err := f.ModifyResponse(res); err != nil {
		t.Fatalf("ModifyResponse(): got %v, want no error", err)
	}

	actx := FromContext(ctx)
	if actx.Error() == nil {
		t.Error("actx.Error(): got nil, want error")
	}
	if tm.ResponseModified() {
		t.Error("tm.RequestModified(): got true, want false")
	}

	// No ID, no auth required.
	f.SetAuthRequired(false)
	actx.SetError(nil)

	if err := f.ModifyResponse(res); err != nil {
		t.Fatalf("ModifyResponse(): got %v, want no error", err)
	}
	if tm.ResponseModified() {
		t.Error("tm.ResponseModified(): got true, want false")
	}

	// Valid ID.
	actx.SetID("id")

	if err := f.ModifyResponse(res); err != nil {
		t.Fatalf("ModifyResponse(): got %v, want no error", err)
	}
	if !tm.ResponseModified() {
		t.Error("tm.ResponseModified(): got false, want true")
	}
}
Example #10
0
func TestIntegrationSkipRoundTrip(t *testing.T) {
	t.Parallel()

	l, err := net.Listen("tcp", "[::1]:0")
	if err != nil {
		t.Fatalf("net.Listen(): got %v, want no error", err)
	}

	p := NewProxy()
	defer p.Close()

	// Transport will be skipped, no 500.
	tr := martiantest.NewTransport()
	tr.Respond(500)
	p.SetRoundTripper(tr)
	p.SetTimeout(200 * time.Millisecond)

	tm := martiantest.NewModifier()
	tm.RequestFunc(func(req *http.Request) {
		ctx := Context(req)
		ctx.SkipRoundTrip()
	})
	p.SetRequestModifier(tm)

	go p.Serve(l)

	conn, err := net.Dial("tcp", l.Addr().String())
	if err != nil {
		t.Fatalf("net.Dial(): got %v, want no error", err)
	}
	defer conn.Close()

	req, err := http.NewRequest("GET", "http://example.com", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}

	// GET http://example.com/ HTTP/1.1
	// Host: example.com
	if err := req.WriteProxy(conn); err != nil {
		t.Fatalf("req.WriteProxy(): got %v, want no error", err)
	}

	// Response from skipped round trip.
	res, err := http.ReadResponse(bufio.NewReader(conn), req)
	if err != nil {
		t.Fatalf("http.ReadResponse(): got %v, want no error", err)
	}
	defer res.Body.Close()

	if got, want := res.StatusCode, 200; got != want {
		t.Errorf("res.StatusCode: got %d, want %d", got, want)
	}
}
Example #11
0
func TestModifyResponseHaltsOnError(t *testing.T) {
	fg := NewGroup()

	reserr := errors.New("request error")
	tm := martiantest.NewModifier()
	tm.ResponseError(reserr)
	fg.AddResponseModifier(tm)

	tm2 := martiantest.NewModifier()
	fg.AddResponseModifier(tm2)

	res := proxyutil.NewResponse(200, nil, nil)
	if err := fg.ModifyResponse(res); err != reserr {
		t.Fatalf("fg.ModifyResponse(): got %v, want %v", err, reserr)
	}

	if tm2.ResponseModified() {
		t.Error("tm2.ResponseModified(): got true, want false")
	}
}
Example #12
0
func TestIntegrationTransparentHTTP(t *testing.T) {
	t.Parallel()

	l, err := net.Listen("tcp", "[::1]:0")
	if err != nil {
		t.Fatalf("net.Listen(): got %v, want no error", err)
	}

	p := NewProxy()
	defer p.Close()

	tr := martiantest.NewTransport()
	p.SetRoundTripper(tr)
	p.SetTimeout(200 * time.Millisecond)

	tm := martiantest.NewModifier()
	p.SetRequestModifier(tm)
	p.SetResponseModifier(tm)

	go p.Serve(l)

	conn, err := net.Dial("tcp", l.Addr().String())
	if err != nil {
		t.Fatalf("net.Dial(): got %v, want no error", err)
	}
	defer conn.Close()

	req, err := http.NewRequest("GET", "http://example.com", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}

	// GET / HTTP/1.1
	// Host: www.example.com
	if err := req.Write(conn); err != nil {
		t.Fatalf("req.Write(): got %v, want no error", err)
	}

	res, err := http.ReadResponse(bufio.NewReader(conn), req)
	if err != nil {
		t.Fatalf("http.ReadResponse(): got %v, want no error", err)
	}

	if got, want := res.StatusCode, 200; got != want {
		t.Fatalf("res.StatusCode: got %d, want %d", got, want)
	}

	if !tm.RequestModified() {
		t.Error("tm.RequestModified(): got false, want true")
	}
	if !tm.ResponseModified() {
		t.Error("tm.ResponseModified(): got false, want true")
	}
}
Example #13
0
func TestPriorityGroupModifyResponseHaltsOnError(t *testing.T) {
	pg := NewGroup()

	reserr := errors.New("response error")
	tm := martiantest.NewModifier()
	tm.ResponseError(reserr)

	pg.AddResponseModifier(tm, 100)

	tm2 := martiantest.NewModifier()
	pg.AddResponseModifier(tm2, 75)

	res := proxyutil.NewResponse(200, nil, nil)
	if err := pg.ModifyResponse(res); err != reserr {
		t.Fatalf("ModifyRequest(): got %v, want %v", err, reserr)
	}

	if tm2.ResponseModified() {
		t.Error("tm2.ResponseModified(): got true, want false")
	}
}
Example #14
0
func TestNewStack(t *testing.T) {
	stack, fg := NewStack("martian")

	tm := martiantest.NewModifier()
	fg.AddRequestModifier(tm)
	fg.AddResponseModifier(tm)

	req, err := http.NewRequest("GET", "http://example.com", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}

	_, remove, err := martian.TestContext(req)
	if err != nil {
		t.Fatalf("martian.TestContext(): got %v, want no error", err)
	}
	defer remove()

	// Hop-by-hop header to be removed.
	req.Header.Set("Hop-By-Hop", "true")
	req.Header.Set("Connection", "Hop-By-Hop")

	req.RemoteAddr = "10.0.0.1:5000"

	if err := stack.ModifyRequest(req); err != nil {
		t.Fatalf("ModifyRequest(): got %v, want no error", err)
	}

	if got, want := req.Header.Get("Hop-By-Hop"), ""; got != want {
		t.Errorf("req.Header.Get(%q): got %q, want %q", "Hop-By-Hop", got, want)
	}
	if got, want := req.Header.Get("X-Forwarded-For"), "10.0.0.1"; got != want {
		t.Errorf("req.Header.Get(%q): got %q, want %q", "X-Forwarded-For", got, want)
	}
	if got, want := req.Header.Get("Via"), "1.1 martian"; got != want {
		t.Errorf("req.Header.Get(%q): got %q, want %q", "Via", got, want)
	}

	res := proxyutil.NewResponse(200, nil, req)

	// Hop-by-hop header to be removed.
	res.Header.Set("Hop-By-Hop", "true")
	res.Header.Set("Connection", "Hop-By-Hop")

	if err := stack.ModifyResponse(res); err != nil {
		t.Fatalf("ModifyResponse(): got %v, want no error", err)
	}

	if got, want := res.Header.Get("Hop-By-Hop"), ""; got != want {
		t.Errorf("res.Header.Get(%q): got %q, want %q", "Hop-By-Hop", got, want)
	}
}
Example #15
0
func TestModifyRequestHaltsOnError(t *testing.T) {
	fg := NewGroup()

	reqerr := errors.New("request error")
	tm := martiantest.NewModifier()
	tm.RequestError(reqerr)
	fg.AddRequestModifier(tm)

	tm2 := martiantest.NewModifier()
	fg.AddRequestModifier(tm2)

	req, err := http.NewRequest("GET", "http://example.com/", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}
	if err := fg.ModifyRequest(req); err != reqerr {
		t.Fatalf("fg.ModifyRequest(): got %v, want %v", err, reqerr)
	}

	if tm2.RequestModified() {
		t.Error("tm2.RequestModified(): got true, want false")
	}
}
Example #16
0
func TestModifyResponse(t *testing.T) {
	fg := NewGroup()
	tm := martiantest.NewModifier()

	fg.AddResponseModifier(tm)

	res := proxyutil.NewResponse(200, nil, nil)
	if err := fg.ModifyResponse(res); err != nil {
		t.Fatalf("fg.ModifyResponse(): got %v, want no error", err)
	}
	if !tm.ResponseModified() {
		t.Error("tm.ResponseModified(): got false, want true")
	}
}
Example #17
0
func TestModifyRequest(t *testing.T) {
	f := NewFilter("mARTian-teSTInG", "true")
	f.SetRequestModifier(nil)

	req, err := http.NewRequest("GET", "http://example.com", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}

	if err := f.ModifyRequest(req); err != nil {
		t.Errorf("ModifyRequest(): got %v, want no error", err)
	}

	tt := []struct {
		name   string
		values []string
		want   bool
	}{
		{
			name:   "Martian-Production",
			values: []string{"true"},
			want:   false,
		},
		{
			name:   "Martian-Testing",
			values: []string{"see-next-value", "true"},
			want:   true,
		},
	}

	for i, tc := range tt {
		f := NewFilter("mARTian-teSTInG", "true")
		tm := martiantest.NewModifier()
		f.SetRequestModifier(tm)

		req, err := http.NewRequest("GET", "http://example.com", nil)
		if err != nil {
			t.Fatalf("%d. http.NewRequest(): got %v, want no error", i, err)
		}
		req.Header[tc.name] = tc.values

		if err := f.ModifyRequest(req); err != nil {
			t.Fatalf("%d. ModifyRequest(): got %v, want no error", i, err)
		}

		if tm.RequestModified() != tc.want {
			t.Errorf("%d. tm.RequestModified(): got %t, want %t", i, tm.RequestModified(), tc.want)
		}
	}
}
Example #18
0
func TestModifyRequest(t *testing.T) {
	fg := NewGroup()
	tm := martiantest.NewModifier()

	fg.AddRequestModifier(tm)

	req, err := http.NewRequest("GET", "/", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}
	if err := fg.ModifyRequest(req); err != nil {
		t.Fatalf("fg.ModifyRequest(): got %v, want no error", err)
	}
	if !tm.RequestModified() {
		t.Error("tm.RequestModified(): got false, want true")
	}
}
Example #19
0
func TestModifyResponse(t *testing.T) {
	f := NewFilter("mARTian-teSTInG", "true")
	f.SetResponseModifier(nil)

	res := proxyutil.NewResponse(200, nil, nil)

	if err := f.ModifyResponse(res); err != nil {
		t.Errorf("ModifyResponse(): got %v, want no error", err)
	}

	tt := []struct {
		name   string
		values []string
		want   bool
	}{
		{
			name:   "Martian-Production",
			values: []string{"true"},
			want:   false,
		},
		{
			name:   "Martian-Testing",
			values: []string{"see-next-value", "true"},
			want:   true,
		},
	}

	for i, tc := range tt {
		f := NewFilter("mARTian-teSTInG", "true")
		tm := martiantest.NewModifier()
		f.SetResponseModifier(tm)

		res := proxyutil.NewResponse(200, nil, nil)
		res.Header[tc.name] = tc.values

		if err := f.ModifyResponse(res); err != nil {
			t.Fatalf("%d. ModifyResponse(): got %v, want no error", i, err)
		}

		if tm.ResponseModified() != tc.want {
			t.Errorf("%d. tm.ResponseModified(): got %t, want %t", i, tm.ResponseModified(), tc.want)
		}
	}
}
Example #20
0
func TestFilter(t *testing.T) {
	f := NewFilter()
	if f.RequestModifier("id") != nil {
		t.Fatalf("f.RequestModifier(%q): got reqmod, want nil", "id")
	}
	if f.ResponseModifier("id") != nil {
		t.Fatalf("f.ResponseModifier(%q): got resmod, want nil", "id")
	}

	tm := martiantest.NewModifier()
	f.SetRequestModifier("id", tm)
	f.SetResponseModifier("id", tm)

	if f.RequestModifier("id") != tm {
		t.Errorf("f.RequestModifier(%q): got nil, want martiantest.Modifier", "id")
	}
	if f.ResponseModifier("id") != tm {
		t.Errorf("f.ResponseModifier(%q): got nil, want martiantest.Modifier", "id")
	}
}
Example #21
0
func TestModifyRequest(t *testing.T) {
	m := NewModifier()
	tm := martiantest.NewModifier()

	m.SetRequestModifier(tm)

	req, err := http.NewRequest("GET", "http://example.com", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}

	if err := m.ModifyRequest(req); err != nil {
		t.Fatalf("ModifyRequest(): got %v, want no error", err)
	}
	if !tm.RequestModified() {
		t.Error("tm.RequestModified(): got false, want true")
	}

	m.SetRequestModifier(nil)

	if err := m.ModifyRequest(req); err != nil {
		t.Fatalf("ModifyRequest(): got %v, want no error", err)
	}
}
Example #22
0
func TestIntegrationConnect(t *testing.T) {
	t.Parallel()

	l, err := net.Listen("tcp", "[::1]:0")
	if err != nil {
		t.Fatalf("net.Listen(): got %v, want no error", err)
	}

	p := NewProxy()
	defer p.Close()

	// Test TLS server.
	ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", time.Hour)
	if err != nil {
		t.Fatalf("mitm.NewAuthority(): got %v, want no error", err)
	}
	mc, err := mitm.NewConfig(ca, priv)
	if err != nil {
		t.Fatalf("mitm.NewConfig(): got %v, want no error", err)
	}

	tl, err := net.Listen("tcp", "[::1]:0")
	if err != nil {
		t.Fatalf("tls.Listen(): got %v, want no error", err)
	}
	tl = tls.NewListener(tl, mc.TLS())

	go http.Serve(tl, http.HandlerFunc(
		func(rw http.ResponseWriter, req *http.Request) {
			rw.WriteHeader(299)
		}))

	tm := martiantest.NewModifier()
	reqerr := errors.New("request error")
	reserr := errors.New("response error")

	// Force the CONNECT request to dial the local TLS server.
	tm.RequestFunc(func(req *http.Request) {
		req.URL.Host = tl.Addr().String()
	})

	tm.RequestError(reqerr)
	tm.ResponseError(reserr)

	p.SetRequestModifier(tm)
	p.SetResponseModifier(tm)

	go p.Serve(l)

	conn, err := net.Dial("tcp", l.Addr().String())
	if err != nil {
		t.Fatalf("net.Dial(): got %v, want no error", err)
	}
	defer conn.Close()

	req, err := http.NewRequest("CONNECT", "//example.com:443", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}

	// CONNECT example.com:443 HTTP/1.1
	// Host: example.com
	//
	// Rewritten to CONNECT to host:port in CONNECT request modifier.
	if err := req.Write(conn); err != nil {
		t.Fatalf("req.Write(): got %v, want no error", err)
	}

	// CONNECT response after establishing tunnel.
	res, err := http.ReadResponse(bufio.NewReader(conn), req)
	if err != nil {
		t.Fatalf("http.ReadResponse(): got %v, want no error", err)
	}

	if got, want := res.StatusCode, 200; got != want {
		t.Fatalf("res.StatusCode: got %d, want %d", got, want)
	}

	if !tm.RequestModified() {
		t.Error("tm.RequestModified(): got false, want true")
	}
	if !tm.ResponseModified() {
		t.Error("tm.ResponseModified(): got false, want true")
	}
	if got, want := res.Header.Get("Warning"), reserr.Error(); !strings.Contains(got, want) {
		t.Errorf("res.Header.Get(%q): got %q, want to contain %q", "Warning", got, want)
	}

	roots := x509.NewCertPool()
	roots.AddCert(ca)

	tlsconn := tls.Client(conn, &tls.Config{
		ServerName: "example.com",
		RootCAs:    roots,
	})
	defer tlsconn.Close()

	req, err = http.NewRequest("GET", "https://example.com", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}
	req.Header.Set("Connection", "close")

	// GET / HTTP/1.1
	// Host: example.com
	// Connection: close
	if err := req.Write(tlsconn); err != nil {
		t.Fatalf("req.Write(): got %v, want no error", err)
	}

	res, err = http.ReadResponse(bufio.NewReader(tlsconn), req)
	if err != nil {
		t.Fatalf("http.ReadResponse(): got %v, want no error", err)
	}
	defer res.Body.Close()

	if got, want := res.StatusCode, 299; got != want {
		t.Fatalf("res.StatusCode: got %d, want %d", got, want)
	}
	if got, want := res.Header.Get("Warning"), reserr.Error(); strings.Contains(got, want) {
		t.Errorf("res.Header.Get(%q): got %s, want to not contain %s", "Warning", got, want)
	}
}
Example #23
0
func TestHTTPThroughConnectWithMITM(t *testing.T) {
	t.Parallel()

	l, err := net.Listen("tcp", "[::1]:0")
	if err != nil {
		t.Fatalf("net.Listen(): got %v, want no error", err)
	}

	p := NewProxy()
	defer p.Close()

	tm := martiantest.NewModifier()
	tm.RequestFunc(func(req *http.Request) {
		ctx := NewContext(req)
		ctx.SkipRoundTrip()

		if req.Method != "GET" && req.Method != "CONNECT" {
			t.Errorf("unexpected method on request handler: %v", req.Method)
		}
	})
	p.SetRequestModifier(tm)

	ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", 2*time.Hour)
	if err != nil {
		t.Fatalf("mitm.NewAuthority(): got %v, want no error", err)
	}

	mc, err := mitm.NewConfig(ca, priv)
	if err != nil {
		t.Fatalf("mitm.NewConfig(): got %v, want no error", err)
	}
	p.SetMITM(mc)

	go p.Serve(l)

	conn, err := net.Dial("tcp", l.Addr().String())
	if err != nil {
		t.Fatalf("net.Dial(): got %v, want no error", err)
	}
	defer conn.Close()

	req, err := http.NewRequest("CONNECT", "//example.com:80", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}

	// CONNECT example.com:80 HTTP/1.1
	// Host: example.com
	if err := req.Write(conn); err != nil {
		t.Fatalf("req.Write(): got %v, want no error", err)
	}

	// Response skipped round trip.
	res, err := http.ReadResponse(bufio.NewReader(conn), req)
	if err != nil {
		t.Fatalf("http.ReadResponse(): got %v, want no error", err)
	}
	res.Body.Close()

	if got, want := res.StatusCode, 200; got != want {
		t.Errorf("res.StatusCode: got %d, want %d", got, want)
	}

	req, err = http.NewRequest("GET", "http://example.com", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}

	// GET http://example.com/ HTTP/1.1
	// Host: example.com
	if err := req.WriteProxy(conn); err != nil {
		t.Fatalf("req.WriteProxy(): got %v, want no error", err)
	}

	// Response from skipped round trip.
	res, err = http.ReadResponse(bufio.NewReader(conn), req)
	if err != nil {
		t.Fatalf("http.ReadResponse(): got %v, want no error", err)
	}
	res.Body.Close()

	if got, want := res.StatusCode, 200; got != want {
		t.Errorf("res.StatusCode: got %d, want %d", got, want)
	}

	req, err = http.NewRequest("GET", "http://example.com", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}

	// GET http://example.com/ HTTP/1.1
	// Host: example.com
	if err := req.WriteProxy(conn); err != nil {
		t.Fatalf("req.WriteProxy(): got %v, want no error", err)
	}

	// Response from skipped round trip.
	res, err = http.ReadResponse(bufio.NewReader(conn), req)
	if err != nil {
		t.Fatalf("http.ReadResponse(): got %v, want no error", err)
	}
	res.Body.Close()

	if got, want := res.StatusCode, 200; got != want {
		t.Errorf("res.StatusCode: got %d, want %d", got, want)
	}
}
Example #24
0
func TestIntegrationTransparentMITM(t *testing.T) {
	t.Parallel()

	ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", 2*time.Hour)
	if err != nil {
		t.Fatalf("mitm.NewAuthority(): got %v, want no error", err)
	}

	mc, err := mitm.NewConfig(ca, priv)
	if err != nil {
		t.Fatalf("mitm.NewConfig(): got %v, want no error", err)
	}

	// Start TLS listener with config that will generate certificates based on
	// SNI from connection.
	//
	// BUG: tls.Listen will not accept a tls.Config where Certificates is empty,
	// even though it is supported by tls.Server when GetCertificate is not nil.
	l, err := net.Listen("tcp", "[::1]:0")
	if err != nil {
		t.Fatalf("net.Listen(): got %v, want no error", err)
	}
	l = tls.NewListener(l, mc.TLS())

	p := NewProxy()
	defer p.Close()

	tr := martiantest.NewTransport()
	tr.Func(func(req *http.Request) (*http.Response, error) {
		res := proxyutil.NewResponse(200, nil, req)
		res.Header.Set("Request-Scheme", req.URL.Scheme)

		return res, nil
	})

	p.SetRoundTripper(tr)

	tm := martiantest.NewModifier()
	p.SetRequestModifier(tm)
	p.SetResponseModifier(tm)

	go p.Serve(l)

	roots := x509.NewCertPool()
	roots.AddCert(ca)

	tlsconn, err := tls.Dial("tcp", l.Addr().String(), &tls.Config{
		// Verify the hostname is example.com.
		ServerName: "example.com",
		// The certificate will have been generated during MITM, so we need to
		// verify it with the generated CA certificate.
		RootCAs: roots,
	})
	if err != nil {
		t.Fatalf("tls.Dial(): got %v, want no error", err)
	}
	defer tlsconn.Close()

	req, err := http.NewRequest("GET", "https://example.com", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}

	// Write Encrypted request directly, no CONNECT.
	// GET / HTTP/1.1
	// Host: example.com
	if err := req.Write(tlsconn); err != nil {
		t.Fatalf("req.Write(): got %v, want no error", err)
	}

	res, err := http.ReadResponse(bufio.NewReader(tlsconn), req)
	if err != nil {
		t.Fatalf("http.ReadResponse(): got %v, want no error", err)
	}
	defer res.Body.Close()

	if got, want := res.StatusCode, 200; got != want {
		t.Fatalf("res.StatusCode: got %d, want %d", got, want)
	}
	if got, want := res.Header.Get("Request-Scheme"), "https"; got != want {
		t.Errorf("res.Header.Get(%q): got %q, want %q", "Request-Scheme", got, want)
	}

	if !tm.RequestModified() {
		t.Errorf("tm.RequestModified(): got false, want true")
	}
	if !tm.ResponseModified() {
		t.Errorf("tm.ResponseModified(): got false, want true")
	}
}
Example #25
0
func TestIntegrationTLSHandshakeErrorCallback(t *testing.T) {
	t.Parallel()

	l, err := net.Listen("tcp", "[::1]:0")
	if err != nil {
		t.Fatalf("net.Listen(): got %v, want no error", err)
	}

	p := NewProxy()
	defer p.Close()

	// Test TLS server.
	ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", time.Hour)
	if err != nil {
		t.Fatalf("mitm.NewAuthority(): got %v, want no error", err)
	}
	mc, err := mitm.NewConfig(ca, priv)
	if err != nil {
		t.Fatalf("mitm.NewConfig(): got %v, want no error", err)
	}
	cb := make(chan error)
	mc.SetHandshakeErrorCallback(func(_ *http.Request, err error) { cb <- err })
	p.SetMITM(mc)

	tl, err := net.Listen("tcp", "[::1]:0")
	if err != nil {
		t.Fatalf("tls.Listen(): got %v, want no error", err)
	}
	tl = tls.NewListener(tl, mc.TLS())

	go http.Serve(tl, http.HandlerFunc(
		func(rw http.ResponseWriter, req *http.Request) {
			rw.WriteHeader(200)
		}))

	tm := martiantest.NewModifier()

	// Force the CONNECT request to dial the local TLS server.
	tm.RequestFunc(func(req *http.Request) {
		req.URL.Host = tl.Addr().String()
	})

	go p.Serve(l)

	conn, err := net.Dial("tcp", l.Addr().String())
	if err != nil {
		t.Fatalf("net.Dial(): got %v, want no error", err)
	}
	defer conn.Close()

	req, err := http.NewRequest("CONNECT", "//example.com:443", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}

	// CONNECT example.com:443 HTTP/1.1
	// Host: example.com
	//
	// Rewritten to CONNECT to host:port in CONNECT request modifier.
	if err := req.Write(conn); err != nil {
		t.Fatalf("req.Write(): got %v, want no error", err)
	}

	// CONNECT response after establishing tunnel.
	if _, err := http.ReadResponse(bufio.NewReader(conn), req); err != nil {
		t.Fatalf("http.ReadResponse(): got %v, want no error", err)
	}

	tlsconn := tls.Client(conn, &tls.Config{
		ServerName: "example.com",
		// Client has no cert so it will get "x509: certificate signed by unknown authority" from the
		// handshake and send "remote error: bad certificate" to the server.
		RootCAs: x509.NewCertPool(),
	})
	defer tlsconn.Close()

	req, err = http.NewRequest("GET", "https://example.com", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}
	req.Header.Set("Connection", "close")

	if got, want := req.Write(tlsconn), "x509: certificate signed by unknown authority"; !strings.Contains(got.Error(), want) {
		t.Fatalf("Got incorrect error from Client Handshake(), got: %v, want: %v", got, want)
	}

	if got, want := <-cb, "remote error: bad certificate"; !strings.Contains(got.Error(), want) {
		t.Fatalf("Got incorrect error from Server Handshake(), got: %v, want: %v", got, want)
	}
}
Example #26
0
func TestIntegrationMITM(t *testing.T) {
	t.Parallel()

	l, err := net.Listen("tcp", "[::1]:0")
	if err != nil {
		t.Fatalf("net.Listen(): got %v, want no error", err)
	}

	p := NewProxy()
	defer p.Close()

	tr := martiantest.NewTransport()
	tr.Func(func(req *http.Request) (*http.Response, error) {
		res := proxyutil.NewResponse(200, nil, req)
		res.Header.Set("Request-Scheme", req.URL.Scheme)

		return res, nil
	})

	p.SetRoundTripper(tr)
	p.SetTimeout(600 * time.Millisecond)

	ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", 2*time.Hour)
	if err != nil {
		t.Fatalf("mitm.NewAuthority(): got %v, want no error", err)
	}

	mc, err := mitm.NewConfig(ca, priv)
	if err != nil {
		t.Fatalf("mitm.NewConfig(): got %v, want no error", err)
	}
	p.SetMITM(mc)

	tm := martiantest.NewModifier()
	reqerr := errors.New("request error")
	reserr := errors.New("response error")
	tm.RequestError(reqerr)
	tm.ResponseError(reserr)

	p.SetRequestModifier(tm)
	p.SetResponseModifier(tm)

	go p.Serve(l)

	conn, err := net.Dial("tcp", l.Addr().String())
	if err != nil {
		t.Fatalf("net.Dial(): got %v, want no error", err)
	}
	defer conn.Close()

	req, err := http.NewRequest("CONNECT", "//example.com:443", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}

	// CONNECT example.com:443 HTTP/1.1
	// Host: example.com
	if err := req.Write(conn); err != nil {
		t.Fatalf("req.Write(): got %v, want no error", err)
	}

	// Response MITM'd from proxy.
	res, err := http.ReadResponse(bufio.NewReader(conn), req)
	if err != nil {
		t.Fatalf("http.ReadResponse(): got %v, want no error", err)
	}
	if got, want := res.StatusCode, 200; got != want {

		t.Errorf("res.StatusCode: got %d, want %d", got, want)
	}
	if got, want := res.Header.Get("Warning"), reserr.Error(); !strings.Contains(got, want) {
		t.Errorf("res.Header.Get(%q): got %q, want to contain %q", "Warning", got, want)
	}

	roots := x509.NewCertPool()
	roots.AddCert(ca)

	tlsconn := tls.Client(conn, &tls.Config{
		ServerName: "example.com",
		RootCAs:    roots,
	})
	defer tlsconn.Close()

	req, err = http.NewRequest("GET", "https://example.com", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}

	// GET / HTTP/1.1
	// Host: example.com
	if err := req.Write(tlsconn); err != nil {
		t.Fatalf("req.Write(): got %v, want no error", err)
	}

	// Response from MITM proxy.
	res, err = http.ReadResponse(bufio.NewReader(tlsconn), req)
	if err != nil {
		t.Fatalf("http.ReadResponse(): got %v, want no error", err)
	}
	defer res.Body.Close()

	if got, want := res.StatusCode, 200; got != want {
		t.Errorf("res.StatusCode: got %d, want %d", got, want)
	}
	if got, want := res.Header.Get("Request-Scheme"), "https"; got != want {
		t.Errorf("res.Header.Get(%q): got %q, want %q", "Request-Scheme", got, want)
	}
	if got, want := res.Header.Get("Warning"), reserr.Error(); !strings.Contains(got, want) {
		t.Errorf("res.Header.Get(%q): got %q, want to contain %q", "Warning", got, want)
	}
}
Example #27
0
func TestFilterModifyRequest(t *testing.T) {
	tt := []struct {
		want  bool
		match string
		url   *url.URL
	}{
		{
			match: "https://www.example.com",
			url:   &url.URL{Scheme: "https"},
			want:  true,
		},
		{
			match: "http://www.martian.local",
			url:   &url.URL{Host: "www.martian.local"},
			want:  true,
		},
		{
			match: "http://www.example.com/test",
			url:   &url.URL{Path: "/test"},
			want:  true,
		},
		{
			match: "http://www.example.com?test=true",
			url:   &url.URL{RawQuery: "test=true"},
			want:  true,
		},
		{
			match: "http://www.example.com#test",
			url:   &url.URL{Fragment: "test"},
			want:  true,
		},
		{
			match: "https://martian.local/test?test=true#test",
			url: &url.URL{
				Scheme:   "https",
				Host:     "martian.local",
				Path:     "/test",
				RawQuery: "test=true",
				Fragment: "test",
			},
			want: true,
		},
		{
			match: "https://www.example.com",
			url:   &url.URL{Scheme: "http"},
			want:  false,
		},
		{
			match: "http://www.martian.external",
			url:   &url.URL{Host: "www.martian.local"},
			want:  false,
		},
		{
			match: "http://www.example.com/testing",
			url:   &url.URL{Path: "/test"},
			want:  false,
		},
		{
			match: "http://www.example.com?test=false",
			url:   &url.URL{RawQuery: "test=true"},
			want:  false,
		},
		{
			match: "http://www.example.com#test",
			url:   &url.URL{Fragment: "testing"},
			want:  false,
		},
	}

	for i, tc := range tt {
		req, err := http.NewRequest("GET", tc.match, nil)
		if err != nil {
			t.Fatalf("%d. NewRequest(): got %v, want no error", i, err)
		}

		mod := NewFilter(tc.url)
		tm := martiantest.NewModifier()
		mod.SetRequestModifier(tm)

		if err := mod.ModifyRequest(req); err != nil {
			t.Fatalf("%d. ModifyRequest(): got %q, want no error", i, err)
		}

		if tm.RequestModified() != tc.want {
			t.Errorf("tm.RequestModified(): got %t, want %t", tm.RequestModified(), tc.want)
		}
	}
}
Example #28
0
func TestIntegrationHTTP100Continue(t *testing.T) {
	t.Parallel()

	l, err := net.Listen("tcp", "[::1]:0")
	if err != nil {
		t.Fatalf("net.Listen(): got %v, want no error", err)
	}

	p := NewProxy()
	defer p.Close()

	p.SetTimeout(2 * time.Second)

	sl, err := net.Listen("tcp", "[::1]:0")
	if err != nil {
		t.Fatalf("net.Listen(): got %v, want no error", err)
	}

	go func() {
		conn, err := sl.Accept()
		if err != nil {
			log.Errorf("proxy_test: failed to accept connection: %v", err)
			return
		}
		defer conn.Close()

		log.Infof("proxy_test: accepted connection: %s", conn.RemoteAddr())

		req, err := http.ReadRequest(bufio.NewReader(conn))
		if err != nil {
			log.Errorf("proxy_test: failed to read request: %v", err)
			return
		}

		if req.Header.Get("Expect") == "100-continue" {
			log.Infof("proxy_test: received 100-continue request")

			conn.Write([]byte("HTTP/1.1 100 Continue\r\n\r\n"))

			log.Infof("proxy_test: sent 100-continue response")
		} else {
			log.Infof("proxy_test: received non 100-continue request")

			res := proxyutil.NewResponse(417, nil, req)
			res.Header.Set("Connection", "close")
			res.Write(conn)
			return
		}

		res := proxyutil.NewResponse(200, req.Body, req)
		res.Header.Set("Connection", "close")
		res.Write(conn)

		log.Infof("proxy_test: sent 200 response")
	}()

	tm := martiantest.NewModifier()
	p.SetRequestModifier(tm)
	p.SetResponseModifier(tm)

	go p.Serve(l)

	conn, err := net.Dial("tcp", l.Addr().String())
	if err != nil {
		t.Fatalf("net.Dial(): got %v, want no error", err)
	}
	defer conn.Close()

	host := sl.Addr().String()
	raw := fmt.Sprintf("POST http://%s/ HTTP/1.1\r\n"+
		"Host: %s\r\n"+
		"Content-Length: 12\r\n"+
		"Expect: 100-continue\r\n\r\n", host, host)

	if _, err := conn.Write([]byte(raw)); err != nil {
		t.Fatalf("conn.Write(headers): got %v, want no error", err)
	}

	go func() {
		select {
		case <-time.After(time.Second):
			conn.Write([]byte("body content"))
		}
	}()

	res, err := http.ReadResponse(bufio.NewReader(conn), nil)
	if err != nil {
		t.Fatalf("http.ReadResponse(): got %v, want no error", err)
	}
	defer res.Body.Close()

	if got, want := res.StatusCode, 200; got != want {
		t.Fatalf("res.StatusCode: got %d, want %d", got, want)
	}

	got, err := ioutil.ReadAll(res.Body)
	if err != nil {
		t.Fatalf("ioutil.ReadAll(): got %v, want no error", err)
	}

	if want := []byte("body content"); !bytes.Equal(got, want) {
		t.Errorf("res.Body: got %q, want %q", got, want)
	}

	if !tm.RequestModified() {
		t.Error("tm.RequestModified(): got false, want true")
	}
	if !tm.ResponseModified() {
		t.Error("tm.ResponseModified(): got false, want true")
	}
}
Example #29
0
func TestIntegrationHTTP(t *testing.T) {
	t.Parallel()

	l, err := net.Listen("tcp", "[::1]:0")
	if err != nil {
		t.Fatalf("net.Listen(): got %v, want no error", err)
	}

	p := NewProxy()
	defer p.Close()

	p.SetRequestModifier(nil)
	p.SetResponseModifier(nil)

	tr := martiantest.NewTransport()
	p.SetRoundTripper(tr)
	p.SetTimeout(200 * time.Millisecond)

	tm := martiantest.NewModifier()

	tm.RequestFunc(func(req *http.Request) {
		ctx := Context(req)
		ctx.Set("martian.test", "true")
	})

	tm.ResponseFunc(func(res *http.Response) {
		ctx := Context(res.Request)
		v, _ := ctx.Get("martian.test")

		res.Header.Set("Martian-Test", v.(string))
	})

	p.SetRequestModifier(tm)
	p.SetResponseModifier(tm)

	go p.Serve(l)

	conn, err := net.Dial("tcp", l.Addr().String())
	if err != nil {
		t.Fatalf("net.Dial(): got %v, want no error", err)
	}
	defer conn.Close()

	req, err := http.NewRequest("GET", "http://example.com", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}

	// GET http://example.com/ HTTP/1.1
	// Host: example.com
	if err := req.WriteProxy(conn); err != nil {
		t.Fatalf("req.WriteProxy(): got %v, want no error", err)
	}

	res, err := http.ReadResponse(bufio.NewReader(conn), req)
	if err != nil {
		t.Fatalf("http.ReadResponse(): got %v, want no error", err)
	}

	if got, want := res.StatusCode, 200; got != want {
		t.Fatalf("res.StatusCode: got %d, want %d", got, want)
	}

	if got, want := res.Header.Get("Martian-Test"), "true"; got != want {
		t.Errorf("res.Header.Get(%q): got %q, want %q", "Martian-Test", got, want)
	}
}
Example #30
0
func TestProxyAuthInvalidCredentials(t *testing.T) {
	m := NewModifier()
	autherr := errors.New("auth error")

	tm := martiantest.NewModifier()
	tm.RequestFunc(func(req *http.Request) {
		ctx := martian.Context(req)
		actx := auth.FromContext(ctx)

		actx.SetError(autherr)
	})
	tm.ResponseFunc(func(res *http.Response) {
		ctx := martian.Context(res.Request)
		actx := auth.FromContext(ctx)

		actx.SetError(autherr)
	})

	m.SetRequestModifier(tm)
	m.SetResponseModifier(tm)

	req, err := http.NewRequest("GET", "http://example.com", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}
	req.Header.Set("Proxy-Authorization", "Basic "+encode("user:pass"))

	ctx, err := session.FromContext(nil)
	if err != nil {
		t.Fatalf("session.FromContext(): got %v, want no error", err)
	}

	martian.SetContext(req, ctx)
	defer martian.RemoveContext(req)

	if err := m.ModifyRequest(req); err != nil {
		t.Fatalf("ModifyRequest(): got %v, want no error", err)
	}

	if !tm.RequestModified() {
		t.Error("tm.RequestModified(): got false, want true")
	}

	actx := auth.FromContext(ctx)
	if actx.Error() != autherr {
		t.Fatalf("auth.Error(): got %v, want %v", actx.Error(), autherr)
	}
	actx.SetError(nil)

	res := proxyutil.NewResponse(200, nil, req)

	if err := m.ModifyResponse(res); err != nil {
		t.Fatalf("ModifyResponse(): got %v, want no error", err)
	}

	if !tm.ResponseModified() {
		t.Error("tm.ResponseModified(): got false, want true")
	}

	if actx.Error() != autherr {
		t.Fatalf("auth.Error(): got %v, want %v", actx.Error(), autherr)
	}

	if got, want := res.StatusCode, http.StatusProxyAuthRequired; got != want {
		t.Errorf("res.StatusCode: got %d, want %d", got, want)
	}
	if got, want := res.Header.Get("Proxy-Authenticate"), "Basic"; got != want {
		t.Errorf("res.Header.Get(%q): got %q, want %q", "Proxy-Authenticate", got, want)
	}
}