Ejemplo n.º 1
0
func (c *Client) hijack(method, path string, hijackOptions hijackOptions) (CloseWaiter, error) {
	if path != "/version" && !c.SkipServerVersionCheck && c.expectedAPIVersion == nil {
		err := c.checkAPIVersion()
		if err != nil {
			return nil, err
		}
	}
	var params io.Reader
	if hijackOptions.data != nil {
		buf, err := json.Marshal(hijackOptions.data)
		if err != nil {
			return nil, err
		}
		params = bytes.NewBuffer(buf)
	}
	req, err := http.NewRequest(method, c.getURL(path), params)
	if err != nil {
		return nil, err
	}
	req.Header.Set("Content-Type", "application/json")
	req.Header.Set("Connection", "Upgrade")
	req.Header.Set("Upgrade", "tcp")
	protocol := c.endpointURL.Scheme
	address := c.endpointURL.Path
	if protocol != "unix" {
		protocol = "tcp"
		address = c.endpointURL.Host
	}
	var dial net.Conn
	if c.TLSConfig != nil && protocol != "unix" {
		dial, err = tlsDialWithDialer(c.Dialer, protocol, address, c.TLSConfig)
		if err != nil {
			return nil, err
		}
	} else {
		dial, err = c.Dialer.Dial(protocol, address)
		if err != nil {
			return nil, err
		}
	}

	errs := make(chan error)
	quit := make(chan struct{})
	go func() {
		clientconn := httputil.NewClientConn(dial, nil)
		defer clientconn.Close()
		clientconn.Do(req)
		if hijackOptions.success != nil {
			hijackOptions.success <- struct{}{}
			<-hijackOptions.success
		}
		rwc, br := clientconn.Hijack()
		defer rwc.Close()

		errChanOut := make(chan error, 1)
		errChanIn := make(chan error, 1)
		if hijackOptions.stdout == nil && hijackOptions.stderr == nil {
			close(errChanOut)
		} else {
			// Only copy if hijackOptions.stdout and/or hijackOptions.stderr is actually set.
			// Otherwise, if the only stream you care about is stdin, your attach session
			// will "hang" until the container terminates, even though you're not reading
			// stdout/stderr
			if hijackOptions.stdout == nil {
				hijackOptions.stdout = ioutil.Discard
			}
			if hijackOptions.stderr == nil {
				hijackOptions.stderr = ioutil.Discard
			}

			go func() {
				defer func() {
					if hijackOptions.in != nil {
						if closer, ok := hijackOptions.in.(io.Closer); ok {
							closer.Close()
						}
						errChanIn <- nil
					}
				}()

				var err error
				if hijackOptions.setRawTerminal {
					_, err = io.Copy(hijackOptions.stdout, br)
				} else {
					_, err = stdcopy.StdCopy(hijackOptions.stdout, hijackOptions.stderr, br)
				}
				errChanOut <- err
			}()
		}

		go func() {
			var err error
			if hijackOptions.in != nil {
				_, err = io.Copy(rwc, hijackOptions.in)
			}
			errChanIn <- err
			rwc.(interface {
				CloseWrite() error
			}).CloseWrite()
		}()

		var errIn error
		select {
		case errIn = <-errChanIn:
		case <-quit:
			return
		}

		var errOut error
		select {
		case errOut = <-errChanOut:
		case <-quit:
			return
		}

		if errIn != nil {
			errs <- errIn
		} else {
			errs <- errOut
		}
	}()

	return struct {
		closerFunc
		waiterFunc
	}{
		closerFunc(func() error { close(quit); return nil }),
		waiterFunc(func() error { return <-errs }),
	}, nil
}
Ejemplo n.º 2
0
func (c *Client) stream(method, path string, streamOptions streamOptions) error {
	if (method == "POST" || method == "PUT") && streamOptions.in == nil {
		streamOptions.in = bytes.NewReader(nil)
	}
	if path != "/version" && !c.SkipServerVersionCheck && c.expectedAPIVersion == nil {
		err := c.checkAPIVersion()
		if err != nil {
			return err
		}
	}
	req, err := http.NewRequest(method, c.getURL(path), streamOptions.in)
	if err != nil {
		return err
	}
	req.Header.Set("User-Agent", userAgent)
	if method == "POST" {
		req.Header.Set("Content-Type", "plain/text")
	}
	for key, val := range streamOptions.headers {
		req.Header.Set(key, val)
	}
	var resp *http.Response
	protocol := c.endpointURL.Scheme
	address := c.endpointURL.Path
	if streamOptions.stdout == nil {
		streamOptions.stdout = ioutil.Discard
	}
	if streamOptions.stderr == nil {
		streamOptions.stderr = ioutil.Discard
	}
	if protocol == "unix" {
		dial, err := c.Dialer.Dial(protocol, address)
		if err != nil {
			return err
		}
		defer dial.Close()
		breader := bufio.NewReader(dial)
		err = req.Write(dial)
		if err != nil {
			return err
		}

		// ReadResponse may hang if server does not replay
		if streamOptions.timeout > 0 {
			dial.SetDeadline(time.Now().Add(streamOptions.timeout))
		}

		if resp, err = http.ReadResponse(breader, req); err != nil {
			// Cancel timeout for future I/O operations
			if streamOptions.timeout > 0 {
				dial.SetDeadline(time.Time{})
			}
			if strings.Contains(err.Error(), "connection refused") {
				return ErrConnectionRefused
			}
			return err
		}
	} else {
		if resp, err = c.HTTPClient.Do(req); err != nil {
			if strings.Contains(err.Error(), "connection refused") {
				return ErrConnectionRefused
			}
			return err
		}
	}
	defer resp.Body.Close()
	if resp.StatusCode < 200 || resp.StatusCode >= 400 {
		return newError(resp)
	}
	if streamOptions.useJSONDecoder || resp.Header.Get("Content-Type") == "application/json" {
		// if we want to get raw json stream, just copy it back to output
		// without decoding it
		if streamOptions.rawJSONStream {
			_, err = io.Copy(streamOptions.stdout, resp.Body)
			return err
		}
		dec := json.NewDecoder(resp.Body)
		for {
			var m jsonMessage
			if err := dec.Decode(&m); err == io.EOF {
				break
			} else if err != nil {
				return err
			}
			if m.Stream != "" {
				fmt.Fprint(streamOptions.stdout, m.Stream)
			} else if m.Progress != "" {
				fmt.Fprintf(streamOptions.stdout, "%s %s\r", m.Status, m.Progress)
			} else if m.Error != "" {
				return errors.New(m.Error)
			}
			if m.Status != "" {
				fmt.Fprintln(streamOptions.stdout, m.Status)
			}
		}
	} else {
		if streamOptions.setRawTerminal {
			_, err = io.Copy(streamOptions.stdout, resp.Body)
		} else {
			_, err = stdcopy.StdCopy(streamOptions.stdout, streamOptions.stderr, resp.Body)
		}
		return err
	}
	return nil
}