func (sshClient *sshClient) handleNewDirectTcpipChannel(newChannel ssh.NewChannel) { // http://tools.ietf.org/html/rfc4254#section-7.2 var directTcpipExtraData struct { HostToConnect string PortToConnect uint32 OriginatorIPAddress string OriginatorPort uint32 } err := ssh.Unmarshal(newChannel.ExtraData(), &directTcpipExtraData) if err != nil { sshClient.rejectNewChannel(newChannel, ssh.Prohibited, "invalid extra data") return } targetAddr := fmt.Sprintf("%s:%d", directTcpipExtraData.HostToConnect, directTcpipExtraData.PortToConnect) log.WithContextFields(LogFields{"target": targetAddr}).Debug("dialing") // TODO: port forward dial timeout // TODO: report ssh.ResourceShortage when appropriate // TODO: IPv6 support fwdConn, err := net.Dial("tcp4", targetAddr) if err != nil { sshClient.rejectNewChannel(newChannel, ssh.ConnectionFailed, err.Error()) return } defer fwdConn.Close() fwdChannel, requests, err := newChannel.Accept() if err != nil { log.WithContextFields(LogFields{"error": err}).Warning("accept new channel failed") return } sshClient.Lock() sshClient.portForwardCount += 1 sshClient.concurrentPortForwardCount += 1 if sshClient.concurrentPortForwardCount > sshClient.peakConcurrentPortForwardCount { sshClient.peakConcurrentPortForwardCount = sshClient.concurrentPortForwardCount } sshClient.Unlock() log.WithContextFields(LogFields{"target": targetAddr}).Debug("relaying") go ssh.DiscardRequests(requests) defer fwdChannel.Close() // When idle port forward traffic rules are in place, wrap fwdConn // in an IdleTimeoutConn configured to reset idle on writes as well // as read. This ensures the port forward idle timeout only happens // when both upstream and downstream directions are are idle. if sshClient.trafficRules.IdlePortForwardTimeoutMilliseconds > 0 { fwdConn = psiphon.NewIdleTimeoutConn( fwdConn, time.Duration(sshClient.trafficRules.IdlePortForwardTimeoutMilliseconds)*time.Millisecond, true) } // relay channel to forwarded connection // TODO: relay errors to fwdChannel.Stderr()? var bytesUp, bytesDown int64 relayWaitGroup := new(sync.WaitGroup) relayWaitGroup.Add(1) go func() { defer relayWaitGroup.Done() var err error bytesUp, err = copyWithThrottle( fwdConn, fwdChannel, sshClient.trafficRules.ThrottleUpstreamSleepMilliseconds) if err != nil { log.WithContextFields(LogFields{"error": err}).Warning("upstream relay failed") } }() bytesDown, err = copyWithThrottle( fwdChannel, fwdConn, sshClient.trafficRules.ThrottleDownstreamSleepMilliseconds) if err != nil { log.WithContextFields(LogFields{"error": err}).Warning("downstream relay failed") } fwdChannel.CloseWrite() relayWaitGroup.Wait() sshClient.Lock() sshClient.concurrentPortForwardCount -= 1 sshClient.bytesUp += bytesUp sshClient.bytesDown += bytesDown sshClient.Unlock() log.WithContextFields(LogFields{"target": targetAddr}).Debug("exiting") }
func (sshServer *sshServer) handleClient(tcpConn *net.TCPConn) { sshClient := &sshClient{ sshServer: sshServer, startTime: time.Now(), geoIPData: GeoIPLookup(psiphon.IPAddressFromAddr(tcpConn.RemoteAddr())), } sshClient.trafficRules = sshServer.config.GetTrafficRules(sshClient.geoIPData.Country) // Wrap the base TCP connection with an IdleTimeoutConn which will terminate // the connection if no data is received before the deadline. This timeout is // in effect for the entire duration of the SSH connection. Clients must actively // use the connection or send SSH keep alive requests to keep the connection // active. conn := psiphon.NewIdleTimeoutConn(tcpConn, SSH_CONNECTION_READ_DEADLINE, false) // Run the initial [obfuscated] SSH handshake in a goroutine so we can both // respect shutdownBroadcast and implement a specific handshake timeout. // The timeout is to reclaim network resources in case the handshake takes // too long. type sshNewServerConnResult struct { conn net.Conn sshConn *ssh.ServerConn channels <-chan ssh.NewChannel requests <-chan *ssh.Request err error } resultChannel := make(chan *sshNewServerConnResult, 2) if SSH_HANDSHAKE_TIMEOUT > 0 { time.AfterFunc(time.Duration(SSH_HANDSHAKE_TIMEOUT), func() { resultChannel <- &sshNewServerConnResult{err: errors.New("ssh handshake timeout")} }) } go func() { result := &sshNewServerConnResult{} if sshServer.useObfuscation { result.conn, result.err = psiphon.NewObfuscatedSshConn( psiphon.OBFUSCATION_CONN_MODE_SERVER, conn, sshServer.config.ObfuscatedSSHKey) } else { result.conn = conn } if result.err == nil { sshServerConfig := &ssh.ServerConfig{ PasswordCallback: sshClient.passwordCallback, AuthLogCallback: sshClient.authLogCallback, ServerVersion: sshServer.config.SSHServerVersion, } sshServerConfig.AddHostKey(sshServer.sshHostKey) result.sshConn, result.channels, result.requests, result.err = ssh.NewServerConn(result.conn, sshServerConfig) } resultChannel <- result }() var result *sshNewServerConnResult select { case result = <-resultChannel: case <-sshServer.shutdownBroadcast: // Close() will interrupt an ongoing handshake // TODO: wait for goroutine to exit before returning? conn.Close() return } if result.err != nil { conn.Close() log.WithContextFields(LogFields{"error": result.err}).Warning("handshake failed") return } sshClient.Lock() sshClient.sshConn = result.sshConn sshClient.Unlock() clientID, ok := sshServer.registerClient(sshClient) if !ok { conn.Close() log.WithContext().Warning("register failed") return } defer sshServer.unregisterClient(clientID) go ssh.DiscardRequests(result.requests) sshClient.handleChannels(result.channels) }