// TestHandlePayloadMessageWithNoMessageId tests that agent doesn't ack payload messages // that do not contain message ids func TestHandlePayloadMessageWithNoMessageId(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() taskEngine := engine.NewMockTaskEngine(ctrl) ecsClient := mock_api.NewMockECSClient(ctrl) stateManager := statemanager.NewNoopStateManager() credentialsManager := credentials.NewManager() ctx := context.Background() buffer := newPayloadRequestHandler(ctx, taskEngine, ecsClient, clusterName, containerInstanceArn, nil, stateManager, refreshCredentialsHandler{}, credentialsManager) // test adding a payload message without the MessageId field payloadMessage := &ecsacs.PayloadMessage{ Tasks: []*ecsacs.Task{ &ecsacs.Task{ Arn: aws.String("t1"), }, }, } err := buffer.handleSingleMessage(payloadMessage) if err == nil { t.Error("Expected error while adding a task with no message id") } // test adding a payload message with blank MessageId payloadMessage.MessageId = aws.String("") err = buffer.handleSingleMessage(payloadMessage) if err == nil { t.Error("Expected error while adding a task with no message id") } }
func TestAcsWsUrl(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() taskEngine := engine.NewMockTaskEngine(ctrl) taskEngine.EXPECT().Version().Return("Docker version result", nil) wsurl := handler.AcsWsUrl("http://endpoint.tld", "myCluster", "myContainerInstance", taskEngine) parsed, err := url.Parse(wsurl) if err != nil { t.Fatal("Should be able to parse url") } if parsed.Path != "/ws" { t.Fatal("Wrong path") } if parsed.Query().Get("clusterArn") != "myCluster" { t.Fatal("Wrong cluster") } if parsed.Query().Get("containerInstanceArn") != "myContainerInstance" { t.Fatal("Wrong cluster") } if parsed.Query().Get("agentVersion") != version.Version { t.Fatal("Wrong cluster") } if parsed.Query().Get("agentHash") != version.GitHashString() { t.Fatal("Wrong cluster") } if parsed.Query().Get("dockerVersion") != "Docker version result" { t.Fatal("Wrong docker version") } }
// TestCredentialsMessageNotAckedWhenTaskNotFound tests if credential messages // are not acked when the task arn in the message is not found in the task // engine func TestCredentialsMessageNotAckedWhenTaskNotFound(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() credentialsManager := credentials.NewManager() taskEngine := engine.NewMockTaskEngine(ctrl) // Return task not found from the engine for GetTaskByArn taskEngine.EXPECT().GetTaskByArn(taskArn).Return(nil, false) ctx, cancel := context.WithCancel(context.Background()) handler := newRefreshCredentialsHandler(ctx, cluster, containerInstance, nil, credentialsManager, taskEngine) // Start a goroutine to listen for acks. Cancelling the context stops the goroutine go func() { for { select { // We never expect the message to be acked case <-handler.ackRequest: t.Fatalf("Received ack when none expected") case <-ctx.Done(): return } } }() // Test adding a credentials message without the MessageId field err := handler.handleSingleMessage(message) if err == nil { t.Error("Expected error updating credentials when the message contains unexpected task arn") } cancel() }
// TestAddPayloadTaskAddsNonStoppedTasksAfterStoppedTasks tests if tasks with desired status // 'RUNNING' are added after tasks with desired status 'STOPPED' func TestAddPayloadTaskAddsNonStoppedTasksAfterStoppedTasks(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() ecsClient := mock_api.NewMockECSClient(ctrl) taskEngine := engine.NewMockTaskEngine(ctrl) credentialsManager := credentials.NewManager() var tasksAddedToEngine []*api.Task taskEngine.EXPECT().AddTask(gomock.Any()).Do(func(task *api.Task) { tasksAddedToEngine = append(tasksAddedToEngine, task) }).Times(2) stoppedTaskArn := "stoppedTask" runningTaskArn := "runningTask" payloadMessage := &ecsacs.PayloadMessage{ Tasks: []*ecsacs.Task{ &ecsacs.Task{ Arn: aws.String(runningTaskArn), DesiredStatus: aws.String("RUNNING"), }, &ecsacs.Task{ Arn: aws.String(stoppedTaskArn), DesiredStatus: aws.String("STOPPED"), }, }, MessageId: aws.String(payloadMessageId), } ctx := context.Background() stateManager := statemanager.NewNoopStateManager() buffer := newPayloadRequestHandler(ctx, taskEngine, ecsClient, clusterName, containerInstanceArn, nil, stateManager, refreshCredentialsHandler{}, credentialsManager) _, ok := buffer.addPayloadTasks(payloadMessage) if !ok { t.Error("addPayloadTasks returned false") } if len(tasksAddedToEngine) != 2 { t.Errorf("Incorrect number of tasks added to the engine. Expected: %d, got: %d", 2, len(tasksAddedToEngine)) } // Verify if stopped task is added before running task firstTaskAdded := tasksAddedToEngine[0] if firstTaskAdded.Arn != stoppedTaskArn { t.Errorf("Expected first task arn: %s, got: %s", stoppedTaskArn, firstTaskAdded.Arn) } if firstTaskAdded.DesiredStatus != api.TaskStopped { t.Errorf("Expected first task state be be: %s , got: %s", "STOPPED", firstTaskAdded.DesiredStatus.String()) } secondTaskAdded := tasksAddedToEngine[1] if secondTaskAdded.Arn != runningTaskArn { t.Errorf("Expected second task arn: %s, got: %s", runningTaskArn, secondTaskAdded.Arn) } if secondTaskAdded.DesiredStatus != api.TaskRunning { t.Errorf("Expected second task state be be: %s , got: %s", "RUNNNING", secondTaskAdded.DesiredStatus.String()) } }
func TestHandlerReconnects(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() taskEngine := engine.NewMockTaskEngine(ctrl) ecsclient := mock_api.NewMockECSClient(ctrl) statemanager := statemanager.NewNoopStateManager() closeWS := make(chan bool) server, serverIn, requests, errs, err := startMockAcsServer(t, closeWS) if err != nil { t.Fatal(err) } go func() { for { select { case <-requests: case <-errs: } } }() ecsclient.EXPECT().DiscoverPollEndpoint("myArn").Return(server.URL, nil).Times(10) taskEngine.EXPECT().Version().Return("Docker: 1.5.0", nil).AnyTimes() ctx, cancel := context.WithCancel(context.Background()) ended := make(chan bool, 1) go func() { handler.StartSession(ctx, handler.StartSessionArguments{ ContainerInstanceArn: "myArn", CredentialProvider: credentials.AnonymousCredentials, Config: &config.Config{Cluster: "someCluster"}, TaskEngine: taskEngine, ECSClient: ecsclient, StateManager: statemanager, AcceptInvalidCert: true, }) // This should never return ended <- true }() start := time.Now() for i := 0; i < 10; i++ { serverIn <- `{"type":"HeartbeatMessage","message":{"healthy":true}}` closeWS <- true } if time.Since(start) > 2*time.Second { t.Error("Test took longer than expected; backoff should not have occured for EOF") } select { case <-ended: t.Fatal("Should not have stopped session") default: } cancel() <-ended }
// TestHandlerReconnectsCorrectlySetsSendCredentialsURLParameter tests if // the 'sendCredentials' URL parameter is set correctly for successive // invocations of startACSSession func TestHandlerReconnectsCorrectlySetsSendCredentialsURLParameter(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() taskEngine := engine.NewMockTaskEngine(ctrl) ecsClient := mock_api.NewMockECSClient(ctrl) statemanager := statemanager.NewNoopStateManager() ctx, cancel := context.WithCancel(context.Background()) mockWsClient := mock_wsclient.NewMockClientServer(ctrl) args := StartSessionArguments{ ContainerInstanceArn: "myArn", CredentialProvider: credentials.AnonymousCredentials, Config: &config.Config{Cluster: "someCluster"}, TaskEngine: taskEngine, ECSClient: ecsClient, StateManager: statemanager, AcceptInvalidCert: true, _heartbeatTimeout: 20 * time.Millisecond, _heartbeatJitter: 10 * time.Millisecond, } session := newSessionResources(args) mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes() mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes() mockWsClient.EXPECT().Close().Return(nil).AnyTimes() mockWsClient.EXPECT().Serve().Return(io.EOF).AnyTimes() gomock.InOrder( // When the websocket client connects to ACS for the first // time, 'sendCredentials' should be set to true mockWsClient.EXPECT().Connect().Do(func() { validateSendCredentialsInSession(t, session, "true") }).Return(nil), // For all subsequent connections to ACS, 'sendCredentials' // should be set to false mockWsClient.EXPECT().Connect().Do(func() { validateSendCredentialsInSession(t, session, "false") }).Return(nil).AnyTimes(), ) backoff := utils.NewSimpleBackoff(connectionBackoffMin, connectionBackoffMax, connectionBackoffJitter, connectionBackoffMultiplier) timer := newDisconnectionTimer(mockWsClient, args.time(), args.heartbeatTimeout(), args.heartbeatJitter()) defer timer.Stop() go func() { for i := 0; i < 10; i++ { startACSSession(ctx, mockWsClient, timer, args, backoff, session) } cancel() }() // Wait for context to be cancelled select { case <-ctx.Done(): } }
// TestHandlePayloadMessageAckedWhenTaskAdded tests if the handler generates an ack // after processing a payload message. func TestHandlePayloadMessageAckedWhenTaskAdded(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() ecsClient := mock_api.NewMockECSClient(ctrl) stateManager := statemanager.NewNoopStateManager() credentialsManager := credentials.NewManager() ctx, cancel := context.WithCancel(context.Background()) taskEngine := engine.NewMockTaskEngine(ctrl) var addedTask *api.Task taskEngine.EXPECT().AddTask(gomock.Any()).Do(func(task *api.Task) { addedTask = task }).Times(1) var ackRequested *ecsacs.AckRequest mockWsClient := mock_wsclient.NewMockClientServer(ctrl) mockWsClient.EXPECT().MakeRequest(gomock.Any()).Do(func(ackRequest *ecsacs.AckRequest) { ackRequested = ackRequest cancel() }).Times(1) buffer := newPayloadRequestHandler(ctx, taskEngine, ecsClient, clusterName, containerInstanceArn, mockWsClient, stateManager, refreshCredentialsHandler{}, credentialsManager) go buffer.start() // Send a payload message payloadMessage := &ecsacs.PayloadMessage{ Tasks: []*ecsacs.Task{ &ecsacs.Task{ Arn: aws.String("t1"), }, }, MessageId: aws.String(payloadMessageId), } err := buffer.handleSingleMessage(payloadMessage) if err != nil { t.Errorf("Error handling payload message: %v", err) } // Wait till we get an ack from the ackBuffer select { case <-ctx.Done(): } // Verify the message id acked if aws.StringValue(ackRequested.MessageId) != payloadMessageId { t.Errorf("Message Id mismatch. Expected: %s, got: %s", payloadMessageId, aws.StringValue(ackRequested.MessageId)) } // Verify if task added == expected task expectedTask := &api.Task{ Arn: "t1", } if !reflect.DeepEqual(addedTask, expectedTask) { t.Errorf("Mismatch between expected and added tasks, expected: %v, added: %v", expectedTask, addedTask) } }
// TestHandlerReconnectsOnServeErrors tests if the handler retries to // to establish the session with ACS when ClientServer.Connect() returns errors func TestHandlerReconnectsOnServeErrors(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() taskEngine := engine.NewMockTaskEngine(ctrl) taskEngine.EXPECT().Version().Return("Docker: 1.5.0", nil).AnyTimes() ecsClient := mock_api.NewMockECSClient(ctrl) ecsClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() statemanager := statemanager.NewNoopStateManager() ctx, cancel := context.WithCancel(context.Background()) mockWsClient := mock_wsclient.NewMockClientServer(ctrl) mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes() mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes() mockWsClient.EXPECT().Connect().Return(nil).AnyTimes() mockWsClient.EXPECT().Close().Return(nil).AnyTimes() gomock.InOrder( // Serve fails 10 times mockWsClient.EXPECT().Serve().Return(io.EOF).Times(10), // Cancel trying to Serve ACS requests on the 11th attempt // Failure to retry on Serve() errors should cause the // test to time out as the context is never cancelled mockWsClient.EXPECT().Serve().Do(func() { cancel() }).Return(io.EOF), ) session := &mockSession{mockWsClient} args := StartSessionArguments{ ContainerInstanceArn: "myArn", CredentialProvider: credentials.AnonymousCredentials, Config: &config.Config{Cluster: "someCluster"}, TaskEngine: taskEngine, ECSClient: ecsClient, StateManager: statemanager, AcceptInvalidCert: true, _heartbeatTimeout: 20 * time.Millisecond, _heartbeatJitter: 10 * time.Millisecond, } backoff := utils.NewSimpleBackoff(connectionBackoffMin, connectionBackoffMax, connectionBackoffJitter, connectionBackoffMultiplier) go func() { startSession(ctx, args, backoff, session) }() // Wait for context to be cancelled select { case <-ctx.Done(): } }
// TestConnectionIsClosedOnIdle tests if the connection to ACS is closed // when the channel is idle func TestConnectionIsClosedOnIdle(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() taskEngine := engine.NewMockTaskEngine(ctrl) taskEngine.EXPECT().Version().Return("Docker: 1.5.0", nil).AnyTimes() ecsClient := mock_api.NewMockECSClient(ctrl) statemanager := statemanager.NewNoopStateManager() mockWsClient := mock_wsclient.NewMockClientServer(ctrl) mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).Do(func(v interface{}) {}).AnyTimes() mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).Do(func(v interface{}) {}).AnyTimes() mockWsClient.EXPECT().Connect().Return(nil) mockWsClient.EXPECT().Serve().Do(func() { // Pretend as if the maximum heartbeatTimeout duration has // been breached while Serving requests time.Sleep(30 * time.Millisecond) }).Return(io.EOF) connectionClosed := make(chan bool) mockWsClient.EXPECT().Close().Do(func() { // Record connection closed connectionClosed <- true }).Return(nil) ctx := context.Background() backoff := utils.NewSimpleBackoff(connectionBackoffMin, connectionBackoffMax, connectionBackoffJitter, connectionBackoffMultiplier) args := StartSessionArguments{ ContainerInstanceArn: "myArn", CredentialProvider: credentials.AnonymousCredentials, Config: &config.Config{Cluster: "someCluster"}, TaskEngine: taskEngine, ECSClient: ecsClient, StateManager: statemanager, AcceptInvalidCert: true, _heartbeatTimeout: 20 * time.Millisecond, _heartbeatJitter: 10 * time.Millisecond, } go func() { timer := newDisconnectionTimer(mockWsClient, args.time(), args.heartbeatTimeout(), args.heartbeatJitter()) defer timer.Stop() startACSSession(ctx, mockWsClient, timer, args, backoff, &mockSession{}) }() // Wait for connection to be closed. If the connection is not closed // due to inactivity, the test will time out <-connectionClosed }
// TestHandlePayloadMessageAddTaskError tests that agent does not ack payload messages // when task engine fails to add tasks func TestHandlePayloadMessageAddTaskError(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() taskEngine := engine.NewMockTaskEngine(ctrl) ecsClient := mock_api.NewMockECSClient(ctrl) stateManager := statemanager.NewNoopStateManager() credentialsManager := credentials.NewManager() // Return error from AddTask taskEngine.EXPECT().AddTask(gomock.Any()).Return(fmt.Errorf("oops")).Times(2) ctx := context.Background() buffer := newPayloadRequestHandler(ctx, taskEngine, ecsClient, clusterName, containerInstanceArn, nil, stateManager, refreshCredentialsHandler{}, credentialsManager) // Test AddTask error with RUNNING task payloadMessage := &ecsacs.PayloadMessage{ Tasks: []*ecsacs.Task{ &ecsacs.Task{ Arn: aws.String("t1"), DesiredStatus: aws.String("RUNNING"), }, }, MessageId: aws.String(payloadMessageId), } err := buffer.handleSingleMessage(payloadMessage) if err == nil { t.Error("Expected error while adding the task") } payloadMessage = &ecsacs.PayloadMessage{ Tasks: []*ecsacs.Task{ &ecsacs.Task{ Arn: aws.String("t1"), DesiredStatus: aws.String("STOPPED"), }, }, MessageId: aws.String(payloadMessageId), } // Test AddTask error with STOPPED task err = buffer.handleSingleMessage(payloadMessage) if err == nil { t.Error("Expected error while adding the task") } }
// TestHandleRefreshMessageAckedWhenCredentialsUpdated tests that a credential message // is ackd when the credentials are updated successfully func TestHandleRefreshMessageAckedWhenCredentialsUpdated(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() credentialsManager := credentials.NewManager() ctx, cancel := context.WithCancel(context.Background()) var ackRequested *ecsacs.IAMRoleCredentialsAckRequest mockWsClient := mock_wsclient.NewMockClientServer(ctrl) mockWsClient.EXPECT().MakeRequest(gomock.Any()).Do(func(ackRequest *ecsacs.IAMRoleCredentialsAckRequest) { ackRequested = ackRequest cancel() }).Times(1) taskEngine := engine.NewMockTaskEngine(ctrl) // Return a task from the engine for GetTaskByArn taskEngine.EXPECT().GetTaskByArn(taskArn).Return(&api.Task{}, true) handler := newRefreshCredentialsHandler(ctx, clusterName, containerInstanceArn, mockWsClient, credentialsManager, taskEngine) go handler.sendAcks() // test adding a credentials message without the MessageId field err := handler.handleSingleMessage(message) if err != nil { t.Errorf("Error updating credentials: %v", err) } // Wait till we get an ack from the ackBuffer select { case <-ctx.Done(): } if !reflect.DeepEqual(ackRequested, expectedAck) { t.Errorf("Message between expected and requested ack. Expected: %v, Requested: %v", expectedAck, ackRequested) } creds, exist := credentialsManager.GetTaskCredentials(credentialsId) if !exist { t.Errorf("Expected credentials to exist for the task") } if !reflect.DeepEqual(creds, expectedCredentials) { t.Errorf("Mismatch between expected credentials and credentials for task. Expected: %v, got: %v", expectedCredentials, creds) } }
// TestHandlePayloadMessageStateSaveError tests that agent does not ack payload messages // when state saver fails to save state func TestHandlePayloadMessageStateSaveError(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() ecsClient := mock_api.NewMockECSClient(ctrl) credentialsManager := credentials.NewManager() taskEngine := engine.NewMockTaskEngine(ctrl) // Save added task in the addedTask variable var addedTask *api.Task taskEngine.EXPECT().AddTask(gomock.Any()).Do(func(task *api.Task) { addedTask = task }).Times(1) // State manager returns error on save stateManager := mock_statemanager.NewMockStateManager(ctrl) stateManager.EXPECT().Save().Return(fmt.Errorf("oops")) ctx := context.Background() buffer := newPayloadRequestHandler(ctx, taskEngine, ecsClient, clusterName, containerInstanceArn, nil, stateManager, refreshCredentialsHandler{}, credentialsManager) // Check if handleSingleMessage returns an error when state manager returns error on Save() err := buffer.handleSingleMessage(&ecsacs.PayloadMessage{ Tasks: []*ecsacs.Task{ &ecsacs.Task{ Arn: aws.String("t1"), }, }, MessageId: aws.String(payloadMessageId), }) if err == nil { t.Error("Expected error while adding a task from statemanager") } // We expect task to be added to the engine even though it hasn't been saved expectedTask := &api.Task{ Arn: "t1", } if !reflect.DeepEqual(addedTask, expectedTask) { t.Errorf("Mismatch between expected and added tasks, expected: %v, added: %v", expectedTask, addedTask) } }
func TestAcsWsUrl(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() taskEngine := engine.NewMockTaskEngine(ctrl) taskEngine.EXPECT().Version().Return("Docker version result", nil) wsurl := acsWsURL(acsURL, "myCluster", "myContainerInstance", taskEngine, &mockSession{}) parsed, err := url.Parse(wsurl) if err != nil { t.Fatal("Should be able to parse url") } if parsed.Path != "/ws" { t.Fatal("Wrong path") } if parsed.Query().Get("clusterArn") != "myCluster" { t.Fatal("Wrong cluster") } if parsed.Query().Get("containerInstanceArn") != "myContainerInstance" { t.Fatal("Wrong cluster") } if parsed.Query().Get("agentVersion") != version.Version { t.Fatal("Wrong cluster") } if parsed.Query().Get("agentHash") != version.GitHashString() { t.Fatal("Wrong cluster") } if parsed.Query().Get("dockerVersion") != "Docker version result" { t.Fatal("Wrong docker version") } if parsed.Query().Get(sendCredentialsURLParameterName) != "true" { t.Fatalf("Wrong value set for: %s", sendCredentialsURLParameterName) } if parsed.Query().Get("seqNum") != "1" { t.Fatal("Wrong seqNum") } }
// TestHandlerReconnectsOnDiscoverPollEndpointError tests if handler retries // to establish the session with ACS on DiscoverPollEndpoint errors func TestHandlerReconnectsOnDiscoverPollEndpointError(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() taskEngine := engine.NewMockTaskEngine(ctrl) taskEngine.EXPECT().Version().Return("Docker: 1.5.0", nil).AnyTimes() ecsClient := mock_api.NewMockECSClient(ctrl) statemanager := statemanager.NewNoopStateManager() ctx, cancel := context.WithCancel(context.Background()) mockWsClient := mock_wsclient.NewMockClientServer(ctrl) mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes() mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes() mockWsClient.EXPECT().Connect().Return(nil).AnyTimes() mockWsClient.EXPECT().Close().Return(nil).AnyTimes() mockWsClient.EXPECT().Serve().Do(func() { // Serve() cancels the context cancel() }).Return(io.EOF) session := &mockSession{mockWsClient} gomock.InOrder( // DiscoverPollEndpoint returns an error on its first invocation ecsClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return("", fmt.Errorf("oops")).Times(1), // Second invocation returns a success ecsClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).Times(1), ) args := StartSessionArguments{ ContainerInstanceArn: "myArn", CredentialProvider: credentials.AnonymousCredentials, Config: &config.Config{Cluster: "someCluster"}, TaskEngine: taskEngine, ECSClient: ecsClient, StateManager: statemanager, AcceptInvalidCert: true, _heartbeatTimeout: 20 * time.Millisecond, _heartbeatJitter: 10 * time.Millisecond, } backoff := utils.NewSimpleBackoff(connectionBackoffMin, connectionBackoffMax, connectionBackoffJitter, connectionBackoffMultiplier) go func() { startSession(ctx, args, backoff, session) }() start := time.Now() // Wait for context to be cancelled select { case <-ctx.Done(): } // Measure the duration between retries timeSinceStart := time.Since(start) if timeSinceStart < connectionBackoffMin { t.Errorf("Duration since start is less than minimum threshold for backoff: %s", timeSinceStart.String()) } // The upper limit here should really be connectionBackoffMin + (connectionBackoffMin * jitter) // But, it can be off by a few milliseconds to account for execution of other instructions // In any case, it should never be higher than 2*connectionBackoffMin if timeSinceStart > 2*connectionBackoffMin { t.Errorf("Duration since start is greater than maximum anticipated wait time: %v", timeSinceStart.String()) } }
func TestHandlerDoesntLeakGouroutines(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() taskEngine := engine.NewMockTaskEngine(ctrl) ecsclient := mock_api.NewMockECSClient(ctrl) statemanager := statemanager.NewNoopStateManager() testTime := ttime.NewTestTime() ttime.SetTime(testTime) closeWS := make(chan bool) server, serverIn, requests, errs, err := startMockAcsServer(t, closeWS) if err != nil { t.Fatal(err) } go func() { for { select { case <-requests: case <-errs: } } }() timesConnected := 0 ecsclient.EXPECT().DiscoverPollEndpoint("myArn").Return(server.URL, nil).AnyTimes().Do(func(_ interface{}) { timesConnected++ }) taskEngine.EXPECT().Version().Return("Docker: 1.5.0", nil).AnyTimes() taskEngine.EXPECT().AddTask(gomock.Any()).AnyTimes() ctx, cancel := context.WithCancel(context.Background()) ended := make(chan bool, 1) go func() { handler.StartSession(ctx, handler.StartSessionArguments{"myArn", credentials.AnonymousCredentials, &config.Config{Cluster: "someCluster"}, taskEngine, ecsclient, statemanager, true}) ended <- true }() // Warm it up serverIn <- `{"type":"HeartbeatMessage","message":{"healthy":true}}` serverIn <- samplePayloadMessage beforeGoroutines := runtime.NumGoroutine() for i := 0; i < 100; i++ { serverIn <- `{"type":"HeartbeatMessage","message":{"healthy":true}}` serverIn <- samplePayloadMessage closeWS <- true } cancel() testTime.Cancel() <-ended afterGoroutines := runtime.NumGoroutine() t.Logf("Gorutines after 1 and after 100 acs messages: %v and %v", beforeGoroutines, afterGoroutines) if timesConnected < 50 { t.Fatal("Expected times connected to be a large number, was ", timesConnected) } if afterGoroutines > beforeGoroutines+5 { t.Error("Goroutine leak, oh no!") pprof.Lookup("goroutine").WriteTo(os.Stdout, 1) } }
func TestHeartbeatOnlyWhenIdle(t *testing.T) { testTime := ttime.NewTestTime() ttime.SetTime(testTime) ctrl := gomock.NewController(t) defer ctrl.Finish() taskEngine := engine.NewMockTaskEngine(ctrl) ecsclient := mock_api.NewMockECSClient(ctrl) statemanager := statemanager.NewNoopStateManager() closeWS := make(chan bool) server, serverIn, requestsChan, errChan, err := startMockAcsServer(t, closeWS) defer close(serverIn) go func() { for { <-requestsChan } }() if err != nil { t.Fatal(err) } // We're testing that it does not reconnect here; must be the case ecsclient.EXPECT().DiscoverPollEndpoint("myArn").Return(server.URL, nil).Times(1) taskEngine.EXPECT().Version().Return("Docker: 1.5.0", nil).AnyTimes() ctx, cancel := context.WithCancel(context.Background()) ended := make(chan bool, 1) go func() { handler.StartSession(ctx, handler.StartSessionArguments{ ContainerInstanceArn: "myArn", CredentialProvider: credentials.AnonymousCredentials, Config: &config.Config{Cluster: "someCluster"}, TaskEngine: taskEngine, ECSClient: ecsclient, StateManager: statemanager, AcceptInvalidCert: true, }) ended <- true }() taskAdded := make(chan bool) taskEngine.EXPECT().AddTask(gomock.Any()).Do(func(interface{}) { taskAdded <- true }).Times(10) for i := 0; i < 10; i++ { serverIn <- samplePayloadMessage testTime.Warp(1 * time.Minute) <-taskAdded } select { case <-ended: t.Fatal("Should not have stop session") case err := <-errChan: t.Fatal("Error should not have been returned from server", err) default: } go server.Close() cancel() <-ended }
// TestPayloadBufferHandlerWithCredentials tests if the async payloadBufferHandler routine // acks the payload message and credentials after adding tasks func TestPayloadBufferHandlerWithCredentials(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() ecsClient := mock_api.NewMockECSClient(ctrl) stateManager := statemanager.NewNoopStateManager() ctx, cancel := context.WithCancel(context.Background()) credentialsManager := credentials.NewManager() // The payload message in the test consists of two tasks, record both of them in // the order in which they were added taskEngine := engine.NewMockTaskEngine(ctrl) var firstAddedTask *api.Task var secondAddedTask *api.Task gomock.InOrder( taskEngine.EXPECT().AddTask(gomock.Any()).Do(func(task *api.Task) { firstAddedTask = task }), taskEngine.EXPECT().AddTask(gomock.Any()).Do(func(task *api.Task) { secondAddedTask = task }), ) // The payload message in the test consists of two tasks, with credentials set // for both. Record the credentials' ack and the payload message ack var payloadAckRequested *ecsacs.AckRequest var firstTaskCredentialsAckRequested *ecsacs.IAMRoleCredentialsAckRequest var secondTaskCredentialsAckRequested *ecsacs.IAMRoleCredentialsAckRequest mockWsClient := mock_wsclient.NewMockClientServer(ctrl) gomock.InOrder( mockWsClient.EXPECT().MakeRequest(gomock.Any()).Do(func(ackRequest *ecsacs.IAMRoleCredentialsAckRequest) { firstTaskCredentialsAckRequested = ackRequest }), mockWsClient.EXPECT().MakeRequest(gomock.Any()).Do(func(ackRequest *ecsacs.IAMRoleCredentialsAckRequest) { secondTaskCredentialsAckRequested = ackRequest }), mockWsClient.EXPECT().MakeRequest(gomock.Any()).Do(func(ackRequest *ecsacs.AckRequest) { payloadAckRequested = ackRequest // Cancel the context when the ack for the payload message is received // This signals a successful workflow in the test cancel() }), ) refreshCredsHandler := newRefreshCredentialsHandler(ctx, clusterName, containerInstanceArn, mockWsClient, credentialsManager, taskEngine) defer refreshCredsHandler.clearAcks() refreshCredsHandler.start() payloadHandler := newPayloadRequestHandler(ctx, taskEngine, ecsClient, clusterName, containerInstanceArn, mockWsClient, stateManager, refreshCredsHandler, credentialsManager) go payloadHandler.start() firstTaskArn := "t1" firstTaskCredentialsExpiration := "expiration" firstTaskCredentialsRoleArn := "r1" firstTaskCredentialsAccessKey := "akid" firstTaskCredentialsSecretKey := "skid" firstTaskCredentialsSessionToken := "token" firstTaskCredentialsId := "credsid1" secondTaskArn := "t2" secondTaskCredentialsExpiration := "expirationSecond" secondTaskCredentialsRoleArn := "r2" secondTaskCredentialsAccessKey := "akid2" secondTaskCredentialsSecretKey := "skid2" secondTaskCredentialsSessionToken := "token2" secondTaskCredentialsId := "credsid2" // Send a payload message to the payloadBufferChannel payloadHandler.messageBuffer <- &ecsacs.PayloadMessage{ Tasks: []*ecsacs.Task{ &ecsacs.Task{ Arn: aws.String(firstTaskArn), RoleCredentials: &ecsacs.IAMRoleCredentials{ AccessKeyId: aws.String(firstTaskCredentialsAccessKey), Expiration: aws.String(firstTaskCredentialsExpiration), RoleArn: aws.String(firstTaskCredentialsRoleArn), SecretAccessKey: aws.String(firstTaskCredentialsSecretKey), SessionToken: aws.String(firstTaskCredentialsSessionToken), CredentialsId: aws.String(firstTaskCredentialsId), }, }, &ecsacs.Task{ Arn: aws.String(secondTaskArn), RoleCredentials: &ecsacs.IAMRoleCredentials{ AccessKeyId: aws.String(secondTaskCredentialsAccessKey), Expiration: aws.String(secondTaskCredentialsExpiration), RoleArn: aws.String(secondTaskCredentialsRoleArn), SecretAccessKey: aws.String(secondTaskCredentialsSecretKey), SessionToken: aws.String(secondTaskCredentialsSessionToken), CredentialsId: aws.String(secondTaskCredentialsId), }, }, }, MessageId: aws.String(payloadMessageId), ClusterArn: aws.String(cluster), ContainerInstanceArn: aws.String(containerInstance), } // Wait till we get an ack select { case <-ctx.Done(): } // Verify if payloadMessageId read from the ack buffer is correct if aws.StringValue(payloadAckRequested.MessageId) != payloadMessageId { t.Errorf("Message Id mismatch. Expected: %s, got: %s", payloadMessageId, aws.StringValue(payloadAckRequested.MessageId)) } // Verify the correctness of the first task added to the engine and the // credentials ack generated for it expectedCredentialsAckForFirstTask := &ecsacs.IAMRoleCredentialsAckRequest{ MessageId: aws.String(payloadMessageId), Expiration: aws.String(firstTaskCredentialsExpiration), CredentialsId: aws.String(firstTaskCredentialsId), } expectedCredentialsForFirstTask := credentials.IAMRoleCredentials{ AccessKeyId: firstTaskCredentialsAccessKey, Expiration: firstTaskCredentialsExpiration, RoleArn: firstTaskCredentialsRoleArn, SecretAccessKey: firstTaskCredentialsSecretKey, SessionToken: firstTaskCredentialsSessionToken, CredentialsId: firstTaskCredentialsId, } err := validateTaskAndCredentials(firstTaskCredentialsAckRequested, expectedCredentialsAckForFirstTask, firstAddedTask, firstTaskArn, expectedCredentialsForFirstTask) if err != nil { t.Errorf("Error validating added task or credentials ack for the same: %v", err) } // Verify the correctness of the second task added to the engine and the // credentials ack generated for it expectedCredentialsAckForSecondTask := &ecsacs.IAMRoleCredentialsAckRequest{ MessageId: aws.String(payloadMessageId), Expiration: aws.String(secondTaskCredentialsExpiration), CredentialsId: aws.String(secondTaskCredentialsId), } expectedCredentialsForSecondTask := credentials.IAMRoleCredentials{ AccessKeyId: secondTaskCredentialsAccessKey, Expiration: secondTaskCredentialsExpiration, RoleArn: secondTaskCredentialsRoleArn, SecretAccessKey: secondTaskCredentialsSecretKey, SessionToken: secondTaskCredentialsSessionToken, CredentialsId: secondTaskCredentialsId, } err = validateTaskAndCredentials(secondTaskCredentialsAckRequested, expectedCredentialsAckForSecondTask, secondAddedTask, secondTaskArn, expectedCredentialsForSecondTask) if err != nil { t.Errorf("Error validating added task or credentials ack for the same: %v", err) } }
// TestStartSessionHandlesRefreshCredentialsMessages tests the agent restart // scenario where the payload to refresh credentials is processed immediately on // connection establishment with ACS func TestStartSessionHandlesRefreshCredentialsMessages(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() taskEngine := engine.NewMockTaskEngine(ctrl) ecsClient := mock_api.NewMockECSClient(ctrl) stateManager := statemanager.NewNoopStateManager() closeWS := make(chan bool) server, serverIn, requestsChan, errChan, err := startMockAcsServer(t, closeWS) if err != nil { t.Fatal(err) } defer close(serverIn) ctx, cancel := context.WithCancel(context.Background()) go func() { for { select { case <-requestsChan: // Cancel the context when we get the ack request cancel() } } }() // DiscoverPollEndpoint returns the URL for the server that we started ecsClient.EXPECT().DiscoverPollEndpoint("myArn").Return(server.URL, nil).Times(1) taskEngine.EXPECT().Version().Return("Docker: 1.5.0", nil).AnyTimes() credentialsManager := mock_credentials.NewMockManager(ctrl) ended := make(chan bool, 1) go func() { StartSession(ctx, StartSessionArguments{ ContainerInstanceArn: "myArn", CredentialProvider: credentials.AnonymousCredentials, Config: &config.Config{Cluster: "someCluster"}, TaskEngine: taskEngine, ECSClient: ecsClient, StateManager: stateManager, AcceptInvalidCert: true, CredentialsManager: credentialsManager, }) // StartSession should never return unless the context is canceled ended <- true }() updatedCredentials := rolecredentials.TaskIAMRoleCredentials{} taskFromEngine := &api.Task{} credentialsIdInRefreshMessage := "credsId" // Ensure that credentials manager interface methods are invoked in the // correct order, with expected arguments gomock.InOrder( // Return a task from the engine for GetTaskByArn taskEngine.EXPECT().GetTaskByArn("t1").Return(taskFromEngine, true), // The last invocation of SetCredentials is to update // credentials when a refresh message is recieved by the handler credentialsManager.EXPECT().SetTaskCredentials(gomock.Any()).Do(func(creds rolecredentials.TaskIAMRoleCredentials) { updatedCredentials = creds // Validate parsed credentials after the update expectedCreds := rolecredentials.TaskIAMRoleCredentials{ ARN: "t1", IAMRoleCredentials: rolecredentials.IAMRoleCredentials{ RoleArn: "r1", AccessKeyId: "newakid", SecretAccessKey: "newskid", SessionToken: "newstkn", Expiration: "later", CredentialsId: credentialsIdInRefreshMessage, }, } if !reflect.DeepEqual(updatedCredentials, expectedCreds) { t.Errorf("Mismatch between expected and credentials expected: %v, added: %v", expectedCreds, updatedCredentials) } }).Return(nil), ) serverIn <- sampleRefreshCredentialsMessage select { case <-ended: t.Fatal("Should not have stop session") case err := <-errChan: t.Fatal("Error should not have been returned from server", err) case <-ctx.Done(): // Context is canceled when requestsChan recieves an ack } // Validate that the correct credentialsId is set for the task credentialsIdFromTask := taskFromEngine.GetCredentialsId() if credentialsIdFromTask != credentialsIdInRefreshMessage { t.Errorf("Mismatch between expected and added credentials id for task, expected: %s, aded: %s", credentialsIdInRefreshMessage, credentialsIdFromTask) } go server.Close() <-ended }
// TestHandlePayloadMessageCredentialsAckedWhenTaskAdded tests if the handler generates // an ack after processing a payload message when the payload message contains a task // with an IAM Role. It also tests if the credentials ack is generated func TestHandlePayloadMessageCredentialsAckedWhenTaskAdded(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() ecsClient := mock_api.NewMockECSClient(ctrl) stateManager := statemanager.NewNoopStateManager() ctx, cancel := context.WithCancel(context.Background()) credentialsManager := credentials.NewManager() taskEngine := engine.NewMockTaskEngine(ctrl) var addedTask *api.Task taskEngine.EXPECT().AddTask(gomock.Any()).Do(func(task *api.Task) { addedTask = task }).Times(1) var payloadAckRequested *ecsacs.AckRequest var taskCredentialsAckRequested *ecsacs.IAMRoleCredentialsAckRequest mockWsClient := mock_wsclient.NewMockClientServer(ctrl) // The payload message in the test consists of a task, with credentials set // Record the credentials ack and the payload message ack gomock.InOrder( mockWsClient.EXPECT().MakeRequest(gomock.Any()).Do(func(ackRequest *ecsacs.IAMRoleCredentialsAckRequest) { taskCredentialsAckRequested = ackRequest }), mockWsClient.EXPECT().MakeRequest(gomock.Any()).Do(func(ackRequest *ecsacs.AckRequest) { payloadAckRequested = ackRequest // Cancel the context when the ack for the payload message is received // This signals a successful workflow in the test cancel() }), ) refreshCredsHandler := newRefreshCredentialsHandler(ctx, clusterName, containerInstanceArn, mockWsClient, credentialsManager, taskEngine) defer refreshCredsHandler.clearAcks() refreshCredsHandler.start() payloadHandler := newPayloadRequestHandler(ctx, taskEngine, ecsClient, clusterName, containerInstanceArn, mockWsClient, stateManager, refreshCredsHandler, credentialsManager) go payloadHandler.start() taskArn := "t1" credentialsExpiration := "expiration" credentialsRoleArn := "r1" credentialsAccessKey := "akid" credentialsSecretKey := "skid" credentialsSessionToken := "token" credentialsId := "credsid" // Send a payload message payloadMessage := &ecsacs.PayloadMessage{ Tasks: []*ecsacs.Task{ &ecsacs.Task{ Arn: aws.String(taskArn), RoleCredentials: &ecsacs.IAMRoleCredentials{ AccessKeyId: aws.String(credentialsAccessKey), Expiration: aws.String(credentialsExpiration), RoleArn: aws.String(credentialsRoleArn), SecretAccessKey: aws.String(credentialsSecretKey), SessionToken: aws.String(credentialsSessionToken), CredentialsId: aws.String(credentialsId), }, }, }, MessageId: aws.String(payloadMessageId), ClusterArn: aws.String(cluster), ContainerInstanceArn: aws.String(containerInstance), } err := payloadHandler.handleSingleMessage(payloadMessage) if err != nil { t.Errorf("Error handling payload message: %v", err) } // Wait till we get an ack from the ackBuffer select { case <-ctx.Done(): } // Verify the message id acked if aws.StringValue(payloadAckRequested.MessageId) != payloadMessageId { t.Errorf("Message Id mismatch. Expected: %s, got: %s", payloadMessageId, aws.StringValue(payloadAckRequested.MessageId)) } // Verify the correctness of the task added to the engine and the // credentials ack generated for it expectedCredentialsAck := &ecsacs.IAMRoleCredentialsAckRequest{ MessageId: aws.String(payloadMessageId), Expiration: aws.String(credentialsExpiration), CredentialsId: aws.String(credentialsId), } expectedCredentials := credentials.IAMRoleCredentials{ AccessKeyId: credentialsAccessKey, Expiration: credentialsExpiration, RoleArn: credentialsRoleArn, SecretAccessKey: credentialsSecretKey, SessionToken: credentialsSessionToken, CredentialsId: credentialsId, } err = validateTaskAndCredentials(taskCredentialsAckRequested, expectedCredentialsAck, addedTask, taskArn, expectedCredentials) if err != nil { t.Errorf("Error validating added task or credentials ack for the same: %v", err) } }