// dialSsh is a helper that builds the transport layers and establishes the SSH connection.
// When additional dial configuration is used, DialStats are recorded and returned.
//
// The net.Conn return value is the value to be removed from pendingConns; additional
// layering (ThrottledConn, ActivityMonitoredConn) is applied, but this return value is the
// base dial conn. The *ActivityMonitoredConn return value is the layered conn passed into
// the ssh.Client.
func dialSsh(
	config *Config,
	pendingConns *common.Conns,
	serverEntry *protocol.ServerEntry,
	selectedProtocol,
	sessionId string) (*dialResult, error) {

	// The meek protocols tunnel obfuscated SSH. Obfuscated SSH is layered on top of SSH.
	// So depending on which protocol is used, multiple layers are initialized.

	useObfuscatedSsh := false
	var directTCPDialAddress string
	var meekConfig *MeekConfig
	var err error

	switch selectedProtocol {
	case protocol.TUNNEL_PROTOCOL_OBFUSCATED_SSH:
		useObfuscatedSsh = true
		directTCPDialAddress = fmt.Sprintf("%s:%d", serverEntry.IpAddress, serverEntry.SshObfuscatedPort)

	case protocol.TUNNEL_PROTOCOL_SSH:
		directTCPDialAddress = fmt.Sprintf("%s:%d", serverEntry.IpAddress, serverEntry.SshPort)

	default:
		useObfuscatedSsh = true
		meekConfig, err = initMeekConfig(config, serverEntry, selectedProtocol, sessionId)
		if err != nil {
			return nil, common.ContextError(err)
		}
	}

	NoticeConnectingServer(
		serverEntry.IpAddress,
		serverEntry.Region,
		selectedProtocol,
		directTCPDialAddress,
		meekConfig)

	// Use an asynchronous callback to record the resolved IP address when
	// dialing a domain name. Note that DialMeek doesn't immediately
	// establish any HTTPS connections, so the resolved IP address won't be
	// reported until during/after ssh session establishment (the ssh traffic
	// is meek payload). So don't Load() the IP address value until after that
	// has completed to ensure a result.
	var resolvedIPAddress atomic.Value
	resolvedIPAddress.Store("")
	setResolvedIPAddress := func(IPAddress string) {
		resolvedIPAddress.Store(IPAddress)
	}

	// Create the base transport: meek or direct connection
	dialConfig := &DialConfig{
		UpstreamProxyUrl:              config.UpstreamProxyUrl,
		UpstreamProxyCustomHeaders:    config.UpstreamProxyCustomHeaders,
		ConnectTimeout:                time.Duration(*config.TunnelConnectTimeoutSeconds) * time.Second,
		PendingConns:                  pendingConns,
		DeviceBinder:                  config.DeviceBinder,
		DnsServerGetter:               config.DnsServerGetter,
		UseIndistinguishableTLS:       config.UseIndistinguishableTLS,
		TrustedCACertificatesFilename: config.TrustedCACertificatesFilename,
		DeviceRegion:                  config.DeviceRegion,
		ResolvedIPCallback:            setResolvedIPAddress,
	}
	var dialConn net.Conn
	if meekConfig != nil {
		dialConn, err = DialMeek(meekConfig, dialConfig)
		if err != nil {
			return nil, common.ContextError(err)
		}
	} else {
		dialConn, err = DialTCP(directTCPDialAddress, dialConfig)
		if err != nil {
			return nil, common.ContextError(err)
		}
	}

	cleanupConn := dialConn
	defer func() {
		// Cleanup on error
		if cleanupConn != nil {
			cleanupConn.Close()
			pendingConns.Remove(cleanupConn)
		}
	}()

	// Activity monitoring is used to measure tunnel duration
	monitoredConn, err := common.NewActivityMonitoredConn(dialConn, 0, false, nil, nil)
	if err != nil {
		return nil, common.ContextError(err)
	}

	// Apply throttling (if configured)
	throttledConn := common.NewThrottledConn(monitoredConn, config.RateLimits)

	// Add obfuscated SSH layer
	var sshConn net.Conn = throttledConn
	if useObfuscatedSsh {
		sshConn, err = common.NewObfuscatedSshConn(
			common.OBFUSCATION_CONN_MODE_CLIENT, throttledConn, serverEntry.SshObfuscatedKey)
		if err != nil {
			return nil, common.ContextError(err)
		}
	}

	// Now establish the SSH session over the conn transport
	expectedPublicKey, err := base64.StdEncoding.DecodeString(serverEntry.SshHostKey)
	if err != nil {
		return nil, common.ContextError(err)
	}
	sshCertChecker := &ssh.CertChecker{
		HostKeyFallback: func(addr string, remote net.Addr, publicKey ssh.PublicKey) error {
			if !bytes.Equal(expectedPublicKey, publicKey.Marshal()) {
				return common.ContextError(errors.New("unexpected host public key"))
			}
			return nil
		},
	}

	sshPasswordPayload := &protocol.SSHPasswordPayload{
		SessionId:          sessionId,
		SshPassword:        serverEntry.SshPassword,
		ClientCapabilities: []string{protocol.CLIENT_CAPABILITY_SERVER_REQUESTS},
	}

	payload, err := json.Marshal(sshPasswordPayload)
	if err != nil {
		return nil, common.ContextError(err)
	}
	sshClientConfig := &ssh.ClientConfig{
		User: serverEntry.SshUsername,
		Auth: []ssh.AuthMethod{
			ssh.Password(string(payload)),
		},
		HostKeyCallback: sshCertChecker.CheckHostKey,
	}

	// The ssh session establishment (via ssh.NewClientConn) is wrapped
	// in a timeout to ensure it won't hang. We've encountered firewalls
	// that allow the TCP handshake to complete but then send a RST to the
	// server-side and nothing to the client-side, and if that happens
	// while ssh.NewClientConn is reading, it may wait forever. The timeout
	// closes the conn, which interrupts it.
	// Note: TCP handshake timeouts are provided by TCPConn, and session
	// timeouts *after* ssh establishment are provided by the ssh keep alive
	// in operate tunnel.
	// TODO: adjust the timeout to account for time-elapsed-from-start

	type sshNewClientResult struct {
		sshClient   *ssh.Client
		sshRequests <-chan *ssh.Request
		err         error
	}
	resultChannel := make(chan *sshNewClientResult, 2)
	if *config.TunnelConnectTimeoutSeconds > 0 {
		time.AfterFunc(time.Duration(*config.TunnelConnectTimeoutSeconds)*time.Second, func() {
			resultChannel <- &sshNewClientResult{nil, nil, errors.New("ssh dial timeout")}
		})
	}

	go func() {
		// The following is adapted from ssh.Dial(), here using a custom conn
		// The sshAddress is passed through to host key verification callbacks; we don't use it.
		sshAddress := ""
		sshClientConn, sshChannels, sshRequests, err := ssh.NewClientConn(
			sshConn, sshAddress, sshClientConfig)
		var sshClient *ssh.Client
		if err == nil {
			sshClient = ssh.NewClient(sshClientConn, sshChannels, nil)
		}
		resultChannel <- &sshNewClientResult{sshClient, sshRequests, err}
	}()

	result := <-resultChannel
	if result.err != nil {
		return nil, common.ContextError(result.err)
	}

	var dialStats *TunnelDialStats

	if dialConfig.UpstreamProxyUrl != "" || meekConfig != nil {
		dialStats = &TunnelDialStats{}

		if dialConfig.UpstreamProxyUrl != "" {

			// Note: UpstreamProxyUrl should have parsed correctly in the dial
			proxyURL, err := url.Parse(dialConfig.UpstreamProxyUrl)
			if err == nil {
				dialStats.UpstreamProxyType = proxyURL.Scheme
			}

			dialStats.UpstreamProxyCustomHeaderNames = make([]string, 0)
			for name, _ := range dialConfig.UpstreamProxyCustomHeaders {
				dialStats.UpstreamProxyCustomHeaderNames = append(dialStats.UpstreamProxyCustomHeaderNames, name)
			}
		}

		if meekConfig != nil {
			dialStats.MeekDialAddress = meekConfig.DialAddress
			dialStats.MeekResolvedIPAddress = resolvedIPAddress.Load().(string)
			dialStats.MeekSNIServerName = meekConfig.SNIServerName
			dialStats.MeekHostHeader = meekConfig.HostHeader
			dialStats.MeekTransformedHostName = meekConfig.TransformedHostName
		}

		NoticeConnectedTunnelDialStats(serverEntry.IpAddress, dialStats)
	}

	cleanupConn = nil

	// Note: dialConn may be used to close the underlying network connection
	// but should not be used to perform I/O as that would interfere with SSH
	// (and also bypasses throttling).

	return &dialResult{
			dialConn:      dialConn,
			monitoredConn: monitoredConn,
			sshClient:     result.sshClient,
			sshRequests:   result.sshRequests,
			dialStats:     dialStats},
		nil
}
func (sshClient *sshClient) handleTCPChannel(
	hostToConnect string,
	portToConnect int,
	newChannel ssh.NewChannel) {

	isWebServerPortForward := false
	config := sshClient.sshServer.support.Config
	if config.WebServerPortForwardAddress != "" {
		destination := net.JoinHostPort(hostToConnect, strconv.Itoa(portToConnect))
		if destination == config.WebServerPortForwardAddress {
			isWebServerPortForward = true
			if config.WebServerPortForwardRedirectAddress != "" {
				// Note: redirect format is validated when config is loaded
				host, portStr, _ := net.SplitHostPort(config.WebServerPortForwardRedirectAddress)
				port, _ := strconv.Atoi(portStr)
				hostToConnect = host
				portToConnect = port
			}
		}
	}

	if !isWebServerPortForward && !sshClient.isPortForwardPermitted(
		portForwardTypeTCP, hostToConnect, portToConnect) {

		sshClient.rejectNewChannel(
			newChannel, ssh.Prohibited, "port forward not permitted")
		return
	}

	var bytesUp, bytesDown int64
	sshClient.openedPortForward(portForwardTypeTCP)
	defer func() {
		sshClient.closedPortForward(
			portForwardTypeTCP, atomic.LoadInt64(&bytesUp), atomic.LoadInt64(&bytesDown))
	}()

	// TOCTOU note: important to increment the port forward count (via
	// openPortForward) _before_ checking isPortForwardLimitExceeded
	// otherwise, the client could potentially consume excess resources
	// by initiating many port forwards concurrently.
	// TODO: close LRU connection (after successful Dial) instead of
	// rejecting new connection?
	if maxCount, exceeded := sshClient.isPortForwardLimitExceeded(portForwardTypeTCP); exceeded {

		// Close the oldest TCP port forward. CloseOldest() closes
		// the conn and the port forward's goroutine will complete
		// the cleanup asynchronously.
		//
		// Some known limitations:
		//
		// - Since CloseOldest() closes the upstream socket but does not
		//   clean up all resources associated with the port forward. These
		//   include the goroutine(s) relaying traffic as well as the SSH
		//   channel. Closing the socket will interrupt the goroutines which
		//   will then complete the cleanup. But, since the full cleanup is
		//   asynchronous, there exists a possibility that a client can consume
		//   more than max port forward resources -- just not upstream sockets.
		//
		// - An LRU list entry for this port forward is not added until
		//   after the dial completes, but the port forward is counted
		//   towards max limits. This means many dials in progress will
		//   put established connections in jeopardy.
		//
		// - We're closing the oldest open connection _before_ successfully
		//   dialing the new port forward. This means we are potentially
		//   discarding a good connection to make way for a failed connection.
		//   We cannot simply dial first and still maintain a limit on
		//   resources used, so to address this we'd need to add some
		//   accounting for connections still establishing.

		sshClient.tcpPortForwardLRU.CloseOldest()

		log.WithContextFields(
			LogFields{
				"maxCount": maxCount,
			}).Debug("closed LRU TCP port forward")
	}

	// Dial the target remote address. This is done in a goroutine to
	// ensure the shutdown signal is handled immediately.

	remoteAddr := fmt.Sprintf("%s:%d", hostToConnect, portToConnect)

	log.WithContextFields(LogFields{"remoteAddr": remoteAddr}).Debug("dialing")

	type dialTcpResult struct {
		conn net.Conn
		err  error
	}

	resultChannel := make(chan *dialTcpResult, 1)

	go func() {
		// TODO: on EADDRNOTAVAIL, temporarily suspend new clients
		// TODO: IPv6 support
		conn, err := net.DialTimeout(
			"tcp4", remoteAddr, SSH_TCP_PORT_FORWARD_DIAL_TIMEOUT)
		resultChannel <- &dialTcpResult{conn, err}
	}()

	var result *dialTcpResult
	select {
	case result = <-resultChannel:
	case <-sshClient.stopBroadcast:
		// Note: may leave dial in progress (TODO: use DialContext to cancel)
		return
	}

	if result.err != nil {
		sshClient.rejectNewChannel(newChannel, ssh.ConnectionFailed, result.err.Error())
		return
	}

	// The upstream TCP port forward connection has been established. Schedule
	// some cleanup and notify the SSH client that the channel is accepted.

	fwdConn := result.conn
	defer fwdConn.Close()

	fwdChannel, requests, err := newChannel.Accept()
	if err != nil {
		log.WithContextFields(LogFields{"error": err}).Warning("accept new channel failed")
		return
	}
	go ssh.DiscardRequests(requests)
	defer fwdChannel.Close()

	// ActivityMonitoredConn monitors the TCP port forward I/O and updates
	// its LRU status. ActivityMonitoredConn also times out I/O on the port
	// forward if both reads and writes have been idle for the specified
	// duration.

	lruEntry := sshClient.tcpPortForwardLRU.Add(fwdConn)
	defer lruEntry.Remove()

	fwdConn, err = common.NewActivityMonitoredConn(
		fwdConn,
		sshClient.idleTCPPortForwardTimeout(),
		true,
		lruEntry)
	if result.err != nil {
		log.WithContextFields(LogFields{"error": err}).Error("NewActivityMonitoredConn failed")
		return
	}

	// Relay channel to forwarded connection.

	log.WithContextFields(LogFields{"remoteAddr": remoteAddr}).Debug("relaying")

	// TODO: relay errors to fwdChannel.Stderr()?
	relayWaitGroup := new(sync.WaitGroup)
	relayWaitGroup.Add(1)
	go func() {
		defer relayWaitGroup.Done()
		// io.Copy allocates a 32K temporary buffer, and each port forward relay uses
		// two of these buffers; using io.CopyBuffer with a smaller buffer reduces the
		// overall memory footprint.
		bytes, err := io.CopyBuffer(
			fwdChannel, fwdConn, make([]byte, SSH_TCP_PORT_FORWARD_COPY_BUFFER_SIZE))
		atomic.AddInt64(&bytesDown, bytes)
		if err != nil && err != io.EOF {
			// Debug since errors such as "connection reset by peer" occur during normal operation
			log.WithContextFields(LogFields{"error": err}).Debug("downstream TCP relay failed")
		}
		// Interrupt upstream io.Copy when downstream is shutting down.
		// TODO: this is done to quickly cleanup the port forward when
		// fwdConn has a read timeout, but is it clean -- upstream may still
		// be flowing?
		fwdChannel.Close()
	}()
	bytes, err := io.CopyBuffer(
		fwdConn, fwdChannel, make([]byte, SSH_TCP_PORT_FORWARD_COPY_BUFFER_SIZE))
	atomic.AddInt64(&bytesUp, bytes)
	if err != nil && err != io.EOF {
		log.WithContextFields(LogFields{"error": err}).Debug("upstream TCP relay failed")
	}
	// Shutdown special case: fwdChannel will be closed and return EOF when
	// the SSH connection is closed, but we need to explicitly close fwdConn
	// to interrupt the downstream io.Copy, which may be blocked on a
	// fwdConn.Read().
	fwdConn.Close()

	relayWaitGroup.Wait()

	log.WithContextFields(
		LogFields{
			"remoteAddr": remoteAddr,
			"bytesUp":    atomic.LoadInt64(&bytesUp),
			"bytesDown":  atomic.LoadInt64(&bytesDown)}).Debug("exiting")
}
func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.Conn) {

	sshServer.registerAcceptedClient(tunnelProtocol)
	defer sshServer.unregisterAcceptedClient(tunnelProtocol)

	geoIPData := sshServer.support.GeoIPService.Lookup(
		common.IPAddressFromAddr(clientConn.RemoteAddr()))

	sshClient := newSshClient(sshServer, tunnelProtocol, geoIPData)

	// Set initial traffic rules, pre-handshake, based on currently known info.
	sshClient.setTrafficRules()

	// Wrap the base client connection with an ActivityMonitoredConn which will
	// terminate the connection if no data is received before the deadline. This
	// timeout is in effect for the entire duration of the SSH connection. Clients
	// must actively use the connection or send SSH keep alive requests to keep
	// the connection active. Writes are not considered reliable activity indicators
	// due to buffering.

	activityConn, err := common.NewActivityMonitoredConn(
		clientConn,
		SSH_CONNECTION_READ_DEADLINE,
		false,
		nil)
	if err != nil {
		clientConn.Close()
		log.WithContextFields(LogFields{"error": err}).Error("NewActivityMonitoredConn failed")
		return
	}
	clientConn = activityConn

	// Further wrap the connection in a rate limiting ThrottledConn.

	throttledConn := common.NewThrottledConn(clientConn, sshClient.rateLimits())
	clientConn = throttledConn

	// Run the initial [obfuscated] SSH handshake in a goroutine so we can both
	// respect shutdownBroadcast and implement a specific handshake timeout.
	// The timeout is to reclaim network resources in case the handshake takes
	// too long.

	type sshNewServerConnResult struct {
		conn     net.Conn
		sshConn  *ssh.ServerConn
		channels <-chan ssh.NewChannel
		requests <-chan *ssh.Request
		err      error
	}

	resultChannel := make(chan *sshNewServerConnResult, 2)

	if SSH_HANDSHAKE_TIMEOUT > 0 {
		time.AfterFunc(time.Duration(SSH_HANDSHAKE_TIMEOUT), func() {
			resultChannel <- &sshNewServerConnResult{err: errors.New("ssh handshake timeout")}
		})
	}

	go func(conn net.Conn) {
		sshServerConfig := &ssh.ServerConfig{
			PasswordCallback: sshClient.passwordCallback,
			AuthLogCallback:  sshClient.authLogCallback,
			ServerVersion:    sshServer.support.Config.SSHServerVersion,
		}
		sshServerConfig.AddHostKey(sshServer.sshHostKey)

		result := &sshNewServerConnResult{}

		// Wrap the connection in an SSH deobfuscator when required.

		if common.TunnelProtocolUsesObfuscatedSSH(tunnelProtocol) {
			// Note: NewObfuscatedSshConn blocks on network I/O
			// TODO: ensure this won't block shutdown
			conn, result.err = psiphon.NewObfuscatedSshConn(
				psiphon.OBFUSCATION_CONN_MODE_SERVER,
				conn,
				sshServer.support.Config.ObfuscatedSSHKey)
			if result.err != nil {
				result.err = common.ContextError(result.err)
			}
		}

		if result.err == nil {
			result.sshConn, result.channels, result.requests, result.err =
				ssh.NewServerConn(conn, sshServerConfig)
		}

		resultChannel <- result

	}(clientConn)

	var result *sshNewServerConnResult
	select {
	case result = <-resultChannel:
	case <-sshServer.shutdownBroadcast:
		// Close() will interrupt an ongoing handshake
		// TODO: wait for goroutine to exit before returning?
		clientConn.Close()
		return
	}

	if result.err != nil {
		clientConn.Close()
		// This is a Debug log due to noise. The handshake often fails due to I/O
		// errors as clients frequently interrupt connections in progress when
		// client-side load balancing completes a connection to a different server.
		log.WithContextFields(LogFields{"error": result.err}).Debug("handshake failed")
		return
	}

	sshClient.Lock()
	sshClient.sshConn = result.sshConn
	sshClient.activityConn = activityConn
	sshClient.throttledConn = throttledConn
	sshClient.Unlock()

	if !sshServer.registerEstablishedClient(sshClient) {
		clientConn.Close()
		log.WithContext().Warning("register failed")
		return
	}
	defer sshServer.unregisterEstablishedClient(sshClient.sessionID)

	sshClient.runClient(result.channels, result.requests)

	// Note: sshServer.unregisterClient calls sshClient.Close(),
	// which also closes underlying transport Conn.
}
func (mux *udpPortForwardMultiplexer) run() {

	// In a loop, read udpgw messages from the client to this channel. Each message is
	// a UDP packet to send upstream either via a new port forward, or on an existing
	// port forward.
	//
	// A goroutine is run to read downstream packets for each UDP port forward. All read
	// packets are encapsulated in udpgw protocol and sent down the channel to the client.
	//
	// When the client disconnects or the server shuts down, the channel will close and
	// readUdpgwMessage will exit with EOF.

	// Recover from and log any unexpected panics caused by udpgw input handling bugs.
	// Note: this covers the run() goroutine only and not relayDownstream() goroutines.
	defer func() {
		if e := recover(); e != nil {
			err := common.ContextError(
				fmt.Errorf(
					"udpPortForwardMultiplexer panic: %s: %s", e, debug.Stack()))
			log.WithContextFields(LogFields{"error": err}).Warning("run failed")
		}
	}()

	buffer := make([]byte, udpgwProtocolMaxMessageSize)
	for {
		// Note: message.packet points to the reusable memory in "buffer".
		// Each readUdpgwMessage call will overwrite the last message.packet.
		message, err := readUdpgwMessage(mux.sshChannel, buffer)
		if err != nil {
			if err != io.EOF {
				log.WithContextFields(LogFields{"error": err}).Warning("readUpdgwMessage failed")
			}
			break
		}

		mux.portForwardsMutex.Lock()
		portForward := mux.portForwards[message.connID]
		mux.portForwardsMutex.Unlock()

		if portForward != nil && message.discardExistingConn {
			// The port forward's goroutine will complete cleanup, including
			// tallying stats and calling sshClient.closedPortForward.
			// portForward.conn.Close() will signal this shutdown.
			// TODO: wait for goroutine to exit before proceeding?
			portForward.conn.Close()
			portForward = nil
		}

		if portForward != nil {

			// Verify that portForward remote address matches latest message

			if 0 != bytes.Compare(portForward.remoteIP, message.remoteIP) ||
				portForward.remotePort != message.remotePort {

				log.WithContext().Warning("UDP port forward remote address mismatch")
				continue
			}

		} else {

			// Create a new port forward

			dialIP := net.IP(message.remoteIP)
			dialPort := int(message.remotePort)

			// Transparent DNS forwarding
			if message.forwardDNS {
				dialIP = mux.sshClient.sshServer.support.DNSResolver.Get()
				dialPort = DNS_RESOLVER_PORT
			}

			if !mux.sshClient.isPortForwardPermitted(
				portForwardTypeUDP, dialIP.String(), int(message.remotePort)) {
				// The udpgw protocol has no error response, so
				// we just discard the message and read another.
				continue
			}

			mux.sshClient.openedPortForward(portForwardTypeUDP)
			// Note: can't defer sshClient.closedPortForward() here

			// TOCTOU note: important to increment the port forward count (via
			// openPortForward) _before_ checking isPortForwardLimitExceeded
			if maxCount, exceeded := mux.sshClient.isPortForwardLimitExceeded(portForwardTypeUDP); exceeded {

				// Close the oldest UDP port forward. CloseOldest() closes
				// the conn and the port forward's goroutine will complete
				// the cleanup asynchronously.
				//
				// See LRU comment in handleTCPChannel() for a known
				// limitations regarding CloseOldest().
				mux.portForwardLRU.CloseOldest()

				log.WithContextFields(
					LogFields{
						"maxCount": maxCount,
					}).Debug("closed LRU UDP port forward")
			}

			log.WithContextFields(
				LogFields{
					"remoteAddr": fmt.Sprintf("%s:%d", dialIP.String(), dialPort),
					"connID":     message.connID}).Debug("dialing")

			// TODO: on EADDRNOTAVAIL, temporarily suspend new clients
			udpConn, err := net.DialUDP(
				"udp", nil, &net.UDPAddr{IP: dialIP, Port: dialPort})
			if err != nil {
				mux.sshClient.closedPortForward(portForwardTypeUDP, 0, 0)
				log.WithContextFields(LogFields{"error": err}).Warning("DialUDP failed")
				continue
			}

			// ActivityMonitoredConn monitors the TCP port forward I/O and updates
			// its LRU status. ActivityMonitoredConn also times out I/O on the port
			// forward if both reads and writes have been idle for the specified
			// duration.

			lruEntry := mux.portForwardLRU.Add(udpConn)

			conn, err := common.NewActivityMonitoredConn(
				udpConn,
				mux.sshClient.idleUDPPortForwardTimeout(),
				true,
				lruEntry)
			if err != nil {
				lruEntry.Remove()
				mux.sshClient.closedPortForward(portForwardTypeUDP, 0, 0)
				log.WithContextFields(LogFields{"error": err}).Error("NewActivityMonitoredConn failed")
				continue
			}

			portForward = &udpPortForward{
				connID:       message.connID,
				preambleSize: message.preambleSize,
				remoteIP:     message.remoteIP,
				remotePort:   message.remotePort,
				conn:         conn,
				lruEntry:     lruEntry,
				bytesUp:      0,
				bytesDown:    0,
				mux:          mux,
			}
			mux.portForwardsMutex.Lock()
			mux.portForwards[portForward.connID] = portForward
			mux.portForwardsMutex.Unlock()

			// relayDownstream will call sshClient.closedPortForward()
			mux.relayWaitGroup.Add(1)
			go portForward.relayDownstream()
		}

		// Note: assumes UDP writes won't block (https://golang.org/pkg/net/#UDPConn.WriteToUDP)
		_, err = portForward.conn.Write(message.packet)
		if err != nil {
			// Debug since errors such as "write: operation not permitted" occur during normal operation
			log.WithContextFields(LogFields{"error": err}).Debug("upstream UDP relay failed")
			// The port forward's goroutine will complete cleanup
			portForward.conn.Close()
		}

		portForward.lruEntry.Touch()

		atomic.AddInt64(&portForward.bytesUp, int64(len(message.packet)))
	}

	// Cleanup all UDP port forward workers when exiting

	mux.portForwardsMutex.Lock()
	for _, portForward := range mux.portForwards {
		// The port forward's goroutine will complete cleanup
		portForward.conn.Close()
	}
	mux.portForwardsMutex.Unlock()

	mux.relayWaitGroup.Wait()
}