コード例 #1
0
ファイル: martianhttp.go プロジェクト: haijianren/martian
// ServeHTTP accepts a POST request with a body containing a modifier as a JSON
// message and updates the contained reqmod and resmod with the parsed
// modifier.
func (m *Modifier) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
	if req.Method != "POST" {
		rw.Header().Set("Allow", "POST")
		rw.WriteHeader(405)
		return
	}

	body, err := ioutil.ReadAll(req.Body)
	if err != nil {
		http.Error(rw, err.Error(), 500)
		log.Errorf("error reading request body: %v", err)
		return
	}
	req.Body.Close()

	r, err := parse.FromJSON(body)
	if err != nil {
		http.Error(rw, err.Error(), 400)
		log.Errorf("error parsing JSON: %v", err)
		return
	}

	m.SetRequestModifier(r.RequestModifier())
	m.SetResponseModifier(r.ResponseModifier())
}
コード例 #2
0
ファイル: proxy.go プロジェクト: rlugojr/martian
func (p *Proxy) handleLoop(conn net.Conn) {
	p.conns.Add(1)
	defer p.conns.Done()
	defer conn.Close()

	s, err := newSession()
	if err != nil {
		log.Errorf("martian: failed to create session: %v", err)
		return
	}

	ctx, err := withSession(s)
	if err != nil {
		log.Errorf("martian: failed to create context: %v", err)
		return
	}

	brw := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))

	for {
		deadline := time.Now().Add(p.timeout)
		conn.SetDeadline(deadline)

		if err := p.handle(ctx, conn, brw); isCloseable(err) {
			log.Debugf("martian: closing connection: %v", conn.RemoteAddr())
			return
		}
	}
}
コード例 #3
0
ファイル: martianhttp.go プロジェクト: rlugojr/martian
func (m *Modifier) servePOST(rw http.ResponseWriter, req *http.Request) {
	body, err := ioutil.ReadAll(req.Body)
	if err != nil {
		http.Error(rw, err.Error(), 500)
		log.Errorf("martianhttp: error reading request body: %v", err)
		return
	}
	req.Body.Close()

	r, err := parse.FromJSON(body)
	if err != nil {
		http.Error(rw, err.Error(), 400)
		log.Errorf("martianhttp: error parsing JSON: %v", err)
		return
	}

	buf := new(bytes.Buffer)
	if err := json.Indent(buf, body, "", "  "); err != nil {
		http.Error(rw, err.Error(), 400)
		log.Errorf("martianhttp: error formatting JSON: %v", err)
		return
	}

	m.mu.Lock()
	defer m.mu.Unlock()

	m.config = buf.Bytes()
	m.setRequestModifier(r.RequestModifier())
	m.setResponseModifier(r.ResponseModifier())
}
コード例 #4
0
ファイル: verify_handlers.go プロジェクト: haijianren/martian
// ServeHTTP writes out a JSON response containing a list of verification
// errors that occurred during the requests and responses sent to the proxy.
func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
	rw.Header().Set("Content-Type", "application/json")

	if req.Method != "GET" {
		rw.Header().Set("Allow", "GET")
		rw.WriteHeader(405)
		log.Errorf("verify: invalid request method: %s", req.Method)
		return
	}

	vres := &verifyResponse{}

	if h.reqv != nil {
		if err := h.reqv.VerifyRequests(); err != nil {
			appendError(vres, err)
		}
	}
	if h.resv != nil {
		if err := h.resv.VerifyResponses(); err != nil {
			appendError(vres, err)
		}
	}

	json.NewEncoder(rw).Encode(vres)
}
コード例 #5
0
ファイル: listener.go プロジェクト: shawnps/martian
// Writes writes bytes from b to the connection, optionally simulating
// connection latency and throttling write throughput based on desired
// bandwidth constraints.
func (c *conn) Write(b []byte) (int, error) {
	c.wonce.Do(c.sleepLatency)

	var total int64
	for len(b) > 0 {
		var max int64

		n, err := c.wb.FillThrottle(func(remaining int64) (int64, error) {
			max = remaining
			if l := int64(len(b)); remaining >= l {
				max = l
			}

			n, err := c.Conn.Write(b[:max])
			return int64(n), err
		})

		total += n

		if err != nil {
			if err != io.EOF {
				log.Errorf("trafficshape: failed write: %v", err)
			}
			return int(total), err
		}

		b = b[max:]
	}

	return int(total), nil
}
コード例 #6
0
ファイル: bucket.go プロジェクト: shawnps/martian
// Fill calls fn with the available capacity remaining (capacity-fill) and
// fills the bucket with the number of tokens returned by fn. If the remaining
// capacity is 0, Fill returns 0, nil. If the remaining capacity is < 0, Fill
// returns 0, ErrBucketOverflow.
//
// If fn returns an error, it will be returned by Fill along with the remaining
// capacity.
//
// fn is provided the remaining capacity as a soft maximum, fn is allowed to
// use more than the remaining capacity without incurring spillage, though this
// will cause subsequent calls to Fill to return ErrBucketOverflow until the
// next drain.
//
// If the bucket is closed when Fill is called, fn will not be executed and
// Fill will return with an error.
func (b *Bucket) Fill(fn func(int64) (int64, error)) (int64, error) {
	if b.closed() {
		log.Errorf("trafficshape: fill on closed bucket")
		return 0, errFillClosedBucket
	}

	fill := atomic.LoadInt64(&b.fill)
	capacity := atomic.LoadInt64(&b.capacity)

	switch {
	case fill < capacity:
		log.Debugf("trafficshape: under capacity (%d/%d)", fill, capacity)

		n, err := fn(capacity - fill)
		fill = atomic.AddInt64(&b.fill, n)

		return n, err
	case fill > capacity:
		log.Debugf("trafficshape: bucket overflow (%d/%d)", fill, capacity)

		return 0, ErrBucketOverflow
	}

	log.Debugf("trafficshape: bucket full (%d/%d)", fill, capacity)
	return 0, nil
}
コード例 #7
0
ファイル: handler.go プロジェクト: shawnps/martian
// ServeHTTP configures latency and bandwidth constraints.
//
// The "latency" query string parameter accepts a duration string in any format
// supported by time.ParseDuration.
// The "up" and "down" query string parameters accept integers as bits per
// second to be used for read and write throughput.
func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
	log.Debugf("trafficshape: configuration request")

	latency := req.FormValue("latency")
	if latency != "" {
		d, err := time.ParseDuration(latency)
		if err != nil {
			log.Errorf("trafficshape: invalid latency parameter: %v", err)
			http.Error(rw, fmt.Sprintf("invalid duration: %s", latency), 400)
			return
		}

		h.l.SetLatency(d)
	}

	up := req.FormValue("up")
	if up != "" {
		br, err := strconv.ParseInt(up, 10, 64)
		if err != nil {
			log.Errorf("trafficshape: invalid up parameter: %v", err)
			http.Error(rw, fmt.Sprintf("invalid upstream: %s", up), 400)
			return
		}

		h.l.SetWriteBitrate(br)
	}

	down := req.FormValue("down")
	if down != "" {
		br, err := strconv.ParseInt(down, 10, 64)
		if err != nil {
			log.Errorf("trafficshape: invalid down parameter: %v", err)
			http.Error(rw, fmt.Sprintf("invalid downstream: %s", down), 400)
			return
		}

		h.l.SetReadBitrate(br)
	}

	log.Debugf("trafficshape: configured successfully")
}
コード例 #8
0
ファイル: har_handlers.go プロジェクト: rlugojr/martian
// ServeHTTP resets the log, which clears its entries.
func (h *resetHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
	if !(req.Method == "POST" || req.Method == "DELETE") {
		rw.Header().Add("Allow", "POST")
		rw.Header().Add("Allow", "DELETE")
		rw.WriteHeader(http.StatusMethodNotAllowed)
		log.Errorf("har: method not allowed: %s", req.Method)
		return
	}
	h.logger.Reset()

	rw.WriteHeader(http.StatusNoContent)
}
コード例 #9
0
ファイル: har_handlers.go プロジェクト: rlugojr/martian
// ServeHTTP writes the log in HAR format to the response body.
func (h *exportHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
	if req.Method != "GET" {
		rw.Header().Add("Allow", "GET")
		rw.WriteHeader(http.StatusMethodNotAllowed)
		log.Errorf("har: method not allowed: %s", req.Method)
		return
	}
	rw.Header().Set("Content-Type", "application/json; charset=utf-8")

	hl := h.logger.Export()
	json.NewEncoder(rw).Encode(hl)
}
コード例 #10
0
ファイル: verify_handlers.go プロジェクト: haijianren/martian
// ServeHTTP resets the verifier for the given ID so that it may
// be run again.
func (h *ResetHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
	if req.Method != "POST" {
		rw.Header().Set("Allow", "POST")
		rw.WriteHeader(405)
		log.Errorf("verify: invalid request method: %s", req.Method)
		return
	}

	if h.reqv != nil {
		h.reqv.ResetRequestVerifications()
	}
	if h.resv != nil {
		h.resv.ResetResponseVerifications()
	}

	rw.WriteHeader(204)
}
コード例 #11
0
ファイル: listener.go プロジェクト: shawnps/martian
// Read reads bytes from connection into b, optionally simulating connection
// latency and throttling read throughput based on desired bandwidth
// constraints.
func (c *conn) Read(b []byte) (int, error) {
	c.ronce.Do(c.sleepLatency)

	n, err := c.rb.FillThrottle(func(remaining int64) (int64, error) {
		max := remaining
		if l := int64(len(b)); max > l {
			max = l
		}

		n, err := c.Conn.Read(b[:max])
		return int64(n), err
	})
	if err != nil && err != io.EOF {
		log.Errorf("trafficshape: error on throttled read: %v", err)
	}

	return int(n), err
}
コード例 #12
0
ファイル: listener.go プロジェクト: shawnps/martian
// WriteTo writes data to w from the connection, optionally simulating
// connection latency and throttling write throughput based on desired
// bandwidth constraints.
func (c *conn) WriteTo(w io.Writer) (int64, error) {
	c.wonce.Do(c.sleepLatency)

	var total int64
	for {
		n, err := c.wb.FillThrottle(func(remaining int64) (int64, error) {
			return io.CopyN(w, c.Conn, remaining)
		})

		total += n

		if err != nil {
			if err != io.EOF {
				log.Errorf("trafficshape: failed copying to writer: %v", err)
			}
			return total, err
		}
	}
}
コード例 #13
0
ファイル: listener.go プロジェクト: shawnps/martian
// ReadFrom reads data from r until EOF or error, optionally simulating
// connection latency and throttling read throughput based on desired bandwidth
// constraints.
func (c *conn) ReadFrom(r io.Reader) (int64, error) {
	c.ronce.Do(c.sleepLatency)

	var total int64
	for {
		n, err := c.rb.FillThrottle(func(remaining int64) (int64, error) {
			return io.CopyN(c.Conn, r, remaining)
		})

		total += n

		if err == io.EOF {
			log.Debugf("trafficshape: exhausted reader successfully")
			return total, nil
		} else if err != nil {
			log.Errorf("trafficshape: failed copying from reader: %v", err)
			return total, err
		}
	}
}
コード例 #14
0
ファイル: proxy.go プロジェクト: rlugojr/martian
// Serve accepts connections from the listener and handles the requests.
func (p *Proxy) Serve(l net.Listener) error {
	defer l.Close()

	var delay time.Duration
	for {
		if p.Closing() {
			return nil
		}

		conn, err := l.Accept()
		if err != nil {
			if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
				if delay == 0 {
					delay = 5 * time.Millisecond
				} else {
					delay *= 2
				}
				if max := time.Second; delay > max {
					delay = max
				}

				log.Debugf("martian: temporary error on accept: %v", err)
				time.Sleep(delay)
				continue
			}

			log.Errorf("martian: failed to accept: %v", err)
			return err
		}
		delay = 0
		log.Debugf("martian: accepted connection from %s", conn.RemoteAddr())

		if tconn, ok := conn.(*net.TCPConn); ok {
			tconn.SetKeepAlive(true)
			tconn.SetKeepAlivePeriod(3 * time.Minute)
		}

		go p.handleLoop(conn)
	}
}
コード例 #15
0
ファイル: listener.go プロジェクト: shawnps/martian
// Accept waits for and returns the next connection to the listener.
func (l *Listener) Accept() (net.Conn, error) {
	oc, err := l.Listener.Accept()
	if err != nil {
		log.Errorf("trafficshape: failed accepting connection: %v", err)
		return nil, err
	}

	if tconn, ok := oc.(*net.TCPConn); ok {
		log.Debugf("trafficshape: setting keep-alive for TCP connection")
		tconn.SetKeepAlive(true)
		tconn.SetKeepAlivePeriod(3 * time.Minute)
	}

	lc := &conn{
		Conn:    oc,
		latency: l.Latency(),
		rb:      l.rb,
		wb:      l.wb,
	}

	return lc, nil
}
コード例 #16
0
ファイル: bucket.go プロジェクト: shawnps/martian
// FillThrottle calls fn with the available capacity remaining (capacity-fill)
// and fills the bucket with the number of tokens returned by fn. If the
// remaining capacity is <= 0, FillThrottle will wait for the next drain before
// running fn.
//
// If fn returns an error, it will be returned by FillThrottle along with the
// number of tokens processed by fn.
//
// fn is provided the remaining capacity as a soft maximum, fn is allowed to
// use more than the remaining capacity without incurring spillage.
//
// If the bucket is closed when FillThrottle is called, or while waiting for
// the next drain, fn will not be executed and FillThrottle will return with an
// error.
func (b *Bucket) FillThrottle(fn func(int64) (int64, error)) (int64, error) {
	for {
		if b.closed() {
			log.Errorf("trafficshape: fill on closed bucket")
			return 0, errFillClosedBucket
		}

		fill := atomic.LoadInt64(&b.fill)
		capacity := atomic.LoadInt64(&b.capacity)

		if fill < capacity {
			log.Debugf("trafficshape: under capacity (%d/%d)", fill, capacity)

			n, err := fn(capacity - fill)
			fill = atomic.AddInt64(&b.fill, n)

			return n, err
		}

		log.Debugf("trafficshape: bucket full (%d/%d)", fill, capacity)
	}
}
コード例 #17
0
ファイル: har.go プロジェクト: shawnps/martian
func postData(req *http.Request) (*PostData, error) {
	// If the request has no body (no Content-Length and Transfer-Encoding isn't
	// chunked), skip the post data.
	if req.ContentLength <= 0 && len(req.TransferEncoding) == 0 {
		return nil, nil
	}

	ct := req.Header.Get("Content-Type")
	mt, ps, err := mime.ParseMediaType(ct)
	if err != nil {
		log.Errorf("har: cannot parse Content-Type header %q: %v", ct, err)
		mt = ct
	}

	pd := &PostData{
		MimeType: mt,
		Params:   []Param{},
	}

	mv := messageview.New()
	if err := mv.SnapshotRequest(req); err != nil {
		return nil, err
	}

	br, err := mv.BodyReader()
	if err != nil {
		return nil, err
	}

	switch mt {
	case "multipart/form-data":
		mpr := multipart.NewReader(br, ps["boundary"])

		for {
			p, err := mpr.NextPart()
			if err == io.EOF {
				break
			}
			if err != nil {
				return nil, err
			}
			defer p.Close()

			body, err := ioutil.ReadAll(p)
			if err != nil {
				return nil, err
			}

			pd.Params = append(pd.Params, Param{
				Name:        p.FormName(),
				Filename:    p.FileName(),
				ContentType: p.Header.Get("Content-Type"),
				Value:       string(body),
			})
		}
	case "application/x-www-form-urlencoded":
		body, err := ioutil.ReadAll(br)
		if err != nil {
			return nil, err
		}

		vs, err := url.ParseQuery(string(body))
		if err != nil {
			return nil, err
		}

		for n, vs := range vs {
			for _, v := range vs {
				pd.Params = append(pd.Params, Param{
					Name:  n,
					Value: v,
				})
			}
		}
	default:
		body, err := ioutil.ReadAll(br)
		if err != nil {
			return nil, err
		}

		pd.Text = string(body)
	}

	return pd, nil
}
コード例 #18
0
ファイル: proxy.go プロジェクト: rlugojr/martian
func (p *Proxy) handle(ctx *Context, conn net.Conn, brw *bufio.ReadWriter) error {
	log.Debugf("martian: waiting for request: %v", conn.RemoteAddr())

	req, err := http.ReadRequest(brw.Reader)
	if err != nil {
		if isCloseable(err) {
			log.Debugf("martian: connection closed prematurely: %v", err)
		} else {
			log.Errorf("martian: failed to read request: %v", err)
		}

		// TODO: TCPConn.WriteClose() to avoid sending an RST to the client.

		return errClose
	}
	defer req.Body.Close()

	if h, pattern := p.mux.Handler(req); pattern != "" {
		defer brw.Flush()

		closing := req.Close || p.Closing()

		log.Infof("martian: intercepted configuration request: %s", req.URL)
		rw := newResponseWriter(brw, closing)
		defer rw.Close()

		h.ServeHTTP(rw, req)

		// Call WriteHeader to ensure a response is sent, since the handler isn't
		// required to call WriteHeader/Write.
		rw.WriteHeader(200)

		if closing {
			return errClose
		}

		return nil
	}

	ctx, err = withSession(ctx.Session())
	if err != nil {
		log.Errorf("martian: failed to build new context: %v", err)
		return err
	}

	link(req, ctx)
	defer unlink(req)

	if tconn, ok := conn.(*tls.Conn); ok {
		ctx.Session().MarkSecure()

		cs := tconn.ConnectionState()
		req.TLS = &cs
	}

	req.URL.Scheme = "http"
	if ctx.Session().IsSecure() {
		log.Debugf("martian: forcing HTTPS inside secure session")
		req.URL.Scheme = "https"
	}

	req.RemoteAddr = conn.RemoteAddr().String()
	if req.URL.Host == "" {
		req.URL.Host = req.Host
	}

	log.Infof("martian: received request: %s", req.URL)

	if req.Method == "CONNECT" {
		if err := p.reqmod.ModifyRequest(req); err != nil {
			log.Errorf("martian: error modifying CONNECT request: %v", err)
			proxyutil.Warning(req.Header, err)
		}

		if p.mitm != nil {
			log.Debugf("martian: attempting MITM for connection: %s", req.Host)
			res := proxyutil.NewResponse(200, nil, req)

			if err := p.resmod.ModifyResponse(res); err != nil {
				log.Errorf("martian: error modifying CONNECT response: %v", err)
				proxyutil.Warning(res.Header, err)
			}

			if err := res.Write(brw); err != nil {
				log.Errorf("martian: got error while writing response back to client: %v", err)
			}
			if err := brw.Flush(); err != nil {
				log.Errorf("martian: got error while flushing response back to client: %v", err)
			}

			log.Debugf("martian: completed MITM for connection: %s", req.Host)

			b := make([]byte, 1)
			if _, err := brw.Read(b); err != nil {
				log.Errorf("martian: error peeking message through CONNECT tunnel to determine type: %v", err)
			}

			// Drain all of the rest of the buffered data.
			buf := make([]byte, brw.Reader.Buffered())
			brw.Read(buf)

			// 22 is the TLS handshake.
			// https://tools.ietf.org/html/rfc5246#section-6.2.1
			if b[0] == 22 {
				// Prepend the previously read data to be read again by
				// http.ReadRequest.
				tlsconn := tls.Server(&peekedConn{conn, io.MultiReader(bytes.NewReader(b), bytes.NewReader(buf), conn)}, p.mitm.TLSForHost(req.Host))

				if err := tlsconn.Handshake(); err != nil {
					p.mitm.HandshakeErrorCallback(req, err)
					return err
				}

				brw.Writer.Reset(tlsconn)
				brw.Reader.Reset(tlsconn)

				return p.handle(ctx, tlsconn, brw)
			}

			// Prepend the previously read data to be read again by http.ReadRequest.
			brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), bytes.NewReader(buf), conn))
			return p.handle(ctx, conn, brw)
		}

		log.Debugf("martian: attempting to establish CONNECT tunnel: %s", req.URL.Host)
		res, cconn, cerr := p.connect(req)
		if cerr != nil {
			log.Errorf("martian: failed to CONNECT: %v", err)
			res = proxyutil.NewResponse(502, nil, req)
			proxyutil.Warning(res.Header, cerr)

			if err := p.resmod.ModifyResponse(res); err != nil {
				log.Errorf("martian: error modifying CONNECT response: %v", err)
				proxyutil.Warning(res.Header, err)
			}

			if err := res.Write(brw); err != nil {
				log.Errorf("martian: got error while writing response back to client: %v", err)
			}
			err := brw.Flush()
			if err != nil {
				log.Errorf("martian: got error while flushing response back to client: %v", err)
			}
			return err
		}
		defer res.Body.Close()
		defer cconn.Close()

		if err := p.resmod.ModifyResponse(res); err != nil {
			log.Errorf("martian: error modifying CONNECT response: %v", err)
			proxyutil.Warning(res.Header, err)
		}

		if err := res.Write(brw); err != nil {
			log.Errorf("martian: got error while writing response back to client: %v", err)
		}
		if err := brw.Flush(); err != nil {
			log.Errorf("martian: got error while flushing response back to client: %v", err)
		}

		cbw := bufio.NewWriter(cconn)
		cbr := bufio.NewReader(cconn)
		defer cbw.Flush()

		copySync := func(w io.Writer, r io.Reader, donec chan<- bool) {
			if _, err := io.Copy(w, r); err != nil && err != io.EOF {
				log.Errorf("martian: failed to copy CONNECT tunnel: %v", err)
			}

			log.Debugf("martian: CONNECT tunnel finished copying")
			donec <- true
		}

		donec := make(chan bool, 2)
		go copySync(cbw, brw, donec)
		go copySync(brw, cbr, donec)

		log.Debugf("martian: established CONNECT tunnel, proxying traffic")
		<-donec
		<-donec
		log.Debugf("martian: closed CONNECT tunnel")

		return errClose
	}

	if err := p.reqmod.ModifyRequest(req); err != nil {
		log.Errorf("martian: error modifying request: %v", err)
		proxyutil.Warning(req.Header, err)
	}

	res, err := p.roundTrip(ctx, req)
	if err != nil {
		log.Errorf("martian: failed to round trip: %v", err)
		res = proxyutil.NewResponse(502, nil, req)
		proxyutil.Warning(res.Header, err)
	}
	defer res.Body.Close()

	if err := p.resmod.ModifyResponse(res); err != nil {
		log.Errorf("martian: error modifying response: %v", err)
		proxyutil.Warning(res.Header, err)
	}

	var closing error
	if req.Close || res.Close || p.Closing() {
		log.Debugf("martian: received close request: %v", req.RemoteAddr)
		res.Close = true
		closing = errClose
	}

	log.Debugf("martian: sent response: %v", req.URL)
	if err := res.Write(brw); err != nil {
		log.Errorf("martian: got error while writing response back to client: %v", err)
	}
	if err := brw.Flush(); err != nil {
		log.Errorf("martian: got error while flushing response back to client: %v", err)
	}

	return closing
}
コード例 #19
0
ファイル: proxy.go プロジェクト: xichen124/martian
func (p *Proxy) handle(ctx *session.Context, conn net.Conn, brw *bufio.ReadWriter) error {
	log.Debugf("martian: waiting for request: %v", conn.RemoteAddr())

	req, err := http.ReadRequest(brw.Reader)
	if err != nil {
		if isCloseable(err) {
			log.Debugf("martian: connection closed prematurely: %v", err)
		} else {
			log.Errorf("martian: failed to read request: %v", err)
		}

		// TODO: TCPConn.WriteClose() to avoid sending an RST to the client.

		return errClose
	}
	defer req.Body.Close()

	if h, pattern := p.mux.Handler(req); pattern != "" {
		defer brw.Flush()

		closing := req.Close || p.Closing()

		log.Infof("martian: intercepted configuration request: %s", req.URL)
		rw := newResponseWriter(brw, closing)
		defer rw.Close()

		h.ServeHTTP(rw, req)

		// Call WriteHeader to ensure a response is sent, since the handler isn't
		// required to call WriteHeader/Write.
		rw.WriteHeader(200)

		if closing {
			return errClose
		}

		return nil
	}

	ctx, err = session.FromContext(ctx)
	if err != nil {
		log.Errorf("martian: failed to derive context: %v", err)
		return err
	}

	SetContext(req, ctx)
	defer RemoveContext(req)

	if tconn, ok := conn.(*tls.Conn); ok {
		ctx.GetSession().MarkSecure()

		cs := tconn.ConnectionState()
		req.TLS = &cs
	}

	req.URL.Scheme = "http"
	if ctx.GetSession().IsSecure() {
		log.Debugf("martian: forcing HTTPS inside secure session")
		req.URL.Scheme = "https"
	}

	req.RemoteAddr = conn.RemoteAddr().String()
	if req.URL.Host == "" {
		req.URL.Host = req.Host
	}

	log.Infof("martian: received request: %s", req.URL)

	if req.Method == "CONNECT" {
		if err := p.reqmod.ModifyRequest(req); err != nil {
			log.Errorf("martian: error modifying CONNECT request: %v", err)
			proxyutil.Warning(req.Header, err)
		}

		if p.mitm != nil {
			log.Debugf("martian: attempting MITM for connection: %s", req.Host)
			res := proxyutil.NewResponse(200, nil, req)

			if err := p.resmod.ModifyResponse(res); err != nil {
				log.Errorf("martian: error modifying CONNECT response: %v", err)
				proxyutil.Warning(res.Header, err)
			}

			res.Write(brw)
			brw.Flush()

			log.Debugf("martian: completed MITM for connection: %s", req.Host)

			tlsconn := tls.Server(conn, p.mitm.TLSForHost(req.Host))
			brw.Writer.Reset(tlsconn)
			brw.Reader.Reset(tlsconn)

			return p.handle(ctx, tlsconn, brw)
		}

		log.Debugf("martian: attempting to establish CONNECT tunnel: %s", req.URL.Host)
		res, cconn, cerr := p.connect(req)
		if cerr != nil {
			log.Errorf("martian: failed to CONNECT: %v", err)
			res = proxyutil.NewResponse(502, nil, req)
			proxyutil.Warning(res.Header, cerr)

			if err := p.resmod.ModifyResponse(res); err != nil {
				log.Errorf("martian: error modifying CONNECT response: %v", err)
				proxyutil.Warning(res.Header, err)
			}

			res.Write(brw)
			return brw.Flush()
		}
		defer res.Body.Close()
		defer cconn.Close()

		if err := p.resmod.ModifyResponse(res); err != nil {
			log.Errorf("martian: error modifying CONNECT response: %v", err)
			proxyutil.Warning(res.Header, err)
		}

		res.Write(brw)
		brw.Flush()

		cbw := bufio.NewWriter(cconn)
		cbr := bufio.NewReader(cconn)
		defer cbw.Flush()

		copySync := func(w io.Writer, r io.Reader, donec chan<- bool) {
			io.Copy(w, r)
			donec <- true
		}

		donec := make(chan bool, 2)
		go copySync(cbw, brw, donec)
		go copySync(brw, cbr, donec)

		log.Debugf("martian: established CONNECT tunnel, proxying traffic")
		<-donec
		<-donec
		log.Debugf("martian: closed CONNECT tunnel")

		return errClose
	}

	if err := p.reqmod.ModifyRequest(req); err != nil {
		log.Errorf("martian: error modifying request: %v", err)
		proxyutil.Warning(req.Header, err)
	}

	res, err := p.roundTrip(ctx, req)
	if err != nil {
		log.Errorf("martian: failed to round trip: %v", err)
		res = proxyutil.NewResponse(502, nil, req)
		proxyutil.Warning(res.Header, err)
	}
	defer res.Body.Close()

	if err := p.resmod.ModifyResponse(res); err != nil {
		log.Errorf("martian: error modifying response: %v", err)
		proxyutil.Warning(res.Header, err)
	}

	var closing error
	if req.Close || p.Closing() {
		log.Debugf("martian: received close request: %v", req.RemoteAddr)
		res.Header.Add("Connection", "close")
		closing = errClose
	}

	log.Debugf("martian: sent response: %v", req.URL)
	res.Write(brw)
	brw.Flush()

	return closing
}
コード例 #20
0
ファイル: proxy_test.go プロジェクト: xichen124/martian
func TestIntegrationHTTP100Continue(t *testing.T) {
	t.Parallel()

	l, err := net.Listen("tcp", "[::1]:0")
	if err != nil {
		t.Fatalf("net.Listen(): got %v, want no error", err)
	}

	p := NewProxy()
	defer p.Close()

	p.SetTimeout(2 * time.Second)

	sl, err := net.Listen("tcp", "[::1]:0")
	if err != nil {
		t.Fatalf("net.Listen(): got %v, want no error", err)
	}

	go func() {
		conn, err := sl.Accept()
		if err != nil {
			log.Errorf("proxy_test: failed to accept connection: %v", err)
			return
		}
		defer conn.Close()

		log.Infof("proxy_test: accepted connection: %s", conn.RemoteAddr())

		req, err := http.ReadRequest(bufio.NewReader(conn))
		if err != nil {
			log.Errorf("proxy_test: failed to read request: %v", err)
			return
		}

		if req.Header.Get("Expect") == "100-continue" {
			log.Infof("proxy_test: received 100-continue request")

			conn.Write([]byte("HTTP/1.1 100 Continue\r\n\r\n"))

			log.Infof("proxy_test: sent 100-continue response")
		} else {
			log.Infof("proxy_test: received non 100-continue request")

			res := proxyutil.NewResponse(417, nil, req)
			res.Header.Set("Connection", "close")
			res.Write(conn)
			return
		}

		res := proxyutil.NewResponse(200, req.Body, req)
		res.Header.Set("Connection", "close")
		res.Write(conn)

		log.Infof("proxy_test: sent 200 response")
	}()

	tm := martiantest.NewModifier()
	p.SetRequestModifier(tm)
	p.SetResponseModifier(tm)

	go p.Serve(l)

	conn, err := net.Dial("tcp", l.Addr().String())
	if err != nil {
		t.Fatalf("net.Dial(): got %v, want no error", err)
	}
	defer conn.Close()

	host := sl.Addr().String()
	raw := fmt.Sprintf("POST http://%s/ HTTP/1.1\r\n"+
		"Host: %s\r\n"+
		"Content-Length: 12\r\n"+
		"Expect: 100-continue\r\n\r\n", host, host)

	if _, err := conn.Write([]byte(raw)); err != nil {
		t.Fatalf("conn.Write(headers): got %v, want no error", err)
	}

	go func() {
		select {
		case <-time.After(time.Second):
			conn.Write([]byte("body content"))
		}
	}()

	res, err := http.ReadResponse(bufio.NewReader(conn), nil)
	if err != nil {
		t.Fatalf("http.ReadResponse(): got %v, want no error", err)
	}
	defer res.Body.Close()

	if got, want := res.StatusCode, 200; got != want {
		t.Fatalf("res.StatusCode: got %d, want %d", got, want)
	}

	got, err := ioutil.ReadAll(res.Body)
	if err != nil {
		t.Fatalf("ioutil.ReadAll(): got %v, want no error", err)
	}

	if want := []byte("body content"); !bytes.Equal(got, want) {
		t.Errorf("res.Body: got %q, want %q", got, want)
	}

	if !tm.RequestModified() {
		t.Error("tm.RequestModified(): got false, want true")
	}
	if !tm.ResponseModified() {
		t.Error("tm.ResponseModified(): got false, want true")
	}
}