예제 #1
0
func (s *TraceSuite) TestTraceCaptureHeaders(c *C) {
	respHeaders := http.Header{
		"X-Re-1": []string{"6", "7"},
		"X-Re-2": []string{"2", "3"},
	}

	handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
		utils.CopyHeaders(w.Header(), respHeaders)
		w.Write([]byte("hello"))
	})
	buf := &bytes.Buffer{}
	l := utils.NewFileLogger(buf, utils.INFO)

	trace := &bytes.Buffer{}

	t, err := New(handler, trace, Logger(l), RequestHeaders("X-Req-B", "X-Req-A"), ResponseHeaders("X-Re-1", "X-Re-2"))
	c.Assert(err, IsNil)

	srv := httptest.NewServer(t)
	defer srv.Close()

	reqHeaders := http.Header{"X-Req-A": []string{"1", "2"}, "X-Req-B": []string{"3", "4"}}
	re, _, err := testutils.Get(srv.URL+"/hello", testutils.Headers(reqHeaders))
	c.Assert(err, IsNil)
	c.Assert(re.StatusCode, Equals, http.StatusOK)

	var r *Record
	c.Assert(json.Unmarshal(trace.Bytes(), &r), IsNil)

	c.Assert(r.Request.Headers, DeepEquals, reqHeaders)
	c.Assert(r.Response.Headers, DeepEquals, respHeaders)
}
예제 #2
0
func (f *Forwarder) ServeHTTP(w http.ResponseWriter, req *http.Request) {
	start := time.Now().UTC()
	response, err := f.roundTripper.RoundTrip(f.copyRequest(req, req.URL))
	if err != nil {
		f.log.Errorf("Error forwarding to %v, err: %v", req.URL, err)
		f.errHandler.ServeHTTP(w, req, err)
		return
	}

	if req.TLS != nil {
		f.log.Infof("Round trip: %v, code: %v, duration: %v tls:version: %x, tls:resume:%t, tls:csuite:%x, tls:server:%v",
			req.URL, response.StatusCode, time.Now().UTC().Sub(start),
			req.TLS.Version,
			req.TLS.DidResume,
			req.TLS.CipherSuite,
			req.TLS.ServerName)
	} else {
		f.log.Infof("Round trip: %v, code: %v, duration: %v",
			req.URL, response.StatusCode, time.Now().UTC().Sub(start))
	}

	utils.CopyHeaders(w.Header(), response.Header)
	w.WriteHeader(response.StatusCode)
	written, _ := io.Copy(w, response.Body)
	if written != 0 {
		w.Header().Set(ContentLength, strconv.FormatInt(written, 10))
	}
	response.Body.Close()
}
예제 #3
0
func (f *Forwarder) copyRequest(req *http.Request, u *url.URL) *http.Request {
	outReq := new(http.Request)
	*outReq = *req // includes shallow copies of maps, but we handle this below

	outReq.URL = utils.CopyURL(req.URL)
	outReq.URL.Scheme = u.Scheme
	outReq.URL.Host = u.Host
	outReq.URL.Opaque = req.RequestURI
	// raw query is already included in RequestURI, so ignore it to avoid dupes
	outReq.URL.RawQuery = ""
	// Do not pass client Host header unless optsetter PassHostHeader is set.
	if f.passHost != true {
		outReq.Host = u.Host
	}
	outReq.Proto = "HTTP/1.1"
	outReq.ProtoMajor = 1
	outReq.ProtoMinor = 1

	// Overwrite close flag so we can keep persistent connection for the backend servers
	outReq.Close = false

	outReq.Header = make(http.Header)
	utils.CopyHeaders(outReq.Header, req.Header)

	if f.rewriter != nil {
		f.rewriter.Rewrite(outReq)
	}
	return outReq
}
예제 #4
0
func Headers(h http.Header) ReqOption {
	return func(o *ReqOpts) error {
		if o.Headers == nil {
			o.Headers = make(http.Header)
		}
		utils.CopyHeaders(o.Headers, h)
		return nil
	}
}
예제 #5
0
func (rw *rewriteHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
	oldURL := rawURL(req)

	// only continue if the Regexp param matches the URL
	if !rw.regexp.MatchString(oldURL) {
		rw.next.ServeHTTP(w, req)
		return
	}

	// apply a rewrite regexp to the URL
	newURL := rw.regexp.ReplaceAllString(oldURL, rw.replacement)

	// replace any variables that may be in there
	rewrittenURL := &bytes.Buffer{}
	if err := ApplyString(newURL, rewrittenURL, req); err != nil {
		rw.errHandler.ServeHTTP(w, req, err)
		return
	}

	// parse the rewritten URL and replace request URL with it
	parsedURL, err := url.Parse(rewrittenURL.String())
	if err != nil {
		rw.errHandler.ServeHTTP(w, req, err)
		return
	}

	if rw.redirect && newURL != oldURL {
		(&redirectHandler{u: parsedURL}).ServeHTTP(w, req)
		return
	}

	req.URL = parsedURL

	// make sure the request URI corresponds the rewritten URL
	req.RequestURI = req.URL.RequestURI()

	if !rw.rewriteBody {
		rw.next.ServeHTTP(w, req)
		return
	}

	bw := &bufferWriter{header: make(http.Header), buffer: &bytes.Buffer{}}
	newBody := &bytes.Buffer{}

	rw.next.ServeHTTP(bw, req)

	if err := Apply(bw.buffer, newBody, req); err != nil {
		log.Errorf("Failed to rewrite response body: %v", err)
		return
	}

	utils.CopyHeaders(w.Header(), bw.Header())
	w.Header().Set("Content-Length", strconv.Itoa(newBody.Len()))
	w.WriteHeader(bw.code)
	io.Copy(w, newBody)
}
예제 #6
0
func (s *Streamer) copyRequest(req *http.Request, body io.ReadCloser, bodySize int64) *http.Request {
	o := *req
	o.URL = utils.CopyURL(req.URL)
	o.Header = make(http.Header)
	utils.CopyHeaders(o.Header, req.Header)
	o.ContentLength = bodySize
	// remove TransferEncoding that could have been previously set because we have transformed the request from chunked encoding
	o.TransferEncoding = []string{}
	// http.Transport will close the request body on any error, we are controlling the close process ourselves, so we override the closer here
	o.Body = ioutil.NopCloser(body)
	return &o
}
예제 #7
0
func MakeRequest(url string, opts ...ReqOption) (*http.Response, []byte, error) {
	o := &ReqOpts{}
	for _, s := range opts {
		if err := s(o); err != nil {
			return nil, nil, err
		}
	}

	if o.Method == "" {
		o.Method = "GET"
	}
	request, _ := http.NewRequest(o.Method, url, strings.NewReader(o.Body))
	if o.Headers != nil {
		utils.CopyHeaders(request.Header, o.Headers)
	}

	if o.Auth != nil {
		request.Header.Set("Authorization", o.Auth.String())
	}

	if len(o.Host) != 0 {
		request.Host = o.Host
	}

	var tr *http.Transport
	if strings.HasPrefix(url, "https") {
		tr = &http.Transport{
			DisableKeepAlives: true,
			TLSClientConfig:   &tls.Config{InsecureSkipVerify: true},
		}
	} else {
		tr = &http.Transport{
			DisableKeepAlives: true,
		}
	}

	client := &http.Client{
		Transport: tr,
		CheckRedirect: func(req *http.Request, via []*http.Request) error {
			return fmt.Errorf("No redirects")
		},
	}
	response, err := client.Do(request)
	if err == nil {
		bodyBytes, err := ioutil.ReadAll(response.Body)
		return response, bodyBytes, err
	}
	return response, nil, err
}
예제 #8
0
func (s *Streamer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
	if err := s.checkLimit(req); err != nil {
		s.log.Infof("request body over limit: %v", err)
		s.errHandler.ServeHTTP(w, req, err)
		return
	}

	// Read the body while keeping limits in mind. This reader controls the maximum bytes
	// to read into memory and disk. This reader returns an error if the total request size exceeds the
	// prefefined MaxSizeBytes. This can occur if we got chunked request, in this case ContentLength would be set to -1
	// and the reader would be unbounded bufio in the http.Server
	body, err := multibuf.New(req.Body, multibuf.MaxBytes(s.maxRequestBodyBytes), multibuf.MemBytes(s.memRequestBodyBytes))
	if err != nil || body == nil {
		s.errHandler.ServeHTTP(w, req, err)
		return
	}

	// Set request body to buffered reader that can replay the read and execute Seek
	// Note that we don't change the original request body as it's handled by the http server
	// and we don'w want to mess with standard library
	defer body.Close()

	// We need to set ContentLength based on known request size. The incoming request may have been
	// set without content length or using chunked TransferEncoding
	totalSize, err := body.Size()
	if err != nil {
		s.log.Errorf("failed to get size, err %v", err)
		s.errHandler.ServeHTTP(w, req, err)
		return
	}

	outreq := s.copyRequest(req, body, totalSize)

	attempt := 1
	for {
		// We create a special writer that will limit the response size, buffer it to disk if necessary
		writer, err := multibuf.NewWriterOnce(multibuf.MaxBytes(s.maxResponseBodyBytes), multibuf.MemBytes(s.memResponseBodyBytes))
		if err != nil {
			s.errHandler.ServeHTTP(w, req, err)
			return
		}

		// We are mimicking http.ResponseWriter to replace writer with our special writer
		b := &bufferWriter{
			header: make(http.Header),
			buffer: writer,
		}
		defer b.Close()

		s.next.ServeHTTP(b, outreq)

		var reader multibuf.MultiReader
		if b.expectBody(outreq) {
			rdr, err := writer.Reader()
			if err != nil {
				s.log.Errorf("failed to read response, err %v", err)
				s.errHandler.ServeHTTP(w, req, err)
				return
			}
			defer rdr.Close()
			reader = rdr
		}

		if (s.retryPredicate == nil || attempt > DefaultMaxRetryAttempts) ||
			!s.retryPredicate(&context{r: req, attempt: attempt, responseCode: b.code, log: s.log}) {
			utils.CopyHeaders(w.Header(), b.Header())
			w.WriteHeader(b.code)
			if reader != nil {
				io.Copy(w, reader)
			}
			return
		}

		attempt += 1
		if _, err := body.Seek(0, 0); err != nil {
			s.log.Errorf("Failed to rewind: error: %v", err)
			s.errHandler.ServeHTTP(w, req, err)
			return
		}
		outreq = s.copyRequest(req, body, totalSize)
		s.log.Infof("retry Request(%v %v) attempt %v", req.Method, req.URL, attempt)
	}
}