예제 #1
0
func TestChain(t *testing.T) {
	mockCtrl := gomock.NewController(t)
	defer mockCtrl.Finish()

	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
	defer cancel()
	req := &transport.Request{
		Caller:    "somecaller",
		Service:   "someservice",
		Encoding:  transport.Encoding("raw"),
		Procedure: "hello",
		Body:      bytes.NewReader([]byte{1, 2, 3}),
	}
	resw := new(transporttest.FakeResponseWriter)

	h := transporttest.NewMockUnaryHandler(mockCtrl)
	h.EXPECT().Handle(ctx, req, resw).After(
		h.EXPECT().Handle(ctx, req, resw).Return(errors.New("great sadness")),
	).Return(nil)

	before := &countInterceptor{}
	after := &countInterceptor{}
	err := transport.ApplyInterceptor(
		h, Chain(before, retryInterceptor, after),
	).Handle(ctx, req, resw)

	assert.NoError(t, err, "expected success")
	assert.Equal(t, 1, before.Count, "expected outer interceptor to be called once")
	assert.Equal(t, 2, after.Count, "expected inner interceptor to be called twice")
}
예제 #2
0
func TestChain(t *testing.T) {
	mockCtrl := gomock.NewController(t)
	defer mockCtrl.Finish()

	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
	defer cancel()
	req := &transport.Request{
		Caller:    "somecaller",
		Service:   "someservice",
		Encoding:  transport.Encoding("raw"),
		Procedure: "hello",
		Body:      bytes.NewReader([]byte{1, 2, 3}),
	}
	res := &transport.Response{
		Body: ioutil.NopCloser(bytes.NewReader([]byte{4, 5, 6})),
	}

	o := transporttest.NewMockUnaryOutbound(mockCtrl)
	o.EXPECT().Call(ctx, req).After(
		o.EXPECT().Call(ctx, req).Return(nil, errors.New("great sadness")),
	).Return(res, nil)

	before := &countFilter{}
	after := &countFilter{}
	gotRes, err := transport.ApplyFilter(
		o, Chain(before, retryFilter, after)).Call(ctx, req)

	assert.NoError(t, err, "expected success")
	assert.Equal(t, 1, before.Count, "expected outer filter to be called once")
	assert.Equal(t, 2, after.Count, "expected inner filter to be called twice")
	assert.Equal(t, res, gotRes, "expected response to match")
}
예제 #3
0
파일: handler.go 프로젝트: yarpc/yarpc-go
func (h handler) callHandler(ctx context.Context, call inboundCall, start time.Time) error {
	_, ok := ctx.Deadline()
	if !ok {
		return tchannel.ErrTimeoutRequired
	}

	treq := &transport.Request{
		Caller:    call.CallerName(),
		Service:   call.ServiceName(),
		Encoding:  transport.Encoding(call.Format()),
		Procedure: call.MethodString(),
	}

	ctx, headers, err := readRequestHeaders(ctx, call.Format(), call.Arg2Reader)
	if err != nil {
		return encoding.RequestHeadersDecodeError(treq, err)
	}
	treq.Headers = headers

	if tcall, ok := call.(tchannelCall); ok {
		tracer := h.deps.Tracer()
		ctx = tchannel.ExtractInboundSpan(ctx, tcall.InboundCall, headers.Items(), tracer)
	}

	body, err := call.Arg3Reader()
	if err != nil {
		return err
	}
	defer body.Close()
	treq.Body = body

	rw := newResponseWriter(treq, call)
	defer rw.Close() // TODO(abg): log if this errors

	treq, err = request.Validate(ctx, treq)
	if err != nil {
		return err
	}

	spec, err := h.Registry.GetHandlerSpec(treq.Service, treq.Procedure)
	if err != nil {
		return err
	}

	switch spec.Type() {
	case transport.Unary:
		treq, err = request.ValidateUnary(ctx, treq)
		if err == nil {
			err = internal.SafelyCallUnaryHandler(ctx, spec.Unary(), start, treq, rw)
		}

	default:
		err = errors.UnsupportedTypeError{Transport: "TChannel", Type: string(spec.Type())}
	}

	return err
}
예제 #4
0
파일: handler.go 프로젝트: yarpc/yarpc-go
func (h handler) callHandler(w http.ResponseWriter, req *http.Request, start time.Time) error {
	treq := &transport.Request{
		Caller:    popHeader(req.Header, CallerHeader),
		Service:   popHeader(req.Header, ServiceHeader),
		Procedure: popHeader(req.Header, ProcedureHeader),
		Encoding:  transport.Encoding(popHeader(req.Header, EncodingHeader)),
		Headers:   applicationHeaders.FromHTTPHeaders(req.Header, transport.Headers{}),
		Body:      req.Body,
	}

	ctx := req.Context()

	v := request.Validator{Request: treq}
	ctx, cancel := v.ParseTTL(ctx, popHeader(req.Header, TTLMSHeader))
	defer cancel()

	ctx, span := h.createSpan(ctx, req, treq, start)

	treq, err := v.Validate(ctx)
	if err != nil {
		return err
	}

	spec, err := h.Registry.GetHandlerSpec(treq.Service, treq.Procedure)
	if err != nil {
		return updateSpanWithErr(span, err)
	}

	switch spec.Type() {
	case transport.Unary:
		defer span.Finish()

		ctx, cancel := v.ParseTTL(ctx, popHeader(req.Header, TTLMSHeader))
		defer cancel()

		treq, err = v.ValidateUnary(ctx)
		if err != nil {
			return err
		}
		err = internal.SafelyCallUnaryHandler(ctx, spec.Unary(), start, treq, newResponseWriter(w))

	case transport.Oneway:
		treq, err = v.ValidateOneway(ctx)
		if err != nil {
			return err
		}
		err = handleOnewayRequest(ctx, span, treq, spec.Oneway())

	default:
		err = errors.UnsupportedTypeError{Transport: "HTTP", Type: string(spec.Type())}
	}

	return updateSpanWithErr(span, err)
}
예제 #5
0
// Request returns a new randomized request.
func (r *randomGenerator) Request() transport.Request {
	bodyData := []byte(r.Atom())

	return transport.Request{
		Caller:          r.Atom(),
		Service:         r.Atom(),
		Encoding:        transport.Encoding(r.Atom()),
		Procedure:       r.Atom(),
		Headers:         r.Headers(),
		ShardKey:        r.Atom(),
		RoutingKey:      r.Atom(),
		RoutingDelegate: r.Atom(),
		Body:            ioutil.NopCloser(bytes.NewReader(bodyData)),
	}
}
예제 #6
0
func TestReqMeta(t *testing.T) {
	tests := []struct {
		build func(*ReqMetaBuilder) *ReqMetaBuilder

		wantEncoding  transport.Encoding
		wantHeaders   yarpc.Headers
		wantCaller    string
		wantProcedure string
		wantService   string
	}{
		{
			build: func(r *ReqMetaBuilder) *ReqMetaBuilder {
				return r
			},
			wantEncoding:  "",
			wantHeaders:   yarpc.NewHeaders(),
			wantCaller:    "",
			wantProcedure: "",
			wantService:   "",
		},
		{
			build: func(r *ReqMetaBuilder) *ReqMetaBuilder {
				return r.
					Encoding(transport.Encoding("myencoding")).
					Headers(yarpc.NewHeaders().With("foo", "bar")).
					Caller("caller").
					Service("service").
					Procedure("procedure")
			},
			wantEncoding:  "myencoding",
			wantHeaders:   yarpc.NewHeaders().With("foo", "bar"),
			wantCaller:    "caller",
			wantService:   "service",
			wantProcedure: "procedure",
		},
	}

	for _, tt := range tests {
		reqMeta := tt.build(NewReqMetaBuilder()).Build()
		assert.Equal(t, tt.wantEncoding, reqMeta.Encoding())
		assert.Equal(t, tt.wantHeaders, reqMeta.Headers())
		assert.Equal(t, tt.wantCaller, reqMeta.Caller())
		assert.Equal(t, tt.wantProcedure, reqMeta.Procedure())
		assert.Equal(t, tt.wantService, reqMeta.Service())
	}
}
예제 #7
0
func TestHandlerErrors(t *testing.T) {
	mockCtrl := gomock.NewController(t)
	defer mockCtrl.Finish()

	tests := []struct {
		format  tchannel.Format
		headers []byte

		wantHeaders map[string]string
	}{
		{
			format:      tchannel.JSON,
			headers:     []byte(`{"Rpc-Header-Foo": "bar"}`),
			wantHeaders: map[string]string{"rpc-header-foo": "bar"},
		},
		{
			format: tchannel.Thrift,
			headers: []byte{
				0x00, 0x01, // 1 header
				0x00, 0x03, 'F', 'o', 'o', // Foo
				0x00, 0x03, 'B', 'a', 'r', // Bar
			},
			wantHeaders: map[string]string{"foo": "Bar"},
		},
	}

	for _, tt := range tests {
		rpcHandler := transporttest.NewMockUnaryHandler(mockCtrl)
		registry := transporttest.NewMockRegistry(mockCtrl)

		spec := transport.NewUnaryHandlerSpec(rpcHandler)
		tchHandler := handler{Registry: registry}

		registry.EXPECT().GetHandlerSpec("service", "hello").Return(spec, nil)

		rpcHandler.EXPECT().Handle(
			transporttest.NewContextMatcher(t),
			transporttest.NewRequestMatcher(t,
				&transport.Request{
					Caller:    "caller",
					Service:   "service",
					Headers:   transport.HeadersFromMap(tt.wantHeaders),
					Encoding:  transport.Encoding(tt.format),
					Procedure: "hello",
					Body:      bytes.NewReader([]byte("world")),
				}),
			gomock.Any(),
		).Return(nil)

		respRecorder := newResponseRecorder()

		ctx, cancel := context.WithTimeout(context.Background(), time.Second)
		defer cancel()
		tchHandler.handle(ctx, &fakeInboundCall{
			service: "service",
			caller:  "caller",
			format:  tt.format,
			method:  "hello",
			arg2:    tt.headers,
			arg3:    []byte("world"),
			resp:    respRecorder,
		})

		assert.NoError(t, respRecorder.systemErr, "did not expect an error")
	}
}
예제 #8
0
func TestHash(t *testing.T) {
	rgen := newRandomGenerator(42)
	request := rgen.Request()

	recorder := NewRecorder(t)
	requestRecord := recorder.requestToRequestRecord(&request)
	referenceHash := recorder.hashRequestRecord(&requestRecord)

	require.Equal(t, "7195d5a712201d2a", referenceHash)

	// Caller
	r := request
	r.Caller = rgen.Atom()
	requestRecord = recorder.requestToRequestRecord(&r)
	assert.NotEqual(t, recorder.hashRequestRecord(&requestRecord), referenceHash)

	// Service
	r = request
	r.Service = rgen.Atom()
	requestRecord = recorder.requestToRequestRecord(&r)
	assert.NotEqual(t, recorder.hashRequestRecord(&requestRecord), referenceHash)

	// Encoding
	r = request
	r.Encoding = transport.Encoding(rgen.Atom())
	requestRecord = recorder.requestToRequestRecord(&r)
	assert.NotEqual(t, recorder.hashRequestRecord(&requestRecord), referenceHash)

	// Procedure
	r = request
	r.Procedure = rgen.Atom()
	requestRecord = recorder.requestToRequestRecord(&r)
	assert.NotEqual(t, recorder.hashRequestRecord(&requestRecord), referenceHash)

	// Headers
	r = request
	r.Headers = rgen.Headers()
	requestRecord = recorder.requestToRequestRecord(&r)
	assert.NotEqual(t, recorder.hashRequestRecord(&requestRecord), referenceHash)

	// ShardKey
	r = request
	r.ShardKey = rgen.Atom()
	requestRecord = recorder.requestToRequestRecord(&r)
	assert.NotEqual(t, recorder.hashRequestRecord(&requestRecord), referenceHash)

	// RoutingKey
	r = request
	r.RoutingKey = rgen.Atom()
	requestRecord = recorder.requestToRequestRecord(&r)
	assert.NotEqual(t, recorder.hashRequestRecord(&requestRecord), referenceHash)

	// RoutingDelegate
	r = request
	r.RoutingDelegate = rgen.Atom()
	requestRecord = recorder.requestToRequestRecord(&r)
	assert.NotEqual(t, recorder.hashRequestRecord(&requestRecord), referenceHash)

	// Body
	r = request
	request.Body = ioutil.NopCloser(bytes.NewReader([]byte(rgen.Atom())))
	requestRecord = recorder.requestToRequestRecord(&r)
	assert.NotEqual(t, recorder.hashRequestRecord(&requestRecord), referenceHash)
}