// dialSsh is a helper that builds the transport layers and establishes the SSH connection.
// When additional dial configuration is used, DialStats are recorded and returned.
//
// The net.Conn return value is the value to be removed from pendingConns; additional
// layering (ThrottledConn, ActivityMonitoredConn) is applied, but this return value is the
// base dial conn. The *ActivityMonitoredConn return value is the layered conn passed into
// the ssh.Client.
func dialSsh(
	config *Config,
	pendingConns *common.Conns,
	serverEntry *protocol.ServerEntry,
	selectedProtocol,
	sessionId string) (*dialResult, error) {

	// The meek protocols tunnel obfuscated SSH. Obfuscated SSH is layered on top of SSH.
	// So depending on which protocol is used, multiple layers are initialized.

	useObfuscatedSsh := false
	var directTCPDialAddress string
	var meekConfig *MeekConfig
	var err error

	switch selectedProtocol {
	case protocol.TUNNEL_PROTOCOL_OBFUSCATED_SSH:
		useObfuscatedSsh = true
		directTCPDialAddress = fmt.Sprintf("%s:%d", serverEntry.IpAddress, serverEntry.SshObfuscatedPort)

	case protocol.TUNNEL_PROTOCOL_SSH:
		directTCPDialAddress = fmt.Sprintf("%s:%d", serverEntry.IpAddress, serverEntry.SshPort)

	default:
		useObfuscatedSsh = true
		meekConfig, err = initMeekConfig(config, serverEntry, selectedProtocol, sessionId)
		if err != nil {
			return nil, common.ContextError(err)
		}
	}

	NoticeConnectingServer(
		serverEntry.IpAddress,
		serverEntry.Region,
		selectedProtocol,
		directTCPDialAddress,
		meekConfig)

	// Use an asynchronous callback to record the resolved IP address when
	// dialing a domain name. Note that DialMeek doesn't immediately
	// establish any HTTPS connections, so the resolved IP address won't be
	// reported until during/after ssh session establishment (the ssh traffic
	// is meek payload). So don't Load() the IP address value until after that
	// has completed to ensure a result.
	var resolvedIPAddress atomic.Value
	resolvedIPAddress.Store("")
	setResolvedIPAddress := func(IPAddress string) {
		resolvedIPAddress.Store(IPAddress)
	}

	// Create the base transport: meek or direct connection
	dialConfig := &DialConfig{
		UpstreamProxyUrl:              config.UpstreamProxyUrl,
		UpstreamProxyCustomHeaders:    config.UpstreamProxyCustomHeaders,
		ConnectTimeout:                time.Duration(*config.TunnelConnectTimeoutSeconds) * time.Second,
		PendingConns:                  pendingConns,
		DeviceBinder:                  config.DeviceBinder,
		DnsServerGetter:               config.DnsServerGetter,
		UseIndistinguishableTLS:       config.UseIndistinguishableTLS,
		TrustedCACertificatesFilename: config.TrustedCACertificatesFilename,
		DeviceRegion:                  config.DeviceRegion,
		ResolvedIPCallback:            setResolvedIPAddress,
	}
	var dialConn net.Conn
	if meekConfig != nil {
		dialConn, err = DialMeek(meekConfig, dialConfig)
		if err != nil {
			return nil, common.ContextError(err)
		}
	} else {
		dialConn, err = DialTCP(directTCPDialAddress, dialConfig)
		if err != nil {
			return nil, common.ContextError(err)
		}
	}

	cleanupConn := dialConn
	defer func() {
		// Cleanup on error
		if cleanupConn != nil {
			cleanupConn.Close()
			pendingConns.Remove(cleanupConn)
		}
	}()

	// Activity monitoring is used to measure tunnel duration
	monitoredConn, err := common.NewActivityMonitoredConn(dialConn, 0, false, nil, nil)
	if err != nil {
		return nil, common.ContextError(err)
	}

	// Apply throttling (if configured)
	throttledConn := common.NewThrottledConn(monitoredConn, config.RateLimits)

	// Add obfuscated SSH layer
	var sshConn net.Conn = throttledConn
	if useObfuscatedSsh {
		sshConn, err = common.NewObfuscatedSshConn(
			common.OBFUSCATION_CONN_MODE_CLIENT, throttledConn, serverEntry.SshObfuscatedKey)
		if err != nil {
			return nil, common.ContextError(err)
		}
	}

	// Now establish the SSH session over the conn transport
	expectedPublicKey, err := base64.StdEncoding.DecodeString(serverEntry.SshHostKey)
	if err != nil {
		return nil, common.ContextError(err)
	}
	sshCertChecker := &ssh.CertChecker{
		HostKeyFallback: func(addr string, remote net.Addr, publicKey ssh.PublicKey) error {
			if !bytes.Equal(expectedPublicKey, publicKey.Marshal()) {
				return common.ContextError(errors.New("unexpected host public key"))
			}
			return nil
		},
	}

	sshPasswordPayload := &protocol.SSHPasswordPayload{
		SessionId:          sessionId,
		SshPassword:        serverEntry.SshPassword,
		ClientCapabilities: []string{protocol.CLIENT_CAPABILITY_SERVER_REQUESTS},
	}

	payload, err := json.Marshal(sshPasswordPayload)
	if err != nil {
		return nil, common.ContextError(err)
	}
	sshClientConfig := &ssh.ClientConfig{
		User: serverEntry.SshUsername,
		Auth: []ssh.AuthMethod{
			ssh.Password(string(payload)),
		},
		HostKeyCallback: sshCertChecker.CheckHostKey,
	}

	// The ssh session establishment (via ssh.NewClientConn) is wrapped
	// in a timeout to ensure it won't hang. We've encountered firewalls
	// that allow the TCP handshake to complete but then send a RST to the
	// server-side and nothing to the client-side, and if that happens
	// while ssh.NewClientConn is reading, it may wait forever. The timeout
	// closes the conn, which interrupts it.
	// Note: TCP handshake timeouts are provided by TCPConn, and session
	// timeouts *after* ssh establishment are provided by the ssh keep alive
	// in operate tunnel.
	// TODO: adjust the timeout to account for time-elapsed-from-start

	type sshNewClientResult struct {
		sshClient   *ssh.Client
		sshRequests <-chan *ssh.Request
		err         error
	}
	resultChannel := make(chan *sshNewClientResult, 2)
	if *config.TunnelConnectTimeoutSeconds > 0 {
		time.AfterFunc(time.Duration(*config.TunnelConnectTimeoutSeconds)*time.Second, func() {
			resultChannel <- &sshNewClientResult{nil, nil, errors.New("ssh dial timeout")}
		})
	}

	go func() {
		// The following is adapted from ssh.Dial(), here using a custom conn
		// The sshAddress is passed through to host key verification callbacks; we don't use it.
		sshAddress := ""
		sshClientConn, sshChannels, sshRequests, err := ssh.NewClientConn(
			sshConn, sshAddress, sshClientConfig)
		var sshClient *ssh.Client
		if err == nil {
			sshClient = ssh.NewClient(sshClientConn, sshChannels, nil)
		}
		resultChannel <- &sshNewClientResult{sshClient, sshRequests, err}
	}()

	result := <-resultChannel
	if result.err != nil {
		return nil, common.ContextError(result.err)
	}

	var dialStats *TunnelDialStats

	if dialConfig.UpstreamProxyUrl != "" || meekConfig != nil {
		dialStats = &TunnelDialStats{}

		if dialConfig.UpstreamProxyUrl != "" {

			// Note: UpstreamProxyUrl should have parsed correctly in the dial
			proxyURL, err := url.Parse(dialConfig.UpstreamProxyUrl)
			if err == nil {
				dialStats.UpstreamProxyType = proxyURL.Scheme
			}

			dialStats.UpstreamProxyCustomHeaderNames = make([]string, 0)
			for name, _ := range dialConfig.UpstreamProxyCustomHeaders {
				dialStats.UpstreamProxyCustomHeaderNames = append(dialStats.UpstreamProxyCustomHeaderNames, name)
			}
		}

		if meekConfig != nil {
			dialStats.MeekDialAddress = meekConfig.DialAddress
			dialStats.MeekResolvedIPAddress = resolvedIPAddress.Load().(string)
			dialStats.MeekSNIServerName = meekConfig.SNIServerName
			dialStats.MeekHostHeader = meekConfig.HostHeader
			dialStats.MeekTransformedHostName = meekConfig.TransformedHostName
		}

		NoticeConnectedTunnelDialStats(serverEntry.IpAddress, dialStats)
	}

	cleanupConn = nil

	// Note: dialConn may be used to close the underlying network connection
	// but should not be used to perform I/O as that would interfere with SSH
	// (and also bypasses throttling).

	return &dialResult{
			dialConn:      dialConn,
			monitoredConn: monitoredConn,
			sshClient:     result.sshClient,
			sshRequests:   result.sshRequests,
			dialStats:     dialStats},
		nil
}
// EstablishTunnel first makes a network transport connection to the
// Psiphon server and then establishes an SSH client session on top of
// that transport. The SSH server is authenticated using the public
// key in the server entry.
// Depending on the server's capabilities, the connection may use
// plain SSH over TCP, obfuscated SSH over TCP, or obfuscated SSH over
// HTTP (meek protocol).
// When requiredProtocol is not blank, that protocol is used. Otherwise,
// the a random supported protocol is used.
// untunneledDialConfig is used for untunneled final status requests.
func EstablishTunnel(
	config *Config,
	untunneledDialConfig *DialConfig,
	sessionId string,
	pendingConns *common.Conns,
	serverEntry *protocol.ServerEntry,
	adjustedEstablishStartTime monotime.Time,
	tunnelOwner TunnelOwner) (tunnel *Tunnel, err error) {

	selectedProtocol, err := selectProtocol(config, serverEntry)
	if err != nil {
		return nil, common.ContextError(err)
	}

	// Build transport layers and establish SSH connection. Note that
	// dialConn and monitoredConn are the same network connection.
	dialResult, err := dialSsh(
		config, pendingConns, serverEntry, selectedProtocol, sessionId)
	if err != nil {
		return nil, common.ContextError(err)
	}

	// Cleanup on error
	defer func() {
		if err != nil {
			dialResult.sshClient.Close()
			dialResult.monitoredConn.Close()
			pendingConns.Remove(dialResult.dialConn)
		}
	}()

	// The tunnel is now connected
	tunnel = &Tunnel{
		mutex:                    new(sync.Mutex),
		config:                   config,
		untunneledDialConfig:     untunneledDialConfig,
		isClosed:                 false,
		serverEntry:              serverEntry,
		protocol:                 selectedProtocol,
		conn:                     dialResult.monitoredConn,
		sshClient:                dialResult.sshClient,
		sshServerRequests:        dialResult.sshRequests,
		operateWaitGroup:         new(sync.WaitGroup),
		shutdownOperateBroadcast: make(chan struct{}),
		// A buffer allows at least one signal to be sent even when the receiver is
		// not listening. Senders should not block.
		signalPortForwardFailure: make(chan struct{}, 1),
		dialStats:                dialResult.dialStats,
		// Buffer allows SetClientVerificationPayload to submit one new payload
		// without blocking or dropping it.
		newClientVerificationPayload: make(chan string, 1),
	}

	// Create a new Psiphon API server context for this tunnel. This includes
	// performing a handshake request. If the handshake fails, this establishment
	// fails.
	if !config.DisableApi {
		NoticeInfo("starting server context for %s", tunnel.serverEntry.IpAddress)
		tunnel.serverContext, err = NewServerContext(tunnel, sessionId)
		if err != nil {
			return nil, common.ContextError(
				fmt.Errorf("error starting server context for %s: %s",
					tunnel.serverEntry.IpAddress, err))
		}
	}

	// establishDuration is the elapsed time between the controller starting tunnel
	// establishment and this tunnel being established. The reported value represents
	// how long the user waited between starting the client and having a usable tunnel;
	// or how long between the client detecting an unexpected tunnel disconnect and
	// completing automatic reestablishment.
	//
	// This time period may include time spent unsuccessfully connecting to other
	// servers. Time spent waiting for network connectivity is excluded.
	tunnel.establishDuration = monotime.Since(adjustedEstablishStartTime)

	tunnel.establishedTime = monotime.Now()

	// Now that network operations are complete, cancel interruptibility
	pendingConns.Remove(dialResult.dialConn)

	// Spawn the operateTunnel goroutine, which monitors the tunnel and handles periodic stats updates.
	tunnel.operateWaitGroup.Add(1)
	go tunnel.operateTunnel(tunnelOwner)

	return tunnel, nil
}