func TestAgentForward(t *testing.T) { server := newServer(t) defer server.Shutdown() conn := server.Dial(clientConfig()) defer conn.Close() keyring := agent.NewKeyring() if err := keyring.Add(agent.AddedKey{PrivateKey: testPrivateKeys["dsa"]}); err != nil { t.Fatalf("Error adding key: %s", err) } if err := keyring.Add(agent.AddedKey{ PrivateKey: testPrivateKeys["dsa"], ConfirmBeforeUse: true, LifetimeSecs: 3600, }); err != nil { t.Fatalf("Error adding key with constraints: %s", err) } pub := testPublicKeys["dsa"] sess, err := conn.NewSession() if err != nil { t.Fatalf("NewSession: %v", err) } if err := agent.RequestAgentForwarding(sess); err != nil { t.Fatalf("RequestAgentForwarding: %v", err) } if err := agent.ForwardToAgent(conn, keyring); err != nil { t.Fatalf("SetupForwardKeyring: %v", err) } out, err := sess.CombinedOutput("ssh-add -L") if err != nil { t.Fatalf("running ssh-add: %v, out %s", err, out) } key, _, _, _, err := ssh.ParseAuthorizedKey(out) if err != nil { t.Fatalf("ParseAuthorizedKey(%q): %v", out, err) } if !bytes.Equal(key.Marshal(), pub.Marshal()) { t.Fatalf("got key %s, want %s", ssh.MarshalAuthorizedKey(key), ssh.MarshalAuthorizedKey(pub)) } }
// newServer returns a new mock ssh server. func newServer(t *testing.T) *server { if testing.Short() { t.Skip("skipping test due to -short") } dir, err := ioutil.TempDir("", "sshtest") if err != nil { t.Fatal(err) } f, err := os.Create(filepath.Join(dir, "sshd_config")) if err != nil { t.Fatal(err) } err = configTmpl.Execute(f, map[string]string{ "Dir": dir, }) if err != nil { t.Fatal(err) } f.Close() for k, v := range testdata.PEMBytes { filename := "id_" + k writeFile(filepath.Join(dir, filename), v) writeFile(filepath.Join(dir, filename+".pub"), ssh.MarshalAuthorizedKey(testPublicKeys[k])) } return &server{ t: t, configfile: f.Name(), cleanup: func() { if err := os.RemoveAll(dir); err != nil { t.Error(err) } }, } }
func (s *Server) Handle(nConn net.Conn) { le := &logEntry{Timestamp: time.Now().Format(time.RFC3339)} defer json.NewEncoder(os.Stdout).Encode(le) conn, chans, reqs, err := ssh.NewServerConn(nConn, s.sshConfig) if err != nil { le.Error = "Handshake failed: " + err.Error() return } defer func() { s.mu.Lock() delete(s.sessionInfo, string(conn.SessionID())) s.mu.Unlock() conn.Close() }() go func(in <-chan *ssh.Request) { for req := range in { le.RequestTypes = append(le.RequestTypes, req.Type) if req.WantReply { req.Reply(false, nil) } } }(reqs) s.mu.RLock() si := s.sessionInfo[string(conn.SessionID())] s.mu.RUnlock() le.Username = conn.User() le.ClientVersion = fmt.Sprintf("%x", conn.ClientVersion()) for _, key := range si.Keys { le.KeysOffered = append(le.KeysOffered, string(ssh.MarshalAuthorizedKey(key))) } for newChannel := range chans { le.ChannelTypes = append(le.ChannelTypes, newChannel.ChannelType()) if newChannel.ChannelType() != "session" { newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") continue } channel, requests, err := newChannel.Accept() if err != nil { le.Error = "Channel accept failed: " + err.Error() continue } agentFwd, x11 := false, false reqLock := &sync.Mutex{} reqLock.Lock() timeout := time.AfterFunc(30*time.Second, func() { reqLock.Unlock() }) go func(in <-chan *ssh.Request) { for req := range in { le.RequestTypes = append(le.RequestTypes, req.Type) ok := false switch req.Type { case "shell": fallthrough case "pty-req": ok = true // "*****@*****.**" and "x11-req" always arrive // before the "pty-req", so we can go ahead now if timeout.Stop() { reqLock.Unlock() } case "*****@*****.**": agentFwd = true case "x11-req": x11 = true } if req.WantReply { req.Reply(ok, nil) } } }(requests) reqLock.Lock() if err != nil { le.Error = "findUser failed: " + err.Error() channel.Close() continue } channel.Close() } }