Beispiel #1
0
func TestCallOnewayFailure(t *testing.T) {
	mockCtrl := gomock.NewController(t)
	defer mockCtrl.Finish()

	ctx := context.Background()

	caller := "caller"
	service := "service"
	procedure := "procedure"
	body := []byte{1, 2, 3}

	outbound := transporttest.NewMockOnewayOutbound(mockCtrl)
	client := New(channel.MultiOutbound(caller, service,
		transport.Outbounds{
			Oneway: outbound,
		}))

	outbound.EXPECT().CallOneway(gomock.Any(),
		transporttest.NewRequestMatcher(t,
			&transport.Request{
				Service:   service,
				Caller:    caller,
				Procedure: procedure,
				Encoding:  Encoding,
				Body:      bytes.NewReader(body),
			}),
	).Return(nil, errors.New("some error"))

	_, err := client.CallOneway(
		ctx,
		yarpc.NewReqMeta().Procedure(procedure),
		body)

	assert.Error(t, err)
}
Beispiel #2
0
func TestHandlerInternalFailure(t *testing.T) {
	mockCtrl := gomock.NewController(t)
	defer mockCtrl.Finish()

	headers := make(http.Header)
	headers.Set(CallerHeader, "somecaller")
	headers.Set(EncodingHeader, "raw")
	headers.Set(TTLMSHeader, "1000")
	headers.Set(ProcedureHeader, "hello")
	headers.Set(ServiceHeader, "fake")

	request := http.Request{
		Method: "POST",
		Header: headers,
		Body:   ioutil.NopCloser(bytes.NewReader([]byte{})),
	}

	rpcHandler := transporttest.NewMockUnaryHandler(mockCtrl)
	rpcHandler.EXPECT().Handle(
		transporttest.NewContextMatcher(t, transporttest.ContextTTL(time.Second)),
		transporttest.NewRequestMatcher(
			t, &transport.Request{
				Caller:    "somecaller",
				Service:   "fake",
				Encoding:  raw.Encoding,
				Procedure: "hello",
				Body:      bytes.NewReader([]byte{}),
			},
		),
		gomock.Any(),
	).Return(fmt.Errorf("great sadness"))

	registry := transporttest.NewMockRegistry(mockCtrl)
	spec := transport.NewUnaryHandlerSpec(rpcHandler)

	registry.EXPECT().GetHandlerSpec("fake", "hello").Return(spec, nil)

	httpHandler := handler{Registry: registry}
	httpResponse := httptest.NewRecorder()
	httpHandler.ServeHTTP(httpResponse, &request)

	code := httpResponse.Code
	assert.True(t, code >= 500 && code < 600, "expected 500 level response")
	assert.Equal(t,
		`UnexpectedError: error for procedure "hello" of service "fake": great sadness`+"\n",
		httpResponse.Body.String())
}
Beispiel #3
0
func TestHandlerSucces(t *testing.T) {
	mockCtrl := gomock.NewController(t)
	defer mockCtrl.Finish()

	headers := make(http.Header)
	headers.Set(CallerHeader, "moe")
	headers.Set(EncodingHeader, "raw")
	headers.Set(TTLMSHeader, "1000")
	headers.Set(ProcedureHeader, "nyuck")
	headers.Set(ServiceHeader, "curly")

	registry := transporttest.NewMockRegistry(mockCtrl)
	rpcHandler := transporttest.NewMockUnaryHandler(mockCtrl)
	spec := transport.NewUnaryHandlerSpec(rpcHandler)

	registry.EXPECT().GetHandlerSpec("curly", "nyuck").Return(spec, nil)

	rpcHandler.EXPECT().Handle(
		transporttest.NewContextMatcher(t,
			transporttest.ContextTTL(time.Second),
		),
		transporttest.NewRequestMatcher(
			t, &transport.Request{
				Caller:    "moe",
				Service:   "curly",
				Encoding:  raw.Encoding,
				Procedure: "nyuck",
				Body:      bytes.NewReader([]byte("Nyuck Nyuck")),
			},
		),
		gomock.Any(),
	).Return(nil)

	httpHandler := handler{Registry: registry}
	req := &http.Request{
		Method: "POST",
		Header: headers,
		Body:   ioutil.NopCloser(bytes.NewReader([]byte("Nyuck Nyuck"))),
	}
	rw := httptest.NewRecorder()
	httpHandler.ServeHTTP(rw, req)
	code := rw.Code
	assert.Equal(t, code, 200, "expected 200 code")
	assert.Equal(t, rw.Body.String(), "")
}
Beispiel #4
0
func TestCall(t *testing.T) {
	mockCtrl := gomock.NewController(t)
	defer mockCtrl.Finish()

	ctx := context.Background()

	caller := "caller"
	service := "service"

	tests := []struct {
		procedure    string
		headers      yarpc.Headers
		body         []byte
		responseBody [][]byte

		want        []byte
		wantErr     string
		wantHeaders yarpc.Headers
	}{
		{
			procedure:    "foo",
			body:         []byte{1, 2, 3},
			responseBody: [][]byte{{4}, {5}, {6}},
			want:         []byte{4, 5, 6},
		},
		{
			procedure:    "bar",
			body:         []byte{1, 2, 3},
			responseBody: [][]byte{{4}, {5}, nil, {6}},
			wantErr:      "error set by user",
		},
		{
			procedure:    "headers",
			headers:      yarpc.NewHeaders().With("x", "y"),
			body:         []byte{},
			responseBody: [][]byte{},
			want:         []byte{},
			wantHeaders:  yarpc.NewHeaders().With("a", "b"),
		},
	}

	for _, tt := range tests {
		outbound := transporttest.NewMockUnaryOutbound(mockCtrl)
		client := New(channel.MultiOutbound(caller, service,
			transport.Outbounds{
				Unary: outbound,
			}))

		writer, responseBody := testreader.ChunkReader()
		for _, chunk := range tt.responseBody {
			writer <- chunk
		}
		close(writer)

		outbound.EXPECT().Call(gomock.Any(),
			transporttest.NewRequestMatcher(t,
				&transport.Request{
					Caller:    caller,
					Service:   service,
					Procedure: tt.procedure,
					Headers:   transport.Headers(tt.headers),
					Encoding:  Encoding,
					Body:      bytes.NewReader(tt.body),
				}),
		).Return(
			&transport.Response{
				Body:    ioutil.NopCloser(responseBody),
				Headers: transport.Headers(tt.wantHeaders),
			}, nil)

		resBody, res, err := client.Call(
			ctx,
			yarpc.NewReqMeta().Procedure(tt.procedure).Headers(tt.headers),
			tt.body)

		if tt.wantErr != "" {
			if assert.Error(t, err) {
				assert.Equal(t, err.Error(), tt.wantErr)
			}
		} else {
			if assert.NoError(t, err) {
				assert.Equal(t, tt.want, resBody)
				assert.Equal(t, tt.wantHeaders, res.Headers())
			}
		}
	}
}
Beispiel #5
0
func TestCallOneway(t *testing.T) {
	mockCtrl := gomock.NewController(t)
	defer mockCtrl.Finish()

	ctx := context.Background()

	caller := "caller"
	service := "service"

	tests := []struct {
		procedure string
		headers   yarpc.Headers
		body      []byte

		wantErr     string
		wantHeaders yarpc.Headers
	}{
		{
			procedure: "foo",
			body:      []byte{1, 2, 3},
		},
		{
			procedure: "headers",
			headers:   yarpc.NewHeaders().With("x", "y"),
			body:      []byte{},
		},
	}

	for _, tt := range tests {
		outbound := transporttest.NewMockOnewayOutbound(mockCtrl)
		client := New(channel.MultiOutbound(caller, service,
			transport.Outbounds{
				Oneway: outbound,
			}))

		outbound.EXPECT().CallOneway(gomock.Any(),
			transporttest.NewRequestMatcher(t,
				&transport.Request{
					Caller:    caller,
					Service:   service,
					Procedure: tt.procedure,
					Headers:   transport.Headers(tt.headers),
					Encoding:  Encoding,
					Body:      bytes.NewReader(tt.body),
				}),
		).Return(&successAck{}, nil)

		ack, err := client.CallOneway(
			ctx,
			yarpc.NewReqMeta().Procedure(tt.procedure).Headers(tt.headers),
			tt.body)

		if tt.wantErr != "" {
			if assert.Error(t, err) {
				assert.Equal(t, err.Error(), tt.wantErr)
			}
		} else {
			assert.Equal(t, "success", ack.String())
		}
	}
}
Beispiel #6
0
func TestSimpleRoundTripOneway(t *testing.T) {
	trans := httpTransport{t}

	tests := []struct {
		name           string
		requestHeaders transport.Headers
		requestBody    string
	}{
		{
			name:           "hello world",
			requestHeaders: transport.NewHeaders().With("foo", "bar"),
			requestBody:    "hello world",
		},
		{
			name:           "empty",
			requestHeaders: transport.NewHeaders(),
			requestBody:    "",
		},
	}

	rootCtx := context.Background()

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {

			requestMatcher := transporttest.NewRequestMatcher(t, &transport.Request{
				Caller:    testCaller,
				Service:   testService,
				Procedure: testProcedureOneway,
				Encoding:  raw.Encoding,
				Headers:   tt.requestHeaders,
				Body:      bytes.NewReader([]byte(tt.requestBody)),
			})

			handlerDone := make(chan struct{})

			onewayHandler := onewayHandlerFunc(func(_ context.Context, r *transport.Request) error {
				assert.True(t, requestMatcher.Matches(r), "request mismatch: received %v", r)

				// Pretend to work: this delay should not slow down tests since it is a
				// server-side operation
				time.Sleep(5 * time.Second)

				// close the channel, telling the client (which should not be waiting for
				// a response) that the handler finished executing
				close(handlerDone)

				return nil
			})

			registry := staticRegistry{OnewayHandler: onewayHandler}

			trans.WithRegistryOneway(registry, func(o transport.OnewayOutbound) {
				ack, err := o.CallOneway(rootCtx, &transport.Request{
					Caller:    testCaller,
					Service:   testService,
					Procedure: testProcedureOneway,
					Encoding:  raw.Encoding,
					Headers:   tt.requestHeaders,
					Body:      bytes.NewReader([]byte(tt.requestBody)),
				})

				select {
				case <-handlerDone:
					// if the server filled the channel, it means we waited for the server
					// to complete the request
					assert.Fail(t, "client waited for server handler to finish executing")
				default:
				}

				if assert.NoError(t, err, "%T: oneway call failed for test '%v'", trans, tt.name) {
					assert.NotNil(t, ack)
				}
			})
		})
	}
}
Beispiel #7
0
func TestSimpleRoundTrip(t *testing.T) {
	transports := []roundTripTransport{
		httpTransport{t},
		tchannelTransport{t},
	}

	tests := []struct {
		requestHeaders  transport.Headers
		requestBody     string
		responseHeaders transport.Headers
		responseBody    string
		responseError   error

		wantError func(error)
	}{
		{
			requestHeaders:  transport.NewHeaders().With("token", "1234"),
			requestBody:     "world",
			responseHeaders: transport.NewHeaders().With("status", "ok"),
			responseBody:    "hello, world",
		},
		{
			requestBody:   "foo",
			responseError: errors.HandlerUnexpectedError(fmt.Errorf("great sadness")),
			wantError: func(err error) {
				assert.True(t, transport.IsUnexpectedError(err), err)
				assert.Equal(t, "UnexpectedError: great sadness", err.Error())
			},
		},
		{
			requestBody:   "bar",
			responseError: errors.HandlerBadRequestError(fmt.Errorf("missing service name")),
			wantError: func(err error) {
				assert.True(t, transport.IsBadRequestError(err))
				assert.Equal(t, "BadRequest: missing service name", err.Error())
			},
		},
		{
			requestBody: "baz",
			responseError: errors.RemoteUnexpectedError(
				`UnexpectedError: error for procedure "foo" of service "bar": great sadness`,
			),
			wantError: func(err error) {
				assert.True(t, transport.IsUnexpectedError(err))
				assert.Equal(t,
					`UnexpectedError: error for procedure "hello" of service "testService": `+
						`UnexpectedError: error for procedure "foo" of service "bar": great sadness`,
					err.Error())
			},
		},
		{
			requestBody: "qux",
			responseError: errors.RemoteBadRequestError(
				`BadRequest: unrecognized procedure "echo" for service "derp"`,
			),
			wantError: func(err error) {
				assert.True(t, transport.IsUnexpectedError(err))
				assert.Equal(t,
					`UnexpectedError: error for procedure "hello" of service "testService": `+
						`BadRequest: unrecognized procedure "echo" for service "derp"`,
					err.Error())
			},
		},
	}

	rootCtx := context.Background()
	for _, tt := range tests {
		for _, trans := range transports {
			requestMatcher := transporttest.NewRequestMatcher(t, &transport.Request{
				Caller:    testCaller,
				Service:   testService,
				Procedure: testProcedure,
				Encoding:  raw.Encoding,
				Headers:   tt.requestHeaders,
				Body:      bytes.NewReader([]byte(tt.requestBody)),
			})

			handler := unaryHandlerFunc(func(_ context.Context, r *transport.Request, w transport.ResponseWriter) error {
				assert.True(t, requestMatcher.Matches(r), "request mismatch: received %v", r)

				if tt.responseError != nil {
					return tt.responseError
				}

				if tt.responseHeaders.Len() > 0 {
					w.AddHeaders(tt.responseHeaders)
				}

				_, err := w.Write([]byte(tt.responseBody))
				assert.NoError(t, err, "failed to write response for %v", r)
				return err
			})

			ctx, cancel := context.WithTimeout(rootCtx, 200*time.Millisecond)
			defer cancel()

			registry := staticRegistry{Handler: handler}
			trans.WithRegistry(registry, func(o transport.UnaryOutbound) {
				res, err := o.Call(ctx, &transport.Request{
					Caller:    testCaller,
					Service:   testService,
					Procedure: testProcedure,
					Encoding:  raw.Encoding,
					Headers:   tt.requestHeaders,
					Body:      bytes.NewReader([]byte(tt.requestBody)),
				})

				if tt.wantError != nil {
					if assert.Error(t, err, "%T: expected error, got %v", trans, res) {
						tt.wantError(err)

						// none of the errors returned by Call can be valid
						// Handler errors.
						_, ok := err.(errors.HandlerError)
						assert.False(t, ok, "%T: %T must not be a HandlerError", trans, err)
					}
				} else {
					responseMatcher := transporttest.NewResponseMatcher(t, &transport.Response{
						Headers: tt.responseHeaders,
						Body:    ioutil.NopCloser(bytes.NewReader([]byte(tt.responseBody))),
					})

					if assert.NoError(t, err, "%T: call failed", trans) {
						assert.True(t, responseMatcher.Matches(res), "%T: response mismatch", trans)
					}
				}
			})
		}
	}
}
Beispiel #8
0
func TestHandlerHeaders(t *testing.T) {
	mockCtrl := gomock.NewController(t)
	defer mockCtrl.Finish()

	tests := []struct {
		giveHeaders http.Header

		wantTTL     time.Duration
		wantHeaders map[string]string
	}{
		{
			giveHeaders: http.Header{
				TTLMSHeader:      {"1000"},
				"Rpc-Header-Foo": {"bar"},
			},
			wantTTL: time.Second,
			wantHeaders: map[string]string{
				"foo": "bar",
			},
		},
		{
			giveHeaders: http.Header{
				TTLMSHeader: {"100"},
				"Rpc-Foo":   {"ignored"},
			},
			wantTTL:     100 * time.Millisecond,
			wantHeaders: map[string]string{},
		},
	}

	for _, tt := range tests {
		registry := transporttest.NewMockRegistry(mockCtrl)
		rpcHandler := transporttest.NewMockUnaryHandler(mockCtrl)
		spec := transport.NewUnaryHandlerSpec(rpcHandler)

		registry.EXPECT().GetHandlerSpec("service", "hello").Return(spec, nil)

		httpHandler := handler{Registry: registry}

		rpcHandler.EXPECT().Handle(
			transporttest.NewContextMatcher(t,
				transporttest.ContextTTL(tt.wantTTL),
			),
			transporttest.NewRequestMatcher(t,
				&transport.Request{
					Caller:    "caller",
					Service:   "service",
					Encoding:  raw.Encoding,
					Procedure: "hello",
					Headers:   transport.HeadersFromMap(tt.wantHeaders),
					Body:      bytes.NewReader([]byte("world")),
				}),
			gomock.Any(),
		).Return(nil)

		headers := http.Header{}
		for k, vs := range tt.giveHeaders {
			for _, v := range vs {
				headers.Add(k, v)
			}
		}
		headers.Set(CallerHeader, "caller")
		headers.Set(ServiceHeader, "service")
		headers.Set(EncodingHeader, "raw")
		headers.Set(ProcedureHeader, "hello")

		req := &http.Request{
			Method: "POST",
			Header: headers,
			Body:   ioutil.NopCloser(bytes.NewReader([]byte("world"))),
		}
		rw := httptest.NewRecorder()
		httpHandler.ServeHTTP(rw, req)
		assert.Equal(t, 200, rw.Code, "expected 200 status code")
	}
}
Beispiel #9
0
func TestClient(t *testing.T) {
	tests := []struct {
		desc                 string
		giveRequestBody      envelope.Enveloper // outgoing request body
		giveResponseEnvelope *wire.Envelope     // returned on DecodeEnveloped()
		giveResponseBody     *wire.Value        // return on Decode()
		clientOptions        []ClientOption

		expectCall          bool           // whether outbound.Call is expected
		wantRequestEnvelope *wire.Envelope // expect EncodeEnveloped(x)
		wantRequestBody     *wire.Value    // expect Encode(x)
		wantError           string         // whether an error is expected
	}{
		{
			desc:            "happy case",
			clientOptions:   []ClientOption{Enveloped},
			giveRequestBody: fakeEnveloper(wire.Call),
			wantRequestEnvelope: &wire.Envelope{
				Name:  "someMethod",
				SeqID: 1,
				Type:  wire.Call,
				Value: wire.NewValueStruct(wire.Struct{}),
			},
			expectCall: true,
			giveResponseEnvelope: &wire.Envelope{
				Name:  "someMethod",
				SeqID: 1,
				Type:  wire.Reply,
				Value: wire.NewValueStruct(wire.Struct{}),
			},
		},
		{
			desc:             "happy case without enveloping",
			giveRequestBody:  fakeEnveloper(wire.Call),
			wantRequestBody:  valueptr(wire.NewValueStruct(wire.Struct{})),
			expectCall:       true,
			giveResponseBody: valueptr(wire.NewValueStruct(wire.Struct{})),
		},
		{
			desc:            "wrong envelope type for request",
			clientOptions:   []ClientOption{Enveloped},
			giveRequestBody: fakeEnveloper(wire.Reply),
			wantError: `failed to encode "thrift" request body for procedure ` +
				`"MyService::someMethod" of service "service": unexpected envelope type: Reply`,
		},
		{
			desc:            "TApplicationException",
			clientOptions:   []ClientOption{Enveloped},
			giveRequestBody: fakeEnveloper(wire.Call),
			wantRequestEnvelope: &wire.Envelope{
				Name:  "someMethod",
				SeqID: 1,
				Type:  wire.Call,
				Value: wire.NewValueStruct(wire.Struct{}),
			},
			expectCall: true,
			giveResponseEnvelope: &wire.Envelope{
				Name:  "someMethod",
				SeqID: 1,
				Type:  wire.Exception,
				Value: wire.NewValueStruct(wire.Struct{Fields: []wire.Field{
					{ID: 1, Value: wire.NewValueString("great sadness")},
					{ID: 2, Value: wire.NewValueI32(7)},
				}}),
			},
			wantError: `thrift request to procedure "MyService::someMethod" of ` +
				`service "service" encountered an internal failure: ` +
				"TApplicationException{Message: great sadness, Type: PROTOCOL_ERROR}",
		},
		{
			desc:            "wrong envelope type for response",
			clientOptions:   []ClientOption{Enveloped},
			giveRequestBody: fakeEnveloper(wire.Call),
			wantRequestEnvelope: &wire.Envelope{
				Name:  "someMethod",
				SeqID: 1,
				Type:  wire.Call,
				Value: wire.NewValueStruct(wire.Struct{}),
			},
			expectCall: true,
			giveResponseEnvelope: &wire.Envelope{
				Name:  "someMethod",
				SeqID: 1,
				Type:  wire.Call,
				Value: wire.NewValueStruct(wire.Struct{}),
			},
			wantError: `failed to decode "thrift" response body for procedure ` +
				`"MyService::someMethod" of service "service": unexpected envelope type: Call`,
		},
	}

	for _, tt := range tests {
		mockCtrl := gomock.NewController(t)
		defer mockCtrl.Finish()

		proto := NewMockProtocol(mockCtrl)

		if tt.wantRequestEnvelope != nil {
			proto.EXPECT().EncodeEnveloped(*tt.wantRequestEnvelope, gomock.Any()).
				Do(func(_ wire.Envelope, w io.Writer) {
					_, err := w.Write([]byte("irrelevant"))
					require.NoError(t, err, "Write() failed")
				}).Return(nil)
		}

		if tt.wantRequestBody != nil {
			proto.EXPECT().Encode(*tt.wantRequestBody, gomock.Any()).
				Do(func(_ wire.Value, w io.Writer) {
					_, err := w.Write([]byte("irrelevant"))
					require.NoError(t, err, "Write() failed")
				}).Return(nil)
		}

		ctx, cancel := context.WithTimeout(context.Background(), time.Second)
		defer cancel()

		trans := transporttest.NewMockUnaryOutbound(mockCtrl)
		if tt.expectCall {
			trans.EXPECT().Call(ctx,
				transporttest.NewRequestMatcher(t, &transport.Request{
					Caller:    "caller",
					Service:   "service",
					Encoding:  Encoding,
					Procedure: "MyService::someMethod",
					Body:      bytes.NewReader([]byte("irrelevant")),
				}),
			).Return(&transport.Response{
				Body: ioutil.NopCloser(bytes.NewReader([]byte("irrelevant"))),
			}, nil)
		}

		if tt.giveResponseEnvelope != nil {
			proto.EXPECT().DecodeEnveloped(gomock.Any()).Return(*tt.giveResponseEnvelope, nil)
		}

		if tt.giveResponseBody != nil {
			proto.EXPECT().Decode(gomock.Any(), wire.TStruct).Return(*tt.giveResponseBody, nil)
		}

		opts := tt.clientOptions
		opts = append(opts, Protocol(proto))
		c := New(Config{
			Service: "MyService",
			Channel: channel.MultiOutbound("caller", "service",
				transport.Outbounds{
					Unary: trans,
				}),
		}, opts...)

		_, _, err := c.Call(ctx, nil, tt.giveRequestBody)
		if tt.wantError != "" {
			if assert.Error(t, err, "%v: expected failure", tt.desc) {
				assert.Contains(t, err.Error(), tt.wantError, "%v: error mismatch", tt.desc)
			}
		} else {
			assert.NoError(t, err, "%v: expected success", tt.desc)
		}
	}
}
Beispiel #10
0
func TestClientOneway(t *testing.T) {
	caller, service, procedure := "caller", "MyService", "someMethod"

	tests := []struct {
		desc            string
		giveRequestBody envelope.Enveloper // outgoing request body
		clientOptions   []ClientOption

		expectCall          bool           // whether outbound.Call is expected
		wantRequestEnvelope *wire.Envelope // expect EncodeEnveloped(x)
		wantRequestBody     *wire.Value    // expect Encode(x)
		wantError           string         // whether an error is expected
	}{
		{
			desc:            "happy case",
			giveRequestBody: fakeEnveloper(wire.Call),
			clientOptions:   []ClientOption{Enveloped},

			expectCall: true,
			wantRequestEnvelope: &wire.Envelope{
				Name:  procedure,
				SeqID: 1,
				Type:  wire.Call,
				Value: wire.NewValueStruct(wire.Struct{}),
			},
		},
		{
			desc:            "happy case without enveloping",
			giveRequestBody: fakeEnveloper(wire.Call),

			expectCall:      true,
			wantRequestBody: valueptr(wire.NewValueStruct(wire.Struct{})),
		},
		{
			desc:            "wrong envelope type for request",
			giveRequestBody: fakeEnveloper(wire.Reply),
			clientOptions:   []ClientOption{Enveloped},

			wantError: `failed to encode "thrift" request body for procedure ` +
				`"MyService::someMethod" of service "MyService": unexpected envelope type: Reply`,
		},
	}

	for _, tt := range tests {
		mockCtrl := gomock.NewController(t)
		defer mockCtrl.Finish()

		proto := NewMockProtocol(mockCtrl)
		bodyBytes := []byte("irrelevant")

		if tt.wantRequestEnvelope != nil {
			proto.EXPECT().EncodeEnveloped(*tt.wantRequestEnvelope, gomock.Any()).
				Do(func(_ wire.Envelope, w io.Writer) {
					_, err := w.Write(bodyBytes)
					require.NoError(t, err, "Write() failed")
				}).Return(nil)
		}

		if tt.wantRequestBody != nil {
			proto.EXPECT().Encode(*tt.wantRequestBody, gomock.Any()).
				Do(func(_ wire.Value, w io.Writer) {
					_, err := w.Write(bodyBytes)
					require.NoError(t, err, "Write() failed")
				}).Return(nil)
		}

		ctx := context.Background()

		onewayOutbound := transporttest.NewMockOnewayOutbound(mockCtrl)

		requestMatcher := transporttest.NewRequestMatcher(t, &transport.Request{
			Caller:    caller,
			Service:   service,
			Encoding:  Encoding,
			Procedure: procedureName(service, procedure),
			Body:      bytes.NewReader(bodyBytes),
		})

		if tt.expectCall {
			if tt.wantError != "" {
				onewayOutbound.
					EXPECT().
					CallOneway(ctx, requestMatcher).
					Return(nil, errors.New(tt.wantError))
			} else {
				onewayOutbound.
					EXPECT().
					CallOneway(ctx, requestMatcher).
					Return(&successAck{}, nil)
			}
		}
		opts := tt.clientOptions
		opts = append(opts, Protocol(proto))

		c := New(Config{
			Service: service,
			Channel: channel.MultiOutbound(caller, service,
				transport.Outbounds{
					Oneway: onewayOutbound,
				}),
		}, opts...)

		ack, err := c.CallOneway(ctx, nil, tt.giveRequestBody)
		if tt.wantError != "" {
			if assert.Error(t, err, "%v: expected failure", tt.desc) {
				assert.Contains(t, err.Error(), tt.wantError, "%v: error mismatch", tt.desc)
			}
		} else {
			assert.NoError(t, err, "%v: expected success", tt.desc)
			assert.Equal(t, "success", ack.String())
		}
	}
}
Beispiel #11
0
func TestHandlerErrors(t *testing.T) {
	mockCtrl := gomock.NewController(t)
	defer mockCtrl.Finish()

	tests := []struct {
		format  tchannel.Format
		headers []byte

		wantHeaders map[string]string
	}{
		{
			format:      tchannel.JSON,
			headers:     []byte(`{"Rpc-Header-Foo": "bar"}`),
			wantHeaders: map[string]string{"rpc-header-foo": "bar"},
		},
		{
			format: tchannel.Thrift,
			headers: []byte{
				0x00, 0x01, // 1 header
				0x00, 0x03, 'F', 'o', 'o', // Foo
				0x00, 0x03, 'B', 'a', 'r', // Bar
			},
			wantHeaders: map[string]string{"foo": "Bar"},
		},
	}

	for _, tt := range tests {
		rpcHandler := transporttest.NewMockUnaryHandler(mockCtrl)
		registry := transporttest.NewMockRegistry(mockCtrl)

		spec := transport.NewUnaryHandlerSpec(rpcHandler)
		tchHandler := handler{Registry: registry}

		registry.EXPECT().GetHandlerSpec("service", "hello").Return(spec, nil)

		rpcHandler.EXPECT().Handle(
			transporttest.NewContextMatcher(t),
			transporttest.NewRequestMatcher(t,
				&transport.Request{
					Caller:    "caller",
					Service:   "service",
					Headers:   transport.HeadersFromMap(tt.wantHeaders),
					Encoding:  transport.Encoding(tt.format),
					Procedure: "hello",
					Body:      bytes.NewReader([]byte("world")),
				}),
			gomock.Any(),
		).Return(nil)

		respRecorder := newResponseRecorder()

		ctx, cancel := context.WithTimeout(context.Background(), time.Second)
		defer cancel()
		tchHandler.handle(ctx, &fakeInboundCall{
			service: "service",
			caller:  "caller",
			format:  tt.format,
			method:  "hello",
			arg2:    tt.headers,
			arg3:    []byte("world"),
			resp:    respRecorder,
		})

		assert.NoError(t, respRecorder.systemErr, "did not expect an error")
	}
}
Beispiel #12
0
func TestHandlerFailures(t *testing.T) {
	tests := []struct {
		desc string

		// context to use in the callm a default one is used otherwise.
		ctx     context.Context
		ctxFunc func() (context.Context, context.CancelFunc)

		sendCall   *fakeInboundCall
		expectCall func(*transporttest.MockUnaryHandler)

		wantErrors []string               // error message contents
		wantStatus tchannel.SystemErrCode // expected status
	}{
		{
			desc: "no timeout on context",
			ctx:  context.Background(),
			sendCall: &fakeInboundCall{
				service: "foo",
				caller:  "bar",
				method:  "hello",
				format:  tchannel.Raw,
				arg2:    []byte{0x00, 0x00},
				arg3:    []byte{0x00},
			},
			wantErrors: []string{"timeout required"},
			wantStatus: tchannel.ErrCodeBadRequest,
		},
		{
			desc: "arg2 reader error",
			sendCall: &fakeInboundCall{
				service: "foo",
				caller:  "bar",
				method:  "hello",
				format:  tchannel.Raw,
				arg2:    nil,
				arg3:    []byte{0x00},
			},
			wantErrors: []string{
				`BadRequest: failed to decode "raw" request headers for`,
				`procedure "hello" of service "foo" from caller "bar"`,
			},
			wantStatus: tchannel.ErrCodeBadRequest,
		},
		{
			desc: "arg2 parse error",
			sendCall: &fakeInboundCall{
				service: "foo",
				caller:  "bar",
				method:  "hello",
				format:  tchannel.JSON,
				arg2:    []byte("{not valid JSON}"),
				arg3:    []byte{0x00},
			},
			wantErrors: []string{
				`BadRequest: failed to decode "json" request headers for`,
				`procedure "hello" of service "foo" from caller "bar"`,
			},
			wantStatus: tchannel.ErrCodeBadRequest,
		},
		{
			desc: "arg3 reader error",
			sendCall: &fakeInboundCall{
				service: "foo",
				caller:  "bar",
				method:  "hello",
				format:  tchannel.Raw,
				arg2:    []byte{0x00, 0x00},
				arg3:    nil,
			},
			wantErrors: []string{
				`UnexpectedError: error for procedure "hello" of service "foo"`,
			},
			wantStatus: tchannel.ErrCodeUnexpected,
		},
		{
			desc: "internal error",
			sendCall: &fakeInboundCall{
				service: "foo",
				caller:  "bar",
				method:  "hello",
				format:  tchannel.Raw,
				arg2:    []byte{0x00, 0x00},
				arg3:    []byte{0x00},
			},
			expectCall: func(h *transporttest.MockUnaryHandler) {
				h.EXPECT().Handle(
					transporttest.NewContextMatcher(t, transporttest.ContextTTL(time.Second)),
					transporttest.NewRequestMatcher(
						t, &transport.Request{
							Caller:    "bar",
							Service:   "foo",
							Encoding:  raw.Encoding,
							Procedure: "hello",
							Body:      bytes.NewReader([]byte{0x00}),
						},
					), gomock.Any(),
				).Return(fmt.Errorf("great sadness"))
			},
			wantErrors: []string{
				`UnexpectedError: error for procedure "hello" of service "foo":`,
				"great sadness",
			},
			wantStatus: tchannel.ErrCodeUnexpected,
		},
		{
			desc: "arg3 encode error",
			sendCall: &fakeInboundCall{
				service: "foo",
				caller:  "bar",
				method:  "hello",
				format:  tchannel.JSON,
				arg2:    []byte("{}"),
				arg3:    []byte("{}"),
			},
			expectCall: func(h *transporttest.MockUnaryHandler) {
				req := &transport.Request{
					Caller:    "bar",
					Service:   "foo",
					Encoding:  json.Encoding,
					Procedure: "hello",
					Body:      bytes.NewReader([]byte("{}")),
				}
				h.EXPECT().Handle(
					transporttest.NewContextMatcher(t, transporttest.ContextTTL(time.Second)),
					transporttest.NewRequestMatcher(t, req),
					gomock.Any(),
				).Return(
					encoding.ResponseBodyEncodeError(req, errors.New(
						"serialization derp",
					)))
			},
			wantErrors: []string{
				`UnexpectedError: failed to encode "json" response body for`,
				`procedure "hello" of service "foo" from caller "bar":`,
				`serialization derp`,
			},
			wantStatus: tchannel.ErrCodeUnexpected,
		},
		{
			desc: "handler timeout",
			ctxFunc: func() (context.Context, context.CancelFunc) {
				return context.WithTimeout(context.Background(), time.Millisecond)
			},
			sendCall: &fakeInboundCall{
				service: "foo",
				caller:  "bar",
				method:  "waituntiltimeout",
				format:  tchannel.Raw,
				arg2:    []byte{0x00, 0x00},
				arg3:    []byte{0x00},
			},
			expectCall: func(h *transporttest.MockUnaryHandler) {
				req := &transport.Request{
					Service:   "foo",
					Caller:    "bar",
					Procedure: "waituntiltimeout",
					Encoding:  raw.Encoding,
					Body:      bytes.NewReader([]byte{0x00}),
				}
				h.EXPECT().Handle(
					transporttest.NewContextMatcher(
						t, transporttest.ContextTTL(time.Millisecond)),
					transporttest.NewRequestMatcher(t, req),
					gomock.Any(),
				).Do(func(ctx context.Context, _ *transport.Request, _ transport.ResponseWriter) {
					<-ctx.Done()
				}).Return(context.DeadlineExceeded)
			},
			wantErrors: []string{
				`tchannel error ErrCodeTimeout: Timeout: call to procedure "waituntiltimeout" of service "foo" from caller "bar" timed out after `},
			wantStatus: tchannel.ErrCodeTimeout,
		},
		{
			desc: "handler panic",
			sendCall: &fakeInboundCall{
				service: "foo",
				caller:  "bar",
				method:  "panic",
				format:  tchannel.Raw,
				arg2:    []byte{0x00, 0x00},
				arg3:    []byte{0x00},
			},
			expectCall: func(h *transporttest.MockUnaryHandler) {
				req := &transport.Request{
					Service:   "foo",
					Caller:    "bar",
					Procedure: "panic",
					Encoding:  raw.Encoding,
					Body:      bytes.NewReader([]byte{0x00}),
				}
				h.EXPECT().Handle(
					transporttest.NewContextMatcher(
						t, transporttest.ContextTTL(time.Second)),
					transporttest.NewRequestMatcher(t, req),
					gomock.Any(),
				).Do(func(context.Context, *transport.Request, transport.ResponseWriter) {
					panic("oops I panicked!")
				})
			},
			wantErrors: []string{
				`UnexpectedError: error for procedure "panic" of service "foo": panic: oops I panicked!`,
			},
			wantStatus: tchannel.ErrCodeUnexpected,
		},
	}

	for _, tt := range tests {
		ctx, cancel := context.WithTimeout(context.Background(), time.Second)
		if tt.ctx != nil {
			ctx = tt.ctx
		} else if tt.ctxFunc != nil {
			ctx, cancel = tt.ctxFunc()
		}
		defer cancel()

		mockCtrl := gomock.NewController(t)
		thandler := transporttest.NewMockUnaryHandler(mockCtrl)
		spec := transport.NewUnaryHandlerSpec(thandler)

		if tt.expectCall != nil {
			tt.expectCall(thandler)
		}

		resp := newResponseRecorder()
		tt.sendCall.resp = resp

		registry := transporttest.NewMockRegistry(mockCtrl)
		registry.EXPECT().GetHandlerSpec(tt.sendCall.service, tt.sendCall.method).
			Return(spec, nil).AnyTimes()

		handler{Registry: registry}.handle(ctx, tt.sendCall)
		err := resp.systemErr
		require.Error(t, err, "expected error for %q", tt.desc)

		systemErr, isSystemErr := err.(tchannel.SystemError)
		require.True(t, isSystemErr, "expected %v for %q to be a system error", err, tt.desc)
		assert.Equal(t, tt.wantStatus, systemErr.Code(), tt.desc)

		for _, msg := range tt.wantErrors {
			assert.Contains(
				t, err.Error(), msg,
				"error should contain message for %q", tt.desc)
		}

		mockCtrl.Finish()
	}
}
Beispiel #13
0
func TestCall(t *testing.T) {
	mockCtrl := gomock.NewController(t)
	defer mockCtrl.Finish()

	ctx := context.Background()

	caller := "caller"
	service := "service"

	tests := []struct {
		procedure       string
		headers         yarpc.Headers
		body            interface{}
		encodedRequest  string
		encodedResponse string

		// whether the outbound receives the request
		noCall bool

		// Either want, or wantType and wantErr must be set.
		want        interface{} // expected response body
		wantHeaders yarpc.Headers
		wantType    reflect.Type // type of response body
		wantErr     string       // error message
	}{
		{
			procedure:       "foo",
			body:            []string{"foo", "bar"},
			encodedRequest:  `["foo","bar"]`,
			encodedResponse: `{"success": true}`,
			want:            map[string]interface{}{"success": true},
		},
		{
			procedure:       "bar",
			body:            []int{1, 2, 3},
			encodedRequest:  `[1,2,3]`,
			encodedResponse: `invalid JSON`,
			wantType:        _typeOfMapInterface,
			wantErr:         `failed to decode "json" response body for procedure "bar" of service "service"`,
		},
		{
			procedure: "baz",
			body:      func() {}, // funcs cannot be json.Marshal'ed
			noCall:    true,
			wantType:  _typeOfMapInterface,
			wantErr:   `failed to encode "json" request body for procedure "baz" of service "service"`,
		},
		{
			procedure:       "requestHeaders",
			headers:         yarpc.NewHeaders().With("user-id", "42"),
			body:            map[string]interface{}{},
			encodedRequest:  "{}",
			encodedResponse: "{}",
			want:            map[string]interface{}{},
			wantHeaders:     yarpc.NewHeaders().With("success", "true"),
		},
	}

	for _, tt := range tests {
		outbound := transporttest.NewMockUnaryOutbound(mockCtrl)
		client := New(channel.MultiOutbound(caller, service,
			transport.Outbounds{
				Unary: outbound,
			}))

		if !tt.noCall {
			outbound.EXPECT().Call(gomock.Any(),
				transporttest.NewRequestMatcher(t,
					&transport.Request{
						Caller:    caller,
						Service:   service,
						Procedure: tt.procedure,
						Encoding:  Encoding,
						Headers:   transport.Headers(tt.headers),
						Body:      bytes.NewReader([]byte(tt.encodedRequest)),
					}),
			).Return(
				&transport.Response{
					Body: ioutil.NopCloser(
						bytes.NewReader([]byte(tt.encodedResponse))),
					Headers: transport.Headers(tt.wantHeaders),
				}, nil)
		}

		var wantType reflect.Type
		if tt.want != nil {
			wantType = reflect.TypeOf(tt.want)
		} else {
			require.NotNil(t, tt.wantType, "wantType is required if want is nil")
			wantType = tt.wantType
		}
		resBody := reflect.Zero(wantType).Interface()

		res, err := client.Call(
			ctx,
			yarpc.NewReqMeta().Procedure(tt.procedure).Headers(tt.headers),
			tt.body,
			&resBody,
		)

		if tt.wantErr != "" {
			if assert.Error(t, err) {
				assert.Contains(t, err.Error(), tt.wantErr)
			}
		} else {
			if assert.NoError(t, err) {
				assert.Equal(t, tt.wantHeaders, res.Headers())
				assert.Equal(t, tt.want, resBody)
			}
		}
	}
}
Beispiel #14
0
func TestCallOneway(t *testing.T) {
	mockCtrl := gomock.NewController(t)
	defer mockCtrl.Finish()

	ctx := context.Background()

	caller := "caller"
	service := "service"

	tests := []struct {
		procedure      string
		headers        yarpc.Headers
		body           interface{}
		encodedRequest string

		// whether the outbound receives the request
		noCall bool

		wantErr string // error message
	}{
		{
			procedure:      "foo",
			body:           []string{"foo", "bar"},
			encodedRequest: `["foo","bar"]` + "\n",
		},
		{
			procedure: "baz",
			body:      func() {}, // funcs cannot be json.Marshal'ed
			noCall:    true,
			wantErr:   `failed to encode "json" request body for procedure "baz" of service "service"`,
		},
		{
			procedure:      "requestHeaders",
			headers:        yarpc.NewHeaders().With("user-id", "42"),
			body:           map[string]interface{}{},
			encodedRequest: "{}\n",
		},
	}

	for _, tt := range tests {
		outbound := transporttest.NewMockOnewayOutbound(mockCtrl)
		client := New(channel.MultiOutbound(caller, service,
			transport.Outbounds{
				Oneway: outbound,
			}))

		if !tt.noCall {
			reqMatcher := transporttest.NewRequestMatcher(t,
				&transport.Request{
					Caller:    caller,
					Service:   service,
					Procedure: tt.procedure,
					Encoding:  Encoding,
					Headers:   transport.Headers(tt.headers),
					Body:      bytes.NewReader([]byte(tt.encodedRequest)),
				})

			if tt.wantErr != "" {
				outbound.
					EXPECT().
					CallOneway(gomock.Any(), reqMatcher).
					Return(nil, errors.New(tt.wantErr))
			} else {
				outbound.
					EXPECT().
					CallOneway(gomock.Any(), reqMatcher).
					Return(&successAck{}, nil)
			}
		}

		ack, err := client.CallOneway(
			ctx,
			yarpc.NewReqMeta().Procedure(tt.procedure).Headers(tt.headers),
			tt.body)

		if tt.wantErr != "" {
			assert.Error(t, err)
			assert.Contains(t, err.Error(), tt.wantErr)
		} else {
			assert.NoError(t, err, "")
			assert.Equal(t, ack.String(), "success")
		}
	}
}