예제 #1
0
func TestNopFilter(t *testing.T) {
	mockCtrl := gomock.NewController(t)
	defer mockCtrl.Finish()

	o := transporttest.NewMockUnaryOutbound(mockCtrl)
	wrappedO := transport.ApplyFilter(o, transport.NopFilter)

	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
	defer cancel()
	req := &transport.Request{
		Caller:    "somecaller",
		Service:   "someservice",
		Encoding:  raw.Encoding,
		Procedure: "hello",
		Body:      bytes.NewReader([]byte{1, 2, 3}),
	}

	res := &transport.Response{Body: ioutil.NopCloser(bytes.NewReader([]byte{4, 5, 6}))}
	o.EXPECT().Call(ctx, req).Return(res, nil)

	got, err := wrappedO.Call(ctx, req)
	if assert.NoError(t, err) {
		assert.Equal(t, res, got)
	}
}
예제 #2
0
func TestChain(t *testing.T) {
	mockCtrl := gomock.NewController(t)
	defer mockCtrl.Finish()

	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
	defer cancel()
	req := &transport.Request{
		Caller:    "somecaller",
		Service:   "someservice",
		Encoding:  transport.Encoding("raw"),
		Procedure: "hello",
		Body:      bytes.NewReader([]byte{1, 2, 3}),
	}
	res := &transport.Response{
		Body: ioutil.NopCloser(bytes.NewReader([]byte{4, 5, 6})),
	}

	o := transporttest.NewMockUnaryOutbound(mockCtrl)
	o.EXPECT().Call(ctx, req).After(
		o.EXPECT().Call(ctx, req).Return(nil, errors.New("great sadness")),
	).Return(res, nil)

	before := &countFilter{}
	after := &countFilter{}
	gotRes, err := transport.ApplyFilter(
		o, Chain(before, retryFilter, after)).Call(ctx, req)

	assert.NoError(t, err, "expected success")
	assert.Equal(t, 1, before.Count, "expected outer filter to be called once")
	assert.Equal(t, 2, after.Count, "expected inner filter to be called twice")
	assert.Equal(t, res, gotRes, "expected response to match")
}
예제 #3
0
func TestCall(t *testing.T) {
	mockCtrl := gomock.NewController(t)
	defer mockCtrl.Finish()

	ctx := context.Background()

	caller := "caller"
	service := "service"

	tests := []struct {
		procedure    string
		headers      yarpc.Headers
		body         []byte
		responseBody [][]byte

		want        []byte
		wantErr     string
		wantHeaders yarpc.Headers
	}{
		{
			procedure:    "foo",
			body:         []byte{1, 2, 3},
			responseBody: [][]byte{{4}, {5}, {6}},
			want:         []byte{4, 5, 6},
		},
		{
			procedure:    "bar",
			body:         []byte{1, 2, 3},
			responseBody: [][]byte{{4}, {5}, nil, {6}},
			wantErr:      "error set by user",
		},
		{
			procedure:    "headers",
			headers:      yarpc.NewHeaders().With("x", "y"),
			body:         []byte{},
			responseBody: [][]byte{},
			want:         []byte{},
			wantHeaders:  yarpc.NewHeaders().With("a", "b"),
		},
	}

	for _, tt := range tests {
		outbound := transporttest.NewMockUnaryOutbound(mockCtrl)
		client := New(channel.MultiOutbound(caller, service,
			transport.Outbounds{
				Unary: outbound,
			}))

		writer, responseBody := testreader.ChunkReader()
		for _, chunk := range tt.responseBody {
			writer <- chunk
		}
		close(writer)

		outbound.EXPECT().Call(gomock.Any(),
			transporttest.NewRequestMatcher(t,
				&transport.Request{
					Caller:    caller,
					Service:   service,
					Procedure: tt.procedure,
					Headers:   transport.Headers(tt.headers),
					Encoding:  Encoding,
					Body:      bytes.NewReader(tt.body),
				}),
		).Return(
			&transport.Response{
				Body:    ioutil.NopCloser(responseBody),
				Headers: transport.Headers(tt.wantHeaders),
			}, nil)

		resBody, res, err := client.Call(
			ctx,
			yarpc.NewReqMeta().Procedure(tt.procedure).Headers(tt.headers),
			tt.body)

		if tt.wantErr != "" {
			if assert.Error(t, err) {
				assert.Equal(t, err.Error(), tt.wantErr)
			}
		} else {
			if assert.NoError(t, err) {
				assert.Equal(t, tt.want, resBody)
				assert.Equal(t, tt.wantHeaders, res.Headers())
			}
		}
	}
}
예제 #4
0
func TestStartStopFailures(t *testing.T) {
	tests := []struct {
		desc string

		inbounds  func(*gomock.Controller) []transport.Inbound
		outbounds func(*gomock.Controller) Outbounds

		wantStartErr string
		wantStopErr  string
	}{
		{
			desc: "all success",
			inbounds: func(mockCtrl *gomock.Controller) []transport.Inbound {
				inbounds := make([]transport.Inbound, 10)
				for i := range inbounds {
					in := transporttest.NewMockInbound(mockCtrl)
					in.EXPECT().Start(gomock.Any(), gomock.Any()).Return(nil)
					in.EXPECT().Stop().Return(nil)
					inbounds[i] = in
				}
				return inbounds
			},
			outbounds: func(mockCtrl *gomock.Controller) Outbounds {
				outbounds := make(Outbounds, 10)
				for i := 0; i < 10; i++ {
					out := transporttest.NewMockUnaryOutbound(mockCtrl)
					out.EXPECT().Start(gomock.Any()).Return(nil)
					out.EXPECT().Stop().Return(nil)
					outbounds[fmt.Sprintf("service-%v", i)] =
						transport.Outbounds{
							Unary: out,
						}
				}
				return outbounds
			},
		},
		{
			desc: "inbound 6 start failure",
			inbounds: func(mockCtrl *gomock.Controller) []transport.Inbound {
				inbounds := make([]transport.Inbound, 10)
				for i := range inbounds {
					in := transporttest.NewMockInbound(mockCtrl)
					if i == 6 {
						in.EXPECT().Start(gomock.Any(), gomock.Any()).Return(errors.New("great sadness"))
					} else {
						in.EXPECT().Start(gomock.Any(), gomock.Any()).Return(nil)
						in.EXPECT().Stop().Return(nil)
					}
					inbounds[i] = in
				}
				return inbounds
			},
			outbounds: func(mockCtrl *gomock.Controller) Outbounds {
				outbounds := make(Outbounds, 10)
				for i := 0; i < 10; i++ {
					out := transporttest.NewMockUnaryOutbound(mockCtrl)
					out.EXPECT().Start(gomock.Any()).Return(nil)
					out.EXPECT().Stop().Return(nil)
					outbounds[fmt.Sprintf("service-%v", i)] =
						transport.Outbounds{
							Unary: out,
						}
				}
				return outbounds
			},
			wantStartErr: "great sadness",
		},
		{
			desc: "inbound 7 stop failure",
			inbounds: func(mockCtrl *gomock.Controller) []transport.Inbound {
				inbounds := make([]transport.Inbound, 10)
				for i := range inbounds {
					in := transporttest.NewMockInbound(mockCtrl)
					in.EXPECT().Start(gomock.Any(), gomock.Any()).Return(nil)
					if i == 7 {
						in.EXPECT().Stop().Return(errors.New("great sadness"))
					} else {
						in.EXPECT().Stop().Return(nil)
					}
					inbounds[i] = in
				}
				return inbounds
			},
			outbounds: func(mockCtrl *gomock.Controller) Outbounds {
				outbounds := make(Outbounds, 10)
				for i := 0; i < 10; i++ {
					out := transporttest.NewMockUnaryOutbound(mockCtrl)
					out.EXPECT().Start(gomock.Any()).Return(nil)
					out.EXPECT().Stop().Return(nil)
					outbounds[fmt.Sprintf("service-%v", i)] =
						transport.Outbounds{
							Unary: out,
						}
				}
				return outbounds
			},
			wantStopErr: "great sadness",
		},
		{
			desc: "outbound 5 start failure",
			inbounds: func(mockCtrl *gomock.Controller) []transport.Inbound {
				inbounds := make([]transport.Inbound, 10)
				for i := range inbounds {
					in := transporttest.NewMockInbound(mockCtrl)
					in.EXPECT().Start(gomock.Any(), gomock.Any()).Return(nil)
					in.EXPECT().Stop().Return(nil)
					inbounds[i] = in
				}
				return inbounds
			},
			outbounds: func(mockCtrl *gomock.Controller) Outbounds {
				outbounds := make(Outbounds, 10)
				for i := 0; i < 10; i++ {
					out := transporttest.NewMockUnaryOutbound(mockCtrl)
					if i == 5 {
						out.EXPECT().Start(gomock.Any()).Return(errors.New("something went wrong"))
					} else {
						out.EXPECT().Start(gomock.Any()).Return(nil)
						out.EXPECT().Stop().Return(nil)
					}
					outbounds[fmt.Sprintf("service-%v", i)] =
						transport.Outbounds{
							Unary: out,
						}
				}
				return outbounds
			},
			wantStartErr: "something went wrong",
			// TODO: Include the name of the outbound in the error message
		},
		{
			desc: "inbound 7 stop failure",
			inbounds: func(mockCtrl *gomock.Controller) []transport.Inbound {
				inbounds := make([]transport.Inbound, 10)
				for i := range inbounds {
					in := transporttest.NewMockInbound(mockCtrl)
					in.EXPECT().Start(gomock.Any(), gomock.Any()).Return(nil)
					in.EXPECT().Stop().Return(nil)
					inbounds[i] = in
				}
				return inbounds
			},
			outbounds: func(mockCtrl *gomock.Controller) Outbounds {
				outbounds := make(Outbounds, 10)
				for i := 0; i < 10; i++ {
					out := transporttest.NewMockUnaryOutbound(mockCtrl)
					out.EXPECT().Start(gomock.Any()).Return(nil)
					if i == 7 {
						out.EXPECT().Stop().Return(errors.New("something went wrong"))
					} else {
						out.EXPECT().Stop().Return(nil)
					}
					outbounds[fmt.Sprintf("service-%v", i)] =
						transport.Outbounds{
							Unary: out,
						}
				}
				return outbounds
			},
			wantStopErr: "something went wrong",
			// TODO: Include the name of the outbound in the error message
		},
	}

	mockCtrl := gomock.NewController(t)
	defer mockCtrl.Finish()
	for _, tt := range tests {

		dispatcher := NewDispatcher(Config{
			Name:      "test",
			Inbounds:  tt.inbounds(mockCtrl),
			Outbounds: tt.outbounds(mockCtrl),
		})

		err := dispatcher.Start()
		if tt.wantStartErr != "" {
			if assert.Error(t, err, "%v: expected Start() to fail", tt.desc) {
				assert.Contains(t, err.Error(), tt.wantStartErr, tt.desc)
			}
			continue
		}
		if !assert.NoError(t, err, "%v: expected Start() to succeed", tt.desc) {
			continue
		}

		err = dispatcher.Stop()
		if tt.wantStopErr == "" {
			assert.NoError(t, err, "%v: expected Stop() to succeed", tt.desc)
			continue
		}
		if assert.Error(t, err, "%v: expected Stop() to fail", tt.desc) {
			assert.Contains(t, err.Error(), tt.wantStopErr, tt.desc)
		}
	}
}
예제 #5
0
func TestClient(t *testing.T) {
	tests := []struct {
		desc                 string
		giveRequestBody      envelope.Enveloper // outgoing request body
		giveResponseEnvelope *wire.Envelope     // returned on DecodeEnveloped()
		giveResponseBody     *wire.Value        // return on Decode()
		clientOptions        []ClientOption

		expectCall          bool           // whether outbound.Call is expected
		wantRequestEnvelope *wire.Envelope // expect EncodeEnveloped(x)
		wantRequestBody     *wire.Value    // expect Encode(x)
		wantError           string         // whether an error is expected
	}{
		{
			desc:            "happy case",
			clientOptions:   []ClientOption{Enveloped},
			giveRequestBody: fakeEnveloper(wire.Call),
			wantRequestEnvelope: &wire.Envelope{
				Name:  "someMethod",
				SeqID: 1,
				Type:  wire.Call,
				Value: wire.NewValueStruct(wire.Struct{}),
			},
			expectCall: true,
			giveResponseEnvelope: &wire.Envelope{
				Name:  "someMethod",
				SeqID: 1,
				Type:  wire.Reply,
				Value: wire.NewValueStruct(wire.Struct{}),
			},
		},
		{
			desc:             "happy case without enveloping",
			giveRequestBody:  fakeEnveloper(wire.Call),
			wantRequestBody:  valueptr(wire.NewValueStruct(wire.Struct{})),
			expectCall:       true,
			giveResponseBody: valueptr(wire.NewValueStruct(wire.Struct{})),
		},
		{
			desc:            "wrong envelope type for request",
			clientOptions:   []ClientOption{Enveloped},
			giveRequestBody: fakeEnveloper(wire.Reply),
			wantError: `failed to encode "thrift" request body for procedure ` +
				`"MyService::someMethod" of service "service": unexpected envelope type: Reply`,
		},
		{
			desc:            "TApplicationException",
			clientOptions:   []ClientOption{Enveloped},
			giveRequestBody: fakeEnveloper(wire.Call),
			wantRequestEnvelope: &wire.Envelope{
				Name:  "someMethod",
				SeqID: 1,
				Type:  wire.Call,
				Value: wire.NewValueStruct(wire.Struct{}),
			},
			expectCall: true,
			giveResponseEnvelope: &wire.Envelope{
				Name:  "someMethod",
				SeqID: 1,
				Type:  wire.Exception,
				Value: wire.NewValueStruct(wire.Struct{Fields: []wire.Field{
					{ID: 1, Value: wire.NewValueString("great sadness")},
					{ID: 2, Value: wire.NewValueI32(7)},
				}}),
			},
			wantError: `thrift request to procedure "MyService::someMethod" of ` +
				`service "service" encountered an internal failure: ` +
				"TApplicationException{Message: great sadness, Type: PROTOCOL_ERROR}",
		},
		{
			desc:            "wrong envelope type for response",
			clientOptions:   []ClientOption{Enveloped},
			giveRequestBody: fakeEnveloper(wire.Call),
			wantRequestEnvelope: &wire.Envelope{
				Name:  "someMethod",
				SeqID: 1,
				Type:  wire.Call,
				Value: wire.NewValueStruct(wire.Struct{}),
			},
			expectCall: true,
			giveResponseEnvelope: &wire.Envelope{
				Name:  "someMethod",
				SeqID: 1,
				Type:  wire.Call,
				Value: wire.NewValueStruct(wire.Struct{}),
			},
			wantError: `failed to decode "thrift" response body for procedure ` +
				`"MyService::someMethod" of service "service": unexpected envelope type: Call`,
		},
	}

	for _, tt := range tests {
		mockCtrl := gomock.NewController(t)
		defer mockCtrl.Finish()

		proto := NewMockProtocol(mockCtrl)

		if tt.wantRequestEnvelope != nil {
			proto.EXPECT().EncodeEnveloped(*tt.wantRequestEnvelope, gomock.Any()).
				Do(func(_ wire.Envelope, w io.Writer) {
					_, err := w.Write([]byte("irrelevant"))
					require.NoError(t, err, "Write() failed")
				}).Return(nil)
		}

		if tt.wantRequestBody != nil {
			proto.EXPECT().Encode(*tt.wantRequestBody, gomock.Any()).
				Do(func(_ wire.Value, w io.Writer) {
					_, err := w.Write([]byte("irrelevant"))
					require.NoError(t, err, "Write() failed")
				}).Return(nil)
		}

		ctx, cancel := context.WithTimeout(context.Background(), time.Second)
		defer cancel()

		trans := transporttest.NewMockUnaryOutbound(mockCtrl)
		if tt.expectCall {
			trans.EXPECT().Call(ctx,
				transporttest.NewRequestMatcher(t, &transport.Request{
					Caller:    "caller",
					Service:   "service",
					Encoding:  Encoding,
					Procedure: "MyService::someMethod",
					Body:      bytes.NewReader([]byte("irrelevant")),
				}),
			).Return(&transport.Response{
				Body: ioutil.NopCloser(bytes.NewReader([]byte("irrelevant"))),
			}, nil)
		}

		if tt.giveResponseEnvelope != nil {
			proto.EXPECT().DecodeEnveloped(gomock.Any()).Return(*tt.giveResponseEnvelope, nil)
		}

		if tt.giveResponseBody != nil {
			proto.EXPECT().Decode(gomock.Any(), wire.TStruct).Return(*tt.giveResponseBody, nil)
		}

		opts := tt.clientOptions
		opts = append(opts, Protocol(proto))
		c := New(Config{
			Service: "MyService",
			Channel: channel.MultiOutbound("caller", "service",
				transport.Outbounds{
					Unary: trans,
				}),
		}, opts...)

		_, _, err := c.Call(ctx, nil, tt.giveRequestBody)
		if tt.wantError != "" {
			if assert.Error(t, err, "%v: expected failure", tt.desc) {
				assert.Contains(t, err.Error(), tt.wantError, "%v: error mismatch", tt.desc)
			}
		} else {
			assert.NoError(t, err, "%v: expected success", tt.desc)
		}
	}
}
예제 #6
0
func TestInjectClientSuccess(t *testing.T) {
	type unknownClient interface{}

	type knownClient interface{}
	clear := yarpc.RegisterClientBuilder(
		func(transport.Channel) knownClient { return knownClient(struct{}{}) })
	defer clear()

	mockCtrl := gomock.NewController(t)
	defer mockCtrl.Finish()

	tests := []struct {
		name   string
		target interface{}

		// list of services for which Channel() should return successfully
		knownServices []string

		// list of field names in target we expect to be nil or non-nil
		wantNil    []string
		wantNonNil []string
	}{
		{
			name:   "empty",
			target: &struct{}{},
		},
		{
			name: "unknown service non-nil",
			target: &struct {
				Client json.Client `service:"foo"`
			}{
				Client: json.New(channel.MultiOutbound(
					"foo",
					"bar",
					transport.Outbounds{
						Unary: transporttest.NewMockUnaryOutbound(mockCtrl),
					})),
			},
			wantNonNil: []string{"Client"},
		},
		{
			name: "unknown type untagged",
			target: &struct {
				Client unknownClient `notservice:"foo"`
			}{},
			wantNil: []string{"Client"},
		},
		{
			name: "unknown type non-nil",
			target: &struct {
				Client unknownClient `service:"foo"`
			}{Client: unknownClient(struct{}{})},
			wantNonNil: []string{"Client"},
		},
		{
			name:          "known type",
			knownServices: []string{"foo"},
			target: &struct {
				Client knownClient `service:"foo"`
			}{},
			wantNonNil: []string{"Client"},
		},
		{
			name:          "default encodings",
			knownServices: []string{"jsontest", "rawtest"},
			target: &struct {
				JSON json.Client `service:"jsontest"`
				Raw  raw.Client  `service:"rawtest"`
			}{},
			wantNonNil: []string{"JSON", "Raw"},
		},
		{
			name: "unexported field",
			target: &struct {
				rawClient raw.Client `service:"rawtest"`
			}{},
			wantNil: []string{"rawClient"},
		},
	}

	for _, tt := range tests {
		cp := newMockChannelProvier(mockCtrl, tt.knownServices...)
		assert.NotPanics(t, func() {
			yarpc.InjectClients(cp, tt.target)
		}, tt.name)

		for _, fieldName := range tt.wantNil {
			field := reflect.ValueOf(tt.target).Elem().FieldByName(fieldName)
			assert.True(t, field.IsNil(), "expected %q to be nil", fieldName)
		}

		for _, fieldName := range tt.wantNonNil {
			field := reflect.ValueOf(tt.target).Elem().FieldByName(fieldName)
			assert.False(t, field.IsNil(), "expected %q to be non-nil", fieldName)
		}
	}
}
예제 #7
0
func TestCall(t *testing.T) {
	mockCtrl := gomock.NewController(t)
	defer mockCtrl.Finish()

	ctx := context.Background()

	caller := "caller"
	service := "service"

	tests := []struct {
		procedure       string
		headers         yarpc.Headers
		body            interface{}
		encodedRequest  string
		encodedResponse string

		// whether the outbound receives the request
		noCall bool

		// Either want, or wantType and wantErr must be set.
		want        interface{} // expected response body
		wantHeaders yarpc.Headers
		wantType    reflect.Type // type of response body
		wantErr     string       // error message
	}{
		{
			procedure:       "foo",
			body:            []string{"foo", "bar"},
			encodedRequest:  `["foo","bar"]`,
			encodedResponse: `{"success": true}`,
			want:            map[string]interface{}{"success": true},
		},
		{
			procedure:       "bar",
			body:            []int{1, 2, 3},
			encodedRequest:  `[1,2,3]`,
			encodedResponse: `invalid JSON`,
			wantType:        _typeOfMapInterface,
			wantErr:         `failed to decode "json" response body for procedure "bar" of service "service"`,
		},
		{
			procedure: "baz",
			body:      func() {}, // funcs cannot be json.Marshal'ed
			noCall:    true,
			wantType:  _typeOfMapInterface,
			wantErr:   `failed to encode "json" request body for procedure "baz" of service "service"`,
		},
		{
			procedure:       "requestHeaders",
			headers:         yarpc.NewHeaders().With("user-id", "42"),
			body:            map[string]interface{}{},
			encodedRequest:  "{}",
			encodedResponse: "{}",
			want:            map[string]interface{}{},
			wantHeaders:     yarpc.NewHeaders().With("success", "true"),
		},
	}

	for _, tt := range tests {
		outbound := transporttest.NewMockUnaryOutbound(mockCtrl)
		client := New(channel.MultiOutbound(caller, service,
			transport.Outbounds{
				Unary: outbound,
			}))

		if !tt.noCall {
			outbound.EXPECT().Call(gomock.Any(),
				transporttest.NewRequestMatcher(t,
					&transport.Request{
						Caller:    caller,
						Service:   service,
						Procedure: tt.procedure,
						Encoding:  Encoding,
						Headers:   transport.Headers(tt.headers),
						Body:      bytes.NewReader([]byte(tt.encodedRequest)),
					}),
			).Return(
				&transport.Response{
					Body: ioutil.NopCloser(
						bytes.NewReader([]byte(tt.encodedResponse))),
					Headers: transport.Headers(tt.wantHeaders),
				}, nil)
		}

		var wantType reflect.Type
		if tt.want != nil {
			wantType = reflect.TypeOf(tt.want)
		} else {
			require.NotNil(t, tt.wantType, "wantType is required if want is nil")
			wantType = tt.wantType
		}
		resBody := reflect.Zero(wantType).Interface()

		res, err := client.Call(
			ctx,
			yarpc.NewReqMeta().Procedure(tt.procedure).Headers(tt.headers),
			tt.body,
			&resBody,
		)

		if tt.wantErr != "" {
			if assert.Error(t, err) {
				assert.Contains(t, err.Error(), tt.wantErr)
			}
		} else {
			if assert.NoError(t, err) {
				assert.Equal(t, tt.wantHeaders, res.Headers())
				assert.Equal(t, tt.want, resBody)
			}
		}
	}
}