/* SSHChallenge performs the challenge-response process to authenticate a connecting client to its SSH keys. */ func (sm *server) SSHChallenge(m protocol.MessageReadWriteCloser) (*User, error) { for { challenge := make([]byte, 64) for i := 0; i < len(challenge); i++ { challenge[i] = byte(rand.Int() % 256) } response := &protocol.Message{ ServerResponse: &protocol.ServerResponse{ Challenge: &protocol.SSHChallenge{ Challenge: challenge, }, }, } err := m.Write(response) if err != nil { return nil, err } challengeResponseMessage, err := m.Read() if err != nil { return nil, err } r := challengeResponseMessage.GetServerRequest() if r == nil { return nil, errors.New("not a server request") } cr := r.GetChallengeResponse() if cr == nil { return nil, errors.New("not a server request") } // Compose this into the proper format for Authenticate. sig := &ssh.Signature{ Format: cr.GetFormat(), Blob: cr.GetSignature(), } verifiedUser, err := sm.authenticator.Authenticate("derp", challenge, sig) if err != nil { return nil, err } if verifiedUser != nil { log.Debug("Verification completed for user %s!", verifiedUser.Username) return verifiedUser, nil } // continue around the loop, letting the client try another key verificationFailure := &protocol.Message{ ServerResponse: &protocol.ServerResponse{ VerificationFailure: &protocol.SSHVerificationFailure{}, }, } err = m.Write(verificationFailure) if err != nil { return nil, err } } }
/* PingHandler returns the correct response for a ping. */ func (sm *server) HandlePing(m protocol.MessageReadWriteCloser, p *protocol.Ping) { log.Debug("Handling a ping request.") sm.stats.Counter(1.0, "messages.ping", 1) pingType := protocol.Ping_RESPONSE pingMsg := &protocol.Message{ Ping: &protocol.Ping{ Type: &pingType, }, } m.Write(pingMsg) }
func testHandler(msc protocol.MessageReadWriteCloser) { for { msg, _ := msc.Read() if pingReq := msg.GetPing(); pingReq != nil { pingResp := protocol.Ping_RESPONSE msc.Write(&protocol.Message{ Ping: &protocol.Ping{ Type: &pingResp, }, }) } } }
func DummyServer(c protocol.MessageReadWriteCloser) { for { msg, err := c.Read() if err != nil { return } if msg.GetServerRequest() != nil { serverRequest := msg.GetServerRequest() accessKey := "access" secret := "secret" token := "token" exp := int64(0) if serverRequest.GetAssumeRole() != nil { challenge := &protocol.Message{ ServerResponse: &protocol.ServerResponse{ Challenge: &protocol.SSHChallenge{ Challenge: []byte("foo"), }, }, } err = c.Write(challenge) } else if serverRequest.GetChallengeResponse() != nil { creds := &protocol.Message{ ServerResponse: &protocol.ServerResponse{ Credentials: &protocol.STSCredentials{ AccessKeyId: &accessKey, SecretAccessKey: &secret, AccessToken: &token, Expiration: &exp, }, }, } err = c.Write(creds) } } } }
/* HandleServerRequest handles the flow for messages that this server accepts from clients. */ func (sm *server) HandleServerRequest(m protocol.MessageReadWriteCloser, r *protocol.ServerRequest) { if assumeRoleMsg := r.GetAssumeRole(); assumeRoleMsg != nil { log.Debug("Handling an assumeRole request.") sm.stats.Counter(1.0, "messages.assumeRole", 1) role := assumeRoleMsg.GetRole() user, err := sm.SSHChallenge(m) if err != nil { log.Errorf("Error trying to handle AssumeRole: %s", err.Error()) m.Close() return } if user != nil { creds, err := sm.credentials.AssumeRole(user, role, sm.enableLDAPRoles) if err != nil { // error message from Amazon, so forward that on to the client errStr := err.Error() errMsg := &protocol.Message{ Error: &errStr, } log.Errorf("Error from AWS for AssumeRole: %s", err.Error()) m.Write(errMsg) sm.stats.Counter(1.0, "errors.assumeRole", 1) //m.Close() return } m.Write(makeCredsResponse(creds)) return } } else if getUserCredentialsMsg := r.GetGetUserCredentials(); getUserCredentialsMsg != nil { sm.stats.Counter(1.0, "messages.getUserCredentialsMsg", 1) user, err := sm.SSHChallenge(m) if err != nil { log.Errorf("Error trying to handle GetUserCredentials: %s", err.Error()) m.Close() return } if user != nil { creds, err := sm.credentials.AssumeRole(user, sm.DefaultRole, sm.enableLDAPRoles) if err != nil { log.Errorf("Error trying to handle GetUserCredentials: %s", err.Error()) m.Close() return } m.Write(makeCredsResponse(creds)) return } } else if addSSHKeyMsg := r.GetAddSSHkey(); addSSHKeyMsg != nil { sm.stats.Counter(1.0, "messages.addSSHKeyMsg", 1) // Search for the user specified in this request. sr := ldap.NewSearchRequest( sm.baseDN, ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false, fmt.Sprintf("(%s=%s)", sm.userAttr, addSSHKeyMsg.GetUsername()), []string{"sshPublicKey", sm.userAttr, "userPassword"}, nil) user, err := sm.ldapServer.Search(sr) if err != nil { log.Errorf("Error trying to handle addSSHKeyMsg: %s", err.Error()) return } if len(user.Entries) == 0 { log.Errorf("User %s not found!", addSSHKeyMsg.GetUsername()) return } // Check their password. password := user.Entries[0].GetAttributeValue("userPassword") if password != addSSHKeyMsg.GetPasswordhash() { log.Errorf("Provided password for user %s does not match %s!", addSSHKeyMsg.GetUsername(), password) return } // Check to see if this SSH key already exists. for _, k := range user.Entries[0].GetAttributeValues("sshPublicKey") { if k == addSSHKeyMsg.GetSshkeybytes() { log.Warning("User %s already has this SSH key. Doing nothing.", addSSHKeyMsg.GetUsername()) successMsg := &protocol.Message{Success: &protocol.Success{}} m.Write(successMsg) return } } mr := ldap.NewModifyRequest(user.Entries[0].DN) mr.Add("sshPublicKey", []string{addSSHKeyMsg.GetSshkeybytes()}) err = sm.ldapServer.Modify(mr) if err != nil { log.Errorf("Could not modify LDAP user: %s", err.Error()) return } successMsg := &protocol.Message{Success: &protocol.Success{}} m.Write(successMsg) return } }
func (h *cliHandler) HandleConnection(c protocol.MessageReadWriteCloser) { for { msg, err := c.Read() if err != nil { return } if msg.GetAgentRequest() != nil { dr := msg.GetAgentRequest() var ( sshAgentSock string sshKeyBytes []byte ) sshAgentSock = dr.GetSshAgentSock() if sshAgentSock != "" { log.Debug("SSH_AUTH_SOCK included in this request: %s", sshAgentSock) } sshKeyBytes = dr.GetSshKeyFile() if sshKeyBytes != nil { log.Debug("SSH keyfile included in this request.") } SSHSetAgentSock(sshAgentSock, sshKeyBytes) if dr.GetAssumeRole() != nil { log.Debug("Handling AssumeRole request.") assumeRole := dr.GetAssumeRole() err := h.client.AssumeRole(assumeRole.GetRole()) var agentResponse protocol.AgentResponse if err == nil { agentResponse.Success = &protocol.Success{} } else { log.Errorf(err.Error()) e := err.Error() agentResponse.Failure = &protocol.Failure{ ErrorMessage: &e, } } msg = &protocol.Message{ AgentResponse: &agentResponse, } err = c.Write(msg) if err != nil { return } } else if dr.GetGetUserCredentials() != nil { log.Debug("Handling GetSessionToken request.") err := h.client.GetUserCredentials() var agentResponse protocol.AgentResponse if err == nil { agentResponse.Success = &protocol.Success{} } else { log.Errorf(err.Error()) e := err.Error() agentResponse.Failure = &protocol.Failure{ ErrorMessage: &e, } } msg = &protocol.Message{ AgentResponse: &agentResponse, } err = c.Write(msg) if err != nil { return } } else { log.Errorf("Unexpected agent request: %s", dr) c.Close() return } } else { log.Errorf("Unexpected message: %s", msg) c.Close() return } } }