// TestHandlePayloadMessageWithNoMessageId tests that agent doesn't ack payload messages
// that do not contain message ids
func TestHandlePayloadMessageWithNoMessageId(t *testing.T) {
	ctrl := gomock.NewController(t)
	defer ctrl.Finish()
	taskEngine := engine.NewMockTaskEngine(ctrl)
	ecsClient := mock_api.NewMockECSClient(ctrl)
	stateManager := statemanager.NewNoopStateManager()
	credentialsManager := credentials.NewManager()

	ctx := context.Background()
	buffer := newPayloadRequestHandler(ctx, taskEngine, ecsClient, clusterName, containerInstanceArn, nil, stateManager, refreshCredentialsHandler{}, credentialsManager)

	// test adding a payload message without the MessageId field
	payloadMessage := &ecsacs.PayloadMessage{
		Tasks: []*ecsacs.Task{
			&ecsacs.Task{
				Arn: aws.String("t1"),
			},
		},
	}
	err := buffer.handleSingleMessage(payloadMessage)
	if err == nil {
		t.Error("Expected error while adding a task with no message id")
	}

	// test adding a payload message with blank MessageId
	payloadMessage.MessageId = aws.String("")
	err = buffer.handleSingleMessage(payloadMessage)

	if err == nil {
		t.Error("Expected error while adding a task with no message id")
	}

}
func TestAcsWsUrl(t *testing.T) {
	ctrl := gomock.NewController(t)
	defer ctrl.Finish()
	taskEngine := engine.NewMockTaskEngine(ctrl)

	taskEngine.EXPECT().Version().Return("Docker version result", nil)

	wsurl := handler.AcsWsUrl("http://endpoint.tld", "myCluster", "myContainerInstance", taskEngine)

	parsed, err := url.Parse(wsurl)
	if err != nil {
		t.Fatal("Should be able to parse url")
	}

	if parsed.Path != "/ws" {
		t.Fatal("Wrong path")
	}

	if parsed.Query().Get("clusterArn") != "myCluster" {
		t.Fatal("Wrong cluster")
	}
	if parsed.Query().Get("containerInstanceArn") != "myContainerInstance" {
		t.Fatal("Wrong cluster")
	}
	if parsed.Query().Get("agentVersion") != version.Version {
		t.Fatal("Wrong cluster")
	}
	if parsed.Query().Get("agentHash") != version.GitHashString() {
		t.Fatal("Wrong cluster")
	}

	if parsed.Query().Get("dockerVersion") != "Docker version result" {
		t.Fatal("Wrong docker version")
	}
}
// TestCredentialsMessageNotAckedWhenTaskNotFound tests if credential messages
// are not acked when the task arn in the message is not found in the task
// engine
func TestCredentialsMessageNotAckedWhenTaskNotFound(t *testing.T) {
	ctrl := gomock.NewController(t)
	defer ctrl.Finish()
	credentialsManager := credentials.NewManager()

	taskEngine := engine.NewMockTaskEngine(ctrl)
	// Return task not found from the engine for GetTaskByArn
	taskEngine.EXPECT().GetTaskByArn(taskArn).Return(nil, false)

	ctx, cancel := context.WithCancel(context.Background())
	handler := newRefreshCredentialsHandler(ctx, cluster, containerInstance, nil, credentialsManager, taskEngine)

	// Start a goroutine to listen for acks. Cancelling the context stops the goroutine
	go func() {
		for {
			select {
			// We never expect the message to be acked
			case <-handler.ackRequest:
				t.Fatalf("Received ack when none expected")
			case <-ctx.Done():
				return
			}
		}
	}()

	// Test adding a credentials message without the MessageId field
	err := handler.handleSingleMessage(message)
	if err == nil {
		t.Error("Expected error updating credentials when the message contains unexpected task arn")
	}
	cancel()
}
// TestAddPayloadTaskAddsNonStoppedTasksAfterStoppedTasks tests if tasks with desired status
// 'RUNNING' are added after tasks with desired status 'STOPPED'
func TestAddPayloadTaskAddsNonStoppedTasksAfterStoppedTasks(t *testing.T) {
	ctrl := gomock.NewController(t)
	defer ctrl.Finish()
	ecsClient := mock_api.NewMockECSClient(ctrl)
	taskEngine := engine.NewMockTaskEngine(ctrl)
	credentialsManager := credentials.NewManager()

	var tasksAddedToEngine []*api.Task
	taskEngine.EXPECT().AddTask(gomock.Any()).Do(func(task *api.Task) {
		tasksAddedToEngine = append(tasksAddedToEngine, task)
	}).Times(2)

	stoppedTaskArn := "stoppedTask"
	runningTaskArn := "runningTask"
	payloadMessage := &ecsacs.PayloadMessage{
		Tasks: []*ecsacs.Task{
			&ecsacs.Task{
				Arn:           aws.String(runningTaskArn),
				DesiredStatus: aws.String("RUNNING"),
			},
			&ecsacs.Task{
				Arn:           aws.String(stoppedTaskArn),
				DesiredStatus: aws.String("STOPPED"),
			},
		},
		MessageId: aws.String(payloadMessageId),
	}

	ctx := context.Background()
	stateManager := statemanager.NewNoopStateManager()
	buffer := newPayloadRequestHandler(ctx, taskEngine, ecsClient, clusterName, containerInstanceArn, nil, stateManager, refreshCredentialsHandler{}, credentialsManager)
	_, ok := buffer.addPayloadTasks(payloadMessage)
	if !ok {
		t.Error("addPayloadTasks returned false")
	}
	if len(tasksAddedToEngine) != 2 {
		t.Errorf("Incorrect number of tasks added to the engine. Expected: %d, got: %d", 2, len(tasksAddedToEngine))
	}

	// Verify if stopped task is added before running task
	firstTaskAdded := tasksAddedToEngine[0]
	if firstTaskAdded.Arn != stoppedTaskArn {
		t.Errorf("Expected first task arn: %s, got: %s", stoppedTaskArn, firstTaskAdded.Arn)
	}
	if firstTaskAdded.DesiredStatus != api.TaskStopped {
		t.Errorf("Expected first task state be be: %s , got: %s", "STOPPED", firstTaskAdded.DesiredStatus.String())
	}

	secondTaskAdded := tasksAddedToEngine[1]
	if secondTaskAdded.Arn != runningTaskArn {
		t.Errorf("Expected second task arn: %s, got: %s", runningTaskArn, secondTaskAdded.Arn)
	}
	if secondTaskAdded.DesiredStatus != api.TaskRunning {
		t.Errorf("Expected second task state be be: %s , got: %s", "RUNNNING", secondTaskAdded.DesiredStatus.String())
	}
}
func TestHandlerReconnects(t *testing.T) {
	ctrl := gomock.NewController(t)
	defer ctrl.Finish()
	taskEngine := engine.NewMockTaskEngine(ctrl)
	ecsclient := mock_api.NewMockECSClient(ctrl)
	statemanager := statemanager.NewNoopStateManager()

	closeWS := make(chan bool)
	server, serverIn, requests, errs, err := startMockAcsServer(t, closeWS)
	if err != nil {
		t.Fatal(err)
	}
	go func() {
		for {
			select {
			case <-requests:
			case <-errs:
			}
		}
	}()

	ecsclient.EXPECT().DiscoverPollEndpoint("myArn").Return(server.URL, nil).Times(10)
	taskEngine.EXPECT().Version().Return("Docker: 1.5.0", nil).AnyTimes()

	ctx, cancel := context.WithCancel(context.Background())
	ended := make(chan bool, 1)
	go func() {
		handler.StartSession(ctx, handler.StartSessionArguments{
			ContainerInstanceArn: "myArn",
			CredentialProvider:   credentials.AnonymousCredentials,
			Config:               &config.Config{Cluster: "someCluster"},
			TaskEngine:           taskEngine,
			ECSClient:            ecsclient,
			StateManager:         statemanager,
			AcceptInvalidCert:    true,
		})
		// This should never return
		ended <- true
	}()
	start := time.Now()
	for i := 0; i < 10; i++ {
		serverIn <- `{"type":"HeartbeatMessage","message":{"healthy":true}}`
		closeWS <- true
	}
	if time.Since(start) > 2*time.Second {
		t.Error("Test took longer than expected; backoff should not have occured for EOF")
	}

	select {
	case <-ended:
		t.Fatal("Should not have stopped session")
	default:
	}
	cancel()
	<-ended
}
// TestHandlerReconnectsCorrectlySetsSendCredentialsURLParameter tests if
// the 'sendCredentials' URL parameter is set correctly for successive
// invocations of startACSSession
func TestHandlerReconnectsCorrectlySetsSendCredentialsURLParameter(t *testing.T) {
	ctrl := gomock.NewController(t)
	defer ctrl.Finish()
	taskEngine := engine.NewMockTaskEngine(ctrl)
	ecsClient := mock_api.NewMockECSClient(ctrl)
	statemanager := statemanager.NewNoopStateManager()

	ctx, cancel := context.WithCancel(context.Background())
	mockWsClient := mock_wsclient.NewMockClientServer(ctrl)

	args := StartSessionArguments{
		ContainerInstanceArn: "myArn",
		CredentialProvider:   credentials.AnonymousCredentials,
		Config:               &config.Config{Cluster: "someCluster"},
		TaskEngine:           taskEngine,
		ECSClient:            ecsClient,
		StateManager:         statemanager,
		AcceptInvalidCert:    true,
		_heartbeatTimeout:    20 * time.Millisecond,
		_heartbeatJitter:     10 * time.Millisecond,
	}
	session := newSessionResources(args)

	mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes()
	mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes()
	mockWsClient.EXPECT().Close().Return(nil).AnyTimes()
	mockWsClient.EXPECT().Serve().Return(io.EOF).AnyTimes()
	gomock.InOrder(
		// When the websocket client connects to ACS for the first
		// time, 'sendCredentials' should be set to true
		mockWsClient.EXPECT().Connect().Do(func() {
			validateSendCredentialsInSession(t, session, "true")
		}).Return(nil),
		// For all subsequent connections to ACS, 'sendCredentials'
		// should be set to false
		mockWsClient.EXPECT().Connect().Do(func() {
			validateSendCredentialsInSession(t, session, "false")
		}).Return(nil).AnyTimes(),
	)

	backoff := utils.NewSimpleBackoff(connectionBackoffMin, connectionBackoffMax, connectionBackoffJitter, connectionBackoffMultiplier)
	timer := newDisconnectionTimer(mockWsClient, args.time(), args.heartbeatTimeout(), args.heartbeatJitter())
	defer timer.Stop()
	go func() {
		for i := 0; i < 10; i++ {
			startACSSession(ctx, mockWsClient, timer, args, backoff, session)
		}
		cancel()
	}()

	// Wait for context to be cancelled
	select {
	case <-ctx.Done():
	}
}
// TestHandlePayloadMessageAckedWhenTaskAdded tests if the handler generates an ack
// after processing a payload message.
func TestHandlePayloadMessageAckedWhenTaskAdded(t *testing.T) {
	ctrl := gomock.NewController(t)
	defer ctrl.Finish()
	ecsClient := mock_api.NewMockECSClient(ctrl)
	stateManager := statemanager.NewNoopStateManager()
	credentialsManager := credentials.NewManager()
	ctx, cancel := context.WithCancel(context.Background())

	taskEngine := engine.NewMockTaskEngine(ctrl)
	var addedTask *api.Task
	taskEngine.EXPECT().AddTask(gomock.Any()).Do(func(task *api.Task) {
		addedTask = task
	}).Times(1)

	var ackRequested *ecsacs.AckRequest
	mockWsClient := mock_wsclient.NewMockClientServer(ctrl)
	mockWsClient.EXPECT().MakeRequest(gomock.Any()).Do(func(ackRequest *ecsacs.AckRequest) {
		ackRequested = ackRequest
		cancel()
	}).Times(1)
	buffer := newPayloadRequestHandler(ctx, taskEngine, ecsClient, clusterName, containerInstanceArn, mockWsClient, stateManager, refreshCredentialsHandler{}, credentialsManager)
	go buffer.start()

	// Send a payload message
	payloadMessage := &ecsacs.PayloadMessage{
		Tasks: []*ecsacs.Task{
			&ecsacs.Task{
				Arn: aws.String("t1"),
			},
		},
		MessageId: aws.String(payloadMessageId),
	}
	err := buffer.handleSingleMessage(payloadMessage)
	if err != nil {
		t.Errorf("Error handling payload message: %v", err)
	}

	// Wait till we get an ack from the ackBuffer
	select {
	case <-ctx.Done():
	}
	// Verify the message id acked
	if aws.StringValue(ackRequested.MessageId) != payloadMessageId {
		t.Errorf("Message Id mismatch. Expected: %s, got: %s", payloadMessageId, aws.StringValue(ackRequested.MessageId))
	}

	// Verify if task added == expected task
	expectedTask := &api.Task{
		Arn: "t1",
	}
	if !reflect.DeepEqual(addedTask, expectedTask) {
		t.Errorf("Mismatch between expected and added tasks, expected: %v, added: %v", expectedTask, addedTask)
	}
}
// TestHandlerReconnectsOnServeErrors tests if the handler retries to
// to establish the session with ACS when ClientServer.Connect() returns errors
func TestHandlerReconnectsOnServeErrors(t *testing.T) {
	ctrl := gomock.NewController(t)
	defer ctrl.Finish()
	taskEngine := engine.NewMockTaskEngine(ctrl)
	taskEngine.EXPECT().Version().Return("Docker: 1.5.0", nil).AnyTimes()

	ecsClient := mock_api.NewMockECSClient(ctrl)
	ecsClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes()

	statemanager := statemanager.NewNoopStateManager()

	ctx, cancel := context.WithCancel(context.Background())
	mockWsClient := mock_wsclient.NewMockClientServer(ctrl)
	mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes()
	mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes()
	mockWsClient.EXPECT().Connect().Return(nil).AnyTimes()
	mockWsClient.EXPECT().Close().Return(nil).AnyTimes()
	gomock.InOrder(
		// Serve fails 10 times
		mockWsClient.EXPECT().Serve().Return(io.EOF).Times(10),
		// Cancel trying to Serve ACS requests on the 11th attempt
		// Failure to retry on Serve() errors should cause the
		// test to time out as the context is never cancelled
		mockWsClient.EXPECT().Serve().Do(func() {
			cancel()
		}).Return(io.EOF),
	)
	session := &mockSession{mockWsClient}

	args := StartSessionArguments{
		ContainerInstanceArn: "myArn",
		CredentialProvider:   credentials.AnonymousCredentials,
		Config:               &config.Config{Cluster: "someCluster"},
		TaskEngine:           taskEngine,
		ECSClient:            ecsClient,
		StateManager:         statemanager,
		AcceptInvalidCert:    true,
		_heartbeatTimeout:    20 * time.Millisecond,
		_heartbeatJitter:     10 * time.Millisecond,
	}
	backoff := utils.NewSimpleBackoff(connectionBackoffMin, connectionBackoffMax, connectionBackoffJitter, connectionBackoffMultiplier)
	go func() {
		startSession(ctx, args, backoff, session)
	}()

	// Wait for context to be cancelled
	select {
	case <-ctx.Done():
	}
}
// TestConnectionIsClosedOnIdle tests if the connection to ACS is closed
// when the channel is idle
func TestConnectionIsClosedOnIdle(t *testing.T) {
	ctrl := gomock.NewController(t)
	defer ctrl.Finish()
	taskEngine := engine.NewMockTaskEngine(ctrl)
	taskEngine.EXPECT().Version().Return("Docker: 1.5.0", nil).AnyTimes()

	ecsClient := mock_api.NewMockECSClient(ctrl)
	statemanager := statemanager.NewNoopStateManager()

	mockWsClient := mock_wsclient.NewMockClientServer(ctrl)
	mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).Do(func(v interface{}) {}).AnyTimes()
	mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).Do(func(v interface{}) {}).AnyTimes()
	mockWsClient.EXPECT().Connect().Return(nil)
	mockWsClient.EXPECT().Serve().Do(func() {
		// Pretend as if the maximum heartbeatTimeout duration has
		// been breached while Serving requests
		time.Sleep(30 * time.Millisecond)
	}).Return(io.EOF)

	connectionClosed := make(chan bool)
	mockWsClient.EXPECT().Close().Do(func() {
		// Record connection closed
		connectionClosed <- true
	}).Return(nil)
	ctx := context.Background()
	backoff := utils.NewSimpleBackoff(connectionBackoffMin, connectionBackoffMax, connectionBackoffJitter, connectionBackoffMultiplier)
	args := StartSessionArguments{
		ContainerInstanceArn: "myArn",
		CredentialProvider:   credentials.AnonymousCredentials,
		Config:               &config.Config{Cluster: "someCluster"},
		TaskEngine:           taskEngine,
		ECSClient:            ecsClient,
		StateManager:         statemanager,
		AcceptInvalidCert:    true,
		_heartbeatTimeout:    20 * time.Millisecond,
		_heartbeatJitter:     10 * time.Millisecond,
	}
	go func() {
		timer := newDisconnectionTimer(mockWsClient, args.time(), args.heartbeatTimeout(), args.heartbeatJitter())
		defer timer.Stop()
		startACSSession(ctx, mockWsClient, timer, args, backoff, &mockSession{})
	}()

	// Wait for connection to be closed. If the connection is not closed
	// due to inactivity, the test will time out
	<-connectionClosed
}
// TestHandlePayloadMessageAddTaskError tests that agent does not ack payload messages
// when task engine fails to add tasks
func TestHandlePayloadMessageAddTaskError(t *testing.T) {
	ctrl := gomock.NewController(t)
	defer ctrl.Finish()
	taskEngine := engine.NewMockTaskEngine(ctrl)
	ecsClient := mock_api.NewMockECSClient(ctrl)
	stateManager := statemanager.NewNoopStateManager()
	credentialsManager := credentials.NewManager()

	// Return error from AddTask
	taskEngine.EXPECT().AddTask(gomock.Any()).Return(fmt.Errorf("oops")).Times(2)

	ctx := context.Background()
	buffer := newPayloadRequestHandler(ctx, taskEngine, ecsClient, clusterName, containerInstanceArn, nil, stateManager, refreshCredentialsHandler{}, credentialsManager)

	// Test AddTask error with RUNNING task
	payloadMessage := &ecsacs.PayloadMessage{
		Tasks: []*ecsacs.Task{
			&ecsacs.Task{
				Arn:           aws.String("t1"),
				DesiredStatus: aws.String("RUNNING"),
			},
		},
		MessageId: aws.String(payloadMessageId),
	}
	err := buffer.handleSingleMessage(payloadMessage)
	if err == nil {
		t.Error("Expected error while adding the task")
	}

	payloadMessage = &ecsacs.PayloadMessage{
		Tasks: []*ecsacs.Task{
			&ecsacs.Task{
				Arn:           aws.String("t1"),
				DesiredStatus: aws.String("STOPPED"),
			},
		},
		MessageId: aws.String(payloadMessageId),
	}
	// Test AddTask error with STOPPED task
	err = buffer.handleSingleMessage(payloadMessage)
	if err == nil {
		t.Error("Expected error while adding the task")
	}
}
// TestHandleRefreshMessageAckedWhenCredentialsUpdated tests that a credential message
// is ackd when the credentials are updated successfully
func TestHandleRefreshMessageAckedWhenCredentialsUpdated(t *testing.T) {
	ctrl := gomock.NewController(t)
	defer ctrl.Finish()
	credentialsManager := credentials.NewManager()

	ctx, cancel := context.WithCancel(context.Background())
	var ackRequested *ecsacs.IAMRoleCredentialsAckRequest

	mockWsClient := mock_wsclient.NewMockClientServer(ctrl)
	mockWsClient.EXPECT().MakeRequest(gomock.Any()).Do(func(ackRequest *ecsacs.IAMRoleCredentialsAckRequest) {
		ackRequested = ackRequest
		cancel()
	}).Times(1)

	taskEngine := engine.NewMockTaskEngine(ctrl)
	// Return a task from the engine for GetTaskByArn
	taskEngine.EXPECT().GetTaskByArn(taskArn).Return(&api.Task{}, true)

	handler := newRefreshCredentialsHandler(ctx, clusterName, containerInstanceArn, mockWsClient, credentialsManager, taskEngine)
	go handler.sendAcks()

	// test adding a credentials message without the MessageId field
	err := handler.handleSingleMessage(message)
	if err != nil {
		t.Errorf("Error updating credentials: %v", err)
	}

	// Wait till we get an ack from the ackBuffer
	select {
	case <-ctx.Done():
	}

	if !reflect.DeepEqual(ackRequested, expectedAck) {
		t.Errorf("Message between expected and requested ack. Expected: %v, Requested: %v", expectedAck, ackRequested)
	}

	creds, exist := credentialsManager.GetTaskCredentials(credentialsId)
	if !exist {
		t.Errorf("Expected credentials to exist for the task")
	}
	if !reflect.DeepEqual(creds, expectedCredentials) {
		t.Errorf("Mismatch between expected credentials and credentials for task. Expected: %v, got: %v", expectedCredentials, creds)
	}
}
// TestHandlePayloadMessageStateSaveError tests that agent does not ack payload messages
// when state saver fails to save state
func TestHandlePayloadMessageStateSaveError(t *testing.T) {
	ctrl := gomock.NewController(t)
	defer ctrl.Finish()
	ecsClient := mock_api.NewMockECSClient(ctrl)
	credentialsManager := credentials.NewManager()

	taskEngine := engine.NewMockTaskEngine(ctrl)
	// Save added task in the addedTask variable
	var addedTask *api.Task
	taskEngine.EXPECT().AddTask(gomock.Any()).Do(func(task *api.Task) {
		addedTask = task
	}).Times(1)

	// State manager returns error on save
	stateManager := mock_statemanager.NewMockStateManager(ctrl)
	stateManager.EXPECT().Save().Return(fmt.Errorf("oops"))

	ctx := context.Background()
	buffer := newPayloadRequestHandler(ctx, taskEngine, ecsClient, clusterName, containerInstanceArn, nil, stateManager, refreshCredentialsHandler{}, credentialsManager)

	// Check if handleSingleMessage returns an error when state manager returns error on Save()
	err := buffer.handleSingleMessage(&ecsacs.PayloadMessage{
		Tasks: []*ecsacs.Task{
			&ecsacs.Task{
				Arn: aws.String("t1"),
			},
		},
		MessageId: aws.String(payloadMessageId),
	})
	if err == nil {
		t.Error("Expected error while adding a task from statemanager")
	}

	// We expect task to be added to the engine even though it hasn't been saved
	expectedTask := &api.Task{
		Arn: "t1",
	}
	if !reflect.DeepEqual(addedTask, expectedTask) {
		t.Errorf("Mismatch between expected and added tasks, expected: %v, added: %v", expectedTask, addedTask)
	}
}
func TestAcsWsUrl(t *testing.T) {
	ctrl := gomock.NewController(t)
	defer ctrl.Finish()
	taskEngine := engine.NewMockTaskEngine(ctrl)

	taskEngine.EXPECT().Version().Return("Docker version result", nil)

	wsurl := acsWsURL(acsURL, "myCluster", "myContainerInstance", taskEngine, &mockSession{})

	parsed, err := url.Parse(wsurl)
	if err != nil {
		t.Fatal("Should be able to parse url")
	}

	if parsed.Path != "/ws" {
		t.Fatal("Wrong path")
	}

	if parsed.Query().Get("clusterArn") != "myCluster" {
		t.Fatal("Wrong cluster")
	}
	if parsed.Query().Get("containerInstanceArn") != "myContainerInstance" {
		t.Fatal("Wrong cluster")
	}
	if parsed.Query().Get("agentVersion") != version.Version {
		t.Fatal("Wrong cluster")
	}
	if parsed.Query().Get("agentHash") != version.GitHashString() {
		t.Fatal("Wrong cluster")
	}
	if parsed.Query().Get("dockerVersion") != "Docker version result" {
		t.Fatal("Wrong docker version")
	}
	if parsed.Query().Get(sendCredentialsURLParameterName) != "true" {
		t.Fatalf("Wrong value set for: %s", sendCredentialsURLParameterName)
	}
	if parsed.Query().Get("seqNum") != "1" {
		t.Fatal("Wrong seqNum")
	}
}
// TestHandlerReconnectsOnDiscoverPollEndpointError tests if handler retries
// to establish the session with ACS on DiscoverPollEndpoint errors
func TestHandlerReconnectsOnDiscoverPollEndpointError(t *testing.T) {
	ctrl := gomock.NewController(t)
	defer ctrl.Finish()
	taskEngine := engine.NewMockTaskEngine(ctrl)
	taskEngine.EXPECT().Version().Return("Docker: 1.5.0", nil).AnyTimes()

	ecsClient := mock_api.NewMockECSClient(ctrl)
	statemanager := statemanager.NewNoopStateManager()

	ctx, cancel := context.WithCancel(context.Background())

	mockWsClient := mock_wsclient.NewMockClientServer(ctrl)
	mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes()
	mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes()
	mockWsClient.EXPECT().Connect().Return(nil).AnyTimes()
	mockWsClient.EXPECT().Close().Return(nil).AnyTimes()
	mockWsClient.EXPECT().Serve().Do(func() {
		// Serve() cancels the context
		cancel()
	}).Return(io.EOF)

	session := &mockSession{mockWsClient}

	gomock.InOrder(
		// DiscoverPollEndpoint returns an error on its first invocation
		ecsClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return("", fmt.Errorf("oops")).Times(1),
		// Second invocation returns a success
		ecsClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).Times(1),
	)
	args := StartSessionArguments{
		ContainerInstanceArn: "myArn",
		CredentialProvider:   credentials.AnonymousCredentials,
		Config:               &config.Config{Cluster: "someCluster"},
		TaskEngine:           taskEngine,
		ECSClient:            ecsClient,
		StateManager:         statemanager,
		AcceptInvalidCert:    true,
		_heartbeatTimeout:    20 * time.Millisecond,
		_heartbeatJitter:     10 * time.Millisecond,
	}
	backoff := utils.NewSimpleBackoff(connectionBackoffMin, connectionBackoffMax, connectionBackoffJitter, connectionBackoffMultiplier)
	go func() {
		startSession(ctx, args, backoff, session)
	}()
	start := time.Now()

	// Wait for context to be cancelled
	select {
	case <-ctx.Done():
	}

	// Measure the duration between retries
	timeSinceStart := time.Since(start)
	if timeSinceStart < connectionBackoffMin {
		t.Errorf("Duration since start is less than minimum threshold for backoff: %s", timeSinceStart.String())
	}

	// The upper limit here should really be connectionBackoffMin + (connectionBackoffMin * jitter)
	// But, it can be off by a few milliseconds to account for execution of other instructions
	// In any case, it should never be higher than 2*connectionBackoffMin
	if timeSinceStart > 2*connectionBackoffMin {
		t.Errorf("Duration since start is greater than maximum anticipated wait time: %v", timeSinceStart.String())
	}

}
示例#15
0
func TestHandlerDoesntLeakGouroutines(t *testing.T) {
	ctrl := gomock.NewController(t)
	defer ctrl.Finish()
	taskEngine := engine.NewMockTaskEngine(ctrl)
	ecsclient := mock_api.NewMockECSClient(ctrl)
	statemanager := statemanager.NewNoopStateManager()
	testTime := ttime.NewTestTime()
	ttime.SetTime(testTime)

	closeWS := make(chan bool)
	server, serverIn, requests, errs, err := startMockAcsServer(t, closeWS)
	if err != nil {
		t.Fatal(err)
	}
	go func() {
		for {
			select {
			case <-requests:
			case <-errs:
			}
		}
	}()

	timesConnected := 0
	ecsclient.EXPECT().DiscoverPollEndpoint("myArn").Return(server.URL, nil).AnyTimes().Do(func(_ interface{}) {
		timesConnected++
	})
	taskEngine.EXPECT().Version().Return("Docker: 1.5.0", nil).AnyTimes()
	taskEngine.EXPECT().AddTask(gomock.Any()).AnyTimes()

	ctx, cancel := context.WithCancel(context.Background())
	ended := make(chan bool, 1)
	go func() {
		handler.StartSession(ctx, handler.StartSessionArguments{"myArn", credentials.AnonymousCredentials, &config.Config{Cluster: "someCluster"}, taskEngine, ecsclient, statemanager, true})
		ended <- true
	}()
	// Warm it up
	serverIn <- `{"type":"HeartbeatMessage","message":{"healthy":true}}`
	serverIn <- samplePayloadMessage

	beforeGoroutines := runtime.NumGoroutine()
	for i := 0; i < 100; i++ {
		serverIn <- `{"type":"HeartbeatMessage","message":{"healthy":true}}`
		serverIn <- samplePayloadMessage
		closeWS <- true
	}

	cancel()
	testTime.Cancel()
	<-ended

	afterGoroutines := runtime.NumGoroutine()

	t.Logf("Gorutines after 1 and after 100 acs messages: %v and %v", beforeGoroutines, afterGoroutines)

	if timesConnected < 50 {
		t.Fatal("Expected times connected to be a large number, was ", timesConnected)
	}
	if afterGoroutines > beforeGoroutines+5 {
		t.Error("Goroutine leak, oh no!")
		pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)
	}

}
示例#16
0
func TestHeartbeatOnlyWhenIdle(t *testing.T) {
	testTime := ttime.NewTestTime()
	ttime.SetTime(testTime)
	ctrl := gomock.NewController(t)
	defer ctrl.Finish()
	taskEngine := engine.NewMockTaskEngine(ctrl)
	ecsclient := mock_api.NewMockECSClient(ctrl)
	statemanager := statemanager.NewNoopStateManager()

	closeWS := make(chan bool)
	server, serverIn, requestsChan, errChan, err := startMockAcsServer(t, closeWS)
	defer close(serverIn)

	go func() {
		for {
			<-requestsChan
		}
	}()
	if err != nil {
		t.Fatal(err)
	}

	// We're testing that it does not reconnect here; must be the case
	ecsclient.EXPECT().DiscoverPollEndpoint("myArn").Return(server.URL, nil).Times(1)
	taskEngine.EXPECT().Version().Return("Docker: 1.5.0", nil).AnyTimes()

	ctx, cancel := context.WithCancel(context.Background())
	ended := make(chan bool, 1)
	go func() {
		handler.StartSession(ctx, handler.StartSessionArguments{
			ContainerInstanceArn: "myArn",
			CredentialProvider:   credentials.AnonymousCredentials,
			Config:               &config.Config{Cluster: "someCluster"},
			TaskEngine:           taskEngine,
			ECSClient:            ecsclient,
			StateManager:         statemanager,
			AcceptInvalidCert:    true,
		})
		ended <- true
	}()

	taskAdded := make(chan bool)
	taskEngine.EXPECT().AddTask(gomock.Any()).Do(func(interface{}) {
		taskAdded <- true
	}).Times(10)
	for i := 0; i < 10; i++ {
		serverIn <- samplePayloadMessage
		testTime.Warp(1 * time.Minute)
		<-taskAdded
	}

	select {
	case <-ended:
		t.Fatal("Should not have stop session")
	case err := <-errChan:
		t.Fatal("Error should not have been returned from server", err)
	default:
	}
	go server.Close()
	cancel()
	<-ended
}
// TestPayloadBufferHandlerWithCredentials tests if the async payloadBufferHandler routine
// acks the payload message and credentials after adding tasks
func TestPayloadBufferHandlerWithCredentials(t *testing.T) {
	ctrl := gomock.NewController(t)
	defer ctrl.Finish()
	ecsClient := mock_api.NewMockECSClient(ctrl)
	stateManager := statemanager.NewNoopStateManager()
	ctx, cancel := context.WithCancel(context.Background())
	credentialsManager := credentials.NewManager()

	// The payload message in the test consists of two tasks, record both of them in
	// the order in which they were added
	taskEngine := engine.NewMockTaskEngine(ctrl)
	var firstAddedTask *api.Task
	var secondAddedTask *api.Task
	gomock.InOrder(
		taskEngine.EXPECT().AddTask(gomock.Any()).Do(func(task *api.Task) {
			firstAddedTask = task
		}),
		taskEngine.EXPECT().AddTask(gomock.Any()).Do(func(task *api.Task) {
			secondAddedTask = task
		}),
	)

	// The payload message in the test consists of two tasks, with credentials set
	// for both. Record the credentials' ack and the payload message ack
	var payloadAckRequested *ecsacs.AckRequest
	var firstTaskCredentialsAckRequested *ecsacs.IAMRoleCredentialsAckRequest
	var secondTaskCredentialsAckRequested *ecsacs.IAMRoleCredentialsAckRequest
	mockWsClient := mock_wsclient.NewMockClientServer(ctrl)
	gomock.InOrder(
		mockWsClient.EXPECT().MakeRequest(gomock.Any()).Do(func(ackRequest *ecsacs.IAMRoleCredentialsAckRequest) {
			firstTaskCredentialsAckRequested = ackRequest
		}),
		mockWsClient.EXPECT().MakeRequest(gomock.Any()).Do(func(ackRequest *ecsacs.IAMRoleCredentialsAckRequest) {
			secondTaskCredentialsAckRequested = ackRequest
		}),
		mockWsClient.EXPECT().MakeRequest(gomock.Any()).Do(func(ackRequest *ecsacs.AckRequest) {
			payloadAckRequested = ackRequest
			// Cancel the context when the ack for the payload message is received
			// This signals a successful workflow in the test
			cancel()
		}),
	)

	refreshCredsHandler := newRefreshCredentialsHandler(ctx, clusterName, containerInstanceArn, mockWsClient, credentialsManager, taskEngine)
	defer refreshCredsHandler.clearAcks()
	refreshCredsHandler.start()
	payloadHandler := newPayloadRequestHandler(ctx, taskEngine, ecsClient, clusterName, containerInstanceArn, mockWsClient, stateManager, refreshCredsHandler, credentialsManager)
	go payloadHandler.start()

	firstTaskArn := "t1"
	firstTaskCredentialsExpiration := "expiration"
	firstTaskCredentialsRoleArn := "r1"
	firstTaskCredentialsAccessKey := "akid"
	firstTaskCredentialsSecretKey := "skid"
	firstTaskCredentialsSessionToken := "token"
	firstTaskCredentialsId := "credsid1"

	secondTaskArn := "t2"
	secondTaskCredentialsExpiration := "expirationSecond"
	secondTaskCredentialsRoleArn := "r2"
	secondTaskCredentialsAccessKey := "akid2"
	secondTaskCredentialsSecretKey := "skid2"
	secondTaskCredentialsSessionToken := "token2"
	secondTaskCredentialsId := "credsid2"

	// Send a payload message to the payloadBufferChannel
	payloadHandler.messageBuffer <- &ecsacs.PayloadMessage{
		Tasks: []*ecsacs.Task{
			&ecsacs.Task{
				Arn: aws.String(firstTaskArn),
				RoleCredentials: &ecsacs.IAMRoleCredentials{
					AccessKeyId:     aws.String(firstTaskCredentialsAccessKey),
					Expiration:      aws.String(firstTaskCredentialsExpiration),
					RoleArn:         aws.String(firstTaskCredentialsRoleArn),
					SecretAccessKey: aws.String(firstTaskCredentialsSecretKey),
					SessionToken:    aws.String(firstTaskCredentialsSessionToken),
					CredentialsId:   aws.String(firstTaskCredentialsId),
				},
			},
			&ecsacs.Task{
				Arn: aws.String(secondTaskArn),
				RoleCredentials: &ecsacs.IAMRoleCredentials{
					AccessKeyId:     aws.String(secondTaskCredentialsAccessKey),
					Expiration:      aws.String(secondTaskCredentialsExpiration),
					RoleArn:         aws.String(secondTaskCredentialsRoleArn),
					SecretAccessKey: aws.String(secondTaskCredentialsSecretKey),
					SessionToken:    aws.String(secondTaskCredentialsSessionToken),
					CredentialsId:   aws.String(secondTaskCredentialsId),
				},
			},
		},
		MessageId:            aws.String(payloadMessageId),
		ClusterArn:           aws.String(cluster),
		ContainerInstanceArn: aws.String(containerInstance),
	}

	// Wait till we get an ack
	select {
	case <-ctx.Done():
	}

	// Verify if payloadMessageId read from the ack buffer is correct
	if aws.StringValue(payloadAckRequested.MessageId) != payloadMessageId {
		t.Errorf("Message Id mismatch. Expected: %s, got: %s", payloadMessageId, aws.StringValue(payloadAckRequested.MessageId))
	}

	// Verify the correctness of the first task added to the engine and the
	// credentials ack generated for it
	expectedCredentialsAckForFirstTask := &ecsacs.IAMRoleCredentialsAckRequest{
		MessageId:     aws.String(payloadMessageId),
		Expiration:    aws.String(firstTaskCredentialsExpiration),
		CredentialsId: aws.String(firstTaskCredentialsId),
	}
	expectedCredentialsForFirstTask := credentials.IAMRoleCredentials{
		AccessKeyId:     firstTaskCredentialsAccessKey,
		Expiration:      firstTaskCredentialsExpiration,
		RoleArn:         firstTaskCredentialsRoleArn,
		SecretAccessKey: firstTaskCredentialsSecretKey,
		SessionToken:    firstTaskCredentialsSessionToken,
		CredentialsId:   firstTaskCredentialsId,
	}
	err := validateTaskAndCredentials(firstTaskCredentialsAckRequested, expectedCredentialsAckForFirstTask, firstAddedTask, firstTaskArn, expectedCredentialsForFirstTask)
	if err != nil {
		t.Errorf("Error validating added task or credentials ack for the same: %v", err)
	}

	// Verify the correctness of the second task added to the engine and the
	// credentials ack generated for it
	expectedCredentialsAckForSecondTask := &ecsacs.IAMRoleCredentialsAckRequest{
		MessageId:     aws.String(payloadMessageId),
		Expiration:    aws.String(secondTaskCredentialsExpiration),
		CredentialsId: aws.String(secondTaskCredentialsId),
	}
	expectedCredentialsForSecondTask := credentials.IAMRoleCredentials{
		AccessKeyId:     secondTaskCredentialsAccessKey,
		Expiration:      secondTaskCredentialsExpiration,
		RoleArn:         secondTaskCredentialsRoleArn,
		SecretAccessKey: secondTaskCredentialsSecretKey,
		SessionToken:    secondTaskCredentialsSessionToken,
		CredentialsId:   secondTaskCredentialsId,
	}
	err = validateTaskAndCredentials(secondTaskCredentialsAckRequested, expectedCredentialsAckForSecondTask, secondAddedTask, secondTaskArn, expectedCredentialsForSecondTask)
	if err != nil {
		t.Errorf("Error validating added task or credentials ack for the same: %v", err)
	}
}
// TestStartSessionHandlesRefreshCredentialsMessages tests the agent restart
// scenario where the payload to refresh credentials is processed immediately on
// connection establishment with ACS
func TestStartSessionHandlesRefreshCredentialsMessages(t *testing.T) {
	ctrl := gomock.NewController(t)
	defer ctrl.Finish()
	taskEngine := engine.NewMockTaskEngine(ctrl)
	ecsClient := mock_api.NewMockECSClient(ctrl)
	stateManager := statemanager.NewNoopStateManager()
	closeWS := make(chan bool)
	server, serverIn, requestsChan, errChan, err := startMockAcsServer(t, closeWS)
	if err != nil {
		t.Fatal(err)
	}
	defer close(serverIn)

	ctx, cancel := context.WithCancel(context.Background())
	go func() {
		for {
			select {
			case <-requestsChan:
				// Cancel the context when we get the ack request
				cancel()
			}
		}
	}()

	// DiscoverPollEndpoint returns the URL for the server that we started
	ecsClient.EXPECT().DiscoverPollEndpoint("myArn").Return(server.URL, nil).Times(1)
	taskEngine.EXPECT().Version().Return("Docker: 1.5.0", nil).AnyTimes()

	credentialsManager := mock_credentials.NewMockManager(ctrl)

	ended := make(chan bool, 1)
	go func() {
		StartSession(ctx, StartSessionArguments{
			ContainerInstanceArn: "myArn",
			CredentialProvider:   credentials.AnonymousCredentials,
			Config:               &config.Config{Cluster: "someCluster"},
			TaskEngine:           taskEngine,
			ECSClient:            ecsClient,
			StateManager:         stateManager,
			AcceptInvalidCert:    true,
			CredentialsManager:   credentialsManager,
		})
		// StartSession should never return unless the context is canceled
		ended <- true
	}()

	updatedCredentials := rolecredentials.TaskIAMRoleCredentials{}
	taskFromEngine := &api.Task{}
	credentialsIdInRefreshMessage := "credsId"
	// Ensure that credentials manager interface methods are invoked in the
	// correct order, with expected arguments
	gomock.InOrder(
		// Return a task from the engine for GetTaskByArn
		taskEngine.EXPECT().GetTaskByArn("t1").Return(taskFromEngine, true),
		// The last invocation of SetCredentials is to update
		// credentials when a refresh message is recieved by the handler
		credentialsManager.EXPECT().SetTaskCredentials(gomock.Any()).Do(func(creds rolecredentials.TaskIAMRoleCredentials) {
			updatedCredentials = creds
			// Validate parsed credentials after the update
			expectedCreds := rolecredentials.TaskIAMRoleCredentials{
				ARN: "t1",
				IAMRoleCredentials: rolecredentials.IAMRoleCredentials{
					RoleArn:         "r1",
					AccessKeyId:     "newakid",
					SecretAccessKey: "newskid",
					SessionToken:    "newstkn",
					Expiration:      "later",
					CredentialsId:   credentialsIdInRefreshMessage,
				},
			}
			if !reflect.DeepEqual(updatedCredentials, expectedCreds) {
				t.Errorf("Mismatch between expected and credentials expected: %v, added: %v", expectedCreds, updatedCredentials)
			}
		}).Return(nil),
	)
	serverIn <- sampleRefreshCredentialsMessage

	select {
	case <-ended:
		t.Fatal("Should not have stop session")
	case err := <-errChan:
		t.Fatal("Error should not have been returned from server", err)
	case <-ctx.Done():
		// Context is canceled when requestsChan recieves an ack
	}

	// Validate that the correct credentialsId is set for the task
	credentialsIdFromTask := taskFromEngine.GetCredentialsId()
	if credentialsIdFromTask != credentialsIdInRefreshMessage {
		t.Errorf("Mismatch between expected and added credentials id for task, expected: %s, aded: %s", credentialsIdInRefreshMessage, credentialsIdFromTask)
	}

	go server.Close()
	<-ended
}
// TestHandlePayloadMessageCredentialsAckedWhenTaskAdded tests if the handler generates
// an ack after processing a payload message when the payload message contains a task
// with an IAM Role. It also tests if the credentials ack is generated
func TestHandlePayloadMessageCredentialsAckedWhenTaskAdded(t *testing.T) {
	ctrl := gomock.NewController(t)
	defer ctrl.Finish()
	ecsClient := mock_api.NewMockECSClient(ctrl)
	stateManager := statemanager.NewNoopStateManager()
	ctx, cancel := context.WithCancel(context.Background())
	credentialsManager := credentials.NewManager()

	taskEngine := engine.NewMockTaskEngine(ctrl)
	var addedTask *api.Task
	taskEngine.EXPECT().AddTask(gomock.Any()).Do(func(task *api.Task) {
		addedTask = task
	}).Times(1)

	var payloadAckRequested *ecsacs.AckRequest
	var taskCredentialsAckRequested *ecsacs.IAMRoleCredentialsAckRequest
	mockWsClient := mock_wsclient.NewMockClientServer(ctrl)
	// The payload message in the test consists of a task, with credentials set
	// Record the credentials ack and the payload message ack
	gomock.InOrder(
		mockWsClient.EXPECT().MakeRequest(gomock.Any()).Do(func(ackRequest *ecsacs.IAMRoleCredentialsAckRequest) {
			taskCredentialsAckRequested = ackRequest
		}),
		mockWsClient.EXPECT().MakeRequest(gomock.Any()).Do(func(ackRequest *ecsacs.AckRequest) {
			payloadAckRequested = ackRequest
			// Cancel the context when the ack for the payload message is received
			// This signals a successful workflow in the test
			cancel()
		}),
	)

	refreshCredsHandler := newRefreshCredentialsHandler(ctx, clusterName, containerInstanceArn, mockWsClient, credentialsManager, taskEngine)
	defer refreshCredsHandler.clearAcks()
	refreshCredsHandler.start()

	payloadHandler := newPayloadRequestHandler(ctx, taskEngine, ecsClient, clusterName, containerInstanceArn, mockWsClient, stateManager, refreshCredsHandler, credentialsManager)
	go payloadHandler.start()

	taskArn := "t1"
	credentialsExpiration := "expiration"
	credentialsRoleArn := "r1"
	credentialsAccessKey := "akid"
	credentialsSecretKey := "skid"
	credentialsSessionToken := "token"
	credentialsId := "credsid"

	// Send a payload message
	payloadMessage := &ecsacs.PayloadMessage{
		Tasks: []*ecsacs.Task{
			&ecsacs.Task{
				Arn: aws.String(taskArn),
				RoleCredentials: &ecsacs.IAMRoleCredentials{
					AccessKeyId:     aws.String(credentialsAccessKey),
					Expiration:      aws.String(credentialsExpiration),
					RoleArn:         aws.String(credentialsRoleArn),
					SecretAccessKey: aws.String(credentialsSecretKey),
					SessionToken:    aws.String(credentialsSessionToken),
					CredentialsId:   aws.String(credentialsId),
				},
			},
		},
		MessageId:            aws.String(payloadMessageId),
		ClusterArn:           aws.String(cluster),
		ContainerInstanceArn: aws.String(containerInstance),
	}
	err := payloadHandler.handleSingleMessage(payloadMessage)
	if err != nil {
		t.Errorf("Error handling payload message: %v", err)
	}

	// Wait till we get an ack from the ackBuffer
	select {
	case <-ctx.Done():
	}

	// Verify the message id acked
	if aws.StringValue(payloadAckRequested.MessageId) != payloadMessageId {
		t.Errorf("Message Id mismatch. Expected: %s, got: %s", payloadMessageId, aws.StringValue(payloadAckRequested.MessageId))
	}

	// Verify the correctness of the task added to the engine and the
	// credentials ack generated for it
	expectedCredentialsAck := &ecsacs.IAMRoleCredentialsAckRequest{
		MessageId:     aws.String(payloadMessageId),
		Expiration:    aws.String(credentialsExpiration),
		CredentialsId: aws.String(credentialsId),
	}
	expectedCredentials := credentials.IAMRoleCredentials{
		AccessKeyId:     credentialsAccessKey,
		Expiration:      credentialsExpiration,
		RoleArn:         credentialsRoleArn,
		SecretAccessKey: credentialsSecretKey,
		SessionToken:    credentialsSessionToken,
		CredentialsId:   credentialsId,
	}
	err = validateTaskAndCredentials(taskCredentialsAckRequested, expectedCredentialsAck, addedTask, taskArn, expectedCredentials)
	if err != nil {
		t.Errorf("Error validating added task or credentials ack for the same: %v", err)
	}
}