Ejemplo n.º 1
0
func main() {
	log.Infof("starting schema...")

	options := struct {
		Driver string `goptions:"-t,--type, obligatory, description='Type of database backend'"`
		DSN    string `goptions:"-d,--database, obligatory, description='DSN of the database backend'"`
	}{
	// No defaults
	}
	goptions.ParseAndFail(&options)

	database := &db.DB{
		Driver: options.Driver,
		DSN:    options.DSN,
	}

	log.Debugf("connecting to %s database at %s", database.Driver, database.DSN)
	if err := database.Connect(); err != nil {
		log.Errorf("failed to connect to %s database at %s: %s",
			database.Driver, database.DSN, err)
	}

	if err := database.Setup(); err != nil {
		log.Errorf("failed to set up schema in %s database at %s: %s",
			database.Driver, database.DSN, err)
		return
	}

	log.Infof("deployed schema version %d", db.CurrentSchema)
}
Ejemplo n.º 2
0
func (s *Supervisor) PurgeArchives() {
	log.Debugf("scanning for archives needing to be expired")

	// mark archives past their retention policy as expired
	toExpire, err := s.Database.GetExpiredArchives()
	if err != nil {
		log.Errorf("error retrieving archives needing to be expired: %s", err.Error())
	}
	for _, archive := range toExpire {
		log.Infof("marking archive %s has expiration date %s, marking as expired", archive.UUID, archive.ExpiresAt)
		err := s.Database.ExpireArchive(archive.UUID)
		if err != nil {
			log.Errorf("error marking archive %s as expired: %s", archive.UUID, err)
			continue
		}
	}

	// get archives that are not valid or purged
	toPurge, err := s.Database.GetArchivesNeedingPurge()
	if err != nil {
		log.Errorf("error retrieving archives to purge: %s", err.Error())
	}

	for _, archive := range toPurge {
		log.Infof("requesting purge of archive %s due to status '%s'", archive.UUID, archive.Status)
		task, err := s.Database.CreatePurgeTask("system", archive, s.PurgeAgent)
		if err != nil {
			log.Errorf("error scheduling purge of archive %s: %s", archive.UUID, err)
			continue
		}
		s.ScheduleTask(task)
	}
}
Ejemplo n.º 3
0
func main() {
	log.Infof("starting schema...")

	options := struct {
		Help    bool   `goptions:"-h, --help, description='Show the help screen'"`
		Driver  string `goptions:"-t, --type, description='Type of database backend (postgres or mysql)'"`
		DSN     string `goptions:"-d,--database, description='DSN of the database backend'"`
		Version bool   `goptions:"-v, --version, description='Display the SHIELD version'"`
	}{
	// No defaults
	}
	if err := goptions.Parse(&options); err != nil {
		fmt.Printf("%s\n", err)
		goptions.PrintHelp()
		return
	}
	if options.Help {
		goptions.PrintHelp()
		os.Exit(0)
	}
	if options.Version {
		if Version == "" {
			fmt.Printf("shield-schema (development)%s\n", Version)
		} else {
			fmt.Printf("shield-schema v%s\n", Version)
		}
		os.Exit(0)
	}
	if options.Driver == "" {
		fmt.Fprintf(os.Stderr, "You must indicate which type of database you wish to manage, via the `--type` option.\n")
		os.Exit(1)
	}
	if options.DSN == "" {
		fmt.Fprintf(os.Stderr, "You must specify the DSN of your database, via the `--database` option.\n")
		os.Exit(1)
	}

	database := &db.DB{
		Driver: options.Driver,
		DSN:    options.DSN,
	}

	log.Debugf("connecting to %s database at %s", database.Driver, database.DSN)
	if err := database.Connect(); err != nil {
		log.Errorf("failed to connect to %s database at %s: %s",
			database.Driver, database.DSN, err)
	}

	if err := database.Setup(); err != nil {
		log.Errorf("failed to set up schema in %s database at %s: %s",
			database.Driver, database.DSN, err)
		return
	}

	log.Infof("deployed schema version %d", db.CurrentSchema)
}
Ejemplo n.º 4
0
func main() {
	var opts ShieldAgentOpts
	opts.Log = "Info"
	if err := goptions.Parse(&opts); err != nil {
		fmt.Printf("%s\n", err)
		goptions.PrintHelp()
		return
	}
	if opts.Help {
		goptions.PrintHelp()
		os.Exit(0)
	}
	if opts.Version {
		if Version == "" {
			fmt.Printf("shield-agent (development)%s\n", Version)
		} else {
			fmt.Printf("shield-agent v%s\n", Version)
		}
		os.Exit(0)
	}
	if opts.ConfigFile == "" {
		fmt.Fprintf(os.Stderr, "You must specify a configuration file via `--config`\n")
		os.Exit(1)
	}

	log.SetupLogging(log.LogConfig{Type: "console", Level: opts.Log})
	log.Infof("starting agent")

	ag := agent.NewAgent()
	if err := ag.ReadConfig(opts.ConfigFile); err != nil {
		log.Errorf("configuration failed: %s", err)
		return
	}
	ag.Run()
}
Ejemplo n.º 5
0
func (oa OAuthenticator) RequireAuth(w http.ResponseWriter, r *http.Request) {
	sess, err := gothic.Store.Get(r, gothic.SessionName)
	if err != nil {
		if _, ok := err.(securecookie.Error); ok {
			r.Header.Set("Cookie", "")
			sess, err = gothic.Store.New(r, gothic.SessionName)
			if err != nil {
				w.WriteHeader(500)
				w.Write([]byte("Failure generating a new session: " + err.Error()))
				return
			}
		} else {
			log.Errorf("%s", err)
			w.WriteHeader(500)
			w.Write([]byte("Unexpected error retrieving session data"))
			return
		}
	}
	sess.AddFlash(r.URL.Path)
	sess.Save(r, w)
	// only start oauth redirection if we're hitting the auth APIs, or web UI
	if ShouldOAuthRedirect(r.URL.Path) {
		log.Debugf("Starting OAuth Process for request: %s", r)
		gothic.BeginAuthHandler(w, r)
	} else {
		// otherwise set auth header for api clients to understand oauth is needed
		log.Debugf("Unauthenticated API Request received, OAuth required, sending 401")
		w.Header().Set("WWW-Authenticate", "Bearer")
		w.WriteHeader(401)
		w.Write([]byte("Unauthorized"))
	}
}
Ejemplo n.º 6
0
func main() {
	supervisor.Version = Version
	var opts ShielddOpts
	opts.Log = "Info"
	if err := goptions.Parse(&opts); err != nil {
		fmt.Printf("%s\n", err)
		goptions.PrintHelp()
		return
	}

	if opts.Help {
		goptions.PrintHelp()
		os.Exit(0)
	}
	if opts.Version {
		if Version == "" {
			fmt.Printf("shieldd (development)\n")
		} else {
			fmt.Printf("shieldd v%s\n", Version)
		}
		os.Exit(0)
	}

	if opts.ConfigFile == "" {
		fmt.Fprintf(os.Stderr, "No config specified. Please try again using the -c/--config argument\n")
		os.Exit(1)
	}

	log.SetupLogging(log.LogConfig{Type: "console", Level: opts.Log})
	log.Infof("starting shield daemon")

	s := supervisor.NewSupervisor()
	if err := s.ReadConfig(opts.ConfigFile); err != nil {
		log.Errorf("Failed to load config: %s", err)
		return
	}

	s.SpawnAPI()
	s.SpawnWorkers()

	if err := s.Run(); err != nil {
		log.Errorf("shield daemon failed: %s", err)
	}
	log.Infof("stopping daemon")
}
Ejemplo n.º 7
0
func (agent *Agent) ServeOne(l net.Listener, async bool) {
	c, err := l.Accept()
	if err != nil {
		log.Errorf("failed to accept: %s\n", err)
		return
	}

	conn, chans, reqs, err := ssh.NewServerConn(c, agent.config)
	if err != nil {
		log.Errorf("handshake failed: %s\n", err)
		return
	}

	if async {

		go agent.handleConn(conn, chans, reqs)
	} else {
		agent.handleConn(conn, chans, reqs)
	}
}
Ejemplo n.º 8
0
func (ws *WebServer) Start() {
	err := ws.Setup()
	if err != nil {
		panic("Could not set up WebServer for SHIELD: " + err.Error())
	}
	log.Debugf("Starting WebServer on '%s'...", ws.Addr)
	err = http.ListenAndServe(ws.Addr, nil)
	if err != nil {
		log.Errorf("HTTP API failed %s", err.Error())
		panic("Cannot setup WebServer, aborting. Check logs for details.")
	}
}
Ejemplo n.º 9
0
func JSON(w http.ResponseWriter, thing interface{}) {
	bytes, err := json.Marshal(thing)
	if err != nil {
		log.Errorf("Cannot marshal JSON: <%s>\n", err)
		w.WriteHeader(500)
		return
	}

	w.Header().Set("Content-Type", "application/json")
	w.WriteHeader(200)
	w.Write(bytes)
	return
}
Ejemplo n.º 10
0
func (s *Supervisor) CheckSchedule() {
	for _, job := range s.jobq {
		if !job.Runnable() {
			continue
		}

		log.Infof("scheduling execution of job %s [%s]", job.Name, job.UUID)
		task, err := s.Database.CreateBackupTask("system", job)
		if err != nil {
			log.Errorf("job -> task conversion / database update failed: %s", err)
			continue
		}
		s.ScheduleTask(task)

		err = job.Reschedule()
		if err != nil {
			log.Errorf("error encountered while determining next run of %s (%s): %s",
				job.UUID, job.Spec, err)
		} else {
			log.Infof("next run of %s [%s] which runs %s is at %s",
				job.Name, job.UUID, job.Spec, job.NextRun)
		}
	}
}
Ejemplo n.º 11
0
func main() {
	var opts ShieldAgentOpts
	opts.Log = "Info"
	if err := goptions.Parse(&opts); err != nil {
		fmt.Printf("%s\n", err)
		goptions.PrintHelp()
		return
	}

	log.SetupLogging(log.LogConfig{Type: "console", Level: opts.Log})
	log.Infof("starting agent")

	ag := agent.NewAgent()
	if err := ag.ReadConfig(opts.ConfigFile); err != nil {
		log.Errorf("configuration failed: %s", err)
		return
	}
	ag.Run()
}
Ejemplo n.º 12
0
func (s *Supervisor) Resync() error {
	jobq, err := s.Database.GetAllJobs(nil)
	if err != nil {
		return err
	}

	// calculate the initial run of each job
	for _, job := range jobq {
		err := job.Reschedule()
		if err != nil {
			log.Errorf("error encountered while determining next run of %s [%s] which runs %s: %s",
				job.Name, job.UUID, job.Spec, err)
		} else {
			log.Infof("initial run of %s [%s] which runs %s is at %s",
				job.Name, job.UUID, job.Spec, job.NextRun)
		}
	}

	s.jobq = jobq
	return nil
}
Ejemplo n.º 13
0
func (s *Supervisor) FailUnfinishedTasks() error {
	tasks, err := s.Database.GetAllTasks(
		&db.TaskFilter{
			ForStatus: db.RunningStatus,
		},
	)
	if err != nil {
		return fmt.Errorf("Failed to sweep database of running tasks: %s", err)
	}

	now := time.Now()
	for _, task := range tasks {
		log.Warnf("Found task %s in 'running' state at startup; setting to 'failed'", task.UUID)
		if err := s.Database.FailTask(task.UUID, now); err != nil {
			return fmt.Errorf("Failed to sweep database of running tasks [%s]: %s", task.UUID, err)
		}
		if task.Op == db.BackupOperation && task.ArchiveUUID != nil {
			archive, err := s.Database.GetArchive(task.ArchiveUUID)
			if err != nil {
				log.Warnf("Unable to retrieve archive %s (for task %s) from the database: %s",
					task.ArchiveUUID, task.UUID, err)
				continue
			}
			log.Warnf("Found archive %s for task %s, purging", archive.UUID, task.UUID)
			task, err := s.Database.CreatePurgeTask("", archive, s.PurgeAgent)
			if err != nil {
				log.Errorf("Failed to purge archive %s (for task %s, which was running at boot): %s",
					archive.UUID, task.UUID, err)
			} else {
				s.ScheduleTask(task)
			}
		}
	}

	return nil
}
Ejemplo n.º 14
0
func (s *Supervisor) Run() error {
	if err := s.Database.Connect(); err != nil {
		return fmt.Errorf("failed to connect to %s database at %s: %s\n",
			s.Database.Driver, s.Database.DSN, err)
	}

	if err := s.Database.CheckCurrentSchema(); err != nil {
		return fmt.Errorf("database failed schema version check: %s\n", err)
	}

	if err := s.Resync(); err != nil {
		return err
	}
	if err := s.FailUnfinishedTasks(); err != nil {
		return err
	}
	if err := s.ReschedulePendingTasks(); err != nil {
		return err
	}

	for {
		select {
		case <-s.resync:
			if err := s.Resync(); err != nil {
				log.Errorf("resync error: %s", err)
			}

		case <-s.purge.C:
			s.PurgeArchives()

		case <-s.tick.C:
			s.CheckSchedule()

			// see if any tasks have been running past the timeout period
			if len(s.runq) > 0 {
				ok := true
				lst := make([]*db.Task, 0)
				now := timestamp.Now()

				for _, runtask := range s.runq {
					if now.After(runtask.TimeoutAt) {
						s.Database.CancelTask(runtask.UUID, now.Time())
						log.Errorf("shield timed out task '%s' after running for %v", runtask.UUID, s.Timeout)
						ok = false

					} else {
						lst = append(lst, runtask)
					}
				}

				if !ok {
					s.runq = lst
				}
			}

			// see if we have anything in the schedule queue
		SchedQueue:
			for len(s.schedq) > 0 {
				select {
				case s.workers <- s.schedq[0]:
					s.Database.StartTask(s.schedq[0].UUID, time.Now())
					s.schedq[0].Attempts++
					log.Infof("sent a task to a worker")
					s.runq = append(s.runq, s.schedq[0])
					log.Debugf("added task to the runq")
					s.schedq = s.schedq[1:]
				default:
					break SchedQueue
				}
			}

		case adhoc := <-s.adhoc:
			s.ScheduleAdhoc(adhoc)

		case u := <-s.updates:
			switch u.Op {
			case STOPPED:
				log.Infof("  %s: job stopped at %s", u.Task, u.StoppedAt)
				s.RemoveTaskFromRunq(u.Task)
				if err := s.Database.CompleteTask(u.Task, u.StoppedAt); err != nil {
					log.Errorf("  %s: !! failed to update database - %s", u.Task, err)
				}

			case FAILED:
				log.Warnf("  %s: task failed!", u.Task)
				s.RemoveTaskFromRunq(u.Task)
				if err := s.Database.FailTask(u.Task, u.StoppedAt); err != nil {
					log.Errorf("  %s: !! failed to update database - %s", u.Task, err)
				}

			case OUTPUT:
				log.Infof("  %s> %s", u.Task, strings.Trim(u.Output, "\n"))
				if err := s.Database.UpdateTaskLog(u.Task, u.Output); err != nil {
					log.Errorf("  %s: !! failed to update database - %s", u.Task, err)
				}

			case RESTORE_KEY:
				log.Infof("  %s: restore key is %s", u.Task, u.Output)
				if id, err := s.Database.CreateTaskArchive(u.Task, u.Output, time.Now()); err != nil {
					log.Errorf("  %s: !! failed to update database - %s", u.Task, err)
				} else {
					if !u.TaskSuccess {
						s.Database.InvalidateArchive(id)
					}
				}

			case PURGE_ARCHIVE:
				log.Infof("  %s: archive %s purged from storage", u.Task, u.Archive)
				if err := s.Database.PurgeArchive(u.Archive); err != nil {
					log.Errorf("  %s: !! failed to update database - %s", u.Task, err)
				}

			default:
				log.Errorf("  %s: !! unrecognized op type", u.Task)
			}
		}
	}
}
Ejemplo n.º 15
0
func (s *Supervisor) ScheduleAdhoc(a *db.Task) {
	log.Infof("schedule adhoc %s job", a.Op)

	switch a.Op {
	case db.BackupOperation:
		// expect a JobUUID to move to the schedq Immediately
		for _, job := range s.jobq {
			if !uuid.Equal(job.UUID, a.JobUUID) {
				continue
			}

			log.Infof("scheduling immediate (ad hoc) execution of job %s [%s]", job.Name, job.UUID)
			task, err := s.Database.CreateBackupTask(a.Owner, job)
			if err != nil {
				log.Errorf("job -> task conversion / database update failed: %s", err)
				if a.TaskUUIDChan != nil {
					a.TaskUUIDChan <- &db.TaskInfo{
						Err:  true,
						Info: err.Error(),
					}
				}
				continue
			}
			if a.TaskUUIDChan != nil {
				a.TaskUUIDChan <- &db.TaskInfo{
					Err:  false,
					Info: task.UUID.String(),
				}
			}

			s.ScheduleTask(task)
		}

	case db.RestoreOperation:
		archive, err := s.Database.GetArchive(a.ArchiveUUID)
		if err != nil {
			log.Errorf("unable to find archive %s for restore task: %s", a.ArchiveUUID, err)
			return
		}
		target, err := s.Database.GetTarget(a.TargetUUID)
		if err != nil {
			log.Errorf("unable to find target %s for restore task: %s", a.TargetUUID, err)
			return
		}
		task, err := s.Database.CreateRestoreTask(a.Owner, archive, target)
		if err != nil {
			log.Errorf("restore task database creation failed: %s", err)
			if a.TaskUUIDChan != nil {
				a.TaskUUIDChan <- &db.TaskInfo{
					Err:  true,
					Info: err.Error(),
				}
			}
			return
		}
		if a.TaskUUIDChan != nil {
			a.TaskUUIDChan <- &db.TaskInfo{
				Err:  false,
				Info: task.UUID.String(),
			}
		}

		s.ScheduleTask(task)
	}
}
Ejemplo n.º 16
0
func (ws *WebServer) Setup() error {
	var err error
	log.Debugf("Configuring WebServer...")
	if err := ws.Database.Connect(); err != nil {
		log.Errorf("Failed to connect to %s database at %s: %s", ws.Database.Driver, ws.Database.DSN, err)
		return err
	}

	if ws.Auth.OAuth.Provider != "" {
		log.Debugf("Configuring OAuth Session store")
		maxSessionAge := ws.Auth.OAuth.Sessions.MaxAge
		authKey := securecookie.GenerateRandomKey(64)
		encKey := securecookie.GenerateRandomKey(32)
		switch ws.Auth.OAuth.Sessions.Type {
		case "sqlite3":
			log.Debugf("Using sqlite3 as a session store")
			store, err := sqlitestore.NewSqliteStore(ws.Auth.OAuth.Sessions.DSN, "http_sessions", "/", maxSessionAge, authKey, encKey)
			if err != nil {
				log.Errorf("Error setting up sessions database: %s", err)
				return err
			}
			gothic.Store = store
		case "postgres":
			log.Debugf("Using postgres as a session store")
			gothic.Store = pgstore.NewPGStore(ws.Auth.OAuth.Sessions.DSN, authKey, encKey)
			gothic.Store.(*pgstore.PGStore).Options.MaxAge = maxSessionAge
		case "mock":
			log.Debugf("Using mocked session store")
			// does nothing, to avoid being accidentally used in prod
		default:
			log.Errorf("Invalid DB Backend for OAuth sessions database")
			return err
		}

		gob.Register(map[string]interface{}{})
		switch ws.Auth.OAuth.Provider {
		case "github":
			log.Debugf("Using github as the oauth provider")
			goth.UseProviders(github.New(ws.Auth.OAuth.Key, ws.Auth.OAuth.Secret, fmt.Sprintf("%s/v1/auth/github/callback", ws.Auth.OAuth.BaseURL), "read:org"))
			OAuthVerifier = &GithubVerifier{Orgs: ws.Auth.OAuth.Authorization.Orgs}
		case "cloudfoundry":
			log.Debugf("Using cloudfoundry as the oauth provider")
			goth.UseProviders(cloudfoundry.New(ws.Auth.OAuth.ProviderURL, ws.Auth.OAuth.Key, ws.Auth.OAuth.Secret, fmt.Sprintf("%s/v1/auth/cloudfoundry/callback", ws.Auth.OAuth.BaseURL), "openid,scim.read"))
			OAuthVerifier = &UAAVerifier{Groups: ws.Auth.OAuth.Authorization.Orgs, UAA: ws.Auth.OAuth.ProviderURL}
			p, err := goth.GetProvider("cloudfoundry")
			if err != nil {
				return err
			}
			p.(*cloudfoundry.Provider).Client = ws.Auth.OAuth.Client
		case "faux":
			log.Debugf("Using mocked session store")
			// does nothing, to avoid being accidentally used in prod
		default:
			log.Errorf("Invalid OAuth provider specified.")
			return err
		}

		gothic.GetProviderName = func(req *http.Request) (string, error) {
			return ws.Auth.OAuth.Provider, nil
		}

		gothic.SetState = func(req *http.Request) string {
			sess, _ := gothic.Store.Get(req, gothic.SessionName)
			sess.Values["state"] = uuid.New()
			return sess.Values["state"].(string)
		}
	}

	protectedAPIs, err := ws.ProtectedAPIs()
	if err != nil {
		log.Errorf("Could not set up HTTP routes: " + err.Error())
		return err
	}

	if ws.Auth.OAuth.Provider != "" {
		log.Debugf("Enabling OAuth handlers for HTTP requests")
		UserAuthenticator = OAuthenticator{
			Cfg: ws.Auth.OAuth,
		}
	} else {
		log.Debugf("Enabling Basic Auth handlers for HTTP requests")
		UserAuthenticator = BasicAuthenticator{
			Cfg: ws.Auth.Basic,
		}
	}

	http.Handle("/", ws.UnauthenticatedResources(Authenticate(ws.Auth.Tokens, protectedAPIs)))
	return nil
}
Ejemplo n.º 17
0
func (oa OAuthenticator) OAuthCallback() http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		log.Debugf("Incoming Auth request: %s", r)
		sess, err := gothic.Store.Get(r, gothic.SessionName)
		if err != nil {
			log.Errorf("Error retrieving session info: %s", err)
			w.WriteHeader(500)
			return
		}
		log.Debugf("Processing oauth callback for '%s'", sess.ID)
		if gothic.GetState(r) != sess.Values["state"] {
			w.WriteHeader(403)
			w.Write([]byte("Unauthorized"))
			return
		}

		if r.URL.Query().Get("code") == "" {
			log.Errorf("No code detected in oauth callback: %v", r)
			w.WriteHeader(403)
			w.Write([]byte("No oauth code issued from provider"))
			return
		}

		user, err := gothic.CompleteUserAuth(w, r)
		if err != nil {
			log.Errorf("Error verifying oauth success: %s. Request: %v", err, r)
			w.WriteHeader(403)
			w.Write([]byte("UnOAuthorized"))
			return
		}

		log.Debugf("Authenticated user %#v", user)

		ts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: user.AccessToken})
		ctx := context.WithValue(oauth2.NoContext, oauth2.HTTPClient, oa.Cfg.Client)
		tc := oauth2.NewClient(ctx, ts)

		log.Debugf("Checking authorization...")
		membership, err := OAuthVerifier.Membership(user, tc)
		if err != nil {
			log.Errorf("Error retreiving user membership: %s", err)
			w.WriteHeader(403)
			w.Write([]byte("Unable to verify your membership"))
			return
		}

		if !OAuthVerifier.Verify(user.NickName, membership) {
			log.Debugf("Authorization denied")
			w.WriteHeader(403)
			w.Write([]byte("You are not authorized to view this content"))
			return
		}

		log.Infof("Successful login for %s", user.NickName)

		redirect := "/"
		if flashes := sess.Flashes(); len(flashes) > 0 {
			if flash, ok := flashes[0].(string); ok {
				// don't redirect back to api calls, to prevent auth redirection loops
				if !apiCall.MatchString(flash) || cliAuthCall.MatchString(flash) {
					redirect = flash
				}
			}
		}

		sess.Values["User"] = user.NickName
		sess.Values["Membership"] = membership
		err = sess.Save(r, w)
		if err != nil {
			log.Errorf("Error saving session: %s", err)
			w.WriteHeader(500)
			w.Write([]byte("Unable to save authentication data. Check the SHIELD logs for more info."))
			return
		}

		http.Redirect(w, r, redirect, 302) // checks auth
	})
}
Ejemplo n.º 18
0
func bail(w http.ResponseWriter, e error) {
	w.WriteHeader(500)
	log.Errorf("Request bailed: <%s>\n", e)
	return
}
Ejemplo n.º 19
0
func worker(id uint, privateKeyFile string, work chan *db.Task, updates chan WorkerUpdate) {
	config, err := agent.ConfigureSSHClient(privateKeyFile)
	if err != nil {
		log.Errorf("worker %d unable to read user key %s: %s; bailing out.\n",
			id, privateKeyFile, err)
		return
	}

	for t := range work {
		client := agent.NewClient(config)

		remote := t.Agent
		if remote == "" {
			updates <- WorkerUpdate{Task: t.UUID, Op: OUTPUT,
				Output: fmt.Sprintf("TASK FAILED!!  no remote agent specified for task %s\n", t.UUID)}
			updates <- WorkerUpdate{Task: t.UUID, Op: FAILED, StoppedAt: time.Now()}
			continue
		}

		err = client.Dial(remote)
		if err != nil {
			updates <- WorkerUpdate{Task: t.UUID, Op: OUTPUT,
				Output: fmt.Sprintf("TASK FAILED!!  shield worker %d unable to connect to %s (%s)\n", id, remote, err)}
			updates <- WorkerUpdate{Task: t.UUID, Op: FAILED, StoppedAt: time.Now()}
			continue
		}

		// start a command and stream output
		final := make(chan string)
		partial := make(chan string)

		go func(out chan string, up chan WorkerUpdate, id uuid.UUID, in chan string) {
			var buffer []string
			for {
				s, ok := <-in
				if !ok {
					break
				}

				switch s[0:2] {
				case "O:":
					buffer = append(buffer, s[2:])
				case "E:":
					up <- WorkerUpdate{
						Task:   id,
						Op:     OUTPUT,
						Output: s[2:] + "\n",
					}
				}
			}
			out <- strings.Join(buffer, "")
			close(out)
		}(final, updates, t.UUID, partial)

		command, err := json.Marshal(WorkerRequest{
			Operation:      t.Op,
			TargetPlugin:   t.TargetPlugin,
			TargetEndpoint: t.TargetEndpoint,
			StorePlugin:    t.StorePlugin,
			StoreEndpoint:  t.StoreEndpoint,
			RestoreKey:     t.RestoreKey,
		})
		if err != nil {
			updates <- WorkerUpdate{Task: t.UUID, Op: OUTPUT,
				Output: fmt.Sprintf("TASK FAILED!! shield worker %d was unable to json encode the request bound for remote agent %s (%s)", id, remote, err),
			}
			updates <- WorkerUpdate{Task: t.UUID, Op: FAILED, StoppedAt: time.Now()}
			client.Close()
			continue
		}
		// exec the command
		var jobFailed bool
		err = client.Run(partial, string(command))
		if err != nil {
			updates <- WorkerUpdate{Task: t.UUID, Op: OUTPUT,
				Output: fmt.Sprintf("TASK FAILED!!  shield worker %d failed to execute the command against the remote agent %s (%s)\n", id, remote, err)}
			jobFailed = true
		}
		client.Close()

		out := <-final
		if t.Op == db.BackupOperation {
			// parse JSON from standard output and get the restore key
			// (this might fail, we might not get a key, etc.)
			v := struct {
				Key string
			}{}

			buf := bytes.NewBufferString(out)
			dec := json.NewDecoder(buf)
			err := dec.Decode(&v)

			if err != nil {
				jobFailed = true
				updates <- WorkerUpdate{Task: t.UUID, Op: OUTPUT,
					Output: fmt.Sprintf("WORKER FAILED!!  shield worker %d failed to parse JSON response from remote agent %s (%s)\n", id, remote, err)}

			} else {
				if v.Key != "" {
					updates <- WorkerUpdate{
						Task:        t.UUID,
						Op:          RESTORE_KEY,
						TaskSuccess: !jobFailed,
						Output:      v.Key,
					}
				} else {
					jobFailed = true
					updates <- WorkerUpdate{Task: t.UUID, Op: OUTPUT,
						Output: fmt.Sprintf("TASK FAILED!! No restore key detected in worker %d. Cowardly refusing to create an archive record", id)}
				}
			}
		}

		if t.Op == db.PurgeOperation && !jobFailed {
			updates <- WorkerUpdate{
				Task:    t.UUID,
				Op:      PURGE_ARCHIVE,
				Archive: t.ArchiveUUID,
			}
		}

		// signal to the supervisor that we finished
		if jobFailed {
			updates <- WorkerUpdate{Task: t.UUID, Op: FAILED, StoppedAt: time.Now()}
		} else {
			updates <- WorkerUpdate{
				Task:      t.UUID,
				Op:        STOPPED,
				StoppedAt: time.Now(),
			}
		}
	}
}
Ejemplo n.º 20
0
func (agent *Agent) handleConn(conn *ssh.ServerConn, chans <-chan ssh.NewChannel, reqs <-chan *ssh.Request) {
	defer conn.Close()

	for newChannel := range chans {
		if newChannel.ChannelType() != "session" {
			log.Errorf("rejecting unknown channel type: %s\n", newChannel.ChannelType())
			newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
			continue
		}

		channel, requests, err := newChannel.Accept()
		if err != nil {
			log.Errorf("failed to accept channel: %s\n", err)
			return
		}

		defer channel.Close()

		for req := range requests {
			if req.Type != "exec" {
				log.Errorf("rejecting non-exec channel request (type=%s)\n", req.Type)
				req.Reply(false, nil)
				continue
			}

			request, err := ParseRequest(req)
			if err != nil {
				log.Errorf("%s\n", err)
				req.Reply(false, nil)
				continue
			}

			if err = request.ResolvePaths(agent); err != nil {
				log.Errorf("%s\n", err)
				req.Reply(false, nil)
				continue
			}

			//log.Errorf("got an agent-request [%s]\n", request.JSON)
			req.Reply(true, nil)

			// drain output to the SSH channel stream
			output := make(chan string)
			done := make(chan int)
			go func(out io.Writer, in chan string, done chan int) {
				for {
					s, ok := <-in
					if !ok {
						break
					}
					fmt.Fprintf(out, "%s", s)
					log.Debugf("%s", s)
				}
				close(done)
			}(channel, output, done)

			// run the agent request
			err = request.Run(output)
			<-done
			var rc int
			if exitErr, ok := err.(*exec.ExitError); ok {
				sys := exitErr.ProcessState.Sys()
				// os.ProcessState.Sys() may not return syscall.WaitStatus on non-UNIX machines,
				// so currently this feature only works on UNIX, but shouldn't crash on other OSes
				if ws, ok := sys.(syscall.WaitStatus); ok {
					if ws.Exited() {
						rc = ws.ExitStatus()
					} else {
						var signal syscall.Signal
						if ws.Signaled() {
							signal = ws.Signal()
						}
						if ws.Stopped() {
							signal = ws.StopSignal()
						}
						sigStr, ok := SIGSTRING[signal]
						if !ok {
							sigStr = "ABRT" // use ABRT as catch-all signal for any that don't translate
							log.Infof("Task execution terminted due to %s, translating as ABRT for ssh transport", signal)
						} else {
							log.Infof("Task execution terminated due to SIG%s", sigStr)
						}
						sigMsg := struct {
							Signal     string
							CoreDumped bool
							Error      string
							Lang       string
						}{
							Signal:     sigStr,
							CoreDumped: false,
							Error:      fmt.Sprintf("shield-pipe terminated due to SIG%s", sigStr),
							Lang:       "en-US",
						}
						channel.SendRequest("exit-signal", false, ssh.Marshal(&sigMsg))
						channel.Close()
						continue
					}
				}
			} else if err != nil {
				// we got some kind of error that isn't a command execution error,
				// from a UNIX system, use an magical error code to signal this to
				// the shield daemon - 16777216
				log.Infof("Task could not execute: %s", err)
				rc = 16777216
			}

			log.Infof("Task completed with rc=%d", rc)
			byteCode := make([]byte, 4)
			binary.BigEndian.PutUint32(byteCode, uint32(rc)) // SSH protocol is big-endian byte ordering
			channel.SendRequest("exit-status", false, byteCode)
			channel.Close()
		}
	}
}