// TestHandlePayloadMessageWithNoMessageId tests that agent doesn't ack payload messages // that do not contain message ids func TestHandlePayloadMessageWithNoMessageId(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() taskEngine := engine.NewMockTaskEngine(ctrl) ecsClient := mock_api.NewMockECSClient(ctrl) stateManager := statemanager.NewNoopStateManager() credentialsManager := credentials.NewManager() ctx := context.Background() buffer := newPayloadRequestHandler(ctx, taskEngine, ecsClient, clusterName, containerInstanceArn, nil, stateManager, refreshCredentialsHandler{}, credentialsManager) // test adding a payload message without the MessageId field payloadMessage := &ecsacs.PayloadMessage{ Tasks: []*ecsacs.Task{ &ecsacs.Task{ Arn: aws.String("t1"), }, }, } err := buffer.handleSingleMessage(payloadMessage) if err == nil { t.Error("Expected error while adding a task with no message id") } // test adding a payload message with blank MessageId payloadMessage.MessageId = aws.String("") err = buffer.handleSingleMessage(payloadMessage) if err == nil { t.Error("Expected error while adding a task with no message id") } }
// TestCredentialsMessageNotAckedWhenTaskNotFound tests if credential messages // are not acked when the task arn in the message is not found in the task // engine func TestCredentialsMessageNotAckedWhenTaskNotFound(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() credentialsManager := credentials.NewManager() taskEngine := engine.NewMockTaskEngine(ctrl) // Return task not found from the engine for GetTaskByArn taskEngine.EXPECT().GetTaskByArn(taskArn).Return(nil, false) ctx, cancel := context.WithCancel(context.Background()) handler := newRefreshCredentialsHandler(ctx, cluster, containerInstance, nil, credentialsManager, taskEngine) // Start a goroutine to listen for acks. Cancelling the context stops the goroutine go func() { for { select { // We never expect the message to be acked case <-handler.ackRequest: t.Fatalf("Received ack when none expected") case <-ctx.Done(): return } } }() // Test adding a credentials message without the MessageId field err := handler.handleSingleMessage(message) if err == nil { t.Error("Expected error updating credentials when the message contains unexpected task arn") } cancel() }
// TestInvalidCredentialsMessageNotAcked tests if invalid credential messages // are not acked func TestInvalidCredentialsMessageNotAcked(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() credentialsManager := credentials.NewManager() ctx, cancel := context.WithCancel(context.Background()) handler := newRefreshCredentialsHandler(ctx, cluster, containerInstance, nil, credentialsManager, nil) // Start a goroutine to listen for acks. Cancelling the context stops the goroutine go func() { for { select { // We never expect the message to be acked case <-handler.ackRequest: t.Fatalf("Received ack when none expected") case <-ctx.Done(): return } } }() // test adding a credentials message without the MessageId field message := &ecsacs.IAMRoleCredentialsMessage{} err := handler.handleSingleMessage(message) if err == nil { t.Error("Expected error updating credentials when the message contains no message id") } cancel() }
// TestAddPayloadTaskAddsNonStoppedTasksAfterStoppedTasks tests if tasks with desired status // 'RUNNING' are added after tasks with desired status 'STOPPED' func TestAddPayloadTaskAddsNonStoppedTasksAfterStoppedTasks(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() ecsClient := mock_api.NewMockECSClient(ctrl) taskEngine := engine.NewMockTaskEngine(ctrl) credentialsManager := credentials.NewManager() var tasksAddedToEngine []*api.Task taskEngine.EXPECT().AddTask(gomock.Any()).Do(func(task *api.Task) { tasksAddedToEngine = append(tasksAddedToEngine, task) }).Times(2) stoppedTaskArn := "stoppedTask" runningTaskArn := "runningTask" payloadMessage := &ecsacs.PayloadMessage{ Tasks: []*ecsacs.Task{ &ecsacs.Task{ Arn: aws.String(runningTaskArn), DesiredStatus: aws.String("RUNNING"), }, &ecsacs.Task{ Arn: aws.String(stoppedTaskArn), DesiredStatus: aws.String("STOPPED"), }, }, MessageId: aws.String(payloadMessageId), } ctx := context.Background() stateManager := statemanager.NewNoopStateManager() buffer := newPayloadRequestHandler(ctx, taskEngine, ecsClient, clusterName, containerInstanceArn, nil, stateManager, refreshCredentialsHandler{}, credentialsManager) _, ok := buffer.addPayloadTasks(payloadMessage) if !ok { t.Error("addPayloadTasks returned false") } if len(tasksAddedToEngine) != 2 { t.Errorf("Incorrect number of tasks added to the engine. Expected: %d, got: %d", 2, len(tasksAddedToEngine)) } // Verify if stopped task is added before running task firstTaskAdded := tasksAddedToEngine[0] if firstTaskAdded.Arn != stoppedTaskArn { t.Errorf("Expected first task arn: %s, got: %s", stoppedTaskArn, firstTaskAdded.Arn) } if firstTaskAdded.DesiredStatus != api.TaskStopped { t.Errorf("Expected first task state be be: %s , got: %s", "STOPPED", firstTaskAdded.DesiredStatus.String()) } secondTaskAdded := tasksAddedToEngine[1] if secondTaskAdded.Arn != runningTaskArn { t.Errorf("Expected second task arn: %s, got: %s", runningTaskArn, secondTaskAdded.Arn) } if secondTaskAdded.DesiredStatus != api.TaskRunning { t.Errorf("Expected second task state be be: %s , got: %s", "RUNNNING", secondTaskAdded.DesiredStatus.String()) } }
// TestHandlePayloadMessageAckedWhenTaskAdded tests if the handler generates an ack // after processing a payload message. func TestHandlePayloadMessageAckedWhenTaskAdded(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() ecsClient := mock_api.NewMockECSClient(ctrl) stateManager := statemanager.NewNoopStateManager() credentialsManager := credentials.NewManager() ctx, cancel := context.WithCancel(context.Background()) taskEngine := engine.NewMockTaskEngine(ctrl) var addedTask *api.Task taskEngine.EXPECT().AddTask(gomock.Any()).Do(func(task *api.Task) { addedTask = task }).Times(1) var ackRequested *ecsacs.AckRequest mockWsClient := mock_wsclient.NewMockClientServer(ctrl) mockWsClient.EXPECT().MakeRequest(gomock.Any()).Do(func(ackRequest *ecsacs.AckRequest) { ackRequested = ackRequest cancel() }).Times(1) buffer := newPayloadRequestHandler(ctx, taskEngine, ecsClient, clusterName, containerInstanceArn, mockWsClient, stateManager, refreshCredentialsHandler{}, credentialsManager) go buffer.start() // Send a payload message payloadMessage := &ecsacs.PayloadMessage{ Tasks: []*ecsacs.Task{ &ecsacs.Task{ Arn: aws.String("t1"), }, }, MessageId: aws.String(payloadMessageId), } err := buffer.handleSingleMessage(payloadMessage) if err != nil { t.Errorf("Error handling payload message: %v", err) } // Wait till we get an ack from the ackBuffer select { case <-ctx.Done(): } // Verify the message id acked if aws.StringValue(ackRequested.MessageId) != payloadMessageId { t.Errorf("Message Id mismatch. Expected: %s, got: %s", payloadMessageId, aws.StringValue(ackRequested.MessageId)) } // Verify if task added == expected task expectedTask := &api.Task{ Arn: "t1", } if !reflect.DeepEqual(addedTask, expectedTask) { t.Errorf("Mismatch between expected and added tasks, expected: %v, added: %v", expectedTask, addedTask) } }
// TestHandlePayloadMessageAddTaskError tests that agent does not ack payload messages // when task engine fails to add tasks func TestHandlePayloadMessageAddTaskError(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() taskEngine := engine.NewMockTaskEngine(ctrl) ecsClient := mock_api.NewMockECSClient(ctrl) stateManager := statemanager.NewNoopStateManager() credentialsManager := credentials.NewManager() // Return error from AddTask taskEngine.EXPECT().AddTask(gomock.Any()).Return(fmt.Errorf("oops")).Times(2) ctx := context.Background() buffer := newPayloadRequestHandler(ctx, taskEngine, ecsClient, clusterName, containerInstanceArn, nil, stateManager, refreshCredentialsHandler{}, credentialsManager) // Test AddTask error with RUNNING task payloadMessage := &ecsacs.PayloadMessage{ Tasks: []*ecsacs.Task{ &ecsacs.Task{ Arn: aws.String("t1"), DesiredStatus: aws.String("RUNNING"), }, }, MessageId: aws.String(payloadMessageId), } err := buffer.handleSingleMessage(payloadMessage) if err == nil { t.Error("Expected error while adding the task") } payloadMessage = &ecsacs.PayloadMessage{ Tasks: []*ecsacs.Task{ &ecsacs.Task{ Arn: aws.String("t1"), DesiredStatus: aws.String("STOPPED"), }, }, MessageId: aws.String(payloadMessageId), } // Test AddTask error with STOPPED task err = buffer.handleSingleMessage(payloadMessage) if err == nil { t.Error("Expected error while adding the task") } }
// TestHandleRefreshMessageAckedWhenCredentialsUpdated tests that a credential message // is ackd when the credentials are updated successfully func TestHandleRefreshMessageAckedWhenCredentialsUpdated(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() credentialsManager := credentials.NewManager() ctx, cancel := context.WithCancel(context.Background()) var ackRequested *ecsacs.IAMRoleCredentialsAckRequest mockWsClient := mock_wsclient.NewMockClientServer(ctrl) mockWsClient.EXPECT().MakeRequest(gomock.Any()).Do(func(ackRequest *ecsacs.IAMRoleCredentialsAckRequest) { ackRequested = ackRequest cancel() }).Times(1) taskEngine := engine.NewMockTaskEngine(ctrl) // Return a task from the engine for GetTaskByArn taskEngine.EXPECT().GetTaskByArn(taskArn).Return(&api.Task{}, true) handler := newRefreshCredentialsHandler(ctx, clusterName, containerInstanceArn, mockWsClient, credentialsManager, taskEngine) go handler.sendAcks() // test adding a credentials message without the MessageId field err := handler.handleSingleMessage(message) if err != nil { t.Errorf("Error updating credentials: %v", err) } // Wait till we get an ack from the ackBuffer select { case <-ctx.Done(): } if !reflect.DeepEqual(ackRequested, expectedAck) { t.Errorf("Message between expected and requested ack. Expected: %v, Requested: %v", expectedAck, ackRequested) } creds, exist := credentialsManager.GetTaskCredentials(credentialsId) if !exist { t.Errorf("Expected credentials to exist for the task") } if !reflect.DeepEqual(creds, expectedCredentials) { t.Errorf("Mismatch between expected credentials and credentials for task. Expected: %v, got: %v", expectedCredentials, creds) } }
// TestHandlePayloadMessageStateSaveError tests that agent does not ack payload messages // when state saver fails to save state func TestHandlePayloadMessageStateSaveError(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() ecsClient := mock_api.NewMockECSClient(ctrl) credentialsManager := credentials.NewManager() taskEngine := engine.NewMockTaskEngine(ctrl) // Save added task in the addedTask variable var addedTask *api.Task taskEngine.EXPECT().AddTask(gomock.Any()).Do(func(task *api.Task) { addedTask = task }).Times(1) // State manager returns error on save stateManager := mock_statemanager.NewMockStateManager(ctrl) stateManager.EXPECT().Save().Return(fmt.Errorf("oops")) ctx := context.Background() buffer := newPayloadRequestHandler(ctx, taskEngine, ecsClient, clusterName, containerInstanceArn, nil, stateManager, refreshCredentialsHandler{}, credentialsManager) // Check if handleSingleMessage returns an error when state manager returns error on Save() err := buffer.handleSingleMessage(&ecsacs.PayloadMessage{ Tasks: []*ecsacs.Task{ &ecsacs.Task{ Arn: aws.String("t1"), }, }, MessageId: aws.String(payloadMessageId), }) if err == nil { t.Error("Expected error while adding a task from statemanager") } // We expect task to be added to the engine even though it hasn't been saved expectedTask := &api.Task{ Arn: "t1", } if !reflect.DeepEqual(addedTask, expectedTask) { t.Errorf("Mismatch between expected and added tasks, expected: %v, added: %v", expectedTask, addedTask) } }
func setup(cfg *config.Config, t *testing.T) (TaskEngine, func(), credentials.Manager) { if testing.Short() { t.Skip("Skipping integ test in short mode") } if _, err := os.Stat("/var/run/docker.sock"); err != nil { t.Skip("Docker not running") } if os.Getenv("ECS_SKIP_ENGINE_INTEG_TEST") != "" { t.Skip("ECS_SKIP_ENGINE_INTEG_TEST") } clientFactory := dockerclient.NewFactory("unix:///var/run/docker.sock") dockerClient, err := NewDockerGoClient(clientFactory, false, cfg) if err != nil { t.Fatalf("Error creating Docker client: %v", err) } credentialsManager := credentials.NewManager() taskEngine := NewDockerTaskEngine(cfg, dockerClient, credentialsManager, eventstream.NewEventStream("ENGINEINTEGTEST", context.Background())) taskEngine.Init() return taskEngine, func() { taskEngine.Shutdown() }, credentialsManager }
func _main() int { defer log.Flush() flagset := flag.NewFlagSet("Amazon ECS Agent", flag.ContinueOnError) versionFlag := flagset.Bool("version", false, "Print the agent version information and exit") logLevel := flagset.String("loglevel", "", "Loglevel: [<crit>|<error>|<warn>|<info>|<debug>]") acceptInsecureCert := flagset.Bool("k", false, "Disable SSL certificate verification. We do not recommend setting this option.") licenseFlag := flagset.Bool("license", false, "Print the LICENSE and NOTICE files and exit") blackholeEc2Metadata := flagset.Bool("blackhole-ec2-metadata", false, "Blackhole the EC2 Metadata requests. Setting this option can cause the ECS Agent to fail to work properly. We do not recommend setting this option") err := flagset.Parse(os.Args[1:]) if err != nil { return exitcodes.ExitTerminal } if *licenseFlag { license := utils.NewLicenseProvider() text, err := license.GetText() if err != nil { fmt.Fprintln(os.Stderr, err) return exitcodes.ExitError } fmt.Println(text) return exitcodes.ExitSuccess } logger.SetLevel(*logLevel) ec2MetadataClient := ec2.DefaultClient if *blackholeEc2Metadata { ec2MetadataClient = ec2.NewBlackholeEC2MetadataClient() } log.Infof("Starting Agent: %s", version.String()) if *acceptInsecureCert { log.Warn("SSL certificate verification disabled. This is not recommended.") } log.Info("Loading configuration") cfg, cfgErr := config.NewConfig(ec2MetadataClient) // Load cfg and create Docker client before doing 'versionFlag' so that it has the DOCKER_HOST variable loaded if needed clientFactory := dockerclient.NewFactory(cfg.DockerEndpoint) dockerClient, err := engine.NewDockerGoClient(clientFactory, *acceptInsecureCert, cfg) if err != nil { log.Criticalf("Error creating Docker client: %v", err) return exitcodes.ExitError } ctx := context.Background() // Create the DockerContainerChange event stream for tcs containerChangeEventStream := eventstream.NewEventStream(ContainerChangeEventStream, ctx) containerChangeEventStream.StartListening() // Create credentials manager. This will be used by the task engine and // the credentials handler credentialsManager := credentials.NewManager() // Create image manager. This will be used by the task engine for saving image states state := dockerstate.NewDockerTaskEngineState() imageManager := engine.NewImageManager(cfg, dockerClient, state) if *versionFlag { versionableEngine := engine.NewTaskEngine(cfg, dockerClient, credentialsManager, containerChangeEventStream, imageManager, state) version.PrintVersion(versionableEngine) return exitcodes.ExitSuccess } sighandlers.StartDebugHandler() if cfgErr != nil { log.Criticalf("Error loading config: %v", err) // All required config values can be inferred from EC2 Metadata, so this error could be transient. return exitcodes.ExitError } log.Debug("Loaded config: " + cfg.String()) var currentEc2InstanceID, containerInstanceArn string var taskEngine engine.TaskEngine if cfg.Checkpoint { log.Info("Checkpointing is enabled. Attempting to load state") var previousCluster, previousEc2InstanceID, previousContainerInstanceArn string previousTaskEngine := engine.NewTaskEngine(cfg, dockerClient, credentialsManager, containerChangeEventStream, imageManager, state) // previousState is used to verify that our current runtime configuration is // compatible with our past configuration as reflected by our state-file previousState, err := initializeStateManager(cfg, previousTaskEngine, &previousCluster, &previousContainerInstanceArn, &previousEc2InstanceID) if err != nil { log.Criticalf("Error creating state manager: %v", err) return exitcodes.ExitTerminal } err = previousState.Load() if err != nil { log.Criticalf("Error loading previously saved state: %v", err) return exitcodes.ExitTerminal } if previousCluster != "" { // TODO Handle default cluster in a sane and unified way across the codebase configuredCluster := cfg.Cluster if configuredCluster == "" { log.Debug("Setting cluster to default; none configured") configuredCluster = config.DefaultClusterName } if previousCluster != configuredCluster { log.Criticalf("Data mismatch; saved cluster '%v' does not match configured cluster '%v'. Perhaps you want to delete the configured checkpoint file?", previousCluster, configuredCluster) return exitcodes.ExitTerminal } cfg.Cluster = previousCluster log.Infof("Restored cluster '%v'", cfg.Cluster) } if instanceIdentityDoc, err := ec2MetadataClient.InstanceIdentityDocument(); err == nil { currentEc2InstanceID = instanceIdentityDoc.InstanceId } else { log.Criticalf("Unable to access EC2 Metadata service to determine EC2 ID: %v", err) } if previousEc2InstanceID != "" && previousEc2InstanceID != currentEc2InstanceID { log.Warnf("Data mismatch; saved InstanceID '%s' does not match current InstanceID '%s'. Overwriting old datafile", previousEc2InstanceID, currentEc2InstanceID) // Reset taskEngine; all the other values are still default taskEngine = engine.NewTaskEngine(cfg, dockerClient, credentialsManager, containerChangeEventStream, imageManager, state) } else { // Use the values we loaded if there's no issue containerInstanceArn = previousContainerInstanceArn taskEngine = previousTaskEngine } } else { log.Info("Checkpointing not enabled; a new container instance will be created each time the agent is run") taskEngine = engine.NewTaskEngine(cfg, dockerClient, credentialsManager, containerChangeEventStream, imageManager, state) } stateManager, err := initializeStateManager(cfg, taskEngine, &cfg.Cluster, &containerInstanceArn, ¤tEc2InstanceID) if err != nil { log.Criticalf("Error creating state manager: %v", err) return exitcodes.ExitTerminal } capabilities := taskEngine.Capabilities() // We instantiate our own credentialProvider for use in acs/tcs. This tries // to mimic roughly the way it's instantiated by the SDK for a default // session. credentialProvider := defaults.CredChain(defaults.Config(), defaults.Handlers()) // Preflight request to make sure they're good if preflightCreds, err := credentialProvider.Get(); err != nil || preflightCreds.AccessKeyID == "" { log.Warnf("Error getting valid credentials (AKID %s): %v", preflightCreds.AccessKeyID, err) } client := api.NewECSClient(credentialProvider, cfg, httpclient.New(api.RoundtripTimeout, *acceptInsecureCert), ec2MetadataClient) if containerInstanceArn == "" { log.Info("Registering Instance with ECS") containerInstanceArn, err = client.RegisterContainerInstance("", capabilities) if err != nil { log.Errorf("Error registering: %v", err) if retriable, ok := err.(utils.Retriable); ok && !retriable.Retry() { return exitcodes.ExitTerminal } return exitcodes.ExitError } log.Infof("Registration completed successfully. I am running as '%s' in cluster '%s'", containerInstanceArn, cfg.Cluster) // Save our shiny new containerInstanceArn stateManager.Save() } else { log.Infof("Restored from checkpoint file. I am running as '%s' in cluster '%s'", containerInstanceArn, cfg.Cluster) _, err = client.RegisterContainerInstance(containerInstanceArn, capabilities) if err != nil { log.Errorf("Error re-registering: %v", err) if awserr, ok := err.(awserr.Error); ok && api.IsInstanceTypeChangedError(awserr) { log.Criticalf("The current instance type does not match the registered instance type. Please revert the instance type change, or alternatively launch a new instance. Error: %v", err) return exitcodes.ExitTerminal } return exitcodes.ExitError } } // Begin listening to the docker daemon and saving changes taskEngine.SetSaver(stateManager) imageManager.SetSaver(stateManager) taskEngine.MustInit() // start of the periodic image cleanup process if !cfg.ImageCleanupDisabled { go imageManager.StartImageCleanupProcess(ctx) } go sighandlers.StartTerminationHandler(stateManager, taskEngine) // Agent introspection api go handlers.ServeHttp(&containerInstanceArn, taskEngine, cfg) // Start serving the endpoint to fetch IAM Role credentials go credentialshandler.ServeHttp(credentialsManager, containerInstanceArn, cfg) // Start sending events to the backend go eventhandler.HandleEngineEvents(taskEngine, client, stateManager) deregisterInstanceEventStream := eventstream.NewEventStream(DeregisterContainerInstanceEventStream, ctx) deregisterInstanceEventStream.StartListening() telemetrySessionParams := tcshandler.TelemetrySessionParams{ ContainerInstanceArn: containerInstanceArn, CredentialProvider: credentialProvider, Cfg: cfg, DeregisterInstanceEventStream: deregisterInstanceEventStream, ContainerChangeEventStream: containerChangeEventStream, DockerClient: dockerClient, AcceptInvalidCert: *acceptInsecureCert, EcsClient: client, TaskEngine: taskEngine, } // Start metrics session in a go routine go tcshandler.StartMetricsSession(telemetrySessionParams) log.Info("Beginning Polling for updates") err = acshandler.StartSession(ctx, acshandler.StartSessionArguments{ AcceptInvalidCert: *acceptInsecureCert, Config: cfg, DeregisterInstanceEventStream: deregisterInstanceEventStream, ContainerInstanceArn: containerInstanceArn, CredentialProvider: credentialProvider, ECSClient: client, StateManager: stateManager, TaskEngine: taskEngine, CredentialsManager: credentialsManager, }) if err != nil { log.Criticalf("Unretriable error starting communicating with ACS: %v", err) return exitcodes.ExitTerminal } log.Critical("ACS Session handler should never exit") return exitcodes.ExitError }
// TestPayloadBufferHandlerWithCredentials tests if the async payloadBufferHandler routine // acks the payload message and credentials after adding tasks func TestPayloadBufferHandlerWithCredentials(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() ecsClient := mock_api.NewMockECSClient(ctrl) stateManager := statemanager.NewNoopStateManager() ctx, cancel := context.WithCancel(context.Background()) credentialsManager := credentials.NewManager() // The payload message in the test consists of two tasks, record both of them in // the order in which they were added taskEngine := engine.NewMockTaskEngine(ctrl) var firstAddedTask *api.Task var secondAddedTask *api.Task gomock.InOrder( taskEngine.EXPECT().AddTask(gomock.Any()).Do(func(task *api.Task) { firstAddedTask = task }), taskEngine.EXPECT().AddTask(gomock.Any()).Do(func(task *api.Task) { secondAddedTask = task }), ) // The payload message in the test consists of two tasks, with credentials set // for both. Record the credentials' ack and the payload message ack var payloadAckRequested *ecsacs.AckRequest var firstTaskCredentialsAckRequested *ecsacs.IAMRoleCredentialsAckRequest var secondTaskCredentialsAckRequested *ecsacs.IAMRoleCredentialsAckRequest mockWsClient := mock_wsclient.NewMockClientServer(ctrl) gomock.InOrder( mockWsClient.EXPECT().MakeRequest(gomock.Any()).Do(func(ackRequest *ecsacs.IAMRoleCredentialsAckRequest) { firstTaskCredentialsAckRequested = ackRequest }), mockWsClient.EXPECT().MakeRequest(gomock.Any()).Do(func(ackRequest *ecsacs.IAMRoleCredentialsAckRequest) { secondTaskCredentialsAckRequested = ackRequest }), mockWsClient.EXPECT().MakeRequest(gomock.Any()).Do(func(ackRequest *ecsacs.AckRequest) { payloadAckRequested = ackRequest // Cancel the context when the ack for the payload message is received // This signals a successful workflow in the test cancel() }), ) refreshCredsHandler := newRefreshCredentialsHandler(ctx, clusterName, containerInstanceArn, mockWsClient, credentialsManager, taskEngine) defer refreshCredsHandler.clearAcks() refreshCredsHandler.start() payloadHandler := newPayloadRequestHandler(ctx, taskEngine, ecsClient, clusterName, containerInstanceArn, mockWsClient, stateManager, refreshCredsHandler, credentialsManager) go payloadHandler.start() firstTaskArn := "t1" firstTaskCredentialsExpiration := "expiration" firstTaskCredentialsRoleArn := "r1" firstTaskCredentialsAccessKey := "akid" firstTaskCredentialsSecretKey := "skid" firstTaskCredentialsSessionToken := "token" firstTaskCredentialsId := "credsid1" secondTaskArn := "t2" secondTaskCredentialsExpiration := "expirationSecond" secondTaskCredentialsRoleArn := "r2" secondTaskCredentialsAccessKey := "akid2" secondTaskCredentialsSecretKey := "skid2" secondTaskCredentialsSessionToken := "token2" secondTaskCredentialsId := "credsid2" // Send a payload message to the payloadBufferChannel payloadHandler.messageBuffer <- &ecsacs.PayloadMessage{ Tasks: []*ecsacs.Task{ &ecsacs.Task{ Arn: aws.String(firstTaskArn), RoleCredentials: &ecsacs.IAMRoleCredentials{ AccessKeyId: aws.String(firstTaskCredentialsAccessKey), Expiration: aws.String(firstTaskCredentialsExpiration), RoleArn: aws.String(firstTaskCredentialsRoleArn), SecretAccessKey: aws.String(firstTaskCredentialsSecretKey), SessionToken: aws.String(firstTaskCredentialsSessionToken), CredentialsId: aws.String(firstTaskCredentialsId), }, }, &ecsacs.Task{ Arn: aws.String(secondTaskArn), RoleCredentials: &ecsacs.IAMRoleCredentials{ AccessKeyId: aws.String(secondTaskCredentialsAccessKey), Expiration: aws.String(secondTaskCredentialsExpiration), RoleArn: aws.String(secondTaskCredentialsRoleArn), SecretAccessKey: aws.String(secondTaskCredentialsSecretKey), SessionToken: aws.String(secondTaskCredentialsSessionToken), CredentialsId: aws.String(secondTaskCredentialsId), }, }, }, MessageId: aws.String(payloadMessageId), ClusterArn: aws.String(cluster), ContainerInstanceArn: aws.String(containerInstance), } // Wait till we get an ack select { case <-ctx.Done(): } // Verify if payloadMessageId read from the ack buffer is correct if aws.StringValue(payloadAckRequested.MessageId) != payloadMessageId { t.Errorf("Message Id mismatch. Expected: %s, got: %s", payloadMessageId, aws.StringValue(payloadAckRequested.MessageId)) } // Verify the correctness of the first task added to the engine and the // credentials ack generated for it expectedCredentialsAckForFirstTask := &ecsacs.IAMRoleCredentialsAckRequest{ MessageId: aws.String(payloadMessageId), Expiration: aws.String(firstTaskCredentialsExpiration), CredentialsId: aws.String(firstTaskCredentialsId), } expectedCredentialsForFirstTask := credentials.IAMRoleCredentials{ AccessKeyId: firstTaskCredentialsAccessKey, Expiration: firstTaskCredentialsExpiration, RoleArn: firstTaskCredentialsRoleArn, SecretAccessKey: firstTaskCredentialsSecretKey, SessionToken: firstTaskCredentialsSessionToken, CredentialsId: firstTaskCredentialsId, } err := validateTaskAndCredentials(firstTaskCredentialsAckRequested, expectedCredentialsAckForFirstTask, firstAddedTask, firstTaskArn, expectedCredentialsForFirstTask) if err != nil { t.Errorf("Error validating added task or credentials ack for the same: %v", err) } // Verify the correctness of the second task added to the engine and the // credentials ack generated for it expectedCredentialsAckForSecondTask := &ecsacs.IAMRoleCredentialsAckRequest{ MessageId: aws.String(payloadMessageId), Expiration: aws.String(secondTaskCredentialsExpiration), CredentialsId: aws.String(secondTaskCredentialsId), } expectedCredentialsForSecondTask := credentials.IAMRoleCredentials{ AccessKeyId: secondTaskCredentialsAccessKey, Expiration: secondTaskCredentialsExpiration, RoleArn: secondTaskCredentialsRoleArn, SecretAccessKey: secondTaskCredentialsSecretKey, SessionToken: secondTaskCredentialsSessionToken, CredentialsId: secondTaskCredentialsId, } err = validateTaskAndCredentials(secondTaskCredentialsAckRequested, expectedCredentialsAckForSecondTask, secondAddedTask, secondTaskArn, expectedCredentialsForSecondTask) if err != nil { t.Errorf("Error validating added task or credentials ack for the same: %v", err) } }
// TestHandlePayloadMessageCredentialsAckedWhenTaskAdded tests if the handler generates // an ack after processing a payload message when the payload message contains a task // with an IAM Role. It also tests if the credentials ack is generated func TestHandlePayloadMessageCredentialsAckedWhenTaskAdded(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() ecsClient := mock_api.NewMockECSClient(ctrl) stateManager := statemanager.NewNoopStateManager() ctx, cancel := context.WithCancel(context.Background()) credentialsManager := credentials.NewManager() taskEngine := engine.NewMockTaskEngine(ctrl) var addedTask *api.Task taskEngine.EXPECT().AddTask(gomock.Any()).Do(func(task *api.Task) { addedTask = task }).Times(1) var payloadAckRequested *ecsacs.AckRequest var taskCredentialsAckRequested *ecsacs.IAMRoleCredentialsAckRequest mockWsClient := mock_wsclient.NewMockClientServer(ctrl) // The payload message in the test consists of a task, with credentials set // Record the credentials ack and the payload message ack gomock.InOrder( mockWsClient.EXPECT().MakeRequest(gomock.Any()).Do(func(ackRequest *ecsacs.IAMRoleCredentialsAckRequest) { taskCredentialsAckRequested = ackRequest }), mockWsClient.EXPECT().MakeRequest(gomock.Any()).Do(func(ackRequest *ecsacs.AckRequest) { payloadAckRequested = ackRequest // Cancel the context when the ack for the payload message is received // This signals a successful workflow in the test cancel() }), ) refreshCredsHandler := newRefreshCredentialsHandler(ctx, clusterName, containerInstanceArn, mockWsClient, credentialsManager, taskEngine) defer refreshCredsHandler.clearAcks() refreshCredsHandler.start() payloadHandler := newPayloadRequestHandler(ctx, taskEngine, ecsClient, clusterName, containerInstanceArn, mockWsClient, stateManager, refreshCredsHandler, credentialsManager) go payloadHandler.start() taskArn := "t1" credentialsExpiration := "expiration" credentialsRoleArn := "r1" credentialsAccessKey := "akid" credentialsSecretKey := "skid" credentialsSessionToken := "token" credentialsId := "credsid" // Send a payload message payloadMessage := &ecsacs.PayloadMessage{ Tasks: []*ecsacs.Task{ &ecsacs.Task{ Arn: aws.String(taskArn), RoleCredentials: &ecsacs.IAMRoleCredentials{ AccessKeyId: aws.String(credentialsAccessKey), Expiration: aws.String(credentialsExpiration), RoleArn: aws.String(credentialsRoleArn), SecretAccessKey: aws.String(credentialsSecretKey), SessionToken: aws.String(credentialsSessionToken), CredentialsId: aws.String(credentialsId), }, }, }, MessageId: aws.String(payloadMessageId), ClusterArn: aws.String(cluster), ContainerInstanceArn: aws.String(containerInstance), } err := payloadHandler.handleSingleMessage(payloadMessage) if err != nil { t.Errorf("Error handling payload message: %v", err) } // Wait till we get an ack from the ackBuffer select { case <-ctx.Done(): } // Verify the message id acked if aws.StringValue(payloadAckRequested.MessageId) != payloadMessageId { t.Errorf("Message Id mismatch. Expected: %s, got: %s", payloadMessageId, aws.StringValue(payloadAckRequested.MessageId)) } // Verify the correctness of the task added to the engine and the // credentials ack generated for it expectedCredentialsAck := &ecsacs.IAMRoleCredentialsAckRequest{ MessageId: aws.String(payloadMessageId), Expiration: aws.String(credentialsExpiration), CredentialsId: aws.String(credentialsId), } expectedCredentials := credentials.IAMRoleCredentials{ AccessKeyId: credentialsAccessKey, Expiration: credentialsExpiration, RoleArn: credentialsRoleArn, SecretAccessKey: credentialsSecretKey, SessionToken: credentialsSessionToken, CredentialsId: credentialsId, } err = validateTaskAndCredentials(taskCredentialsAckRequested, expectedCredentialsAck, addedTask, taskArn, expectedCredentials) if err != nil { t.Errorf("Error validating added task or credentials ack for the same: %v", err) } }
func TestHandlerDoesntLeakGouroutines(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() taskEngine := engine.NewMockTaskEngine(ctrl) ecsClient := mock_api.NewMockECSClient(ctrl) statemanager := statemanager.NewNoopStateManager() closeWS := make(chan bool) server, serverIn, requests, errs, err := startMockAcsServer(t, closeWS) if err != nil { t.Fatal(err) } go func() { for { select { case <-requests: case <-errs: } } }() timesConnected := 0 ecsClient.EXPECT().DiscoverPollEndpoint("myArn").Return(server.URL, nil).AnyTimes().Do(func(_ interface{}) { timesConnected++ }) taskEngine.EXPECT().Version().Return("Docker: 1.5.0", nil).AnyTimes() taskEngine.EXPECT().AddTask(gomock.Any()).AnyTimes() ctx, cancel := context.WithCancel(context.Background()) ended := make(chan bool, 1) go func() { StartSession(ctx, StartSessionArguments{ ContainerInstanceArn: "myArn", CredentialProvider: credentials.AnonymousCredentials, Config: &config.Config{Cluster: "someCluster"}, TaskEngine: taskEngine, ECSClient: ecsClient, StateManager: statemanager, AcceptInvalidCert: true, CredentialsManager: rolecredentials.NewManager(), _heartbeatTimeout: 20 * time.Millisecond, _heartbeatJitter: 10 * time.Millisecond, }) ended <- true }() // Warm it up serverIn <- `{"type":"HeartbeatMessage","message":{"healthy":true}}` serverIn <- samplePayloadMessage beforeGoroutines := runtime.NumGoroutine() for i := 0; i < 100; i++ { serverIn <- `{"type":"HeartbeatMessage","message":{"healthy":true}}` serverIn <- samplePayloadMessage closeWS <- true } cancel() <-ended afterGoroutines := runtime.NumGoroutine() t.Logf("Gorutines after 1 and after 100 acs messages: %v and %v", beforeGoroutines, afterGoroutines) if timesConnected < 50 { t.Fatal("Expected times connected to be a large number, was ", timesConnected) } if afterGoroutines > beforeGoroutines+5 { t.Error("Goroutine leak, oh no!") pprof.Lookup("goroutine").WriteTo(os.Stdout, 1) } }