// requestID returns the request id for stream. func (h *portForwardStreamHandler) requestID(stream httpstream.Stream) string { requestID := stream.Headers().Get(api.PortForwardRequestIDHeader) if len(requestID) == 0 { glog.V(5).Infof("(conn=%p) stream received without %s header", h.conn, api.PortForwardRequestIDHeader) // If we get here, it's because the connection came from an older client // that isn't generating the request id header // (https://github.com/kubernetes/kubernetes/blob/843134885e7e0b360eb5441e85b1410a8b1a7a0c/pkg/client/unversioned/portforward/portforward.go#L258-L287) // // This is a best-effort attempt at supporting older clients. // // When there aren't concurrent new forwarded connections, each connection // will have a pair of streams (data, error), and the stream IDs will be // consecutive odd numbers, e.g. 1 and 3 for the first connection. Convert // the stream ID into a pseudo-request id by taking the stream type and // using id = stream.Identifier() when the stream type is error, // and id = stream.Identifier() - 2 when it's data. // // NOTE: this only works when there are not concurrent new streams from // multiple forwarded connections; it's a best-effort attempt at supporting // old clients that don't generate request ids. If there are concurrent // new connections, it's possible that 1 connection gets streams whose IDs // are not consecutive (e.g. 5 and 9 instead of 5 and 7). streamType := stream.Headers().Get(api.StreamType) switch streamType { case api.StreamTypeError: requestID = strconv.Itoa(int(stream.Identifier())) case api.StreamTypeData: requestID = strconv.Itoa(int(stream.Identifier()) - 2) } glog.V(5).Infof("(conn=%p) automatically assigning request ID=%q from stream type=%s, stream ID=%d", h.conn, requestID, streamType, stream.Identifier()) } return requestID }
// add adds the stream to the portForwardStreamPair. If the pair already // contains a stream for the new stream's type, an error is returned. add // returns true if both the data and error streams for this pair have been // received. func (p *portForwardStreamPair) add(stream httpstream.Stream) (bool, error) { p.lock.Lock() defer p.lock.Unlock() switch stream.Headers().Get(api.StreamType) { case api.StreamTypeError: if p.errorStream != nil { return false, errors.New("error stream already assigned") } p.errorStream = stream case api.StreamTypeData: if p.dataStream != nil { return false, errors.New("data stream already assigned") } p.dataStream = stream } complete := p.errorStream != nil && p.dataStream != nil if complete { close(p.complete) } return complete, nil }
func TestServeExecInContainer(t *testing.T) { tests := []struct { stdin bool stdout bool stderr bool tty bool responseStatusCode int uid bool }{ {responseStatusCode: http.StatusBadRequest}, {stdin: true, responseStatusCode: http.StatusSwitchingProtocols}, {stdout: true, responseStatusCode: http.StatusSwitchingProtocols}, {stderr: true, responseStatusCode: http.StatusSwitchingProtocols}, {stdout: true, stderr: true, responseStatusCode: http.StatusSwitchingProtocols}, {stdout: true, stderr: true, tty: true, responseStatusCode: http.StatusSwitchingProtocols}, {stdin: true, stdout: true, stderr: true, responseStatusCode: http.StatusSwitchingProtocols}, } for i, test := range tests { fw := newServerTest() fw.fakeKubelet.streamingConnectionIdleTimeoutFunc = func() time.Duration { return 0 } podNamespace := "other" podName := "foo" expectedPodName := getPodName(podName, podNamespace) expectedUid := "9b01b80f-8fb4-11e4-95ab-4200af06647" expectedContainerName := "baz" expectedCommand := "ls -a" expectedStdin := "stdin" expectedStdout := "stdout" expectedStderr := "stderr" execFuncDone := make(chan struct{}) clientStdoutReadDone := make(chan struct{}) clientStderrReadDone := make(chan struct{}) fw.fakeKubelet.execFunc = func(podFullName string, uid types.UID, containerName string, cmd []string, in io.Reader, out, stderr io.WriteCloser, tty bool) error { defer close(execFuncDone) if podFullName != expectedPodName { t.Fatalf("%d: podFullName: expected %s, got %s", i, expectedPodName, podFullName) } if test.uid && string(uid) != expectedUid { t.Fatalf("%d: uid: expected %v, got %v", i, expectedUid, uid) } if containerName != expectedContainerName { t.Fatalf("%d: containerName: expected %s, got %s", i, expectedContainerName, containerName) } if strings.Join(cmd, " ") != expectedCommand { t.Fatalf("%d: cmd: expected: %s, got %v", i, expectedCommand, cmd) } if test.stdin { if in == nil { t.Fatalf("%d: stdin: expected non-nil", i) } b := make([]byte, 10) n, err := in.Read(b) if err != nil { t.Fatalf("%d: error reading from stdin: %v", i, err) } if e, a := expectedStdin, string(b[0:n]); e != a { t.Fatalf("%d: stdin: expected to read %v, got %v", i, e, a) } } else if in != nil { t.Fatalf("%d: stdin: expected nil: %#v", i, in) } if test.stdout { if out == nil { t.Fatalf("%d: stdout: expected non-nil", i) } _, err := out.Write([]byte(expectedStdout)) if err != nil { t.Fatalf("%d:, error writing to stdout: %v", i, err) } out.Close() <-clientStdoutReadDone } else if out != nil { t.Fatalf("%d: stdout: expected nil: %#v", i, out) } if tty { if stderr != nil { t.Fatalf("%d: tty set but received non-nil stderr: %v", i, stderr) } } else if test.stderr { if stderr == nil { t.Fatalf("%d: stderr: expected non-nil", i) } _, err := stderr.Write([]byte(expectedStderr)) if err != nil { t.Fatalf("%d:, error writing to stderr: %v", i, err) } stderr.Close() <-clientStderrReadDone } else if stderr != nil { t.Fatalf("%d: stderr: expected nil: %#v", i, stderr) } return nil } var url string if test.uid { url = fw.testHTTPServer.URL + "/exec/" + podNamespace + "/" + podName + "/" + expectedUid + "/" + expectedContainerName + "?command=ls&command=-a" } else { url = fw.testHTTPServer.URL + "/exec/" + podNamespace + "/" + podName + "/" + expectedContainerName + "?command=ls&command=-a" } if test.stdin { url += "&" + api.ExecStdinParam + "=1" } if test.stdout { url += "&" + api.ExecStdoutParam + "=1" } if test.stderr && !test.tty { url += "&" + api.ExecStderrParam + "=1" } if test.tty { url += "&" + api.ExecTTYParam + "=1" } var ( resp *http.Response err error upgradeRoundTripper httpstream.UpgradeRoundTripper c *http.Client ) if test.responseStatusCode != http.StatusSwitchingProtocols { c = &http.Client{} } else { upgradeRoundTripper = spdy.NewRoundTripper(nil) c = &http.Client{Transport: upgradeRoundTripper} } resp, err = c.Post(url, "", nil) if err != nil { t.Fatalf("%d: Got error POSTing: %v", i, err) } defer resp.Body.Close() _, err = ioutil.ReadAll(resp.Body) if err != nil { t.Errorf("%d: Error reading response body: %v", i, err) } if e, a := test.responseStatusCode, resp.StatusCode; e != a { t.Fatalf("%d: response status: expected %v, got %v", i, e, a) } if test.responseStatusCode != http.StatusSwitchingProtocols { continue } conn, err := upgradeRoundTripper.NewConnection(resp) if err != nil { t.Fatalf("Unexpected error creating streaming connection: %s", err) } if conn == nil { t.Fatalf("%d: unexpected nil conn", i) } defer conn.Close() h := http.Header{} h.Set(api.StreamType, api.StreamTypeError) errorStream, err := conn.CreateStream(h) if err != nil { t.Fatalf("%d: error creating error stream: %v", i, err) } defer errorStream.Reset() if test.stdin { h.Set(api.StreamType, api.StreamTypeStdin) stream, err := conn.CreateStream(h) if err != nil { t.Fatalf("%d: error creating stdin stream: %v", i, err) } defer stream.Reset() _, err = stream.Write([]byte(expectedStdin)) if err != nil { t.Fatalf("%d: error writing to stdin stream: %v", i, err) } } var stdoutStream httpstream.Stream if test.stdout { h.Set(api.StreamType, api.StreamTypeStdout) stdoutStream, err = conn.CreateStream(h) if err != nil { t.Fatalf("%d: error creating stdout stream: %v", i, err) } defer stdoutStream.Reset() } var stderrStream httpstream.Stream if test.stderr && !test.tty { h.Set(api.StreamType, api.StreamTypeStderr) stderrStream, err = conn.CreateStream(h) if err != nil { t.Fatalf("%d: error creating stderr stream: %v", i, err) } defer stderrStream.Reset() } if test.stdout { output := make([]byte, 10) n, err := stdoutStream.Read(output) close(clientStdoutReadDone) if err != nil { t.Fatalf("%d: error reading from stdout stream: %v", i, err) } if e, a := expectedStdout, string(output[0:n]); e != a { t.Fatalf("%d: stdout: expected '%v', got '%v'", i, e, a) } } if test.stderr && !test.tty { output := make([]byte, 10) n, err := stderrStream.Read(output) close(clientStderrReadDone) if err != nil { t.Fatalf("%d: error reading from stderr stream: %v", i, err) } if e, a := expectedStderr, string(output[0:n]); e != a { t.Fatalf("%d: stderr: expected '%v', got '%v'", i, e, a) } } <-execFuncDone } }
func fakeExecServer(t *testing.T, i int, stdinData, stdoutData, stderrData, errorData string, tty bool, messageCount int) http.HandlerFunc { // error + stdin + stdout expectedStreams := 3 if !tty { // stderr expectedStreams++ } return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { protocol, err := httpstream.Handshake(req, w, []string{StreamProtocolV2Name}, StreamProtocolV1Name) if err != nil { t.Fatal(err) } if protocol != StreamProtocolV2Name { t.Fatalf("unexpected protocol: %s", protocol) } streamCh := make(chan httpstream.Stream) upgrader := spdy.NewResponseUpgrader() conn := upgrader.UpgradeResponse(w, req, func(stream httpstream.Stream) error { streamCh <- stream return nil }) // from this point on, we can no longer call methods on w if conn == nil { // The upgrader is responsible for notifying the client of any errors that // occurred during upgrading. All we can do is return here at this point // if we weren't successful in upgrading. return } defer conn.Close() var errorStream, stdinStream, stdoutStream, stderrStream httpstream.Stream receivedStreams := 0 WaitForStreams: for { select { case stream := <-streamCh: streamType := stream.Headers().Get(api.StreamType) switch streamType { case api.StreamTypeError: errorStream = stream receivedStreams++ case api.StreamTypeStdin: stdinStream = stream receivedStreams++ case api.StreamTypeStdout: stdoutStream = stream receivedStreams++ case api.StreamTypeStderr: stderrStream = stream receivedStreams++ default: t.Errorf("%d: unexpected stream type: %q", i, streamType) } if receivedStreams == expectedStreams { break WaitForStreams } } } if len(errorData) > 0 { n, err := fmt.Fprint(errorStream, errorData) if err != nil { t.Errorf("%d: error writing to errorStream: %v", i, err) } if e, a := len(errorData), n; e != a { t.Errorf("%d: expected to write %d bytes to errorStream, but only wrote %d", i, e, a) } errorStream.Close() } if len(stdoutData) > 0 { for j := 0; j < messageCount; j++ { n, err := fmt.Fprint(stdoutStream, stdoutData) if err != nil { t.Errorf("%d: error writing to stdoutStream: %v", i, err) } if e, a := len(stdoutData), n; e != a { t.Errorf("%d: expected to write %d bytes to stdoutStream, but only wrote %d", i, e, a) } } stdoutStream.Close() } if len(stderrData) > 0 { for j := 0; j < messageCount; j++ { n, err := fmt.Fprint(stderrStream, stderrData) if err != nil { t.Errorf("%d: error writing to stderrStream: %v", i, err) } if e, a := len(stderrData), n; e != a { t.Errorf("%d: expected to write %d bytes to stderrStream, but only wrote %d", i, e, a) } } stderrStream.Close() } if len(stdinData) > 0 { data := make([]byte, len(stdinData)) for j := 0; j < messageCount; j++ { n, err := io.ReadFull(stdinStream, data) if err != nil { t.Errorf("%d: error reading stdin stream: %v", i, err) } if e, a := len(stdinData), n; e != a { t.Errorf("%d: expected to read %d bytes from stdinStream, but only read %d", i, e, a) } if e, a := stdinData, string(data); e != a { t.Errorf("%d: stdin: expected %q, got %q", i, e, a) } } stdinStream.Close() } }) }