Example #1
0
// requestID returns the request id for stream.
func (h *portForwardStreamHandler) requestID(stream httpstream.Stream) string {
	requestID := stream.Headers().Get(api.PortForwardRequestIDHeader)
	if len(requestID) == 0 {
		glog.V(5).Infof("(conn=%p) stream received without %s header", h.conn, api.PortForwardRequestIDHeader)
		// If we get here, it's because the connection came from an older client
		// that isn't generating the request id header
		// (https://github.com/kubernetes/kubernetes/blob/843134885e7e0b360eb5441e85b1410a8b1a7a0c/pkg/client/unversioned/portforward/portforward.go#L258-L287)
		//
		// This is a best-effort attempt at supporting older clients.
		//
		// When there aren't concurrent new forwarded connections, each connection
		// will have a pair of streams (data, error), and the stream IDs will be
		// consecutive odd numbers, e.g. 1 and 3 for the first connection. Convert
		// the stream ID into a pseudo-request id by taking the stream type and
		// using id = stream.Identifier() when the stream type is error,
		// and id = stream.Identifier() - 2 when it's data.
		//
		// NOTE: this only works when there are not concurrent new streams from
		// multiple forwarded connections; it's a best-effort attempt at supporting
		// old clients that don't generate request ids.  If there are concurrent
		// new connections, it's possible that 1 connection gets streams whose IDs
		// are not consecutive (e.g. 5 and 9 instead of 5 and 7).
		streamType := stream.Headers().Get(api.StreamType)
		switch streamType {
		case api.StreamTypeError:
			requestID = strconv.Itoa(int(stream.Identifier()))
		case api.StreamTypeData:
			requestID = strconv.Itoa(int(stream.Identifier()) - 2)
		}

		glog.V(5).Infof("(conn=%p) automatically assigning request ID=%q from stream type=%s, stream ID=%d", h.conn, requestID, streamType, stream.Identifier())
	}
	return requestID
}
Example #2
0
// add adds the stream to the portForwardStreamPair. If the pair already
// contains a stream for the new stream's type, an error is returned. add
// returns true if both the data and error streams for this pair have been
// received.
func (p *portForwardStreamPair) add(stream httpstream.Stream) (bool, error) {
	p.lock.Lock()
	defer p.lock.Unlock()

	switch stream.Headers().Get(api.StreamType) {
	case api.StreamTypeError:
		if p.errorStream != nil {
			return false, errors.New("error stream already assigned")
		}
		p.errorStream = stream
	case api.StreamTypeData:
		if p.dataStream != nil {
			return false, errors.New("data stream already assigned")
		}
		p.dataStream = stream
	}

	complete := p.errorStream != nil && p.dataStream != nil
	if complete {
		close(p.complete)
	}
	return complete, nil
}
Example #3
0
func waitForPortForwardDataStreamAndRun(pod string, uid types.UID, errorStream httpstream.Stream, dataStreamChan chan httpstream.Stream, host HostInterface) {
	defer errorStream.Reset()

	var dataStream httpstream.Stream

	select {
	case dataStream = <-dataStreamChan:
	case <-time.After(streamCreationTimeout):
		errorStream.Write([]byte("Timed out waiting for data stream"))
		//TODO delete from dataStreamChans[port]
		return
	}

	portString := dataStream.Headers().Get(api.PortHeader)
	port, _ := strconv.ParseUint(portString, 10, 16)
	err := host.PortForward(pod, uid, uint16(port), dataStream)
	if err != nil {
		msg := fmt.Errorf("Error forwarding port %d to pod %s, uid %v: %v", port, pod, uid, err)
		glog.Error(msg)
		errorStream.Write([]byte(msg.Error()))
	}
}
Example #4
0
func fakeExecServer(t *testing.T, i int, stdinData, stdoutData, stderrData, errorData string, tty bool) http.HandlerFunc {
	// error + stdin + stdout
	expectedStreams := 3
	if !tty {
		// stderr
		expectedStreams++
	}

	return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
		streamCh := make(chan httpstream.Stream)

		upgrader := spdy.NewResponseUpgrader()
		conn := upgrader.UpgradeResponse(w, req, func(stream httpstream.Stream) error {
			streamCh <- stream
			return nil
		})
		// from this point on, we can no longer call methods on w
		if conn == nil {
			// The upgrader is responsible for notifying the client of any errors that
			// occurred during upgrading. All we can do is return here at this point
			// if we weren't successful in upgrading.
			return
		}
		defer conn.Close()

		var errorStream, stdinStream, stdoutStream, stderrStream httpstream.Stream
		receivedStreams := 0
	WaitForStreams:
		for {
			select {
			case stream := <-streamCh:
				streamType := stream.Headers().Get(api.StreamType)
				switch streamType {
				case api.StreamTypeError:
					errorStream = stream
					receivedStreams++
				case api.StreamTypeStdin:
					stdinStream = stream
					stdinStream.Close()
					receivedStreams++
				case api.StreamTypeStdout:
					stdoutStream = stream
					receivedStreams++
				case api.StreamTypeStderr:
					stderrStream = stream
					receivedStreams++
				default:
					t.Errorf("%d: unexpected stream type: %q", i, streamType)
				}

				defer stream.Reset()

				if receivedStreams == expectedStreams {
					break WaitForStreams
				}
			}
		}

		if len(errorData) > 0 {
			fmt.Fprint(errorStream, errorData)
			errorStream.Close()
		}

		if len(stdoutData) > 0 {
			fmt.Fprint(stdoutStream, stdoutData)
			stdoutStream.Close()
		}
		if len(stderrData) > 0 {
			fmt.Fprint(stderrStream, stderrData)
			stderrStream.Close()
		}
		if len(stdinData) > 0 {
			data, err := ioutil.ReadAll(stdinStream)
			if err != nil {
				t.Errorf("%d: error reading stdin stream: %v", i, err)
			}
			if e, a := stdinData, string(data); e != a {
				t.Errorf("%d: stdin: expected %q, got %q", i, e, a)
			}
		}
	})
}
Example #5
0
func fakeExecServer(t *testing.T, i int, stdinData, stdoutData, stderrData, errorData string, tty bool, messageCount int) http.HandlerFunc {
	// error + stdin + stdout
	expectedStreams := 3
	if !tty {
		// stderr
		expectedStreams++
	}

	return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
		protocol, err := httpstream.Handshake(req, w, []string{StreamProtocolV2Name}, StreamProtocolV1Name)
		if err != nil {
			t.Fatal(err)
		}
		if protocol != StreamProtocolV2Name {
			t.Fatalf("unexpected protocol: %s", protocol)
		}
		streamCh := make(chan streamAndReply)

		upgrader := spdy.NewResponseUpgrader()
		conn := upgrader.UpgradeResponse(w, req, func(stream httpstream.Stream, replySent <-chan struct{}) error {
			streamCh <- streamAndReply{Stream: stream, replySent: replySent}
			return nil
		})
		// from this point on, we can no longer call methods on w
		if conn == nil {
			// The upgrader is responsible for notifying the client of any errors that
			// occurred during upgrading. All we can do is return here at this point
			// if we weren't successful in upgrading.
			return
		}
		defer conn.Close()

		var errorStream, stdinStream, stdoutStream, stderrStream httpstream.Stream
		receivedStreams := 0
		replyChan := make(chan struct{})
		stop := make(chan struct{})
		defer close(stop)
	WaitForStreams:
		for {
			select {
			case stream := <-streamCh:
				streamType := stream.Headers().Get(api.StreamType)
				switch streamType {
				case api.StreamTypeError:
					errorStream = stream
					go waitStreamReply(stream.replySent, replyChan, stop)
				case api.StreamTypeStdin:
					stdinStream = stream
					go waitStreamReply(stream.replySent, replyChan, stop)
				case api.StreamTypeStdout:
					stdoutStream = stream
					go waitStreamReply(stream.replySent, replyChan, stop)
				case api.StreamTypeStderr:
					stderrStream = stream
					go waitStreamReply(stream.replySent, replyChan, stop)
				default:
					t.Errorf("%d: unexpected stream type: %q", i, streamType)
				}

				if receivedStreams == expectedStreams {
					break WaitForStreams
				}
			case <-replyChan:
				receivedStreams++
				if receivedStreams == expectedStreams {
					break WaitForStreams
				}
			}
		}

		if len(errorData) > 0 {
			n, err := fmt.Fprint(errorStream, errorData)
			if err != nil {
				t.Errorf("%d: error writing to errorStream: %v", i, err)
			}
			if e, a := len(errorData), n; e != a {
				t.Errorf("%d: expected to write %d bytes to errorStream, but only wrote %d", i, e, a)
			}
			errorStream.Close()
		}

		if len(stdoutData) > 0 {
			for j := 0; j < messageCount; j++ {
				n, err := fmt.Fprint(stdoutStream, stdoutData)
				if err != nil {
					t.Errorf("%d: error writing to stdoutStream: %v", i, err)
				}
				if e, a := len(stdoutData), n; e != a {
					t.Errorf("%d: expected to write %d bytes to stdoutStream, but only wrote %d", i, e, a)
				}
			}
			stdoutStream.Close()
		}
		if len(stderrData) > 0 {
			for j := 0; j < messageCount; j++ {
				n, err := fmt.Fprint(stderrStream, stderrData)
				if err != nil {
					t.Errorf("%d: error writing to stderrStream: %v", i, err)
				}
				if e, a := len(stderrData), n; e != a {
					t.Errorf("%d: expected to write %d bytes to stderrStream, but only wrote %d", i, e, a)
				}
			}
			stderrStream.Close()
		}
		if len(stdinData) > 0 {
			data := make([]byte, len(stdinData))
			for j := 0; j < messageCount; j++ {
				n, err := io.ReadFull(stdinStream, data)
				if err != nil {
					t.Errorf("%d: error reading stdin stream: %v", i, err)
				}
				if e, a := len(stdinData), n; e != a {
					t.Errorf("%d: expected to read %d bytes from stdinStream, but only read %d", i, e, a)
				}
				if e, a := stdinData, string(data); e != a {
					t.Errorf("%d: stdin: expected %q, got %q", i, e, a)
				}
			}
			stdinStream.Close()
		}
	})
}
Example #6
0
func TestServeExecInContainer(t *testing.T) {
	tests := []struct {
		stdin              bool
		stdout             bool
		stderr             bool
		tty                bool
		responseStatusCode int
		uid                bool
	}{
		{responseStatusCode: http.StatusBadRequest},
		{stdin: true, responseStatusCode: http.StatusSwitchingProtocols},
		{stdout: true, responseStatusCode: http.StatusSwitchingProtocols},
		{stderr: true, responseStatusCode: http.StatusSwitchingProtocols},
		{stdout: true, stderr: true, responseStatusCode: http.StatusSwitchingProtocols},
		{stdout: true, stderr: true, tty: true, responseStatusCode: http.StatusSwitchingProtocols},
		{stdin: true, stdout: true, stderr: true, responseStatusCode: http.StatusSwitchingProtocols},
	}

	for i, test := range tests {
		fw := newServerTest()

		fw.fakeKubelet.streamingConnectionIdleTimeoutFunc = func() time.Duration {
			return 0
		}

		podNamespace := "other"
		podName := "foo"
		expectedPodName := getPodName(podName, podNamespace)
		expectedUid := "9b01b80f-8fb4-11e4-95ab-4200af06647"
		expectedContainerName := "baz"
		expectedCommand := "ls -a"
		expectedStdin := "stdin"
		expectedStdout := "stdout"
		expectedStderr := "stderr"
		execFuncDone := make(chan struct{})
		clientStdoutReadDone := make(chan struct{})
		clientStderrReadDone := make(chan struct{})

		fw.fakeKubelet.execFunc = func(podFullName string, uid types.UID, containerName string, cmd []string, in io.Reader, out, stderr io.WriteCloser, tty bool) error {
			defer close(execFuncDone)
			if podFullName != expectedPodName {
				t.Fatalf("%d: podFullName: expected %s, got %s", i, expectedPodName, podFullName)
			}
			if test.uid && string(uid) != expectedUid {
				t.Fatalf("%d: uid: expected %v, got %v", i, expectedUid, uid)
			}
			if containerName != expectedContainerName {
				t.Fatalf("%d: containerName: expected %s, got %s", i, expectedContainerName, containerName)
			}
			if strings.Join(cmd, " ") != expectedCommand {
				t.Fatalf("%d: cmd: expected: %s, got %v", i, expectedCommand, cmd)
			}

			if test.stdin {
				if in == nil {
					t.Fatalf("%d: stdin: expected non-nil", i)
				}
				b := make([]byte, 10)
				n, err := in.Read(b)
				if err != nil {
					t.Fatalf("%d: error reading from stdin: %v", i, err)
				}
				if e, a := expectedStdin, string(b[0:n]); e != a {
					t.Fatalf("%d: stdin: expected to read %v, got %v", i, e, a)
				}
			} else if in != nil {
				t.Fatalf("%d: stdin: expected nil: %#v", i, in)
			}

			if test.stdout {
				if out == nil {
					t.Fatalf("%d: stdout: expected non-nil", i)
				}
				_, err := out.Write([]byte(expectedStdout))
				if err != nil {
					t.Fatalf("%d:, error writing to stdout: %v", i, err)
				}
				out.Close()
				<-clientStdoutReadDone
			} else if out != nil {
				t.Fatalf("%d: stdout: expected nil: %#v", i, out)
			}

			if tty {
				if stderr != nil {
					t.Fatalf("%d: tty set but received non-nil stderr: %v", i, stderr)
				}
			} else if test.stderr {
				if stderr == nil {
					t.Fatalf("%d: stderr: expected non-nil", i)
				}
				_, err := stderr.Write([]byte(expectedStderr))
				if err != nil {
					t.Fatalf("%d:, error writing to stderr: %v", i, err)
				}
				stderr.Close()
				<-clientStderrReadDone
			} else if stderr != nil {
				t.Fatalf("%d: stderr: expected nil: %#v", i, stderr)
			}

			return nil
		}

		var url string
		if test.uid {
			url = fw.testHTTPServer.URL + "/exec/" + podNamespace + "/" + podName + "/" + expectedUid + "/" + expectedContainerName + "?command=ls&command=-a"
		} else {
			url = fw.testHTTPServer.URL + "/exec/" + podNamespace + "/" + podName + "/" + expectedContainerName + "?command=ls&command=-a"
		}
		if test.stdin {
			url += "&" + api.ExecStdinParam + "=1"
		}
		if test.stdout {
			url += "&" + api.ExecStdoutParam + "=1"
		}
		if test.stderr && !test.tty {
			url += "&" + api.ExecStderrParam + "=1"
		}
		if test.tty {
			url += "&" + api.ExecTTYParam + "=1"
		}

		var (
			resp                *http.Response
			err                 error
			upgradeRoundTripper httpstream.UpgradeRoundTripper
			c                   *http.Client
		)

		if test.responseStatusCode != http.StatusSwitchingProtocols {
			c = &http.Client{}
		} else {
			upgradeRoundTripper = spdy.NewRoundTripper(nil)
			c = &http.Client{Transport: upgradeRoundTripper}
		}

		resp, err = c.Post(url, "", nil)
		if err != nil {
			t.Fatalf("%d: Got error POSTing: %v", i, err)
		}
		defer resp.Body.Close()

		_, err = ioutil.ReadAll(resp.Body)
		if err != nil {
			t.Errorf("%d: Error reading response body: %v", i, err)
		}

		if e, a := test.responseStatusCode, resp.StatusCode; e != a {
			t.Fatalf("%d: response status: expected %v, got %v", i, e, a)
		}

		if test.responseStatusCode != http.StatusSwitchingProtocols {
			continue
		}

		conn, err := upgradeRoundTripper.NewConnection(resp)
		if err != nil {
			t.Fatalf("Unexpected error creating streaming connection: %s", err)
		}
		if conn == nil {
			t.Fatalf("%d: unexpected nil conn", i)
		}
		defer conn.Close()

		h := http.Header{}
		h.Set(api.StreamType, api.StreamTypeError)
		errorStream, err := conn.CreateStream(h)
		if err != nil {
			t.Fatalf("%d: error creating error stream: %v", i, err)
		}
		defer errorStream.Reset()

		if test.stdin {
			h.Set(api.StreamType, api.StreamTypeStdin)
			stream, err := conn.CreateStream(h)
			if err != nil {
				t.Fatalf("%d: error creating stdin stream: %v", i, err)
			}
			defer stream.Reset()
			_, err = stream.Write([]byte(expectedStdin))
			if err != nil {
				t.Fatalf("%d: error writing to stdin stream: %v", i, err)
			}
		}

		var stdoutStream httpstream.Stream
		if test.stdout {
			h.Set(api.StreamType, api.StreamTypeStdout)
			stdoutStream, err = conn.CreateStream(h)
			if err != nil {
				t.Fatalf("%d: error creating stdout stream: %v", i, err)
			}
			defer stdoutStream.Reset()
		}

		var stderrStream httpstream.Stream
		if test.stderr && !test.tty {
			h.Set(api.StreamType, api.StreamTypeStderr)
			stderrStream, err = conn.CreateStream(h)
			if err != nil {
				t.Fatalf("%d: error creating stderr stream: %v", i, err)
			}
			defer stderrStream.Reset()
		}

		if test.stdout {
			output := make([]byte, 10)
			n, err := stdoutStream.Read(output)
			close(clientStdoutReadDone)
			if err != nil {
				t.Fatalf("%d: error reading from stdout stream: %v", i, err)
			}
			if e, a := expectedStdout, string(output[0:n]); e != a {
				t.Fatalf("%d: stdout: expected '%v', got '%v'", i, e, a)
			}
		}

		if test.stderr && !test.tty {
			output := make([]byte, 10)
			n, err := stderrStream.Read(output)
			close(clientStderrReadDone)
			if err != nil {
				t.Fatalf("%d: error reading from stderr stream: %v", i, err)
			}
			if e, a := expectedStderr, string(output[0:n]); e != a {
				t.Fatalf("%d: stderr: expected '%v', got '%v'", i, e, a)
			}
		}

		<-execFuncDone
	}
}
Example #7
0
func (s *Server) createStreams(request *restful.Request, response *restful.Response) (io.Reader, io.WriteCloser, io.WriteCloser, io.WriteCloser, httpstream.Connection, bool, bool) {
	// start at 1 for error stream
	expectedStreams := 1
	if request.QueryParameter(api.ExecStdinParam) == "1" {
		expectedStreams++
	}
	if request.QueryParameter(api.ExecStdoutParam) == "1" {
		expectedStreams++
	}
	tty := request.QueryParameter(api.ExecTTYParam) == "1"
	if !tty && request.QueryParameter(api.ExecStderrParam) == "1" {
		expectedStreams++
	}

	if expectedStreams == 1 {
		response.WriteError(http.StatusBadRequest, fmt.Errorf("you must specify at least 1 of stdin, stdout, stderr"))
		return nil, nil, nil, nil, nil, false, false
	}

	streamCh := make(chan httpstream.Stream)

	upgrader := spdy.NewResponseUpgrader()
	conn := upgrader.UpgradeResponse(response.ResponseWriter, request.Request, func(stream httpstream.Stream) error {
		streamCh <- stream
		return nil
	})
	// from this point on, we can no longer call methods on response
	if conn == nil {
		// The upgrader is responsible for notifying the client of any errors that
		// occurred during upgrading. All we can do is return here at this point
		// if we weren't successful in upgrading.
		return nil, nil, nil, nil, nil, false, false
	}

	conn.SetIdleTimeout(s.host.StreamingConnectionIdleTimeout())

	// TODO make it configurable?
	expired := time.NewTimer(streamCreationTimeout)

	var errorStream, stdinStream, stdoutStream, stderrStream httpstream.Stream
	receivedStreams := 0
WaitForStreams:
	for {
		select {
		case stream := <-streamCh:
			streamType := stream.Headers().Get(api.StreamType)
			switch streamType {
			case api.StreamTypeError:
				errorStream = stream
				defer errorStream.Reset()
				receivedStreams++
			case api.StreamTypeStdin:
				stdinStream = stream
				receivedStreams++
			case api.StreamTypeStdout:
				stdoutStream = stream
				receivedStreams++
			case api.StreamTypeStderr:
				stderrStream = stream
				receivedStreams++
			default:
				glog.Errorf("Unexpected stream type: '%s'", streamType)
			}
			if receivedStreams == expectedStreams {
				break WaitForStreams
			}
		case <-expired.C:
			// TODO find a way to return the error to the user. Maybe use a separate
			// stream to report errors?
			glog.Error("Timed out waiting for client to create streams")
			return nil, nil, nil, nil, nil, false, false
		}
	}

	if stdinStream != nil {
		// close our half of the input stream, since we won't be writing to it
		stdinStream.Close()
	}

	return stdinStream, stdoutStream, stderrStream, errorStream, conn, tty, true
}
Example #8
0
func (e *streamProtocolV2) stream(conn httpstream.Connection) error {
	var (
		err                                                  error
		errorStream, remoteStdin, remoteStdout, remoteStderr httpstream.Stream
	)

	headers := http.Header{}

	// set up all the streams first
	// set up error stream
	errorChan := make(chan error)
	headers.Set(api.StreamType, api.StreamTypeError)
	errorStream, err = conn.CreateStream(headers)
	if err != nil {
		return err
	}

	// set up stdin stream
	if e.stdin != nil {
		headers.Set(api.StreamType, api.StreamTypeStdin)
		remoteStdin, err = conn.CreateStream(headers)
		if err != nil {
			return err
		}
	}

	// set up stdout stream
	if e.stdout != nil {
		headers.Set(api.StreamType, api.StreamTypeStdout)
		remoteStdout, err = conn.CreateStream(headers)
		if err != nil {
			return err
		}
	}

	// set up stderr stream
	if e.stderr != nil && !e.tty {
		headers.Set(api.StreamType, api.StreamTypeStderr)
		remoteStderr, err = conn.CreateStream(headers)
		if err != nil {
			return err
		}
	}

	// now that all the streams have been created, proceed with reading & copying

	// always read from errorStream
	go func() {
		message, err := ioutil.ReadAll(errorStream)
		switch {
		case err != nil && err != io.EOF:
			errorChan <- fmt.Errorf("error reading from error stream: %s", err)
		case len(message) > 0:
			errorChan <- fmt.Errorf("error executing remote command: %s", message)
		default:
			errorChan <- nil
		}
		close(errorChan)
	}()

	var wg sync.WaitGroup
	var once sync.Once

	if e.stdin != nil {
		// copy from client's stdin to container's stdin
		go func() {
			// if e.stdin is noninteractive, e.g. `echo abc | kubectl exec -i <pod> -- cat`, make sure
			// we close remoteStdin as soon as the copy from e.stdin to remoteStdin finishes. Otherwise
			// the executed command will remain running.
			defer once.Do(func() { remoteStdin.Close() })

			if _, err := io.Copy(remoteStdin, e.stdin); err != nil {
				runtime.HandleError(err)
			}
		}()

		// read from remoteStdin until the stream is closed. this is essential to
		// be able to exit interactive sessions cleanly and not leak goroutines or
		// hang the client's terminal.
		//
		// go-dockerclient's current hijack implementation
		// (https://github.com/fsouza/go-dockerclient/blob/89f3d56d93788dfe85f864a44f85d9738fca0670/client.go#L564)
		// waits for all three streams (stdin/stdout/stderr) to finish copying
		// before returning. When hijack finishes copying stdout/stderr, it calls
		// Close() on its side of remoteStdin, which allows this copy to complete.
		// When that happens, we must Close() on our side of remoteStdin, to
		// allow the copy in hijack to complete, and hijack to return.
		go func() {
			defer once.Do(func() { remoteStdin.Close() })
			// this "copy" doesn't actually read anything - it's just here to wait for
			// the server to close remoteStdin.
			if _, err := io.Copy(ioutil.Discard, remoteStdin); err != nil {
				runtime.HandleError(err)
			}
		}()
	}

	if e.stdout != nil {
		wg.Add(1)
		go func() {
			defer wg.Done()
			if _, err := io.Copy(e.stdout, remoteStdout); err != nil {
				runtime.HandleError(err)
			}
		}()
	}

	if e.stderr != nil && !e.tty {
		wg.Add(1)
		go func() {
			defer wg.Done()
			if _, err := io.Copy(e.stderr, remoteStderr); err != nil {
				runtime.HandleError(err)
			}
		}()
	}

	// we're waiting for stdout/stderr to finish copying
	wg.Wait()

	// waits for errorStream to finish reading with an error or nil
	return <-errorChan
}
Example #9
0
func (e *streamProtocolV1) stream(conn httpstream.Connection) error {
	doneChan := make(chan struct{}, 2)
	errorChan := make(chan error)

	cp := func(s string, dst io.Writer, src io.Reader) {
		glog.V(6).Infof("Copying %s", s)
		defer glog.V(6).Infof("Done copying %s", s)
		if _, err := io.Copy(dst, src); err != nil && err != io.EOF {
			glog.Errorf("Error copying %s: %v", s, err)
		}
		if s == api.StreamTypeStdout || s == api.StreamTypeStderr {
			doneChan <- struct{}{}
		}
	}

	var (
		err                                                  error
		errorStream, remoteStdin, remoteStdout, remoteStderr httpstream.Stream
	)

	// set up all the streams first
	headers := http.Header{}
	headers.Set(api.StreamType, api.StreamTypeError)
	errorStream, err = conn.CreateStream(headers)
	if err != nil {
		return err
	}
	defer errorStream.Reset()

	// Create all the streams first, then start the copy goroutines. The server doesn't start its copy
	// goroutines until it's received all of the streams. If the client creates the stdin stream and
	// immediately begins copying stdin data to the server, it's possible to overwhelm and wedge the
	// spdy frame handler in the server so that it is full of unprocessed frames. The frames aren't
	// getting processed because the server hasn't started its copying, and it won't do that until it
	// gets all the streams. By creating all the streams first, we ensure that the server is ready to
	// process data before the client starts sending any. See https://issues.k8s.io/16373 for more info.
	if e.stdin != nil {
		headers.Set(api.StreamType, api.StreamTypeStdin)
		remoteStdin, err = conn.CreateStream(headers)
		if err != nil {
			return err
		}
		defer remoteStdin.Reset()
	}

	if e.stdout != nil {
		headers.Set(api.StreamType, api.StreamTypeStdout)
		remoteStdout, err = conn.CreateStream(headers)
		if err != nil {
			return err
		}
		defer remoteStdout.Reset()
	}

	if e.stderr != nil && !e.tty {
		headers.Set(api.StreamType, api.StreamTypeStderr)
		remoteStderr, err = conn.CreateStream(headers)
		if err != nil {
			return err
		}
		defer remoteStderr.Reset()
	}

	// now that all the streams have been created, proceed with reading & copying

	// always read from errorStream
	go func() {
		message, err := ioutil.ReadAll(errorStream)
		if err != nil && err != io.EOF {
			errorChan <- fmt.Errorf("Error reading from error stream: %s", err)
			return
		}
		if len(message) > 0 {
			errorChan <- fmt.Errorf("Error executing remote command: %s", message)
			return
		}
	}()

	if e.stdin != nil {
		// TODO this goroutine will never exit cleanly (the io.Copy never unblocks)
		// because stdin is not closed until the process exits. If we try to call
		// stdin.Close(), it returns no error but doesn't unblock the copy. It will
		// exit when the process exits, instead.
		go cp(api.StreamTypeStdin, remoteStdin, e.stdin)
	}

	waitCount := 0
	completedStreams := 0

	if e.stdout != nil {
		waitCount++
		go cp(api.StreamTypeStdout, e.stdout, remoteStdout)
	}

	if e.stderr != nil && !e.tty {
		waitCount++
		go cp(api.StreamTypeStderr, e.stderr, remoteStderr)
	}

Loop:
	for {
		select {
		case <-doneChan:
			completedStreams++
			if completedStreams == waitCount {
				break Loop
			}
		case err := <-errorChan:
			return err
		}
	}

	return nil
}
Example #10
0
func (e *streamProtocolV1) stream(conn httpstream.Connection) error {
	doneChan := make(chan struct{}, 2)
	errorChan := make(chan error)

	cp := func(s string, dst io.Writer, src io.Reader) {
		glog.V(6).Infof("Copying %s", s)
		defer glog.V(6).Infof("Done copying %s", s)
		if _, err := io.Copy(dst, src); err != nil && err != io.EOF {
			glog.Errorf("Error copying %s: %v", s, err)
		}
		if s == api.StreamTypeStdout || s == api.StreamTypeStderr {
			doneChan <- struct{}{}
		}
	}

	var (
		err                                                  error
		errorStream, remoteStdin, remoteStdout, remoteStderr httpstream.Stream
	)

	// set up all the streams first
	headers := http.Header{}
	headers.Set(api.StreamType, api.StreamTypeError)
	errorStream, err = conn.CreateStream(headers)
	if err != nil {
		return err
	}
	defer errorStream.Reset()

	if e.stdin != nil {
		headers.Set(api.StreamType, api.StreamTypeStdin)
		remoteStdin, err = conn.CreateStream(headers)
		if err != nil {
			return err
		}
		defer remoteStdin.Reset()
	}

	if e.stdout != nil {
		headers.Set(api.StreamType, api.StreamTypeStdout)
		remoteStdout, err = conn.CreateStream(headers)
		if err != nil {
			return err
		}
		defer remoteStdout.Reset()
	}

	if e.stderr != nil && !e.tty {
		headers.Set(api.StreamType, api.StreamTypeStderr)
		remoteStderr, err = conn.CreateStream(headers)
		if err != nil {
			return err
		}
		defer remoteStderr.Reset()
	}

	// now that all the streams have been created, proceed with reading & copying

	// always read from errorStream
	go func() {
		message, err := ioutil.ReadAll(errorStream)
		if err != nil && err != io.EOF {
			errorChan <- fmt.Errorf("Error reading from error stream: %s", err)
			return
		}
		if len(message) > 0 {
			errorChan <- fmt.Errorf("Error executing remote command: %s", message)
			return
		}
	}()

	if e.stdin != nil {
		// TODO this goroutine will never exit cleanly (the io.Copy never unblocks)
		// because stdin is not closed until the process exits. If we try to call
		// stdin.Close(), it returns no error but doesn't unblock the copy. It will
		// exit when the process exits, instead.
		go cp(api.StreamTypeStdin, remoteStdin, e.stdin)
	}

	waitCount := 0
	completedStreams := 0

	if e.stdout != nil {
		waitCount++
		go cp(api.StreamTypeStdout, e.stdout, remoteStdout)
	}

	if e.stderr != nil && !e.tty {
		waitCount++
		go cp(api.StreamTypeStderr, e.stderr, remoteStderr)
	}

Loop:
	for {
		select {
		case <-doneChan:
			completedStreams++
			if completedStreams == waitCount {
				break Loop
			}
		case err := <-errorChan:
			return err
		}
	}

	return nil
}