func TestRoundTripAndNewConnection(t *testing.T) { localhostPool := x509.NewCertPool() if !localhostPool.AppendCertsFromPEM(localhostCert) { t.Errorf("error setting up localhostCert pool") } httpsServerInvalidHostname := func(h http.Handler) *httptest.Server { cert, err := tls.X509KeyPair(exampleCert, exampleKey) if err != nil { t.Errorf("https (invalid hostname): proxy_test: %v", err) } ts := httptest.NewUnstartedServer(h) ts.TLS = &tls.Config{ Certificates: []tls.Certificate{cert}, } ts.StartTLS() return ts } httpsServerValidHostname := func(h http.Handler) *httptest.Server { cert, err := tls.X509KeyPair(localhostCert, localhostKey) if err != nil { t.Errorf("https (valid hostname): proxy_test: %v", err) } ts := httptest.NewUnstartedServer(h) ts.TLS = &tls.Config{ Certificates: []tls.Certificate{cert}, } ts.StartTLS() return ts } testCases := map[string]struct { serverFunc func(http.Handler) *httptest.Server proxyServerFunc func(http.Handler) *httptest.Server clientTLS *tls.Config serverConnectionHeader string serverUpgradeHeader string serverStatusCode int shouldError bool }{ "no headers": { serverFunc: httptest.NewServer, serverConnectionHeader: "", serverUpgradeHeader: "", serverStatusCode: http.StatusSwitchingProtocols, shouldError: true, }, "no upgrade header": { serverFunc: httptest.NewServer, serverConnectionHeader: "Upgrade", serverUpgradeHeader: "", serverStatusCode: http.StatusSwitchingProtocols, shouldError: true, }, "no connection header": { serverFunc: httptest.NewServer, serverConnectionHeader: "", serverUpgradeHeader: "SPDY/3.1", serverStatusCode: http.StatusSwitchingProtocols, shouldError: true, }, "no switching protocol status code": { serverFunc: httptest.NewServer, serverConnectionHeader: "Upgrade", serverUpgradeHeader: "SPDY/3.1", serverStatusCode: http.StatusForbidden, shouldError: true, }, "http": { serverFunc: httptest.NewServer, serverConnectionHeader: "Upgrade", serverUpgradeHeader: "SPDY/3.1", serverStatusCode: http.StatusSwitchingProtocols, shouldError: false, }, "https (invalid hostname + InsecureSkipVerify)": { serverFunc: httpsServerInvalidHostname, clientTLS: &tls.Config{InsecureSkipVerify: true}, serverConnectionHeader: "Upgrade", serverUpgradeHeader: "SPDY/3.1", serverStatusCode: http.StatusSwitchingProtocols, shouldError: false, }, "https (invalid hostname + hostname verification)": { serverFunc: httpsServerInvalidHostname, clientTLS: &tls.Config{InsecureSkipVerify: false}, serverConnectionHeader: "Upgrade", serverUpgradeHeader: "SPDY/3.1", serverStatusCode: http.StatusSwitchingProtocols, shouldError: true, }, "https (valid hostname + RootCAs)": { serverFunc: httpsServerValidHostname, clientTLS: &tls.Config{RootCAs: localhostPool}, serverConnectionHeader: "Upgrade", serverUpgradeHeader: "SPDY/3.1", serverStatusCode: http.StatusSwitchingProtocols, shouldError: false, }, "proxied http->http": { serverFunc: httptest.NewServer, proxyServerFunc: httptest.NewServer, serverConnectionHeader: "Upgrade", serverUpgradeHeader: "SPDY/3.1", serverStatusCode: http.StatusSwitchingProtocols, shouldError: false, }, "proxied https (invalid hostname + InsecureSkipVerify) -> http": { serverFunc: httptest.NewServer, proxyServerFunc: httpsServerInvalidHostname, clientTLS: &tls.Config{InsecureSkipVerify: true}, serverConnectionHeader: "Upgrade", serverUpgradeHeader: "SPDY/3.1", serverStatusCode: http.StatusSwitchingProtocols, shouldError: false, }, "proxied https (invalid hostname + hostname verification) -> http": { serverFunc: httptest.NewServer, proxyServerFunc: httpsServerInvalidHostname, clientTLS: &tls.Config{InsecureSkipVerify: false}, serverConnectionHeader: "Upgrade", serverUpgradeHeader: "SPDY/3.1", serverStatusCode: http.StatusSwitchingProtocols, shouldError: true, // fails because the client doesn't trust the proxy }, "proxied https (valid hostname + RootCAs) -> http": { serverFunc: httptest.NewServer, proxyServerFunc: httpsServerValidHostname, clientTLS: &tls.Config{RootCAs: localhostPool}, serverConnectionHeader: "Upgrade", serverUpgradeHeader: "SPDY/3.1", serverStatusCode: http.StatusSwitchingProtocols, shouldError: false, }, "proxied https (invalid hostname + InsecureSkipVerify) -> https (invalid hostname)": { serverFunc: httpsServerInvalidHostname, proxyServerFunc: httpsServerInvalidHostname, clientTLS: &tls.Config{InsecureSkipVerify: true}, serverConnectionHeader: "Upgrade", serverUpgradeHeader: "SPDY/3.1", serverStatusCode: http.StatusSwitchingProtocols, shouldError: false, // works because the test proxy ignores TLS errors }, "proxied https (invalid hostname + hostname verification) -> https (invalid hostname)": { serverFunc: httpsServerInvalidHostname, proxyServerFunc: httpsServerInvalidHostname, clientTLS: &tls.Config{InsecureSkipVerify: false}, serverConnectionHeader: "Upgrade", serverUpgradeHeader: "SPDY/3.1", serverStatusCode: http.StatusSwitchingProtocols, shouldError: true, // fails because the client doesn't trust the proxy }, "proxied https (valid hostname + RootCAs) -> https (valid hostname + RootCAs)": { serverFunc: httpsServerValidHostname, proxyServerFunc: httpsServerValidHostname, clientTLS: &tls.Config{RootCAs: localhostPool}, serverConnectionHeader: "Upgrade", serverUpgradeHeader: "SPDY/3.1", serverStatusCode: http.StatusSwitchingProtocols, shouldError: false, }, } for k, testCase := range testCases { server := testCase.serverFunc(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { if testCase.shouldError { if e, a := httpstream.HeaderUpgrade, req.Header.Get(httpstream.HeaderConnection); e != a { t.Fatalf("%s: Expected connection=upgrade header, got '%s", k, a) } w.Header().Set(httpstream.HeaderConnection, testCase.serverConnectionHeader) w.Header().Set(httpstream.HeaderUpgrade, testCase.serverUpgradeHeader) w.WriteHeader(testCase.serverStatusCode) return } streamCh := make(chan httpstream.Stream) responseUpgrader := NewResponseUpgrader() spdyConn := responseUpgrader.UpgradeResponse(w, req, func(s httpstream.Stream, replySent <-chan struct{}) error { streamCh <- s return nil }) if spdyConn == nil { t.Fatalf("%s: unexpected nil spdyConn", k) } defer spdyConn.Close() stream := <-streamCh io.Copy(stream, stream) })) // TODO: Uncomment when fix #19254 // defer server.Close() serverURL, err := url.Parse(server.URL) if err != nil { t.Fatalf("%s: Error creating request: %s", k, err) } req, err := http.NewRequest("GET", server.URL, nil) if err != nil { t.Fatalf("%s: Error creating request: %s", k, err) } spdyTransport := NewSpdyRoundTripper(testCase.clientTLS) var proxierCalled bool var proxyCalledWithHost string if testCase.proxyServerFunc != nil { proxyHandler := goproxy.NewProxyHttpServer() proxyHandler.OnRequest().HandleConnectFunc(func(host string, ctx *goproxy.ProxyCtx) (*goproxy.ConnectAction, string) { proxyCalledWithHost = host return goproxy.OkConnect, host }) proxy := testCase.proxyServerFunc(proxyHandler) spdyTransport.proxier = func(proxierReq *http.Request) (*url.URL, error) { proxyURL, err := url.Parse(proxy.URL) if err != nil { return nil, err } proxierCalled = true return proxyURL, nil } // TODO: Uncomment when fix #19254 // defer proxy.Close() } client := &http.Client{Transport: spdyTransport} resp, err := client.Do(req) var conn httpstream.Connection if err == nil { conn, err = spdyTransport.NewConnection(resp) } haveErr := err != nil if e, a := testCase.shouldError, haveErr; e != a { t.Fatalf("%s: shouldError=%t, got %t: %v", k, e, a, err) } if testCase.shouldError { continue } defer conn.Close() if resp.StatusCode != http.StatusSwitchingProtocols { t.Fatalf("%s: expected http 101 switching protocols, got %d", k, resp.StatusCode) } stream, err := conn.CreateStream(http.Header{}) if err != nil { t.Fatalf("%s: error creating client stream: %s", k, err) } n, err := stream.Write([]byte("hello")) if err != nil { t.Fatalf("%s: error writing to stream: %s", k, err) } if n != 5 { t.Fatalf("%s: Expected to write 5 bytes, but actually wrote %d", k, n) } b := make([]byte, 5) n, err = stream.Read(b) if err != nil { t.Fatalf("%s: error reading from stream: %s", k, err) } if n != 5 { t.Fatalf("%s: Expected to read 5 bytes, but actually read %d", k, n) } if e, a := "hello", string(b[0:n]); e != a { t.Fatalf("%s: expected '%s', got '%s'", k, e, a) } if testCase.proxyServerFunc != nil { if !proxierCalled { t.Fatalf("%s: Expected to use a proxy but proxier in SpdyRoundTripper wasn't called", k) } if proxyCalledWithHost != serverURL.Host { t.Fatalf("%s: Expected to see a call to the proxy for backend %q, got %q", k, serverURL.Host, proxyCalledWithHost) } } } }
func (e *streamProtocolV1) stream(conn httpstream.Connection) error { doneChan := make(chan struct{}, 2) errorChan := make(chan error) cp := func(s string, dst io.Writer, src io.Reader) { glog.V(6).Infof("Copying %s", s) defer glog.V(6).Infof("Done copying %s", s) if _, err := io.Copy(dst, src); err != nil && err != io.EOF { glog.Errorf("Error copying %s: %v", s, err) } if s == api.StreamTypeStdout || s == api.StreamTypeStderr { doneChan <- struct{}{} } } headers := http.Header{} headers.Set(api.StreamType, api.StreamTypeError) errorStream, err := conn.CreateStream(headers) if err != nil { return err } go func() { message, err := ioutil.ReadAll(errorStream) if err != nil && err != io.EOF { errorChan <- fmt.Errorf("Error reading from error stream: %s", err) return } if len(message) > 0 { errorChan <- fmt.Errorf("Error executing remote command: %s", message) return } }() defer errorStream.Reset() if e.stdin != nil { headers.Set(api.StreamType, api.StreamTypeStdin) remoteStdin, err := conn.CreateStream(headers) if err != nil { return err } defer remoteStdin.Reset() // TODO this goroutine will never exit cleanly (the io.Copy never unblocks) // because stdin is not closed until the process exits. If we try to call // stdin.Close(), it returns no error but doesn't unblock the copy. It will // exit when the process exits, instead. go cp(api.StreamTypeStdin, remoteStdin, e.stdin) } waitCount := 0 completedStreams := 0 if e.stdout != nil { waitCount++ headers.Set(api.StreamType, api.StreamTypeStdout) remoteStdout, err := conn.CreateStream(headers) if err != nil { return err } defer remoteStdout.Reset() go cp(api.StreamTypeStdout, e.stdout, remoteStdout) } if e.stderr != nil && !e.tty { waitCount++ headers.Set(api.StreamType, api.StreamTypeStderr) remoteStderr, err := conn.CreateStream(headers) if err != nil { return err } defer remoteStderr.Reset() go cp(api.StreamTypeStderr, e.stderr, remoteStderr) } Loop: for { select { case <-doneChan: completedStreams++ if completedStreams == waitCount { break Loop } case err := <-errorChan: return err } } return nil }
func (e *streamProtocolV2) stream(conn httpstream.Connection) error { var ( err error errorStream, remoteStdin, remoteStdout, remoteStderr httpstream.Stream ) headers := http.Header{} // set up all the streams first // set up error stream errorChan := make(chan error) headers.Set(api.StreamType, api.StreamTypeError) errorStream, err = conn.CreateStream(headers) if err != nil { return err } // set up stdin stream if e.stdin != nil { headers.Set(api.StreamType, api.StreamTypeStdin) remoteStdin, err = conn.CreateStream(headers) if err != nil { return err } } // set up stdout stream if e.stdout != nil { headers.Set(api.StreamType, api.StreamTypeStdout) remoteStdout, err = conn.CreateStream(headers) if err != nil { return err } } // set up stderr stream if e.stderr != nil && !e.tty { headers.Set(api.StreamType, api.StreamTypeStderr) remoteStderr, err = conn.CreateStream(headers) if err != nil { return err } } // now that all the streams have been created, proceed with reading & copying // always read from errorStream go func() { message, err := ioutil.ReadAll(errorStream) switch { case err != nil && err != io.EOF: errorChan <- fmt.Errorf("error reading from error stream: %s", err) case len(message) > 0: errorChan <- fmt.Errorf("error executing remote command: %s", message) default: errorChan <- nil } close(errorChan) }() var wg sync.WaitGroup var once sync.Once if e.stdin != nil { // copy from client's stdin to container's stdin go func() { // if e.stdin is noninteractive, e.g. `echo abc | kubectl exec -i <pod> -- cat`, make sure // we close remoteStdin as soon as the copy from e.stdin to remoteStdin finishes. Otherwise // the executed command will remain running. defer once.Do(func() { remoteStdin.Close() }) if _, err := io.Copy(remoteStdin, e.stdin); err != nil { runtime.HandleError(err) } }() // read from remoteStdin until the stream is closed. this is essential to // be able to exit interactive sessions cleanly and not leak goroutines or // hang the client's terminal. // // go-dockerclient's current hijack implementation // (https://github.com/fsouza/go-dockerclient/blob/89f3d56d93788dfe85f864a44f85d9738fca0670/client.go#L564) // waits for all three streams (stdin/stdout/stderr) to finish copying // before returning. When hijack finishes copying stdout/stderr, it calls // Close() on its side of remoteStdin, which allows this copy to complete. // When that happens, we must Close() on our side of remoteStdin, to // allow the copy in hijack to complete, and hijack to return. go func() { defer once.Do(func() { remoteStdin.Close() }) // this "copy" doesn't actually read anything - it's just here to wait for // the server to close remoteStdin. if _, err := io.Copy(ioutil.Discard, remoteStdin); err != nil { runtime.HandleError(err) } }() } if e.stdout != nil { wg.Add(1) go func() { defer wg.Done() if _, err := io.Copy(e.stdout, remoteStdout); err != nil { runtime.HandleError(err) } }() } if e.stderr != nil && !e.tty { wg.Add(1) go func() { defer wg.Done() if _, err := io.Copy(e.stderr, remoteStderr); err != nil { runtime.HandleError(err) } }() } // we're waiting for stdout/stderr to finish copying wg.Wait() // waits for errorStream to finish reading with an error or nil return <-errorChan }
func (e *streamProtocolV1) stream(conn httpstream.Connection) error { doneChan := make(chan struct{}, 2) errorChan := make(chan error) cp := func(s string, dst io.Writer, src io.Reader) { glog.V(6).Infof("Copying %s", s) defer glog.V(6).Infof("Done copying %s", s) if _, err := io.Copy(dst, src); err != nil && err != io.EOF { glog.Errorf("Error copying %s: %v", s, err) } if s == api.StreamTypeStdout || s == api.StreamTypeStderr { doneChan <- struct{}{} } } var ( err error errorStream, remoteStdin, remoteStdout, remoteStderr httpstream.Stream ) // set up all the streams first headers := http.Header{} headers.Set(api.StreamType, api.StreamTypeError) errorStream, err = conn.CreateStream(headers) if err != nil { return err } defer errorStream.Reset() // Create all the streams first, then start the copy goroutines. The server doesn't start its copy // goroutines until it's received all of the streams. If the client creates the stdin stream and // immediately begins copying stdin data to the server, it's possible to overwhelm and wedge the // spdy frame handler in the server so that it is full of unprocessed frames. The frames aren't // getting processed because the server hasn't started its copying, and it won't do that until it // gets all the streams. By creating all the streams first, we ensure that the server is ready to // process data before the client starts sending any. See https://issues.k8s.io/16373 for more info. if e.stdin != nil { headers.Set(api.StreamType, api.StreamTypeStdin) remoteStdin, err = conn.CreateStream(headers) if err != nil { return err } defer remoteStdin.Reset() } if e.stdout != nil { headers.Set(api.StreamType, api.StreamTypeStdout) remoteStdout, err = conn.CreateStream(headers) if err != nil { return err } defer remoteStdout.Reset() } if e.stderr != nil && !e.tty { headers.Set(api.StreamType, api.StreamTypeStderr) remoteStderr, err = conn.CreateStream(headers) if err != nil { return err } defer remoteStderr.Reset() } // now that all the streams have been created, proceed with reading & copying // always read from errorStream go func() { message, err := ioutil.ReadAll(errorStream) if err != nil && err != io.EOF { errorChan <- fmt.Errorf("Error reading from error stream: %s", err) return } if len(message) > 0 { errorChan <- fmt.Errorf("Error executing remote command: %s", message) return } }() if e.stdin != nil { // TODO this goroutine will never exit cleanly (the io.Copy never unblocks) // because stdin is not closed until the process exits. If we try to call // stdin.Close(), it returns no error but doesn't unblock the copy. It will // exit when the process exits, instead. go cp(api.StreamTypeStdin, remoteStdin, e.stdin) } waitCount := 0 completedStreams := 0 if e.stdout != nil { waitCount++ go cp(api.StreamTypeStdout, e.stdout, remoteStdout) } if e.stderr != nil && !e.tty { waitCount++ go cp(api.StreamTypeStderr, e.stderr, remoteStderr) } Loop: for { select { case <-doneChan: completedStreams++ if completedStreams == waitCount { break Loop } case err := <-errorChan: return err } } return nil }