Example #1
0
func TestModifyResponseResetAuth(t *testing.T) {
	auth := NewModifier()

	auth.SetRequestModifier(martian.RequestModifierFunc(
		func(ctx *martian.Context, req *http.Request) error {
			if ctx.Auth.ID != "secret:pass" {
				ctx.Auth.Error = errors.New("invalid auth")
			}
			return nil
		}))

	req, err := http.NewRequest("GET", "http://example.com", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}
	req.Header.Set("Proxy-Authorization", "Basic "+encode("wrong:pass"))

	ctx := martian.NewContext()
	// This will set ctx.Auth.Error since the ID isn't "secret:pass".
	if err := auth.ModifyRequest(ctx, req); err != nil {
		t.Fatalf("ModifyRequest(): got %v, want no error", err)
	}

	res := proxyutil.NewResponse(200, nil, req)
	if err := auth.ModifyResponse(ctx, res); err != nil {
		t.Fatalf("ModifyResponse(): got %v, want no error", err)
	}
	if got, want := res.StatusCode, 407; got != want {
		t.Errorf("res.StatusCode: got %d, want %d", got, want)
	}

	if got, want := ctx.Auth.ID, ""; got != want {
		t.Errorf("ctx.Auth.ID: got %q, want %q", got, want)
	}
	if err := ctx.Auth.Error; err != nil {
		t.Errorf("ctx.Auth.Error: got %v, want no error", err)
	}

	// This will be successful because the ID is "secret:pass".
	req.Header.Set("Proxy-Authorization", "Basic "+encode("secret:pass"))
	if err := auth.ModifyRequest(ctx, req); err != nil {
		t.Fatalf("ModifyRequest(): got %v, want no error", err)
	}

	// Reset the response.
	res = proxyutil.NewResponse(200, nil, req)
	if err := auth.ModifyResponse(ctx, res); err != nil {
		t.Fatalf("ModifyResponse(): got %v, want no error", err)
	}
	if got, want := res.StatusCode, 200; got != want {
		t.Errorf("res.StatusCode: got %d, want %d", got, want)
	}
}
Example #2
0
func TestModifyResponse(t *testing.T) {
	m := NewModifier()
	ctx := martian.NewContext()
	res := proxyutil.NewResponse(200, nil, nil)

	if err := m.ModifyResponse(ctx, res); err != nil {
		t.Fatalf("ModifyResponse(): got %v, want no error", err)
	}

	m.SetResponseModifier(martian.ResponseModifierFunc(
		func(*martian.Context, *http.Response) error {
			ctx.Auth.Error = errors.New("auth is required")
			return nil
		}))

	if err := m.ModifyResponse(ctx, res); err != nil {
		t.Fatalf("ModifyResponse(): got %v, want no error", err)
	}
	if got, want := res.StatusCode, http.StatusProxyAuthRequired; got != want {
		t.Errorf("res.StatusCode: got %d, want %d", got, want)
	}
	if got, want := res.Header.Get("Proxy-Authenticate"), "Basic"; got != want {
		t.Errorf("res.Header.Get(%q): got %q, want %q", "Proxy-Authenticate", got, want)
	}
}
Example #3
0
func TestCopyModifier(t *testing.T) {
	m := NewCopyModifier("Original", "Copy")

	req, err := http.NewRequest("GET", "http://example.com", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}
	req.Header.Set("Original", "test")

	if err := m.ModifyRequest(req); err != nil {
		t.Fatalf("ModifyRequest(): got %v, want no error", err)
	}

	if got, want := req.Header.Get("Copy"), "test"; got != want {
		t.Errorf("req.Header.Get(%q): got %q, want %q", "Copy", got, want)
	}

	res := proxyutil.NewResponse(200, nil, req)
	res.Header.Set("Original", "test")

	if err := m.ModifyResponse(res); err != nil {
		t.Fatalf("ModifyResponse(): got %v, want no error", err)
	}

	if got, want := res.Header.Get("Copy"), "test"; got != want {
		t.Errorf("res.Header.Get(%q): got %q, want %q", "Copy", got, want)
	}
}
Example #4
0
func TestModifyResponse(t *testing.T) {
	p := NewProxy(mitm)

	req, err := http.NewRequest("GET", "http://example.com", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(%q, ..., nil): got %v, want no error", "GET", err)
	}
	res := proxyutil.NewResponse(200, nil, req)

	if err := p.ModifyResponse(NewContext(), res); err != nil {
		t.Fatalf("ModifyResponse(): got %v, want no error", err)
	}

	modifierRun := false
	f := func(*Context, *http.Response) error {
		modifierRun = true
		return nil
	}
	p.SetResponseModifier(ResponseModifierFunc(f))

	if err := p.ModifyResponse(NewContext(), res); err != nil {
		t.Fatalf("ModifyResponse(): got %v, want no error", err)
	}
	if !modifierRun {
		t.Error("modifierRun: got false, want true")
	}
}
Example #5
0
func TestModifierFromJSON(t *testing.T) {
	msg := []byte(`{
    "cookie.Modifier": {
      "scope": ["request", "response"],
		  "name": "martian",
			"value": "value"
		}
	}`)

	r, err := parse.FromJSON(msg)
	if err != nil {
		t.Fatalf("parse.FromJSON(): got %v, want no error", err)
	}

	req, err := http.NewRequest("GET", "http://example.com/path/", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}

	reqmod := r.RequestModifier()

	if reqmod == nil {
		t.Fatal("reqmod: got nil, want not nil")
	}

	if err := reqmod.ModifyRequest(martian.NewContext(), req); err != nil {
		t.Fatalf("reqmod.ModifyRequest(): got %v, want no error", err)
	}

	if got, want := len(req.Cookies()), 1; got != want {
		t.Fatalf("len(req.Cookies): got %v, want %v", got, want)
	}
	if got, want := req.Cookies()[0].Name, "martian"; got != want {
		t.Errorf("req.Cookies()[0].Name: got %v, want %v", got, want)
	}
	if got, want := req.Cookies()[0].Value, "value"; got != want {
		t.Errorf("req.Cookies()[0].Value: got %v, want %v", got, want)
	}

	resmod := r.ResponseModifier()

	if resmod == nil {
		t.Fatal("resmod: got nil, want not nil")
	}

	res := proxyutil.NewResponse(200, nil, req)
	if err := resmod.ModifyResponse(martian.NewContext(), res); err != nil {
		t.Fatalf("resmod.ModifyResponse(): got %v, want no error", err)
	}

	if got, want := len(res.Cookies()), 1; got != want {
		t.Fatalf("len(res.Cookies): got %v, want %v", got, want)
	}
	if got, want := res.Cookies()[0].Name, "martian"; got != want {
		t.Errorf("res.Cookies()[0].Name: got %v, want %v", got, want)
	}
	if got, want := res.Cookies()[0].Value, "value"; got != want {
		t.Errorf("res.Cookies()[0].Value: got %v, want %v", got, want)
	}
}
Example #6
0
func TestFromJSON(t *testing.T) {
	msg := []byte(`{
    "status.Modifier": {
      "scope": ["response"],
      "statusCode": 400
    }
  }`)

	r, err := parse.FromJSON(msg)
	if err != nil {
		t.Fatalf("parse.FromJSON(): got %v, want no error", err)
	}
	resmod := r.ResponseModifier()
	if resmod == nil {
		t.Fatal("resmod: got nil, want not nil")
	}

	res := proxyutil.NewResponse(200, nil, nil)
	if err := resmod.ModifyResponse(res); err != nil {
		t.Fatalf("ModifyResponse(): got %v, want no error", err)
	}
	if got, want := res.StatusCode, 400; got != want {
		t.Errorf("res.StatusCode: got %d, want %d", got, want)
	}
}
Example #7
0
func TestModifierFuncs(t *testing.T) {
	reqmod := RequestModifierFunc(
		func(req *http.Request) error {
			req.Header.Set("Request-Modified", "true")
			return nil
		})

	resmod := ResponseModifierFunc(
		func(res *http.Response) error {
			res.Header.Set("Response-Modified", "true")
			return nil
		})

	req, err := http.NewRequest("GET", "http://example.com", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}

	if err := reqmod.ModifyRequest(req); err != nil {
		t.Fatalf("reqmod.ModifyRequest(): got %v, want no error", err)
	}
	if got, want := req.Header.Get("Request-Modified"), "true"; got != want {
		t.Errorf("req.Header.Get(%q): got %q, want %q", "Request-Modified", got, want)
	}

	res := proxyutil.NewResponse(200, nil, req)

	if err := resmod.ModifyResponse(res); err != nil {
		t.Fatalf("resmod.ModifyResponse(): got %v, want no error", err)
	}
	if got, want := res.Header.Get("Response-Modified"), "true"; got != want {
		t.Errorf("res.Header.Get(%q): got %q, want %q", "Response-Modified", got, want)
	}
}
Example #8
0
func TestStatusModifierOnResponse(t *testing.T) {
	for i, status := range []int{
		http.StatusForbidden,
		http.StatusOK,
		http.StatusTemporaryRedirect,
	} {
		req, err := http.NewRequest("GET", "/", nil)
		if err != nil {
			t.Fatalf("NewRequest(): got %v, want no error", err)
		}

		res := proxyutil.NewResponse(200, nil, req)

		mod := NewModifier(status)

		if err := mod.ModifyResponse(res); err != nil {
			t.Fatalf("%d. ModifyResponse(): got %v, want no error", i, err)
		}

		if got, want := res.StatusCode, status; got != want {
			t.Errorf("%d. res.StatusCode: got %v, want %v", i, got, want)
		}
		if got, want := res.Status, fmt.Sprintf("%d %s", res.StatusCode, http.StatusText(status)); got != want {
			t.Errorf("%d. res.Status: got %q, want %q", i, got, want)
		}
	}
}
Example #9
0
func TestExportIgnoresOrphanedResponse(t *testing.T) {
	logger := NewLogger("martian", "2.0.0")

	req, err := http.NewRequest("GET", "http://example.com", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}

	_, remove, err := martian.TestContext(req)
	if err != nil {
		t.Fatalf("martian.TestContext(): got %v, want no error", err)
	}
	defer remove()

	if err := logger.ModifyRequest(req); err != nil {
		t.Fatalf("ModifyRequest(): got %v, want no error", err)
	}

	// Reset before the response comes back.
	logger.Reset()

	res := proxyutil.NewResponse(200, nil, req)
	if err := logger.ModifyResponse(res); err != nil {
		t.Fatalf("ModifyResponse(): got %v, want no error", err)
	}

	log := logger.Export().Log
	if got, want := len(log.Entries), 0; got != want {
		t.Errorf("len(log.Entries): got %d, want %d", got, want)
	}
}
Example #10
0
func TestVerifierFromJSON(t *testing.T) {
	msg := []byte(`{
    "status.Verifier": {
      "scope": ["response"],
      "statusCode": 400
    }
  }`)

	r, err := parse.FromJSON(msg)
	if err != nil {
		t.Fatalf("parse.FromJSON(): got %v, want no error", err)
	}
	resmod := r.ResponseModifier()
	if resmod == nil {
		t.Fatal("resmod: got nil, want not nil")
	}
	resv, ok := resmod.(verify.ResponseVerifier)
	if !ok {
		t.Fatal("reqmod.(verify.RequestVerifier): got !ok, want ok")
	}

	req, err := http.NewRequest("GET", "http://www.example.com", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}

	res := proxyutil.NewResponse(200, nil, req)
	if err := resv.ModifyResponse(martian.NewContext(), res); err != nil {
		t.Fatalf("ModifyResponse(): got %v, want no error", err)
	}
	if err := resv.VerifyResponses(); err == nil {
		t.Error("VerifyResponses(): got nil, want not nil")
	}
}
Example #11
0
func (p *Proxy) connect(req *http.Request) (*http.Response, net.Conn, error) {
	if p.proxyURL != nil {
		log.Debugf("martian: CONNECT with downstream proxy: %s", p.proxyURL.Host)

		conn, err := net.Dial("tcp", p.proxyURL.Host)
		if err != nil {
			return nil, nil, err
		}
		pbw := bufio.NewWriter(conn)
		pbr := bufio.NewReader(conn)

		req.Write(pbw)
		pbw.Flush()

		res, err := http.ReadResponse(pbr, req)
		if err != nil {
			return nil, nil, err
		}

		return res, conn, nil
	}

	log.Debugf("martian: CONNECT to host directly: %s", req.URL.Host)

	conn, err := net.Dial("tcp", req.URL.Host)
	if err != nil {
		return nil, nil, err
	}

	return proxyutil.NewResponse(200, nil, req), conn, nil
}
Example #12
0
func TestBodyModifier(t *testing.T) {
	mod := NewModifier([]byte("text"), "text/plain")

	req, err := http.NewRequest("GET", "/", strings.NewReader(""))
	if err != nil {
		t.Fatalf("NewRequest(): got %v, want no error", err)
	}
	req.Header.Set("Content-Encoding", "gzip")

	if err := mod.ModifyRequest(req); err != nil {
		t.Fatalf("ModifyRequest(): got %v, want no error", err)
	}

	if got, want := req.Header.Get("Content-Type"), "text/plain"; got != want {
		t.Errorf("req.Header.Get(%q): got %v, want %v", "Content-Type", got, want)
	}
	if got, want := req.ContentLength, int64(len([]byte("text"))); got != want {
		t.Errorf("req.ContentLength: got %d, want %d", got, want)
	}
	if got, want := req.Header.Get("Content-Encoding"), ""; got != want {
		t.Errorf("req.Header.Get(%q): got %q, want %q", "Content-Encoding", got, want)
	}

	got, err := ioutil.ReadAll(req.Body)
	if err != nil {
		t.Fatalf("ioutil.ReadAll(): got %v, want no error", err)
	}
	req.Body.Close()

	if want := []byte("text"); !bytes.Equal(got, want) {
		t.Errorf("res.Body: got %q, want %q", got, want)
	}

	res := proxyutil.NewResponse(200, nil, req)
	res.Header.Set("Content-Encoding", "gzip")

	if err := mod.ModifyResponse(res); err != nil {
		t.Fatalf("ModifyResponse(): got %v, want no error", err)
	}

	if got, want := res.Header.Get("Content-Type"), "text/plain"; got != want {
		t.Errorf("res.Header.Get(%q): got %v, want %v", "Content-Type", got, want)
	}
	if got, want := res.ContentLength, int64(len([]byte("text"))); got != want {
		t.Errorf("res.ContentLength: got %d, want %d", got, want)
	}
	if got, want := res.Header.Get("Content-Encoding"), ""; got != want {
		t.Errorf("res.Header.Get(%q): got %q, want %q", "Content-Encoding", got, want)
	}

	got, err = ioutil.ReadAll(res.Body)
	if err != nil {
		t.Fatalf("ioutil.ReadAll(): got %v, want no error", err)
	}
	res.Body.Close()

	if want := []byte("text"); !bytes.Equal(got, want) {
		t.Errorf("res.Body: got %q, want %q", got, want)
	}
}
Example #13
0
func TestNoModifiers(t *testing.T) {
	m := NewModifier()
	m.SetRequestModifier(nil)
	m.SetResponseModifier(nil)

	req, err := http.NewRequest("GET", "http://example.com", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}

	ctx, err := session.FromContext(nil)
	if err != nil {
		t.Fatalf("session.FromContext(): got %v, want no error", err)
	}

	martian.SetContext(req, ctx)
	defer martian.RemoveContext(req)

	if err := m.ModifyRequest(req); err != nil {
		t.Errorf("ModifyRequest(): got %v, want no error", err)
	}

	res := proxyutil.NewResponse(200, nil, req)
	if err := m.ModifyResponse(res); err != nil {
		t.Errorf("ModifyResponse(): got %v, want no error", err)
	}
}
Example #14
0
func ExampleLogger() {
	l := NewLogger()
	l.SetLogFunc(func(line string) {
		// Remove \r to make it easier to test with examples.
		fmt.Print(strings.Replace(line, "\r", "", -1))
	})
	l.SetDecode(true)

	buf := new(bytes.Buffer)
	gw := gzip.NewWriter(buf)
	gw.Write([]byte("request content"))
	gw.Close()

	req, err := http.NewRequest("GET", "http://example.com/path?querystring", buf)
	if err != nil {
		fmt.Println(err)
		return
	}
	req.TransferEncoding = []string{"chunked"}
	req.Header.Set("Content-Encoding", "gzip")

	if err := l.ModifyRequest(req); err != nil {
		fmt.Println(err)
		return
	}

	res := proxyutil.NewResponse(200, strings.NewReader("response content"), req)
	res.ContentLength = 16
	res.Header.Set("Date", "Tue, 15 Nov 1994 08:12:31 GMT")
	res.Header.Set("Other-Header", "values")

	if err := l.ModifyResponse(res); err != nil {
		fmt.Println(err)
		return
	}
	// Output:
	// --------------------------------------------------------------------------------
	// Request to http://example.com/path?querystring
	// --------------------------------------------------------------------------------
	// GET http://example.com/path?querystring HTTP/1.1
	// Host: example.com
	// Transfer-Encoding: chunked
	// Content-Encoding: gzip
	//
	// request content
	//
	// --------------------------------------------------------------------------------
	//
	// --------------------------------------------------------------------------------
	// Response from http://example.com/path?querystring
	// --------------------------------------------------------------------------------
	// HTTP/1.1 200 OK
	// Content-Length: 16
	// Date: Tue, 15 Nov 1994 08:12:31 GMT
	// Other-Header: values
	//
	// response content
	// --------------------------------------------------------------------------------
}
Example #15
0
func TestModifyResponseNoModifier(t *testing.T) {
	m := NewModifier()
	res := proxyutil.NewResponse(200, nil, nil)

	if err := m.ModifyResponse(martian.NewContext(), res); err != nil {
		t.Errorf("ModifyResponse(): got %v, want no error", err)
	}
}
Example #16
0
func (p *Proxy) roundTrip(ctx *Context, req *http.Request) (*http.Response, error) {
	if ctx.SkippingRoundTrip() {
		log.Debugf("martian: skipping round trip")
		return proxyutil.NewResponse(200, nil, req), nil
	}

	return p.roundTripper.RoundTrip(req)
}
Example #17
0
func TestModifier(t *testing.T) {
	var reqrun bool
	var resrun bool

	moderr := errors.New("modifier error")
	tm := NewModifier()
	tm.RequestError(moderr)
	tm.RequestFunc(func(*http.Request) {
		reqrun = true
	})

	tm.ResponseError(moderr)
	tm.ResponseFunc(func(*http.Response) {
		resrun = true
	})

	req, err := http.NewRequest("GET", "http://example.com", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}

	if err := tm.ModifyRequest(req); err != moderr {
		t.Fatalf("tm.ModifyRequest(): got %v, want %v", err, moderr)
	}
	if !tm.RequestModified() {
		t.Errorf("tm.RequestModified(): got false, want true")
	}
	if tm.RequestCount() != 1 {
		t.Errorf("tm.RequestCount(): got %d, want %d", tm.RequestCount(), 1)
	}
	if !reqrun {
		t.Error("reqrun: got false, want true")
	}

	res := proxyutil.NewResponse(200, nil, req)

	if err := tm.ModifyResponse(res); err != moderr {
		t.Fatalf("tm.ModifyResponse(): got %v, want %v", err, moderr)
	}
	if !tm.ResponseModified() {
		t.Errorf("tm.ResponseModified(): got false, want true")
	}
	if tm.ResponseCount() != 1 {
		t.Errorf("tm.ResponseCount(): got %d, want %d", tm.ResponseCount(), 1)
	}
	if !resrun {
		t.Error("resrun: got false, want true")
	}

	tm.Reset()

	if tm.RequestModified() {
		t.Error("tm.RequestModified(): got true, want false")
	}
	if tm.ResponseModified() {
		t.Error("tm.ResponseModified(): got true, want false")
	}
}
Example #18
0
func TestResponseViewDecodeGzipContentEncoding(t *testing.T) {
	body := new(bytes.Buffer)
	gw := gzip.NewWriter(body)
	gw.Write([]byte("body content"))
	gw.Flush()
	gw.Close()

	res := proxyutil.NewResponse(200, body, nil)
	res.TransferEncoding = []string{"chunked"}
	res.Header.Set("Content-Encoding", "gzip")

	mv := New()
	if err := mv.SnapshotResponse(res); err != nil {
		t.Fatalf("SnapshotResponse(): got %v, want no error", err)
	}

	got, err := ioutil.ReadAll(mv.HeaderReader())
	if err != nil {
		t.Fatalf("ioutil.ReadAll(mv.HeaderReader()): got %v, want no error", err)
	}

	hdrwant := "HTTP/1.1 200 OK\r\n" +
		"Transfer-Encoding: chunked\r\n" +
		"Content-Encoding: gzip\r\n\r\n"

	if !bytes.Equal(got, []byte(hdrwant)) {
		t.Fatalf("mv.HeaderReader(): got %q, want %q", got, hdrwant)
	}

	br, err := mv.BodyReader(Decode())
	if err != nil {
		t.Fatalf("mv.BodyReader(): got %v, want no error", err)
	}

	got, err = ioutil.ReadAll(br)
	if err != nil {
		t.Fatalf("ioutil.ReadAll(mv.BodyReader()): got %v, wt o error", err)
	}

	bodywant := "body content"

	if !bytes.Equal(got, []byte(bodywant)) {
		t.Fatalf("mv.BodyReader(): got %q, want %q", got, bodywant)
	}

	r, err := mv.Reader(Decode())
	if err != nil {
		t.Fatalf("mv.Reader(): got %v, want no error", err)
	}
	got, err = ioutil.ReadAll(r)
	if err != nil {
		t.Fatalf("ioutil.ReadAll(mv.Reader()): got %v, want no error", err)
	}

	if want := []byte(hdrwant + bodywant + "\r\n"); !bytes.Equal(got, want) {
		t.Fatalf("mv.Read(): got %q, want %q", got, want)
	}
}
Example #19
0
func TestModifyResponse(t *testing.T) {
	f := NewFilter()

	tm := martiantest.NewModifier()
	f.SetResponseModifier("id", tm)

	req, err := http.NewRequest("GET", "http://example.com", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}
	res := proxyutil.NewResponse(200, nil, req)

	// No ID, auth required.
	f.SetAuthRequired(true)

	ctx, err := session.FromContext(nil)
	if err != nil {
		t.Fatalf("session.FromContext(): got %v, want no error", err)
	}

	martian.SetContext(req, ctx)
	defer martian.RemoveContext(req)

	if err := f.ModifyResponse(res); err != nil {
		t.Fatalf("ModifyResponse(): got %v, want no error", err)
	}

	actx := FromContext(ctx)

	if actx.Error() == nil {
		t.Error("actx.Error(): got nil, want error")
	}
	if tm.ResponseModified() {
		t.Error("tm.RequestModified(): got true, want false")
	}

	// No ID, no auth required.
	f.SetAuthRequired(false)
	actx.SetError(nil)

	if err := f.ModifyResponse(res); err != nil {
		t.Fatalf("ModifyResponse(): got %v, want no error", err)
	}
	if tm.ResponseModified() {
		t.Error("tm.ResponseModified(): got true, want false")
	}

	// Valid ID.
	actx.SetID("id")

	if err := f.ModifyResponse(res); err != nil {
		t.Fatalf("ModifyResponse(): got %v, want no error", err)
	}
	if !tm.ResponseModified() {
		t.Error("tm.ResponseModified(): got false, want true")
	}
}
func TestFilterFromJSON(t *testing.T) {
	msg := []byte(`
	{
		"querystring.Filter": {
      "scope": ["request", "response"],
      "name": "param",
      "value": "true",
      "modifier": {
        "header.Modifier": {
          "scope": ["request", "response"],
          "name": "Mod-Run",
          "value": "true"
        }
      }
    }
	}`)

	r, err := parse.FromJSON(msg)
	if err != nil {
		t.Fatalf("parse.FromJSON(): got %v, want no error", err)
	}

	reqmod := r.RequestModifier()

	if reqmod == nil {
		t.Fatal("reqmod: got nil, want not nil")
	}

	req, err := http.NewRequest("GET", "https://martian.test?param=true", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}

	ctx := martian.NewContext()
	if err := reqmod.ModifyRequest(ctx, req); err != nil {
		t.Fatalf("reqmod.ModifyRequest(): got %v, want no error", err)
	}

	if got, want := req.Header.Get("Mod-Run"), "true"; got != want {
		t.Errorf("req.Header.Get(%q): got %q, want %q", "Mod-Run", got, want)
	}

	resmod := r.ResponseModifier()

	if resmod == nil {
		t.Fatalf("resmod: got nil, want not nil")
	}

	res := proxyutil.NewResponse(200, nil, req)
	if err := resmod.ModifyResponse(ctx, res); err != nil {
		t.Fatalf("resmod.ModifyResponse(): got %v, want no error", err)
	}

	if got, want := res.Header.Get("Mod-Run"), "true"; got != want {
		t.Errorf("res.Header.Get(%q): got %q, want %q", "Mod-Run", got, want)
	}
}
Example #21
0
// handleRequest runs the request and response modifiers and performs the roundtrip to the destination server.
func (p *Proxy) handleRequest(ctx *Context, rw *bufio.ReadWriter, req *http.Request) (closing bool) {
	if err := proxyutil.FixBadFraming(req.Header); err != nil {
		Errorf("proxyutil.FixBadFraming(): %v", err)
		proxyutil.NewErrorResponse(400, err, req).Write(rw)
	}

	proxyutil.SetForwardedHeaders(req)
	proxyutil.SetViaHeader(req.Header, "1.1 martian")

	if err := p.ModifyRequest(ctx, req); err != nil {
		Errorf("martian.ModifyRequest(): %v", err)
		proxyutil.NewErrorResponse(400, err, req).Write(rw)
		return
	}

	if shouldCloseAfterReply(req.Header) {
		Debugf("closing after reply")
		closing = true
	}

	proxyutil.RemoveHopByHopHeaders(req.Header)

	var res *http.Response
	var err error
	if !ctx.SkipRoundTrip {
		Debugf("proceed to round trip for %s", req.URL)

		res, err = p.RoundTripper.RoundTrip(req)
		if err != nil {
			Errorf("RoundTripper.RoundTrip(%s): %v", req.URL, err)
			proxyutil.NewErrorResponse(502, err, req).Write(rw)
			return
		}
	} else {
		Debugf("skipped round trip for %s", req.URL)
		res = proxyutil.NewResponse(200, nil, req)
	}

	proxyutil.RemoveHopByHopHeaders(res.Header)

	if err := p.ModifyResponse(ctx, res); err != nil {
		Errorf("martian.ModifyResponse(): %v", err)
		proxyutil.NewErrorResponse(400, err, req).Write(rw)
		return
	}

	if closing {
		res.Header.Set("Connection", "close")
		res.Close = true
	}

	if err := res.Write(rw); err != nil {
		Errorf("res.Write(): %v", err)
	}

	return
}
Example #22
0
func TestResponseView(t *testing.T) {
	body := strings.NewReader("body content")
	res := proxyutil.NewResponse(200, body, nil)
	res.ContentLength = 12
	res.Header.Set("Response-Header", "true")

	mv := New()
	if err := mv.SnapshotResponse(res); err != nil {
		t.Fatalf("SnapshotResponse(): got %v, want no error", err)
	}

	got, err := ioutil.ReadAll(mv.HeaderReader())
	if err != nil {
		t.Fatalf("ioutil.ReadAll(mv.HeaderReader()): got %v, want no error", err)
	}

	hdrwant := "HTTP/1.1 200 OK\r\n" +
		"Content-Length: 12\r\n" +
		"Response-Header: true\r\n\r\n"

	if !bytes.Equal(got, []byte(hdrwant)) {
		t.Fatalf("mv.HeaderReader(): got %q, want %q", got, hdrwant)
	}

	br, err := mv.BodyReader()
	if err != nil {
		t.Fatalf("mv.BodyReader(): got %v, want no error", err)
	}

	got, err = ioutil.ReadAll(br)
	if err != nil {
		t.Fatalf("ioutil.ReadAll(mv.BodyReader()): got %v, want no error", err)
	}

	bodywant := "body content"
	if !bytes.Equal(got, []byte(bodywant)) {
		t.Fatalf("mv.BodyReader(): got %q, want %q", got, bodywant)
	}

	r, err := mv.Reader()
	if err != nil {
		t.Fatalf("mv.Reader(): got %v, want no error", err)
	}
	got, err = ioutil.ReadAll(r)
	if err != nil {
		t.Fatalf("ioutil.ReadAll(mv.Reader()): got %v, want no error", err)
	}

	if want := []byte(hdrwant + bodywant); !bytes.Equal(got, want) {
		t.Fatalf("mv.Read(): got %q, want %q", got, want)
	}

	// Sanity check to ensure it still parses.
	if _, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(got)), nil); err != nil {
		t.Fatalf("http.ReadResponse(): got %v, want no error", err)
	}
}
Example #23
0
func TestProxyAuth(t *testing.T) {
	m := NewModifier()
	tm := martiantest.NewModifier()
	m.SetRequestModifier(tm)
	m.SetResponseModifier(tm)

	req, err := http.NewRequest("GET", "http://example.com", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}
	req.Header.Set("Proxy-Authorization", "Basic "+encode("user:pass"))

	ctx, err := session.FromContext(nil)
	if err != nil {
		t.Fatalf("session.FromContext(): got %v, want no error", err)
	}

	martian.SetContext(req, ctx)
	defer martian.RemoveContext(req)

	if err := m.ModifyRequest(req); err != nil {
		t.Fatalf("ModifyRequest(): got %v, want no error", err)
	}

	actx := auth.FromContext(ctx)
	if got, want := actx.ID(), "user:pass"; got != want {
		t.Fatalf("actx.ID(): got %q, want %q", got, want)
	}

	if err := m.ModifyRequest(req); err != nil {
		t.Fatalf("ModifyRequest(): got %v, want no error", err)
	}
	if !tm.RequestModified() {
		t.Error("tm.RequestModified(): got false, want true")
	}

	res := proxyutil.NewResponse(200, nil, req)

	if err := m.ModifyResponse(res); err != nil {
		t.Fatalf("ModifyResponse(): got %v, want no error", err)
	}

	if err := m.ModifyResponse(res); err != nil {
		t.Fatalf("ModifyResponse(): got %v, want no error", err)
	}

	if !tm.ResponseModified() {
		t.Error("tm.ResponseModified(): got false, want true")
	}

	if got, want := res.StatusCode, 200; got != want {
		t.Errorf("res.StatusCode: got %d, want %d", got, want)
	}
	if got, want := res.Header.Get("Proxy-Authenticate"), ""; got != want {
		t.Errorf("res.Header.Get(%q): got %q, want %q", "Proxy-Authenticate", got, want)
	}
}
Example #24
0
func TestModifyResponse(t *testing.T) {
	f := NewFilter("mARTian-teSTInG", "true")
	f.SetResponseModifier(nil)

	res := proxyutil.NewResponse(200, nil, nil)

	if err := f.ModifyResponse(res); err != nil {
		t.Errorf("ModifyResponse(): got %v, want no error", err)
	}

	tt := []struct {
		name   string
		values []string
		want   bool
	}{
		{
			name:   "Martian-Production",
			values: []string{"true"},
			want:   false,
		},
		{
			name:   "Martian-Testing",
			values: []string{"see-next-value", "true"},
			want:   true,
		},
	}

	for i, tc := range tt {
		f := NewFilter("mARTian-teSTInG", "true")
		tm := martiantest.NewModifier()
		f.SetResponseModifier(tm)

		res := proxyutil.NewResponse(200, nil, nil)
		res.Header[tc.name] = tc.values

		if err := f.ModifyResponse(res); err != nil {
			t.Fatalf("%d. ModifyResponse(): got %v, want no error", i, err)
		}

		if tm.ResponseModified() != tc.want {
			t.Errorf("%d. tm.ResponseModified(): got %t, want %t", i, tm.ResponseModified(), tc.want)
		}
	}
}
Example #25
0
func TestRemoveHopByHopHeaders(t *testing.T) {
	m := NewHopByHopModifier()
	req, err := http.NewRequest("GET", "/", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}
	hs := http.Header{
		// Additional hop-by-hop headers are listed in the
		// Connection header.
		"Connection": []string{
			"X-Connection",
			"X-Hop-By-Hop, close",
		},

		// RFC hop-by-hop headers.
		"Keep-Alive":          []string{},
		"Proxy-Authenticate":  []string{},
		"Proxy-Authorization": []string{},
		"Te":                []string{},
		"Trailer":           []string{},
		"Transfer-Encoding": []string{},
		"Upgrade":           []string{},

		// Hop-by-hop headers listed in the Connection header.
		"X-Connection": []string{},
		"X-Hop-By-Hop": []string{},

		// End-to-end header that should not be removed.
		"X-End-To-End": []string{},
	}

	req.Header = hs
	if err := m.ModifyRequest(req); err != nil {
		t.Fatalf("ModifyRequest(): got %v, want no error", err)
	}

	if got, want := len(req.Header), 1; got != want {
		t.Fatalf("len(req.Header): got %d, want %d", got, want)
	}
	if _, ok := req.Header["X-End-To-End"]; !ok {
		t.Errorf("req.Header[%q]: got !ok, want ok", "X-End-To-End")
	}

	res := proxyutil.NewResponse(200, nil, req)
	res.Header = hs
	if err := m.ModifyResponse(res); err != nil {
		t.Fatalf("ModifyResponse(): got %v, want no error", err)
	}

	if got, want := len(res.Header), 1; got != want {
		t.Fatalf("len(res.Header): got %d, want %d", got, want)
	}
	if _, ok := res.Header["X-End-To-End"]; !ok {
		t.Errorf("res.Header[%q]: got !ok, want ok", "X-End-To-End")
	}
}
Example #26
0
func TestModifyResponse(t *testing.T) {
	m := NewModifier()
	m.SetResponseModifier(nil)

	req, err := http.NewRequest("GET", "http://example.com", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}

	ctx, err := session.FromContext(nil)
	if err != nil {
		t.Fatalf("session.FromContext(): got %v, want no error", err)
	}

	martian.SetContext(req, ctx)
	defer martian.RemoveContext(req)

	res := proxyutil.NewResponse(200, nil, req)
	if err := m.ModifyResponse(res); err != nil {
		t.Fatalf("ModifyResponse(): got %v, want no error", err)
	}

	// Modifier with error.
	tm := martiantest.NewModifier()
	reserr := errors.New("response error")
	tm.ResponseError(reserr)

	m.SetResponseModifier(tm)
	if err := m.ModifyResponse(res); err != reserr {
		t.Fatalf("ModifyResponse(): got %v, want %v", err, reserr)
	}

	// Modifier with auth error.
	tm.ResponseError(nil)
	autherr := errors.New("auth error")
	tm.ResponseFunc(func(res *http.Response) {
		ctx := martian.Context(res.Request)
		actx := auth.FromContext(ctx)

		actx.SetError(autherr)
	})

	actx := auth.FromContext(ctx)
	actx.SetID("bad-auth")

	if err := m.ModifyResponse(res); err != nil {
		t.Fatalf("ModifyResponse(): got %v, want no error", err)
	}
	if got, want := res.StatusCode, 403; got != want {
		t.Errorf("res.StatusCode: got %d, want %d", got, want)
	}

	if got, want := actx.Error(), autherr; got != want {
		t.Errorf("actx.Error(): got %v, want %v", got, want)
	}
}
Example #27
0
func TestServeHTTPBuildsValidRequest(t *testing.T) {
	p := NewProxy(mitm)
	p.RoundTripper = RoundTripFunc(func(req *http.Request) (*http.Response, error) {
		if got, want := req.URL.Scheme, "https"; got != want {
			t.Errorf("req.URL.Scheme: got %q, want %q", got, want)
		}
		if got, want := req.URL.Host, "www.example.com"; got != want {
			t.Errorf("req.URL.Host: got %q, want %q", got, want)
		}
		if req.RemoteAddr == "" {
			t.Error("req.RemoteAddr: got empty, want addr")
		}

		return proxyutil.NewResponse(201, nil, req), nil
	})

	rc, wc := pipeWithTimeout()
	defer rc.Close()
	defer wc.Close()

	rw := newHijackRecorder(wc)
	req, err := http.NewRequest("CONNECT", "//www.example.com:443", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}

	go p.ServeHTTP(rw, req)

	res, err := http.ReadResponse(bufio.NewReader(rc), req)
	if err != nil {
		t.Fatalf("http.ReadResponse(): got %v, want no error", err)
	}
	res.Body.Close()

	tlsConn := tlsClient(rc, p.mitm.Authority, "www.example.com")
	req, err = http.NewRequest("GET", "https://www.example.com", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no erro", err)
	}
	req.Header.Set("Connection", "close")

	if err := req.Write(tlsConn); err != nil {
		t.Fatalf("req.Write(): got %v, want no error", err)
	}

	res, err = http.ReadResponse(bufio.NewReader(tlsConn), nil)
	if err != nil {
		t.Fatalf("http.ReadResponse(): got %v, want no error", err)
	}
	defer res.Body.Close()

	if got, want := res.StatusCode, 201; got != want {
		t.Errorf("res.StatusCode: got %d, want %d", got, want)
	}
}
Example #28
0
func TestModifyResponse(t *testing.T) {
	f := NewFilter()

	modifierRun := false
	f.SetResponseModifier("id", martian.ResponseModifierFunc(
		func(*martian.Context, *http.Response) error {
			modifierRun = true
			return nil
		}))

	res := proxyutil.NewResponse(200, nil, nil)
	ctx := martian.NewContext()

	// No ID, auth required.
	f.SetAuthRequired(true)

	if err := f.ModifyResponse(ctx, res); err != nil {
		t.Fatalf("ModifyResponse(): got %v, want no error", err)
	}
	if ctx.Auth.Error == nil {
		t.Error("ctx.Auth.Error: got nil, want error")
	}
	if modifierRun {
		t.Error("modifierRun: got true, want false")
	}

	// No ID, no auth required.
	f.SetAuthRequired(false)
	ctx.Auth.Error = nil

	if err := f.ModifyResponse(ctx, res); err != nil {
		t.Fatalf("ModifyResponse(): got %v, want no error", err)
	}
	if ctx.Auth.Error != nil {
		t.Errorf("ctx.Auth.Error: got %v, want no error", ctx.Auth.Error)
	}
	if modifierRun {
		t.Error("modifierRun: got true, want false")
	}

	// Valid ID.
	ctx.Auth.ID = "id"
	ctx.Auth.Error = nil

	if err := f.ModifyResponse(ctx, res); err != nil {
		t.Fatalf("ModifyResponse(): got %v, want no error", err)
	}
	if ctx.Auth.Error != nil {
		t.Errorf("ctx.Auth.Error: got %v, want no error", ctx.Auth.Error)
	}
	if !modifierRun {
		t.Error("modifierRun: got false, want true")
	}
}
Example #29
0
func ExampleLogger() {
	l := NewLogger()
	l.IncludeBody(true)
	l.SetLogFunc(func(line string) {
		// Remove \r to make it easier to test with examples.
		fmt.Print(strings.Replace(line, "\r", "", -1))
	})

	req, err := http.NewRequest("GET", "http://example.com/path?querystring", strings.NewReader("request content"))
	if err != nil {
		fmt.Println(err)
		return
	}
	req.RequestURI = req.URL.RequestURI()
	req.Header.Set("Other-Header", "values")
	req.Close = true

	if err := l.ModifyRequest(req); err != nil {
		fmt.Println(err)
		return
	}

	res := proxyutil.NewResponse(200, strings.NewReader("response content"), req)
	res.ContentLength = 16
	res.Header.Set("Date", "Tue, 15 Nov 1994 08:12:31 GMT")
	res.Header.Set("Other-Header", "values")

	if err := l.ModifyResponse(res); err != nil {
		fmt.Println(err)
		return
	}
	// Output:
	// --------------------------------------------------------------------------------
	// Request to http://example.com/path?querystring
	// --------------------------------------------------------------------------------
	// GET /path?querystring HTTP/1.1
	// Host: example.com
	// Connection: close
	// Other-Header: values
	//
	// request content
	// --------------------------------------------------------------------------------
	//
	// --------------------------------------------------------------------------------
	// Response from http://example.com/path?querystring
	// --------------------------------------------------------------------------------
	// HTTP/1.1 200 OK
	// Content-Length: 16
	// Date: Tue, 15 Nov 1994 08:12:31 GMT
	// Other-Header: values
	//
	// response content
	// --------------------------------------------------------------------------------
}
Example #30
0
// CopyHeaders sets the transport to respond with a 200 OK response with
// headers copied from the request to the response verbatim.
func (tr *Transport) CopyHeaders(names ...string) {
	tr.rtfunc = func(req *http.Request) (*http.Response, error) {
		res := proxyutil.NewResponse(200, nil, req)

		for _, n := range names {
			res.Header.Set(n, req.Header.Get(n))
		}

		return res, nil
	}
}