// TestHandlePayloadMessageWithNoMessageId tests that agent doesn't ack payload messages
// that do not contain message ids
func TestHandlePayloadMessageWithNoMessageId(t *testing.T) {
	ctrl := gomock.NewController(t)
	defer ctrl.Finish()
	taskEngine := engine.NewMockTaskEngine(ctrl)
	ecsClient := mock_api.NewMockECSClient(ctrl)
	stateManager := statemanager.NewNoopStateManager()
	credentialsManager := credentials.NewManager()

	ctx := context.Background()
	buffer := newPayloadRequestHandler(ctx, taskEngine, ecsClient, clusterName, containerInstanceArn, nil, stateManager, refreshCredentialsHandler{}, credentialsManager)

	// test adding a payload message without the MessageId field
	payloadMessage := &ecsacs.PayloadMessage{
		Tasks: []*ecsacs.Task{
			&ecsacs.Task{
				Arn: aws.String("t1"),
			},
		},
	}
	err := buffer.handleSingleMessage(payloadMessage)
	if err == nil {
		t.Error("Expected error while adding a task with no message id")
	}

	// test adding a payload message with blank MessageId
	payloadMessage.MessageId = aws.String("")
	err = buffer.handleSingleMessage(payloadMessage)

	if err == nil {
		t.Error("Expected error while adding a task with no message id")
	}

}
// TestCredentialsMessageNotAckedWhenTaskNotFound tests if credential messages
// are not acked when the task arn in the message is not found in the task
// engine
func TestCredentialsMessageNotAckedWhenTaskNotFound(t *testing.T) {
	ctrl := gomock.NewController(t)
	defer ctrl.Finish()
	credentialsManager := credentials.NewManager()

	taskEngine := engine.NewMockTaskEngine(ctrl)
	// Return task not found from the engine for GetTaskByArn
	taskEngine.EXPECT().GetTaskByArn(taskArn).Return(nil, false)

	ctx, cancel := context.WithCancel(context.Background())
	handler := newRefreshCredentialsHandler(ctx, cluster, containerInstance, nil, credentialsManager, taskEngine)

	// Start a goroutine to listen for acks. Cancelling the context stops the goroutine
	go func() {
		for {
			select {
			// We never expect the message to be acked
			case <-handler.ackRequest:
				t.Fatalf("Received ack when none expected")
			case <-ctx.Done():
				return
			}
		}
	}()

	// Test adding a credentials message without the MessageId field
	err := handler.handleSingleMessage(message)
	if err == nil {
		t.Error("Expected error updating credentials when the message contains unexpected task arn")
	}
	cancel()
}
// TestInvalidCredentialsMessageNotAcked tests if invalid credential messages
// are not acked
func TestInvalidCredentialsMessageNotAcked(t *testing.T) {
	ctrl := gomock.NewController(t)
	defer ctrl.Finish()
	credentialsManager := credentials.NewManager()

	ctx, cancel := context.WithCancel(context.Background())
	handler := newRefreshCredentialsHandler(ctx, cluster, containerInstance, nil, credentialsManager, nil)

	// Start a goroutine to listen for acks. Cancelling the context stops the goroutine
	go func() {
		for {
			select {
			// We never expect the message to be acked
			case <-handler.ackRequest:
				t.Fatalf("Received ack when none expected")
			case <-ctx.Done():
				return
			}
		}
	}()

	// test adding a credentials message without the MessageId field
	message := &ecsacs.IAMRoleCredentialsMessage{}
	err := handler.handleSingleMessage(message)
	if err == nil {
		t.Error("Expected error updating credentials when the message contains no message id")
	}
	cancel()
}
// TestAddPayloadTaskAddsNonStoppedTasksAfterStoppedTasks tests if tasks with desired status
// 'RUNNING' are added after tasks with desired status 'STOPPED'
func TestAddPayloadTaskAddsNonStoppedTasksAfterStoppedTasks(t *testing.T) {
	ctrl := gomock.NewController(t)
	defer ctrl.Finish()
	ecsClient := mock_api.NewMockECSClient(ctrl)
	taskEngine := engine.NewMockTaskEngine(ctrl)
	credentialsManager := credentials.NewManager()

	var tasksAddedToEngine []*api.Task
	taskEngine.EXPECT().AddTask(gomock.Any()).Do(func(task *api.Task) {
		tasksAddedToEngine = append(tasksAddedToEngine, task)
	}).Times(2)

	stoppedTaskArn := "stoppedTask"
	runningTaskArn := "runningTask"
	payloadMessage := &ecsacs.PayloadMessage{
		Tasks: []*ecsacs.Task{
			&ecsacs.Task{
				Arn:           aws.String(runningTaskArn),
				DesiredStatus: aws.String("RUNNING"),
			},
			&ecsacs.Task{
				Arn:           aws.String(stoppedTaskArn),
				DesiredStatus: aws.String("STOPPED"),
			},
		},
		MessageId: aws.String(payloadMessageId),
	}

	ctx := context.Background()
	stateManager := statemanager.NewNoopStateManager()
	buffer := newPayloadRequestHandler(ctx, taskEngine, ecsClient, clusterName, containerInstanceArn, nil, stateManager, refreshCredentialsHandler{}, credentialsManager)
	_, ok := buffer.addPayloadTasks(payloadMessage)
	if !ok {
		t.Error("addPayloadTasks returned false")
	}
	if len(tasksAddedToEngine) != 2 {
		t.Errorf("Incorrect number of tasks added to the engine. Expected: %d, got: %d", 2, len(tasksAddedToEngine))
	}

	// Verify if stopped task is added before running task
	firstTaskAdded := tasksAddedToEngine[0]
	if firstTaskAdded.Arn != stoppedTaskArn {
		t.Errorf("Expected first task arn: %s, got: %s", stoppedTaskArn, firstTaskAdded.Arn)
	}
	if firstTaskAdded.DesiredStatus != api.TaskStopped {
		t.Errorf("Expected first task state be be: %s , got: %s", "STOPPED", firstTaskAdded.DesiredStatus.String())
	}

	secondTaskAdded := tasksAddedToEngine[1]
	if secondTaskAdded.Arn != runningTaskArn {
		t.Errorf("Expected second task arn: %s, got: %s", runningTaskArn, secondTaskAdded.Arn)
	}
	if secondTaskAdded.DesiredStatus != api.TaskRunning {
		t.Errorf("Expected second task state be be: %s , got: %s", "RUNNNING", secondTaskAdded.DesiredStatus.String())
	}
}
// TestHandlePayloadMessageAckedWhenTaskAdded tests if the handler generates an ack
// after processing a payload message.
func TestHandlePayloadMessageAckedWhenTaskAdded(t *testing.T) {
	ctrl := gomock.NewController(t)
	defer ctrl.Finish()
	ecsClient := mock_api.NewMockECSClient(ctrl)
	stateManager := statemanager.NewNoopStateManager()
	credentialsManager := credentials.NewManager()
	ctx, cancel := context.WithCancel(context.Background())

	taskEngine := engine.NewMockTaskEngine(ctrl)
	var addedTask *api.Task
	taskEngine.EXPECT().AddTask(gomock.Any()).Do(func(task *api.Task) {
		addedTask = task
	}).Times(1)

	var ackRequested *ecsacs.AckRequest
	mockWsClient := mock_wsclient.NewMockClientServer(ctrl)
	mockWsClient.EXPECT().MakeRequest(gomock.Any()).Do(func(ackRequest *ecsacs.AckRequest) {
		ackRequested = ackRequest
		cancel()
	}).Times(1)
	buffer := newPayloadRequestHandler(ctx, taskEngine, ecsClient, clusterName, containerInstanceArn, mockWsClient, stateManager, refreshCredentialsHandler{}, credentialsManager)
	go buffer.start()

	// Send a payload message
	payloadMessage := &ecsacs.PayloadMessage{
		Tasks: []*ecsacs.Task{
			&ecsacs.Task{
				Arn: aws.String("t1"),
			},
		},
		MessageId: aws.String(payloadMessageId),
	}
	err := buffer.handleSingleMessage(payloadMessage)
	if err != nil {
		t.Errorf("Error handling payload message: %v", err)
	}

	// Wait till we get an ack from the ackBuffer
	select {
	case <-ctx.Done():
	}
	// Verify the message id acked
	if aws.StringValue(ackRequested.MessageId) != payloadMessageId {
		t.Errorf("Message Id mismatch. Expected: %s, got: %s", payloadMessageId, aws.StringValue(ackRequested.MessageId))
	}

	// Verify if task added == expected task
	expectedTask := &api.Task{
		Arn: "t1",
	}
	if !reflect.DeepEqual(addedTask, expectedTask) {
		t.Errorf("Mismatch between expected and added tasks, expected: %v, added: %v", expectedTask, addedTask)
	}
}
// TestHandlePayloadMessageAddTaskError tests that agent does not ack payload messages
// when task engine fails to add tasks
func TestHandlePayloadMessageAddTaskError(t *testing.T) {
	ctrl := gomock.NewController(t)
	defer ctrl.Finish()
	taskEngine := engine.NewMockTaskEngine(ctrl)
	ecsClient := mock_api.NewMockECSClient(ctrl)
	stateManager := statemanager.NewNoopStateManager()
	credentialsManager := credentials.NewManager()

	// Return error from AddTask
	taskEngine.EXPECT().AddTask(gomock.Any()).Return(fmt.Errorf("oops")).Times(2)

	ctx := context.Background()
	buffer := newPayloadRequestHandler(ctx, taskEngine, ecsClient, clusterName, containerInstanceArn, nil, stateManager, refreshCredentialsHandler{}, credentialsManager)

	// Test AddTask error with RUNNING task
	payloadMessage := &ecsacs.PayloadMessage{
		Tasks: []*ecsacs.Task{
			&ecsacs.Task{
				Arn:           aws.String("t1"),
				DesiredStatus: aws.String("RUNNING"),
			},
		},
		MessageId: aws.String(payloadMessageId),
	}
	err := buffer.handleSingleMessage(payloadMessage)
	if err == nil {
		t.Error("Expected error while adding the task")
	}

	payloadMessage = &ecsacs.PayloadMessage{
		Tasks: []*ecsacs.Task{
			&ecsacs.Task{
				Arn:           aws.String("t1"),
				DesiredStatus: aws.String("STOPPED"),
			},
		},
		MessageId: aws.String(payloadMessageId),
	}
	// Test AddTask error with STOPPED task
	err = buffer.handleSingleMessage(payloadMessage)
	if err == nil {
		t.Error("Expected error while adding the task")
	}
}
// TestHandleRefreshMessageAckedWhenCredentialsUpdated tests that a credential message
// is ackd when the credentials are updated successfully
func TestHandleRefreshMessageAckedWhenCredentialsUpdated(t *testing.T) {
	ctrl := gomock.NewController(t)
	defer ctrl.Finish()
	credentialsManager := credentials.NewManager()

	ctx, cancel := context.WithCancel(context.Background())
	var ackRequested *ecsacs.IAMRoleCredentialsAckRequest

	mockWsClient := mock_wsclient.NewMockClientServer(ctrl)
	mockWsClient.EXPECT().MakeRequest(gomock.Any()).Do(func(ackRequest *ecsacs.IAMRoleCredentialsAckRequest) {
		ackRequested = ackRequest
		cancel()
	}).Times(1)

	taskEngine := engine.NewMockTaskEngine(ctrl)
	// Return a task from the engine for GetTaskByArn
	taskEngine.EXPECT().GetTaskByArn(taskArn).Return(&api.Task{}, true)

	handler := newRefreshCredentialsHandler(ctx, clusterName, containerInstanceArn, mockWsClient, credentialsManager, taskEngine)
	go handler.sendAcks()

	// test adding a credentials message without the MessageId field
	err := handler.handleSingleMessage(message)
	if err != nil {
		t.Errorf("Error updating credentials: %v", err)
	}

	// Wait till we get an ack from the ackBuffer
	select {
	case <-ctx.Done():
	}

	if !reflect.DeepEqual(ackRequested, expectedAck) {
		t.Errorf("Message between expected and requested ack. Expected: %v, Requested: %v", expectedAck, ackRequested)
	}

	creds, exist := credentialsManager.GetTaskCredentials(credentialsId)
	if !exist {
		t.Errorf("Expected credentials to exist for the task")
	}
	if !reflect.DeepEqual(creds, expectedCredentials) {
		t.Errorf("Mismatch between expected credentials and credentials for task. Expected: %v, got: %v", expectedCredentials, creds)
	}
}
// TestHandlePayloadMessageStateSaveError tests that agent does not ack payload messages
// when state saver fails to save state
func TestHandlePayloadMessageStateSaveError(t *testing.T) {
	ctrl := gomock.NewController(t)
	defer ctrl.Finish()
	ecsClient := mock_api.NewMockECSClient(ctrl)
	credentialsManager := credentials.NewManager()

	taskEngine := engine.NewMockTaskEngine(ctrl)
	// Save added task in the addedTask variable
	var addedTask *api.Task
	taskEngine.EXPECT().AddTask(gomock.Any()).Do(func(task *api.Task) {
		addedTask = task
	}).Times(1)

	// State manager returns error on save
	stateManager := mock_statemanager.NewMockStateManager(ctrl)
	stateManager.EXPECT().Save().Return(fmt.Errorf("oops"))

	ctx := context.Background()
	buffer := newPayloadRequestHandler(ctx, taskEngine, ecsClient, clusterName, containerInstanceArn, nil, stateManager, refreshCredentialsHandler{}, credentialsManager)

	// Check if handleSingleMessage returns an error when state manager returns error on Save()
	err := buffer.handleSingleMessage(&ecsacs.PayloadMessage{
		Tasks: []*ecsacs.Task{
			&ecsacs.Task{
				Arn: aws.String("t1"),
			},
		},
		MessageId: aws.String(payloadMessageId),
	})
	if err == nil {
		t.Error("Expected error while adding a task from statemanager")
	}

	// We expect task to be added to the engine even though it hasn't been saved
	expectedTask := &api.Task{
		Arn: "t1",
	}
	if !reflect.DeepEqual(addedTask, expectedTask) {
		t.Errorf("Mismatch between expected and added tasks, expected: %v, added: %v", expectedTask, addedTask)
	}
}
func setup(cfg *config.Config, t *testing.T) (TaskEngine, func(), credentials.Manager) {
	if testing.Short() {
		t.Skip("Skipping integ test in short mode")
	}
	if _, err := os.Stat("/var/run/docker.sock"); err != nil {
		t.Skip("Docker not running")
	}
	if os.Getenv("ECS_SKIP_ENGINE_INTEG_TEST") != "" {
		t.Skip("ECS_SKIP_ENGINE_INTEG_TEST")
	}

	clientFactory := dockerclient.NewFactory("unix:///var/run/docker.sock")
	dockerClient, err := NewDockerGoClient(clientFactory, false, cfg)
	if err != nil {
		t.Fatalf("Error creating Docker client: %v", err)
	}
	credentialsManager := credentials.NewManager()
	taskEngine := NewDockerTaskEngine(cfg, dockerClient, credentialsManager, eventstream.NewEventStream("ENGINEINTEGTEST", context.Background()))
	taskEngine.Init()
	return taskEngine, func() {
		taskEngine.Shutdown()
	}, credentialsManager
}
예제 #10
0
func _main() int {
	defer log.Flush()
	flagset := flag.NewFlagSet("Amazon ECS Agent", flag.ContinueOnError)
	versionFlag := flagset.Bool("version", false, "Print the agent version information and exit")
	logLevel := flagset.String("loglevel", "", "Loglevel: [<crit>|<error>|<warn>|<info>|<debug>]")
	acceptInsecureCert := flagset.Bool("k", false, "Disable SSL certificate verification. We do not recommend setting this option.")
	licenseFlag := flagset.Bool("license", false, "Print the LICENSE and NOTICE files and exit")
	blackholeEc2Metadata := flagset.Bool("blackhole-ec2-metadata", false, "Blackhole the EC2 Metadata requests. Setting this option can cause the ECS Agent to fail to work properly.  We do not recommend setting this option")
	err := flagset.Parse(os.Args[1:])
	if err != nil {
		return exitcodes.ExitTerminal
	}

	if *licenseFlag {
		license := utils.NewLicenseProvider()
		text, err := license.GetText()
		if err != nil {
			fmt.Fprintln(os.Stderr, err)
			return exitcodes.ExitError
		}
		fmt.Println(text)
		return exitcodes.ExitSuccess
	}

	logger.SetLevel(*logLevel)
	ec2MetadataClient := ec2.DefaultClient
	if *blackholeEc2Metadata {
		ec2MetadataClient = ec2.NewBlackholeEC2MetadataClient()
	}

	log.Infof("Starting Agent: %s", version.String())
	if *acceptInsecureCert {
		log.Warn("SSL certificate verification disabled. This is not recommended.")
	}
	log.Info("Loading configuration")
	cfg, cfgErr := config.NewConfig(ec2MetadataClient)
	// Load cfg and create Docker client before doing 'versionFlag' so that it has the DOCKER_HOST variable loaded if needed
	clientFactory := dockerclient.NewFactory(cfg.DockerEndpoint)
	dockerClient, err := engine.NewDockerGoClient(clientFactory, *acceptInsecureCert, cfg)
	if err != nil {
		log.Criticalf("Error creating Docker client: %v", err)
		return exitcodes.ExitError
	}

	ctx := context.Background()
	// Create the DockerContainerChange event stream for tcs
	containerChangeEventStream := eventstream.NewEventStream(ContainerChangeEventStream, ctx)
	containerChangeEventStream.StartListening()

	// Create credentials manager. This will be used by the task engine and
	// the credentials handler
	credentialsManager := credentials.NewManager()
	// Create image manager. This will be used by the task engine for saving image states
	state := dockerstate.NewDockerTaskEngineState()
	imageManager := engine.NewImageManager(cfg, dockerClient, state)
	if *versionFlag {
		versionableEngine := engine.NewTaskEngine(cfg, dockerClient, credentialsManager, containerChangeEventStream, imageManager, state)
		version.PrintVersion(versionableEngine)
		return exitcodes.ExitSuccess
	}

	sighandlers.StartDebugHandler()

	if cfgErr != nil {
		log.Criticalf("Error loading config: %v", err)
		// All required config values can be inferred from EC2 Metadata, so this error could be transient.
		return exitcodes.ExitError
	}
	log.Debug("Loaded config: " + cfg.String())

	var currentEc2InstanceID, containerInstanceArn string
	var taskEngine engine.TaskEngine

	if cfg.Checkpoint {
		log.Info("Checkpointing is enabled. Attempting to load state")
		var previousCluster, previousEc2InstanceID, previousContainerInstanceArn string
		previousTaskEngine := engine.NewTaskEngine(cfg, dockerClient, credentialsManager, containerChangeEventStream, imageManager, state)
		// previousState is used to verify that our current runtime configuration is
		// compatible with our past configuration as reflected by our state-file
		previousState, err := initializeStateManager(cfg, previousTaskEngine, &previousCluster, &previousContainerInstanceArn, &previousEc2InstanceID)
		if err != nil {
			log.Criticalf("Error creating state manager: %v", err)
			return exitcodes.ExitTerminal
		}

		err = previousState.Load()
		if err != nil {
			log.Criticalf("Error loading previously saved state: %v", err)
			return exitcodes.ExitTerminal
		}

		if previousCluster != "" {
			// TODO Handle default cluster in a sane and unified way across the codebase
			configuredCluster := cfg.Cluster
			if configuredCluster == "" {
				log.Debug("Setting cluster to default; none configured")
				configuredCluster = config.DefaultClusterName
			}
			if previousCluster != configuredCluster {
				log.Criticalf("Data mismatch; saved cluster '%v' does not match configured cluster '%v'. Perhaps you want to delete the configured checkpoint file?", previousCluster, configuredCluster)
				return exitcodes.ExitTerminal
			}
			cfg.Cluster = previousCluster
			log.Infof("Restored cluster '%v'", cfg.Cluster)
		}

		if instanceIdentityDoc, err := ec2MetadataClient.InstanceIdentityDocument(); err == nil {
			currentEc2InstanceID = instanceIdentityDoc.InstanceId
		} else {
			log.Criticalf("Unable to access EC2 Metadata service to determine EC2 ID: %v", err)
		}

		if previousEc2InstanceID != "" && previousEc2InstanceID != currentEc2InstanceID {
			log.Warnf("Data mismatch; saved InstanceID '%s' does not match current InstanceID '%s'. Overwriting old datafile", previousEc2InstanceID, currentEc2InstanceID)

			// Reset taskEngine; all the other values are still default
			taskEngine = engine.NewTaskEngine(cfg, dockerClient, credentialsManager, containerChangeEventStream, imageManager, state)
		} else {
			// Use the values we loaded if there's no issue
			containerInstanceArn = previousContainerInstanceArn
			taskEngine = previousTaskEngine
		}
	} else {
		log.Info("Checkpointing not enabled; a new container instance will be created each time the agent is run")
		taskEngine = engine.NewTaskEngine(cfg, dockerClient, credentialsManager, containerChangeEventStream, imageManager, state)
	}

	stateManager, err := initializeStateManager(cfg, taskEngine, &cfg.Cluster, &containerInstanceArn, &currentEc2InstanceID)
	if err != nil {
		log.Criticalf("Error creating state manager: %v", err)
		return exitcodes.ExitTerminal
	}

	capabilities := taskEngine.Capabilities()

	// We instantiate our own credentialProvider for use in acs/tcs. This tries
	// to mimic roughly the way it's instantiated by the SDK for a default
	// session.
	credentialProvider := defaults.CredChain(defaults.Config(), defaults.Handlers())
	// Preflight request to make sure they're good
	if preflightCreds, err := credentialProvider.Get(); err != nil || preflightCreds.AccessKeyID == "" {
		log.Warnf("Error getting valid credentials (AKID %s): %v", preflightCreds.AccessKeyID, err)
	}
	client := api.NewECSClient(credentialProvider, cfg, httpclient.New(api.RoundtripTimeout, *acceptInsecureCert), ec2MetadataClient)

	if containerInstanceArn == "" {
		log.Info("Registering Instance with ECS")
		containerInstanceArn, err = client.RegisterContainerInstance("", capabilities)
		if err != nil {
			log.Errorf("Error registering: %v", err)
			if retriable, ok := err.(utils.Retriable); ok && !retriable.Retry() {
				return exitcodes.ExitTerminal
			}
			return exitcodes.ExitError
		}
		log.Infof("Registration completed successfully. I am running as '%s' in cluster '%s'", containerInstanceArn, cfg.Cluster)
		// Save our shiny new containerInstanceArn
		stateManager.Save()
	} else {
		log.Infof("Restored from checkpoint file. I am running as '%s' in cluster '%s'", containerInstanceArn, cfg.Cluster)
		_, err = client.RegisterContainerInstance(containerInstanceArn, capabilities)
		if err != nil {
			log.Errorf("Error re-registering: %v", err)
			if awserr, ok := err.(awserr.Error); ok && api.IsInstanceTypeChangedError(awserr) {
				log.Criticalf("The current instance type does not match the registered instance type. Please revert the instance type change, or alternatively launch a new instance. Error: %v", err)
				return exitcodes.ExitTerminal
			}
			return exitcodes.ExitError
		}
	}

	// Begin listening to the docker daemon and saving changes
	taskEngine.SetSaver(stateManager)
	imageManager.SetSaver(stateManager)
	taskEngine.MustInit()

	// start of the periodic image cleanup process
	if !cfg.ImageCleanupDisabled {
		go imageManager.StartImageCleanupProcess(ctx)
	}

	go sighandlers.StartTerminationHandler(stateManager, taskEngine)

	// Agent introspection api
	go handlers.ServeHttp(&containerInstanceArn, taskEngine, cfg)

	// Start serving the endpoint to fetch IAM Role credentials
	go credentialshandler.ServeHttp(credentialsManager, containerInstanceArn, cfg)

	// Start sending events to the backend
	go eventhandler.HandleEngineEvents(taskEngine, client, stateManager)

	deregisterInstanceEventStream := eventstream.NewEventStream(DeregisterContainerInstanceEventStream, ctx)
	deregisterInstanceEventStream.StartListening()

	telemetrySessionParams := tcshandler.TelemetrySessionParams{
		ContainerInstanceArn: containerInstanceArn,
		CredentialProvider:   credentialProvider,
		Cfg:                  cfg,
		DeregisterInstanceEventStream: deregisterInstanceEventStream,
		ContainerChangeEventStream:    containerChangeEventStream,
		DockerClient:                  dockerClient,
		AcceptInvalidCert:             *acceptInsecureCert,
		EcsClient:                     client,
		TaskEngine:                    taskEngine,
	}

	// Start metrics session in a go routine
	go tcshandler.StartMetricsSession(telemetrySessionParams)

	log.Info("Beginning Polling for updates")
	err = acshandler.StartSession(ctx, acshandler.StartSessionArguments{
		AcceptInvalidCert: *acceptInsecureCert,
		Config:            cfg,
		DeregisterInstanceEventStream: deregisterInstanceEventStream,
		ContainerInstanceArn:          containerInstanceArn,
		CredentialProvider:            credentialProvider,
		ECSClient:                     client,
		StateManager:                  stateManager,
		TaskEngine:                    taskEngine,
		CredentialsManager:            credentialsManager,
	})
	if err != nil {
		log.Criticalf("Unretriable error starting communicating with ACS: %v", err)
		return exitcodes.ExitTerminal
	}
	log.Critical("ACS Session handler should never exit")
	return exitcodes.ExitError
}
// TestPayloadBufferHandlerWithCredentials tests if the async payloadBufferHandler routine
// acks the payload message and credentials after adding tasks
func TestPayloadBufferHandlerWithCredentials(t *testing.T) {
	ctrl := gomock.NewController(t)
	defer ctrl.Finish()
	ecsClient := mock_api.NewMockECSClient(ctrl)
	stateManager := statemanager.NewNoopStateManager()
	ctx, cancel := context.WithCancel(context.Background())
	credentialsManager := credentials.NewManager()

	// The payload message in the test consists of two tasks, record both of them in
	// the order in which they were added
	taskEngine := engine.NewMockTaskEngine(ctrl)
	var firstAddedTask *api.Task
	var secondAddedTask *api.Task
	gomock.InOrder(
		taskEngine.EXPECT().AddTask(gomock.Any()).Do(func(task *api.Task) {
			firstAddedTask = task
		}),
		taskEngine.EXPECT().AddTask(gomock.Any()).Do(func(task *api.Task) {
			secondAddedTask = task
		}),
	)

	// The payload message in the test consists of two tasks, with credentials set
	// for both. Record the credentials' ack and the payload message ack
	var payloadAckRequested *ecsacs.AckRequest
	var firstTaskCredentialsAckRequested *ecsacs.IAMRoleCredentialsAckRequest
	var secondTaskCredentialsAckRequested *ecsacs.IAMRoleCredentialsAckRequest
	mockWsClient := mock_wsclient.NewMockClientServer(ctrl)
	gomock.InOrder(
		mockWsClient.EXPECT().MakeRequest(gomock.Any()).Do(func(ackRequest *ecsacs.IAMRoleCredentialsAckRequest) {
			firstTaskCredentialsAckRequested = ackRequest
		}),
		mockWsClient.EXPECT().MakeRequest(gomock.Any()).Do(func(ackRequest *ecsacs.IAMRoleCredentialsAckRequest) {
			secondTaskCredentialsAckRequested = ackRequest
		}),
		mockWsClient.EXPECT().MakeRequest(gomock.Any()).Do(func(ackRequest *ecsacs.AckRequest) {
			payloadAckRequested = ackRequest
			// Cancel the context when the ack for the payload message is received
			// This signals a successful workflow in the test
			cancel()
		}),
	)

	refreshCredsHandler := newRefreshCredentialsHandler(ctx, clusterName, containerInstanceArn, mockWsClient, credentialsManager, taskEngine)
	defer refreshCredsHandler.clearAcks()
	refreshCredsHandler.start()
	payloadHandler := newPayloadRequestHandler(ctx, taskEngine, ecsClient, clusterName, containerInstanceArn, mockWsClient, stateManager, refreshCredsHandler, credentialsManager)
	go payloadHandler.start()

	firstTaskArn := "t1"
	firstTaskCredentialsExpiration := "expiration"
	firstTaskCredentialsRoleArn := "r1"
	firstTaskCredentialsAccessKey := "akid"
	firstTaskCredentialsSecretKey := "skid"
	firstTaskCredentialsSessionToken := "token"
	firstTaskCredentialsId := "credsid1"

	secondTaskArn := "t2"
	secondTaskCredentialsExpiration := "expirationSecond"
	secondTaskCredentialsRoleArn := "r2"
	secondTaskCredentialsAccessKey := "akid2"
	secondTaskCredentialsSecretKey := "skid2"
	secondTaskCredentialsSessionToken := "token2"
	secondTaskCredentialsId := "credsid2"

	// Send a payload message to the payloadBufferChannel
	payloadHandler.messageBuffer <- &ecsacs.PayloadMessage{
		Tasks: []*ecsacs.Task{
			&ecsacs.Task{
				Arn: aws.String(firstTaskArn),
				RoleCredentials: &ecsacs.IAMRoleCredentials{
					AccessKeyId:     aws.String(firstTaskCredentialsAccessKey),
					Expiration:      aws.String(firstTaskCredentialsExpiration),
					RoleArn:         aws.String(firstTaskCredentialsRoleArn),
					SecretAccessKey: aws.String(firstTaskCredentialsSecretKey),
					SessionToken:    aws.String(firstTaskCredentialsSessionToken),
					CredentialsId:   aws.String(firstTaskCredentialsId),
				},
			},
			&ecsacs.Task{
				Arn: aws.String(secondTaskArn),
				RoleCredentials: &ecsacs.IAMRoleCredentials{
					AccessKeyId:     aws.String(secondTaskCredentialsAccessKey),
					Expiration:      aws.String(secondTaskCredentialsExpiration),
					RoleArn:         aws.String(secondTaskCredentialsRoleArn),
					SecretAccessKey: aws.String(secondTaskCredentialsSecretKey),
					SessionToken:    aws.String(secondTaskCredentialsSessionToken),
					CredentialsId:   aws.String(secondTaskCredentialsId),
				},
			},
		},
		MessageId:            aws.String(payloadMessageId),
		ClusterArn:           aws.String(cluster),
		ContainerInstanceArn: aws.String(containerInstance),
	}

	// Wait till we get an ack
	select {
	case <-ctx.Done():
	}

	// Verify if payloadMessageId read from the ack buffer is correct
	if aws.StringValue(payloadAckRequested.MessageId) != payloadMessageId {
		t.Errorf("Message Id mismatch. Expected: %s, got: %s", payloadMessageId, aws.StringValue(payloadAckRequested.MessageId))
	}

	// Verify the correctness of the first task added to the engine and the
	// credentials ack generated for it
	expectedCredentialsAckForFirstTask := &ecsacs.IAMRoleCredentialsAckRequest{
		MessageId:     aws.String(payloadMessageId),
		Expiration:    aws.String(firstTaskCredentialsExpiration),
		CredentialsId: aws.String(firstTaskCredentialsId),
	}
	expectedCredentialsForFirstTask := credentials.IAMRoleCredentials{
		AccessKeyId:     firstTaskCredentialsAccessKey,
		Expiration:      firstTaskCredentialsExpiration,
		RoleArn:         firstTaskCredentialsRoleArn,
		SecretAccessKey: firstTaskCredentialsSecretKey,
		SessionToken:    firstTaskCredentialsSessionToken,
		CredentialsId:   firstTaskCredentialsId,
	}
	err := validateTaskAndCredentials(firstTaskCredentialsAckRequested, expectedCredentialsAckForFirstTask, firstAddedTask, firstTaskArn, expectedCredentialsForFirstTask)
	if err != nil {
		t.Errorf("Error validating added task or credentials ack for the same: %v", err)
	}

	// Verify the correctness of the second task added to the engine and the
	// credentials ack generated for it
	expectedCredentialsAckForSecondTask := &ecsacs.IAMRoleCredentialsAckRequest{
		MessageId:     aws.String(payloadMessageId),
		Expiration:    aws.String(secondTaskCredentialsExpiration),
		CredentialsId: aws.String(secondTaskCredentialsId),
	}
	expectedCredentialsForSecondTask := credentials.IAMRoleCredentials{
		AccessKeyId:     secondTaskCredentialsAccessKey,
		Expiration:      secondTaskCredentialsExpiration,
		RoleArn:         secondTaskCredentialsRoleArn,
		SecretAccessKey: secondTaskCredentialsSecretKey,
		SessionToken:    secondTaskCredentialsSessionToken,
		CredentialsId:   secondTaskCredentialsId,
	}
	err = validateTaskAndCredentials(secondTaskCredentialsAckRequested, expectedCredentialsAckForSecondTask, secondAddedTask, secondTaskArn, expectedCredentialsForSecondTask)
	if err != nil {
		t.Errorf("Error validating added task or credentials ack for the same: %v", err)
	}
}
// TestHandlePayloadMessageCredentialsAckedWhenTaskAdded tests if the handler generates
// an ack after processing a payload message when the payload message contains a task
// with an IAM Role. It also tests if the credentials ack is generated
func TestHandlePayloadMessageCredentialsAckedWhenTaskAdded(t *testing.T) {
	ctrl := gomock.NewController(t)
	defer ctrl.Finish()
	ecsClient := mock_api.NewMockECSClient(ctrl)
	stateManager := statemanager.NewNoopStateManager()
	ctx, cancel := context.WithCancel(context.Background())
	credentialsManager := credentials.NewManager()

	taskEngine := engine.NewMockTaskEngine(ctrl)
	var addedTask *api.Task
	taskEngine.EXPECT().AddTask(gomock.Any()).Do(func(task *api.Task) {
		addedTask = task
	}).Times(1)

	var payloadAckRequested *ecsacs.AckRequest
	var taskCredentialsAckRequested *ecsacs.IAMRoleCredentialsAckRequest
	mockWsClient := mock_wsclient.NewMockClientServer(ctrl)
	// The payload message in the test consists of a task, with credentials set
	// Record the credentials ack and the payload message ack
	gomock.InOrder(
		mockWsClient.EXPECT().MakeRequest(gomock.Any()).Do(func(ackRequest *ecsacs.IAMRoleCredentialsAckRequest) {
			taskCredentialsAckRequested = ackRequest
		}),
		mockWsClient.EXPECT().MakeRequest(gomock.Any()).Do(func(ackRequest *ecsacs.AckRequest) {
			payloadAckRequested = ackRequest
			// Cancel the context when the ack for the payload message is received
			// This signals a successful workflow in the test
			cancel()
		}),
	)

	refreshCredsHandler := newRefreshCredentialsHandler(ctx, clusterName, containerInstanceArn, mockWsClient, credentialsManager, taskEngine)
	defer refreshCredsHandler.clearAcks()
	refreshCredsHandler.start()

	payloadHandler := newPayloadRequestHandler(ctx, taskEngine, ecsClient, clusterName, containerInstanceArn, mockWsClient, stateManager, refreshCredsHandler, credentialsManager)
	go payloadHandler.start()

	taskArn := "t1"
	credentialsExpiration := "expiration"
	credentialsRoleArn := "r1"
	credentialsAccessKey := "akid"
	credentialsSecretKey := "skid"
	credentialsSessionToken := "token"
	credentialsId := "credsid"

	// Send a payload message
	payloadMessage := &ecsacs.PayloadMessage{
		Tasks: []*ecsacs.Task{
			&ecsacs.Task{
				Arn: aws.String(taskArn),
				RoleCredentials: &ecsacs.IAMRoleCredentials{
					AccessKeyId:     aws.String(credentialsAccessKey),
					Expiration:      aws.String(credentialsExpiration),
					RoleArn:         aws.String(credentialsRoleArn),
					SecretAccessKey: aws.String(credentialsSecretKey),
					SessionToken:    aws.String(credentialsSessionToken),
					CredentialsId:   aws.String(credentialsId),
				},
			},
		},
		MessageId:            aws.String(payloadMessageId),
		ClusterArn:           aws.String(cluster),
		ContainerInstanceArn: aws.String(containerInstance),
	}
	err := payloadHandler.handleSingleMessage(payloadMessage)
	if err != nil {
		t.Errorf("Error handling payload message: %v", err)
	}

	// Wait till we get an ack from the ackBuffer
	select {
	case <-ctx.Done():
	}

	// Verify the message id acked
	if aws.StringValue(payloadAckRequested.MessageId) != payloadMessageId {
		t.Errorf("Message Id mismatch. Expected: %s, got: %s", payloadMessageId, aws.StringValue(payloadAckRequested.MessageId))
	}

	// Verify the correctness of the task added to the engine and the
	// credentials ack generated for it
	expectedCredentialsAck := &ecsacs.IAMRoleCredentialsAckRequest{
		MessageId:     aws.String(payloadMessageId),
		Expiration:    aws.String(credentialsExpiration),
		CredentialsId: aws.String(credentialsId),
	}
	expectedCredentials := credentials.IAMRoleCredentials{
		AccessKeyId:     credentialsAccessKey,
		Expiration:      credentialsExpiration,
		RoleArn:         credentialsRoleArn,
		SecretAccessKey: credentialsSecretKey,
		SessionToken:    credentialsSessionToken,
		CredentialsId:   credentialsId,
	}
	err = validateTaskAndCredentials(taskCredentialsAckRequested, expectedCredentialsAck, addedTask, taskArn, expectedCredentials)
	if err != nil {
		t.Errorf("Error validating added task or credentials ack for the same: %v", err)
	}
}
예제 #13
0
func TestHandlerDoesntLeakGouroutines(t *testing.T) {
	ctrl := gomock.NewController(t)
	defer ctrl.Finish()
	taskEngine := engine.NewMockTaskEngine(ctrl)
	ecsClient := mock_api.NewMockECSClient(ctrl)
	statemanager := statemanager.NewNoopStateManager()

	closeWS := make(chan bool)
	server, serverIn, requests, errs, err := startMockAcsServer(t, closeWS)
	if err != nil {
		t.Fatal(err)
	}
	go func() {
		for {
			select {
			case <-requests:
			case <-errs:
			}
		}
	}()

	timesConnected := 0
	ecsClient.EXPECT().DiscoverPollEndpoint("myArn").Return(server.URL, nil).AnyTimes().Do(func(_ interface{}) {
		timesConnected++
	})
	taskEngine.EXPECT().Version().Return("Docker: 1.5.0", nil).AnyTimes()
	taskEngine.EXPECT().AddTask(gomock.Any()).AnyTimes()

	ctx, cancel := context.WithCancel(context.Background())
	ended := make(chan bool, 1)
	go func() {
		StartSession(ctx, StartSessionArguments{
			ContainerInstanceArn: "myArn",
			CredentialProvider:   credentials.AnonymousCredentials,
			Config:               &config.Config{Cluster: "someCluster"},
			TaskEngine:           taskEngine,
			ECSClient:            ecsClient,
			StateManager:         statemanager,
			AcceptInvalidCert:    true,
			CredentialsManager:   rolecredentials.NewManager(),
			_heartbeatTimeout:    20 * time.Millisecond,
			_heartbeatJitter:     10 * time.Millisecond,
		})
		ended <- true
	}()
	// Warm it up
	serverIn <- `{"type":"HeartbeatMessage","message":{"healthy":true}}`
	serverIn <- samplePayloadMessage

	beforeGoroutines := runtime.NumGoroutine()
	for i := 0; i < 100; i++ {
		serverIn <- `{"type":"HeartbeatMessage","message":{"healthy":true}}`
		serverIn <- samplePayloadMessage
		closeWS <- true
	}

	cancel()
	<-ended

	afterGoroutines := runtime.NumGoroutine()

	t.Logf("Gorutines after 1 and after 100 acs messages: %v and %v", beforeGoroutines, afterGoroutines)

	if timesConnected < 50 {
		t.Fatal("Expected times connected to be a large number, was ", timesConnected)
	}
	if afterGoroutines > beforeGoroutines+5 {
		t.Error("Goroutine leak, oh no!")
		pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)
	}
}