func (h jsonHandler) Handle(ctx context.Context, treq *transport.Request, rw transport.ResponseWriter) error { if err := encoding.Expect(treq, Encoding); err != nil { return err } reqBody, err := h.reader.Read(json.NewDecoder(treq.Body)) if err != nil { return encoding.RequestBodyDecodeError(treq, err) } reqMeta := meta.FromTransportRequest(treq) results := h.handler.Call([]reflect.Value{reflect.ValueOf(ctx), reflect.ValueOf(reqMeta), reqBody}) if err := results[2].Interface(); err != nil { return err.(error) } if resMeta, ok := results[1].Interface().(yarpc.ResMeta); ok { meta.ToTransportResponseWriter(resMeta, rw) } result := results[0].Interface() if err := json.NewEncoder(rw).Encode(result); err != nil { return encoding.ResponseBodyEncodeError(treq, err) } return nil }
func TestHandlerFailures(t *testing.T) { tests := []struct { desc string // context to use in the callm a default one is used otherwise. ctx context.Context ctxFunc func() (context.Context, context.CancelFunc) sendCall *fakeInboundCall expectCall func(*transporttest.MockUnaryHandler) wantErrors []string // error message contents wantStatus tchannel.SystemErrCode // expected status }{ { desc: "no timeout on context", ctx: context.Background(), sendCall: &fakeInboundCall{ service: "foo", caller: "bar", method: "hello", format: tchannel.Raw, arg2: []byte{0x00, 0x00}, arg3: []byte{0x00}, }, wantErrors: []string{"timeout required"}, wantStatus: tchannel.ErrCodeBadRequest, }, { desc: "arg2 reader error", sendCall: &fakeInboundCall{ service: "foo", caller: "bar", method: "hello", format: tchannel.Raw, arg2: nil, arg3: []byte{0x00}, }, wantErrors: []string{ `BadRequest: failed to decode "raw" request headers for`, `procedure "hello" of service "foo" from caller "bar"`, }, wantStatus: tchannel.ErrCodeBadRequest, }, { desc: "arg2 parse error", sendCall: &fakeInboundCall{ service: "foo", caller: "bar", method: "hello", format: tchannel.JSON, arg2: []byte("{not valid JSON}"), arg3: []byte{0x00}, }, wantErrors: []string{ `BadRequest: failed to decode "json" request headers for`, `procedure "hello" of service "foo" from caller "bar"`, }, wantStatus: tchannel.ErrCodeBadRequest, }, { desc: "arg3 reader error", sendCall: &fakeInboundCall{ service: "foo", caller: "bar", method: "hello", format: tchannel.Raw, arg2: []byte{0x00, 0x00}, arg3: nil, }, wantErrors: []string{ `UnexpectedError: error for procedure "hello" of service "foo"`, }, wantStatus: tchannel.ErrCodeUnexpected, }, { desc: "internal error", sendCall: &fakeInboundCall{ service: "foo", caller: "bar", method: "hello", format: tchannel.Raw, arg2: []byte{0x00, 0x00}, arg3: []byte{0x00}, }, expectCall: func(h *transporttest.MockUnaryHandler) { h.EXPECT().Handle( transporttest.NewContextMatcher(t, transporttest.ContextTTL(time.Second)), transporttest.NewRequestMatcher( t, &transport.Request{ Caller: "bar", Service: "foo", Encoding: raw.Encoding, Procedure: "hello", Body: bytes.NewReader([]byte{0x00}), }, ), gomock.Any(), ).Return(fmt.Errorf("great sadness")) }, wantErrors: []string{ `UnexpectedError: error for procedure "hello" of service "foo":`, "great sadness", }, wantStatus: tchannel.ErrCodeUnexpected, }, { desc: "arg3 encode error", sendCall: &fakeInboundCall{ service: "foo", caller: "bar", method: "hello", format: tchannel.JSON, arg2: []byte("{}"), arg3: []byte("{}"), }, expectCall: func(h *transporttest.MockUnaryHandler) { req := &transport.Request{ Caller: "bar", Service: "foo", Encoding: json.Encoding, Procedure: "hello", Body: bytes.NewReader([]byte("{}")), } h.EXPECT().Handle( transporttest.NewContextMatcher(t, transporttest.ContextTTL(time.Second)), transporttest.NewRequestMatcher(t, req), gomock.Any(), ).Return( encoding.ResponseBodyEncodeError(req, errors.New( "serialization derp", ))) }, wantErrors: []string{ `UnexpectedError: failed to encode "json" response body for`, `procedure "hello" of service "foo" from caller "bar":`, `serialization derp`, }, wantStatus: tchannel.ErrCodeUnexpected, }, { desc: "handler timeout", ctxFunc: func() (context.Context, context.CancelFunc) { return context.WithTimeout(context.Background(), time.Millisecond) }, sendCall: &fakeInboundCall{ service: "foo", caller: "bar", method: "waituntiltimeout", format: tchannel.Raw, arg2: []byte{0x00, 0x00}, arg3: []byte{0x00}, }, expectCall: func(h *transporttest.MockUnaryHandler) { req := &transport.Request{ Service: "foo", Caller: "bar", Procedure: "waituntiltimeout", Encoding: raw.Encoding, Body: bytes.NewReader([]byte{0x00}), } h.EXPECT().Handle( transporttest.NewContextMatcher( t, transporttest.ContextTTL(time.Millisecond)), transporttest.NewRequestMatcher(t, req), gomock.Any(), ).Do(func(ctx context.Context, _ *transport.Request, _ transport.ResponseWriter) { <-ctx.Done() }).Return(context.DeadlineExceeded) }, wantErrors: []string{ `tchannel error ErrCodeTimeout: Timeout: call to procedure "waituntiltimeout" of service "foo" from caller "bar" timed out after `}, wantStatus: tchannel.ErrCodeTimeout, }, { desc: "handler panic", sendCall: &fakeInboundCall{ service: "foo", caller: "bar", method: "panic", format: tchannel.Raw, arg2: []byte{0x00, 0x00}, arg3: []byte{0x00}, }, expectCall: func(h *transporttest.MockUnaryHandler) { req := &transport.Request{ Service: "foo", Caller: "bar", Procedure: "panic", Encoding: raw.Encoding, Body: bytes.NewReader([]byte{0x00}), } h.EXPECT().Handle( transporttest.NewContextMatcher( t, transporttest.ContextTTL(time.Second)), transporttest.NewRequestMatcher(t, req), gomock.Any(), ).Do(func(context.Context, *transport.Request, transport.ResponseWriter) { panic("oops I panicked!") }) }, wantErrors: []string{ `UnexpectedError: error for procedure "panic" of service "foo": panic: oops I panicked!`, }, wantStatus: tchannel.ErrCodeUnexpected, }, } for _, tt := range tests { ctx, cancel := context.WithTimeout(context.Background(), time.Second) if tt.ctx != nil { ctx = tt.ctx } else if tt.ctxFunc != nil { ctx, cancel = tt.ctxFunc() } defer cancel() mockCtrl := gomock.NewController(t) thandler := transporttest.NewMockUnaryHandler(mockCtrl) spec := transport.NewUnaryHandlerSpec(thandler) if tt.expectCall != nil { tt.expectCall(thandler) } resp := newResponseRecorder() tt.sendCall.resp = resp registry := transporttest.NewMockRegistry(mockCtrl) registry.EXPECT().GetHandlerSpec(tt.sendCall.service, tt.sendCall.method). Return(spec, nil).AnyTimes() handler{Registry: registry}.handle(ctx, tt.sendCall) err := resp.systemErr require.Error(t, err, "expected error for %q", tt.desc) systemErr, isSystemErr := err.(tchannel.SystemError) require.True(t, isSystemErr, "expected %v for %q to be a system error", err, tt.desc) assert.Equal(t, tt.wantStatus, systemErr.Code(), tt.desc) for _, msg := range tt.wantErrors { assert.Contains( t, err.Error(), msg, "error should contain message for %q", tt.desc) } mockCtrl.Finish() } }
func (t thriftUnaryHandler) Handle(ctx context.Context, treq *transport.Request, rw transport.ResponseWriter) error { if err := encoding.Expect(treq, Encoding); err != nil { return err } body, err := ioutil.ReadAll(treq.Body) if err != nil { return err } // We disable enveloping if either the client or the transport requires it. proto := t.Protocol if !t.Enveloping { proto = disableEnvelopingProtocol{ Protocol: proto, Type: wire.Call, // we only decode requests } } envelope, err := proto.DecodeEnveloped(bytes.NewReader(body)) if err != nil { return encoding.RequestBodyDecodeError(treq, err) } if envelope.Type != wire.Call { return encoding.RequestBodyDecodeError( treq, errUnexpectedEnvelopeType(envelope.Type)) } reqMeta := meta.FromTransportRequest(treq) res, err := t.UnaryHandler.Handle(ctx, reqMeta, envelope.Value) if err != nil { return err } if resType := res.Body.EnvelopeType(); resType != wire.Reply { return encoding.ResponseBodyEncodeError( treq, errUnexpectedEnvelopeType(resType)) } value, err := res.Body.ToWire() if err != nil { return err } if res.IsApplicationError { rw.SetApplicationError() } resMeta := res.Meta if resMeta != nil { meta.ToTransportResponseWriter(resMeta, rw) } err = proto.EncodeEnveloped(wire.Envelope{ Name: res.Body.MethodName(), Type: res.Body.EnvelopeType(), SeqID: envelope.SeqID, Value: value, }, rw) if err != nil { return encoding.ResponseBodyEncodeError(treq, err) } return nil }