예제 #1
1
// Round trips the request to one of the endpoints and returns the response.
func (l *HttpLocation) RoundTrip(req request.Request) (*http.Response, error) {
	// Get options and transport as one single read transaction.
	// Options and transport may change if someone calls SetOptions
	o, tr := l.GetOptionsAndTransport()
	originalRequest := req.GetHttpRequest()

	//  Check request size first, if that exceeds the limit, we don't bother reading the request.
	if l.isRequestOverLimit(req) {
		return nil, errors.FromStatus(http.StatusRequestEntityTooLarge)
	}

	// Read the body while keeping this location's limits in mind. This reader controls the maximum bytes
	// to read into memory and disk. This reader returns anerror if the total request size exceeds the
	// prefefined MaxSizeBytes. This can occur if we got chunked request, in this case ContentLength would be set to -1
	// and the reader would be unbounded bufio in the http.Server
	body, err := netutils.NewBodyBufferWithOptions(originalRequest.Body, netutils.BodyBufferOptions{
		MemBufferBytes: o.Limits.MaxMemBodyBytes,
		MaxSizeBytes:   o.Limits.MaxBodyBytes,
	})
	if err != nil {
		return nil, err
	}
	if body == nil {
		return nil, fmt.Errorf("Empty body")
	}

	// Set request body to buffered reader that can replay the read and execute Seek
	req.SetBody(body)
	// Note that we don't change the original request Body as it's handled by the http server
	defer body.Close()

	for {
		_, err := req.GetBody().Seek(0, 0)
		if err != nil {
			return nil, err
		}

		endpoint, err := l.loadBalancer.NextEndpoint(req)
		if err != nil {
			log.Errorf("Load Balancer failure: %s", err)
			return nil, err
		}

		// Adds headers, changes urls. Note that we rewrite request each time we proxy it to the
		// endpoint, so that each try gets a fresh start
		req.SetHttpRequest(l.copyRequest(originalRequest, req.GetBody(), endpoint))

		// In case if error is not nil, we allow load balancer to choose the next endpoint
		// e.g. to do request failover. Nil error means that we got proxied the request successfully.
		response, err := l.proxyToEndpoint(tr, &o, endpoint, req)
		if o.ShouldFailover(req) {
			continue
		} else {
			return response, err
		}
	}
	log.Errorf("All endpoints failed!")
	return nil, fmt.Errorf("All endpoints failed")
}
예제 #2
0
// Proxy the request to the given endpoint, execute observers and middlewares chains
func (l *HttpLocation) proxyToEndpoint(tr *http.Transport, o *Options, endpoint endpoint.Endpoint, req request.Request) (*http.Response, error) {

	a := &request.BaseAttempt{Endpoint: endpoint}

	l.observerChain.ObserveRequest(req)
	defer l.observerChain.ObserveResponse(req, a)
	defer req.AddAttempt(a)

	it := l.middlewareChain.GetIter()
	defer l.unwindIter(it, req, a)

	for v := it.Next(); v != nil; v = it.Next() {
		a.Response, a.Error = v.ProcessRequest(req)
		if a.Response != nil || a.Error != nil {
			// Move the iterator forward to count it again once we unwind the chain
			it.Next()
			log.Errorf("Midleware intercepted request with response=%s, error=%s", a.Response.Status, a.Error)
			return a.Response, a.Error
		}
	}

	// Forward the request and mirror the response
	start := o.TimeProvider.UtcNow()
	a.Response, a.Error = tr.RoundTrip(req.GetHttpRequest())
	a.Duration = o.TimeProvider.UtcNow().Sub(start)
	return a.Response, a.Error
}
예제 #3
0
// RequestToClientIp is a TokenMapper that maps the request to the client IP.
func RequestToClientIp(req request.Request) (string, error) {
	vals := strings.SplitN(req.GetHttpRequest().RemoteAddr, ":", 2)
	if len(vals[0]) == 0 {
		return "", fmt.Errorf("Failed to parse client IP")
	}
	return vals[0], nil
}
예제 #4
0
func getMetrics(r request.Request) *metrics.RoundTripMetrics {
	m, ok := r.GetUserData(cbreakerMetrics)
	if !ok {
		return nil
	}
	return m.(*metrics.RoundTripMetrics)
}
예제 #5
0
func (r *RoundRobin) NextEndpoint(req request.Request) (endpoint.Endpoint, error) {
	r.mutex.Lock()
	defer r.mutex.Unlock()

	e, err := r.nextEndpoint(req)
	if err != nil {
		return nil, err
	}
	lastAttempt := req.GetLastAttempt()
	// This is the first try, so just return the selected endpoint
	if lastAttempt == nil {
		return e, nil
	}
	// Try to prevent failover to the same endpoint that we've seen before,
	// that reduces the probability of the scenario when failover hits same endpoint
	// on the next attempt and fails, so users will see a failed request.
	var endpoint endpoint.Endpoint
	for _ = range r.endpoints {
		endpoint, err = r.nextEndpoint(req)
		if err != nil {
			return nil, err
		}
		if !hasAttempted(req, endpoint) {
			return endpoint, nil
		}
	}
	return endpoint, nil
}
예제 #6
0
func (tl *TokenLimiter) ProcessRequest(r request.Request) (*http.Response, error) {
	tl.mutex.Lock()
	defer tl.mutex.Unlock()

	token, amount, err := tl.mapper(r)
	if err != nil {
		return nil, err
	}

	bucketI, exists := tl.buckets.Get(token)
	if !exists {
		bucketI, err = NewTokenBucket(tl.rate, tl.options.Burst+1, tl.options.TimeProvider)
		if err != nil {
			return nil, err
		}
		// We set ttl as 10 times rate period. E.g. if rate is 100 requests/second per client ip
		// the counters for this ip will expire after 10 seconds of inactivity
		tl.buckets.Set(token, bucketI, int(tl.rate.Period/time.Second)*10+1)
	}
	bucket := bucketI.(*TokenBucket)
	delay, err := bucket.Consume(amount)
	if err != nil {
		return nil, err
	}
	if delay > 0 {
		return netutils.NewTextResponse(r.GetHttpRequest(), errors.StatusTooManyRequests, "Too many requests"), nil
	}
	return nil, nil
}
예제 #7
0
func (tl *TokenLimiter) ProcessRequest(r request.Request) (*http.Response, error) {
	tl.mutex.Lock()
	defer tl.mutex.Unlock()

	token, amount, err := tl.mapper(r)
	if err != nil {
		return nil, err
	}

	effectiveRates := tl.effectiveRates(r)
	bucketSetI, exists := tl.bucketSets.Get(token)
	var bucketSet *tokenBucketSet

	if exists {
		bucketSet = bucketSetI.(*tokenBucketSet)
		bucketSet.update(effectiveRates)
	} else {
		bucketSet = newTokenBucketSet(effectiveRates, tl.clock)
		// We set ttl as 10 times rate period. E.g. if rate is 100 requests/second per client ip
		// the counters for this ip will expire after 10 seconds of inactivity
		tl.bucketSets.Set(token, bucketSet, int(bucketSet.maxPeriod/time.Second)*10+1)
	}

	delay, err := bucketSet.consume(amount)
	if err != nil {
		return nil, err
	}
	if delay > 0 {
		return netutils.NewTextResponse(r.GetHttpRequest(), errors.StatusTooManyRequests, "Too many requests"), nil
	}
	return nil, nil
}
예제 #8
0
func hasAttempted(req request.Request, endpoint endpoint.Endpoint) bool {
	for _, a := range req.GetAttempts() {
		if a.GetEndpoint().GetId() == endpoint.GetId() {
			return true
		}
	}
	return false
}
예제 #9
0
func (m *methodMatcher) match(req request.Request) location.Location {
	for _, c := range m.methods {
		if req.GetHttpRequest().Method == c {
			return m.matcher.match(req)
		}
	}
	return nil
}
예제 #10
0
func (e *ExpRouter) Route(req request.Request) (location.Location, error) {
	l, err := e.r.Route(req.GetHttpRequest())
	if err != nil {
		return nil, err
	}
	if l == nil {
		return nil, nil
	}
	return l.(location.Location), nil
}
예제 #11
0
// Takes the request and returns the location if the request path matches any of it's paths
// returns nil if none of the requests matches
func (p *trie) match(r request.Request) location.Location {
	if p.root == nil {
		return nil
	}

	path := r.GetHttpRequest().URL.Path
	if len(path) == 0 {
		path = "/"
	}
	return p.root.match(-1, path, r)
}
예제 #12
0
func (c *CircuitBreaker) checkCondition(r request.Request) bool {
	if !c.timeToCheck() {
		return false
	}

	c.m.Lock()
	defer c.m.Unlock()

	// Other goroutine could have updated the lastCheck variable before we grabbed mutex
	if !c.tm.UtcNow().After(c.lastCheck) {
		return false
	}
	c.lastCheck = c.tm.UtcNow().Add(c.checkPeriod)
	// Each requests holds a context attached to it, we use it to attach the metrics to the request
	// so condition checker function can use it for analysis on the next line.
	r.SetUserData(cbreakerMetrics, c.metrics)
	return c.condition(r)
}
예제 #13
0
func (rw *Rewriter) ProcessRequest(r request.Request) (*http.Response, error) {
	req := r.GetHttpRequest()

	if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
		if rw.TrustForwardHeader {
			if prior, ok := req.Header[headers.XForwardedFor]; ok {
				clientIP = strings.Join(prior, ", ") + ", " + clientIP
			}
		}
		req.Header.Set(headers.XForwardedFor, clientIP)
	}

	if xfp := req.Header.Get(headers.XForwardedProto); xfp != "" && rw.TrustForwardHeader {
		req.Header.Set(headers.XForwardedProto, xfp)
	} else if req.TLS != nil {
		req.Header.Set(headers.XForwardedProto, "https")
	} else {
		req.Header.Set(headers.XForwardedProto, "http")
	}

	if req.Host != "" {
		req.Header.Set(headers.XForwardedHost, req.Host)
	}
	req.Header.Set(headers.XForwardedServer, rw.Hostname)

	// Remove hop-by-hop headers to the backend.  Especially important is "Connection" because we want a persistent
	// connection, regardless of what the client sent to us.
	netutils.RemoveHeaders(headers.HopHeaders, req.Header)

	// We need to set ContentLength based on known request size. The incoming request may have been
	// set without content length or using chunked TransferEncoding
	totalSize, err := r.GetBody().TotalSize()
	if err != nil {
		return nil, err
	}
	req.ContentLength = totalSize
	// Remove TransferEncoding that could have been previously set
	req.TransferEncoding = []string{}

	return nil, nil
}
예제 #14
0
func (cl *ConnectionLimiter) ProcessRequest(r request.Request) (*http.Response, error) {
	cl.mutex.Lock()
	defer cl.mutex.Unlock()

	token, amount, err := cl.mapper(r)
	if err != nil {
		return nil, err
	}

	connections := cl.connections[token]
	if connections >= cl.maxConnections {
		return netutils.NewTextResponse(
			r.GetHttpRequest(),
			errors.StatusTooManyRequests,
			fmt.Sprintf("Connection limit reached. Max is: %d, yours: %d", cl.maxConnections, connections)), nil
	}

	cl.connections[token] += amount
	cl.totalConnections += int64(amount)
	return nil, nil
}
예제 #15
0
func (l *HttpLocation) isRequestOverLimit(req request.Request) bool {
	if l.options.Limits.MaxBodyBytes <= 0 {
		return false
	}
	return req.GetHttpRequest().ContentLength > l.options.Limits.MaxBodyBytes
}
예제 #16
0
// Failover in case if last attempt resulted in error
func IsNetworkError(req request.Request) bool {
	attempts := len(req.GetAttempts())
	return attempts != 0 && req.GetAttempts()[attempts-1].GetError() != nil
}
예제 #17
0
func (l *ConstHttpLocation) RoundTrip(r request.Request) (*http.Response, error) {
	req := r.GetHttpRequest()
	req.URL = netutils.MustParseUrl(l.Url)
	return http.DefaultTransport.RoundTrip(req)
}
예제 #18
0
func (c *CircuitBreaker) shouldRecordMetrics(r request.Request) bool {
	_, ok := r.GetUserData(cbreakerRecordMetrics)
	return ok
}
예제 #19
0
func (c *CircuitBreaker) markToRecordMetrics(r request.Request) {
	r.SetUserData(cbreakerRecordMetrics, true)
}
예제 #20
0
// RequestToHost maps request to the host value
func RequestToHost(req request.Request) (string, error) {
	return req.GetHttpRequest().Host, nil
}
예제 #21
0
// Maps request to it's size in bytes
func RequestToBytes(req request.Request) (int64, error) {
	return req.GetBody().TotalSize()
}
예제 #22
0
// Maps request to its path
func mapRequestToPath(req request.Request) string {
	return req.GetHttpRequest().URL.Path
}
예제 #23
0
func (re *Response) getHTTPResponse(r request.Request) *http.Response {
	return netutils.NewHttpResponse(r.GetHttpRequest(), re.StatusCode, re.Body, re.ContentType)
}