func (c *Client) hijack(method, path string, hijackOptions hijackOptions) error { if path != "/version" && !c.SkipServerVersionCheck && c.expectedAPIVersion == nil { err := c.checkAPIVersion() if err != nil { return err } } var params io.Reader if hijackOptions.data != nil { buf, err := json.Marshal(hijackOptions.data) if err != nil { return err } params = bytes.NewBuffer(buf) } if hijackOptions.stdout == nil { hijackOptions.stdout = ioutil.Discard } if hijackOptions.stderr == nil { hijackOptions.stderr = ioutil.Discard } req, err := http.NewRequest(method, c.getURL(path), params) if err != nil { return err } req.Header.Set("Content-Type", "plain/text") req.Header.Set("Connection", "Upgrade") req.Header.Set("Upgrade", "tcp") protocol := c.endpointURL.Scheme address := c.endpointURL.Path if protocol != "unix" { protocol = "tcp" address = c.endpointURL.Host } var dial net.Conn if c.TLSConfig != nil && protocol != "unix" { dial, err = tlsDial(protocol, address, c.TLSConfig) if err != nil { return err } } else { dial, err = net.Dial(protocol, address) if err != nil { return err } } clientconn := httputil.NewClientConn(dial, nil) defer clientconn.Close() clientconn.Do(req) if hijackOptions.success != nil { hijackOptions.success <- struct{}{} <-hijackOptions.success } rwc, br := clientconn.Hijack() defer rwc.Close() errChanOut := make(chan error, 1) errChanIn := make(chan error, 1) go func() { defer func() { if hijackOptions.in != nil { if closer, ok := hijackOptions.in.(io.Closer); ok { closer.Close() } } }() var err error if hijackOptions.setRawTerminal { _, err = io.Copy(hijackOptions.stdout, br) } else { _, err = stdcopy.StdCopy(hijackOptions.stdout, hijackOptions.stderr, br) } errChanOut <- err }() go func() { var err error if hijackOptions.in != nil { _, err = io.Copy(rwc, hijackOptions.in) } errChanIn <- err rwc.(interface { CloseWrite() error }).CloseWrite() }() errIn := <-errChanIn errOut := <-errChanOut if errIn != nil { return errIn } return errOut }
func (c *Client) stream(method, path string, streamOptions streamOptions) error { if (method == "POST" || method == "PUT") && streamOptions.in == nil { streamOptions.in = bytes.NewReader(nil) } if path != "/version" && !c.SkipServerVersionCheck && c.expectedAPIVersion == nil { err := c.checkAPIVersion() if err != nil { return err } } req, err := http.NewRequest(method, c.getURL(path), streamOptions.in) if err != nil { return err } req.Header.Set("User-Agent", userAgent) if method == "POST" { req.Header.Set("Content-Type", "plain/text") } for key, val := range streamOptions.headers { req.Header.Set(key, val) } var resp *http.Response protocol := c.endpointURL.Scheme address := c.endpointURL.Path if streamOptions.stdout == nil { streamOptions.stdout = ioutil.Discard } else if t, ok := streamOptions.stdout.(io.Closer); ok { defer t.Close() } if streamOptions.stderr == nil { streamOptions.stderr = ioutil.Discard } else if t, ok := streamOptions.stderr.(io.Closer); ok { defer t.Close() } if protocol == "unix" { dial, err := net.Dial(protocol, address) if err != nil { return err } defer dial.Close() breader := bufio.NewReader(dial) err = req.Write(dial) if err != nil { return err } // ReadResponse may hang if server does not replay if streamOptions.timeout > 0 { dial.SetDeadline(time.Now().Add(streamOptions.timeout)) } if resp, err = http.ReadResponse(breader, req); err != nil { // Cancel timeout for future I/O operations if streamOptions.timeout > 0 { dial.SetDeadline(time.Time{}) } if strings.Contains(err.Error(), "connection refused") { return ErrConnectionRefused } return err } } else { if resp, err = c.HTTPClient.Do(req); err != nil { if strings.Contains(err.Error(), "connection refused") { return ErrConnectionRefused } return err } } defer resp.Body.Close() if resp.StatusCode < 200 || resp.StatusCode >= 400 { body, err := ioutil.ReadAll(resp.Body) if err != nil { return err } return newError(resp.StatusCode, body) } if streamOptions.useJSONDecoder || resp.Header.Get("Content-Type") == "application/json" { // if we want to get raw json stream, just copy it back to output // without decoding it if streamOptions.rawJSONStream { _, err = io.Copy(streamOptions.stdout, resp.Body) return err } dec := json.NewDecoder(resp.Body) for { var m jsonMessage if err := dec.Decode(&m); err == io.EOF { break } else if err != nil { return err } if m.Stream != "" { fmt.Fprint(streamOptions.stdout, m.Stream) } else if m.Progress != "" { fmt.Fprintf(streamOptions.stdout, "%s %s\r", m.Status, m.Progress) } else if m.Error != "" { return errors.New(m.Error) } if m.Status != "" { fmt.Fprintln(streamOptions.stdout, m.Status) } } } else { if streamOptions.setRawTerminal { _, err = io.Copy(streamOptions.stdout, resp.Body) } else { _, err = stdcopy.StdCopy(streamOptions.stdout, streamOptions.stderr, resp.Body) } return err } return nil }