func (sshClient *sshClient) handleTCPChannel( hostToConnect string, portToConnect int, newChannel ssh.NewChannel) { if !sshClient.isPortForwardPermitted( portToConnect, sshClient.trafficRules.AllowTCPPorts, sshClient.trafficRules.DenyTCPPorts) { sshClient.rejectNewChannel( newChannel, ssh.Prohibited, "port forward not permitted") return } var bytesUp, bytesDown int64 sshClient.openedPortForward(sshClient.tcpTrafficState) defer func() { sshClient.closedPortForward( sshClient.tcpTrafficState, atomic.LoadInt64(&bytesUp), atomic.LoadInt64(&bytesDown)) }() // TOCTOU note: important to increment the port forward count (via // openPortForward) _before_ checking isPortForwardLimitExceeded // otherwise, the client could potentially consume excess resources // by initiating many port forwards concurrently. // TODO: close LRU connection (after successful Dial) instead of // rejecting new connection? if sshClient.isPortForwardLimitExceeded( sshClient.tcpTrafficState, sshClient.trafficRules.MaxTCPPortForwardCount) { // Close the oldest TCP port forward. CloseOldest() closes // the conn and the port forward's goroutine will complete // the cleanup asynchronously. // // Some known limitations: // // - Since CloseOldest() closes the upstream socket but does not // clean up all resources associated with the port forward. These // include the goroutine(s) relaying traffic as well as the SSH // channel. Closing the socket will interrupt the goroutines which // will then complete the cleanup. But, since the full cleanup is // asynchronous, there exists a possibility that a client can consume // more than max port forward resources -- just not upstream sockets. // // - An LRU list entry for this port forward is not added until // after the dial completes, but the port forward is counted // towards max limits. This means many dials in progress will // put established connections in jeopardy. // // - We're closing the oldest open connection _before_ successfully // dialing the new port forward. This means we are potentially // discarding a good connection to make way for a failed connection. // We cannot simply dial first and still maintain a limit on // resources used, so to address this we'd need to add some // accounting for connections still establishing. sshClient.tcpPortForwardLRU.CloseOldest() log.WithContextFields( LogFields{ "maxCount": sshClient.trafficRules.MaxTCPPortForwardCount, }).Debug("closed LRU TCP port forward") } // Dial the target remote address. This is done in a goroutine to // ensure the shutdown signal is handled immediately. remoteAddr := fmt.Sprintf("%s:%d", hostToConnect, portToConnect) log.WithContextFields(LogFields{"remoteAddr": remoteAddr}).Debug("dialing") type dialTcpResult struct { conn net.Conn err error } resultChannel := make(chan *dialTcpResult, 1) go func() { // TODO: on EADDRNOTAVAIL, temporarily suspend new clients // TODO: IPv6 support conn, err := net.DialTimeout( "tcp4", remoteAddr, SSH_TCP_PORT_FORWARD_DIAL_TIMEOUT) resultChannel <- &dialTcpResult{conn, err} }() var result *dialTcpResult select { case result = <-resultChannel: case <-sshClient.stopBroadcast: // Note: may leave dial in progress return } if result.err != nil { sshClient.rejectNewChannel(newChannel, ssh.ConnectionFailed, result.err.Error()) return } // The upstream TCP port forward connection has been established. Schedule // some cleanup and notify the SSH client that the channel is accepted. fwdConn := result.conn defer fwdConn.Close() lruEntry := sshClient.tcpPortForwardLRU.Add(fwdConn) defer lruEntry.Remove() // ActivityMonitoredConn monitors the TCP port forward I/O and updates // its LRU status. ActivityMonitoredConn also times out read on the port // forward if both reads and writes have been idle for the specified // duration. fwdConn = psiphon.NewActivityMonitoredConn( fwdConn, time.Duration(sshClient.trafficRules.IdleTCPPortForwardTimeoutMilliseconds)*time.Millisecond, true, lruEntry) fwdChannel, requests, err := newChannel.Accept() if err != nil { log.WithContextFields(LogFields{"error": err}).Warning("accept new channel failed") return } go ssh.DiscardRequests(requests) defer fwdChannel.Close() log.WithContextFields(LogFields{"remoteAddr": remoteAddr}).Debug("relaying") // Relay channel to forwarded connection. // TODO: relay errors to fwdChannel.Stderr()? relayWaitGroup := new(sync.WaitGroup) relayWaitGroup.Add(1) go func() { defer relayWaitGroup.Done() // io.Copy allocates a 32K temporary buffer, and each port forward relay uses // two of these buffers; using io.CopyBuffer with a smaller buffer reduces the // overall memory footprint. bytes, err := io.CopyBuffer( fwdChannel, fwdConn, make([]byte, SSH_TCP_PORT_FORWARD_COPY_BUFFER_SIZE)) atomic.AddInt64(&bytesDown, bytes) if err != nil && err != io.EOF { // Debug since errors such as "connection reset by peer" occur during normal operation log.WithContextFields(LogFields{"error": err}).Debug("downstream TCP relay failed") } // Interrupt upstream io.Copy when downstream is shutting down. // TODO: this is done to quickly cleanup the port forward when // fwdConn has a read timeout, but is it clean -- upstream may still // be flowing? fwdChannel.Close() }() bytes, err := io.CopyBuffer( fwdConn, fwdChannel, make([]byte, SSH_TCP_PORT_FORWARD_COPY_BUFFER_SIZE)) atomic.AddInt64(&bytesUp, bytes) if err != nil && err != io.EOF { log.WithContextFields(LogFields{"error": err}).Debug("upstream TCP relay failed") } // Shutdown special case: fwdChannel will be closed and return EOF when // the SSH connection is closed, but we need to explicitly close fwdConn // to interrupt the downstream io.Copy, which may be blocked on a // fwdConn.Read(). fwdConn.Close() relayWaitGroup.Wait() log.WithContextFields( LogFields{ "remoteAddr": remoteAddr, "bytesUp": atomic.LoadInt64(&bytesUp), "bytesDown": atomic.LoadInt64(&bytesDown)}).Debug("exiting") }
func (mux *udpPortForwardMultiplexer) run() { // In a loop, read udpgw messages from the client to this channel. Each message is // a UDP packet to send upstream either via a new port forward, or on an existing // port forward. // // A goroutine is run to read downstream packets for each UDP port forward. All read // packets are encapsulated in udpgw protocol and sent down the channel to the client. // // When the client disconnects or the server shuts down, the channel will close and // readUdpgwMessage will exit with EOF. buffer := make([]byte, udpgwProtocolMaxMessageSize) for { // Note: message.packet points to the reusable memory in "buffer". // Each readUdpgwMessage call will overwrite the last message.packet. message, err := readUdpgwMessage(mux.sshChannel, buffer) if err != nil { if err != io.EOF { log.WithContextFields(LogFields{"error": err}).Warning("readUpdgwMessage failed") } break } mux.portForwardsMutex.Lock() portForward := mux.portForwards[message.connID] mux.portForwardsMutex.Unlock() if portForward != nil && message.discardExistingConn { // The port forward's goroutine will complete cleanup, including // tallying stats and calling sshClient.closedPortForward. // portForward.conn.Close() will signal this shutdown. // TODO: wait for goroutine to exit before proceeding? portForward.conn.Close() portForward = nil } if portForward != nil { // Verify that portForward remote address matches latest message if 0 != bytes.Compare(portForward.remoteIP, message.remoteIP) || portForward.remotePort != message.remotePort { log.WithContext().Warning("UDP port forward remote address mismatch") continue } } else { // Create a new port forward if !mux.sshClient.isPortForwardPermitted( int(message.remotePort), mux.sshClient.trafficRules.AllowUDPPorts, mux.sshClient.trafficRules.DenyUDPPorts) { // The udpgw protocol has no error response, so // we just discard the message and read another. continue } mux.sshClient.openedPortForward(mux.sshClient.udpTrafficState) // Note: can't defer sshClient.closedPortForward() here // TOCTOU note: important to increment the port forward count (via // openPortForward) _before_ checking isPortForwardLimitExceeded if mux.sshClient.isPortForwardLimitExceeded( mux.sshClient.tcpTrafficState, mux.sshClient.trafficRules.MaxUDPPortForwardCount) { // Close the oldest UDP port forward. CloseOldest() closes // the conn and the port forward's goroutine will complete // the cleanup asynchronously. // // See LRU comment in handleTCPChannel() for a known // limitations regarding CloseOldest(). mux.portForwardLRU.CloseOldest() log.WithContextFields( LogFields{ "maxCount": mux.sshClient.trafficRules.MaxUDPPortForwardCount, }).Debug("closed LRU UDP port forward") } dialIP := net.IP(message.remoteIP) dialPort := int(message.remotePort) // Transparent DNS forwarding if message.forwardDNS { dialIP, dialPort = mux.transparentDNSAddress(dialIP, dialPort) } log.WithContextFields( LogFields{ "remoteAddr": fmt.Sprintf("%s:%d", dialIP.String(), dialPort), "connID": message.connID}).Debug("dialing") // TODO: on EADDRNOTAVAIL, temporarily suspend new clients udpConn, err := net.DialUDP( "udp", nil, &net.UDPAddr{IP: dialIP, Port: dialPort}) if err != nil { mux.sshClient.closedPortForward(mux.sshClient.udpTrafficState, 0, 0) log.WithContextFields(LogFields{"error": err}).Warning("DialUDP failed") continue } lruEntry := mux.portForwardLRU.Add(udpConn) // ActivityMonitoredConn monitors the TCP port forward I/O and updates // its LRU status. ActivityMonitoredConn also times out read on the port // forward if both reads and writes have been idle for the specified // duration. conn := psiphon.NewActivityMonitoredConn( udpConn, time.Duration(mux.sshClient.trafficRules.IdleUDPPortForwardTimeoutMilliseconds)*time.Millisecond, true, lruEntry) portForward = &udpPortForward{ connID: message.connID, preambleSize: message.preambleSize, remoteIP: message.remoteIP, remotePort: message.remotePort, conn: conn, lruEntry: lruEntry, bytesUp: 0, bytesDown: 0, mux: mux, } mux.portForwardsMutex.Lock() mux.portForwards[portForward.connID] = portForward mux.portForwardsMutex.Unlock() // relayDownstream will call sshClient.closedPortForward() mux.relayWaitGroup.Add(1) go portForward.relayDownstream() } // Note: assumes UDP writes won't block (https://golang.org/pkg/net/#UDPConn.WriteToUDP) _, err = portForward.conn.Write(message.packet) if err != nil { // Debug since errors such as "write: operation not permitted" occur during normal operation log.WithContextFields(LogFields{"error": err}).Debug("upstream UDP relay failed") // The port forward's goroutine will complete cleanup portForward.conn.Close() } portForward.lruEntry.Touch() atomic.AddInt64(&portForward.bytesUp, int64(len(message.packet))) } // Cleanup all UDP port forward workers when exiting mux.portForwardsMutex.Lock() for _, portForward := range mux.portForwards { // The port forward's goroutine will complete cleanup portForward.conn.Close() } mux.portForwardsMutex.Unlock() mux.relayWaitGroup.Wait() }
func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.Conn) { sshServer.registerAcceptedClient(tunnelProtocol) defer sshServer.unregisterAcceptedClient(tunnelProtocol) geoIPData := sshServer.support.GeoIPService.Lookup( psiphon.IPAddressFromAddr(clientConn.RemoteAddr())) // TODO: apply reload of TrafficRulesSet to existing clients sshClient := newSshClient( sshServer, tunnelProtocol, geoIPData, sshServer.support.TrafficRulesSet.GetTrafficRules(geoIPData.Country)) // Wrap the base client connection with an ActivityMonitoredConn which will // terminate the connection if no data is received before the deadline. This // timeout is in effect for the entire duration of the SSH connection. Clients // must actively use the connection or send SSH keep alive requests to keep // the connection active. activityConn := psiphon.NewActivityMonitoredConn( clientConn, SSH_CONNECTION_READ_DEADLINE, false, nil) clientConn = activityConn // Further wrap the connection in a rate limiting ThrottledConn. rateLimits := sshClient.trafficRules.GetRateLimits(tunnelProtocol) clientConn = psiphon.NewThrottledConn( clientConn, rateLimits.DownstreamUnlimitedBytes, int64(rateLimits.DownstreamBytesPerSecond), rateLimits.UpstreamUnlimitedBytes, int64(rateLimits.UpstreamBytesPerSecond)) // Run the initial [obfuscated] SSH handshake in a goroutine so we can both // respect shutdownBroadcast and implement a specific handshake timeout. // The timeout is to reclaim network resources in case the handshake takes // too long. type sshNewServerConnResult struct { conn net.Conn sshConn *ssh.ServerConn channels <-chan ssh.NewChannel requests <-chan *ssh.Request err error } resultChannel := make(chan *sshNewServerConnResult, 2) if SSH_HANDSHAKE_TIMEOUT > 0 { time.AfterFunc(time.Duration(SSH_HANDSHAKE_TIMEOUT), func() { resultChannel <- &sshNewServerConnResult{err: errors.New("ssh handshake timeout")} }) } go func(conn net.Conn) { sshServerConfig := &ssh.ServerConfig{ PasswordCallback: sshClient.passwordCallback, AuthLogCallback: sshClient.authLogCallback, ServerVersion: sshServer.support.Config.SSHServerVersion, } sshServerConfig.AddHostKey(sshServer.sshHostKey) result := &sshNewServerConnResult{} // Wrap the connection in an SSH deobfuscator when required. if psiphon.TunnelProtocolUsesObfuscatedSSH(tunnelProtocol) { // Note: NewObfuscatedSshConn blocks on network I/O // TODO: ensure this won't block shutdown conn, result.err = psiphon.NewObfuscatedSshConn( psiphon.OBFUSCATION_CONN_MODE_SERVER, clientConn, sshServer.support.Config.ObfuscatedSSHKey) if result.err != nil { result.err = psiphon.ContextError(result.err) } } if result.err == nil { result.sshConn, result.channels, result.requests, result.err = ssh.NewServerConn(conn, sshServerConfig) } resultChannel <- result }(clientConn) var result *sshNewServerConnResult select { case result = <-resultChannel: case <-sshServer.shutdownBroadcast: // Close() will interrupt an ongoing handshake // TODO: wait for goroutine to exit before returning? clientConn.Close() return } if result.err != nil { clientConn.Close() // This is a Debug log due to noise. The handshake often fails due to I/O // errors as clients frequently interrupt connections in progress when // client-side load balancing completes a connection to a different server. log.WithContextFields(LogFields{"error": result.err}).Debug("handshake failed") return } sshClient.Lock() sshClient.sshConn = result.sshConn sshClient.activityConn = activityConn sshClient.Unlock() clientID, ok := sshServer.registerEstablishedClient(sshClient) if !ok { clientConn.Close() log.WithContext().Warning("register failed") return } defer sshServer.unregisterEstablishedClient(clientID) sshClient.runClient(result.channels, result.requests) // Note: sshServer.unregisterClient calls sshClient.Close(), // which also closes underlying transport Conn. }