func TestCallOnewayFailure(t *testing.T) { mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() ctx := context.Background() caller := "caller" service := "service" procedure := "procedure" body := []byte{1, 2, 3} outbound := transporttest.NewMockOnewayOutbound(mockCtrl) client := New(channel.MultiOutbound(caller, service, transport.Outbounds{ Oneway: outbound, })) outbound.EXPECT().CallOneway(gomock.Any(), transporttest.NewRequestMatcher(t, &transport.Request{ Service: service, Caller: caller, Procedure: procedure, Encoding: Encoding, Body: bytes.NewReader(body), }), ).Return(nil, errors.New("some error")) _, err := client.CallOneway( ctx, yarpc.NewReqMeta().Procedure(procedure), body) assert.Error(t, err) }
func TestHandlerInternalFailure(t *testing.T) { mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() headers := make(http.Header) headers.Set(CallerHeader, "somecaller") headers.Set(EncodingHeader, "raw") headers.Set(TTLMSHeader, "1000") headers.Set(ProcedureHeader, "hello") headers.Set(ServiceHeader, "fake") request := http.Request{ Method: "POST", Header: headers, Body: ioutil.NopCloser(bytes.NewReader([]byte{})), } rpcHandler := transporttest.NewMockUnaryHandler(mockCtrl) rpcHandler.EXPECT().Handle( transporttest.NewContextMatcher(t, transporttest.ContextTTL(time.Second)), transporttest.NewRequestMatcher( t, &transport.Request{ Caller: "somecaller", Service: "fake", Encoding: raw.Encoding, Procedure: "hello", Body: bytes.NewReader([]byte{}), }, ), gomock.Any(), ).Return(fmt.Errorf("great sadness")) registry := transporttest.NewMockRegistry(mockCtrl) spec := transport.NewUnaryHandlerSpec(rpcHandler) registry.EXPECT().GetHandlerSpec("fake", "hello").Return(spec, nil) httpHandler := handler{Registry: registry} httpResponse := httptest.NewRecorder() httpHandler.ServeHTTP(httpResponse, &request) code := httpResponse.Code assert.True(t, code >= 500 && code < 600, "expected 500 level response") assert.Equal(t, `UnexpectedError: error for procedure "hello" of service "fake": great sadness`+"\n", httpResponse.Body.String()) }
func TestHandlerSucces(t *testing.T) { mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() headers := make(http.Header) headers.Set(CallerHeader, "moe") headers.Set(EncodingHeader, "raw") headers.Set(TTLMSHeader, "1000") headers.Set(ProcedureHeader, "nyuck") headers.Set(ServiceHeader, "curly") registry := transporttest.NewMockRegistry(mockCtrl) rpcHandler := transporttest.NewMockUnaryHandler(mockCtrl) spec := transport.NewUnaryHandlerSpec(rpcHandler) registry.EXPECT().GetHandlerSpec("curly", "nyuck").Return(spec, nil) rpcHandler.EXPECT().Handle( transporttest.NewContextMatcher(t, transporttest.ContextTTL(time.Second), ), transporttest.NewRequestMatcher( t, &transport.Request{ Caller: "moe", Service: "curly", Encoding: raw.Encoding, Procedure: "nyuck", Body: bytes.NewReader([]byte("Nyuck Nyuck")), }, ), gomock.Any(), ).Return(nil) httpHandler := handler{Registry: registry} req := &http.Request{ Method: "POST", Header: headers, Body: ioutil.NopCloser(bytes.NewReader([]byte("Nyuck Nyuck"))), } rw := httptest.NewRecorder() httpHandler.ServeHTTP(rw, req) code := rw.Code assert.Equal(t, code, 200, "expected 200 code") assert.Equal(t, rw.Body.String(), "") }
func TestCall(t *testing.T) { mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() ctx := context.Background() caller := "caller" service := "service" tests := []struct { procedure string headers yarpc.Headers body []byte responseBody [][]byte want []byte wantErr string wantHeaders yarpc.Headers }{ { procedure: "foo", body: []byte{1, 2, 3}, responseBody: [][]byte{{4}, {5}, {6}}, want: []byte{4, 5, 6}, }, { procedure: "bar", body: []byte{1, 2, 3}, responseBody: [][]byte{{4}, {5}, nil, {6}}, wantErr: "error set by user", }, { procedure: "headers", headers: yarpc.NewHeaders().With("x", "y"), body: []byte{}, responseBody: [][]byte{}, want: []byte{}, wantHeaders: yarpc.NewHeaders().With("a", "b"), }, } for _, tt := range tests { outbound := transporttest.NewMockUnaryOutbound(mockCtrl) client := New(channel.MultiOutbound(caller, service, transport.Outbounds{ Unary: outbound, })) writer, responseBody := testreader.ChunkReader() for _, chunk := range tt.responseBody { writer <- chunk } close(writer) outbound.EXPECT().Call(gomock.Any(), transporttest.NewRequestMatcher(t, &transport.Request{ Caller: caller, Service: service, Procedure: tt.procedure, Headers: transport.Headers(tt.headers), Encoding: Encoding, Body: bytes.NewReader(tt.body), }), ).Return( &transport.Response{ Body: ioutil.NopCloser(responseBody), Headers: transport.Headers(tt.wantHeaders), }, nil) resBody, res, err := client.Call( ctx, yarpc.NewReqMeta().Procedure(tt.procedure).Headers(tt.headers), tt.body) if tt.wantErr != "" { if assert.Error(t, err) { assert.Equal(t, err.Error(), tt.wantErr) } } else { if assert.NoError(t, err) { assert.Equal(t, tt.want, resBody) assert.Equal(t, tt.wantHeaders, res.Headers()) } } } }
func TestCallOneway(t *testing.T) { mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() ctx := context.Background() caller := "caller" service := "service" tests := []struct { procedure string headers yarpc.Headers body []byte wantErr string wantHeaders yarpc.Headers }{ { procedure: "foo", body: []byte{1, 2, 3}, }, { procedure: "headers", headers: yarpc.NewHeaders().With("x", "y"), body: []byte{}, }, } for _, tt := range tests { outbound := transporttest.NewMockOnewayOutbound(mockCtrl) client := New(channel.MultiOutbound(caller, service, transport.Outbounds{ Oneway: outbound, })) outbound.EXPECT().CallOneway(gomock.Any(), transporttest.NewRequestMatcher(t, &transport.Request{ Caller: caller, Service: service, Procedure: tt.procedure, Headers: transport.Headers(tt.headers), Encoding: Encoding, Body: bytes.NewReader(tt.body), }), ).Return(&successAck{}, nil) ack, err := client.CallOneway( ctx, yarpc.NewReqMeta().Procedure(tt.procedure).Headers(tt.headers), tt.body) if tt.wantErr != "" { if assert.Error(t, err) { assert.Equal(t, err.Error(), tt.wantErr) } } else { assert.Equal(t, "success", ack.String()) } } }
func TestSimpleRoundTripOneway(t *testing.T) { trans := httpTransport{t} tests := []struct { name string requestHeaders transport.Headers requestBody string }{ { name: "hello world", requestHeaders: transport.NewHeaders().With("foo", "bar"), requestBody: "hello world", }, { name: "empty", requestHeaders: transport.NewHeaders(), requestBody: "", }, } rootCtx := context.Background() for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { requestMatcher := transporttest.NewRequestMatcher(t, &transport.Request{ Caller: testCaller, Service: testService, Procedure: testProcedureOneway, Encoding: raw.Encoding, Headers: tt.requestHeaders, Body: bytes.NewReader([]byte(tt.requestBody)), }) handlerDone := make(chan struct{}) onewayHandler := onewayHandlerFunc(func(_ context.Context, r *transport.Request) error { assert.True(t, requestMatcher.Matches(r), "request mismatch: received %v", r) // Pretend to work: this delay should not slow down tests since it is a // server-side operation time.Sleep(5 * time.Second) // close the channel, telling the client (which should not be waiting for // a response) that the handler finished executing close(handlerDone) return nil }) registry := staticRegistry{OnewayHandler: onewayHandler} trans.WithRegistryOneway(registry, func(o transport.OnewayOutbound) { ack, err := o.CallOneway(rootCtx, &transport.Request{ Caller: testCaller, Service: testService, Procedure: testProcedureOneway, Encoding: raw.Encoding, Headers: tt.requestHeaders, Body: bytes.NewReader([]byte(tt.requestBody)), }) select { case <-handlerDone: // if the server filled the channel, it means we waited for the server // to complete the request assert.Fail(t, "client waited for server handler to finish executing") default: } if assert.NoError(t, err, "%T: oneway call failed for test '%v'", trans, tt.name) { assert.NotNil(t, ack) } }) }) } }
func TestSimpleRoundTrip(t *testing.T) { transports := []roundTripTransport{ httpTransport{t}, tchannelTransport{t}, } tests := []struct { requestHeaders transport.Headers requestBody string responseHeaders transport.Headers responseBody string responseError error wantError func(error) }{ { requestHeaders: transport.NewHeaders().With("token", "1234"), requestBody: "world", responseHeaders: transport.NewHeaders().With("status", "ok"), responseBody: "hello, world", }, { requestBody: "foo", responseError: errors.HandlerUnexpectedError(fmt.Errorf("great sadness")), wantError: func(err error) { assert.True(t, transport.IsUnexpectedError(err), err) assert.Equal(t, "UnexpectedError: great sadness", err.Error()) }, }, { requestBody: "bar", responseError: errors.HandlerBadRequestError(fmt.Errorf("missing service name")), wantError: func(err error) { assert.True(t, transport.IsBadRequestError(err)) assert.Equal(t, "BadRequest: missing service name", err.Error()) }, }, { requestBody: "baz", responseError: errors.RemoteUnexpectedError( `UnexpectedError: error for procedure "foo" of service "bar": great sadness`, ), wantError: func(err error) { assert.True(t, transport.IsUnexpectedError(err)) assert.Equal(t, `UnexpectedError: error for procedure "hello" of service "testService": `+ `UnexpectedError: error for procedure "foo" of service "bar": great sadness`, err.Error()) }, }, { requestBody: "qux", responseError: errors.RemoteBadRequestError( `BadRequest: unrecognized procedure "echo" for service "derp"`, ), wantError: func(err error) { assert.True(t, transport.IsUnexpectedError(err)) assert.Equal(t, `UnexpectedError: error for procedure "hello" of service "testService": `+ `BadRequest: unrecognized procedure "echo" for service "derp"`, err.Error()) }, }, } rootCtx := context.Background() for _, tt := range tests { for _, trans := range transports { requestMatcher := transporttest.NewRequestMatcher(t, &transport.Request{ Caller: testCaller, Service: testService, Procedure: testProcedure, Encoding: raw.Encoding, Headers: tt.requestHeaders, Body: bytes.NewReader([]byte(tt.requestBody)), }) handler := unaryHandlerFunc(func(_ context.Context, r *transport.Request, w transport.ResponseWriter) error { assert.True(t, requestMatcher.Matches(r), "request mismatch: received %v", r) if tt.responseError != nil { return tt.responseError } if tt.responseHeaders.Len() > 0 { w.AddHeaders(tt.responseHeaders) } _, err := w.Write([]byte(tt.responseBody)) assert.NoError(t, err, "failed to write response for %v", r) return err }) ctx, cancel := context.WithTimeout(rootCtx, 200*time.Millisecond) defer cancel() registry := staticRegistry{Handler: handler} trans.WithRegistry(registry, func(o transport.UnaryOutbound) { res, err := o.Call(ctx, &transport.Request{ Caller: testCaller, Service: testService, Procedure: testProcedure, Encoding: raw.Encoding, Headers: tt.requestHeaders, Body: bytes.NewReader([]byte(tt.requestBody)), }) if tt.wantError != nil { if assert.Error(t, err, "%T: expected error, got %v", trans, res) { tt.wantError(err) // none of the errors returned by Call can be valid // Handler errors. _, ok := err.(errors.HandlerError) assert.False(t, ok, "%T: %T must not be a HandlerError", trans, err) } } else { responseMatcher := transporttest.NewResponseMatcher(t, &transport.Response{ Headers: tt.responseHeaders, Body: ioutil.NopCloser(bytes.NewReader([]byte(tt.responseBody))), }) if assert.NoError(t, err, "%T: call failed", trans) { assert.True(t, responseMatcher.Matches(res), "%T: response mismatch", trans) } } }) } } }
func TestHandlerHeaders(t *testing.T) { mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() tests := []struct { giveHeaders http.Header wantTTL time.Duration wantHeaders map[string]string }{ { giveHeaders: http.Header{ TTLMSHeader: {"1000"}, "Rpc-Header-Foo": {"bar"}, }, wantTTL: time.Second, wantHeaders: map[string]string{ "foo": "bar", }, }, { giveHeaders: http.Header{ TTLMSHeader: {"100"}, "Rpc-Foo": {"ignored"}, }, wantTTL: 100 * time.Millisecond, wantHeaders: map[string]string{}, }, } for _, tt := range tests { registry := transporttest.NewMockRegistry(mockCtrl) rpcHandler := transporttest.NewMockUnaryHandler(mockCtrl) spec := transport.NewUnaryHandlerSpec(rpcHandler) registry.EXPECT().GetHandlerSpec("service", "hello").Return(spec, nil) httpHandler := handler{Registry: registry} rpcHandler.EXPECT().Handle( transporttest.NewContextMatcher(t, transporttest.ContextTTL(tt.wantTTL), ), transporttest.NewRequestMatcher(t, &transport.Request{ Caller: "caller", Service: "service", Encoding: raw.Encoding, Procedure: "hello", Headers: transport.HeadersFromMap(tt.wantHeaders), Body: bytes.NewReader([]byte("world")), }), gomock.Any(), ).Return(nil) headers := http.Header{} for k, vs := range tt.giveHeaders { for _, v := range vs { headers.Add(k, v) } } headers.Set(CallerHeader, "caller") headers.Set(ServiceHeader, "service") headers.Set(EncodingHeader, "raw") headers.Set(ProcedureHeader, "hello") req := &http.Request{ Method: "POST", Header: headers, Body: ioutil.NopCloser(bytes.NewReader([]byte("world"))), } rw := httptest.NewRecorder() httpHandler.ServeHTTP(rw, req) assert.Equal(t, 200, rw.Code, "expected 200 status code") } }
func TestClient(t *testing.T) { tests := []struct { desc string giveRequestBody envelope.Enveloper // outgoing request body giveResponseEnvelope *wire.Envelope // returned on DecodeEnveloped() giveResponseBody *wire.Value // return on Decode() clientOptions []ClientOption expectCall bool // whether outbound.Call is expected wantRequestEnvelope *wire.Envelope // expect EncodeEnveloped(x) wantRequestBody *wire.Value // expect Encode(x) wantError string // whether an error is expected }{ { desc: "happy case", clientOptions: []ClientOption{Enveloped}, giveRequestBody: fakeEnveloper(wire.Call), wantRequestEnvelope: &wire.Envelope{ Name: "someMethod", SeqID: 1, Type: wire.Call, Value: wire.NewValueStruct(wire.Struct{}), }, expectCall: true, giveResponseEnvelope: &wire.Envelope{ Name: "someMethod", SeqID: 1, Type: wire.Reply, Value: wire.NewValueStruct(wire.Struct{}), }, }, { desc: "happy case without enveloping", giveRequestBody: fakeEnveloper(wire.Call), wantRequestBody: valueptr(wire.NewValueStruct(wire.Struct{})), expectCall: true, giveResponseBody: valueptr(wire.NewValueStruct(wire.Struct{})), }, { desc: "wrong envelope type for request", clientOptions: []ClientOption{Enveloped}, giveRequestBody: fakeEnveloper(wire.Reply), wantError: `failed to encode "thrift" request body for procedure ` + `"MyService::someMethod" of service "service": unexpected envelope type: Reply`, }, { desc: "TApplicationException", clientOptions: []ClientOption{Enveloped}, giveRequestBody: fakeEnveloper(wire.Call), wantRequestEnvelope: &wire.Envelope{ Name: "someMethod", SeqID: 1, Type: wire.Call, Value: wire.NewValueStruct(wire.Struct{}), }, expectCall: true, giveResponseEnvelope: &wire.Envelope{ Name: "someMethod", SeqID: 1, Type: wire.Exception, Value: wire.NewValueStruct(wire.Struct{Fields: []wire.Field{ {ID: 1, Value: wire.NewValueString("great sadness")}, {ID: 2, Value: wire.NewValueI32(7)}, }}), }, wantError: `thrift request to procedure "MyService::someMethod" of ` + `service "service" encountered an internal failure: ` + "TApplicationException{Message: great sadness, Type: PROTOCOL_ERROR}", }, { desc: "wrong envelope type for response", clientOptions: []ClientOption{Enveloped}, giveRequestBody: fakeEnveloper(wire.Call), wantRequestEnvelope: &wire.Envelope{ Name: "someMethod", SeqID: 1, Type: wire.Call, Value: wire.NewValueStruct(wire.Struct{}), }, expectCall: true, giveResponseEnvelope: &wire.Envelope{ Name: "someMethod", SeqID: 1, Type: wire.Call, Value: wire.NewValueStruct(wire.Struct{}), }, wantError: `failed to decode "thrift" response body for procedure ` + `"MyService::someMethod" of service "service": unexpected envelope type: Call`, }, } for _, tt := range tests { mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() proto := NewMockProtocol(mockCtrl) if tt.wantRequestEnvelope != nil { proto.EXPECT().EncodeEnveloped(*tt.wantRequestEnvelope, gomock.Any()). Do(func(_ wire.Envelope, w io.Writer) { _, err := w.Write([]byte("irrelevant")) require.NoError(t, err, "Write() failed") }).Return(nil) } if tt.wantRequestBody != nil { proto.EXPECT().Encode(*tt.wantRequestBody, gomock.Any()). Do(func(_ wire.Value, w io.Writer) { _, err := w.Write([]byte("irrelevant")) require.NoError(t, err, "Write() failed") }).Return(nil) } ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() trans := transporttest.NewMockUnaryOutbound(mockCtrl) if tt.expectCall { trans.EXPECT().Call(ctx, transporttest.NewRequestMatcher(t, &transport.Request{ Caller: "caller", Service: "service", Encoding: Encoding, Procedure: "MyService::someMethod", Body: bytes.NewReader([]byte("irrelevant")), }), ).Return(&transport.Response{ Body: ioutil.NopCloser(bytes.NewReader([]byte("irrelevant"))), }, nil) } if tt.giveResponseEnvelope != nil { proto.EXPECT().DecodeEnveloped(gomock.Any()).Return(*tt.giveResponseEnvelope, nil) } if tt.giveResponseBody != nil { proto.EXPECT().Decode(gomock.Any(), wire.TStruct).Return(*tt.giveResponseBody, nil) } opts := tt.clientOptions opts = append(opts, Protocol(proto)) c := New(Config{ Service: "MyService", Channel: channel.MultiOutbound("caller", "service", transport.Outbounds{ Unary: trans, }), }, opts...) _, _, err := c.Call(ctx, nil, tt.giveRequestBody) if tt.wantError != "" { if assert.Error(t, err, "%v: expected failure", tt.desc) { assert.Contains(t, err.Error(), tt.wantError, "%v: error mismatch", tt.desc) } } else { assert.NoError(t, err, "%v: expected success", tt.desc) } } }
func TestClientOneway(t *testing.T) { caller, service, procedure := "caller", "MyService", "someMethod" tests := []struct { desc string giveRequestBody envelope.Enveloper // outgoing request body clientOptions []ClientOption expectCall bool // whether outbound.Call is expected wantRequestEnvelope *wire.Envelope // expect EncodeEnveloped(x) wantRequestBody *wire.Value // expect Encode(x) wantError string // whether an error is expected }{ { desc: "happy case", giveRequestBody: fakeEnveloper(wire.Call), clientOptions: []ClientOption{Enveloped}, expectCall: true, wantRequestEnvelope: &wire.Envelope{ Name: procedure, SeqID: 1, Type: wire.Call, Value: wire.NewValueStruct(wire.Struct{}), }, }, { desc: "happy case without enveloping", giveRequestBody: fakeEnveloper(wire.Call), expectCall: true, wantRequestBody: valueptr(wire.NewValueStruct(wire.Struct{})), }, { desc: "wrong envelope type for request", giveRequestBody: fakeEnveloper(wire.Reply), clientOptions: []ClientOption{Enveloped}, wantError: `failed to encode "thrift" request body for procedure ` + `"MyService::someMethod" of service "MyService": unexpected envelope type: Reply`, }, } for _, tt := range tests { mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() proto := NewMockProtocol(mockCtrl) bodyBytes := []byte("irrelevant") if tt.wantRequestEnvelope != nil { proto.EXPECT().EncodeEnveloped(*tt.wantRequestEnvelope, gomock.Any()). Do(func(_ wire.Envelope, w io.Writer) { _, err := w.Write(bodyBytes) require.NoError(t, err, "Write() failed") }).Return(nil) } if tt.wantRequestBody != nil { proto.EXPECT().Encode(*tt.wantRequestBody, gomock.Any()). Do(func(_ wire.Value, w io.Writer) { _, err := w.Write(bodyBytes) require.NoError(t, err, "Write() failed") }).Return(nil) } ctx := context.Background() onewayOutbound := transporttest.NewMockOnewayOutbound(mockCtrl) requestMatcher := transporttest.NewRequestMatcher(t, &transport.Request{ Caller: caller, Service: service, Encoding: Encoding, Procedure: procedureName(service, procedure), Body: bytes.NewReader(bodyBytes), }) if tt.expectCall { if tt.wantError != "" { onewayOutbound. EXPECT(). CallOneway(ctx, requestMatcher). Return(nil, errors.New(tt.wantError)) } else { onewayOutbound. EXPECT(). CallOneway(ctx, requestMatcher). Return(&successAck{}, nil) } } opts := tt.clientOptions opts = append(opts, Protocol(proto)) c := New(Config{ Service: service, Channel: channel.MultiOutbound(caller, service, transport.Outbounds{ Oneway: onewayOutbound, }), }, opts...) ack, err := c.CallOneway(ctx, nil, tt.giveRequestBody) if tt.wantError != "" { if assert.Error(t, err, "%v: expected failure", tt.desc) { assert.Contains(t, err.Error(), tt.wantError, "%v: error mismatch", tt.desc) } } else { assert.NoError(t, err, "%v: expected success", tt.desc) assert.Equal(t, "success", ack.String()) } } }
func TestHandlerErrors(t *testing.T) { mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() tests := []struct { format tchannel.Format headers []byte wantHeaders map[string]string }{ { format: tchannel.JSON, headers: []byte(`{"Rpc-Header-Foo": "bar"}`), wantHeaders: map[string]string{"rpc-header-foo": "bar"}, }, { format: tchannel.Thrift, headers: []byte{ 0x00, 0x01, // 1 header 0x00, 0x03, 'F', 'o', 'o', // Foo 0x00, 0x03, 'B', 'a', 'r', // Bar }, wantHeaders: map[string]string{"foo": "Bar"}, }, } for _, tt := range tests { rpcHandler := transporttest.NewMockUnaryHandler(mockCtrl) registry := transporttest.NewMockRegistry(mockCtrl) spec := transport.NewUnaryHandlerSpec(rpcHandler) tchHandler := handler{Registry: registry} registry.EXPECT().GetHandlerSpec("service", "hello").Return(spec, nil) rpcHandler.EXPECT().Handle( transporttest.NewContextMatcher(t), transporttest.NewRequestMatcher(t, &transport.Request{ Caller: "caller", Service: "service", Headers: transport.HeadersFromMap(tt.wantHeaders), Encoding: transport.Encoding(tt.format), Procedure: "hello", Body: bytes.NewReader([]byte("world")), }), gomock.Any(), ).Return(nil) respRecorder := newResponseRecorder() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() tchHandler.handle(ctx, &fakeInboundCall{ service: "service", caller: "caller", format: tt.format, method: "hello", arg2: tt.headers, arg3: []byte("world"), resp: respRecorder, }) assert.NoError(t, respRecorder.systemErr, "did not expect an error") } }
func TestHandlerFailures(t *testing.T) { tests := []struct { desc string // context to use in the callm a default one is used otherwise. ctx context.Context ctxFunc func() (context.Context, context.CancelFunc) sendCall *fakeInboundCall expectCall func(*transporttest.MockUnaryHandler) wantErrors []string // error message contents wantStatus tchannel.SystemErrCode // expected status }{ { desc: "no timeout on context", ctx: context.Background(), sendCall: &fakeInboundCall{ service: "foo", caller: "bar", method: "hello", format: tchannel.Raw, arg2: []byte{0x00, 0x00}, arg3: []byte{0x00}, }, wantErrors: []string{"timeout required"}, wantStatus: tchannel.ErrCodeBadRequest, }, { desc: "arg2 reader error", sendCall: &fakeInboundCall{ service: "foo", caller: "bar", method: "hello", format: tchannel.Raw, arg2: nil, arg3: []byte{0x00}, }, wantErrors: []string{ `BadRequest: failed to decode "raw" request headers for`, `procedure "hello" of service "foo" from caller "bar"`, }, wantStatus: tchannel.ErrCodeBadRequest, }, { desc: "arg2 parse error", sendCall: &fakeInboundCall{ service: "foo", caller: "bar", method: "hello", format: tchannel.JSON, arg2: []byte("{not valid JSON}"), arg3: []byte{0x00}, }, wantErrors: []string{ `BadRequest: failed to decode "json" request headers for`, `procedure "hello" of service "foo" from caller "bar"`, }, wantStatus: tchannel.ErrCodeBadRequest, }, { desc: "arg3 reader error", sendCall: &fakeInboundCall{ service: "foo", caller: "bar", method: "hello", format: tchannel.Raw, arg2: []byte{0x00, 0x00}, arg3: nil, }, wantErrors: []string{ `UnexpectedError: error for procedure "hello" of service "foo"`, }, wantStatus: tchannel.ErrCodeUnexpected, }, { desc: "internal error", sendCall: &fakeInboundCall{ service: "foo", caller: "bar", method: "hello", format: tchannel.Raw, arg2: []byte{0x00, 0x00}, arg3: []byte{0x00}, }, expectCall: func(h *transporttest.MockUnaryHandler) { h.EXPECT().Handle( transporttest.NewContextMatcher(t, transporttest.ContextTTL(time.Second)), transporttest.NewRequestMatcher( t, &transport.Request{ Caller: "bar", Service: "foo", Encoding: raw.Encoding, Procedure: "hello", Body: bytes.NewReader([]byte{0x00}), }, ), gomock.Any(), ).Return(fmt.Errorf("great sadness")) }, wantErrors: []string{ `UnexpectedError: error for procedure "hello" of service "foo":`, "great sadness", }, wantStatus: tchannel.ErrCodeUnexpected, }, { desc: "arg3 encode error", sendCall: &fakeInboundCall{ service: "foo", caller: "bar", method: "hello", format: tchannel.JSON, arg2: []byte("{}"), arg3: []byte("{}"), }, expectCall: func(h *transporttest.MockUnaryHandler) { req := &transport.Request{ Caller: "bar", Service: "foo", Encoding: json.Encoding, Procedure: "hello", Body: bytes.NewReader([]byte("{}")), } h.EXPECT().Handle( transporttest.NewContextMatcher(t, transporttest.ContextTTL(time.Second)), transporttest.NewRequestMatcher(t, req), gomock.Any(), ).Return( encoding.ResponseBodyEncodeError(req, errors.New( "serialization derp", ))) }, wantErrors: []string{ `UnexpectedError: failed to encode "json" response body for`, `procedure "hello" of service "foo" from caller "bar":`, `serialization derp`, }, wantStatus: tchannel.ErrCodeUnexpected, }, { desc: "handler timeout", ctxFunc: func() (context.Context, context.CancelFunc) { return context.WithTimeout(context.Background(), time.Millisecond) }, sendCall: &fakeInboundCall{ service: "foo", caller: "bar", method: "waituntiltimeout", format: tchannel.Raw, arg2: []byte{0x00, 0x00}, arg3: []byte{0x00}, }, expectCall: func(h *transporttest.MockUnaryHandler) { req := &transport.Request{ Service: "foo", Caller: "bar", Procedure: "waituntiltimeout", Encoding: raw.Encoding, Body: bytes.NewReader([]byte{0x00}), } h.EXPECT().Handle( transporttest.NewContextMatcher( t, transporttest.ContextTTL(time.Millisecond)), transporttest.NewRequestMatcher(t, req), gomock.Any(), ).Do(func(ctx context.Context, _ *transport.Request, _ transport.ResponseWriter) { <-ctx.Done() }).Return(context.DeadlineExceeded) }, wantErrors: []string{ `tchannel error ErrCodeTimeout: Timeout: call to procedure "waituntiltimeout" of service "foo" from caller "bar" timed out after `}, wantStatus: tchannel.ErrCodeTimeout, }, { desc: "handler panic", sendCall: &fakeInboundCall{ service: "foo", caller: "bar", method: "panic", format: tchannel.Raw, arg2: []byte{0x00, 0x00}, arg3: []byte{0x00}, }, expectCall: func(h *transporttest.MockUnaryHandler) { req := &transport.Request{ Service: "foo", Caller: "bar", Procedure: "panic", Encoding: raw.Encoding, Body: bytes.NewReader([]byte{0x00}), } h.EXPECT().Handle( transporttest.NewContextMatcher( t, transporttest.ContextTTL(time.Second)), transporttest.NewRequestMatcher(t, req), gomock.Any(), ).Do(func(context.Context, *transport.Request, transport.ResponseWriter) { panic("oops I panicked!") }) }, wantErrors: []string{ `UnexpectedError: error for procedure "panic" of service "foo": panic: oops I panicked!`, }, wantStatus: tchannel.ErrCodeUnexpected, }, } for _, tt := range tests { ctx, cancel := context.WithTimeout(context.Background(), time.Second) if tt.ctx != nil { ctx = tt.ctx } else if tt.ctxFunc != nil { ctx, cancel = tt.ctxFunc() } defer cancel() mockCtrl := gomock.NewController(t) thandler := transporttest.NewMockUnaryHandler(mockCtrl) spec := transport.NewUnaryHandlerSpec(thandler) if tt.expectCall != nil { tt.expectCall(thandler) } resp := newResponseRecorder() tt.sendCall.resp = resp registry := transporttest.NewMockRegistry(mockCtrl) registry.EXPECT().GetHandlerSpec(tt.sendCall.service, tt.sendCall.method). Return(spec, nil).AnyTimes() handler{Registry: registry}.handle(ctx, tt.sendCall) err := resp.systemErr require.Error(t, err, "expected error for %q", tt.desc) systemErr, isSystemErr := err.(tchannel.SystemError) require.True(t, isSystemErr, "expected %v for %q to be a system error", err, tt.desc) assert.Equal(t, tt.wantStatus, systemErr.Code(), tt.desc) for _, msg := range tt.wantErrors { assert.Contains( t, err.Error(), msg, "error should contain message for %q", tt.desc) } mockCtrl.Finish() } }
func TestCall(t *testing.T) { mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() ctx := context.Background() caller := "caller" service := "service" tests := []struct { procedure string headers yarpc.Headers body interface{} encodedRequest string encodedResponse string // whether the outbound receives the request noCall bool // Either want, or wantType and wantErr must be set. want interface{} // expected response body wantHeaders yarpc.Headers wantType reflect.Type // type of response body wantErr string // error message }{ { procedure: "foo", body: []string{"foo", "bar"}, encodedRequest: `["foo","bar"]`, encodedResponse: `{"success": true}`, want: map[string]interface{}{"success": true}, }, { procedure: "bar", body: []int{1, 2, 3}, encodedRequest: `[1,2,3]`, encodedResponse: `invalid JSON`, wantType: _typeOfMapInterface, wantErr: `failed to decode "json" response body for procedure "bar" of service "service"`, }, { procedure: "baz", body: func() {}, // funcs cannot be json.Marshal'ed noCall: true, wantType: _typeOfMapInterface, wantErr: `failed to encode "json" request body for procedure "baz" of service "service"`, }, { procedure: "requestHeaders", headers: yarpc.NewHeaders().With("user-id", "42"), body: map[string]interface{}{}, encodedRequest: "{}", encodedResponse: "{}", want: map[string]interface{}{}, wantHeaders: yarpc.NewHeaders().With("success", "true"), }, } for _, tt := range tests { outbound := transporttest.NewMockUnaryOutbound(mockCtrl) client := New(channel.MultiOutbound(caller, service, transport.Outbounds{ Unary: outbound, })) if !tt.noCall { outbound.EXPECT().Call(gomock.Any(), transporttest.NewRequestMatcher(t, &transport.Request{ Caller: caller, Service: service, Procedure: tt.procedure, Encoding: Encoding, Headers: transport.Headers(tt.headers), Body: bytes.NewReader([]byte(tt.encodedRequest)), }), ).Return( &transport.Response{ Body: ioutil.NopCloser( bytes.NewReader([]byte(tt.encodedResponse))), Headers: transport.Headers(tt.wantHeaders), }, nil) } var wantType reflect.Type if tt.want != nil { wantType = reflect.TypeOf(tt.want) } else { require.NotNil(t, tt.wantType, "wantType is required if want is nil") wantType = tt.wantType } resBody := reflect.Zero(wantType).Interface() res, err := client.Call( ctx, yarpc.NewReqMeta().Procedure(tt.procedure).Headers(tt.headers), tt.body, &resBody, ) if tt.wantErr != "" { if assert.Error(t, err) { assert.Contains(t, err.Error(), tt.wantErr) } } else { if assert.NoError(t, err) { assert.Equal(t, tt.wantHeaders, res.Headers()) assert.Equal(t, tt.want, resBody) } } } }
func TestCallOneway(t *testing.T) { mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() ctx := context.Background() caller := "caller" service := "service" tests := []struct { procedure string headers yarpc.Headers body interface{} encodedRequest string // whether the outbound receives the request noCall bool wantErr string // error message }{ { procedure: "foo", body: []string{"foo", "bar"}, encodedRequest: `["foo","bar"]` + "\n", }, { procedure: "baz", body: func() {}, // funcs cannot be json.Marshal'ed noCall: true, wantErr: `failed to encode "json" request body for procedure "baz" of service "service"`, }, { procedure: "requestHeaders", headers: yarpc.NewHeaders().With("user-id", "42"), body: map[string]interface{}{}, encodedRequest: "{}\n", }, } for _, tt := range tests { outbound := transporttest.NewMockOnewayOutbound(mockCtrl) client := New(channel.MultiOutbound(caller, service, transport.Outbounds{ Oneway: outbound, })) if !tt.noCall { reqMatcher := transporttest.NewRequestMatcher(t, &transport.Request{ Caller: caller, Service: service, Procedure: tt.procedure, Encoding: Encoding, Headers: transport.Headers(tt.headers), Body: bytes.NewReader([]byte(tt.encodedRequest)), }) if tt.wantErr != "" { outbound. EXPECT(). CallOneway(gomock.Any(), reqMatcher). Return(nil, errors.New(tt.wantErr)) } else { outbound. EXPECT(). CallOneway(gomock.Any(), reqMatcher). Return(&successAck{}, nil) } } ack, err := client.CallOneway( ctx, yarpc.NewReqMeta().Procedure(tt.procedure).Headers(tt.headers), tt.body) if tt.wantErr != "" { assert.Error(t, err) assert.Contains(t, err.Error(), tt.wantErr) } else { assert.NoError(t, err, "") assert.Equal(t, ack.String(), "success") } } }