func TestHandlerPanic(t *testing.T) { inbound := NewInbound("localhost:0") serverDispatcher := yarpc.NewDispatcher(yarpc.Config{ Name: "yarpc-test", Inbounds: []transport.Inbound{inbound}, }) serverDispatcher.Register([]transport.Registrant{ { Procedure: "panic", HandlerSpec: transport.NewUnaryHandlerSpec(panickedHandler{}), }, }) require.NoError(t, serverDispatcher.Start()) defer serverDispatcher.Stop() clientDispatcher := yarpc.NewDispatcher(yarpc.Config{ Name: "yarpc-test-client", Outbounds: yarpc.Outbounds{ "yarpc-test": { Unary: NewOutbound(fmt.Sprintf("http://%s", inbound.Addr().String())), }, }, }) require.NoError(t, clientDispatcher.Start()) defer clientDispatcher.Stop() client := raw.New(clientDispatcher.Channel("yarpc-test")) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() _, _, err := client.Call(ctx, yarpc.NewReqMeta().Procedure("panic"), []byte{}) assert.True(t, transport.IsUnexpectedError(err), "Must be an UnexpectedError") assert.Equal(t, `UnexpectedError: error for procedure "panic" of service "yarpc-test": panic: oops I panicked!`, err.Error()) }
func TestSimpleRoundTrip(t *testing.T) { transports := []roundTripTransport{ httpTransport{t}, tchannelTransport{t}, } tests := []struct { requestHeaders transport.Headers requestBody string responseHeaders transport.Headers responseBody string responseError error wantError func(error) }{ { requestHeaders: transport.NewHeaders().With("token", "1234"), requestBody: "world", responseHeaders: transport.NewHeaders().With("status", "ok"), responseBody: "hello, world", }, { requestBody: "foo", responseError: errors.HandlerUnexpectedError(fmt.Errorf("great sadness")), wantError: func(err error) { assert.True(t, transport.IsUnexpectedError(err), err) assert.Equal(t, "UnexpectedError: great sadness", err.Error()) }, }, { requestBody: "bar", responseError: errors.HandlerBadRequestError(fmt.Errorf("missing service name")), wantError: func(err error) { assert.True(t, transport.IsBadRequestError(err)) assert.Equal(t, "BadRequest: missing service name", err.Error()) }, }, { requestBody: "baz", responseError: errors.RemoteUnexpectedError( `UnexpectedError: error for procedure "foo" of service "bar": great sadness`, ), wantError: func(err error) { assert.True(t, transport.IsUnexpectedError(err)) assert.Equal(t, `UnexpectedError: error for procedure "hello" of service "testService": `+ `UnexpectedError: error for procedure "foo" of service "bar": great sadness`, err.Error()) }, }, { requestBody: "qux", responseError: errors.RemoteBadRequestError( `BadRequest: unrecognized procedure "echo" for service "derp"`, ), wantError: func(err error) { assert.True(t, transport.IsUnexpectedError(err)) assert.Equal(t, `UnexpectedError: error for procedure "hello" of service "testService": `+ `BadRequest: unrecognized procedure "echo" for service "derp"`, err.Error()) }, }, } rootCtx := context.Background() for _, tt := range tests { for _, trans := range transports { requestMatcher := transporttest.NewRequestMatcher(t, &transport.Request{ Caller: testCaller, Service: testService, Procedure: testProcedure, Encoding: raw.Encoding, Headers: tt.requestHeaders, Body: bytes.NewReader([]byte(tt.requestBody)), }) handler := unaryHandlerFunc(func(_ context.Context, r *transport.Request, w transport.ResponseWriter) error { assert.True(t, requestMatcher.Matches(r), "request mismatch: received %v", r) if tt.responseError != nil { return tt.responseError } if tt.responseHeaders.Len() > 0 { w.AddHeaders(tt.responseHeaders) } _, err := w.Write([]byte(tt.responseBody)) assert.NoError(t, err, "failed to write response for %v", r) return err }) ctx, cancel := context.WithTimeout(rootCtx, 200*time.Millisecond) defer cancel() registry := staticRegistry{Handler: handler} trans.WithRegistry(registry, func(o transport.UnaryOutbound) { res, err := o.Call(ctx, &transport.Request{ Caller: testCaller, Service: testService, Procedure: testProcedure, Encoding: raw.Encoding, Headers: tt.requestHeaders, Body: bytes.NewReader([]byte(tt.requestBody)), }) if tt.wantError != nil { if assert.Error(t, err, "%T: expected error, got %v", trans, res) { tt.wantError(err) // none of the errors returned by Call can be valid // Handler errors. _, ok := err.(errors.HandlerError) assert.False(t, ok, "%T: %T must not be a HandlerError", trans, err) } } else { responseMatcher := transporttest.NewResponseMatcher(t, &transport.Response{ Headers: tt.responseHeaders, Body: ioutil.NopCloser(bytes.NewReader([]byte(tt.responseBody))), }) if assert.NoError(t, err, "%T: call failed", trans) { assert.True(t, responseMatcher.Matches(res), "%T: response mismatch", trans) } } }) } } }