func fakeExecServer(t *testing.T, i int, stdinData, stdoutData, stderrData, errorData string, tty bool) http.HandlerFunc { // error + stdin + stdout expectedStreams := 3 if !tty { // stderr expectedStreams++ } return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { streamCh := make(chan httpstream.Stream) upgrader := spdy.NewResponseUpgrader() conn := upgrader.UpgradeResponse(w, req, func(stream httpstream.Stream) error { streamCh <- stream return nil }) // from this point on, we can no longer call methods on w if conn == nil { // The upgrader is responsible for notifying the client of any errors that // occurred during upgrading. All we can do is return here at this point // if we weren't successful in upgrading. return } defer conn.Close() var errorStream, stdinStream, stdoutStream, stderrStream httpstream.Stream receivedStreams := 0 WaitForStreams: for { select { case stream := <-streamCh: streamType := stream.Headers().Get(api.StreamType) switch streamType { case api.StreamTypeError: errorStream = stream receivedStreams++ case api.StreamTypeStdin: stdinStream = stream stdinStream.Close() receivedStreams++ case api.StreamTypeStdout: stdoutStream = stream receivedStreams++ case api.StreamTypeStderr: stderrStream = stream receivedStreams++ default: t.Errorf("%d: unexpected stream type: %q", i, streamType) } defer stream.Reset() if receivedStreams == expectedStreams { break WaitForStreams } } } if len(errorData) > 0 { fmt.Fprint(errorStream, errorData) errorStream.Close() } if len(stdoutData) > 0 { fmt.Fprint(stdoutStream, stdoutData) stdoutStream.Close() } if len(stderrData) > 0 { fmt.Fprint(stderrStream, stderrData) stderrStream.Close() } if len(stdinData) > 0 { data, err := ioutil.ReadAll(stdinStream) if err != nil { t.Errorf("%d: error reading stdin stream: %v", i, err) } if e, a := stdinData, string(data); e != a { t.Errorf("%d: stdin: expected %q, got %q", i, e, a) } } }) }
func (s *Server) createStreams(request *restful.Request, response *restful.Response) (io.Reader, io.WriteCloser, io.WriteCloser, io.WriteCloser, httpstream.Connection, bool, bool) { // start at 1 for error stream expectedStreams := 1 if request.QueryParameter(api.ExecStdinParam) == "1" { expectedStreams++ } if request.QueryParameter(api.ExecStdoutParam) == "1" { expectedStreams++ } tty := request.QueryParameter(api.ExecTTYParam) == "1" if !tty && request.QueryParameter(api.ExecStderrParam) == "1" { expectedStreams++ } if expectedStreams == 1 { response.WriteError(http.StatusBadRequest, fmt.Errorf("you must specify at least 1 of stdin, stdout, stderr")) return nil, nil, nil, nil, nil, false, false } streamCh := make(chan httpstream.Stream) upgrader := spdy.NewResponseUpgrader() conn := upgrader.UpgradeResponse(response.ResponseWriter, request.Request, func(stream httpstream.Stream) error { streamCh <- stream return nil }) // from this point on, we can no longer call methods on response if conn == nil { // The upgrader is responsible for notifying the client of any errors that // occurred during upgrading. All we can do is return here at this point // if we weren't successful in upgrading. return nil, nil, nil, nil, nil, false, false } conn.SetIdleTimeout(s.host.StreamingConnectionIdleTimeout()) // TODO make it configurable? expired := time.NewTimer(streamCreationTimeout) var errorStream, stdinStream, stdoutStream, stderrStream httpstream.Stream receivedStreams := 0 WaitForStreams: for { select { case stream := <-streamCh: streamType := stream.Headers().Get(api.StreamType) switch streamType { case api.StreamTypeError: errorStream = stream defer errorStream.Reset() receivedStreams++ case api.StreamTypeStdin: stdinStream = stream receivedStreams++ case api.StreamTypeStdout: stdoutStream = stream receivedStreams++ case api.StreamTypeStderr: stderrStream = stream receivedStreams++ default: glog.Errorf("Unexpected stream type: '%s'", streamType) } if receivedStreams == expectedStreams { break WaitForStreams } case <-expired.C: // TODO find a way to return the error to the user. Maybe use a separate // stream to report errors? glog.Error("Timed out waiting for client to create streams") return nil, nil, nil, nil, nil, false, false } } if stdinStream != nil { // close our half of the input stream, since we won't be writing to it stdinStream.Close() } return stdinStream, stdoutStream, stderrStream, errorStream, conn, tty, true }
func fakeExecServer(t *testing.T, i int, stdinData, stdoutData, stderrData, errorData string, tty bool, messageCount int) http.HandlerFunc { // error + stdin + stdout expectedStreams := 3 if !tty { // stderr expectedStreams++ } return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { protocol, err := httpstream.Handshake(req, w, []string{StreamProtocolV2Name}, StreamProtocolV1Name) if err != nil { t.Fatal(err) } if protocol != StreamProtocolV2Name { t.Fatalf("unexpected protocol: %s", protocol) } streamCh := make(chan streamAndReply) upgrader := spdy.NewResponseUpgrader() conn := upgrader.UpgradeResponse(w, req, func(stream httpstream.Stream, replySent <-chan struct{}) error { streamCh <- streamAndReply{Stream: stream, replySent: replySent} return nil }) // from this point on, we can no longer call methods on w if conn == nil { // The upgrader is responsible for notifying the client of any errors that // occurred during upgrading. All we can do is return here at this point // if we weren't successful in upgrading. return } defer conn.Close() var errorStream, stdinStream, stdoutStream, stderrStream httpstream.Stream receivedStreams := 0 replyChan := make(chan struct{}) stop := make(chan struct{}) defer close(stop) WaitForStreams: for { select { case stream := <-streamCh: streamType := stream.Headers().Get(api.StreamType) switch streamType { case api.StreamTypeError: errorStream = stream go waitStreamReply(stream.replySent, replyChan, stop) case api.StreamTypeStdin: stdinStream = stream go waitStreamReply(stream.replySent, replyChan, stop) case api.StreamTypeStdout: stdoutStream = stream go waitStreamReply(stream.replySent, replyChan, stop) case api.StreamTypeStderr: stderrStream = stream go waitStreamReply(stream.replySent, replyChan, stop) default: t.Errorf("%d: unexpected stream type: %q", i, streamType) } if receivedStreams == expectedStreams { break WaitForStreams } case <-replyChan: receivedStreams++ if receivedStreams == expectedStreams { break WaitForStreams } } } if len(errorData) > 0 { n, err := fmt.Fprint(errorStream, errorData) if err != nil { t.Errorf("%d: error writing to errorStream: %v", i, err) } if e, a := len(errorData), n; e != a { t.Errorf("%d: expected to write %d bytes to errorStream, but only wrote %d", i, e, a) } errorStream.Close() } if len(stdoutData) > 0 { for j := 0; j < messageCount; j++ { n, err := fmt.Fprint(stdoutStream, stdoutData) if err != nil { t.Errorf("%d: error writing to stdoutStream: %v", i, err) } if e, a := len(stdoutData), n; e != a { t.Errorf("%d: expected to write %d bytes to stdoutStream, but only wrote %d", i, e, a) } } stdoutStream.Close() } if len(stderrData) > 0 { for j := 0; j < messageCount; j++ { n, err := fmt.Fprint(stderrStream, stderrData) if err != nil { t.Errorf("%d: error writing to stderrStream: %v", i, err) } if e, a := len(stderrData), n; e != a { t.Errorf("%d: expected to write %d bytes to stderrStream, but only wrote %d", i, e, a) } } stderrStream.Close() } if len(stdinData) > 0 { data := make([]byte, len(stdinData)) for j := 0; j < messageCount; j++ { n, err := io.ReadFull(stdinStream, data) if err != nil { t.Errorf("%d: error reading stdin stream: %v", i, err) } if e, a := len(stdinData), n; e != a { t.Errorf("%d: expected to read %d bytes from stdinStream, but only read %d", i, e, a) } if e, a := stdinData, string(data); e != a { t.Errorf("%d: stdin: expected %q, got %q", i, e, a) } } stdinStream.Close() } }) }
func (e *streamProtocolV2) stream(conn httpstream.Connection) error { var ( err error errorStream, remoteStdin, remoteStdout, remoteStderr httpstream.Stream ) headers := http.Header{} // set up all the streams first // set up error stream errorChan := make(chan error) headers.Set(api.StreamType, api.StreamTypeError) errorStream, err = conn.CreateStream(headers) if err != nil { return err } // set up stdin stream if e.stdin != nil { headers.Set(api.StreamType, api.StreamTypeStdin) remoteStdin, err = conn.CreateStream(headers) if err != nil { return err } } // set up stdout stream if e.stdout != nil { headers.Set(api.StreamType, api.StreamTypeStdout) remoteStdout, err = conn.CreateStream(headers) if err != nil { return err } } // set up stderr stream if e.stderr != nil && !e.tty { headers.Set(api.StreamType, api.StreamTypeStderr) remoteStderr, err = conn.CreateStream(headers) if err != nil { return err } } // now that all the streams have been created, proceed with reading & copying // always read from errorStream go func() { message, err := ioutil.ReadAll(errorStream) switch { case err != nil && err != io.EOF: errorChan <- fmt.Errorf("error reading from error stream: %s", err) case len(message) > 0: errorChan <- fmt.Errorf("error executing remote command: %s", message) default: errorChan <- nil } close(errorChan) }() var wg sync.WaitGroup var once sync.Once if e.stdin != nil { // copy from client's stdin to container's stdin go func() { // if e.stdin is noninteractive, e.g. `echo abc | kubectl exec -i <pod> -- cat`, make sure // we close remoteStdin as soon as the copy from e.stdin to remoteStdin finishes. Otherwise // the executed command will remain running. defer once.Do(func() { remoteStdin.Close() }) if _, err := io.Copy(remoteStdin, e.stdin); err != nil { runtime.HandleError(err) } }() // read from remoteStdin until the stream is closed. this is essential to // be able to exit interactive sessions cleanly and not leak goroutines or // hang the client's terminal. // // go-dockerclient's current hijack implementation // (https://github.com/fsouza/go-dockerclient/blob/89f3d56d93788dfe85f864a44f85d9738fca0670/client.go#L564) // waits for all three streams (stdin/stdout/stderr) to finish copying // before returning. When hijack finishes copying stdout/stderr, it calls // Close() on its side of remoteStdin, which allows this copy to complete. // When that happens, we must Close() on our side of remoteStdin, to // allow the copy in hijack to complete, and hijack to return. go func() { defer once.Do(func() { remoteStdin.Close() }) // this "copy" doesn't actually read anything - it's just here to wait for // the server to close remoteStdin. if _, err := io.Copy(ioutil.Discard, remoteStdin); err != nil { runtime.HandleError(err) } }() } if e.stdout != nil { wg.Add(1) go func() { defer wg.Done() if _, err := io.Copy(e.stdout, remoteStdout); err != nil { runtime.HandleError(err) } }() } if e.stderr != nil && !e.tty { wg.Add(1) go func() { defer wg.Done() if _, err := io.Copy(e.stderr, remoteStderr); err != nil { runtime.HandleError(err) } }() } // we're waiting for stdout/stderr to finish copying wg.Wait() // waits for errorStream to finish reading with an error or nil return <-errorChan }