// dialSsh is a helper that builds the transport layers and establishes the SSH connection. // When additional dial configuration is used, DialStats are recorded and returned. // // The net.Conn return value is the value to be removed from pendingConns; additional // layering (ThrottledConn, ActivityMonitoredConn) is applied, but this return value is the // base dial conn. The *ActivityMonitoredConn return value is the layered conn passed into // the ssh.Client. func dialSsh( config *Config, pendingConns *common.Conns, serverEntry *protocol.ServerEntry, selectedProtocol, sessionId string) (*dialResult, error) { // The meek protocols tunnel obfuscated SSH. Obfuscated SSH is layered on top of SSH. // So depending on which protocol is used, multiple layers are initialized. useObfuscatedSsh := false var directTCPDialAddress string var meekConfig *MeekConfig var err error switch selectedProtocol { case protocol.TUNNEL_PROTOCOL_OBFUSCATED_SSH: useObfuscatedSsh = true directTCPDialAddress = fmt.Sprintf("%s:%d", serverEntry.IpAddress, serverEntry.SshObfuscatedPort) case protocol.TUNNEL_PROTOCOL_SSH: directTCPDialAddress = fmt.Sprintf("%s:%d", serverEntry.IpAddress, serverEntry.SshPort) default: useObfuscatedSsh = true meekConfig, err = initMeekConfig(config, serverEntry, selectedProtocol, sessionId) if err != nil { return nil, common.ContextError(err) } } NoticeConnectingServer( serverEntry.IpAddress, serverEntry.Region, selectedProtocol, directTCPDialAddress, meekConfig) // Use an asynchronous callback to record the resolved IP address when // dialing a domain name. Note that DialMeek doesn't immediately // establish any HTTPS connections, so the resolved IP address won't be // reported until during/after ssh session establishment (the ssh traffic // is meek payload). So don't Load() the IP address value until after that // has completed to ensure a result. var resolvedIPAddress atomic.Value resolvedIPAddress.Store("") setResolvedIPAddress := func(IPAddress string) { resolvedIPAddress.Store(IPAddress) } // Create the base transport: meek or direct connection dialConfig := &DialConfig{ UpstreamProxyUrl: config.UpstreamProxyUrl, UpstreamProxyCustomHeaders: config.UpstreamProxyCustomHeaders, ConnectTimeout: time.Duration(*config.TunnelConnectTimeoutSeconds) * time.Second, PendingConns: pendingConns, DeviceBinder: config.DeviceBinder, DnsServerGetter: config.DnsServerGetter, UseIndistinguishableTLS: config.UseIndistinguishableTLS, TrustedCACertificatesFilename: config.TrustedCACertificatesFilename, DeviceRegion: config.DeviceRegion, ResolvedIPCallback: setResolvedIPAddress, } var dialConn net.Conn if meekConfig != nil { dialConn, err = DialMeek(meekConfig, dialConfig) if err != nil { return nil, common.ContextError(err) } } else { dialConn, err = DialTCP(directTCPDialAddress, dialConfig) if err != nil { return nil, common.ContextError(err) } } cleanupConn := dialConn defer func() { // Cleanup on error if cleanupConn != nil { cleanupConn.Close() pendingConns.Remove(cleanupConn) } }() // Activity monitoring is used to measure tunnel duration monitoredConn, err := common.NewActivityMonitoredConn(dialConn, 0, false, nil, nil) if err != nil { return nil, common.ContextError(err) } // Apply throttling (if configured) throttledConn := common.NewThrottledConn(monitoredConn, config.RateLimits) // Add obfuscated SSH layer var sshConn net.Conn = throttledConn if useObfuscatedSsh { sshConn, err = common.NewObfuscatedSshConn( common.OBFUSCATION_CONN_MODE_CLIENT, throttledConn, serverEntry.SshObfuscatedKey) if err != nil { return nil, common.ContextError(err) } } // Now establish the SSH session over the conn transport expectedPublicKey, err := base64.StdEncoding.DecodeString(serverEntry.SshHostKey) if err != nil { return nil, common.ContextError(err) } sshCertChecker := &ssh.CertChecker{ HostKeyFallback: func(addr string, remote net.Addr, publicKey ssh.PublicKey) error { if !bytes.Equal(expectedPublicKey, publicKey.Marshal()) { return common.ContextError(errors.New("unexpected host public key")) } return nil }, } sshPasswordPayload := &protocol.SSHPasswordPayload{ SessionId: sessionId, SshPassword: serverEntry.SshPassword, ClientCapabilities: []string{protocol.CLIENT_CAPABILITY_SERVER_REQUESTS}, } payload, err := json.Marshal(sshPasswordPayload) if err != nil { return nil, common.ContextError(err) } sshClientConfig := &ssh.ClientConfig{ User: serverEntry.SshUsername, Auth: []ssh.AuthMethod{ ssh.Password(string(payload)), }, HostKeyCallback: sshCertChecker.CheckHostKey, } // The ssh session establishment (via ssh.NewClientConn) is wrapped // in a timeout to ensure it won't hang. We've encountered firewalls // that allow the TCP handshake to complete but then send a RST to the // server-side and nothing to the client-side, and if that happens // while ssh.NewClientConn is reading, it may wait forever. The timeout // closes the conn, which interrupts it. // Note: TCP handshake timeouts are provided by TCPConn, and session // timeouts *after* ssh establishment are provided by the ssh keep alive // in operate tunnel. // TODO: adjust the timeout to account for time-elapsed-from-start type sshNewClientResult struct { sshClient *ssh.Client sshRequests <-chan *ssh.Request err error } resultChannel := make(chan *sshNewClientResult, 2) if *config.TunnelConnectTimeoutSeconds > 0 { time.AfterFunc(time.Duration(*config.TunnelConnectTimeoutSeconds)*time.Second, func() { resultChannel <- &sshNewClientResult{nil, nil, errors.New("ssh dial timeout")} }) } go func() { // The following is adapted from ssh.Dial(), here using a custom conn // The sshAddress is passed through to host key verification callbacks; we don't use it. sshAddress := "" sshClientConn, sshChannels, sshRequests, err := ssh.NewClientConn( sshConn, sshAddress, sshClientConfig) var sshClient *ssh.Client if err == nil { sshClient = ssh.NewClient(sshClientConn, sshChannels, nil) } resultChannel <- &sshNewClientResult{sshClient, sshRequests, err} }() result := <-resultChannel if result.err != nil { return nil, common.ContextError(result.err) } var dialStats *TunnelDialStats if dialConfig.UpstreamProxyUrl != "" || meekConfig != nil { dialStats = &TunnelDialStats{} if dialConfig.UpstreamProxyUrl != "" { // Note: UpstreamProxyUrl should have parsed correctly in the dial proxyURL, err := url.Parse(dialConfig.UpstreamProxyUrl) if err == nil { dialStats.UpstreamProxyType = proxyURL.Scheme } dialStats.UpstreamProxyCustomHeaderNames = make([]string, 0) for name, _ := range dialConfig.UpstreamProxyCustomHeaders { dialStats.UpstreamProxyCustomHeaderNames = append(dialStats.UpstreamProxyCustomHeaderNames, name) } } if meekConfig != nil { dialStats.MeekDialAddress = meekConfig.DialAddress dialStats.MeekResolvedIPAddress = resolvedIPAddress.Load().(string) dialStats.MeekSNIServerName = meekConfig.SNIServerName dialStats.MeekHostHeader = meekConfig.HostHeader dialStats.MeekTransformedHostName = meekConfig.TransformedHostName } NoticeConnectedTunnelDialStats(serverEntry.IpAddress, dialStats) } cleanupConn = nil // Note: dialConn may be used to close the underlying network connection // but should not be used to perform I/O as that would interfere with SSH // (and also bypasses throttling). return &dialResult{ dialConn: dialConn, monitoredConn: monitoredConn, sshClient: result.sshClient, sshRequests: result.sshRequests, dialStats: dialStats}, nil }
func (sshClient *sshClient) handleTCPChannel( hostToConnect string, portToConnect int, newChannel ssh.NewChannel) { isWebServerPortForward := false config := sshClient.sshServer.support.Config if config.WebServerPortForwardAddress != "" { destination := net.JoinHostPort(hostToConnect, strconv.Itoa(portToConnect)) if destination == config.WebServerPortForwardAddress { isWebServerPortForward = true if config.WebServerPortForwardRedirectAddress != "" { // Note: redirect format is validated when config is loaded host, portStr, _ := net.SplitHostPort(config.WebServerPortForwardRedirectAddress) port, _ := strconv.Atoi(portStr) hostToConnect = host portToConnect = port } } } if !isWebServerPortForward && !sshClient.isPortForwardPermitted( portForwardTypeTCP, hostToConnect, portToConnect) { sshClient.rejectNewChannel( newChannel, ssh.Prohibited, "port forward not permitted") return } var bytesUp, bytesDown int64 sshClient.openedPortForward(portForwardTypeTCP) defer func() { sshClient.closedPortForward( portForwardTypeTCP, atomic.LoadInt64(&bytesUp), atomic.LoadInt64(&bytesDown)) }() // TOCTOU note: important to increment the port forward count (via // openPortForward) _before_ checking isPortForwardLimitExceeded // otherwise, the client could potentially consume excess resources // by initiating many port forwards concurrently. // TODO: close LRU connection (after successful Dial) instead of // rejecting new connection? if maxCount, exceeded := sshClient.isPortForwardLimitExceeded(portForwardTypeTCP); exceeded { // Close the oldest TCP port forward. CloseOldest() closes // the conn and the port forward's goroutine will complete // the cleanup asynchronously. // // Some known limitations: // // - Since CloseOldest() closes the upstream socket but does not // clean up all resources associated with the port forward. These // include the goroutine(s) relaying traffic as well as the SSH // channel. Closing the socket will interrupt the goroutines which // will then complete the cleanup. But, since the full cleanup is // asynchronous, there exists a possibility that a client can consume // more than max port forward resources -- just not upstream sockets. // // - An LRU list entry for this port forward is not added until // after the dial completes, but the port forward is counted // towards max limits. This means many dials in progress will // put established connections in jeopardy. // // - We're closing the oldest open connection _before_ successfully // dialing the new port forward. This means we are potentially // discarding a good connection to make way for a failed connection. // We cannot simply dial first and still maintain a limit on // resources used, so to address this we'd need to add some // accounting for connections still establishing. sshClient.tcpPortForwardLRU.CloseOldest() log.WithContextFields( LogFields{ "maxCount": maxCount, }).Debug("closed LRU TCP port forward") } // Dial the target remote address. This is done in a goroutine to // ensure the shutdown signal is handled immediately. remoteAddr := fmt.Sprintf("%s:%d", hostToConnect, portToConnect) log.WithContextFields(LogFields{"remoteAddr": remoteAddr}).Debug("dialing") type dialTcpResult struct { conn net.Conn err error } resultChannel := make(chan *dialTcpResult, 1) go func() { // TODO: on EADDRNOTAVAIL, temporarily suspend new clients // TODO: IPv6 support conn, err := net.DialTimeout( "tcp4", remoteAddr, SSH_TCP_PORT_FORWARD_DIAL_TIMEOUT) resultChannel <- &dialTcpResult{conn, err} }() var result *dialTcpResult select { case result = <-resultChannel: case <-sshClient.stopBroadcast: // Note: may leave dial in progress (TODO: use DialContext to cancel) return } if result.err != nil { sshClient.rejectNewChannel(newChannel, ssh.ConnectionFailed, result.err.Error()) return } // The upstream TCP port forward connection has been established. Schedule // some cleanup and notify the SSH client that the channel is accepted. fwdConn := result.conn defer fwdConn.Close() fwdChannel, requests, err := newChannel.Accept() if err != nil { log.WithContextFields(LogFields{"error": err}).Warning("accept new channel failed") return } go ssh.DiscardRequests(requests) defer fwdChannel.Close() // ActivityMonitoredConn monitors the TCP port forward I/O and updates // its LRU status. ActivityMonitoredConn also times out I/O on the port // forward if both reads and writes have been idle for the specified // duration. lruEntry := sshClient.tcpPortForwardLRU.Add(fwdConn) defer lruEntry.Remove() fwdConn, err = common.NewActivityMonitoredConn( fwdConn, sshClient.idleTCPPortForwardTimeout(), true, lruEntry) if result.err != nil { log.WithContextFields(LogFields{"error": err}).Error("NewActivityMonitoredConn failed") return } // Relay channel to forwarded connection. log.WithContextFields(LogFields{"remoteAddr": remoteAddr}).Debug("relaying") // TODO: relay errors to fwdChannel.Stderr()? relayWaitGroup := new(sync.WaitGroup) relayWaitGroup.Add(1) go func() { defer relayWaitGroup.Done() // io.Copy allocates a 32K temporary buffer, and each port forward relay uses // two of these buffers; using io.CopyBuffer with a smaller buffer reduces the // overall memory footprint. bytes, err := io.CopyBuffer( fwdChannel, fwdConn, make([]byte, SSH_TCP_PORT_FORWARD_COPY_BUFFER_SIZE)) atomic.AddInt64(&bytesDown, bytes) if err != nil && err != io.EOF { // Debug since errors such as "connection reset by peer" occur during normal operation log.WithContextFields(LogFields{"error": err}).Debug("downstream TCP relay failed") } // Interrupt upstream io.Copy when downstream is shutting down. // TODO: this is done to quickly cleanup the port forward when // fwdConn has a read timeout, but is it clean -- upstream may still // be flowing? fwdChannel.Close() }() bytes, err := io.CopyBuffer( fwdConn, fwdChannel, make([]byte, SSH_TCP_PORT_FORWARD_COPY_BUFFER_SIZE)) atomic.AddInt64(&bytesUp, bytes) if err != nil && err != io.EOF { log.WithContextFields(LogFields{"error": err}).Debug("upstream TCP relay failed") } // Shutdown special case: fwdChannel will be closed and return EOF when // the SSH connection is closed, but we need to explicitly close fwdConn // to interrupt the downstream io.Copy, which may be blocked on a // fwdConn.Read(). fwdConn.Close() relayWaitGroup.Wait() log.WithContextFields( LogFields{ "remoteAddr": remoteAddr, "bytesUp": atomic.LoadInt64(&bytesUp), "bytesDown": atomic.LoadInt64(&bytesDown)}).Debug("exiting") }
func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.Conn) { sshServer.registerAcceptedClient(tunnelProtocol) defer sshServer.unregisterAcceptedClient(tunnelProtocol) geoIPData := sshServer.support.GeoIPService.Lookup( common.IPAddressFromAddr(clientConn.RemoteAddr())) sshClient := newSshClient(sshServer, tunnelProtocol, geoIPData) // Set initial traffic rules, pre-handshake, based on currently known info. sshClient.setTrafficRules() // Wrap the base client connection with an ActivityMonitoredConn which will // terminate the connection if no data is received before the deadline. This // timeout is in effect for the entire duration of the SSH connection. Clients // must actively use the connection or send SSH keep alive requests to keep // the connection active. Writes are not considered reliable activity indicators // due to buffering. activityConn, err := common.NewActivityMonitoredConn( clientConn, SSH_CONNECTION_READ_DEADLINE, false, nil) if err != nil { clientConn.Close() log.WithContextFields(LogFields{"error": err}).Error("NewActivityMonitoredConn failed") return } clientConn = activityConn // Further wrap the connection in a rate limiting ThrottledConn. throttledConn := common.NewThrottledConn(clientConn, sshClient.rateLimits()) clientConn = throttledConn // Run the initial [obfuscated] SSH handshake in a goroutine so we can both // respect shutdownBroadcast and implement a specific handshake timeout. // The timeout is to reclaim network resources in case the handshake takes // too long. type sshNewServerConnResult struct { conn net.Conn sshConn *ssh.ServerConn channels <-chan ssh.NewChannel requests <-chan *ssh.Request err error } resultChannel := make(chan *sshNewServerConnResult, 2) if SSH_HANDSHAKE_TIMEOUT > 0 { time.AfterFunc(time.Duration(SSH_HANDSHAKE_TIMEOUT), func() { resultChannel <- &sshNewServerConnResult{err: errors.New("ssh handshake timeout")} }) } go func(conn net.Conn) { sshServerConfig := &ssh.ServerConfig{ PasswordCallback: sshClient.passwordCallback, AuthLogCallback: sshClient.authLogCallback, ServerVersion: sshServer.support.Config.SSHServerVersion, } sshServerConfig.AddHostKey(sshServer.sshHostKey) result := &sshNewServerConnResult{} // Wrap the connection in an SSH deobfuscator when required. if common.TunnelProtocolUsesObfuscatedSSH(tunnelProtocol) { // Note: NewObfuscatedSshConn blocks on network I/O // TODO: ensure this won't block shutdown conn, result.err = psiphon.NewObfuscatedSshConn( psiphon.OBFUSCATION_CONN_MODE_SERVER, conn, sshServer.support.Config.ObfuscatedSSHKey) if result.err != nil { result.err = common.ContextError(result.err) } } if result.err == nil { result.sshConn, result.channels, result.requests, result.err = ssh.NewServerConn(conn, sshServerConfig) } resultChannel <- result }(clientConn) var result *sshNewServerConnResult select { case result = <-resultChannel: case <-sshServer.shutdownBroadcast: // Close() will interrupt an ongoing handshake // TODO: wait for goroutine to exit before returning? clientConn.Close() return } if result.err != nil { clientConn.Close() // This is a Debug log due to noise. The handshake often fails due to I/O // errors as clients frequently interrupt connections in progress when // client-side load balancing completes a connection to a different server. log.WithContextFields(LogFields{"error": result.err}).Debug("handshake failed") return } sshClient.Lock() sshClient.sshConn = result.sshConn sshClient.activityConn = activityConn sshClient.throttledConn = throttledConn sshClient.Unlock() if !sshServer.registerEstablishedClient(sshClient) { clientConn.Close() log.WithContext().Warning("register failed") return } defer sshServer.unregisterEstablishedClient(sshClient.sessionID) sshClient.runClient(result.channels, result.requests) // Note: sshServer.unregisterClient calls sshClient.Close(), // which also closes underlying transport Conn. }
func (mux *udpPortForwardMultiplexer) run() { // In a loop, read udpgw messages from the client to this channel. Each message is // a UDP packet to send upstream either via a new port forward, or on an existing // port forward. // // A goroutine is run to read downstream packets for each UDP port forward. All read // packets are encapsulated in udpgw protocol and sent down the channel to the client. // // When the client disconnects or the server shuts down, the channel will close and // readUdpgwMessage will exit with EOF. // Recover from and log any unexpected panics caused by udpgw input handling bugs. // Note: this covers the run() goroutine only and not relayDownstream() goroutines. defer func() { if e := recover(); e != nil { err := common.ContextError( fmt.Errorf( "udpPortForwardMultiplexer panic: %s: %s", e, debug.Stack())) log.WithContextFields(LogFields{"error": err}).Warning("run failed") } }() buffer := make([]byte, udpgwProtocolMaxMessageSize) for { // Note: message.packet points to the reusable memory in "buffer". // Each readUdpgwMessage call will overwrite the last message.packet. message, err := readUdpgwMessage(mux.sshChannel, buffer) if err != nil { if err != io.EOF { log.WithContextFields(LogFields{"error": err}).Warning("readUpdgwMessage failed") } break } mux.portForwardsMutex.Lock() portForward := mux.portForwards[message.connID] mux.portForwardsMutex.Unlock() if portForward != nil && message.discardExistingConn { // The port forward's goroutine will complete cleanup, including // tallying stats and calling sshClient.closedPortForward. // portForward.conn.Close() will signal this shutdown. // TODO: wait for goroutine to exit before proceeding? portForward.conn.Close() portForward = nil } if portForward != nil { // Verify that portForward remote address matches latest message if 0 != bytes.Compare(portForward.remoteIP, message.remoteIP) || portForward.remotePort != message.remotePort { log.WithContext().Warning("UDP port forward remote address mismatch") continue } } else { // Create a new port forward dialIP := net.IP(message.remoteIP) dialPort := int(message.remotePort) // Transparent DNS forwarding if message.forwardDNS { dialIP = mux.sshClient.sshServer.support.DNSResolver.Get() dialPort = DNS_RESOLVER_PORT } if !mux.sshClient.isPortForwardPermitted( portForwardTypeUDP, dialIP.String(), int(message.remotePort)) { // The udpgw protocol has no error response, so // we just discard the message and read another. continue } mux.sshClient.openedPortForward(portForwardTypeUDP) // Note: can't defer sshClient.closedPortForward() here // TOCTOU note: important to increment the port forward count (via // openPortForward) _before_ checking isPortForwardLimitExceeded if maxCount, exceeded := mux.sshClient.isPortForwardLimitExceeded(portForwardTypeUDP); exceeded { // Close the oldest UDP port forward. CloseOldest() closes // the conn and the port forward's goroutine will complete // the cleanup asynchronously. // // See LRU comment in handleTCPChannel() for a known // limitations regarding CloseOldest(). mux.portForwardLRU.CloseOldest() log.WithContextFields( LogFields{ "maxCount": maxCount, }).Debug("closed LRU UDP port forward") } log.WithContextFields( LogFields{ "remoteAddr": fmt.Sprintf("%s:%d", dialIP.String(), dialPort), "connID": message.connID}).Debug("dialing") // TODO: on EADDRNOTAVAIL, temporarily suspend new clients udpConn, err := net.DialUDP( "udp", nil, &net.UDPAddr{IP: dialIP, Port: dialPort}) if err != nil { mux.sshClient.closedPortForward(portForwardTypeUDP, 0, 0) log.WithContextFields(LogFields{"error": err}).Warning("DialUDP failed") continue } // ActivityMonitoredConn monitors the TCP port forward I/O and updates // its LRU status. ActivityMonitoredConn also times out I/O on the port // forward if both reads and writes have been idle for the specified // duration. lruEntry := mux.portForwardLRU.Add(udpConn) conn, err := common.NewActivityMonitoredConn( udpConn, mux.sshClient.idleUDPPortForwardTimeout(), true, lruEntry) if err != nil { lruEntry.Remove() mux.sshClient.closedPortForward(portForwardTypeUDP, 0, 0) log.WithContextFields(LogFields{"error": err}).Error("NewActivityMonitoredConn failed") continue } portForward = &udpPortForward{ connID: message.connID, preambleSize: message.preambleSize, remoteIP: message.remoteIP, remotePort: message.remotePort, conn: conn, lruEntry: lruEntry, bytesUp: 0, bytesDown: 0, mux: mux, } mux.portForwardsMutex.Lock() mux.portForwards[portForward.connID] = portForward mux.portForwardsMutex.Unlock() // relayDownstream will call sshClient.closedPortForward() mux.relayWaitGroup.Add(1) go portForward.relayDownstream() } // Note: assumes UDP writes won't block (https://golang.org/pkg/net/#UDPConn.WriteToUDP) _, err = portForward.conn.Write(message.packet) if err != nil { // Debug since errors such as "write: operation not permitted" occur during normal operation log.WithContextFields(LogFields{"error": err}).Debug("upstream UDP relay failed") // The port forward's goroutine will complete cleanup portForward.conn.Close() } portForward.lruEntry.Touch() atomic.AddInt64(&portForward.bytesUp, int64(len(message.packet))) } // Cleanup all UDP port forward workers when exiting mux.portForwardsMutex.Lock() for _, portForward := range mux.portForwards { // The port forward's goroutine will complete cleanup portForward.conn.Close() } mux.portForwardsMutex.Unlock() mux.relayWaitGroup.Wait() }