Exemplo n.º 1
0
func TestAuth(t *testing.T) {
	a, b, err := netPipe()
	if err != nil {
		t.Fatalf("netPipe: %v", err)
	}

	defer a.Close()
	defer b.Close()

	agent, _, cleanup := startAgent(t)
	defer cleanup()

	if err := agent.Add(testPrivateKeys["rsa"], nil, "comment"); err != nil {
		t.Errorf("Add: %v", err)
	}

	serverConf := ssh.ServerConfig{}
	serverConf.AddHostKey(testSigners["rsa"])
	serverConf.PublicKeyCallback = func(c ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
		if bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) {
			return nil, nil
		}

		return nil, errors.New("pubkey rejected")
	}

	go func() {
		conn, _, _, err := ssh.NewServerConn(a, &serverConf)
		if err != nil {
			t.Fatalf("Server: %v", err)
		}
		conn.Close()
	}()

	conf := ssh.ClientConfig{}
	conf.Auth = append(conf.Auth, ssh.PublicKeysCallback(agent.Signers))
	conn, _, _, err := ssh.NewClientConn(b, "", &conf)
	if err != nil {
		t.Fatalf("NewClientConn: %v", err)
	}
	conn.Close()
}
Exemplo n.º 2
0
func (c *SSHCluster) findSSHAuth() error {
	c.base.SendLog("Detecting authentication")

	testAndAddAuthMethod := func(t *TargetServer, a ssh.AuthMethod) bool {
		sshConfig := c.sshConfigForAuth(t, []ssh.AuthMethod{a})
		if c.testAndAddAuthentication(t, sshConfig) {
			c.base.SendLog(fmt.Sprintf("Verified authentication for %s@%s", t.User, t.IP))
			return true
		}
		return false
	}

	testAndAddSigner := func(t *TargetServer, s ssh.Signer) bool {
		return testAndAddAuthMethod(t, ssh.PublicKeys(s))
	}

	testAllAuthenticated := func(targets []*TargetServer) bool {
		for _, t := range targets {
			if t.SSHConfig == nil {
				return false
			}
		}
		return true
	}

	sshAgent := c.sshAgent()
	sshAgentAuth := ssh.PublicKeysCallback(sshAgent.Signers)
	for _, t := range c.Targets {
		testAndAddAuthMethod(t, sshAgentAuth)
	}

	if testAllAuthenticated(c.Targets) {
		return nil
	}

	var agentKeys [][]byte
	if keys, err := sshAgent.List(); err == nil {
		agentKeys = make([][]byte, len(keys))
		for i, k := range keys {
			agentKeys[i] = k.Marshal()
		}
	}

	var signers []privateKeySigner

signerloop:
	for _, s := range c.findSSHKeySigners() {
		if s.publicKey != nil {
			for _, k := range agentKeys {
				if bytes.Equal(k, s.publicKey.Marshal()) {
					continue signerloop
				}
			}
		}
		signers = append(signers, s)
	}

outer:
	for _, t := range c.Targets {
		if t.SSHConfig != nil {
			continue
		}
		for _, s := range signers {
			if s.Encrypted {
				if s.publicKey == nil {
					signer, err := s.Decrypt()
					if err != nil {
						continue
					}
					if testAndAddSigner(t, signer) {
						continue outer
					}
				} else {
					if testAndAddSigner(t, s) {
						continue outer
					}
				}
			} else {
				signer, err := ssh.NewSignerFromKey(s.key)
				if err != nil {
					continue
				}
				if testAndAddSigner(t, signer) {
					continue outer
				}
			}
		}
	}

	for _, t := range c.Targets {
		if t.SSHConfig != nil {
			continue
		}
		answer, err := c.base.ChoicePrompt(Choice{
			Message: "No working authentication found.\nPlease choose one of the following options:",
			Options: []ChoiceOption{
				{
					Type:  1,
					Name:  "Private key",
					Value: "1",
				},
				{
					Name:  "Password",
					Value: "2",
				},
				{
					Name:  "Abort",
					Value: "3",
				},
			},
		})
		if err != nil {
			return err
		}
		switch answer {
		case "1":
			if err := c.importSSHKeyPair(t); err != nil {
				return err
			} else {
				continue
			}
		case "2":
			password := c.base.PromptProtectedInput(fmt.Sprintf("Please enter your password for %s@%s", t.User, t.IP))
			if testAndAddAuthMethod(t, ssh.Password(password)) {
				continue
			}
		}
		return fmt.Errorf("No working authentication found")
	}
	return nil
}