func waitForPortForwardDataStreamAndRun(pod string, uid types.UID, errorStream httpstream.Stream, dataStreamChan chan httpstream.Stream, host HostInterface) { defer errorStream.Reset() var dataStream httpstream.Stream select { case dataStream = <-dataStreamChan: case <-time.After(streamCreationTimeout): errorStream.Write([]byte("Timed out waiting for data stream")) //TODO delete from dataStreamChans[port] return } portString := dataStream.Headers().Get(api.PortHeader) port, _ := strconv.ParseUint(portString, 10, 16) err := host.PortForward(pod, uid, uint16(port), dataStream) if err != nil { msg := fmt.Errorf("Error forwarding port %d to pod %s, uid %v: %v", port, pod, uid, err) glog.Error(msg) errorStream.Write([]byte(msg.Error())) } }
func TestServeExecInContainer(t *testing.T) { tests := []struct { stdin bool stdout bool stderr bool tty bool responseStatusCode int uid bool }{ {responseStatusCode: http.StatusBadRequest}, {stdin: true, responseStatusCode: http.StatusSwitchingProtocols}, {stdout: true, responseStatusCode: http.StatusSwitchingProtocols}, {stderr: true, responseStatusCode: http.StatusSwitchingProtocols}, {stdout: true, stderr: true, responseStatusCode: http.StatusSwitchingProtocols}, {stdout: true, stderr: true, tty: true, responseStatusCode: http.StatusSwitchingProtocols}, {stdin: true, stdout: true, stderr: true, responseStatusCode: http.StatusSwitchingProtocols}, } for i, test := range tests { fw := newServerTest() fw.fakeKubelet.streamingConnectionIdleTimeoutFunc = func() time.Duration { return 0 } podNamespace := "other" podName := "foo" expectedPodName := getPodName(podName, podNamespace) expectedUid := "9b01b80f-8fb4-11e4-95ab-4200af06647" expectedContainerName := "baz" expectedCommand := "ls -a" expectedStdin := "stdin" expectedStdout := "stdout" expectedStderr := "stderr" execFuncDone := make(chan struct{}) clientStdoutReadDone := make(chan struct{}) clientStderrReadDone := make(chan struct{}) fw.fakeKubelet.execFunc = func(podFullName string, uid types.UID, containerName string, cmd []string, in io.Reader, out, stderr io.WriteCloser, tty bool) error { defer close(execFuncDone) if podFullName != expectedPodName { t.Fatalf("%d: podFullName: expected %s, got %s", i, expectedPodName, podFullName) } if test.uid && string(uid) != expectedUid { t.Fatalf("%d: uid: expected %v, got %v", i, expectedUid, uid) } if containerName != expectedContainerName { t.Fatalf("%d: containerName: expected %s, got %s", i, expectedContainerName, containerName) } if strings.Join(cmd, " ") != expectedCommand { t.Fatalf("%d: cmd: expected: %s, got %v", i, expectedCommand, cmd) } if test.stdin { if in == nil { t.Fatalf("%d: stdin: expected non-nil", i) } b := make([]byte, 10) n, err := in.Read(b) if err != nil { t.Fatalf("%d: error reading from stdin: %v", i, err) } if e, a := expectedStdin, string(b[0:n]); e != a { t.Fatalf("%d: stdin: expected to read %v, got %v", i, e, a) } } else if in != nil { t.Fatalf("%d: stdin: expected nil: %#v", i, in) } if test.stdout { if out == nil { t.Fatalf("%d: stdout: expected non-nil", i) } _, err := out.Write([]byte(expectedStdout)) if err != nil { t.Fatalf("%d:, error writing to stdout: %v", i, err) } out.Close() <-clientStdoutReadDone } else if out != nil { t.Fatalf("%d: stdout: expected nil: %#v", i, out) } if tty { if stderr != nil { t.Fatalf("%d: tty set but received non-nil stderr: %v", i, stderr) } } else if test.stderr { if stderr == nil { t.Fatalf("%d: stderr: expected non-nil", i) } _, err := stderr.Write([]byte(expectedStderr)) if err != nil { t.Fatalf("%d:, error writing to stderr: %v", i, err) } stderr.Close() <-clientStderrReadDone } else if stderr != nil { t.Fatalf("%d: stderr: expected nil: %#v", i, stderr) } return nil } var url string if test.uid { url = fw.testHTTPServer.URL + "/exec/" + podNamespace + "/" + podName + "/" + expectedUid + "/" + expectedContainerName + "?command=ls&command=-a" } else { url = fw.testHTTPServer.URL + "/exec/" + podNamespace + "/" + podName + "/" + expectedContainerName + "?command=ls&command=-a" } if test.stdin { url += "&" + api.ExecStdinParam + "=1" } if test.stdout { url += "&" + api.ExecStdoutParam + "=1" } if test.stderr && !test.tty { url += "&" + api.ExecStderrParam + "=1" } if test.tty { url += "&" + api.ExecTTYParam + "=1" } var ( resp *http.Response err error upgradeRoundTripper httpstream.UpgradeRoundTripper c *http.Client ) if test.responseStatusCode != http.StatusSwitchingProtocols { c = &http.Client{} } else { upgradeRoundTripper = spdy.NewRoundTripper(nil) c = &http.Client{Transport: upgradeRoundTripper} } resp, err = c.Post(url, "", nil) if err != nil { t.Fatalf("%d: Got error POSTing: %v", i, err) } defer resp.Body.Close() _, err = ioutil.ReadAll(resp.Body) if err != nil { t.Errorf("%d: Error reading response body: %v", i, err) } if e, a := test.responseStatusCode, resp.StatusCode; e != a { t.Fatalf("%d: response status: expected %v, got %v", i, e, a) } if test.responseStatusCode != http.StatusSwitchingProtocols { continue } conn, err := upgradeRoundTripper.NewConnection(resp) if err != nil { t.Fatalf("Unexpected error creating streaming connection: %s", err) } if conn == nil { t.Fatalf("%d: unexpected nil conn", i) } defer conn.Close() h := http.Header{} h.Set(api.StreamType, api.StreamTypeError) errorStream, err := conn.CreateStream(h) if err != nil { t.Fatalf("%d: error creating error stream: %v", i, err) } defer errorStream.Reset() if test.stdin { h.Set(api.StreamType, api.StreamTypeStdin) stream, err := conn.CreateStream(h) if err != nil { t.Fatalf("%d: error creating stdin stream: %v", i, err) } defer stream.Reset() _, err = stream.Write([]byte(expectedStdin)) if err != nil { t.Fatalf("%d: error writing to stdin stream: %v", i, err) } } var stdoutStream httpstream.Stream if test.stdout { h.Set(api.StreamType, api.StreamTypeStdout) stdoutStream, err = conn.CreateStream(h) if err != nil { t.Fatalf("%d: error creating stdout stream: %v", i, err) } defer stdoutStream.Reset() } var stderrStream httpstream.Stream if test.stderr && !test.tty { h.Set(api.StreamType, api.StreamTypeStderr) stderrStream, err = conn.CreateStream(h) if err != nil { t.Fatalf("%d: error creating stderr stream: %v", i, err) } defer stderrStream.Reset() } if test.stdout { output := make([]byte, 10) n, err := stdoutStream.Read(output) close(clientStdoutReadDone) if err != nil { t.Fatalf("%d: error reading from stdout stream: %v", i, err) } if e, a := expectedStdout, string(output[0:n]); e != a { t.Fatalf("%d: stdout: expected '%v', got '%v'", i, e, a) } } if test.stderr && !test.tty { output := make([]byte, 10) n, err := stderrStream.Read(output) close(clientStderrReadDone) if err != nil { t.Fatalf("%d: error reading from stderr stream: %v", i, err) } if e, a := expectedStderr, string(output[0:n]); e != a { t.Fatalf("%d: stderr: expected '%v', got '%v'", i, e, a) } } <-execFuncDone } }
func (e *streamProtocolV1) stream(conn httpstream.Connection) error { doneChan := make(chan struct{}, 2) errorChan := make(chan error) cp := func(s string, dst io.Writer, src io.Reader) { glog.V(6).Infof("Copying %s", s) defer glog.V(6).Infof("Done copying %s", s) if _, err := io.Copy(dst, src); err != nil && err != io.EOF { glog.Errorf("Error copying %s: %v", s, err) } if s == api.StreamTypeStdout || s == api.StreamTypeStderr { doneChan <- struct{}{} } } var ( err error errorStream, remoteStdin, remoteStdout, remoteStderr httpstream.Stream ) // set up all the streams first headers := http.Header{} headers.Set(api.StreamType, api.StreamTypeError) errorStream, err = conn.CreateStream(headers) if err != nil { return err } defer errorStream.Reset() // Create all the streams first, then start the copy goroutines. The server doesn't start its copy // goroutines until it's received all of the streams. If the client creates the stdin stream and // immediately begins copying stdin data to the server, it's possible to overwhelm and wedge the // spdy frame handler in the server so that it is full of unprocessed frames. The frames aren't // getting processed because the server hasn't started its copying, and it won't do that until it // gets all the streams. By creating all the streams first, we ensure that the server is ready to // process data before the client starts sending any. See https://issues.k8s.io/16373 for more info. if e.stdin != nil { headers.Set(api.StreamType, api.StreamTypeStdin) remoteStdin, err = conn.CreateStream(headers) if err != nil { return err } defer remoteStdin.Reset() } if e.stdout != nil { headers.Set(api.StreamType, api.StreamTypeStdout) remoteStdout, err = conn.CreateStream(headers) if err != nil { return err } defer remoteStdout.Reset() } if e.stderr != nil && !e.tty { headers.Set(api.StreamType, api.StreamTypeStderr) remoteStderr, err = conn.CreateStream(headers) if err != nil { return err } defer remoteStderr.Reset() } // now that all the streams have been created, proceed with reading & copying // always read from errorStream go func() { message, err := ioutil.ReadAll(errorStream) if err != nil && err != io.EOF { errorChan <- fmt.Errorf("Error reading from error stream: %s", err) return } if len(message) > 0 { errorChan <- fmt.Errorf("Error executing remote command: %s", message) return } }() if e.stdin != nil { // TODO this goroutine will never exit cleanly (the io.Copy never unblocks) // because stdin is not closed until the process exits. If we try to call // stdin.Close(), it returns no error but doesn't unblock the copy. It will // exit when the process exits, instead. go cp(api.StreamTypeStdin, remoteStdin, e.stdin) } waitCount := 0 completedStreams := 0 if e.stdout != nil { waitCount++ go cp(api.StreamTypeStdout, e.stdout, remoteStdout) } if e.stderr != nil && !e.tty { waitCount++ go cp(api.StreamTypeStderr, e.stderr, remoteStderr) } Loop: for { select { case <-doneChan: completedStreams++ if completedStreams == waitCount { break Loop } case err := <-errorChan: return err } } return nil }
func (s *Server) createStreams(request *restful.Request, response *restful.Response) (io.Reader, io.WriteCloser, io.WriteCloser, io.WriteCloser, httpstream.Connection, bool, bool) { // start at 1 for error stream expectedStreams := 1 if request.QueryParameter(api.ExecStdinParam) == "1" { expectedStreams++ } if request.QueryParameter(api.ExecStdoutParam) == "1" { expectedStreams++ } tty := request.QueryParameter(api.ExecTTYParam) == "1" if !tty && request.QueryParameter(api.ExecStderrParam) == "1" { expectedStreams++ } if expectedStreams == 1 { response.WriteError(http.StatusBadRequest, fmt.Errorf("you must specify at least 1 of stdin, stdout, stderr")) return nil, nil, nil, nil, nil, false, false } streamCh := make(chan httpstream.Stream) upgrader := spdy.NewResponseUpgrader() conn := upgrader.UpgradeResponse(response.ResponseWriter, request.Request, func(stream httpstream.Stream) error { streamCh <- stream return nil }) // from this point on, we can no longer call methods on response if conn == nil { // The upgrader is responsible for notifying the client of any errors that // occurred during upgrading. All we can do is return here at this point // if we weren't successful in upgrading. return nil, nil, nil, nil, nil, false, false } conn.SetIdleTimeout(s.host.StreamingConnectionIdleTimeout()) // TODO make it configurable? expired := time.NewTimer(streamCreationTimeout) var errorStream, stdinStream, stdoutStream, stderrStream httpstream.Stream receivedStreams := 0 WaitForStreams: for { select { case stream := <-streamCh: streamType := stream.Headers().Get(api.StreamType) switch streamType { case api.StreamTypeError: errorStream = stream defer errorStream.Reset() receivedStreams++ case api.StreamTypeStdin: stdinStream = stream receivedStreams++ case api.StreamTypeStdout: stdoutStream = stream receivedStreams++ case api.StreamTypeStderr: stderrStream = stream receivedStreams++ default: glog.Errorf("Unexpected stream type: '%s'", streamType) } if receivedStreams == expectedStreams { break WaitForStreams } case <-expired.C: // TODO find a way to return the error to the user. Maybe use a separate // stream to report errors? glog.Error("Timed out waiting for client to create streams") return nil, nil, nil, nil, nil, false, false } } if stdinStream != nil { // close our half of the input stream, since we won't be writing to it stdinStream.Close() } return stdinStream, stdoutStream, stderrStream, errorStream, conn, tty, true }
func (e *streamProtocolV1) stream(conn httpstream.Connection) error { doneChan := make(chan struct{}, 2) errorChan := make(chan error) cp := func(s string, dst io.Writer, src io.Reader) { glog.V(6).Infof("Copying %s", s) defer glog.V(6).Infof("Done copying %s", s) if _, err := io.Copy(dst, src); err != nil && err != io.EOF { glog.Errorf("Error copying %s: %v", s, err) } if s == api.StreamTypeStdout || s == api.StreamTypeStderr { doneChan <- struct{}{} } } var ( err error errorStream, remoteStdin, remoteStdout, remoteStderr httpstream.Stream ) // set up all the streams first headers := http.Header{} headers.Set(api.StreamType, api.StreamTypeError) errorStream, err = conn.CreateStream(headers) if err != nil { return err } defer errorStream.Reset() if e.stdin != nil { headers.Set(api.StreamType, api.StreamTypeStdin) remoteStdin, err = conn.CreateStream(headers) if err != nil { return err } defer remoteStdin.Reset() } if e.stdout != nil { headers.Set(api.StreamType, api.StreamTypeStdout) remoteStdout, err = conn.CreateStream(headers) if err != nil { return err } defer remoteStdout.Reset() } if e.stderr != nil && !e.tty { headers.Set(api.StreamType, api.StreamTypeStderr) remoteStderr, err = conn.CreateStream(headers) if err != nil { return err } defer remoteStderr.Reset() } // now that all the streams have been created, proceed with reading & copying // always read from errorStream go func() { message, err := ioutil.ReadAll(errorStream) if err != nil && err != io.EOF { errorChan <- fmt.Errorf("Error reading from error stream: %s", err) return } if len(message) > 0 { errorChan <- fmt.Errorf("Error executing remote command: %s", message) return } }() if e.stdin != nil { // TODO this goroutine will never exit cleanly (the io.Copy never unblocks) // because stdin is not closed until the process exits. If we try to call // stdin.Close(), it returns no error but doesn't unblock the copy. It will // exit when the process exits, instead. go cp(api.StreamTypeStdin, remoteStdin, e.stdin) } waitCount := 0 completedStreams := 0 if e.stdout != nil { waitCount++ go cp(api.StreamTypeStdout, e.stdout, remoteStdout) } if e.stderr != nil && !e.tty { waitCount++ go cp(api.StreamTypeStderr, e.stderr, remoteStderr) } Loop: for { select { case <-doneChan: completedStreams++ if completedStreams == waitCount { break Loop } case err := <-errorChan: return err } } return nil }