func main() { // Parse command-line flags for this system. var ( listenAddress = flag.String("addr", "", "Address to listen to incoming requests on.") ldapAddress = flag.String("ldapAddr", "", "Address to connect to LDAP.") ldapBindDN = flag.String("ldapBindDN", "", "LDAP DN to bind to for login.") ldapInsecure = flag.Bool("insecureLDAP", false, "INSECURE: Don't use TLS for LDAP connection.") ldapBindPassword = flag.String("ldapBindPassword", "", "LDAP password for bind.") statsdHost = flag.String("statsHost", "", "Address to send statsd metrics to.") iamAccount = flag.String("iamaccount", "", "AWS Account ID for generating IAM Role ARNs") enableLDAPRoles = flag.Bool("ldaproles", false, "Enable role support using LDAP directory.") roleAttribute = flag.String("roleattribute", "", "Group attribute to get role from.") defaultRole = flag.String("role", "", "AWS role to assume by default.") configFile = flag.String("conf", "/etc/hologram/server.json", "Config file to load.") cacheTimeout = flag.Int("cachetime", 3600, "Time in seconds after which to refresh LDAP user cache.") debugMode = flag.Bool("debug", false, "Enable debug mode.") config Config ) flag.Parse() // Enable debug log output if the user requested it. if *debugMode { log.DebugMode(true) log.Debug("Enabling debug log output. Use sparingly.") } // Parse in options from the given config file. log.Debug("Loading configuration from %s", *configFile) configContents, configErr := ioutil.ReadFile(*configFile) if configErr != nil { log.Errorf("Could not read from config file. The error was: %s", configErr.Error()) os.Exit(1) } configParseErr := json.Unmarshal(configContents, &config) if configParseErr != nil { log.Errorf("Error in parsing config file: %s", configParseErr.Error()) os.Exit(1) } // Merge in command flag options. if *ldapAddress != "" { config.LDAP.Host = *ldapAddress } if *ldapInsecure { config.LDAP.InsecureLDAP = true } if *ldapBindDN != "" { config.LDAP.Bind.DN = *ldapBindDN } if *ldapBindPassword != "" { config.LDAP.Bind.Password = *ldapBindPassword } if *statsdHost != "" { config.Stats = *statsdHost } if *iamAccount != "" { config.AWS.Account = *iamAccount } if *listenAddress != "" { config.Listen = *listenAddress } if *defaultRole != "" { config.AWS.DefaultRole = *defaultRole } if *enableLDAPRoles { config.LDAP.EnableLDAPRoles = true } if *roleAttribute != "" { config.LDAP.RoleAttribute = *roleAttribute } if *cacheTimeout != 3600 { config.CacheTimeout = *cacheTimeout } var stats g2s.Statter var statsErr error if config.LDAP.UserAttr == "" { config.LDAP.UserAttr = "cn" } if config.Stats == "" { log.Debug("No statsd server specified; no metrics will be emitted by this program.") stats = g2s.Noop() } else { stats, statsErr = g2s.Dial("udp", config.Stats) if statsErr != nil { log.Errorf("Error connecting to statsd: %s. No metrics will be emitted by this program.", statsErr.Error()) stats = g2s.Noop() } else { log.Debug("This program will emit metrics to %s", config.Stats) } } // Setup the server state machine that responds to requests. auth, err := aws.GetAuth(os.Getenv("HOLOGRAM_AWSKEY"), os.Getenv("HOLOGRAM_AWSSECRET"), "", time.Now()) if err != nil { log.Errorf("Error getting instance credentials: %s", err.Error()) os.Exit(1) } stsConnection := sts.New(auth, aws.Regions["us-east-1"]) credentialsService := server.NewDirectSessionTokenService(config.AWS.Account, stsConnection) var ldapServer *ldap.Conn // Connect to the LDAP server using TLS or not depending on the config if config.LDAP.InsecureLDAP { log.Debug("Connecting to LDAP at server %s (NOT using TLS).", config.LDAP.Host) ldapServer, err = ldap.Dial("tcp", config.LDAP.Host) if err != nil { log.Errorf("Could not dial LDAP! %s", err.Error()) os.Exit(1) } } else { // Connect to the LDAP server with sample credentials. tlsConfig := &tls.Config{ InsecureSkipVerify: true, } log.Debug("Connecting to LDAP at server %s.", config.LDAP.Host) ldapServer, err = ldap.DialTLS("tcp", config.LDAP.Host, tlsConfig) if err != nil { log.Errorf("Could not dial LDAP! %s", err.Error()) os.Exit(1) } } if bindErr := ldapServer.Bind(config.LDAP.Bind.DN, config.LDAP.Bind.Password); bindErr != nil { log.Errorf("Could not bind to LDAP! %s", bindErr.Error()) os.Exit(1) } ldapCache, err := server.NewLDAPUserCache(ldapServer, stats, config.LDAP.UserAttr, config.LDAP.BaseDN, config.LDAP.EnableLDAPRoles, config.LDAP.RoleAttribute) if err != nil { log.Errorf("Top-level error in LDAPUserCache layer: %s", err.Error()) os.Exit(1) } serverHandler := server.New(ldapCache, credentialsService, config.AWS.DefaultRole, stats, ldapServer, config.LDAP.UserAttr, config.LDAP.BaseDN, config.LDAP.EnableLDAPRoles) server, err := remote.NewServer(config.Listen, serverHandler.HandleConnection) // Wait for a signal from the OS to shutdown. terminate := make(chan os.Signal) signal.Notify(terminate, syscall.SIGINT, syscall.SIGTERM) // SIGUSR1 and SIGUSR2 should make Hologram enable and disable debug logging, // respectively. debugEnable := make(chan os.Signal) debugDisable := make(chan os.Signal) signal.Notify(debugEnable, syscall.SIGUSR1) signal.Notify(debugDisable, syscall.SIGUSR2) // SIGHUP should make Hologram server reload its cache of user information // from LDAP. reloadCacheSigHup := make(chan os.Signal) signal.Notify(reloadCacheSigHup, syscall.SIGHUP) // Reload the cache based on time set in configuration cacheTimeoutTicker := time.NewTicker(time.Duration(config.CacheTimeout) * time.Second) log.Info("Hologram server is online, waiting for termination.") WaitForTermination: for { select { case <-terminate: break WaitForTermination case <-debugEnable: log.Info("Enabling debug mode.") log.DebugMode(true) case <-debugDisable: log.Info("Disabling debug mode.") log.DebugMode(false) case <-reloadCacheSigHup: log.Info("Force-reloading user cache.") ldapCache.Update() case <-cacheTimeoutTicker.C: log.Info("Cache timeout. Reloading user cache.") ldapCache.Update() } } log.Info("Caught signal; shutting down now.") server.Close() }
func TestServerStateMachine(t *testing.T) { // This silly thing is needed for equality testing for the LDAP dummy. neededModifyRequest := ldap.NewModifyRequest("something") neededModifyRequest.Add("sshPublicKey", []string{"test"}) Convey("Given a state machine setup with a null logger", t, func() { authenticator := &DummyAuthenticator{&server.User{Username: "******"}} ldap := &DummyLDAP{ username: "******", password: "******", sshKeys: []string{}, req: neededModifyRequest, } testServer := server.New(authenticator, &dummyCredentials{}, "default", g2s.Noop(), ldap, "cn", "dc=testdn,dc=com", false) r, w := io.Pipe() testConnection := protocol.NewMessageConnection(ReadWriter(r, w)) go testServer.HandleConnection(testConnection) Convey("When a ping message comes in", func() { testPing := &protocol.Message{Ping: &protocol.Ping{}} testConnection.Write(testPing) Convey("Then the server should respond with a pong response.", func() { recvMsg, recvErr := testConnection.Read() So(recvErr, ShouldBeNil) So(recvMsg.GetPing(), ShouldNotBeNil) }) }) Convey("After an AssumeRequest", func() { role := "testrole" msg := &protocol.Message{ ServerRequest: &protocol.ServerRequest{ AssumeRole: &protocol.AssumeRole{ Role: &role, }, }, } testConnection.Write(msg) msg, err := testConnection.Read() if err != nil { t.Fatal(err) } Convey("it should challenge, then send credentials on success", func() { challenge := msg.GetServerResponse().GetChallenge().GetChallenge() So(len(challenge), ShouldEqual, 64) format := "test" sig := []byte("ssss") challengeResponseMsg := &protocol.Message{ ServerRequest: &protocol.ServerRequest{ ChallengeResponse: &protocol.SSHChallengeResponse{ Format: &format, Signature: sig, }, }, } testConnection.Write(challengeResponseMsg) credsMsg, err := testConnection.Read() if err != nil { t.Fatal(err) } So(credsMsg, ShouldNotBeNil) So(credsMsg.GetServerResponse(), ShouldNotBeNil) So(credsMsg.GetServerResponse().GetCredentials(), ShouldNotBeNil) creds := credsMsg.GetServerResponse().GetCredentials() So(creds.GetAccessKeyId(), ShouldEqual, "access_key") So(creds.GetSecretAccessKey(), ShouldEqual, "secret") So(creds.GetAccessToken(), ShouldEqual, "token") So(creds.GetExpiration(), ShouldBeGreaterThanOrEqualTo, time.Now().Unix()) }) Convey("it should then send failure message on failed key verification", func() { authenticator.user = nil challenge := msg.GetServerResponse().GetChallenge().GetChallenge() So(len(challenge), ShouldEqual, 64) format := "test" sig := []byte("ssss") challengeResponseMsg := &protocol.Message{ ServerRequest: &protocol.ServerRequest{ ChallengeResponse: &protocol.SSHChallengeResponse{ Format: &format, Signature: sig, }, }, } testConnection.Write(challengeResponseMsg) credsMsg, err := testConnection.Read() if err != nil { t.Fatal(err) } So(credsMsg, ShouldNotBeNil) So(credsMsg.GetServerResponse(), ShouldNotBeNil) So(credsMsg.GetServerResponse().GetVerificationFailure(), ShouldNotBeNil) }) }) Convey("When a request to add an SSH key comes in", func() { user := "******" password := "******" sshKey := "test" testMessage := &protocol.Message{ ServerRequest: &protocol.ServerRequest{ AddSSHkey: &protocol.AddSSHKey{ Username: &user, Passwordhash: &password, Sshkeybytes: &sshKey, }, }, } testConnection.Write(testMessage) Convey("If this request is valid", func() { msg, err := testConnection.Read() if err != nil { t.Fatal(err) } if msg.GetSuccess() == nil { t.Fail() } Convey("It should add the SSH key to the user.", func() { So(ldap.sshKeys[0], ShouldEqual, sshKey) Convey("If the user tries to add the same SSH key", func() { testConnection.Write(testMessage) Convey("It should not insert the same key twice.", func() { So(len(ldap.sshKeys), ShouldEqual, 1) }) }) }) }) }) }) }