func TestTaskManagerRunTask(t *testing.T) {
	resolved := false
	var serverURL string
	var handler = func(w http.ResponseWriter, r *http.Request) {
		if strings.Contains(r.URL.Path, "/artifacts/public/logs/live_backing.log") {
			json.NewEncoder(w).Encode(&queue.S3ArtifactResponse{
				PutURL: serverURL,
			})
			return
		}

		if strings.Contains(r.URL.Path, "/artifacts/public/logs/live.log") {
			json.NewEncoder(w).Encode(&queue.RedirectArtifactResponse{})
			return
		}

		if strings.Contains(r.URL.Path, "/task/abc/runs/1/completed") {
			resolved = true
			w.Header().Set("Content-Type", "application/json; charset=UTF-8")
			json.NewEncoder(w).Encode(&queue.TaskStatusResponse{})
		}
	}

	s := httptest.NewServer(http.HandlerFunc(handler))
	serverURL = s.URL
	defer s.Close()

	tempPath := filepath.Join(os.TempDir(), slugid.Nice())
	tempStorage, err := runtime.NewTemporaryStorage(tempPath)
	if err != nil {
		t.Fatal(err)
	}

	localServer, err := webhookserver.NewLocalServer(
		[]byte{127, 0, 0, 1}, 60000,
		"", 0,
		"example.com",
		"",
		"",
		"",
		10*time.Minute,
	)
	if err != nil {
		t.Error(err)
	}

	gc := &gc.GarbageCollector{}
	environment := &runtime.Environment{
		GarbageCollector: gc,
		TemporaryStorage: tempStorage,
		WebHookServer:    localServer,
	}
	engineProvider := engines.Engines()["mock"]
	engine, err := engineProvider.NewEngine(engines.EngineOptions{
		Environment: environment,
		Log:         logger.WithField("engine", "mock"),
	})
	if err != nil {
		t.Fatal(err.Error())
	}

	cfg := &configType{
		QueueBaseURL: serverURL,
	}

	tm, err := newTaskManager(cfg, engine, MockPlugin{}, environment, logger.WithField("test", "TestTaskManagerRunTask"), gc)
	if err != nil {
		t.Fatal(err)
	}

	claim := &taskClaim{
		taskID: "abc",
		runID:  1,
		taskClaim: &queue.TaskClaimResponse{
			Credentials: struct {
				AccessToken string `json:"accessToken"`
				Certificate string `json:"certificate"`
				ClientID    string `json:"clientId"`
			}{
				AccessToken: "123",
				ClientID:    "abc",
				Certificate: "",
			},
			TakenUntil: tcclient.Time(time.Now().Add(time.Minute * 5)),
		},
		definition: &queue.TaskDefinitionResponse{
			Payload: []byte(`{"delay": 1,"function": "write-log","argument": "Hello World"}`),
		},
	}
	tm.run(claim)
	assert.True(t, resolved, "Task was not resolved")
}
Example #2
0
// New will create a worker given configuration matching the schema from
// ConfigSchema(). The log parameter is optional and if nil is given a default
// logrus logger will be used.
func New(config interface{}, log *logrus.Logger) (*Worker, error) {
	// Validate and map configuration to c
	var c configType
	if err := schematypes.MustMap(ConfigSchema(), config, &c); err != nil {
		return nil, fmt.Errorf("Invalid configuration: %s", err)
	}

	// Create temporary folder
	err := os.RemoveAll(c.TemporaryFolder)
	if err != nil {
		return nil, fmt.Errorf("Failed to remove temporaryFolder: %s, error: %s",
			c.TemporaryFolder, err)
	}
	tempStorage, err := runtime.NewTemporaryStorage(c.TemporaryFolder)
	if err != nil {
		return nil, fmt.Errorf("Failed to create temporary folder, error: %s", err)
	}

	// Create logger
	if log == nil {
		log = logrus.New()
	}
	log.Level, _ = logrus.ParseLevel(c.LogLevel)

	// Setup WebHookServer
	localServer, err := webhookserver.NewLocalServer(
		net.ParseIP(c.ServerIP), c.ServerPort,
		c.NetworkInterface, c.ExposedPort,
		c.DNSDomain,
		c.DNSSecret,
		c.TLSCertificate,
		c.TLSKey,
		time.Duration(c.MaxLifeCycle)*time.Second,
	)
	if err != nil {
		return nil, err
	}

	// Create environment
	gc := gc.New(c.TemporaryFolder, c.MinimumDiskSpace, c.MinimumMemory)
	env := &runtime.Environment{
		GarbageCollector: gc,
		Log:              log,
		TemporaryStorage: tempStorage,
		WebHookServer:    localServer,
	}

	// Ensure that engine confiuguration was provided for the engine selected
	if _, ok := c.Engines[c.Engine]; !ok {
		return nil, fmt.Errorf("Invalid configuration: The key 'engines.%s' must "+
			"be specified when engine '%s' is selected", c.Engine, c.Engine)
	}

	// Find engine provider (schema should ensure it exists)
	provider := engines.Engines()[c.Engine]
	engine, err := provider.NewEngine(engines.EngineOptions{
		Environment: env,
		Log:         env.Log.WithField("engine", c.Engine),
		Config:      c.Engines[c.Engine],
	})
	if err != nil {
		return nil, fmt.Errorf("Engine initialization failed, error: %s", err)
	}

	// Initialize plugin manager
	pm, err := plugins.NewPluginManager(plugins.PluginOptions{
		Environment: env,
		Engine:      engine,
		Log:         env.Log.WithField("plugin", "plugin-manager"),
		Config:      c.Plugins,
	})
	if err != nil {
		return nil, fmt.Errorf("Plugin initialization failed, error: %s", err)
	}

	tm, err := newTaskManager(
		&c, engine, pm, env,
		env.Log.WithField("component", "task-manager"), gc,
	)
	if err != nil {
		return nil, err
	}

	return &Worker{
		log:    env.Log.WithField("component", "worker"),
		tm:     tm,
		sm:     runtime.NewShutdownManager("local"),
		env:    env,
		server: localServer,
		done:   make(chan struct{}),
	}, nil
}
func TestWorkerShutdown(t *testing.T) {
	var resCount int32
	var serverURL string

	var handler = func(w http.ResponseWriter, r *http.Request) {
		if strings.Contains(r.URL.Path, "/artifacts/public/logs/live_backing.log") {
			json.NewEncoder(w).Encode(&queue.S3ArtifactResponse{
				PutURL: serverURL,
			})
			return
		}

		if strings.Contains(r.URL.Path, "/artifacts/public/logs/live.log") {
			json.NewEncoder(w).Encode(&queue.RedirectArtifactResponse{})
			return
		}

		if strings.Contains(r.URL.Path, "exception") {
			var exception queue.TaskExceptionRequest
			err := json.NewDecoder(r.Body).Decode(&exception)
			// Ignore errors for now
			if err != nil {
				return
			}

			assert.Equal(t, "worker-shutdown", exception.Reason)
			atomic.AddInt32(&resCount, 1)

			w.Header().Set("Content-Type", "application/json; charset=UTF-8")
			json.NewEncoder(w).Encode(&queue.TaskStatusResponse{})
		}
	}

	s := httptest.NewServer(http.HandlerFunc(handler))
	serverURL = s.URL
	defer s.Close()

	tempPath := filepath.Join(os.TempDir(), slugid.Nice())
	tempStorage, err := runtime.NewTemporaryStorage(tempPath)
	if err != nil {
		t.Fatal(err)
	}

	localServer, err := webhookserver.NewLocalServer(
		[]byte{127, 0, 0, 1}, 60000,
		"", 0,
		"example.com",
		"",
		"",
		"",
		10*time.Minute,
	)
	if err != nil {
		t.Error(err)
	}

	gc := &gc.GarbageCollector{}
	environment := &runtime.Environment{
		GarbageCollector: gc,
		TemporaryStorage: tempStorage,
		WebHookServer:    localServer,
	}
	engineProvider := engines.Engines()["mock"]
	engine, err := engineProvider.NewEngine(engines.EngineOptions{
		Environment: environment,
		Log:         logger.WithField("engine", "mock"),
	})
	if err != nil {
		t.Fatal(err.Error())
	}

	cfg := &configType{
		QueueBaseURL: serverURL,
	}
	tm, err := newTaskManager(cfg, engine, MockPlugin{}, environment, logger.WithField("test", "TestRunTask"), gc)
	if err != nil {
		t.Fatal(err)
	}

	claims := []*taskClaim{
		&taskClaim{
			taskID: "abc",
			runID:  1,
			definition: &queue.TaskDefinitionResponse{
				Payload: []byte(`{"delay": 5000,"function": "write-log","argument": "Hello World"}`),
			},
			taskClaim: &queue.TaskClaimResponse{
				Credentials: struct {
					AccessToken string `json:"accessToken"`
					Certificate string `json:"certificate"`
					ClientID    string `json:"clientId"`
				}{
					AccessToken: "123",
					ClientID:    "abc",
					Certificate: "",
				},
				TakenUntil: tcclient.Time(time.Now().Add(time.Minute * 5)),
			},
		},
		&taskClaim{
			taskID: "def",
			runID:  0,
			definition: &queue.TaskDefinitionResponse{
				Payload: []byte(`{"delay": 5000,"function": "write-log","argument": "Hello World"}`),
			},
			taskClaim: &queue.TaskClaimResponse{
				Credentials: struct {
					AccessToken string `json:"accessToken"`
					Certificate string `json:"certificate"`
					ClientID    string `json:"clientId"`
				}{
					AccessToken: "123",
					ClientID:    "abc",
					Certificate: "",
				},
				TakenUntil: tcclient.Time(time.Now().Add(time.Minute * 5)),
			},
		},
	}

	var wg sync.WaitGroup
	wg.Add(2)
	go func() {
		for _, c := range claims {
			go func(claim *taskClaim) {
				tm.run(claim)
				wg.Done()
			}(c)
		}
	}()

	time.Sleep(500 * time.Millisecond)
	assert.Equal(t, len(tm.RunningTasks()), 2)
	close(tm.done)
	tm.Stop()

	wg.Wait()
	assert.Equal(t, 0, len(tm.RunningTasks()))
	assert.Equal(t, int32(2), atomic.LoadInt32(&resCount))
}