Пример #1
0
func testSupervisor(t *testing.T, nWorkers int, test func(*Supervisor)) {
	tempdir, err := ioutil.TempDir("", "")
	if err != nil {
		t.Error(err)
		return
	}
	defer os.RemoveAll(tempdir)

	client, err := tlsconfig.NewClient(tempdir)
	if err != nil {
		t.Error(err)
		return
	}

	testStorage = NewTestStorage()
	c := SupervisorConfig{
		Storage:       testStorage,
		TLSHandshaker: client,
		// Discard logs for tests
		Logger: log.New(ioutil.Discard, "", 0),
	}
	supervisor, err := NewSupervisor(c)
	if err != nil {
		t.Errorf("failed to create supervisor: %v", err)
		return
	}
	// discard logs for tests
	supervisor.Logger = log.New(ioutil.Discard, "", 0)

	for i := 0; i < nWorkers; i++ {
		mps, err := mps.NewMPS()
		if err != nil {
			t.Error(err)
			return
		}
		defer mps.Reset()
		// configure server's TLS config directly rather than with a handshake
		server := httptest.NewUnstartedServer(mps)
		server.TLS, err = client.ServerConfig(net.ParseIP("127.0.0.1"))
		if err != nil {
			t.Error(err)
			return
		}
		server.StartTLS()
		defer server.Close()

		addr := strings.TrimPrefix(server.URL, "https://")

		if err := supervisor.AddWorker(addr, int64(i)); err != nil {
			t.Errorf("could not add worker: %v", err)
			return
		}
	}
	test(supervisor)
}
Пример #2
0
func TestSupervisorReboot(t *testing.T) {

	tempdir, err := ioutil.TempDir("", "")
	if err != nil {
		t.Error(err)
		return
	}
	defer os.RemoveAll(tempdir)

	client, err := tlsconfig.NewClient(tempdir)
	if err != nil {
		t.Error(err)
		return
	}

	testStorage = NewTestStorage()
	c := SupervisorConfig{
		Storage:       testStorage,
		TLSHandshaker: client,
		// Discard logs for tests
		Logger: log.New(ioutil.Discard, "", 0),
	}
	supervisor1, err := NewSupervisor(c)
	if err != nil {
		t.Errorf("failed to create supervisor: %v", err)
		return
	}

	addrs := [5]string{"", "", "", "", ""}

	for i := 0; i < 5; i++ {
		mps, err := mps.NewMPS()
		if err != nil {
			t.Error(err)
			return
		}
		defer mps.Reset()
		// configure server's TLS config directly rather than with a handshake
		server := httptest.NewUnstartedServer(mps)
		server.TLS, err = client.ServerConfig(net.ParseIP("127.0.0.1"))
		if err != nil {
			t.Error(err)
			return
		}
		server.StartTLS()
		defer server.Close()

		addrs[i] = strings.TrimPrefix(server.URL, "https://")
	}

	type Model struct {
		User, Model string
		Version     int
		Instances   []int64
	}
	deployments := map[int64]*Model{}

	for i, addr := range addrs {
		if err := supervisor1.AddWorker(addr, int64(i)); err != nil {
			t.Errorf("could not add worker: %v", err)
			return
		}
	}

	user := "******"
	// Deploy 5 different models
	modelnames := []string{
		"hellor_0", "hellor_1", "hellor_2", "hellor_3", "hellor_4",
	}

	// Override the storage's NewDeployment function so we can track which
	// deployment is associated with which model version.
	// This is normally done by the database.
	nInstances := 2
	var nextId int64 = 1
	var nextDeployId int64 = 1
	testStorage.newDeployment = func(user, model string, version int) (deployId int64, instIds []int64, err error) {
		deployId = nextDeployId
		nextDeployId++

		ids := make([]int64, nInstances)
		instIds = make([]int64, nInstances)
		for i := range ids {
			ids[i] = nextId
			instIds[i] = nextId
			nextId++
		}

		deployments[deployId] = &Model{user, model, version, ids}
		return deployId, instIds, nil
	}
	for _, model := range modelnames {
		if err := supervisor1.Deploy(user, model, 1); err != nil {
			t.Errorf("could not deploy: %v", err)
			return
		}
	}

	nExp := len(modelnames)

	if n := len(deployments); n != nExp {
		t.Errorf("expected %d deployments, got: %d", nExp, n)
		return
	}
	for _, model := range deployments {
		if n := len(model.Instances); n != nInstances {
			t.Errorf("expected %d deployments per model, got %d", nInstances, n)
			return
		}
	}

	testPred := func(super *Supervisor) {
		for _, model := range modelnames {
			hf := func(w http.ResponseWriter, r *http.Request) {
				super.Predict(user, model, w, r)
			}
			s := httptest.NewServer(http.HandlerFunc(hf))
			defer s.Close()
			for i := 0; i < 100; i++ {
				body := bytes.NewReader([]byte(`{"name":"bigdatabob"}`))
				resp, err := http.Post(s.URL, "application/json", body)
				if err != nil {
					t.Errorf("request failed: %v", err)
					return
				}
				respbody, err := ioutil.ReadAll(resp.Body)
				resp.Body.Close()
				if err != nil {
					t.Errorf("could not read body: %v", err)
					return
				}
				if resp.StatusCode != http.StatusOK {
					t.Errorf("expected 200, got %s: %s", resp.Status, respbody)
					return
				}
			}
		}
	}
	testPred(supervisor1)

	supervisor1.shutdown()

	workers := make([]db.Worker, len(addrs))
	for i, addr := range addrs {
		workers[i] = db.Worker{int64(i), addr}
	}

	deploymentReqs := make([]db.DeploymentReq, len(deployments))
	i := 0
	for deployId, model := range deployments {
		deploymentReqs[i] = db.DeploymentReq{
			Username:         model.User,
			Modelname:        model.Model,
			Version:          model.Version,
			LastDeployId:     deployId,
			ValidInstanceIds: model.Instances,
		}
		i++
	}

	c2 := SupervisorConfig{
		Storage:       testStorage,
		TLSHandshaker: client,
		Workers:       workers,
		Deployments:   deploymentReqs,
		// Discard logs for tests
		Logger: log.New(ioutil.Discard, "", 0),
	}
	supervisor2, err := NewSupervisor(c2)
	if err != nil {
		t.Errorf("failed to start second supervisor: %v", err)
		return
	}
	// discard logs for tests
	supervisor2.Logger = log.New(ioutil.Discard, "", 0)

	testPred(supervisor2)
}