예제 #1
0
func TestOutboundHeaders(t *testing.T) {
	tests := []struct {
		desc    string
		context context.Context
		headers transport.Headers

		wantHeaders map[string]string
	}{
		{
			desc:    "application headers",
			headers: transport.NewHeaders().With("foo", "bar").With("baz", "Qux"),
			wantHeaders: map[string]string{
				"Rpc-Header-Foo": "bar",
				"Rpc-Header-Baz": "Qux",
			},
		},
	}

	for _, tt := range tests {
		server := httptest.NewServer(http.HandlerFunc(
			func(w http.ResponseWriter, r *http.Request) {
				defer r.Body.Close()
				for k, v := range tt.wantHeaders {
					assert.Equal(
						t, v, r.Header.Get(k), "%v: header %v did not match", tt.desc, k)
				}
			},
		))
		defer server.Close()

		ctx := tt.context
		if ctx == nil {
			var cancel context.CancelFunc
			ctx, cancel = context.WithTimeout(context.Background(), time.Second)
			defer cancel()
		}

		out := NewOutbound(server.URL)
		require.NoError(t, out.Start(transport.NoDeps), "failed to start outbound")
		defer out.Stop()

		res, err := out.Call(ctx, &transport.Request{
			Caller:    "caller",
			Service:   "service",
			Encoding:  raw.Encoding,
			Headers:   tt.headers,
			Procedure: "hello",
			Body:      bytes.NewReader([]byte("world")),
		})

		if !assert.NoError(t, err, "%v: call failed", tt.desc) {
			continue
		}

		if !assert.NoError(t, res.Body.Close(), "%v: failed to close response body") {
			continue
		}
	}
}
예제 #2
0
func TestResponseWriterAddHeadersAfterWrite(t *testing.T) {
	call := &fakeInboundCall{format: tchannel.Raw, resp: newResponseRecorder()}
	w := newResponseWriter(new(transport.Request), call)
	w.Write([]byte("foo"))
	assert.Panics(t, func() {
		w.AddHeaders(transport.NewHeaders().With("foo", "bar"))
	})
}
예제 #3
0
// Headers returns a new randomized header.
func (r *randomGenerator) Headers() transport.Headers {
	headers := transport.NewHeaders()
	size := 2 + r.randsrc.Intn(6)
	for i := 0; i < size; i++ {
		headers = headers.With(r.Atom(), r.Atom())
	}
	return headers
}
예제 #4
0
func TestHandleSuccessWithResponseHeaders(t *testing.T) {
	h := func(ctx context.Context, r yarpc.ReqMeta, _ *simpleRequest) (*simpleResponse, yarpc.ResMeta, error) {
		resMeta := yarpc.NewResMeta().Headers(yarpc.NewHeaders().With("foo", "bar"))
		return &simpleResponse{Success: true}, resMeta, nil
	}

	handler := jsonHandler{
		reader:  structReader{reflect.TypeOf(simpleRequest{})},
		handler: reflect.ValueOf(h),
	}

	resw := new(transporttest.FakeResponseWriter)
	err := handler.Handle(context.Background(), &transport.Request{
		Procedure: "simpleCall",
		Encoding:  "json",
		Body:      jsonBody(`{"name": "foo", "attributes": {"bar": 42}}`),
	}, resw)
	require.NoError(t, err)

	assert.Equal(t, transport.NewHeaders().With("foo", "bar"), resw.Headers)
}
예제 #5
0
func TestOutboundHeaders(t *testing.T) {
	tests := []struct {
		context context.Context
		headers transport.Headers

		wantHeaders []byte
		wantError   string
	}{
		{
			headers: transport.NewHeaders().With("contextfoo", "bar"),
			wantHeaders: []byte{
				0x00, 0x01,
				0x00, 0x0A, 'c', 'o', 'n', 't', 'e', 'x', 't', 'f', 'o', 'o',
				0x00, 0x03, 'b', 'a', 'r',
			},
		},
		{
			headers: transport.NewHeaders().With("Foo", "bar"),
			wantHeaders: []byte{
				0x00, 0x01,
				0x00, 0x03, 'f', 'o', 'o',
				0x00, 0x03, 'b', 'a', 'r',
			},
		},
	}

	for _, tt := range tests {
		server := testutils.NewServer(t, nil)
		defer server.Close()
		hostport := server.PeerInfo().HostPort

		server.GetSubChannel("service").SetHandler(tchannel.HandlerFunc(
			func(ctx context.Context, call *tchannel.InboundCall) {
				headers, body, err := readArgs(call)
				if assert.NoError(t, err, "failed to read request") {
					assert.Equal(t, tt.wantHeaders, headers, "headers did not match")
					assert.Equal(t, []byte("world"), body)
				}

				err = writeArgs(call.Response(), []byte{0x00, 0x00}, []byte("bye!"))
				assert.NoError(t, err, "failed to write response")
			}))

		for _, getOutbound := range newOutbounds {
			out := getOutbound(testutils.NewClient(t, &testutils.ChannelOpts{
				ServiceName: "caller",
			}), hostport)
			require.NoError(t, out.Start(transport.NoDeps), "failed to start outbound")
			defer out.Stop()

			ctx := tt.context
			if ctx == nil {
				ctx = context.Background()
			}
			ctx, cancel := context.WithTimeout(ctx, time.Second)
			defer cancel()

			res, err := out.Call(
				ctx,
				&transport.Request{
					Caller:    "caller",
					Service:   "service",
					Encoding:  raw.Encoding,
					Procedure: "hello",
					Headers:   tt.headers,
					Body:      bytes.NewReader([]byte("world")),
				},
			)
			if tt.wantError != "" {
				if assert.Error(t, err, "expected error") {
					assert.Contains(t, err.Error(), tt.wantError)
				}
			} else {
				if assert.NoError(t, err, "call failed") {
					defer res.Body.Close()
				}
			}
		}
	}
}
예제 #6
0
func TestRawHandler(t *testing.T) {
	// handler to use for test cases where the handler should not be called
	handlerNotCalled :=
		func(ctx context.Context, reqMeta yarpc.ReqMeta, body []byte) ([]byte, yarpc.ResMeta, error) {
			t.Errorf("unexpected call handle(%v, %v)", reqMeta, body)
			return nil, nil, fmt.Errorf("unexpected call handle(%v, %v)", reqMeta, body)
		}

	tests := []struct {
		procedure  string
		headers    transport.Headers
		bodyChunks [][]byte
		handler    UnaryHandler

		wantErr     string
		wantHeaders transport.Headers
		wantBody    []byte
	}{
		{
			procedure: "foo",
			bodyChunks: [][]byte{
				{1, 2, 3},
				{4, 5, 6},
			},
			handler: func(ctx context.Context, reqMeta yarpc.ReqMeta, body []byte) ([]byte, yarpc.ResMeta, error) {
				assert.Equal(t, "foo", reqMeta.Procedure())
				assert.Equal(t, []byte{1, 2, 3, 4, 5, 6}, body)
				return []byte("hello"), nil, nil
			},
			wantBody: []byte("hello"),
		},
		{
			procedure: "bar",
			bodyChunks: [][]byte{
				{1, 2, 3},
				nil, // triggers a read error
				{4, 5, 6},
			},
			handler: handlerNotCalled,
			wantErr: "error set by user",
			// TODO consistent error messages between languages
		},
		{
			procedure:  "baz",
			bodyChunks: [][]byte{},
			handler: func(ctx context.Context, reqMeta yarpc.ReqMeta, body []byte) ([]byte, yarpc.ResMeta, error) {
				assert.Equal(t, []byte{}, body)
				return nil, nil, fmt.Errorf("great sadness")
			},
			wantErr: "great sadness",
		},
		{
			procedure:  "responseHeaders",
			bodyChunks: [][]byte{},
			handler: func(ctx context.Context, reqMeta yarpc.ReqMeta, body []byte) ([]byte, yarpc.ResMeta, error) {
				resMeta := yarpc.NewResMeta().Headers(yarpc.NewHeaders().With("hello", "world"))
				return []byte{}, resMeta, nil
			},
			wantHeaders: transport.NewHeaders().With("hello", "world"),
		},
	}

	for _, tt := range tests {
		handler := rawUnaryHandler{tt.handler}
		resw := new(transporttest.FakeResponseWriter)

		writer, chunkReader := testreader.ChunkReader()
		for _, chunk := range tt.bodyChunks {
			writer <- chunk
		}
		close(writer)

		ctx, cancel := context.WithTimeout(context.Background(), time.Second)
		defer cancel()

		err := handler.Handle(ctx, &transport.Request{
			Procedure: tt.procedure,
			Headers:   tt.headers,
			Encoding:  "raw",
			Body:      chunkReader,
		}, resw)

		if tt.wantErr != "" {
			if assert.Error(t, err) {
				assert.Equal(t, err.Error(), tt.wantErr)
			}
		} else {
			if assert.NoError(t, err) {
				assert.Equal(t, tt.wantHeaders, resw.Headers)
				assert.Equal(t, tt.wantBody, resw.Body.Bytes(),
					"body does not match for %s", tt.procedure)
			}
		}
	}
}
예제 #7
0
func TestSimpleRoundTripOneway(t *testing.T) {
	trans := httpTransport{t}

	tests := []struct {
		name           string
		requestHeaders transport.Headers
		requestBody    string
	}{
		{
			name:           "hello world",
			requestHeaders: transport.NewHeaders().With("foo", "bar"),
			requestBody:    "hello world",
		},
		{
			name:           "empty",
			requestHeaders: transport.NewHeaders(),
			requestBody:    "",
		},
	}

	rootCtx := context.Background()

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {

			requestMatcher := transporttest.NewRequestMatcher(t, &transport.Request{
				Caller:    testCaller,
				Service:   testService,
				Procedure: testProcedureOneway,
				Encoding:  raw.Encoding,
				Headers:   tt.requestHeaders,
				Body:      bytes.NewReader([]byte(tt.requestBody)),
			})

			handlerDone := make(chan struct{})

			onewayHandler := onewayHandlerFunc(func(_ context.Context, r *transport.Request) error {
				assert.True(t, requestMatcher.Matches(r), "request mismatch: received %v", r)

				// Pretend to work: this delay should not slow down tests since it is a
				// server-side operation
				time.Sleep(5 * time.Second)

				// close the channel, telling the client (which should not be waiting for
				// a response) that the handler finished executing
				close(handlerDone)

				return nil
			})

			registry := staticRegistry{OnewayHandler: onewayHandler}

			trans.WithRegistryOneway(registry, func(o transport.OnewayOutbound) {
				ack, err := o.CallOneway(rootCtx, &transport.Request{
					Caller:    testCaller,
					Service:   testService,
					Procedure: testProcedureOneway,
					Encoding:  raw.Encoding,
					Headers:   tt.requestHeaders,
					Body:      bytes.NewReader([]byte(tt.requestBody)),
				})

				select {
				case <-handlerDone:
					// if the server filled the channel, it means we waited for the server
					// to complete the request
					assert.Fail(t, "client waited for server handler to finish executing")
				default:
				}

				if assert.NoError(t, err, "%T: oneway call failed for test '%v'", trans, tt.name) {
					assert.NotNil(t, ack)
				}
			})
		})
	}
}
예제 #8
0
func TestSimpleRoundTrip(t *testing.T) {
	transports := []roundTripTransport{
		httpTransport{t},
		tchannelTransport{t},
	}

	tests := []struct {
		requestHeaders  transport.Headers
		requestBody     string
		responseHeaders transport.Headers
		responseBody    string
		responseError   error

		wantError func(error)
	}{
		{
			requestHeaders:  transport.NewHeaders().With("token", "1234"),
			requestBody:     "world",
			responseHeaders: transport.NewHeaders().With("status", "ok"),
			responseBody:    "hello, world",
		},
		{
			requestBody:   "foo",
			responseError: errors.HandlerUnexpectedError(fmt.Errorf("great sadness")),
			wantError: func(err error) {
				assert.True(t, transport.IsUnexpectedError(err), err)
				assert.Equal(t, "UnexpectedError: great sadness", err.Error())
			},
		},
		{
			requestBody:   "bar",
			responseError: errors.HandlerBadRequestError(fmt.Errorf("missing service name")),
			wantError: func(err error) {
				assert.True(t, transport.IsBadRequestError(err))
				assert.Equal(t, "BadRequest: missing service name", err.Error())
			},
		},
		{
			requestBody: "baz",
			responseError: errors.RemoteUnexpectedError(
				`UnexpectedError: error for procedure "foo" of service "bar": great sadness`,
			),
			wantError: func(err error) {
				assert.True(t, transport.IsUnexpectedError(err))
				assert.Equal(t,
					`UnexpectedError: error for procedure "hello" of service "testService": `+
						`UnexpectedError: error for procedure "foo" of service "bar": great sadness`,
					err.Error())
			},
		},
		{
			requestBody: "qux",
			responseError: errors.RemoteBadRequestError(
				`BadRequest: unrecognized procedure "echo" for service "derp"`,
			),
			wantError: func(err error) {
				assert.True(t, transport.IsUnexpectedError(err))
				assert.Equal(t,
					`UnexpectedError: error for procedure "hello" of service "testService": `+
						`BadRequest: unrecognized procedure "echo" for service "derp"`,
					err.Error())
			},
		},
	}

	rootCtx := context.Background()
	for _, tt := range tests {
		for _, trans := range transports {
			requestMatcher := transporttest.NewRequestMatcher(t, &transport.Request{
				Caller:    testCaller,
				Service:   testService,
				Procedure: testProcedure,
				Encoding:  raw.Encoding,
				Headers:   tt.requestHeaders,
				Body:      bytes.NewReader([]byte(tt.requestBody)),
			})

			handler := unaryHandlerFunc(func(_ context.Context, r *transport.Request, w transport.ResponseWriter) error {
				assert.True(t, requestMatcher.Matches(r), "request mismatch: received %v", r)

				if tt.responseError != nil {
					return tt.responseError
				}

				if tt.responseHeaders.Len() > 0 {
					w.AddHeaders(tt.responseHeaders)
				}

				_, err := w.Write([]byte(tt.responseBody))
				assert.NoError(t, err, "failed to write response for %v", r)
				return err
			})

			ctx, cancel := context.WithTimeout(rootCtx, 200*time.Millisecond)
			defer cancel()

			registry := staticRegistry{Handler: handler}
			trans.WithRegistry(registry, func(o transport.UnaryOutbound) {
				res, err := o.Call(ctx, &transport.Request{
					Caller:    testCaller,
					Service:   testService,
					Procedure: testProcedure,
					Encoding:  raw.Encoding,
					Headers:   tt.requestHeaders,
					Body:      bytes.NewReader([]byte(tt.requestBody)),
				})

				if tt.wantError != nil {
					if assert.Error(t, err, "%T: expected error, got %v", trans, res) {
						tt.wantError(err)

						// none of the errors returned by Call can be valid
						// Handler errors.
						_, ok := err.(errors.HandlerError)
						assert.False(t, ok, "%T: %T must not be a HandlerError", trans, err)
					}
				} else {
					responseMatcher := transporttest.NewResponseMatcher(t, &transport.Response{
						Headers: tt.responseHeaders,
						Body:    ioutil.NopCloser(bytes.NewReader([]byte(tt.responseBody))),
					})

					if assert.NoError(t, err, "%T: call failed", trans) {
						assert.True(t, responseMatcher.Matches(res), "%T: response mismatch", trans)
					}
				}
			})
		}
	}
}
예제 #9
0
파일: outbound.go 프로젝트: yarpc/yarpc-go
// Call makes a HTTP request
func (o *Outbound) Call(ctx context.Context, treq *transport.Request) (*transport.Response, error) {
	if !o.started.Load() {
		// panic because there's no recovery from this
		panic(errOutboundNotStarted)
	}
	start := time.Now()
	deadline, _ := ctx.Deadline()
	ttl := deadline.Sub(start)

	peer, err := o.getPeerForRequest(ctx, treq)
	if err != nil {
		return nil, err
	}
	endRequest := peer.StartRequest()
	defer endRequest()

	req, err := o.createRequest(peer, treq)
	if err != nil {
		return nil, err
	}

	req.Header = applicationHeaders.ToHTTPHeaders(treq.Headers, nil)
	ctx, req, span := o.withOpentracingSpan(ctx, req, treq, start)
	defer span.Finish()
	req = o.withCoreHeaders(req, treq, ttl)

	client, err := o.getHTTPClient(peer)
	if err != nil {
		return nil, err
	}

	response, err := client.Do(req.WithContext(ctx))

	if err != nil {
		// Workaround borrowed from ctxhttp until
		// https://github.com/golang/go/issues/17711 is resolved.
		select {
		case <-ctx.Done():
			err = ctx.Err()
		default:
		}

		span.SetTag("error", true)
		span.LogEvent(err.Error())
		if err == context.DeadlineExceeded {
			end := time.Now()
			return nil, errors.ClientTimeoutError(treq.Service, treq.Procedure, end.Sub(start))
		}

		return nil, err
	}

	span.SetTag("http.status_code", response.StatusCode)

	if response.StatusCode >= 200 && response.StatusCode < 300 {
		appHeaders := applicationHeaders.FromHTTPHeaders(
			response.Header, transport.NewHeaders())
		return &transport.Response{
			Headers: appHeaders,
			Body:    response.Body,
		}, nil
	}

	return nil, getErrFromResponse(response)
}