func testSupervisor(t *testing.T, nWorkers int, test func(*Supervisor)) {
	tempdir, err := ioutil.TempDir("", "")
	if err != nil {
		t.Error(err)
		return
	}
	defer os.RemoveAll(tempdir)

	client, err := tlsconfig.NewClient(tempdir)
	if err != nil {
		t.Error(err)
		return
	}

	testStorage = NewTestStorage()
	c := SupervisorConfig{
		Storage:       testStorage,
		TLSHandshaker: client,
		// Discard logs for tests
		Logger: log.New(ioutil.Discard, "", 0),
	}
	supervisor, err := NewSupervisor(c)
	if err != nil {
		t.Errorf("failed to create supervisor: %v", err)
		return
	}
	// discard logs for tests
	supervisor.Logger = log.New(ioutil.Discard, "", 0)

	for i := 0; i < nWorkers; i++ {
		mps, err := mps.NewMPS()
		if err != nil {
			t.Error(err)
			return
		}
		defer mps.Reset()
		// configure server's TLS config directly rather than with a handshake
		server := httptest.NewUnstartedServer(mps)
		server.TLS, err = client.ServerConfig(net.ParseIP("127.0.0.1"))
		if err != nil {
			t.Error(err)
			return
		}
		server.StartTLS()
		defer server.Close()

		addr := strings.TrimPrefix(server.URL, "https://")

		if err := supervisor.AddWorker(addr, int64(i)); err != nil {
			t.Errorf("could not add worker: %v", err)
			return
		}
	}
	test(supervisor)
}
func TestSupervisorReboot(t *testing.T) {

	tempdir, err := ioutil.TempDir("", "")
	if err != nil {
		t.Error(err)
		return
	}
	defer os.RemoveAll(tempdir)

	client, err := tlsconfig.NewClient(tempdir)
	if err != nil {
		t.Error(err)
		return
	}

	testStorage = NewTestStorage()
	c := SupervisorConfig{
		Storage:       testStorage,
		TLSHandshaker: client,
		// Discard logs for tests
		Logger: log.New(ioutil.Discard, "", 0),
	}
	supervisor1, err := NewSupervisor(c)
	if err != nil {
		t.Errorf("failed to create supervisor: %v", err)
		return
	}

	addrs := [5]string{"", "", "", "", ""}

	for i := 0; i < 5; i++ {
		mps, err := mps.NewMPS()
		if err != nil {
			t.Error(err)
			return
		}
		defer mps.Reset()
		// configure server's TLS config directly rather than with a handshake
		server := httptest.NewUnstartedServer(mps)
		server.TLS, err = client.ServerConfig(net.ParseIP("127.0.0.1"))
		if err != nil {
			t.Error(err)
			return
		}
		server.StartTLS()
		defer server.Close()

		addrs[i] = strings.TrimPrefix(server.URL, "https://")
	}

	type Model struct {
		User, Model string
		Version     int
		Instances   []int64
	}
	deployments := map[int64]*Model{}

	for i, addr := range addrs {
		if err := supervisor1.AddWorker(addr, int64(i)); err != nil {
			t.Errorf("could not add worker: %v", err)
			return
		}
	}

	user := "******"
	// Deploy 5 different models
	modelnames := []string{
		"hellor_0", "hellor_1", "hellor_2", "hellor_3", "hellor_4",
	}

	// Override the storage's NewDeployment function so we can track which
	// deployment is associated with which model version.
	// This is normally done by the database.
	nInstances := 2
	var nextId int64 = 1
	var nextDeployId int64 = 1
	testStorage.newDeployment = func(user, model string, version int) (deployId int64, instIds []int64, err error) {
		deployId = nextDeployId
		nextDeployId++

		ids := make([]int64, nInstances)
		instIds = make([]int64, nInstances)
		for i := range ids {
			ids[i] = nextId
			instIds[i] = nextId
			nextId++
		}

		deployments[deployId] = &Model{user, model, version, ids}
		return deployId, instIds, nil
	}
	for _, model := range modelnames {
		if err := supervisor1.Deploy(user, model, 1); err != nil {
			t.Errorf("could not deploy: %v", err)
			return
		}
	}

	nExp := len(modelnames)

	if n := len(deployments); n != nExp {
		t.Errorf("expected %d deployments, got: %d", nExp, n)
		return
	}
	for _, model := range deployments {
		if n := len(model.Instances); n != nInstances {
			t.Errorf("expected %d deployments per model, got %d", nInstances, n)
			return
		}
	}

	testPred := func(super *Supervisor) {
		for _, model := range modelnames {
			hf := func(w http.ResponseWriter, r *http.Request) {
				super.Predict(user, model, w, r)
			}
			s := httptest.NewServer(http.HandlerFunc(hf))
			defer s.Close()
			for i := 0; i < 100; i++ {
				body := bytes.NewReader([]byte(`{"name":"bigdatabob"}`))
				resp, err := http.Post(s.URL, "application/json", body)
				if err != nil {
					t.Errorf("request failed: %v", err)
					return
				}
				respbody, err := ioutil.ReadAll(resp.Body)
				resp.Body.Close()
				if err != nil {
					t.Errorf("could not read body: %v", err)
					return
				}
				if resp.StatusCode != http.StatusOK {
					t.Errorf("expected 200, got %s: %s", resp.Status, respbody)
					return
				}
			}
		}
	}
	testPred(supervisor1)

	supervisor1.shutdown()

	workers := make([]db.Worker, len(addrs))
	for i, addr := range addrs {
		workers[i] = db.Worker{int64(i), addr}
	}

	deploymentReqs := make([]db.DeploymentReq, len(deployments))
	i := 0
	for deployId, model := range deployments {
		deploymentReqs[i] = db.DeploymentReq{
			Username:         model.User,
			Modelname:        model.Model,
			Version:          model.Version,
			LastDeployId:     deployId,
			ValidInstanceIds: model.Instances,
		}
		i++
	}

	c2 := SupervisorConfig{
		Storage:       testStorage,
		TLSHandshaker: client,
		Workers:       workers,
		Deployments:   deploymentReqs,
		// Discard logs for tests
		Logger: log.New(ioutil.Discard, "", 0),
	}
	supervisor2, err := NewSupervisor(c2)
	if err != nil {
		t.Errorf("failed to start second supervisor: %v", err)
		return
	}
	// discard logs for tests
	supervisor2.Logger = log.New(ioutil.Discard, "", 0)

	testPred(supervisor2)
}
Beispiel #3
0
func NewApp(config *AppConfig) (*App, error) {

	if config.ModelReplication < 1 {
		return nil, fmt.Errorf("model replication level cannot be less than 1")
	}
	var accessLog io.Writer
	if config.AccessLog == "" {
		accessLog = os.Stderr
	} else {
		flags := os.O_WRONLY | os.O_APPEND | os.O_CREATE
		file, err := os.OpenFile(config.AccessLog, flags, 0644)
		if err != nil {
			return nil, err
		}
		log.Println("using access log:", config.AccessLog)
		accessLog = file
	}

	appDB, err := sqlutil.NewReconnectingDB("mysql", config.DBConnStr)
	if err != nil {
		return nil, fmt.Errorf("could not connect to database: %v", err)
	}

	// do the static directories exist?
	for _, dir := range []string{"css", "js", "fonts", "img", "model"} {
		d := filepath.Join(config.StaticDir, dir)
		_, err := os.Stat(d)
		if err != nil {
			return nil, fmt.Errorf("directory %s doesn't exist", d)
		}
	}

	serveDir := func(dirName string) http.Handler {
		d := filepath.Join(config.StaticDir, dirName)
		return http.StripPrefix("/"+dirName, http.FileServer(http.Dir(d)))
	}

	store := sessions.NewCookieStore([]byte(config.SecretSessionKey))

	client, err := tlsconfig.NewClient(config.CertDir)
	if err != nil {
		return nil, fmt.Errorf("certificate loading failed: %v", err)
	}

	app := App{
		store:            store,
		db:               appDB,
		staticDir:        config.StaticDir,
		bundleDir:        config.BundleDir,
		modelLogsDir:     config.ModelLogs,
		modelReplication: config.ModelReplication,
		isdev:            config.IsDev,
		serviceName:      config.ServiceName,
		sharingDisabled:  config.DisableSharing,
	}

	tx, err := app.db.Begin()
	if err != nil {
		return nil, err
	}
	workers, err := db.Workers(tx)
	if err != nil {
		return nil, err
	}
	deploymentReqs, err := db.DeploymentRequests(tx)
	if err != nil {
		return nil, err
	}
	tx.Rollback()
	c := alb.SupervisorConfig{
		Storage:       &storage{new(sync.Mutex), &app},
		TLSHandshaker: client,
		Workers:       workers,
		Deployments:   deploymentReqs,
	}
	app.sup, err = alb.NewSupervisor(c)
	if err != nil {
		return nil, fmt.Errorf("failed to initialize communcation with MPSs: %v", err)
	}
	app.sup.Logger = log.New(os.Stderr, "supervisor: ", log.LstdFlags)

	r := mux.NewRouter()
	r.PathPrefix("/css/").Handler(serveDir("css"))
	r.PathPrefix("/js/").Handler(serveDir("js"))
	r.PathPrefix("/fonts/").Handler(serveDir("fonts"))
	r.PathPrefix("/img/").Handler(serveDir("img"))

	r.HandleFunc("/login", app.handleLogin)
	r.HandleFunc("/register", app.handleRegister)
	r.HandleFunc("/verify-password", app.handleVerifyPassword)
	r.HandleFunc("/logout", app.handleLogout)

	r.Handle("/favicon.ico", app.serveFile("favicon.ico"))

	// deployment routes
	r.PathPrefix("/deployer").HandlerFunc(app.handleDeployment)
	r.HandleFunc("/verify", app.handleOldVerify)

	authedRouter := mux.NewRouter()
	authedRouter.NotFoundHandler = app.serveFile("404.html")

	authedRouter.PathPrefix("/model/").Handler(serveDir("model"))
	authedRouter.Handle("/", app.serveFile("index.html"))
	authedRouter.Handle("/account", app.serveFile("account.html"))
	authedRouter.HandleFunc("/user.json", app.handleUser)
	authedRouter.HandleFunc("/users/{name}", app.handleUserByName)

	// apikey stuff
	authedRouter.HandleFunc("/apikey", app.handleApikey)

	// model specific pages
	authedRouter.Handle("/models/{name}", app.serveFile("model/index.html"))
	authedRouter.HandleFunc("/models/{name}/json", app.handleModel)
	authedRouter.Handle("/models/{name}/scoring", app.serveFile("model/scoring.html"))
	authedRouter.Handle("/models/{name}/versions", app.serveFile("model/versions.html"))
	authedRouter.HandleFunc("/models/{name}/versions.json", app.handleModelVersions)
	authedRouter.Handle("/models/{name}/logs", app.serveFile("model/logs.html"))
	authedRouter.HandleFunc("/models/{name}/logs/json", app.handleModelLogs)
	authedRouter.Handle("/models/{name}/settings", app.serveFile("model/settings.html"))
	authedRouter.Handle("/models/{name}/form-builder", app.serveFile("model/form-builder.html"))
	authedRouter.HandleFunc("/models/{name}/redeploy/{version}", app.handleModelRedeploy)
	authedRouter.HandleFunc("/models/{name}/shared", app.handleModelSharedUsers)
	authedRouter.HandleFunc("/models/{name}/startshare/{user}", app.handleModelStartSharing)
	authedRouter.HandleFunc("/models/{name}/stopshare/{user}", app.handleModelStopSharing)
	authedRouter.HandleFunc("/models/{name}/example", app.handleModelExample)

	// model actions
	authedRouter.HandleFunc("/models/{name}/action/{action}", app.handleModelStateChange)

	authedRouter.HandleFunc("/models", app.handleUserModels)
	authedRouter.HandleFunc("/shared", app.handleSharedModels)
	authedRouter.HandleFunc("/whoami", app.handleWhoami)

	adminRouter := mux.NewRouter()
	authedRouter.PathPrefix("/admin").Handler(app.restrictAdmin(adminRouter))

	adminRouter.Handle("/admin", app.serveFile("admin/index.html"))
	adminRouter.Handle("/admin/users", app.serveFile("admin/users.html"))
	adminRouter.HandleFunc("/admin/users.json", app.handleUsersData)
	adminRouter.HandleFunc("/admin/users/create", app.handleUsersCreate)
	adminRouter.HandleFunc("/admin/users/delete", app.handleUserDelete)
	adminRouter.HandleFunc("/admin/users/makeadmin", app.handleUserMakeAdmin)
	adminRouter.HandleFunc("/admin/users/unmakeadmin", app.handleUserUnmakeAdmin)
	adminRouter.HandleFunc("/admin/users/setpass", app.handleUserSetPass)

	adminRouter.Handle("/admin/models", app.serveFile("admin/models.html"))
	adminRouter.HandleFunc("/admin/models.json", app.handleAllUserModels)

	adminRouter.HandleFunc("/admin/servers", app.handleServers)
	adminRouter.HandleFunc("/admin/servers.json", app.handleServersData)
	adminRouter.HandleFunc("/admin/servers/remove", app.handleServersRemove)

	// get initial authentication for models
	auth, err := func() (*db.PredictionAuth, error) {
		tx, err := app.db.Begin()
		if err != nil {
			return nil, err
		}
		defer tx.Rollback()
		return db.GetAuth(tx)
	}()
	if err != nil {
		return nil, err
	}

	// must enforce no caching for jsx pages
	noCaching := func(handler http.Handler) http.Handler {
		hf := func(w http.ResponseWriter, r *http.Request) {
			// see: http://goo.gl/itaIDo
			w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
			w.Header().Set("Pragma", "no-cache")
			w.Header().Set("Expires", "0")

			handler.ServeHTTP(w, r)
		}
		return http.HandlerFunc(hf)
	}

	h := handlers.LoggingHandler(accessLog, noCaching(app.restrict(authedRouter)))
	r.NotFoundHandler = h

	predRouter := alb.NewPredictionRouter(r, app.sup, auth)

	// TODO: add api routes

	app.router = predRouter

	return &app, nil
}