func testServer(t *testing.T) (cancel, step func(), resp <-chan *http.Response) { var ( ctx, cancelfn = context.WithCancel(context.Background()) stepch = make(chan bool) endpoint = func(context.Context, interface{}) (interface{}, error) { <-stepch; return struct{}{}, nil } response = make(chan *http.Response) handler = httptransport.Server{ Context: ctx, Endpoint: endpoint, DecodeRequestFunc: func(*http.Request) (interface{}, error) { return struct{}{}, nil }, EncodeResponseFunc: func(http.ResponseWriter, interface{}) error { return nil }, Before: []httptransport.RequestFunc{func(ctx context.Context, r *http.Request) context.Context { return ctx }}, After: []httptransport.ResponseFunc{func(ctx context.Context, w http.ResponseWriter) { return }}, Logger: log.NewNopLogger(), } ) go func() { server := httptest.NewServer(handler) defer server.Close() resp, err := http.Get(server.URL) if err != nil { t.Error(err) return } response <- resp }() return cancelfn, func() { stepch <- true }, response }
func TestCancelAfterRequest(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) resp, err := doRequest(ctx) // Cancel before reading the body. // Request.Body should still be readable after the context is canceled. cancel() b, err := ioutil.ReadAll(resp.Body) if err != nil || string(b) != requestBody { t.Fatalf("could not read body: %q %v", b, err) } }
func TestCancel(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) go func() { time.Sleep(requestDuration / 2) cancel() }() resp, err := doRequest(ctx) if resp != nil || err == nil { t.Fatalf("expected error, didn't get one. resp: %v", resp) } if err != ctx.Err() { t.Fatalf("expected error from context but got: %v", err) } }
// Endpoint returns a usable endpoint that will invoke the RPC specified by // the client. func (c Client) Endpoint() endpoint.Endpoint { return func(ctx context.Context, request interface{}) (interface{}, error) { ctx, cancel := context.WithCancel(ctx) defer cancel() req, err := http.NewRequest(c.Method, c.URL.String(), nil) if err != nil { return nil, fmt.Errorf("NewRequest: %v", err) } if err = c.EncodeRequestFunc(req, request); err != nil { return nil, fmt.Errorf("Encode: %v", err) } for _, f := range c.Before { ctx = f(ctx, req) } var resp *http.Response if c.Client != nil { resp, err = c.Client.Do(req) } else { resp, err = http.DefaultClient.Do(req) } if err != nil { return nil, fmt.Errorf("Do: %v", err) } defer func() { _ = resp.Body.Close() }() response, err := c.DecodeResponseFunc(resp) if err != nil { return nil, fmt.Errorf("Decode: %v", err) } return response, nil } }
// ServeHTTP implements http.Handler. func (s Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { if s.ErrorEncoder == nil { s.ErrorEncoder = defaultErrorEncoder } ctx, cancel := context.WithCancel(s.Context) defer cancel() for _, f := range s.Before { ctx = f(ctx, r) } request, err := s.DecodeRequestFunc(r) if err != nil { _ = s.Logger.Log("err", err) s.ErrorEncoder(w, badRequestError{err}) return } response, err := s.Endpoint(ctx, request) if err != nil { _ = s.Logger.Log("err", err) s.ErrorEncoder(w, err) return } for _, f := range s.After { f(ctx, w) } if err := s.EncodeResponseFunc(w, response); err != nil { _ = s.Logger.Log("err", err) s.ErrorEncoder(w, err) return } }