func TestRoundTripAndNewConnection(t *testing.T) {
	localhostPool := x509.NewCertPool()
	if !localhostPool.AppendCertsFromPEM(localhostCert) {
		t.Errorf("error setting up localhostCert pool")
	}

	httpsServerInvalidHostname := func(h http.Handler) *httptest.Server {
		cert, err := tls.X509KeyPair(exampleCert, exampleKey)
		if err != nil {
			t.Errorf("https (invalid hostname): proxy_test: %v", err)
		}
		ts := httptest.NewUnstartedServer(h)
		ts.TLS = &tls.Config{
			Certificates: []tls.Certificate{cert},
		}
		ts.StartTLS()
		return ts
	}

	httpsServerValidHostname := func(h http.Handler) *httptest.Server {
		cert, err := tls.X509KeyPair(localhostCert, localhostKey)
		if err != nil {
			t.Errorf("https (valid hostname): proxy_test: %v", err)
		}
		ts := httptest.NewUnstartedServer(h)
		ts.TLS = &tls.Config{
			Certificates: []tls.Certificate{cert},
		}
		ts.StartTLS()
		return ts
	}

	testCases := map[string]struct {
		serverFunc             func(http.Handler) *httptest.Server
		proxyServerFunc        func(http.Handler) *httptest.Server
		clientTLS              *tls.Config
		serverConnectionHeader string
		serverUpgradeHeader    string
		serverStatusCode       int
		shouldError            bool
	}{
		"no headers": {
			serverFunc:             httptest.NewServer,
			serverConnectionHeader: "",
			serverUpgradeHeader:    "",
			serverStatusCode:       http.StatusSwitchingProtocols,
			shouldError:            true,
		},
		"no upgrade header": {
			serverFunc:             httptest.NewServer,
			serverConnectionHeader: "Upgrade",
			serverUpgradeHeader:    "",
			serverStatusCode:       http.StatusSwitchingProtocols,
			shouldError:            true,
		},
		"no connection header": {
			serverFunc:             httptest.NewServer,
			serverConnectionHeader: "",
			serverUpgradeHeader:    "SPDY/3.1",
			serverStatusCode:       http.StatusSwitchingProtocols,
			shouldError:            true,
		},
		"no switching protocol status code": {
			serverFunc:             httptest.NewServer,
			serverConnectionHeader: "Upgrade",
			serverUpgradeHeader:    "SPDY/3.1",
			serverStatusCode:       http.StatusForbidden,
			shouldError:            true,
		},
		"http": {
			serverFunc:             httptest.NewServer,
			serverConnectionHeader: "Upgrade",
			serverUpgradeHeader:    "SPDY/3.1",
			serverStatusCode:       http.StatusSwitchingProtocols,
			shouldError:            false,
		},
		"https (invalid hostname + InsecureSkipVerify)": {
			serverFunc:             httpsServerInvalidHostname,
			clientTLS:              &tls.Config{InsecureSkipVerify: true},
			serverConnectionHeader: "Upgrade",
			serverUpgradeHeader:    "SPDY/3.1",
			serverStatusCode:       http.StatusSwitchingProtocols,
			shouldError:            false,
		},
		"https (invalid hostname + hostname verification)": {
			serverFunc:             httpsServerInvalidHostname,
			clientTLS:              &tls.Config{InsecureSkipVerify: false},
			serverConnectionHeader: "Upgrade",
			serverUpgradeHeader:    "SPDY/3.1",
			serverStatusCode:       http.StatusSwitchingProtocols,
			shouldError:            true,
		},
		"https (valid hostname + RootCAs)": {
			serverFunc:             httpsServerValidHostname,
			clientTLS:              &tls.Config{RootCAs: localhostPool},
			serverConnectionHeader: "Upgrade",
			serverUpgradeHeader:    "SPDY/3.1",
			serverStatusCode:       http.StatusSwitchingProtocols,
			shouldError:            false,
		},
		"proxied http->http": {
			serverFunc:             httptest.NewServer,
			proxyServerFunc:        httptest.NewServer,
			serverConnectionHeader: "Upgrade",
			serverUpgradeHeader:    "SPDY/3.1",
			serverStatusCode:       http.StatusSwitchingProtocols,
			shouldError:            false,
		},
		"proxied https (invalid hostname + InsecureSkipVerify) -> http": {
			serverFunc:             httptest.NewServer,
			proxyServerFunc:        httpsServerInvalidHostname,
			clientTLS:              &tls.Config{InsecureSkipVerify: true},
			serverConnectionHeader: "Upgrade",
			serverUpgradeHeader:    "SPDY/3.1",
			serverStatusCode:       http.StatusSwitchingProtocols,
			shouldError:            false,
		},
		"proxied https (invalid hostname + hostname verification) -> http": {
			serverFunc:             httptest.NewServer,
			proxyServerFunc:        httpsServerInvalidHostname,
			clientTLS:              &tls.Config{InsecureSkipVerify: false},
			serverConnectionHeader: "Upgrade",
			serverUpgradeHeader:    "SPDY/3.1",
			serverStatusCode:       http.StatusSwitchingProtocols,
			shouldError:            true, // fails because the client doesn't trust the proxy
		},
		"proxied https (valid hostname + RootCAs) -> http": {
			serverFunc:             httptest.NewServer,
			proxyServerFunc:        httpsServerValidHostname,
			clientTLS:              &tls.Config{RootCAs: localhostPool},
			serverConnectionHeader: "Upgrade",
			serverUpgradeHeader:    "SPDY/3.1",
			serverStatusCode:       http.StatusSwitchingProtocols,
			shouldError:            false,
		},
		"proxied https (invalid hostname + InsecureSkipVerify) -> https (invalid hostname)": {
			serverFunc:             httpsServerInvalidHostname,
			proxyServerFunc:        httpsServerInvalidHostname,
			clientTLS:              &tls.Config{InsecureSkipVerify: true},
			serverConnectionHeader: "Upgrade",
			serverUpgradeHeader:    "SPDY/3.1",
			serverStatusCode:       http.StatusSwitchingProtocols,
			shouldError:            false, // works because the test proxy ignores TLS errors
		},
		"proxied https (invalid hostname + hostname verification) -> https (invalid hostname)": {
			serverFunc:             httpsServerInvalidHostname,
			proxyServerFunc:        httpsServerInvalidHostname,
			clientTLS:              &tls.Config{InsecureSkipVerify: false},
			serverConnectionHeader: "Upgrade",
			serverUpgradeHeader:    "SPDY/3.1",
			serverStatusCode:       http.StatusSwitchingProtocols,
			shouldError:            true, // fails because the client doesn't trust the proxy
		},
		"proxied https (valid hostname + RootCAs) -> https (valid hostname + RootCAs)": {
			serverFunc:             httpsServerValidHostname,
			proxyServerFunc:        httpsServerValidHostname,
			clientTLS:              &tls.Config{RootCAs: localhostPool},
			serverConnectionHeader: "Upgrade",
			serverUpgradeHeader:    "SPDY/3.1",
			serverStatusCode:       http.StatusSwitchingProtocols,
			shouldError:            false,
		},
	}

	for k, testCase := range testCases {
		server := testCase.serverFunc(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
			if testCase.shouldError {
				if e, a := httpstream.HeaderUpgrade, req.Header.Get(httpstream.HeaderConnection); e != a {
					t.Fatalf("%s: Expected connection=upgrade header, got '%s", k, a)
				}

				w.Header().Set(httpstream.HeaderConnection, testCase.serverConnectionHeader)
				w.Header().Set(httpstream.HeaderUpgrade, testCase.serverUpgradeHeader)
				w.WriteHeader(testCase.serverStatusCode)

				return
			}

			streamCh := make(chan httpstream.Stream)

			responseUpgrader := NewResponseUpgrader()
			spdyConn := responseUpgrader.UpgradeResponse(w, req, func(s httpstream.Stream, replySent <-chan struct{}) error {
				streamCh <- s
				return nil
			})
			if spdyConn == nil {
				t.Fatalf("%s: unexpected nil spdyConn", k)
			}
			defer spdyConn.Close()

			stream := <-streamCh
			io.Copy(stream, stream)
		}))
		// TODO: Uncomment when fix #19254
		// defer server.Close()

		serverURL, err := url.Parse(server.URL)
		if err != nil {
			t.Fatalf("%s: Error creating request: %s", k, err)
		}
		req, err := http.NewRequest("GET", server.URL, nil)
		if err != nil {
			t.Fatalf("%s: Error creating request: %s", k, err)
		}

		spdyTransport := NewSpdyRoundTripper(testCase.clientTLS)

		var proxierCalled bool
		var proxyCalledWithHost string
		if testCase.proxyServerFunc != nil {
			proxyHandler := goproxy.NewProxyHttpServer()
			proxyHandler.OnRequest().HandleConnectFunc(func(host string, ctx *goproxy.ProxyCtx) (*goproxy.ConnectAction, string) {
				proxyCalledWithHost = host
				return goproxy.OkConnect, host
			})
			proxy := testCase.proxyServerFunc(proxyHandler)

			spdyTransport.proxier = func(proxierReq *http.Request) (*url.URL, error) {
				proxyURL, err := url.Parse(proxy.URL)
				if err != nil {
					return nil, err
				}
				proxierCalled = true
				return proxyURL, nil
			}
			// TODO: Uncomment when fix #19254
			// defer proxy.Close()
		}

		client := &http.Client{Transport: spdyTransport}

		resp, err := client.Do(req)
		var conn httpstream.Connection
		if err == nil {
			conn, err = spdyTransport.NewConnection(resp)
		}
		haveErr := err != nil
		if e, a := testCase.shouldError, haveErr; e != a {
			t.Fatalf("%s: shouldError=%t, got %t: %v", k, e, a, err)
		}
		if testCase.shouldError {
			continue
		}
		defer conn.Close()

		if resp.StatusCode != http.StatusSwitchingProtocols {
			t.Fatalf("%s: expected http 101 switching protocols, got %d", k, resp.StatusCode)
		}

		stream, err := conn.CreateStream(http.Header{})
		if err != nil {
			t.Fatalf("%s: error creating client stream: %s", k, err)
		}

		n, err := stream.Write([]byte("hello"))
		if err != nil {
			t.Fatalf("%s: error writing to stream: %s", k, err)
		}
		if n != 5 {
			t.Fatalf("%s: Expected to write 5 bytes, but actually wrote %d", k, n)
		}

		b := make([]byte, 5)
		n, err = stream.Read(b)
		if err != nil {
			t.Fatalf("%s: error reading from stream: %s", k, err)
		}
		if n != 5 {
			t.Fatalf("%s: Expected to read 5 bytes, but actually read %d", k, n)
		}
		if e, a := "hello", string(b[0:n]); e != a {
			t.Fatalf("%s: expected '%s', got '%s'", k, e, a)
		}

		if testCase.proxyServerFunc != nil {
			if !proxierCalled {
				t.Fatalf("%s: Expected to use a proxy but proxier in SpdyRoundTripper wasn't called", k)
			}
			if proxyCalledWithHost != serverURL.Host {
				t.Fatalf("%s: Expected to see a call to the proxy for backend %q, got %q", k, serverURL.Host, proxyCalledWithHost)
			}
		}
	}
}
Example #2
0
func (e *streamProtocolV1) stream(conn httpstream.Connection) error {
	doneChan := make(chan struct{}, 2)
	errorChan := make(chan error)

	cp := func(s string, dst io.Writer, src io.Reader) {
		glog.V(6).Infof("Copying %s", s)
		defer glog.V(6).Infof("Done copying %s", s)
		if _, err := io.Copy(dst, src); err != nil && err != io.EOF {
			glog.Errorf("Error copying %s: %v", s, err)
		}
		if s == api.StreamTypeStdout || s == api.StreamTypeStderr {
			doneChan <- struct{}{}
		}
	}

	headers := http.Header{}
	headers.Set(api.StreamType, api.StreamTypeError)
	errorStream, err := conn.CreateStream(headers)
	if err != nil {
		return err
	}
	go func() {
		message, err := ioutil.ReadAll(errorStream)
		if err != nil && err != io.EOF {
			errorChan <- fmt.Errorf("Error reading from error stream: %s", err)
			return
		}
		if len(message) > 0 {
			errorChan <- fmt.Errorf("Error executing remote command: %s", message)
			return
		}
	}()
	defer errorStream.Reset()

	if e.stdin != nil {
		headers.Set(api.StreamType, api.StreamTypeStdin)
		remoteStdin, err := conn.CreateStream(headers)
		if err != nil {
			return err
		}
		defer remoteStdin.Reset()
		// TODO this goroutine will never exit cleanly (the io.Copy never unblocks)
		// because stdin is not closed until the process exits. If we try to call
		// stdin.Close(), it returns no error but doesn't unblock the copy. It will
		// exit when the process exits, instead.
		go cp(api.StreamTypeStdin, remoteStdin, e.stdin)
	}

	waitCount := 0
	completedStreams := 0

	if e.stdout != nil {
		waitCount++
		headers.Set(api.StreamType, api.StreamTypeStdout)
		remoteStdout, err := conn.CreateStream(headers)
		if err != nil {
			return err
		}
		defer remoteStdout.Reset()
		go cp(api.StreamTypeStdout, e.stdout, remoteStdout)
	}

	if e.stderr != nil && !e.tty {
		waitCount++
		headers.Set(api.StreamType, api.StreamTypeStderr)
		remoteStderr, err := conn.CreateStream(headers)
		if err != nil {
			return err
		}
		defer remoteStderr.Reset()
		go cp(api.StreamTypeStderr, e.stderr, remoteStderr)
	}

Loop:
	for {
		select {
		case <-doneChan:
			completedStreams++
			if completedStreams == waitCount {
				break Loop
			}
		case err := <-errorChan:
			return err
		}
	}

	return nil
}
Example #3
0
func (e *streamProtocolV2) stream(conn httpstream.Connection) error {
	var (
		err                                                  error
		errorStream, remoteStdin, remoteStdout, remoteStderr httpstream.Stream
	)

	headers := http.Header{}

	// set up all the streams first
	// set up error stream
	errorChan := make(chan error)
	headers.Set(api.StreamType, api.StreamTypeError)
	errorStream, err = conn.CreateStream(headers)
	if err != nil {
		return err
	}

	// set up stdin stream
	if e.stdin != nil {
		headers.Set(api.StreamType, api.StreamTypeStdin)
		remoteStdin, err = conn.CreateStream(headers)
		if err != nil {
			return err
		}
	}

	// set up stdout stream
	if e.stdout != nil {
		headers.Set(api.StreamType, api.StreamTypeStdout)
		remoteStdout, err = conn.CreateStream(headers)
		if err != nil {
			return err
		}
	}

	// set up stderr stream
	if e.stderr != nil && !e.tty {
		headers.Set(api.StreamType, api.StreamTypeStderr)
		remoteStderr, err = conn.CreateStream(headers)
		if err != nil {
			return err
		}
	}

	// now that all the streams have been created, proceed with reading & copying

	// always read from errorStream
	go func() {
		message, err := ioutil.ReadAll(errorStream)
		switch {
		case err != nil && err != io.EOF:
			errorChan <- fmt.Errorf("error reading from error stream: %s", err)
		case len(message) > 0:
			errorChan <- fmt.Errorf("error executing remote command: %s", message)
		default:
			errorChan <- nil
		}
		close(errorChan)
	}()

	var wg sync.WaitGroup
	var once sync.Once

	if e.stdin != nil {
		// copy from client's stdin to container's stdin
		go func() {
			// if e.stdin is noninteractive, e.g. `echo abc | kubectl exec -i <pod> -- cat`, make sure
			// we close remoteStdin as soon as the copy from e.stdin to remoteStdin finishes. Otherwise
			// the executed command will remain running.
			defer once.Do(func() { remoteStdin.Close() })

			if _, err := io.Copy(remoteStdin, e.stdin); err != nil {
				runtime.HandleError(err)
			}
		}()

		// read from remoteStdin until the stream is closed. this is essential to
		// be able to exit interactive sessions cleanly and not leak goroutines or
		// hang the client's terminal.
		//
		// go-dockerclient's current hijack implementation
		// (https://github.com/fsouza/go-dockerclient/blob/89f3d56d93788dfe85f864a44f85d9738fca0670/client.go#L564)
		// waits for all three streams (stdin/stdout/stderr) to finish copying
		// before returning. When hijack finishes copying stdout/stderr, it calls
		// Close() on its side of remoteStdin, which allows this copy to complete.
		// When that happens, we must Close() on our side of remoteStdin, to
		// allow the copy in hijack to complete, and hijack to return.
		go func() {
			defer once.Do(func() { remoteStdin.Close() })
			// this "copy" doesn't actually read anything - it's just here to wait for
			// the server to close remoteStdin.
			if _, err := io.Copy(ioutil.Discard, remoteStdin); err != nil {
				runtime.HandleError(err)
			}
		}()
	}

	if e.stdout != nil {
		wg.Add(1)
		go func() {
			defer wg.Done()
			if _, err := io.Copy(e.stdout, remoteStdout); err != nil {
				runtime.HandleError(err)
			}
		}()
	}

	if e.stderr != nil && !e.tty {
		wg.Add(1)
		go func() {
			defer wg.Done()
			if _, err := io.Copy(e.stderr, remoteStderr); err != nil {
				runtime.HandleError(err)
			}
		}()
	}

	// we're waiting for stdout/stderr to finish copying
	wg.Wait()

	// waits for errorStream to finish reading with an error or nil
	return <-errorChan
}
Example #4
0
func (e *streamProtocolV1) stream(conn httpstream.Connection) error {
	doneChan := make(chan struct{}, 2)
	errorChan := make(chan error)

	cp := func(s string, dst io.Writer, src io.Reader) {
		glog.V(6).Infof("Copying %s", s)
		defer glog.V(6).Infof("Done copying %s", s)
		if _, err := io.Copy(dst, src); err != nil && err != io.EOF {
			glog.Errorf("Error copying %s: %v", s, err)
		}
		if s == api.StreamTypeStdout || s == api.StreamTypeStderr {
			doneChan <- struct{}{}
		}
	}

	var (
		err                                                  error
		errorStream, remoteStdin, remoteStdout, remoteStderr httpstream.Stream
	)

	// set up all the streams first
	headers := http.Header{}
	headers.Set(api.StreamType, api.StreamTypeError)
	errorStream, err = conn.CreateStream(headers)
	if err != nil {
		return err
	}
	defer errorStream.Reset()

	// Create all the streams first, then start the copy goroutines. The server doesn't start its copy
	// goroutines until it's received all of the streams. If the client creates the stdin stream and
	// immediately begins copying stdin data to the server, it's possible to overwhelm and wedge the
	// spdy frame handler in the server so that it is full of unprocessed frames. The frames aren't
	// getting processed because the server hasn't started its copying, and it won't do that until it
	// gets all the streams. By creating all the streams first, we ensure that the server is ready to
	// process data before the client starts sending any. See https://issues.k8s.io/16373 for more info.
	if e.stdin != nil {
		headers.Set(api.StreamType, api.StreamTypeStdin)
		remoteStdin, err = conn.CreateStream(headers)
		if err != nil {
			return err
		}
		defer remoteStdin.Reset()
	}

	if e.stdout != nil {
		headers.Set(api.StreamType, api.StreamTypeStdout)
		remoteStdout, err = conn.CreateStream(headers)
		if err != nil {
			return err
		}
		defer remoteStdout.Reset()
	}

	if e.stderr != nil && !e.tty {
		headers.Set(api.StreamType, api.StreamTypeStderr)
		remoteStderr, err = conn.CreateStream(headers)
		if err != nil {
			return err
		}
		defer remoteStderr.Reset()
	}

	// now that all the streams have been created, proceed with reading & copying

	// always read from errorStream
	go func() {
		message, err := ioutil.ReadAll(errorStream)
		if err != nil && err != io.EOF {
			errorChan <- fmt.Errorf("Error reading from error stream: %s", err)
			return
		}
		if len(message) > 0 {
			errorChan <- fmt.Errorf("Error executing remote command: %s", message)
			return
		}
	}()

	if e.stdin != nil {
		// TODO this goroutine will never exit cleanly (the io.Copy never unblocks)
		// because stdin is not closed until the process exits. If we try to call
		// stdin.Close(), it returns no error but doesn't unblock the copy. It will
		// exit when the process exits, instead.
		go cp(api.StreamTypeStdin, remoteStdin, e.stdin)
	}

	waitCount := 0
	completedStreams := 0

	if e.stdout != nil {
		waitCount++
		go cp(api.StreamTypeStdout, e.stdout, remoteStdout)
	}

	if e.stderr != nil && !e.tty {
		waitCount++
		go cp(api.StreamTypeStderr, e.stderr, remoteStderr)
	}

Loop:
	for {
		select {
		case <-doneChan:
			completedStreams++
			if completedStreams == waitCount {
				break Loop
			}
		case err := <-errorChan:
			return err
		}
	}

	return nil
}