func TestOutboundHeaders(t *testing.T) { tests := []struct { desc string context context.Context headers transport.Headers wantHeaders map[string]string }{ { desc: "application headers", headers: transport.NewHeaders().With("foo", "bar").With("baz", "Qux"), wantHeaders: map[string]string{ "Rpc-Header-Foo": "bar", "Rpc-Header-Baz": "Qux", }, }, } for _, tt := range tests { server := httptest.NewServer(http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { defer r.Body.Close() for k, v := range tt.wantHeaders { assert.Equal( t, v, r.Header.Get(k), "%v: header %v did not match", tt.desc, k) } }, )) defer server.Close() ctx := tt.context if ctx == nil { var cancel context.CancelFunc ctx, cancel = context.WithTimeout(context.Background(), time.Second) defer cancel() } out := NewOutbound(server.URL) require.NoError(t, out.Start(transport.NoDeps), "failed to start outbound") defer out.Stop() res, err := out.Call(ctx, &transport.Request{ Caller: "caller", Service: "service", Encoding: raw.Encoding, Headers: tt.headers, Procedure: "hello", Body: bytes.NewReader([]byte("world")), }) if !assert.NoError(t, err, "%v: call failed", tt.desc) { continue } if !assert.NoError(t, res.Body.Close(), "%v: failed to close response body") { continue } } }
func TestResponseWriterAddHeadersAfterWrite(t *testing.T) { call := &fakeInboundCall{format: tchannel.Raw, resp: newResponseRecorder()} w := newResponseWriter(new(transport.Request), call) w.Write([]byte("foo")) assert.Panics(t, func() { w.AddHeaders(transport.NewHeaders().With("foo", "bar")) }) }
// Headers returns a new randomized header. func (r *randomGenerator) Headers() transport.Headers { headers := transport.NewHeaders() size := 2 + r.randsrc.Intn(6) for i := 0; i < size; i++ { headers = headers.With(r.Atom(), r.Atom()) } return headers }
func TestHandleSuccessWithResponseHeaders(t *testing.T) { h := func(ctx context.Context, r yarpc.ReqMeta, _ *simpleRequest) (*simpleResponse, yarpc.ResMeta, error) { resMeta := yarpc.NewResMeta().Headers(yarpc.NewHeaders().With("foo", "bar")) return &simpleResponse{Success: true}, resMeta, nil } handler := jsonHandler{ reader: structReader{reflect.TypeOf(simpleRequest{})}, handler: reflect.ValueOf(h), } resw := new(transporttest.FakeResponseWriter) err := handler.Handle(context.Background(), &transport.Request{ Procedure: "simpleCall", Encoding: "json", Body: jsonBody(`{"name": "foo", "attributes": {"bar": 42}}`), }, resw) require.NoError(t, err) assert.Equal(t, transport.NewHeaders().With("foo", "bar"), resw.Headers) }
func TestOutboundHeaders(t *testing.T) { tests := []struct { context context.Context headers transport.Headers wantHeaders []byte wantError string }{ { headers: transport.NewHeaders().With("contextfoo", "bar"), wantHeaders: []byte{ 0x00, 0x01, 0x00, 0x0A, 'c', 'o', 'n', 't', 'e', 'x', 't', 'f', 'o', 'o', 0x00, 0x03, 'b', 'a', 'r', }, }, { headers: transport.NewHeaders().With("Foo", "bar"), wantHeaders: []byte{ 0x00, 0x01, 0x00, 0x03, 'f', 'o', 'o', 0x00, 0x03, 'b', 'a', 'r', }, }, } for _, tt := range tests { server := testutils.NewServer(t, nil) defer server.Close() hostport := server.PeerInfo().HostPort server.GetSubChannel("service").SetHandler(tchannel.HandlerFunc( func(ctx context.Context, call *tchannel.InboundCall) { headers, body, err := readArgs(call) if assert.NoError(t, err, "failed to read request") { assert.Equal(t, tt.wantHeaders, headers, "headers did not match") assert.Equal(t, []byte("world"), body) } err = writeArgs(call.Response(), []byte{0x00, 0x00}, []byte("bye!")) assert.NoError(t, err, "failed to write response") })) for _, getOutbound := range newOutbounds { out := getOutbound(testutils.NewClient(t, &testutils.ChannelOpts{ ServiceName: "caller", }), hostport) require.NoError(t, out.Start(transport.NoDeps), "failed to start outbound") defer out.Stop() ctx := tt.context if ctx == nil { ctx = context.Background() } ctx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() res, err := out.Call( ctx, &transport.Request{ Caller: "caller", Service: "service", Encoding: raw.Encoding, Procedure: "hello", Headers: tt.headers, Body: bytes.NewReader([]byte("world")), }, ) if tt.wantError != "" { if assert.Error(t, err, "expected error") { assert.Contains(t, err.Error(), tt.wantError) } } else { if assert.NoError(t, err, "call failed") { defer res.Body.Close() } } } } }
func TestRawHandler(t *testing.T) { // handler to use for test cases where the handler should not be called handlerNotCalled := func(ctx context.Context, reqMeta yarpc.ReqMeta, body []byte) ([]byte, yarpc.ResMeta, error) { t.Errorf("unexpected call handle(%v, %v)", reqMeta, body) return nil, nil, fmt.Errorf("unexpected call handle(%v, %v)", reqMeta, body) } tests := []struct { procedure string headers transport.Headers bodyChunks [][]byte handler UnaryHandler wantErr string wantHeaders transport.Headers wantBody []byte }{ { procedure: "foo", bodyChunks: [][]byte{ {1, 2, 3}, {4, 5, 6}, }, handler: func(ctx context.Context, reqMeta yarpc.ReqMeta, body []byte) ([]byte, yarpc.ResMeta, error) { assert.Equal(t, "foo", reqMeta.Procedure()) assert.Equal(t, []byte{1, 2, 3, 4, 5, 6}, body) return []byte("hello"), nil, nil }, wantBody: []byte("hello"), }, { procedure: "bar", bodyChunks: [][]byte{ {1, 2, 3}, nil, // triggers a read error {4, 5, 6}, }, handler: handlerNotCalled, wantErr: "error set by user", // TODO consistent error messages between languages }, { procedure: "baz", bodyChunks: [][]byte{}, handler: func(ctx context.Context, reqMeta yarpc.ReqMeta, body []byte) ([]byte, yarpc.ResMeta, error) { assert.Equal(t, []byte{}, body) return nil, nil, fmt.Errorf("great sadness") }, wantErr: "great sadness", }, { procedure: "responseHeaders", bodyChunks: [][]byte{}, handler: func(ctx context.Context, reqMeta yarpc.ReqMeta, body []byte) ([]byte, yarpc.ResMeta, error) { resMeta := yarpc.NewResMeta().Headers(yarpc.NewHeaders().With("hello", "world")) return []byte{}, resMeta, nil }, wantHeaders: transport.NewHeaders().With("hello", "world"), }, } for _, tt := range tests { handler := rawUnaryHandler{tt.handler} resw := new(transporttest.FakeResponseWriter) writer, chunkReader := testreader.ChunkReader() for _, chunk := range tt.bodyChunks { writer <- chunk } close(writer) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() err := handler.Handle(ctx, &transport.Request{ Procedure: tt.procedure, Headers: tt.headers, Encoding: "raw", Body: chunkReader, }, resw) if tt.wantErr != "" { if assert.Error(t, err) { assert.Equal(t, err.Error(), tt.wantErr) } } else { if assert.NoError(t, err) { assert.Equal(t, tt.wantHeaders, resw.Headers) assert.Equal(t, tt.wantBody, resw.Body.Bytes(), "body does not match for %s", tt.procedure) } } } }
func TestSimpleRoundTripOneway(t *testing.T) { trans := httpTransport{t} tests := []struct { name string requestHeaders transport.Headers requestBody string }{ { name: "hello world", requestHeaders: transport.NewHeaders().With("foo", "bar"), requestBody: "hello world", }, { name: "empty", requestHeaders: transport.NewHeaders(), requestBody: "", }, } rootCtx := context.Background() for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { requestMatcher := transporttest.NewRequestMatcher(t, &transport.Request{ Caller: testCaller, Service: testService, Procedure: testProcedureOneway, Encoding: raw.Encoding, Headers: tt.requestHeaders, Body: bytes.NewReader([]byte(tt.requestBody)), }) handlerDone := make(chan struct{}) onewayHandler := onewayHandlerFunc(func(_ context.Context, r *transport.Request) error { assert.True(t, requestMatcher.Matches(r), "request mismatch: received %v", r) // Pretend to work: this delay should not slow down tests since it is a // server-side operation time.Sleep(5 * time.Second) // close the channel, telling the client (which should not be waiting for // a response) that the handler finished executing close(handlerDone) return nil }) registry := staticRegistry{OnewayHandler: onewayHandler} trans.WithRegistryOneway(registry, func(o transport.OnewayOutbound) { ack, err := o.CallOneway(rootCtx, &transport.Request{ Caller: testCaller, Service: testService, Procedure: testProcedureOneway, Encoding: raw.Encoding, Headers: tt.requestHeaders, Body: bytes.NewReader([]byte(tt.requestBody)), }) select { case <-handlerDone: // if the server filled the channel, it means we waited for the server // to complete the request assert.Fail(t, "client waited for server handler to finish executing") default: } if assert.NoError(t, err, "%T: oneway call failed for test '%v'", trans, tt.name) { assert.NotNil(t, ack) } }) }) } }
func TestSimpleRoundTrip(t *testing.T) { transports := []roundTripTransport{ httpTransport{t}, tchannelTransport{t}, } tests := []struct { requestHeaders transport.Headers requestBody string responseHeaders transport.Headers responseBody string responseError error wantError func(error) }{ { requestHeaders: transport.NewHeaders().With("token", "1234"), requestBody: "world", responseHeaders: transport.NewHeaders().With("status", "ok"), responseBody: "hello, world", }, { requestBody: "foo", responseError: errors.HandlerUnexpectedError(fmt.Errorf("great sadness")), wantError: func(err error) { assert.True(t, transport.IsUnexpectedError(err), err) assert.Equal(t, "UnexpectedError: great sadness", err.Error()) }, }, { requestBody: "bar", responseError: errors.HandlerBadRequestError(fmt.Errorf("missing service name")), wantError: func(err error) { assert.True(t, transport.IsBadRequestError(err)) assert.Equal(t, "BadRequest: missing service name", err.Error()) }, }, { requestBody: "baz", responseError: errors.RemoteUnexpectedError( `UnexpectedError: error for procedure "foo" of service "bar": great sadness`, ), wantError: func(err error) { assert.True(t, transport.IsUnexpectedError(err)) assert.Equal(t, `UnexpectedError: error for procedure "hello" of service "testService": `+ `UnexpectedError: error for procedure "foo" of service "bar": great sadness`, err.Error()) }, }, { requestBody: "qux", responseError: errors.RemoteBadRequestError( `BadRequest: unrecognized procedure "echo" for service "derp"`, ), wantError: func(err error) { assert.True(t, transport.IsUnexpectedError(err)) assert.Equal(t, `UnexpectedError: error for procedure "hello" of service "testService": `+ `BadRequest: unrecognized procedure "echo" for service "derp"`, err.Error()) }, }, } rootCtx := context.Background() for _, tt := range tests { for _, trans := range transports { requestMatcher := transporttest.NewRequestMatcher(t, &transport.Request{ Caller: testCaller, Service: testService, Procedure: testProcedure, Encoding: raw.Encoding, Headers: tt.requestHeaders, Body: bytes.NewReader([]byte(tt.requestBody)), }) handler := unaryHandlerFunc(func(_ context.Context, r *transport.Request, w transport.ResponseWriter) error { assert.True(t, requestMatcher.Matches(r), "request mismatch: received %v", r) if tt.responseError != nil { return tt.responseError } if tt.responseHeaders.Len() > 0 { w.AddHeaders(tt.responseHeaders) } _, err := w.Write([]byte(tt.responseBody)) assert.NoError(t, err, "failed to write response for %v", r) return err }) ctx, cancel := context.WithTimeout(rootCtx, 200*time.Millisecond) defer cancel() registry := staticRegistry{Handler: handler} trans.WithRegistry(registry, func(o transport.UnaryOutbound) { res, err := o.Call(ctx, &transport.Request{ Caller: testCaller, Service: testService, Procedure: testProcedure, Encoding: raw.Encoding, Headers: tt.requestHeaders, Body: bytes.NewReader([]byte(tt.requestBody)), }) if tt.wantError != nil { if assert.Error(t, err, "%T: expected error, got %v", trans, res) { tt.wantError(err) // none of the errors returned by Call can be valid // Handler errors. _, ok := err.(errors.HandlerError) assert.False(t, ok, "%T: %T must not be a HandlerError", trans, err) } } else { responseMatcher := transporttest.NewResponseMatcher(t, &transport.Response{ Headers: tt.responseHeaders, Body: ioutil.NopCloser(bytes.NewReader([]byte(tt.responseBody))), }) if assert.NoError(t, err, "%T: call failed", trans) { assert.True(t, responseMatcher.Matches(res), "%T: response mismatch", trans) } } }) } } }
// Call makes a HTTP request func (o *Outbound) Call(ctx context.Context, treq *transport.Request) (*transport.Response, error) { if !o.started.Load() { // panic because there's no recovery from this panic(errOutboundNotStarted) } start := time.Now() deadline, _ := ctx.Deadline() ttl := deadline.Sub(start) peer, err := o.getPeerForRequest(ctx, treq) if err != nil { return nil, err } endRequest := peer.StartRequest() defer endRequest() req, err := o.createRequest(peer, treq) if err != nil { return nil, err } req.Header = applicationHeaders.ToHTTPHeaders(treq.Headers, nil) ctx, req, span := o.withOpentracingSpan(ctx, req, treq, start) defer span.Finish() req = o.withCoreHeaders(req, treq, ttl) client, err := o.getHTTPClient(peer) if err != nil { return nil, err } response, err := client.Do(req.WithContext(ctx)) if err != nil { // Workaround borrowed from ctxhttp until // https://github.com/golang/go/issues/17711 is resolved. select { case <-ctx.Done(): err = ctx.Err() default: } span.SetTag("error", true) span.LogEvent(err.Error()) if err == context.DeadlineExceeded { end := time.Now() return nil, errors.ClientTimeoutError(treq.Service, treq.Procedure, end.Sub(start)) } return nil, err } span.SetTag("http.status_code", response.StatusCode) if response.StatusCode >= 200 && response.StatusCode < 300 { appHeaders := applicationHeaders.FromHTTPHeaders( response.Header, transport.NewHeaders()) return &transport.Response{ Headers: appHeaders, Body: response.Body, }, nil } return nil, getErrFromResponse(response) }