예제 #1
0
func testProcessSendCommandMessage(t *testing.T, testCase TestCaseSendCommand) {

	cancelFlag := task.NewChanneledCancelFlag()

	// method should call replyBuilder to format the response
	replyBuilderMock := new(MockedReplyBuilder)
	replyBuilderMock.On("BuildReply", mock.Anything, testCase.PluginResults).Return(testCase.ReplyPayload)

	// method should call the proper APIs on the MDS service
	mdsMock := new(MockedMDS)
	var replyPayload string
	mdsMock.On("SendReply", mock.Anything, *testCase.Msg.MessageId, mock.AnythingOfType("string")).Return(nil).Run(func(args mock.Arguments) {
		replyPayload = args.Get(2).(string)
	})
	mdsMock.On("DeleteMessage", mock.Anything, *testCase.Msg.MessageId).Return(nil)

	// create a mock sendResponse function
	sendResponse := func(messageID string, pluginID string, results map[string]*contracts.PluginResult) {
		contextMock := context.NewMockDefault()
		log := contextMock.Log()
		payloadDoc := replyBuilderMock.BuildReply(pluginID, results)
		payloadB, err := json.Marshal(payloadDoc)
		if err != nil {
			return
		}
		payload := string(payloadB)
		// call the mock sendreply so that we can assert the reply sent
		err = mdsMock.SendReply(log, messageID, payload)
	}

	// method should call plugin runner with the given configuration
	pluginRunnerMock := new(MockedPluginRunner)
	// mock.AnythingOfType("func(string, string, map[string]*plugin.Result)")
	pluginRunnerMock.On("RunPlugins", mock.Anything, *testCase.Msg.MessageId, testCase.PluginConfigs, mock.Anything, cancelFlag).Return(testCase.PluginResults)

	// call method under test
	//orchestrationRootDir is set to empty such that it can meet the test expectation.
	orchestrationRootDir := ""
	p := Processor{}
	p.processSendCommandMessage(context.NewMockDefault(), mdsMock, orchestrationRootDir, pluginRunnerMock.RunPlugins, cancelFlag, replyBuilderMock.BuildReply, sendResponse, testCase.Msg)

	// assert that the expectations were met
	pluginRunnerMock.AssertExpectations(t)
	replyBuilderMock.AssertExpectations(t)
	mdsMock.AssertExpectations(t)

	// check that the method sent the right reply
	var parsedReply messageContracts.SendReplyPayload
	err := json.Unmarshal([]byte(replyPayload), &parsedReply)
	assert.Nil(t, err)
	assert.Equal(t, testCase.ReplyPayload, parsedReply)
}
예제 #2
0
func prepareTestPollOnce() (proc Processor, testCase TestCasePollOnce) {

	// create mock context and log
	contextMock := context.NewMockDefault()

	// create mocked service and set expectations
	mdsMock := new(MockedMDS)

	// create a agentConfig with dummy instanceID and agentInfo
	agentConfig := contracts.AgentConfiguration{
		AgentInfo: contracts.AgentInfo{
			Name:      "EC2Config",
			Version:   "1",
			Lang:      "en-US",
			Os:        "linux",
			OsVersion: "1",
		},
		InstanceID: testDestination,
	}

	proc = Processor{
		context: contextMock,
		config:  agentConfig,
		service: mdsMock,
	}

	testCase = TestCasePollOnce{
		ContextMock: contextMock,
		MdsMock:     mdsMock,
	}

	return
}
예제 #3
0
func testProcessCancelCommandMessage(t *testing.T, testCase TestCaseCancelCommand) {
	context := context.NewMockDefault()
	// create a cancel message
	cancelMessagePayload := messageContracts.CancelPayload{
		CancelMessageID: "aws.ssm" + testCase.MsgToCancelID + "." + testCase.InstanceID,
	}
	msgContent, err := jsonutil.Marshal(cancelMessagePayload)
	if err != nil {
		t.Fatal(err)
	}
	mdsCancelMessage := createMDSMessage(testCase.MsgID, msgContent, "aws.ssm.cancelCommand.us.east.1.1", testCase.InstanceID)

	// method should call the proper APIs on the MDS service
	mdsMock := new(MockedMDS)
	mdsMock.On("DeleteMessage", mock.Anything, *mdsCancelMessage.MessageId).Return(nil)

	// method should call cancel command
	sendCommandPoolMock := new(task.MockedPool)
	sendCommandPoolMock.On("Cancel", cancelMessagePayload.CancelMessageID).Return(true)

	p := Processor{}
	// call the code we are testing
	p.processCancelCommandMessage(context, mdsMock, sendCommandPoolMock, mdsCancelMessage)

	// assert that the expectations were met
	mdsMock.AssertExpectations(t)
	sendCommandPoolMock.AssertExpectations(t)
}
예제 #4
0
// TestRunPlugins tests that RunPluginsWithRegistry calls all the expected plugins.
func TestRunPluginsWithRegistry(t *testing.T) {
	pluginNames := []string{"plugin1", "plugin2"}
	pluginConfigs := make(map[string]*contracts.Configuration)
	pluginResults := make(map[string]*contracts.PluginResult)
	pluginInstances := make(map[string]*plugin.Mock)
	pluginRegistry := plugin.PluginRegistry{}
	documentID := "TestDocument"

	sendResponse := func(messageID string, pluginID string, results map[string]*contracts.PluginResult) {
	}

	var cancelFlag task.CancelFlag
	ctx := context.NewMockDefault()
	defaultTime := time.Now()
	for _, name := range pluginNames {

		// create an instance of our test object
		pluginInstances[name] = new(plugin.Mock)

		// setup expectations
		pluginConfigs[name] = &contracts.Configuration{}
		pluginResults[name] = &contracts.PluginResult{
			Output:        name,
			StartDateTime: defaultTime,
			EndDateTime:   defaultTime,
		}

		if name == "plugin2" {
			pluginResults[name].Status = contracts.ResultStatusSuccessAndReboot
		}

		pluginInstances[name].On("Execute", ctx, *pluginConfigs[name], cancelFlag).Return(*pluginResults[name])
		pluginRegistry[name] = pluginInstances[name]
	}

	// call the code we are testing
	outputs := RunPlugins(ctx, documentID, pluginConfigs, pluginRegistry, sendResponse, cancelFlag)

	// fix the times expectation.
	for _, result := range outputs {
		result.EndDateTime = defaultTime
		result.StartDateTime = defaultTime
	}

	// assert that the expectations were met
	for _, mockPlugin := range pluginInstances {
		mockPlugin.AssertExpectations(t)
	}
	ctx.AssertCalled(t, "Log")
	assert.Equal(t, pluginResults, outputs)
	time.Sleep(10 * time.Second)
	assert.Equal(t, true, rebooter.RebootRequested())
}
예제 #5
0
// testExecute tests the run command plugin's Execute method.
func testExecute(t *testing.T, testCase TestCase) {
	executeTester := func(p *Plugin, mockCancelFlag *task.MockCancelFlag, mockExecuter *executers.MockCommandExecuter, mockS3Uploader *pluginutil.MockDefaultPlugin) {
		// setup expectations and correct outputs
		var pluginProperties []interface{}
		var correctOutputs string
		mockContext := context.NewMockDefault()

		// set expectations
		setCancelFlagExpectations(mockCancelFlag)
		setExecuterExpectations(mockExecuter, testCase, mockCancelFlag, p)
		setS3UploaderExpectations(mockS3Uploader, testCase, p)

		// prepare plugin input
		var rawPluginInput interface{}
		err := jsonutil.Remarshal(testCase.Input, &rawPluginInput)
		assert.Nil(t, err)

		pluginProperties = append(pluginProperties, rawPluginInput)
		correctOutputs = testCase.Output.String()

		//Create messageId which is in the format of aws.ssm.<commandID>.<InstanceID>
		commandID := uuid.NewV4().String()

		// call plugin
		res := p.Execute(
			mockContext,
			contracts.Configuration{
				Properties:             pluginProperties,
				OutputS3BucketName:     s3BucketName,
				OutputS3KeyPrefix:      s3KeyPrefix,
				OrchestrationDirectory: orchestrationDirectory,
				BookKeepingFileName:    commandID,
			}, mockCancelFlag)

		// assert output is correct (mocked object expectations are tested automatically by testExecution)
		assert.NotNil(t, res.StartDateTime)
		assert.NotNil(t, res.EndDateTime)
		assert.Equal(t, correctOutputs, res.Output)

		// assert that the flag is checked after every set of commands
		mockCancelFlag.AssertNumberOfCalls(t, "Canceled", 1)
	}

	testExecution(t, executeTester)
}
예제 #6
0
func TestExecute(t *testing.T) {
	pluginInput := createStubPluginInput()
	pluginInput.TargetVersion = ""
	config := contracts.Configuration{}
	p := make([]interface{}, 1)
	p[0] = pluginInput
	config.Properties = p
	plugin := &Plugin{}

	pluginInput.TargetVersion = ""
	mockCancelFlag := new(task.MockCancelFlag)
	mockContext := context.NewMockDefault()

	// Create stub
	updateAgent = func(
		p *Plugin,
		config contracts.Configuration,
		log log.T,
		manager pluginHelper,
		util updateutil.T,
		rawPluginInput interface{},
		cancelFlag task.CancelFlag,
		outputS3BucketName string,
		outputS3KeyPrefix string,
		startTime time.Time) (out UpdatePluginOutput) {
		out = UpdatePluginOutput{}
		out.ExitCode = 1
		out.Stderr = "error"

		return out
	}

	// Setup mocks
	mockCancelFlag.On("Canceled").Return(false)
	mockCancelFlag.On("ShutDown").Return(false)
	mockCancelFlag.On("Wait").Return(false).After(100 * time.Millisecond)

	result := plugin.Execute(mockContext, config, mockCancelFlag)

	assert.Equal(t, result.Code, 1)
	assert.Contains(t, result.Output, "error")
}
예제 #7
0
func prepareTestProcessMessage(testTopic string) (proc Processor, testCase TestCaseProcessMessage) {

	// create mock context and log
	contextMock := context.NewMockDefault()

	// create dummy message that would be passed processMessage
	message := ssmmds.Message{
		CreatedDate: &testCreatedDate,
		Destination: &testDestination,
		MessageId:   &testMessageId,
		Topic:       &testTopic,
	}

	// create a agentConfig with dummy instanceID and agentInfo
	agentConfig := contracts.AgentConfiguration{
		AgentInfo: contracts.AgentInfo{
			Name:      "EC2Config",
			Version:   "1",
			Lang:      "en-US",
			Os:        "linux",
			OsVersion: "1",
		},
		InstanceID: *message.Destination,
	}

	// create mocked service and set expectations
	mdsMock := new(MockedMDS)

	// sendCommand and cancelCommand will be processed by separate worker pools
	// so we can define the number of workers per each
	sendCommandTaskPool := new(task.MockedPool)
	cancelCommandTaskPool := new(task.MockedPool)

	orchestrationRootDir := ""

	// create a mock sendDocLevelResponse function
	isDocLevelResponseSent := false
	sendDocLevelResponse := func(messageID string, resultStatus contracts.ResultStatus, documentTraceOutput string) {
		isDocLevelResponseSent = true
	}

	// create a mock persistData function
	isDataPersisted := false
	persistData := func(msg *ssmmds.Message, bookkeeping string) {
		isDataPersisted = true
	}

	// create a processor with all above
	proc = Processor{
		context:              contextMock,
		config:               agentConfig,
		service:              mdsMock,
		pluginRunner:         pluginRunner,
		sendCommandPool:      sendCommandTaskPool,
		cancelCommandPool:    cancelCommandTaskPool,
		sendDocLevelResponse: sendDocLevelResponse,
		orchestrationRootDir: orchestrationRootDir,
		persistData:          persistData,
	}

	testCase = TestCaseProcessMessage{
		ContextMock:               contextMock,
		Message:                   message,
		MdsMock:                   mdsMock,
		IsDocLevelResponseSent:    &isDocLevelResponseSent,
		IsDataPersisted:           &isDataPersisted,
		SendCommandTaskPoolMock:   sendCommandTaskPool,
		CancelCommandTaskPoolMock: cancelCommandTaskPool,
	}

	return
}