Example #1
0
func TestExportIgnoresOrphanedResponse(t *testing.T) {
	logger := NewLogger("martian", "2.0.0")

	req, err := http.NewRequest("GET", "http://example.com", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}

	ctx, err := session.FromContext(nil)
	if err != nil {
		t.Fatalf("session.FromContext(): got %v, want no error", err)
	}
	martian.SetContext(req, ctx)
	defer martian.RemoveContext(req)

	if err := logger.ModifyRequest(req); err != nil {
		t.Fatalf("ModifyRequest(): got %v, want no error", err)
	}

	// Reset before the response comes back.
	logger.Reset()

	res := proxyutil.NewResponse(200, nil, req)
	if err := logger.ModifyResponse(res); err != nil {
		t.Fatalf("ModifyResponse(): got %v, want no error", err)
	}

	log := logger.Export().Log
	if got, want := len(log.Entries), 0; got != want {
		t.Errorf("len(log.Entries): got %d, want %d", got, want)
	}
}
Example #2
0
func TestModifyRequest(t *testing.T) {
	f := NewFilter()

	tm := martiantest.NewModifier()
	f.SetRequestModifier("id", tm)

	req, err := http.NewRequest("GET", "http://example.com", nil)
	if err != nil {
		t.Fatalf("NewRequest(): got %v, want no error", err)
	}

	// No ID, auth required.
	f.SetAuthRequired(true)

	ctx := session.FromContext(nil)
	martian.SetContext(req, ctx)
	defer martian.RemoveContext(req)

	if err := f.ModifyRequest(req); err != nil {
		t.Fatalf("ModifyRequest(): got %v, want no error", err)
	}

	actx := FromContext(ctx)
	if actx.Error() == nil {
		t.Error("actx.Error(): got nil, want error")
	}
	if tm.RequestModified() {
		t.Error("tm.RequestModified(): got true, want false")
	}
	tm.Reset()

	// No ID, auth not required.
	f.SetAuthRequired(false)
	actx.SetError(nil)

	if err := f.ModifyRequest(req); err != nil {
		t.Fatalf("ModifyRequest(): got %v, want no error", err)
	}

	if actx.Error() != nil {
		t.Errorf("actx.Error(): got %v, want no error", err)
	}
	if tm.RequestModified() {
		t.Error("tm.RequestModified(): got true, want false")
	}

	// Valid ID.
	actx.SetError(nil)
	actx.SetID("id")

	if err := f.ModifyRequest(req); err != nil {
		t.Fatalf("ModifyRequest(): got %v, want no error", err)
	}
	if actx.Error() != nil {
		t.Errorf("actx.Error(): got %v, want no error", actx.Error())
	}
	if !tm.RequestModified() {
		t.Error("tm.RequestModified(): got false, want true")
	}
}
Example #3
0
func TestReset(t *testing.T) {
	logger := NewLogger("martian", "2.0.0")

	req, err := http.NewRequest("GET", "http://example.com", nil)
	if err != nil {
		t.Fatalf("NewRequest(): got %v, want no error", err)
	}

	ctx, err := session.FromContext(nil)
	if err != nil {
		t.Fatalf("session.FromContext(): got %v, want no error", err)
	}
	martian.SetContext(req, ctx)
	defer martian.RemoveContext(req)

	if err := logger.ModifyRequest(req); err != nil {
		t.Fatalf("ModifyRequest(): got %v, want no error", err)
	}

	log := logger.Export().Log
	if got, want := len(log.Entries), 1; got != want {
		t.Fatalf("len(log.Entries): got %d, want %d", got, want)
	}

	logger.Reset()

	log = logger.Export().Log
	if got, want := len(log.Entries), 0; got != want {
		t.Errorf("len(log.Entries): got %d, want %d", got, want)
	}
}
Example #4
0
func TestContext(t *testing.T) {
	t.Parallel()

	req, err := http.NewRequest("GET", "http://example.com", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}

	want, err := session.FromContext(nil)
	if err != nil {
		t.Fatalf("session.FromContext(): got %v, want no error", err)
	}

	SetContext(req, want)

	if got := Context(req); got != want {
		t.Errorf("Context(req): got %v, want %v", got, want)
	}

	RemoveContext(req)

	if got := Context(req); got != nil {
		t.Errorf("Context(req): got %v, want nil", got)
	}
}
Example #5
0
func TestNoModifiers(t *testing.T) {
	m := NewModifier()
	m.SetRequestModifier(nil)
	m.SetResponseModifier(nil)

	req, err := http.NewRequest("GET", "http://example.com", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}

	ctx, err := session.FromContext(nil)
	if err != nil {
		t.Fatalf("session.FromContext(): got %v, want no error", err)
	}

	martian.SetContext(req, ctx)
	defer martian.RemoveContext(req)

	if err := m.ModifyRequest(req); err != nil {
		t.Errorf("ModifyRequest(): got %v, want no error", err)
	}

	res := proxyutil.NewResponse(200, nil, req)
	if err := m.ModifyResponse(res); err != nil {
		t.Errorf("ModifyResponse(): got %v, want no error", err)
	}
}
Example #6
0
func TestExportSortsEntries(t *testing.T) {
	logger := NewLogger("martian", "2.0.0")
	count := 10

	for i := 0; i < count; i++ {
		req, err := http.NewRequest("GET", "http://example.com", nil)
		if err != nil {
			t.Fatalf("NewRequest(): got %v, want no error", err)
		}

		ctx, err := session.FromContext(nil)
		if err != nil {
			t.Fatalf("session.FromContext(): got %v, want no error", err)
		}
		martian.SetContext(req, ctx)
		defer martian.RemoveContext(req)

		if err := logger.ModifyRequest(req); err != nil {
			t.Fatalf("ModifyRequest(): got %v, want no error", err)
		}
	}

	log := logger.Export().Log

	for i := 0; i < count-1; i++ {
		first := log.Entries[i]
		second := log.Entries[i+1]

		if got, want := first.StartedDateTime, second.StartedDateTime; got.After(want) {
			t.Errorf("entry.StartedDateTime: got %s, want to be before %s", got, want)
		}
	}
}
Example #7
0
func TestModifyRequest(t *testing.T) {
	m := NewModifier()
	m.SetRequestModifier(nil)

	req, err := http.NewRequest("CONNECT", "https://www.example.com", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}

	ctx := session.FromContext(nil)
	martian.SetContext(req, ctx)
	defer martian.RemoveContext(req)

	if err := m.ModifyRequest(req); err != nil {
		t.Fatalf("ModifyRequest(): got %v, want no error", err)
	}

	actx := auth.FromContext(ctx)

	if got, want := actx.ID(), ""; got != want {
		t.Errorf("actx.ID(): got %q, want %q", got, want)
	}

	// IP with port and modifier with error.
	tm := martiantest.NewModifier()
	reqerr := errors.New("request error")
	tm.RequestError(reqerr)

	req.RemoteAddr = "1.1.1.1:8111"
	m.SetRequestModifier(tm)

	if err := m.ModifyRequest(req); err != reqerr {
		t.Fatalf("ModifyConnectRequest(): got %v, want %v", err, reqerr)
	}

	if got, want := actx.ID(), "1.1.1.1"; got != want {
		t.Errorf("actx.ID(): got %q, want %q", got, want)
	}

	// IP without port and modifier with auth error.
	req.RemoteAddr = "4.4.4.4"

	autherr := errors.New("auth error")
	tm.RequestError(nil)
	tm.RequestFunc(func(req *http.Request) {
		ctx := martian.Context(req)
		actx := auth.FromContext(ctx)

		actx.SetError(autherr)
	})

	if err := m.ModifyRequest(req); err != nil {
		t.Fatalf("ModifyRequest(): got %v, want no error", err)
	}

	if got, want := actx.ID(), ""; got != want {
		t.Errorf("actx.ID(): got %q, want %q", got, want)
	}
}
Example #8
0
func TestProxyAuth(t *testing.T) {
	m := NewModifier()
	tm := martiantest.NewModifier()
	m.SetRequestModifier(tm)
	m.SetResponseModifier(tm)

	req, err := http.NewRequest("GET", "http://example.com", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}
	req.Header.Set("Proxy-Authorization", "Basic "+encode("user:pass"))

	ctx, err := session.FromContext(nil)
	if err != nil {
		t.Fatalf("session.FromContext(): got %v, want no error", err)
	}

	martian.SetContext(req, ctx)
	defer martian.RemoveContext(req)

	if err := m.ModifyRequest(req); err != nil {
		t.Fatalf("ModifyRequest(): got %v, want no error", err)
	}

	actx := auth.FromContext(ctx)
	if got, want := actx.ID(), "user:pass"; got != want {
		t.Fatalf("actx.ID(): got %q, want %q", got, want)
	}

	if err := m.ModifyRequest(req); err != nil {
		t.Fatalf("ModifyRequest(): got %v, want no error", err)
	}
	if !tm.RequestModified() {
		t.Error("tm.RequestModified(): got false, want true")
	}

	res := proxyutil.NewResponse(200, nil, req)

	if err := m.ModifyResponse(res); err != nil {
		t.Fatalf("ModifyResponse(): got %v, want no error", err)
	}

	if err := m.ModifyResponse(res); err != nil {
		t.Fatalf("ModifyResponse(): got %v, want no error", err)
	}

	if !tm.ResponseModified() {
		t.Error("tm.ResponseModified(): got false, want true")
	}

	if got, want := res.StatusCode, 200; got != want {
		t.Errorf("res.StatusCode: got %d, want %d", got, want)
	}
	if got, want := res.Header.Get("Proxy-Authenticate"), ""; got != want {
		t.Errorf("res.Header.Get(%q): got %q, want %q", "Proxy-Authenticate", got, want)
	}
}
Example #9
0
func TestModifyResponse(t *testing.T) {
	m := NewModifier()
	m.SetResponseModifier(nil)

	req, err := http.NewRequest("GET", "http://example.com", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}

	ctx, err := session.FromContext(nil)
	if err != nil {
		t.Fatalf("session.FromContext(): got %v, want no error", err)
	}

	martian.SetContext(req, ctx)
	defer martian.RemoveContext(req)

	res := proxyutil.NewResponse(200, nil, req)
	if err := m.ModifyResponse(res); err != nil {
		t.Fatalf("ModifyResponse(): got %v, want no error", err)
	}

	// Modifier with error.
	tm := martiantest.NewModifier()
	reserr := errors.New("response error")
	tm.ResponseError(reserr)

	m.SetResponseModifier(tm)
	if err := m.ModifyResponse(res); err != reserr {
		t.Fatalf("ModifyResponse(): got %v, want %v", err, reserr)
	}

	// Modifier with auth error.
	tm.ResponseError(nil)
	autherr := errors.New("auth error")
	tm.ResponseFunc(func(res *http.Response) {
		ctx := martian.Context(res.Request)
		actx := auth.FromContext(ctx)

		actx.SetError(autherr)
	})

	actx := auth.FromContext(ctx)
	actx.SetID("bad-auth")

	if err := m.ModifyResponse(res); err != nil {
		t.Fatalf("ModifyResponse(): got %v, want no error", err)
	}
	if got, want := res.StatusCode, 403; got != want {
		t.Errorf("res.StatusCode: got %d, want %d", got, want)
	}

	if got, want := actx.Error(), autherr; got != want {
		t.Errorf("actx.Error(): got %v, want %v", got, want)
	}
}
Example #10
0
func TestModifyRequestBodyURLEncoded(t *testing.T) {
	logger := NewLogger("martian", "2.0.0")

	body := strings.NewReader("first=true&second=false")
	req, err := http.NewRequest("POST", "http://example.com", body)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}
	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")

	ctx, err := session.FromContext(nil)
	if err != nil {
		t.Fatalf("session.FromContext(): got %v, want no error", err)
	}
	martian.SetContext(req, ctx)
	defer martian.RemoveContext(req)

	if err := logger.ModifyRequest(req); err != nil {
		t.Fatalf("ModifyRequest(): got %v, want no error", err)
	}

	log := logger.Export().Log
	if got, want := len(log.Entries), 1; got != want {
		t.Errorf("len(log.Entries): got %v, want %v", got, want)
	}

	pd := log.Entries[0].Request.PostData
	if got, want := pd.MimeType, "application/x-www-form-urlencoded"; got != want {
		t.Errorf("PostData.MimeType: got %v, want %v", got, want)
	}

	if got, want := len(pd.Params), 2; got != want {
		t.Fatalf("len(PostData.Params): got %d, want %d", got, want)
	}

	for _, p := range pd.Params {
		var want string
		switch p.Name {
		case "first":
			want = "true"
		case "second":
			want = "false"
		default:
			t.Errorf("PostData.Params: got %q, want to not be present", p.Name)
			continue
		}

		if got := p.Value; got != want {
			t.Errorf("PostData.Params[%q]: got %q, want %q", p.Name, got, want)
		}
	}
}
Example #11
0
func TestViaModifier(t *testing.T) {
	m := NewViaModifier("martian")
	req, err := http.NewRequest("GET", "/", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}
	res := proxyutil.NewResponse(200, nil, req)

	ctx := session.FromContext(nil)
	martian.SetContext(req, ctx)
	defer martian.RemoveContext(req)

	if err := m.ModifyRequest(req); err != nil {
		t.Fatalf("ModifyRequest(): got %v, want no error", err)
	}
	if got, want := req.Header.Get("Via"), "1.1 martian"; got != want {
		t.Errorf("req.Header.Get(%q): got %q, want %q", "Via", got, want)
	}

	if err := m.ModifyResponse(res); err != nil {
		t.Fatalf("ModifyResponse(): got %v, want no error", err)
	}

	req.Header.Set("Via", "1.0 alpha")
	if err := m.ModifyRequest(req); err != nil {
		t.Fatalf("ModifyRequest(): got %v, want no error", err)
	}
	if got, want := req.Header.Get("Via"), "1.0 alpha, 1.1 martian"; got != want {
		t.Errorf("req.Header.Get(%q): got %q, want %q", "Via", got, want)
	}

	req.Header.Set("Via", "1.0 alpha, 1.1 martian, 1.1 beta")
	if err := m.ModifyRequest(req); err == nil {
		t.Fatal("ModifyRequest(): got nil, want request loop error")
	}
	if !ctx.SkippingRoundTrip() {
		t.Errorf("ctx.SkippingRoundTrip(): got false, want true")
	}

	if err := m.ModifyResponse(res); err == nil {
		t.Fatal("ModifyResponse(): got nil, want request loop error")
	}
	if got, want := res.StatusCode, 400; got != want {
		t.Errorf("res.StatusCode: got %d, want %d", got, want)
	}
	if got, want := res.Status, http.StatusText(400); got != want {
		t.Errorf("res.Status: got %q, want %q", got, want)
	}
}
Example #12
0
func (p *Proxy) handleLoop(conn net.Conn) {
	p.conns.Add(1)
	defer p.conns.Done()
	defer conn.Close()

	ctx := session.FromContext(nil)
	brw := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))

	for {
		deadline := time.Now().Add(p.timeout)
		conn.SetDeadline(deadline)

		if err := p.handle(ctx, conn, brw); isCloseable(err) {
			Infof("martian: closing connection: %v", conn.RemoteAddr())
			return
		}
	}
}
Example #13
0
func TestHARExportsTime(t *testing.T) {
	logger := NewLogger("martian", "2.0.0")

	req, err := http.NewRequest("GET", "http://example.com", nil)
	if err != nil {
		t.Fatalf("NewRequest(): got %v, want no error", err)
	}

	ctx, err := session.FromContext(nil)
	if err != nil {
		t.Fatalf("session.FromContext(): got %v, want no error", err)
	}
	martian.SetContext(req, ctx)
	defer martian.RemoveContext(req)

	if err := logger.ModifyRequest(req); err != nil {
		t.Fatalf("ModifyRequest(): got %v, want no error", err)
	}

	// Simulate fast network round trip.
	time.Sleep(10 * time.Millisecond)

	res := proxyutil.NewResponse(200, nil, req)

	if err := logger.ModifyResponse(res); err != nil {
		t.Fatalf("ModifyResponse(): got %v, want no error", err)
	}

	log := logger.Export().Log
	if got, want := len(log.Entries), 1; got != want {
		t.Fatalf("len(log.Entries): got %v, want %v", got, want)
	}

	entry := log.Entries[0]
	min, max := int64(10), int64(100)
	if got := entry.Time; got < min || got > max {
		t.Errorf("entry.Time: got %dms, want between %dms and %vms", got, min, max)
	}
}
Example #14
0
func TestModifyRequestBodyArbitraryContentType(t *testing.T) {
	logger := NewLogger("martian", "2.0.0")

	body := "arbitrary binary data"
	req, err := http.NewRequest("POST", "http://www.example.com", strings.NewReader(body))
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}

	ctx, err := session.FromContext(nil)
	if err != nil {
		t.Fatalf("session.FromContext(): got %v, want no error", err)
	}
	martian.SetContext(req, ctx)
	defer martian.RemoveContext(req)

	if err := logger.ModifyRequest(req); err != nil {
		t.Fatalf("ModifyRequest(): got %v, want no error", err)
	}

	log := logger.Export().Log
	if got, want := len(log.Entries), 1; got != want {
		t.Errorf("len(log.Entries): got %d, want %d", got, want)
	}

	pd := log.Entries[0].Request.PostData
	if got, want := pd.MimeType, ""; got != want {
		t.Errorf("PostData.MimeType: got %q, want %q", got, want)
	}
	if got, want := len(pd.Params), 0; got != want {
		t.Errorf("len(PostData.Params): got %d, want %d", got, want)
	}

	if got, want := pd.Text, body; got != want {
		t.Errorf("PostData.Text: got %q, want %q", got, want)
	}
}
Example #15
0
func TestModifyRequestBodyMultipart(t *testing.T) {
	logger := NewLogger("martian", "2.0.0")

	body := new(bytes.Buffer)
	mpw := multipart.NewWriter(body)
	mpw.SetBoundary("boundary")

	if err := mpw.WriteField("key", "value"); err != nil {
		t.Errorf("mpw.WriteField(): got %v, want no error", err)
	}

	w, err := mpw.CreateFormFile("file", "test.txt")
	if _, err = w.Write([]byte("file contents")); err != nil {
		t.Fatalf("Write(): got %v, want no error", err)
	}
	mpw.Close()

	req, err := http.NewRequest("POST", "http://example.com", body)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}
	req.Header.Set("Content-Type", mpw.FormDataContentType())

	ctx, err := session.FromContext(nil)
	if err != nil {
		t.Fatalf("session.FromContext(): got %v, want no error", err)
	}
	martian.SetContext(req, ctx)
	defer martian.RemoveContext(req)

	if err := logger.ModifyRequest(req); err != nil {
		t.Fatalf("ModifyRequest(): got %v, want no error", err)
	}

	log := logger.Export().Log
	if got, want := len(log.Entries), 1; got != want {
		t.Fatalf("len(log.Entries): got %d, want %d", got, want)
	}

	pd := log.Entries[0].Request.PostData
	if got, want := pd.MimeType, "multipart/form-data"; got != want {
		t.Errorf("PostData.MimeType: got %q, want %q", got, want)
	}
	if got, want := len(pd.Params), 2; got != want {
		t.Errorf("PostData.Params: got %d, want %d", got, want)
	}

	for _, p := range pd.Params {
		var want Param

		switch p.Name {
		case "key":
			want = Param{
				Filename:    "",
				ContentType: "",
				Value:       "value",
			}
		case "file":
			want = Param{
				Filename:    "test.txt",
				ContentType: "application/octet-stream",
				Value:       "file contents",
			}
		default:
			t.Errorf("pd.Params: got %q, want not to be present", p.Name)
			continue
		}

		if got, want := p.Filename, want.Filename; got != want {
			t.Errorf("p.Filename: got %q, want %q", got, want)
		}
		if got, want := p.ContentType, want.ContentType; got != want {
			t.Errorf("p.ContentType: got %q, want %q", got, want)
		}
		if got, want := p.Value, want.Value; got != want {
			t.Errorf("p.Value: got %q, want %q", got, want)
		}
	}
}
Example #16
0
func TestModifyRequest(t *testing.T) {
	req, err := http.NewRequest("GET", "http://example.com/path?query=true", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}
	req.Header.Add("Request-Header", "first")
	req.Header.Add("Request-Header", "second")

	cookie := &http.Cookie{
		Name:  "request",
		Value: "cookie",
	}
	req.AddCookie(cookie)

	ctx, err := session.FromContext(nil)
	if err != nil {
		t.Fatalf("session.FromContext(): got %v, want no error", err)
	}
	martian.SetContext(req, ctx)
	defer martian.RemoveContext(req)

	logger := NewLogger("martian", "2.0.0")
	if err := logger.ModifyRequest(req); err != nil {
		t.Fatalf("ModifyRequest(): got %v, want no error", err)
	}

	log := logger.Export().Log
	if got, want := log.Version, "1.2"; got != want {
		t.Errorf("log.Version: got %q, want %q", got, want)
	}
	if got, want := log.Creator.Name, "martian"; got != want {
		t.Errorf("log.Creator.Name: got %q, want %q", got, want)
	}
	if got, want := log.Creator.Version, "2.0.0"; got != want {
		t.Errorf("log.Creator.Version: got %q, want %q", got, want)
	}

	if got, want := len(log.Entries), 1; got != want {
		t.Fatalf("len(log.Entries): got %d, want %d", got, want)
	}

	entry := log.Entries[0]
	if got, want := time.Since(entry.StartedDateTime), time.Second; got > want {
		t.Errorf("entry.StartedDateTime: got %s, want less than %s", got, want)
	}

	hreq := entry.Request
	if got, want := hreq.Method, "GET"; got != want {
		t.Errorf("hreq.Method: got %q, want %q", got, want)
	}

	if got, want := hreq.URL, "http://example.com/path?query=true"; got != want {
		t.Errorf("hreq.URL: got %q, want %q", got, want)
	}

	if got, want := hreq.HTTPVersion, "HTTP/1.1"; got != want {
		t.Errorf("hreq.HTTPVersion: got %q, want %q", got, want)
	}

	if got, want := hreq.BodySize, int64(0); got != want {
		t.Errorf("hreq.BodySize: got %d, want %d", got, want)
	}

	if got, want := hreq.HeadersSize, int64(-1); got != want {
		t.Errorf("hreq.HeadersSize: got %d, want %d", got, want)
	}

	if got, want := len(hreq.QueryString), 1; got != want {
		t.Fatalf("len(hreq.QueryString): got %d, want %q", got, want)
	}

	qs := hreq.QueryString[0]
	if got, want := qs.Name, "query"; got != want {
		t.Errorf("qs.Name: got %q, want %q", got, want)
	}
	if got, want := qs.Value, "true"; got != want {
		t.Errorf("qs.Value: got %q, want %q", got, want)
	}

	if got, want := len(hreq.Headers), 2; got != want {
		t.Fatalf("len(hreq.Headers): got %d, want %d", got, want)
	}

	for _, h := range hreq.Headers {
		var want string
		switch h.Name {
		case "Request-Header":
			want = "first, second"
		case "Cookie":
			want = cookie.String()
		default:
			t.Errorf("hreq.Headers: got %q, want header to not be present", h.Name)
			continue
		}

		if got := h.Value; got != want {
			t.Errorf("hreq.Headers[%q]: got %q, want %q", h.Name, got, want)
		}
	}

	if got, want := len(hreq.Cookies), 1; got != want {
		t.Fatalf("len(hreq.Cookies): got %d, want %d", got, want)
	}

	hcookie := hreq.Cookies[0]
	if got, want := hcookie.Name, "request"; got != want {
		t.Errorf("hcookie.Name: got %q, want %q", got, want)
	}
	if got, want := hcookie.Value, "cookie"; got != want {
		t.Errorf("hcookie.Value: got %q, want %q", got, want)
	}
}
Example #17
0
func TestProxyAuthInvalidCredentials(t *testing.T) {
	m := NewModifier()
	autherr := errors.New("auth error")

	tm := martiantest.NewModifier()
	tm.RequestFunc(func(req *http.Request) {
		ctx := martian.Context(req)
		actx := auth.FromContext(ctx)

		actx.SetError(autherr)
	})
	tm.ResponseFunc(func(res *http.Response) {
		ctx := martian.Context(res.Request)
		actx := auth.FromContext(ctx)

		actx.SetError(autherr)
	})

	m.SetRequestModifier(tm)
	m.SetResponseModifier(tm)

	req, err := http.NewRequest("GET", "http://example.com", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}
	req.Header.Set("Proxy-Authorization", "Basic "+encode("user:pass"))

	ctx, err := session.FromContext(nil)
	if err != nil {
		t.Fatalf("session.FromContext(): got %v, want no error", err)
	}

	martian.SetContext(req, ctx)
	defer martian.RemoveContext(req)

	if err := m.ModifyRequest(req); err != nil {
		t.Fatalf("ModifyRequest(): got %v, want no error", err)
	}

	if !tm.RequestModified() {
		t.Error("tm.RequestModified(): got false, want true")
	}

	actx := auth.FromContext(ctx)
	if actx.Error() != autherr {
		t.Fatalf("auth.Error(): got %v, want %v", actx.Error(), autherr)
	}
	actx.SetError(nil)

	res := proxyutil.NewResponse(200, nil, req)

	if err := m.ModifyResponse(res); err != nil {
		t.Fatalf("ModifyResponse(): got %v, want no error", err)
	}

	if !tm.ResponseModified() {
		t.Error("tm.ResponseModified(): got false, want true")
	}

	if actx.Error() != autherr {
		t.Fatalf("auth.Error(): got %v, want %v", actx.Error(), autherr)
	}

	if got, want := res.StatusCode, http.StatusProxyAuthRequired; got != want {
		t.Errorf("res.StatusCode: got %d, want %d", got, want)
	}
	if got, want := res.Header.Get("Proxy-Authenticate"), "Basic"; got != want {
		t.Errorf("res.Header.Get(%q): got %q, want %q", "Proxy-Authenticate", got, want)
	}
}
Example #18
0
func TestModifyResponse(t *testing.T) {
	req, err := http.NewRequest("GET", "http://example.com", nil)
	if err != nil {
		t.Fatalf("NewRequest(): got %v, want no error", err)
	}

	ctx, err := session.FromContext(nil)
	if err != nil {
		t.Fatalf("session.FromContext(): got %v, want no error", err)
	}
	martian.SetContext(req, ctx)
	defer martian.RemoveContext(req)

	res := proxyutil.NewResponse(301, strings.NewReader("response body"), req)
	res.Header.Add("Response-Header", "first")
	res.Header.Add("Response-Header", "second")
	res.Header.Set("Location", "google.com")

	expires := time.Now()
	cookie := &http.Cookie{
		Name:     "response",
		Value:    "cookie",
		Path:     "/",
		Domain:   "example.com",
		Expires:  expires,
		Secure:   true,
		HttpOnly: true,
	}
	res.Header.Set("Set-Cookie", cookie.String())

	logger := NewLogger("martian", "2.0.0")

	if err := logger.ModifyRequest(req); err != nil {
		t.Fatalf("ModifyRequest(): got %v, want no error", err)
	}

	if err := logger.ModifyResponse(res); err != nil {
		t.Fatalf("ModifyResponse(): got %v, want no error", err)
	}

	log := logger.Export().Log
	if got, want := len(log.Entries), 1; got != want {
		t.Fatalf("len(log.Entries): got %d, want %d", got, want)
	}

	hres := log.Entries[0].Response
	if got, want := hres.Status, 301; got != want {
		t.Errorf("hres.Status: got %d, want %d", got, want)
	}

	if got, want := hres.StatusText, "Moved Permanently"; got != want {
		t.Errorf("hres.StatusText: got %q, want %q", got, want)
	}

	if got, want := hres.HTTPVersion, "HTTP/1.1"; got != want {
		t.Errorf("hres.HTTPVersion: got %q, want %q", got, want)
	}

	if got, want := hres.Content.Text, []byte("response body"); !bytes.Equal(got, want) {
		t.Errorf("hres.Content.Text: got %q, want %q", got, want)
	}

	if got, want := len(hres.Headers), 3; got != want {
		t.Fatalf("len(hreq.Headers): got %d, want %d", got, want)
	}

	for _, h := range hres.Headers {
		var want string
		switch h.Name {
		case "Response-Header":
			want = "first, second"
		case "Location":
			want = "google.com"
		case "Set-Cookie":
			want = cookie.String()
		default:
			t.Errorf("hres.Headers: got %q, want header to not be present", h.Name)
			continue
		}

		if got := h.Value; got != want {
			t.Errorf("hres.Headers[%q]: got %q, want %q", h.Name, got, want)
		}
	}

	if got, want := len(hres.Cookies), 1; got != want {
		t.Fatalf("len(hres.Cookies): got %d, want %d", got, want)
	}

	hcookie := hres.Cookies[0]
	if got, want := hcookie.Name, "response"; got != want {
		t.Errorf("hcookie.Name: got %q, want %q", got, want)
	}
	if got, want := hcookie.Value, "cookie"; got != want {
		t.Errorf("hcookie.Value: got %q, want %q", got, want)
	}
	if got, want := hcookie.Path, "/"; got != want {
		t.Errorf("hcookie.Path: got %q, want %q", got, want)
	}
	if got, want := hcookie.Domain, "example.com"; got != want {
		t.Errorf("hcookie.Domain: got %q, want %q", got, want)
	}
	if got, want := hcookie.Expires, expires; got.Equal(want) {
		t.Errorf("hcookie.Expires: got %s, want %s", got, want)
	}
	if !hcookie.HTTPOnly {
		t.Error("hcookie.HTTPOnly: got false, want true")
	}
	if !hcookie.Secure {
		t.Error("hcookie.Secure: got false, want true")
	}
}
Example #19
0
func (p *Proxy) handle(ctx *session.Context, conn net.Conn, brw *bufio.ReadWriter) error {
	log.Debugf("martian: waiting for request: %v", conn.RemoteAddr())

	req, err := http.ReadRequest(brw.Reader)
	if err != nil {
		if isCloseable(err) {
			log.Debugf("martian: connection closed prematurely: %v", err)
		} else {
			log.Errorf("martian: failed to read request: %v", err)
		}

		// TODO: TCPConn.WriteClose() to avoid sending an RST to the client.

		return errClose
	}
	defer req.Body.Close()

	if h, pattern := p.mux.Handler(req); pattern != "" {
		defer brw.Flush()

		closing := req.Close || p.Closing()

		log.Infof("martian: intercepted configuration request: %s", req.URL)
		rw := newResponseWriter(brw, closing)
		defer rw.Close()

		h.ServeHTTP(rw, req)

		// Call WriteHeader to ensure a response is sent, since the handler isn't
		// required to call WriteHeader/Write.
		rw.WriteHeader(200)

		if closing {
			return errClose
		}

		return nil
	}

	ctx, err = session.FromContext(ctx)
	if err != nil {
		log.Errorf("martian: failed to derive context: %v", err)
		return err
	}

	SetContext(req, ctx)
	defer RemoveContext(req)

	if tconn, ok := conn.(*tls.Conn); ok {
		ctx.GetSession().MarkSecure()

		cs := tconn.ConnectionState()
		req.TLS = &cs
	}

	req.URL.Scheme = "http"
	if ctx.GetSession().IsSecure() {
		log.Debugf("martian: forcing HTTPS inside secure session")
		req.URL.Scheme = "https"
	}

	req.RemoteAddr = conn.RemoteAddr().String()
	if req.URL.Host == "" {
		req.URL.Host = req.Host
	}

	log.Infof("martian: received request: %s", req.URL)

	if req.Method == "CONNECT" {
		if err := p.reqmod.ModifyRequest(req); err != nil {
			log.Errorf("martian: error modifying CONNECT request: %v", err)
			proxyutil.Warning(req.Header, err)
		}

		if p.mitm != nil {
			log.Debugf("martian: attempting MITM for connection: %s", req.Host)
			res := proxyutil.NewResponse(200, nil, req)

			if err := p.resmod.ModifyResponse(res); err != nil {
				log.Errorf("martian: error modifying CONNECT response: %v", err)
				proxyutil.Warning(res.Header, err)
			}

			res.Write(brw)
			brw.Flush()

			log.Debugf("martian: completed MITM for connection: %s", req.Host)

			tlsconn := tls.Server(conn, p.mitm.TLSForHost(req.Host))
			brw.Writer.Reset(tlsconn)
			brw.Reader.Reset(tlsconn)

			return p.handle(ctx, tlsconn, brw)
		}

		log.Debugf("martian: attempting to establish CONNECT tunnel: %s", req.URL.Host)
		res, cconn, cerr := p.connect(req)
		if cerr != nil {
			log.Errorf("martian: failed to CONNECT: %v", err)
			res = proxyutil.NewResponse(502, nil, req)
			proxyutil.Warning(res.Header, cerr)

			if err := p.resmod.ModifyResponse(res); err != nil {
				log.Errorf("martian: error modifying CONNECT response: %v", err)
				proxyutil.Warning(res.Header, err)
			}

			res.Write(brw)
			return brw.Flush()
		}
		defer res.Body.Close()
		defer cconn.Close()

		if err := p.resmod.ModifyResponse(res); err != nil {
			log.Errorf("martian: error modifying CONNECT response: %v", err)
			proxyutil.Warning(res.Header, err)
		}

		res.Write(brw)
		brw.Flush()

		cbw := bufio.NewWriter(cconn)
		cbr := bufio.NewReader(cconn)
		defer cbw.Flush()

		copySync := func(w io.Writer, r io.Reader, donec chan<- bool) {
			io.Copy(w, r)
			donec <- true
		}

		donec := make(chan bool, 2)
		go copySync(cbw, brw, donec)
		go copySync(brw, cbr, donec)

		log.Debugf("martian: established CONNECT tunnel, proxying traffic")
		<-donec
		<-donec
		log.Debugf("martian: closed CONNECT tunnel")

		return errClose
	}

	if err := p.reqmod.ModifyRequest(req); err != nil {
		log.Errorf("martian: error modifying request: %v", err)
		proxyutil.Warning(req.Header, err)
	}

	res, err := p.roundTrip(ctx, req)
	if err != nil {
		log.Errorf("martian: failed to round trip: %v", err)
		res = proxyutil.NewResponse(502, nil, req)
		proxyutil.Warning(res.Header, err)
	}
	defer res.Body.Close()

	if err := p.resmod.ModifyResponse(res); err != nil {
		log.Errorf("martian: error modifying response: %v", err)
		proxyutil.Warning(res.Header, err)
	}

	var closing error
	if req.Close || p.Closing() {
		log.Debugf("martian: received close request: %v", req.RemoteAddr)
		res.Header.Add("Connection", "close")
		closing = errClose
	}

	log.Debugf("martian: sent response: %v", req.URL)
	res.Write(brw)
	brw.Flush()

	return closing
}
Example #20
0
func TestExportHandlerServeHTTP(t *testing.T) {
	logger := NewLogger("martian", "2.0.0")

	req, err := http.NewRequest("GET", "http://example.com", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}

	ctx, err := session.FromContext(nil)
	if err != nil {
		t.Fatalf("session.FromContext(): got %v, want no error", err)
	}
	martian.SetContext(req, ctx)
	defer martian.RemoveContext(req)

	if err := logger.ModifyRequest(req); err != nil {
		t.Fatalf("ModifyRequest(): got %v, want no error", err)
	}

	res := proxyutil.NewResponse(200, nil, req)
	if err := logger.ModifyResponse(res); err != nil {
		t.Fatalf("ModifyResponse(): got %v, want no error", err)
	}

	h := NewExportHandler(logger)

	req, err = http.NewRequest("GET", "/", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}

	rw := httptest.NewRecorder()
	h.ServeHTTP(rw, req)
	if got, want := rw.Code, http.StatusOK; got != want {
		t.Errorf("rw.Code: got %d, want %d", got, want)
	}

	hl := &HAR{}
	if err := json.Unmarshal(rw.Body.Bytes(), hl); err != nil {
		t.Fatalf("json.Unmarshal(): got %v, want no error", err)
	}

	if got, want := len(hl.Log.Entries), 1; got != want {
		t.Fatalf("len(hl.Log.Entries): got %v, want %v", got, want)
	}

	entry := hl.Log.Entries[0]
	if got, want := entry.Request.URL, "http://example.com"; got != want {
		t.Errorf("Request.URL: got %q, want %q", got, want)
	}
	if got, want := entry.Response.Status, 200; got != want {
		t.Errorf("Response.Status: got %d, want %d", got, want)
	}

	rh := NewResetHandler(logger)
	req, err = http.NewRequest("DELETE", "/", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}

	rw = httptest.NewRecorder()
	rh.ServeHTTP(rw, req)

	req, err = http.NewRequest("GET", "/", nil)
	if err != nil {
		t.Fatalf("http.NewRequest(): got %v, want no error", err)
	}

	rw = httptest.NewRecorder()
	h.ServeHTTP(rw, req)
	if got, want := rw.Code, http.StatusOK; got != want {
		t.Errorf("rw.Code: got %d, want %d", got, want)
	}

	hl = &HAR{}
	if err := json.Unmarshal(rw.Body.Bytes(), hl); err != nil {
		t.Fatalf("json.Unmarshal(): got %v, want no error", err)
	}

	if got, want := len(hl.Log.Entries), 0; got != want {
		t.Errorf("len(Log.Entries): got %v, want %v", got, want)
	}
}