예제 #1
0
func (oa OAuthenticator) RequireAuth(w http.ResponseWriter, r *http.Request) {
	sess, err := gothic.Store.Get(r, gothic.SessionName)
	if err != nil {
		if _, ok := err.(securecookie.Error); ok {
			r.Header.Set("Cookie", "")
			sess, err = gothic.Store.New(r, gothic.SessionName)
			if err != nil {
				w.WriteHeader(500)
				w.Write([]byte("Failure generating a new session: " + err.Error()))
				return
			}
		} else {
			log.Errorf("%s", err)
			w.WriteHeader(500)
			w.Write([]byte("Unexpected error retrieving session data"))
			return
		}
	}
	sess.AddFlash(r.URL.Path)
	sess.Save(r, w)
	// only start oauth redirection if we're hitting the auth APIs, or web UI
	if ShouldOAuthRedirect(r.URL.Path) {
		log.Debugf("Starting OAuth Process for request: %s", r)
		gothic.BeginAuthHandler(w, r)
	} else {
		// otherwise set auth header for api clients to understand oauth is needed
		log.Debugf("Unauthenticated API Request received, OAuth required, sending 401")
		w.Header().Set("WWW-Authenticate", "Bearer")
		w.WriteHeader(401)
		w.Write([]byte("Unauthorized"))
	}
}
예제 #2
0
func (ba BasicAuthenticator) IsAuthenticated(r *http.Request) bool {
	authType, authToken := AuthHeader(r)
	log.Debugf("Checking `Authorization: %s %s` against our configuration", authType, authToken)

	if strings.ToLower(authType) == "basic" {
		decoded, err := base64.StdEncoding.DecodeString(authToken)
		if err != nil {
			log.Infof("Authorization header is corrupt: %s", err)
			return false
		}

		creds := strings.SplitN(string(decoded), ":", 2)
		if len(creds) != 2 {
			log.Infof("Authorization header is corrupt: '%s' does not contain a ':' delimiter",
				string(decoded))
			return false
		}

		log.Debugf("Received Authorization credentials for user '%s', password '%s'", creds[0], obfuscate(creds[1]))
		log.Debugf("checking against the configured credentials '%s', password '%s'", ba.Cfg.User, obfuscate(ba.Cfg.Password))
		return creds[0] == ba.Cfg.User && creds[1] == ba.Cfg.Password
	}

	log.Infof("Received an invalid Authorization header type '%s' (not 'Basic')", authType)
	return false
}
예제 #3
0
func Authenticate(tokens map[string]string, next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		providedToken := r.Header.Get("X-Shield-Token")
		if providedToken != "" {
			log.Debugf("Checking X-Shield-Token against available tokens")
			for name, token := range tokens {
				if providedToken == token {
					log.Debugf("Matched token %s!", name)
					next.ServeHTTP(w, r)
					return
				}
			}
			log.Debugf("No tokens matched")
		}

		if UserAuthenticator.IsAuthenticated(r) {
			log.Debugf("Request was authenticated, continuing to process")
			next.ServeHTTP(w, r)
		} else {
			log.Debugf("Request not authenticated, denying")
			UserAuthenticator.RequireAuth(w, r)
			return
		}
	})
}
예제 #4
0
func (oa OAuthenticator) IsAuthenticated(r *http.Request) bool {
	authType, authToken := AuthHeader(r)
	if strings.ToLower(authType) == "bearer" {
		log.Debugf("Received bearer token auth request")
		// jwt.Parse does both parsing and validating of the token
		token, err := jwt.Parse(authToken, func(token *jwt.Token) (interface{}, error) {
			if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
				return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
			}

			return oa.Cfg.JWTPublicKey, nil
		})
		if err != nil {
			return false
		}

		if expir, ok := token.Claims["expiration"].(float64); !ok || int64(expir) <= time.Now().Unix() {
			return false
		}

		userName, ok := token.Claims["user"].(string)
		if !ok {
			log.Debugf("user claim is not a string: %#v", token.Claims["user"])
			return false
		}

		membership, ok := token.Claims["membership"].(map[string]interface{})
		if !ok {
			log.Debugf("membership claim is not a Membership: %#v", token.Claims["membership"])
			return false
		}
		return OAuthVerifier.Verify(userName, membership)
	}

	sess, err := gothic.Store.Get(r, gothic.SessionName)
	if err != nil {
		log.Debugf("Error retrieving session: %s", err)
		return false
	}

	user, ok := sess.Values["User"].(string)
	if ok {
		membership, ok := sess.Values["Membership"].(map[string]interface{})
		if ok {
			return OAuthVerifier.Verify(user, membership)
		}
	}
	return false
}
예제 #5
0
파일: main.go 프로젝트: yacloud-io/shield
func main() {
	log.Infof("starting schema...")

	options := struct {
		Driver string `goptions:"-t,--type, obligatory, description='Type of database backend'"`
		DSN    string `goptions:"-d,--database, obligatory, description='DSN of the database backend'"`
	}{
	// No defaults
	}
	goptions.ParseAndFail(&options)

	database := &db.DB{
		Driver: options.Driver,
		DSN:    options.DSN,
	}

	log.Debugf("connecting to %s database at %s", database.Driver, database.DSN)
	if err := database.Connect(); err != nil {
		log.Errorf("failed to connect to %s database at %s: %s",
			database.Driver, database.DSN, err)
	}

	if err := database.Setup(); err != nil {
		log.Errorf("failed to set up schema in %s database at %s: %s",
			database.Driver, database.DSN, err)
		return
	}

	log.Infof("deployed schema version %d", db.CurrentSchema)
}
예제 #6
0
func (s *Supervisor) SpawnWorkers() {
	var i uint
	for i = 0; i < s.Workers; i++ {
		log.Debugf("spawning worker %d", i)
		s.SpawnWorker()
	}
}
예제 #7
0
func (s *Supervisor) PurgeArchives() {
	log.Debugf("scanning for archives needing to be expired")

	// mark archives past their retention policy as expired
	toExpire, err := s.Database.GetExpiredArchives()
	if err != nil {
		log.Errorf("error retrieving archives needing to be expired: %s", err.Error())
	}
	for _, archive := range toExpire {
		log.Infof("marking archive %s has expiration date %s, marking as expired", archive.UUID, archive.ExpiresAt)
		err := s.Database.ExpireArchive(archive.UUID)
		if err != nil {
			log.Errorf("error marking archive %s as expired: %s", archive.UUID, err)
			continue
		}
	}

	// get archives that are not valid or purged
	toPurge, err := s.Database.GetArchivesNeedingPurge()
	if err != nil {
		log.Errorf("error retrieving archives to purge: %s", err.Error())
	}

	for _, archive := range toPurge {
		log.Infof("requesting purge of archive %s due to status '%s'", archive.UUID, archive.Status)
		task, err := s.Database.CreatePurgeTask("system", archive, s.PurgeAgent)
		if err != nil {
			log.Errorf("error scheduling purge of archive %s: %s", archive.UUID, err)
			continue
		}
		s.ScheduleTask(task)
	}
}
예제 #8
0
func AuthHeader(r *http.Request) (string, string) {
	log.Debugf("Retrieving auth header `%v`", r.Header.Get("Authorization"))
	auth := strings.SplitN(r.Header.Get("Authorization"), " ", 2)
	if len(auth) != 2 {
		return "", ""
	}
	return auth[0], auth[1]
}
예제 #9
0
func main() {
	log.Infof("starting schema...")

	options := struct {
		Help    bool   `goptions:"-h, --help, description='Show the help screen'"`
		Driver  string `goptions:"-t, --type, description='Type of database backend (postgres or mysql)'"`
		DSN     string `goptions:"-d,--database, description='DSN of the database backend'"`
		Version bool   `goptions:"-v, --version, description='Display the SHIELD version'"`
	}{
	// No defaults
	}
	if err := goptions.Parse(&options); err != nil {
		fmt.Printf("%s\n", err)
		goptions.PrintHelp()
		return
	}
	if options.Help {
		goptions.PrintHelp()
		os.Exit(0)
	}
	if options.Version {
		if Version == "" {
			fmt.Printf("shield-schema (development)%s\n", Version)
		} else {
			fmt.Printf("shield-schema v%s\n", Version)
		}
		os.Exit(0)
	}
	if options.Driver == "" {
		fmt.Fprintf(os.Stderr, "You must indicate which type of database you wish to manage, via the `--type` option.\n")
		os.Exit(1)
	}
	if options.DSN == "" {
		fmt.Fprintf(os.Stderr, "You must specify the DSN of your database, via the `--database` option.\n")
		os.Exit(1)
	}

	database := &db.DB{
		Driver: options.Driver,
		DSN:    options.DSN,
	}

	log.Debugf("connecting to %s database at %s", database.Driver, database.DSN)
	if err := database.Connect(); err != nil {
		log.Errorf("failed to connect to %s database at %s: %s",
			database.Driver, database.DSN, err)
	}

	if err := database.Setup(); err != nil {
		log.Errorf("failed to set up schema in %s database at %s: %s",
			database.Driver, database.DSN, err)
		return
	}

	log.Infof("deployed schema version %d", db.CurrentSchema)
}
예제 #10
0
func (ws *WebServer) Start() {
	err := ws.Setup()
	if err != nil {
		panic("Could not set up WebServer for SHIELD: " + err.Error())
	}
	log.Debugf("Starting WebServer on '%s'...", ws.Addr)
	err = http.ListenAndServe(ws.Addr, nil)
	if err != nil {
		log.Errorf("HTTP API failed %s", err.Error())
		panic("Cannot setup WebServer, aborting. Check logs for details.")
	}
}
예제 #11
0
파일: db.go 프로젝트: yacloud-io/shield
// Execute a named, data query (SELECT)
func (db *DB) Query(sql_or_name string, args ...interface{}) (*sql.Rows, error) {
	s, err := db.statement(sql_or_name)
	if err != nil {
		return nil, err
	}

	log.Debugf("Parameters: %v", args)
	r, err := s.Query(args...)
	if err != nil {
		return nil, err
	}

	return r, nil
}
예제 #12
0
파일: db.go 프로젝트: yacloud-io/shield
// Execute a named, non-data query (INSERT, UPDATE, DELETE, etc.)
func (db *DB) Exec(sql_or_name string, args ...interface{}) error {
	s, err := db.statement(sql_or_name)
	if err != nil {
		return err
	}

	log.Debugf("Parameters: %v", args)
	_, err = s.Exec(args...)
	if err != nil {
		return err
	}

	return nil
}
예제 #13
0
func (uv *UAAVerifier) Verify(user string, membership map[string]interface{}) bool {
	// If none specified, don't let anyone in
	if len(uv.Groups) == 0 {
		log.Debugf("No groups specified for authorization, denying access to '%s'.", user)
		return false
	}

	log.Debugf("User Groups: %#v", membership["Groups"])
	log.Debugf("Allowed Groups: %#v", uv.Groups)

	for _, target := range uv.Groups {
		log.Debugf("Seeing if '%s' is in UAA Group '%s'", user, target)

		var groups []string
		var ok bool
		groups, ok = membership["Groups"].([]string)
		if !ok {
			g, ok := membership["Groups"].([]interface{})
			if ok {
				for _, o := range g {
					s, ok := o.(string)
					if !ok {
						log.Debugf("Unexpected data type for group: %#v", o)
						return false
					}
					groups = append(groups, s)
				}
			} else {
				log.Debugf("Unexpected data type for groups: %#v", membership["Groups"])
				return false
			}
		}

		for _, group := range groups {
			if group == target {
				log.Debugf("'%s' is an allowed group, granting access to '%s'", target, user)
				return true
			}
		}
	}
	log.Debugf("No groups matched")
	return false
}
예제 #14
0
func (gv *GithubVerifier) Verify(user string, membership map[string]interface{}) bool {
	// If none specified, don't let anyone in
	if len(gv.Orgs) == 0 {
		log.Debugf("No orgs specified for authorization, denying access to '%s'.", user)
		return false
	}

	log.Debugf("User orgs: %#v", membership["Orgs"])
	log.Debugf("Allowed Orgs: %#v", gv.Orgs)

	for _, target := range gv.Orgs {
		log.Debugf("Seeing if '%s' is in GitHub Org '%s'", user, target)

		var orgs []string
		var ok bool
		orgs, ok = membership["Orgs"].([]string)
		if !ok {
			os, ok := membership["Orgs"].([]interface{})
			if ok {
				for _, o := range os {
					s, ok := o.(string)
					if !ok {
						log.Debugf("Unexpected data type for group: %#v", o)
						return false
					}
					orgs = append(orgs, s)
				}
			} else {
				log.Debugf("Unexpected data type for groups: %#v", membership["Orgs"])
				return false
			}
		}

		for _, org := range orgs {
			if org == target {
				log.Debugf("'%s' is an allowed org, granting access to '%s'", target, user)
				return true
			}
		}
	}
	return false
}
예제 #15
0
파일: db.go 프로젝트: yacloud-io/shield
// Return the prepared Statement for a given SQL query
func (db *DB) statement(sql_or_name string) (*sql.Stmt, error) {
	sql := db.resolve(db.rebind(sql_or_name))
	if db.connection == nil {
		return nil, fmt.Errorf("Not connected to database")
	}

	log.Debugf("Executing SQL: %s", sql)

	q, ok := db.qCache[sql]
	if !ok {
		stmt, err := db.connection.Prepare(sql)
		if err != nil {
			return nil, err
		}
		db.qCache[sql] = stmt
	}

	q, ok = db.qCache[sql]
	if !ok {
		return nil, fmt.Errorf("Weird bug: query '%s' is still not properly prepared", sql)
	}
	return q, nil
}
예제 #16
0
func (req *Request) Run(output chan string) error {
	cmd := exec.Command("shield-pipe")

	log.Infof("Executing %s request using target %s and store %s via shield-pipe", req.Operation, req.TargetPlugin, req.StorePlugin)
	log.Debugf("Target Endpoint config: %s", req.TargetEndpoint)
	log.Debugf("Store Endpoint config: %s", req.StoreEndpoint)

	cmd.Env = []string{
		fmt.Sprintf("HOME=%s", os.Getenv("HOME")),
		fmt.Sprintf("PATH=%s", os.Getenv("PATH")),
		fmt.Sprintf("USER=%s", os.Getenv("USER")),
		fmt.Sprintf("LANG=%s", os.Getenv("LANG")),

		fmt.Sprintf("SHIELD_OP=%s", req.Operation),
		fmt.Sprintf("SHIELD_STORE_PLUGIN=%s", req.StorePlugin),
		fmt.Sprintf("SHIELD_STORE_ENDPOINT=%s", req.StoreEndpoint),
		fmt.Sprintf("SHIELD_TARGET_PLUGIN=%s", req.TargetPlugin),
		fmt.Sprintf("SHIELD_TARGET_ENDPOINT=%s", req.TargetEndpoint),
		fmt.Sprintf("SHIELD_RESTORE_KEY=%s", req.RestoreKey),
	}

	if log.LogLevel() == syslog.LOG_DEBUG {
		cmd.Env = append(cmd.Env, "DEBUG=true")
	}

	log.Debugf("ENV: %s", strings.Join(cmd.Env, ","))

	stdout, err := cmd.StdoutPipe()
	if err != nil {
		return err
	}
	stderr, err := cmd.StderrPipe()
	if err != nil {
		return err
	}

	var wg sync.WaitGroup
	drain := func(prefix string, out chan string, in io.Reader) {
		defer wg.Done()
		s := bufio.NewScanner(in)
		for s.Scan() {
			out <- fmt.Sprintf("%s:%s\n", prefix, s.Text())
		}
	}

	wg.Add(2)
	go drain("E", output, stderr)
	go drain("O", output, stdout)

	err = cmd.Start()
	if err != nil {
		close(output)
		return err
	}

	wg.Wait()
	close(output)

	err = cmd.Wait()
	if err != nil {
		return err
	}

	return nil
}
예제 #17
0
func (ws *WebServer) Setup() error {
	var err error
	log.Debugf("Configuring WebServer...")
	if err := ws.Database.Connect(); err != nil {
		log.Errorf("Failed to connect to %s database at %s: %s", ws.Database.Driver, ws.Database.DSN, err)
		return err
	}

	if ws.Auth.OAuth.Provider != "" {
		log.Debugf("Configuring OAuth Session store")
		maxSessionAge := ws.Auth.OAuth.Sessions.MaxAge
		authKey := securecookie.GenerateRandomKey(64)
		encKey := securecookie.GenerateRandomKey(32)
		switch ws.Auth.OAuth.Sessions.Type {
		case "sqlite3":
			log.Debugf("Using sqlite3 as a session store")
			store, err := sqlitestore.NewSqliteStore(ws.Auth.OAuth.Sessions.DSN, "http_sessions", "/", maxSessionAge, authKey, encKey)
			if err != nil {
				log.Errorf("Error setting up sessions database: %s", err)
				return err
			}
			gothic.Store = store
		case "postgres":
			log.Debugf("Using postgres as a session store")
			gothic.Store = pgstore.NewPGStore(ws.Auth.OAuth.Sessions.DSN, authKey, encKey)
			gothic.Store.(*pgstore.PGStore).Options.MaxAge = maxSessionAge
		case "mock":
			log.Debugf("Using mocked session store")
			// does nothing, to avoid being accidentally used in prod
		default:
			log.Errorf("Invalid DB Backend for OAuth sessions database")
			return err
		}

		gob.Register(map[string]interface{}{})
		switch ws.Auth.OAuth.Provider {
		case "github":
			log.Debugf("Using github as the oauth provider")
			goth.UseProviders(github.New(ws.Auth.OAuth.Key, ws.Auth.OAuth.Secret, fmt.Sprintf("%s/v1/auth/github/callback", ws.Auth.OAuth.BaseURL), "read:org"))
			OAuthVerifier = &GithubVerifier{Orgs: ws.Auth.OAuth.Authorization.Orgs}
		case "cloudfoundry":
			log.Debugf("Using cloudfoundry as the oauth provider")
			goth.UseProviders(cloudfoundry.New(ws.Auth.OAuth.ProviderURL, ws.Auth.OAuth.Key, ws.Auth.OAuth.Secret, fmt.Sprintf("%s/v1/auth/cloudfoundry/callback", ws.Auth.OAuth.BaseURL), "openid,scim.read"))
			OAuthVerifier = &UAAVerifier{Groups: ws.Auth.OAuth.Authorization.Orgs, UAA: ws.Auth.OAuth.ProviderURL}
			p, err := goth.GetProvider("cloudfoundry")
			if err != nil {
				return err
			}
			p.(*cloudfoundry.Provider).Client = ws.Auth.OAuth.Client
		case "faux":
			log.Debugf("Using mocked session store")
			// does nothing, to avoid being accidentally used in prod
		default:
			log.Errorf("Invalid OAuth provider specified.")
			return err
		}

		gothic.GetProviderName = func(req *http.Request) (string, error) {
			return ws.Auth.OAuth.Provider, nil
		}

		gothic.SetState = func(req *http.Request) string {
			sess, _ := gothic.Store.Get(req, gothic.SessionName)
			sess.Values["state"] = uuid.New()
			return sess.Values["state"].(string)
		}
	}

	protectedAPIs, err := ws.ProtectedAPIs()
	if err != nil {
		log.Errorf("Could not set up HTTP routes: " + err.Error())
		return err
	}

	if ws.Auth.OAuth.Provider != "" {
		log.Debugf("Enabling OAuth handlers for HTTP requests")
		UserAuthenticator = OAuthenticator{
			Cfg: ws.Auth.OAuth,
		}
	} else {
		log.Debugf("Enabling Basic Auth handlers for HTTP requests")
		UserAuthenticator = BasicAuthenticator{
			Cfg: ws.Auth.Basic,
		}
	}

	http.Handle("/", ws.UnauthenticatedResources(Authenticate(ws.Auth.Tokens, protectedAPIs)))
	return nil
}
예제 #18
0
func (s *Supervisor) Run() error {
	if err := s.Database.Connect(); err != nil {
		return fmt.Errorf("failed to connect to %s database at %s: %s\n",
			s.Database.Driver, s.Database.DSN, err)
	}

	if err := s.Database.CheckCurrentSchema(); err != nil {
		return fmt.Errorf("database failed schema version check: %s\n", err)
	}

	if err := s.Resync(); err != nil {
		return err
	}
	if err := s.FailUnfinishedTasks(); err != nil {
		return err
	}
	if err := s.ReschedulePendingTasks(); err != nil {
		return err
	}

	for {
		select {
		case <-s.resync:
			if err := s.Resync(); err != nil {
				log.Errorf("resync error: %s", err)
			}

		case <-s.purge.C:
			s.PurgeArchives()

		case <-s.tick.C:
			s.CheckSchedule()

			// see if any tasks have been running past the timeout period
			if len(s.runq) > 0 {
				ok := true
				lst := make([]*db.Task, 0)
				now := timestamp.Now()

				for _, runtask := range s.runq {
					if now.After(runtask.TimeoutAt) {
						s.Database.CancelTask(runtask.UUID, now.Time())
						log.Errorf("shield timed out task '%s' after running for %v", runtask.UUID, s.Timeout)
						ok = false

					} else {
						lst = append(lst, runtask)
					}
				}

				if !ok {
					s.runq = lst
				}
			}

			// see if we have anything in the schedule queue
		SchedQueue:
			for len(s.schedq) > 0 {
				select {
				case s.workers <- s.schedq[0]:
					s.Database.StartTask(s.schedq[0].UUID, time.Now())
					s.schedq[0].Attempts++
					log.Infof("sent a task to a worker")
					s.runq = append(s.runq, s.schedq[0])
					log.Debugf("added task to the runq")
					s.schedq = s.schedq[1:]
				default:
					break SchedQueue
				}
			}

		case adhoc := <-s.adhoc:
			s.ScheduleAdhoc(adhoc)

		case u := <-s.updates:
			switch u.Op {
			case STOPPED:
				log.Infof("  %s: job stopped at %s", u.Task, u.StoppedAt)
				s.RemoveTaskFromRunq(u.Task)
				if err := s.Database.CompleteTask(u.Task, u.StoppedAt); err != nil {
					log.Errorf("  %s: !! failed to update database - %s", u.Task, err)
				}

			case FAILED:
				log.Warnf("  %s: task failed!", u.Task)
				s.RemoveTaskFromRunq(u.Task)
				if err := s.Database.FailTask(u.Task, u.StoppedAt); err != nil {
					log.Errorf("  %s: !! failed to update database - %s", u.Task, err)
				}

			case OUTPUT:
				log.Infof("  %s> %s", u.Task, strings.Trim(u.Output, "\n"))
				if err := s.Database.UpdateTaskLog(u.Task, u.Output); err != nil {
					log.Errorf("  %s: !! failed to update database - %s", u.Task, err)
				}

			case RESTORE_KEY:
				log.Infof("  %s: restore key is %s", u.Task, u.Output)
				if id, err := s.Database.CreateTaskArchive(u.Task, u.Output, time.Now()); err != nil {
					log.Errorf("  %s: !! failed to update database - %s", u.Task, err)
				} else {
					if !u.TaskSuccess {
						s.Database.InvalidateArchive(id)
					}
				}

			case PURGE_ARCHIVE:
				log.Infof("  %s: archive %s purged from storage", u.Task, u.Archive)
				if err := s.Database.PurgeArchive(u.Archive); err != nil {
					log.Errorf("  %s: !! failed to update database - %s", u.Task, err)
				}

			default:
				log.Errorf("  %s: !! unrecognized op type", u.Task)
			}
		}
	}
}
예제 #19
0
파일: agent.go 프로젝트: yacloud-io/shield
func (agent *Agent) handleConn(conn *ssh.ServerConn, chans <-chan ssh.NewChannel, reqs <-chan *ssh.Request) {
	defer conn.Close()

	for newChannel := range chans {
		if newChannel.ChannelType() != "session" {
			log.Errorf("rejecting unknown channel type: %s\n", newChannel.ChannelType())
			newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
			continue
		}

		channel, requests, err := newChannel.Accept()
		if err != nil {
			log.Errorf("failed to accept channel: %s\n", err)
			return
		}

		defer channel.Close()

		for req := range requests {
			if req.Type != "exec" {
				log.Errorf("rejecting non-exec channel request (type=%s)\n", req.Type)
				req.Reply(false, nil)
				continue
			}

			request, err := ParseRequest(req)
			if err != nil {
				log.Errorf("%s\n", err)
				req.Reply(false, nil)
				continue
			}

			if err = request.ResolvePaths(agent); err != nil {
				log.Errorf("%s\n", err)
				req.Reply(false, nil)
				continue
			}

			//log.Errorf("got an agent-request [%s]\n", request.JSON)
			req.Reply(true, nil)

			// drain output to the SSH channel stream
			output := make(chan string)
			done := make(chan int)
			go func(out io.Writer, in chan string, done chan int) {
				for {
					s, ok := <-in
					if !ok {
						break
					}
					fmt.Fprintf(out, "%s", s)
					log.Debugf("%s", s)
				}
				close(done)
			}(channel, output, done)

			// run the agent request
			err = request.Run(output)
			<-done
			var rc int
			if exitErr, ok := err.(*exec.ExitError); ok {
				sys := exitErr.ProcessState.Sys()
				// os.ProcessState.Sys() may not return syscall.WaitStatus on non-UNIX machines,
				// so currently this feature only works on UNIX, but shouldn't crash on other OSes
				if ws, ok := sys.(syscall.WaitStatus); ok {
					if ws.Exited() {
						rc = ws.ExitStatus()
					} else {
						var signal syscall.Signal
						if ws.Signaled() {
							signal = ws.Signal()
						}
						if ws.Stopped() {
							signal = ws.StopSignal()
						}
						sigStr, ok := SIGSTRING[signal]
						if !ok {
							sigStr = "ABRT" // use ABRT as catch-all signal for any that don't translate
							log.Infof("Task execution terminted due to %s, translating as ABRT for ssh transport", signal)
						} else {
							log.Infof("Task execution terminated due to SIG%s", sigStr)
						}
						sigMsg := struct {
							Signal     string
							CoreDumped bool
							Error      string
							Lang       string
						}{
							Signal:     sigStr,
							CoreDumped: false,
							Error:      fmt.Sprintf("shield-pipe terminated due to SIG%s", sigStr),
							Lang:       "en-US",
						}
						channel.SendRequest("exit-signal", false, ssh.Marshal(&sigMsg))
						channel.Close()
						continue
					}
				}
			} else if err != nil {
				// we got some kind of error that isn't a command execution error,
				// from a UNIX system, use an magical error code to signal this to
				// the shield daemon - 16777216
				log.Infof("Task could not execute: %s", err)
				rc = 16777216
			}

			log.Infof("Task completed with rc=%d", rc)
			byteCode := make([]byte, 4)
			binary.BigEndian.PutUint32(byteCode, uint32(rc)) // SSH protocol is big-endian byte ordering
			channel.SendRequest("exit-status", false, byteCode)
			channel.Close()
		}
	}
}
예제 #20
0
파일: config.go 프로젝트: yacloud-io/shield
func (s *Supervisor) ReadConfig(path string) error {
	b, err := ioutil.ReadFile(path)
	if err != nil {
		return err
	}

	var config Config
	err = yaml.Unmarshal(b, &config)
	if err != nil {
		return err
	}

	if config.Addr == "" {
		config.Addr = ":8888"
	}
	if config.PrivateKeyFile == "" {
		config.PrivateKeyFile = "/etc/shield/ssh/server.key"
	}
	if config.WebRoot == "" {
		config.WebRoot = "/usr/share/shield/webui"
	}
	if config.Workers == 0 {
		config.Workers = 5
	}

	if config.PurgeAgent == "" {
		config.PurgeAgent = "localhost:5444"
	}

	if config.MaxTimeout == 0 {
		config.MaxTimeout = 12
	}

	if config.Auth.Basic.User == "" {
		config.Auth.Basic.User = "******"
	}

	if config.Auth.Basic.Password == "" {
		config.Auth.Basic.Password = "******"
	}

	if config.Auth.OAuth.Sessions.MaxAge == 0 {
		config.Auth.OAuth.Sessions.MaxAge = 86400 * 30
	}

	if config.Auth.OAuth.Provider != "" {
		if config.Auth.OAuth.BaseURL == "" {
			return fmt.Errorf("OAuth requested, but no external URL provided. Cannot proceed.")
		}
		if config.Auth.OAuth.SigningKey == "" {
			log.Debugf("No signing key specified, generating a random one")
			privKey, err := rsa.GenerateKey(rand.Reader, 2048)
			if err != nil {
				return err
			}
			config.Auth.OAuth.JWTPrivateKey = privKey
		} else {
			bytes, err := ioutil.ReadFile(config.Auth.OAuth.SigningKey)
			if err != nil {
				return err
			}
			privKey, err := jwt.ParseRSAPrivateKeyFromPEM(bytes)
			if err != nil {
				return err
			}
			config.Auth.OAuth.JWTPrivateKey = privKey

		}
		config.Auth.OAuth.JWTPublicKey = &config.Auth.OAuth.JWTPrivateKey.PublicKey

		config.Auth.OAuth.Client = &http.Client{
			Transport: &http.Transport{
				TLSClientConfig: &tls.Config{
					InsecureSkipVerify: config.SkipSSLVerify,
				},
			},
		}

		config.Auth.OAuth.ProviderURL = strings.TrimSuffix(config.Auth.OAuth.ProviderURL, "/")
	}

	if s.Database == nil {
		s.Database = &db.DB{}
	}

	s.Database.Driver = config.DatabaseType
	s.Database.DSN = config.DatabaseDSN
	s.PrivateKeyFile = config.PrivateKeyFile
	s.Workers = config.Workers
	s.PurgeAgent = config.PurgeAgent
	s.Timeout = time.Duration(config.MaxTimeout) * time.Hour

	ws := WebServer{
		Database:   s.Database.Copy(),
		Addr:       config.Addr,
		WebRoot:    config.WebRoot,
		Auth:       config.Auth,
		Supervisor: s,
	}
	s.Web = &ws
	return nil
}
예제 #21
0
func (oa OAuthenticator) OAuthCallback() http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		log.Debugf("Incoming Auth request: %s", r)
		sess, err := gothic.Store.Get(r, gothic.SessionName)
		if err != nil {
			log.Errorf("Error retrieving session info: %s", err)
			w.WriteHeader(500)
			return
		}
		log.Debugf("Processing oauth callback for '%s'", sess.ID)
		if gothic.GetState(r) != sess.Values["state"] {
			w.WriteHeader(403)
			w.Write([]byte("Unauthorized"))
			return
		}

		if r.URL.Query().Get("code") == "" {
			log.Errorf("No code detected in oauth callback: %v", r)
			w.WriteHeader(403)
			w.Write([]byte("No oauth code issued from provider"))
			return
		}

		user, err := gothic.CompleteUserAuth(w, r)
		if err != nil {
			log.Errorf("Error verifying oauth success: %s. Request: %v", err, r)
			w.WriteHeader(403)
			w.Write([]byte("UnOAuthorized"))
			return
		}

		log.Debugf("Authenticated user %#v", user)

		ts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: user.AccessToken})
		ctx := context.WithValue(oauth2.NoContext, oauth2.HTTPClient, oa.Cfg.Client)
		tc := oauth2.NewClient(ctx, ts)

		log.Debugf("Checking authorization...")
		membership, err := OAuthVerifier.Membership(user, tc)
		if err != nil {
			log.Errorf("Error retreiving user membership: %s", err)
			w.WriteHeader(403)
			w.Write([]byte("Unable to verify your membership"))
			return
		}

		if !OAuthVerifier.Verify(user.NickName, membership) {
			log.Debugf("Authorization denied")
			w.WriteHeader(403)
			w.Write([]byte("You are not authorized to view this content"))
			return
		}

		log.Infof("Successful login for %s", user.NickName)

		redirect := "/"
		if flashes := sess.Flashes(); len(flashes) > 0 {
			if flash, ok := flashes[0].(string); ok {
				// don't redirect back to api calls, to prevent auth redirection loops
				if !apiCall.MatchString(flash) || cliAuthCall.MatchString(flash) {
					redirect = flash
				}
			}
		}

		sess.Values["User"] = user.NickName
		sess.Values["Membership"] = membership
		err = sess.Save(r, w)
		if err != nil {
			log.Errorf("Error saving session: %s", err)
			w.WriteHeader(500)
			w.Write([]byte("Unable to save authentication data. Check the SHIELD logs for more info."))
			return
		}

		http.Redirect(w, r, redirect, 302) // checks auth
	})
}