예제 #1
0
func runRetryTest(t *testing.T, f func(r *retryTest)) {
	r := &retryTest{}
	defer testutils.SetTimeout(t, time.Second)()
	r.setup()
	defer testutils.ResetSleepStub(&timeSleep)

	withSetup(t, func(hypCh *tchannel.Channel, hostPort string) {
		json.Register(hypCh, json.Handlers{"ad": r.adHandler}, nil)

		// Advertise failures cause warning log messages.
		opts := testutils.NewOpts().
			SetServiceName("my-client").
			AddLogFilter("Hyperbahn client registration failed", 10)
		serverCh := testutils.NewServer(t, opts)
		defer serverCh.Close()

		var err error
		r.ch = serverCh
		r.client, err = NewClient(serverCh, configFor(hostPort), &ClientOptions{
			Handler:      r,
			FailStrategy: FailStrategyIgnore,
		})
		require.NoError(t, err, "NewClient")
		defer r.client.Close()
		f(r)
		r.mock.AssertExpectations(t)
	})
}
예제 #2
0
func TestInitialAdvertiseFailedRetryBackoff(t *testing.T) {
	defer testutils.SetTimeout(t, 2*time.Second)()

	clientOpts := stubbedSleep()
	sleepArgs, sleepBlock, sleepClose := testutils.SleepStub(&clientOpts.TimeSleep)

	// We expect to retry 5 times,
	go func() {
		for attempt := uint(0); attempt < 5; attempt++ {
			maxSleepFor := advertiseRetryInterval * time.Duration(1<<attempt)
			got := <-sleepArgs
			assert.True(t, got <= maxSleepFor,
				"Initial advertise attempt %v expected sleep %v < %v", attempt, got, maxSleepFor)
			sleepBlock <- struct{}{}
		}
		sleepClose()
	}()

	withSetup(t, func(hypCh *tchannel.Channel, hostPort string) {
		serverCh := testutils.NewServer(t, nil)
		defer serverCh.Close()

		client, err := NewClient(serverCh, configFor(hostPort), clientOpts)
		require.NoError(t, err, "NewClient")
		defer client.Close()
		assert.Error(t, client.Advertise(), "Advertise without handler should fail")
	})
}
예제 #3
0
func testStreamArg(t *testing.T, f func(argWriter ArgWriter, argReader ArgReader)) {
	defer testutils.SetTimeout(t, 2*time.Second)()
	ctx, cancel := NewContext(time.Second)
	defer cancel()

	helper := streamHelper{t}
	WithVerifiedServer(t, nil, func(ch *Channel, hostPort string) {
		ch.Register(streamPartialHandler(t, true /* report errors */), "echoStream")

		argWriter, argReader := helper.startCall(ctx, ch, hostPort, ch.ServiceName())
		verifyBytes := func(n byte) {
			require.NoError(t, writeFlushBytes(argWriter, []byte{n}), "arg3 write failed")

			arg3 := make([]byte, int(n))
			_, err := io.ReadFull(argReader, arg3)
			require.NoError(t, err, "arg3 read failed")

			assert.Equal(t, makeRepeatedBytes(n), arg3, "arg3 result mismatch")
		}

		verifyBytes(0)
		verifyBytes(5)
		verifyBytes(100)
		verifyBytes(1)

		f(argWriter, argReader)
	})
}
예제 #4
0
func TestStatsCalls(t *testing.T) {
	defer testutils.SetTimeout(t, time.Second)()

	initialTime := time.Date(2015, 2, 1, 10, 10, 0, 0, time.UTC)
	clientNow, clientNowFn := testutils.NowStub(initialTime)
	serverNow, serverNowFn := testutils.NowStub(initialTime)
	clientNowFn(100 * time.Millisecond)
	serverNowFn(50 * time.Millisecond)

	clientStats := newRecordingStatsReporter()
	serverStats := newRecordingStatsReporter()
	serverOpts := testutils.NewOpts().
		SetStatsReporter(serverStats).
		SetTimeNow(serverNow)
	WithVerifiedServer(t, serverOpts, func(serverCh *Channel, hostPort string) {
		handler := raw.Wrap(newTestHandler(t))
		serverCh.Register(handler, "echo")
		serverCh.Register(handler, "app-error")

		ch := testutils.NewClient(t, testutils.NewOpts().
			SetStatsReporter(clientStats).
			SetTimeNow(clientNow))
		defer ch.Close()

		ctx, cancel := NewContext(time.Second * 5)
		defer cancel()

		_, _, _, err := raw.Call(ctx, ch, hostPort, testServiceName, "echo", []byte("Headers"), []byte("Body"))
		require.NoError(t, err)

		outboundTags := tagsForOutboundCall(serverCh, ch, "echo")
		clientStats.Expected.IncCounter("outbound.calls.send", outboundTags, 1)
		clientStats.Expected.IncCounter("outbound.calls.success", outboundTags, 1)
		clientStats.Expected.RecordTimer("outbound.calls.per-attempt.latency", outboundTags, 100*time.Millisecond)
		clientStats.Expected.RecordTimer("outbound.calls.latency", outboundTags, 100*time.Millisecond)
		inboundTags := tagsForInboundCall(serverCh, ch, "echo")
		serverStats.Expected.IncCounter("inbound.calls.recvd", inboundTags, 1)
		serverStats.Expected.IncCounter("inbound.calls.success", inboundTags, 1)
		serverStats.Expected.RecordTimer("inbound.calls.latency", inboundTags, 50*time.Millisecond)

		_, _, resp, err := raw.Call(ctx, ch, hostPort, testServiceName, "app-error", nil, nil)
		require.NoError(t, err)
		require.True(t, resp.ApplicationError(), "expected application error")

		outboundTags = tagsForOutboundCall(serverCh, ch, "app-error")
		clientStats.Expected.IncCounter("outbound.calls.send", outboundTags, 1)
		clientStats.Expected.IncCounter("outbound.calls.per-attempt.app-errors", outboundTags, 1)
		clientStats.Expected.IncCounter("outbound.calls.app-errors", outboundTags, 1)
		clientStats.Expected.RecordTimer("outbound.calls.per-attempt.latency", outboundTags, 100*time.Millisecond)
		clientStats.Expected.RecordTimer("outbound.calls.latency", outboundTags, 100*time.Millisecond)
		inboundTags = tagsForInboundCall(serverCh, ch, "app-error")
		serverStats.Expected.IncCounter("inbound.calls.recvd", inboundTags, 1)
		serverStats.Expected.IncCounter("inbound.calls.app-errors", inboundTags, 1)
		serverStats.Expected.RecordTimer("inbound.calls.latency", inboundTags, 50*time.Millisecond)
	})

	clientStats.Validate(t)
	serverStats.Validate(t)
}
예제 #5
0
func TestCloseSemanticsIsolated(t *testing.T) {
	// We defer the check as we want it to run after the SetTimeout clears the timeout.
	defer goroutines.VerifyNoLeaks(t, nil)
	defer testutils.SetTimeout(t, 2*time.Second)()

	ctx, cancel := NewContext(time.Second)
	defer cancel()

	ct := &closeSemanticsTest{t, ctx, true /* isolated */}
	ct.runTest(ctx)
}
예제 #6
0
func TestCloseSemantics(t *testing.T) {
	// We defer the check as we want it to run after the SetTimeout clears the timeout.
	defer VerifyNoBlockedGoroutines(t)
	defer testutils.SetTimeout(t, 2*time.Second)()

	ctx, cancel := NewContext(time.Second)
	defer cancel()

	ct := &closeSemanticsTest{t, ctx, false /* isolated */}
	ct.runTest(ctx)
}
예제 #7
0
func testStreamArg(t *testing.T, f func(argWriter ArgWriter, argReader ArgReader)) {
	defer testutils.SetTimeout(t, 2*time.Second)()
	ctx, cancel := NewContext(time.Second)
	defer cancel()

	WithVerifiedServer(t, nil, func(ch *Channel, hostPort string) {
		ch.Register(streamPartialHandler(t), "echoStream")

		call, err := ch.BeginCall(ctx, hostPort, ch.PeerInfo().ServiceName, "echoStream", nil)
		require.NoError(t, err, "BeginCall failed")
		require.Nil(t, NewArgWriter(call.Arg2Writer()).Write(nil))

		argWriter, err := call.Arg3Writer()
		require.NoError(t, err, "Arg3Writer failed")

		// Flush arg3 to force the call to start without any arg3.
		require.NoError(t, argWriter.Flush(), "Arg3Writer flush failed")

		// Write out to the stream, and expect to get data
		response := call.Response()

		var arg2 []byte
		require.NoError(t, NewArgReader(response.Arg2Reader()).Read(&arg2), "Arg2Reader failed")
		require.False(t, response.ApplicationError(), "call failed")

		argReader, err := response.Arg3Reader()
		require.NoError(t, err, "Arg3Reader failed")

		verifyBytes := func(n byte) {
			_, err := argWriter.Write([]byte{n})
			require.NoError(t, err, "arg3 write failed")
			require.NoError(t, argWriter.Flush(), "arg3 flush failed")

			arg3 := make([]byte, int(n))
			_, err = io.ReadFull(argReader, arg3)
			require.NoError(t, err, "arg3 read failed")

			assert.Equal(t, makeRepeatedBytes(n), arg3, "arg3 result mismatch")
		}

		verifyBytes(0)
		verifyBytes(5)
		verifyBytes(100)
		verifyBytes(1)

		f(argWriter, argReader)
	})
}
예제 #8
0
func runRetryTest(t *testing.T, f func(r *retryTest)) {
	r := &retryTest{}
	defer testutils.SetTimeout(t, time.Second)()
	r.setup()
	defer testutils.ResetSleepStub(&timeSleep)

	withSetup(t, func(serverCh *tchannel.Channel, hostPort string) {
		json.Register(serverCh, json.Handlers{"ad": r.adHandler}, nil)

		clientCh, err := testutils.NewServer(&testutils.ChannelOpts{ServiceName: "my-client"})
		require.NoError(t, err, "NewServer failed")
		defer clientCh.Close()

		r.ch = clientCh
		r.client, err = NewClient(clientCh, configFor(hostPort), &ClientOptions{
			Handler:      r,
			FailStrategy: FailStrategyIgnore,
		})
		require.NoError(t, err, "NewClient")
		defer r.client.Close()
		f(r)
		r.mock.AssertExpectations(t)
	})
}
예제 #9
0
func TestStatsCalls(t *testing.T) {
	defer testutils.SetTimeout(t, time.Second)()

	initialTime := time.Date(2015, 2, 1, 10, 10, 0, 0, time.UTC)
	nowStub, nowFn := testutils.NowStub(initialTime)
	// time.Now will be called in this order for each call:
	// sender records time they started sending
	// receiver records time the request is sent to application
	// receiver calculates application handler latency
	// sender records call latency
	// so expected inbound latency = incrementor, outbound = 3 * incrementor

	clientStats := newRecordingStatsReporter()
	serverStats := newRecordingStatsReporter()
	serverOpts := testutils.NewOpts().
		SetStatsReporter(serverStats).
		SetTimeNow(nowStub)
	WithVerifiedServer(t, serverOpts, func(serverCh *Channel, hostPort string) {
		handler := raw.Wrap(newTestHandler(t))
		serverCh.Register(handler, "echo")
		serverCh.Register(handler, "app-error")

		ch, err := testutils.NewClient(testutils.NewOpts().
			SetStatsReporter(clientStats).
			SetTimeNow(nowStub))
		require.NoError(t, err)

		ctx, cancel := NewContext(time.Second * 5)
		defer cancel()

		// Set now incrementor to 50ms, so expected Inbound latency is 50ms, outbound is 150ms.
		nowFn(50 * time.Millisecond)
		_, _, _, err = raw.Call(ctx, ch, hostPort, testServiceName, "echo", []byte("Headers"), []byte("Body"))
		require.NoError(t, err)

		outboundTags := tagsForOutboundCall(serverCh, ch, "echo")
		clientStats.Expected.IncCounter("outbound.calls.send", outboundTags, 1)
		clientStats.Expected.IncCounter("outbound.calls.success", outboundTags, 1)
		clientStats.Expected.RecordTimer("outbound.calls.latency", outboundTags, 150*time.Millisecond)
		inboundTags := tagsForInboundCall(serverCh, ch, "echo")
		serverStats.Expected.IncCounter("inbound.calls.recvd", inboundTags, 1)
		serverStats.Expected.IncCounter("inbound.calls.success", inboundTags, 1)
		serverStats.Expected.RecordTimer("inbound.calls.latency", inboundTags, 50*time.Millisecond)

		// Expected inbound latency = 70ms, outbound = 210ms.
		nowFn(70 * time.Millisecond)
		_, _, resp, err := raw.Call(ctx, ch, hostPort, testServiceName, "app-error", nil, nil)
		require.NoError(t, err)
		require.True(t, resp.ApplicationError(), "expected application error")

		outboundTags = tagsForOutboundCall(serverCh, ch, "app-error")
		clientStats.Expected.IncCounter("outbound.calls.send", outboundTags, 1)
		clientStats.Expected.IncCounter("outbound.calls.app-errors", outboundTags, 1)
		clientStats.Expected.RecordTimer("outbound.calls.latency", outboundTags, 210*time.Millisecond)
		inboundTags = tagsForInboundCall(serverCh, ch, "app-error")
		serverStats.Expected.IncCounter("inbound.calls.recvd", inboundTags, 1)
		serverStats.Expected.IncCounter("inbound.calls.app-errors", inboundTags, 1)
		serverStats.Expected.RecordTimer("inbound.calls.latency", inboundTags, 70*time.Millisecond)
	})

	clientStats.Validate(t)
	serverStats.Validate(t)
}
예제 #10
0
func TestStatsCalls(t *testing.T) {
	defer testutils.SetTimeout(t, 2*time.Second)()

	tests := []struct {
		method  string
		wantErr bool
	}{
		{
			method: "echo",
		},
		{
			method:  "app-error",
			wantErr: true,
		},
	}

	for _, tt := range tests {
		initialTime := time.Date(2015, 2, 1, 10, 10, 0, 0, time.UTC)
		clientNow, clientNowFn := testutils.NowStub(initialTime)
		serverNow, serverNowFn := testutils.NowStub(initialTime)
		clientNowFn(100 * time.Millisecond)
		serverNowFn(50 * time.Millisecond)

		clientStats := newRecordingStatsReporter()
		serverStats := newRecordingStatsReporter()
		serverOpts := testutils.NewOpts().
			SetStatsReporter(serverStats).
			SetTimeNow(serverNow)
		WithVerifiedServer(t, serverOpts, func(serverCh *Channel, hostPort string) {
			handler := raw.Wrap(newTestHandler(t))
			serverCh.Register(handler, "echo")
			serverCh.Register(handler, "app-error")

			ch := testutils.NewClient(t, testutils.NewOpts().
				SetStatsReporter(clientStats).
				SetTimeNow(clientNow))
			defer ch.Close()

			ctx, cancel := NewContext(time.Second * 5)
			defer cancel()

			_, _, resp, err := raw.Call(ctx, ch, hostPort, testutils.DefaultServerName, tt.method, nil, nil)
			require.NoError(t, err, "Call(%v) should fail", tt.method)
			assert.Equal(t, tt.wantErr, resp.ApplicationError(), "Call(%v) check application error")

			outboundTags := tagsForOutboundCall(serverCh, ch, tt.method)
			inboundTags := tagsForInboundCall(serverCh, ch, tt.method)

			clientStats.Expected.IncCounter("outbound.calls.send", outboundTags, 1)
			clientStats.Expected.RecordTimer("outbound.calls.per-attempt.latency", outboundTags, 100*time.Millisecond)
			clientStats.Expected.RecordTimer("outbound.calls.latency", outboundTags, 100*time.Millisecond)
			serverStats.Expected.IncCounter("inbound.calls.recvd", inboundTags, 1)
			serverStats.Expected.RecordTimer("inbound.calls.latency", inboundTags, 50*time.Millisecond)

			if tt.wantErr {
				clientStats.Expected.IncCounter("outbound.calls.per-attempt.app-errors", outboundTags, 1)
				clientStats.Expected.IncCounter("outbound.calls.app-errors", outboundTags, 1)
				serverStats.Expected.IncCounter("inbound.calls.app-errors", inboundTags, 1)
			} else {
				clientStats.Expected.IncCounter("outbound.calls.success", outboundTags, 1)
				serverStats.Expected.IncCounter("inbound.calls.success", inboundTags, 1)
			}
		})

		clientStats.Validate(t)
		serverStats.Validate(t)
	}
}
예제 #11
0
func TestStatsWithRetries(t *testing.T) {
	defer testutils.SetTimeout(t, 2*time.Second)()
	a := testutils.DurationArray

	initialTime := time.Date(2015, 2, 1, 10, 10, 0, 0, time.UTC)
	nowStub, nowFn := testutils.NowStub(initialTime)

	clientStats := newRecordingStatsReporter()
	ch := testutils.NewClient(t, testutils.NewOpts().
		SetStatsReporter(clientStats).
		SetTimeNow(nowStub))
	defer ch.Close()

	nowFn(10 * time.Millisecond)
	ctx, cancel := NewContext(time.Second)
	defer cancel()

	// TODO why do we need this??
	opts := testutils.NewOpts().NoRelay()
	WithVerifiedServer(t, opts, func(serverCh *Channel, hostPort string) {
		respErr := make(chan error, 1)
		testutils.RegisterFunc(serverCh, "req", func(ctx context.Context, args *raw.Args) (*raw.Res, error) {
			return &raw.Res{Arg2: args.Arg2, Arg3: args.Arg3}, <-respErr
		})
		ch.Peers().Add(serverCh.PeerInfo().HostPort)

		// timeNow is called at:
		// RunWithRetry start, per-attempt start, per-attempt end.
		// Each attempt takes 2 * step.
		tests := []struct {
			expectErr           error
			numFailures         int
			numAttempts         int
			overallLatency      time.Duration
			perAttemptLatencies []time.Duration
		}{
			{
				numFailures:         0,
				numAttempts:         1,
				perAttemptLatencies: a(10 * time.Millisecond),
				overallLatency:      20 * time.Millisecond,
			},
			{
				numFailures:         1,
				numAttempts:         2,
				perAttemptLatencies: a(10*time.Millisecond, 10*time.Millisecond),
				overallLatency:      40 * time.Millisecond,
			},
			{
				numFailures:         4,
				numAttempts:         5,
				perAttemptLatencies: a(10*time.Millisecond, 10*time.Millisecond, 10*time.Millisecond, 10*time.Millisecond, 10*time.Millisecond),
				overallLatency:      100 * time.Millisecond,
			},
			{
				numFailures:         5,
				numAttempts:         5,
				expectErr:           ErrServerBusy,
				perAttemptLatencies: a(10*time.Millisecond, 10*time.Millisecond, 10*time.Millisecond, 10*time.Millisecond, 10*time.Millisecond),
				overallLatency:      100 * time.Millisecond,
			},
		}

		for _, tt := range tests {
			clientStats.Reset()
			err := ch.RunWithRetry(ctx, func(ctx context.Context, rs *RequestState) error {
				if rs.Attempt > tt.numFailures {
					respErr <- nil
				} else {
					respErr <- ErrServerBusy
				}

				sc := ch.GetSubChannel(serverCh.ServiceName())
				_, err := raw.CallV2(ctx, sc, raw.CArgs{
					Method:      "req",
					CallOptions: &CallOptions{RequestState: rs},
				})
				return err
			})
			assert.Equal(t, tt.expectErr, err, "RunWithRetry unexpected error")

			outboundTags := tagsForOutboundCall(serverCh, ch, "req")
			if tt.expectErr == nil {
				clientStats.Expected.IncCounter("outbound.calls.success", outboundTags, 1)
			}
			clientStats.Expected.IncCounter("outbound.calls.send", outboundTags, int64(tt.numAttempts))
			for i, latency := range tt.perAttemptLatencies {
				clientStats.Expected.RecordTimer("outbound.calls.per-attempt.latency", outboundTags, latency)
				if i > 0 {
					tags := tagsForOutboundCall(serverCh, ch, "req")
					tags["retry-count"] = fmt.Sprint(i)
					clientStats.Expected.IncCounter("outbound.calls.retries", tags, 1)
				}
			}
			clientStats.Expected.RecordTimer("outbound.calls.latency", outboundTags, tt.overallLatency)
			clientStats.Validate(t)
		}
	})
}
예제 #12
0
func TestFramesReleased(t *testing.T) {
	CheckStress(t)

	defer testutils.SetTimeout(t, 10*time.Second)()
	const (
		requestsPerGoroutine = 10
		numGoroutines        = 10
		maxRandArg           = 512 * 1024
	)

	var serverExchanges, clientExchanges string
	pool := NewRecordingFramePool()
	opts := testutils.NewOpts().
		SetServiceName("swap-server").
		SetFramePool(pool)
	WithVerifiedServer(t, opts, func(serverCh *Channel, hostPort string) {
		serverCh.Register(raw.Wrap(&swapper{t}), "swap")

		clientCh, err := NewChannel("swap-client", nil)
		require.NoError(t, err)
		defer clientCh.Close()

		// Create an active connection that can be shared by the goroutines by calling Ping.
		ctx, cancel := NewContext(time.Second)
		defer cancel()
		require.NoError(t, clientCh.Ping(ctx, hostPort))

		var wg sync.WaitGroup
		worker := func() {
			for i := 0; i < requestsPerGoroutine; i++ {
				ctx, cancel := NewContext(time.Second * 5)
				defer cancel()

				require.NoError(t, clientCh.Ping(ctx, hostPort))

				arg2 := testutils.RandBytes(rand.Intn(maxRandArg))
				arg3 := testutils.RandBytes(rand.Intn(maxRandArg))
				resArg2, resArg3, _, err := raw.Call(ctx, clientCh, hostPort, "swap-server", "swap", arg2, arg3)
				if !assert.NoError(t, err, "error during sendRecv") {
					continue
				}

				// We expect the arguments to be swapped.
				if bytes.Compare(arg3, resArg2) != 0 {
					t.Errorf("returned arg2 does not match expected:\n  got %v\n want %v", resArg2, arg3)
				}
				if bytes.Compare(arg2, resArg3) != 0 {
					t.Errorf("returned arg2 does not match expected:\n  got %v\n want %v", resArg3, arg2)
				}
			}
			wg.Done()
		}

		for i := 0; i < numGoroutines; i++ {
			wg.Add(1)
			go worker()
		}

		wg.Wait()

		serverExchanges = CheckEmptyExchanges(serverCh)
		clientExchanges = CheckEmptyExchanges(clientCh)
	})

	// Wait a few milliseconds for the closing of channels to take effect.
	time.Sleep(10 * time.Millisecond)

	if unreleasedCount, isEmpty := pool.CheckEmpty(); isEmpty != "" || unreleasedCount > 0 {
		t.Errorf("Frame pool has %v unreleased frames, errors:\n%v", unreleasedCount, isEmpty)
	}

	// Check the message exchanges and make sure they are all empty.
	if serverExchanges != "" {
		t.Errorf("Found uncleared message exchanges on server:\n%s", serverExchanges)
	}
	if clientExchanges != "" {
		t.Errorf("Found uncleared message exchanges on client:\n%s", clientExchanges)
	}
}
예제 #13
0
func TestCloseSemantics(t *testing.T) {
	// We defer the check as we want it to run after the SetTimeout clears the timeout.
	defer VerifyNoBlockedGoroutines(t)
	defer testutils.SetTimeout(t, 2*time.Second)()
	ctx, cancel := NewContext(time.Second)
	defer cancel()

	makeServer := func(name string) (*Channel, chan struct{}) {
		ch, err := testutils.NewServer(&testutils.ChannelOpts{ServiceName: name})
		require.NoError(t, err)
		c := make(chan struct{})
		testutils.RegisterFunc(t, ch, "stream", func(ctx context.Context, args *raw.Args) (*raw.Res, error) {
			<-c
			return &raw.Res{}, nil
		})
		testutils.RegisterFunc(t, ch, "call", func(ctx context.Context, args *raw.Args) (*raw.Res, error) {
			return &raw.Res{}, nil
		})
		return ch, c
	}

	withNewClient := func(f func(ch *Channel)) {
		ch, err := testutils.NewClient(&testutils.ChannelOpts{ServiceName: "client"})
		require.NoError(t, err)
		f(ch)
		ch.Close()
	}

	call := func(from *Channel, to *Channel) error {
		toPeer := to.PeerInfo()
		_, _, _, err := raw.Call(ctx, from, toPeer.HostPort, toPeer.ServiceName, "call", nil, nil)
		return err
	}

	callStream := func(from *Channel, to *Channel) <-chan struct{} {
		c := make(chan struct{})

		toPeer := to.PeerInfo()
		call, err := from.BeginCall(ctx, toPeer.HostPort, toPeer.ServiceName, "stream", nil)
		require.NoError(t, err)
		require.NoError(t, NewArgWriter(call.Arg2Writer()).Write(nil), "write arg2")
		require.NoError(t, NewArgWriter(call.Arg3Writer()).Write(nil), "write arg3")

		go func() {
			var d []byte
			require.NoError(t, NewArgReader(call.Response().Arg2Reader()).Read(&d), "read arg2 from %v to %v", from.PeerInfo(), to.PeerInfo())
			require.NoError(t, NewArgReader(call.Response().Arg3Reader()).Read(&d), "read arg3")
			c <- struct{}{}
		}()

		return c
	}

	s1, s1C := makeServer("s1")
	s2, s2C := makeServer("s2")

	// Make a call from s1 -> s2, and s2 -> s1
	call1 := callStream(s1, s2)
	call2 := callStream(s2, s1)

	// s1 and s2 are both open, so calls to it should be successful.
	withNewClient(func(ch *Channel) {
		require.NoError(t, call(ch, s1))
		require.NoError(t, call(ch, s2))
	})
	require.NoError(t, call(s1, s2))
	require.NoError(t, call(s2, s1))

	// Close s1, should no longer be able to call it.
	s1.Close()
	assert.Equal(t, ChannelStartClose, s1.State())
	withNewClient(func(ch *Channel) {
		assert.Error(t, call(ch, s1), "closed channel should not accept incoming calls")
		require.NoError(t, call(ch, s2),
			"closed channel with pending incoming calls should allow outgoing calls")
	})

	// Even an existing connection (e.g. from s2) should fail.
	assert.Equal(t, ErrChannelClosed, call(s2, s1), "closed channel should not accept incoming calls")

	require.NoError(t, call(s1, s2),
		"closed channel with pending incoming calls should allow outgoing calls")

	// Once the incoming connection is drained, outgoing calls should fail.
	s1C <- struct{}{}
	<-call2
	runtime.Gosched()
	assert.Equal(t, ChannelInboundClosed, s1.State())
	require.Error(t, call(s1, s2),
		"closed channel with no pending incoming calls should not allow outgoing calls")

	// Now the channel should be completely closed as there are no pending connections.
	s2C <- struct{}{}
	<-call1
	runtime.Gosched()
	assert.Equal(t, ChannelClosed, s1.State())

	// Close s2 so we don't leave any goroutines running.
	s2.Close()
}
예제 #14
0
func TestFramesReleased(t *testing.T) {
	CheckStress(t)

	defer testutils.SetTimeout(t, 10*time.Second)()
	const (
		requestsPerGoroutine = 10
		numGoroutines        = 10
	)

	var serverExchanges, clientExchanges string
	pool := NewRecordingFramePool()
	opts := testutils.NewOpts().
		SetServiceName("swap-server").
		SetFramePool(pool).
		AddLogFilter("Couldn't find handler.", numGoroutines*requestsPerGoroutine)
	WithVerifiedServer(t, opts, func(serverCh *Channel, hostPort string) {
		serverCh.Register(raw.Wrap(&swapper{t}), "swap")

		clientOpts := testutils.NewOpts().SetFramePool(pool)
		clientCh := testutils.NewClient(t, clientOpts)
		defer clientCh.Close()

		// Create an active connection that can be shared by the goroutines by calling Ping.
		ctx, cancel := NewContext(time.Second)
		defer cancel()
		require.NoError(t, clientCh.Ping(ctx, hostPort))

		var wg sync.WaitGroup
		for i := 0; i < numGoroutines; i++ {
			wg.Add(1)
			go func() {
				defer wg.Done()

				for i := 0; i < requestsPerGoroutine; i++ {
					doPingAndCall(t, clientCh, hostPort)
					doErrorCall(t, clientCh, hostPort)
				}
			}()
		}

		wg.Wait()

		serverExchanges = CheckEmptyExchanges(serverCh)
		clientExchanges = CheckEmptyExchanges(clientCh)
	})

	// Since the test is still running, the timeout goroutine will be running and can be ignored.
	goroutines.VerifyNoLeaks(t, &goroutines.VerifyOpts{
		Exclude: "testutils.SetTimeout",
	})

	if unreleasedCount, isEmpty := pool.CheckEmpty(); isEmpty != "" || unreleasedCount > 0 {
		t.Errorf("Frame pool has %v unreleased frames, errors:\n%v", unreleasedCount, isEmpty)
	}

	// Check the message exchanges and make sure they are all empty.
	if serverExchanges != "" {
		t.Errorf("Found uncleared message exchanges on server:\n%s", serverExchanges)
	}
	if clientExchanges != "" {
		t.Errorf("Found uncleared message exchanges on client:\n%s", clientExchanges)
	}
}