示例#1
0
func TestHeaders(t *testing.T) {
	tests := []struct {
		prefix    string
		transport transport.Headers
		http      http.Header
	}{
		{
			ApplicationHeaderPrefix,
			transport.HeadersFromMap(map[string]string{
				"foo":     "bar",
				"foo-bar": "hello",
			}),
			http.Header{
				"Rpc-Header-Foo":     []string{"bar"},
				"Rpc-Header-Foo-Bar": []string{"hello"},
			},
		},
	}

	for _, tt := range tests {
		m := headerMapper{tt.prefix}
		assert.Equal(t, tt.transport, m.FromHTTPHeaders(tt.http, transport.Headers{}))
		assert.Equal(t, tt.http, m.ToHTTPHeaders(tt.transport, nil))
	}
}
示例#2
0
func (r *Recorder) recordToResponse(cachedRecord *record) transport.Response {
	response := transport.Response{
		Headers: transport.HeadersFromMap(cachedRecord.Response.Headers),
		Body:    ioutil.NopCloser(bytes.NewReader(cachedRecord.Response.Body)),
	}
	return response
}
示例#3
0
func TestEncodeAndDecodeHeaders(t *testing.T) {
	tests := []struct {
		bytes   []byte
		headers map[string]string
	}{
		{[]byte{0x00, 0x00}, nil},
		{
			[]byte{
				0x00, 0x01, // 1 header

				0x00, 0x05, // length = 5
				'h', 'e', 'l', 'l', 'o',

				0x00, 0x05, // lengtth = 5
				'w', 'o', 'r', 'l', 'd',
			},
			map[string]string{"hello": "world"},
		},
	}

	for _, tt := range tests {
		headers := transport.HeadersFromMap(tt.headers)
		assert.Equal(t, tt.bytes, encodeHeaders(headers))

		result, err := decodeHeaders(bytes.NewReader(tt.bytes))
		if assert.NoError(t, err) {
			assert.Equal(t, headers, result)
		}
	}
}
示例#4
0
func TestResponseWriter(t *testing.T) {
	recorder := httptest.NewRecorder()
	writer := newResponseWriter(recorder)

	headers := transport.HeadersFromMap(map[string]string{
		"foo":       "bar",
		"shard-key": "123",
	})
	writer.AddHeaders(headers)

	_, err := writer.Write([]byte("hello"))
	require.NoError(t, err)

	assert.Equal(t, "bar", recorder.Header().Get("rpc-header-foo"))
	assert.Equal(t, "123", recorder.Header().Get("rpc-header-shard-key"))
	assert.Equal(t, "hello", recorder.Body.String())
}
示例#5
0
文件: header.go 项目: yarpc/yarpc-go
// readHeaders reads headers using the given function to get the arg reader.
//
// This may be used with the Arg2Reader functions on InboundCall and
// OutboundCallResponse.
//
// If the format is JSON, the headers are expected to be JSON encoded.
//
// This function always returns a non-nil Headers object in case of success.
func readHeaders(format tchannel.Format, getReader func() (tchannel.ArgReader, error)) (transport.Headers, error) {
	if format == tchannel.JSON {
		// JSON is special
		var headers map[string]string
		err := tchannel.NewArgReader(getReader()).ReadJSON(&headers)
		return transport.HeadersFromMap(headers), err
	}

	r, err := getReader()
	if err != nil {
		return transport.Headers{}, err
	}

	headers, err := decodeHeaders(r)
	if err != nil {
		return headers, err
	}

	return headers, r.Close()
}
示例#6
0
func TestHandlerHeaders(t *testing.T) {
	mockCtrl := gomock.NewController(t)
	defer mockCtrl.Finish()

	tests := []struct {
		giveHeaders http.Header

		wantTTL     time.Duration
		wantHeaders map[string]string
	}{
		{
			giveHeaders: http.Header{
				TTLMSHeader:      {"1000"},
				"Rpc-Header-Foo": {"bar"},
			},
			wantTTL: time.Second,
			wantHeaders: map[string]string{
				"foo": "bar",
			},
		},
		{
			giveHeaders: http.Header{
				TTLMSHeader: {"100"},
				"Rpc-Foo":   {"ignored"},
			},
			wantTTL:     100 * time.Millisecond,
			wantHeaders: map[string]string{},
		},
	}

	for _, tt := range tests {
		registry := transporttest.NewMockRegistry(mockCtrl)
		rpcHandler := transporttest.NewMockUnaryHandler(mockCtrl)
		spec := transport.NewUnaryHandlerSpec(rpcHandler)

		registry.EXPECT().GetHandlerSpec("service", "hello").Return(spec, nil)

		httpHandler := handler{Registry: registry}

		rpcHandler.EXPECT().Handle(
			transporttest.NewContextMatcher(t,
				transporttest.ContextTTL(tt.wantTTL),
			),
			transporttest.NewRequestMatcher(t,
				&transport.Request{
					Caller:    "caller",
					Service:   "service",
					Encoding:  raw.Encoding,
					Procedure: "hello",
					Headers:   transport.HeadersFromMap(tt.wantHeaders),
					Body:      bytes.NewReader([]byte("world")),
				}),
			gomock.Any(),
		).Return(nil)

		headers := http.Header{}
		for k, vs := range tt.giveHeaders {
			for _, v := range vs {
				headers.Add(k, v)
			}
		}
		headers.Set(CallerHeader, "caller")
		headers.Set(ServiceHeader, "service")
		headers.Set(EncodingHeader, "raw")
		headers.Set(ProcedureHeader, "hello")

		req := &http.Request{
			Method: "POST",
			Header: headers,
			Body:   ioutil.NopCloser(bytes.NewReader([]byte("world"))),
		}
		rw := httptest.NewRecorder()
		httpHandler.ServeHTTP(rw, req)
		assert.Equal(t, 200, rw.Code, "expected 200 status code")
	}
}
示例#7
0
func TestHandlerErrors(t *testing.T) {
	mockCtrl := gomock.NewController(t)
	defer mockCtrl.Finish()

	tests := []struct {
		format  tchannel.Format
		headers []byte

		wantHeaders map[string]string
	}{
		{
			format:      tchannel.JSON,
			headers:     []byte(`{"Rpc-Header-Foo": "bar"}`),
			wantHeaders: map[string]string{"rpc-header-foo": "bar"},
		},
		{
			format: tchannel.Thrift,
			headers: []byte{
				0x00, 0x01, // 1 header
				0x00, 0x03, 'F', 'o', 'o', // Foo
				0x00, 0x03, 'B', 'a', 'r', // Bar
			},
			wantHeaders: map[string]string{"foo": "Bar"},
		},
	}

	for _, tt := range tests {
		rpcHandler := transporttest.NewMockUnaryHandler(mockCtrl)
		registry := transporttest.NewMockRegistry(mockCtrl)

		spec := transport.NewUnaryHandlerSpec(rpcHandler)
		tchHandler := handler{Registry: registry}

		registry.EXPECT().GetHandlerSpec("service", "hello").Return(spec, nil)

		rpcHandler.EXPECT().Handle(
			transporttest.NewContextMatcher(t),
			transporttest.NewRequestMatcher(t,
				&transport.Request{
					Caller:    "caller",
					Service:   "service",
					Headers:   transport.HeadersFromMap(tt.wantHeaders),
					Encoding:  transport.Encoding(tt.format),
					Procedure: "hello",
					Body:      bytes.NewReader([]byte("world")),
				}),
			gomock.Any(),
		).Return(nil)

		respRecorder := newResponseRecorder()

		ctx, cancel := context.WithTimeout(context.Background(), time.Second)
		defer cancel()
		tchHandler.handle(ctx, &fakeInboundCall{
			service: "service",
			caller:  "caller",
			format:  tt.format,
			method:  "hello",
			arg2:    tt.headers,
			arg3:    []byte("world"),
			resp:    respRecorder,
		})

		assert.NoError(t, respRecorder.systemErr, "did not expect an error")
	}
}
示例#8
0
func TestResponseWriter(t *testing.T) {
	tests := []struct {
		format           tchannel.Format
		apply            func(*responseWriter)
		arg2             []byte
		arg3             []byte
		applicationError bool
	}{
		{
			format: tchannel.Raw,
			apply: func(w *responseWriter) {
				headers := transport.HeadersFromMap(map[string]string{"foo": "bar"})
				w.AddHeaders(headers)
				_, err := w.Write([]byte("hello "))
				require.NoError(t, err)
				_, err = w.Write([]byte("world"))
				require.NoError(t, err)
			},
			arg2: []byte{
				0x00, 0x01,
				0x00, 0x03, 'f', 'o', 'o',
				0x00, 0x03, 'b', 'a', 'r',
			},
			arg3: []byte("hello world"),
		},
		{
			format: tchannel.Raw,
			apply: func(w *responseWriter) {
				_, err := w.Write([]byte("foo"))
				require.NoError(t, err)
				_, err = w.Write([]byte("bar"))
				require.NoError(t, err)
			},
			arg2: []byte{0x00, 0x00},
			arg3: []byte("foobar"),
		},
		{
			format: tchannel.JSON,
			apply: func(w *responseWriter) {
				headers := transport.HeadersFromMap(map[string]string{"foo": "bar"})
				w.AddHeaders(headers)

				_, err := w.Write([]byte("{"))
				require.NoError(t, err)

				_, err = w.Write([]byte("}"))
				require.NoError(t, err)
			},
			arg2: []byte(`{"foo":"bar"}` + "\n"),
			arg3: []byte("{}"),
		},
		{
			format: tchannel.JSON,
			apply: func(w *responseWriter) {
				_, err := w.Write([]byte("{"))
				require.NoError(t, err)

				_, err = w.Write([]byte("}"))
				require.NoError(t, err)
			},
			arg2: []byte("{}\n"),
			arg3: []byte("{}"),
		},
		{
			format: tchannel.Raw,
			apply: func(w *responseWriter) {
				w.SetApplicationError()
				_, err := w.Write([]byte("hello"))
				require.NoError(t, err)
			},
			arg2:             []byte{0x00, 0x00},
			arg3:             []byte("hello"),
			applicationError: true,
		},
	}

	for _, tt := range tests {
		call := &fakeInboundCall{format: tt.format}
		resp := newResponseRecorder()
		call.resp = resp

		w := newResponseWriter(new(transport.Request), call)
		tt.apply(w)
		assert.NoError(t, w.Close())

		assert.Nil(t, resp.systemErr)
		assert.Equal(t, tt.arg2, resp.arg2.Bytes())
		assert.Equal(t, tt.arg3, resp.arg3.Bytes())

		if tt.applicationError {
			assert.True(t, resp.applicationError, "expected an application error")
		}
	}
}
示例#9
0
func TestReadAndWriteHeaders(t *testing.T) {
	tests := []struct {
		format tchannel.Format

		// the headers are serialized in an undefined order so the encoding
		// must be one of the following
		bytes   []byte
		orBytes []byte

		headers map[string]string
	}{
		{
			tchannel.Raw,
			[]byte{
				0x00, 0x02,
				0x00, 0x01, 'a', 0x00, 0x01, '1',
				0x00, 0x01, 'b', 0x00, 0x01, '2',
			},
			[]byte{
				0x00, 0x02,
				0x00, 0x01, 'b', 0x00, 0x01, '2',
				0x00, 0x01, 'a', 0x00, 0x01, '1',
			},
			map[string]string{"a": "1", "b": "2"},
		},
		{
			tchannel.JSON,
			[]byte(`{"a":"1","b":"2"}` + "\n"),
			[]byte(`{"b":"2","a":"1"}` + "\n"),
			map[string]string{"a": "1", "b": "2"},
		},
		{
			tchannel.Thrift,
			[]byte{
				0x00, 0x02,
				0x00, 0x01, 'a', 0x00, 0x01, '1',
				0x00, 0x01, 'b', 0x00, 0x01, '2',
			},
			[]byte{
				0x00, 0x02,
				0x00, 0x01, 'b', 0x00, 0x01, '2',
				0x00, 0x01, 'a', 0x00, 0x01, '1',
			},
			map[string]string{"a": "1", "b": "2"},
		},
	}

	for _, tt := range tests {
		headers := transport.HeadersFromMap(tt.headers)

		buffer := newBufferArgWriter()
		err := writeHeaders(tt.format, headers, func() (tchannel.ArgWriter, error) {
			return buffer, nil
		})
		require.NoError(t, err)

		// Result must match either tt.bytes or tt.orBytes.
		if !bytes.Equal(tt.bytes, buffer.Bytes()) {
			assert.Equal(t, tt.orBytes, buffer.Bytes(), "failed for %v", tt.format)
		}

		result, err := readHeaders(tt.format, func() (tchannel.ArgReader, error) {
			reader := ioutil.NopCloser(bytes.NewReader(buffer.Bytes()))
			return tchannel.ArgReader(reader), nil
		})
		require.NoError(t, err)
		assert.Equal(t, headers, result, "failed for %v", tt.format)
	}
}