Ejemplo n.º 1
0
func TestAgentForward(t *testing.T) {
	server := newServer(t)
	defer server.Shutdown()
	conn := server.Dial(clientConfig())
	defer conn.Close()

	keyring := agent.NewKeyring()
	if err := keyring.Add(agent.AddedKey{PrivateKey: testPrivateKeys["dsa"]}); err != nil {
		t.Fatalf("Error adding key: %s", err)
	}
	if err := keyring.Add(agent.AddedKey{
		PrivateKey:       testPrivateKeys["dsa"],
		ConfirmBeforeUse: true,
		LifetimeSecs:     3600,
	}); err != nil {
		t.Fatalf("Error adding key with constraints: %s", err)
	}
	pub := testPublicKeys["dsa"]

	sess, err := conn.NewSession()
	if err != nil {
		t.Fatalf("NewSession: %v", err)
	}
	if err := agent.RequestAgentForwarding(sess); err != nil {
		t.Fatalf("RequestAgentForwarding: %v", err)
	}

	if err := agent.ForwardToAgent(conn, keyring); err != nil {
		t.Fatalf("SetupForwardKeyring: %v", err)
	}
	out, err := sess.CombinedOutput("ssh-add -L")
	if err != nil {
		t.Fatalf("running ssh-add: %v, out %s", err, out)
	}
	key, _, _, _, err := ssh.ParseAuthorizedKey(out)
	if err != nil {
		t.Fatalf("ParseAuthorizedKey(%q): %v", out, err)
	}

	if !bytes.Equal(key.Marshal(), pub.Marshal()) {
		t.Fatalf("got key %s, want %s", ssh.MarshalAuthorizedKey(key), ssh.MarshalAuthorizedKey(pub))
	}
}
Ejemplo n.º 2
0
// newServer returns a new mock ssh server.
func newServer(t *testing.T) *server {
	if testing.Short() {
		t.Skip("skipping test due to -short")
	}
	dir, err := ioutil.TempDir("", "sshtest")
	if err != nil {
		t.Fatal(err)
	}
	f, err := os.Create(filepath.Join(dir, "sshd_config"))
	if err != nil {
		t.Fatal(err)
	}
	err = configTmpl.Execute(f, map[string]string{
		"Dir": dir,
	})
	if err != nil {
		t.Fatal(err)
	}
	f.Close()

	for k, v := range testdata.PEMBytes {
		filename := "id_" + k
		writeFile(filepath.Join(dir, filename), v)
		writeFile(filepath.Join(dir, filename+".pub"), ssh.MarshalAuthorizedKey(testPublicKeys[k]))
	}

	return &server{
		t:          t,
		configfile: f.Name(),
		cleanup: func() {
			if err := os.RemoveAll(dir); err != nil {
				t.Error(err)
			}
		},
	}
}
Ejemplo n.º 3
0
func (s *Server) Handle(nConn net.Conn) {
	le := &logEntry{Timestamp: time.Now().Format(time.RFC3339)}
	defer json.NewEncoder(os.Stdout).Encode(le)

	conn, chans, reqs, err := ssh.NewServerConn(nConn, s.sshConfig)
	if err != nil {
		le.Error = "Handshake failed: " + err.Error()
		return
	}
	defer func() {
		s.mu.Lock()
		delete(s.sessionInfo, string(conn.SessionID()))
		s.mu.Unlock()
		conn.Close()
	}()
	go func(in <-chan *ssh.Request) {
		for req := range in {
			le.RequestTypes = append(le.RequestTypes, req.Type)
			if req.WantReply {
				req.Reply(false, nil)
			}
		}
	}(reqs)

	s.mu.RLock()
	si := s.sessionInfo[string(conn.SessionID())]
	s.mu.RUnlock()

	le.Username = conn.User()
	le.ClientVersion = fmt.Sprintf("%x", conn.ClientVersion())
	for _, key := range si.Keys {
		le.KeysOffered = append(le.KeysOffered, string(ssh.MarshalAuthorizedKey(key)))
	}

	for newChannel := range chans {
		le.ChannelTypes = append(le.ChannelTypes, newChannel.ChannelType())

		if newChannel.ChannelType() != "session" {
			newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
			continue
		}
		channel, requests, err := newChannel.Accept()
		if err != nil {
			le.Error = "Channel accept failed: " + err.Error()
			continue
		}

		agentFwd, x11 := false, false
		reqLock := &sync.Mutex{}
		reqLock.Lock()
		timeout := time.AfterFunc(30*time.Second, func() { reqLock.Unlock() })

		go func(in <-chan *ssh.Request) {
			for req := range in {
				le.RequestTypes = append(le.RequestTypes, req.Type)
				ok := false
				switch req.Type {
				case "shell":
					fallthrough
				case "pty-req":
					ok = true

					// "*****@*****.**" and "x11-req" always arrive
					// before the "pty-req", so we can go ahead now
					if timeout.Stop() {
						reqLock.Unlock()
					}

				case "*****@*****.**":
					agentFwd = true
				case "x11-req":
					x11 = true
				}

				if req.WantReply {
					req.Reply(ok, nil)
				}
			}
		}(requests)

		reqLock.Lock()

		if err != nil {
			le.Error = "findUser failed: " + err.Error()
			channel.Close()
			continue
		}

		channel.Close()
	}
}