예제 #1
0
func (c *Client) hijack(method, path string, hijackOptions hijackOptions) error {
	if path != "/version" && !c.SkipServerVersionCheck && c.expectedAPIVersion == nil {
		err := c.checkAPIVersion()
		if err != nil {
			return err
		}
	}

	var params io.Reader
	if hijackOptions.data != nil {
		buf, err := json.Marshal(hijackOptions.data)
		if err != nil {
			return err
		}
		params = bytes.NewBuffer(buf)
	}

	if hijackOptions.stdout == nil {
		hijackOptions.stdout = ioutil.Discard
	}
	if hijackOptions.stderr == nil {
		hijackOptions.stderr = ioutil.Discard
	}
	req, err := http.NewRequest(method, c.getURL(path), params)
	if err != nil {
		return err
	}
	req.Header.Set("Content-Type", "plain/text")
	req.Header.Set("Connection", "Upgrade")
	req.Header.Set("Upgrade", "tcp")
	protocol := c.endpointURL.Scheme
	address := c.endpointURL.Path
	if protocol != "unix" {
		protocol = "tcp"
		address = c.endpointURL.Host
	}
	var dial net.Conn
	if c.TLSConfig != nil && protocol != "unix" {
		dial, err = tlsDialWithDialer(c.Dialer, protocol, address, c.TLSConfig)
		if err != nil {
			return err
		}
	} else {
		dial, err = c.Dialer.Dial(protocol, address)
		if err != nil {
			return err
		}
	}
	clientconn := httputil.NewClientConn(dial, nil)
	defer clientconn.Close()
	clientconn.Do(req)
	if hijackOptions.success != nil {
		hijackOptions.success <- struct{}{}
		<-hijackOptions.success
	}
	rwc, br := clientconn.Hijack()
	defer rwc.Close()
	errChanOut := make(chan error, 1)
	errChanIn := make(chan error, 1)
	go func() {
		defer func() {
			if hijackOptions.in != nil {
				if closer, ok := hijackOptions.in.(io.Closer); ok {
					closer.Close()
				}
			}
		}()
		var err error
		if hijackOptions.setRawTerminal {
			_, err = io.Copy(hijackOptions.stdout, br)
		} else {
			_, err = stdcopy.StdCopy(hijackOptions.stdout, hijackOptions.stderr, br)
		}
		errChanOut <- err
	}()
	go func() {
		var err error
		if hijackOptions.in != nil {
			_, err = io.Copy(rwc, hijackOptions.in)
		}
		errChanIn <- err
		rwc.(interface {
			CloseWrite() error
		}).CloseWrite()
	}()
	errIn := <-errChanIn
	errOut := <-errChanOut
	if errIn != nil {
		return errIn
	}
	return errOut
}
예제 #2
0
func (c *Client) stream(method, path string, streamOptions streamOptions) error {
	if (method == "POST" || method == "PUT") && streamOptions.in == nil {
		streamOptions.in = bytes.NewReader(nil)
	}
	if path != "/version" && !c.SkipServerVersionCheck && c.expectedAPIVersion == nil {
		err := c.checkAPIVersion()
		if err != nil {
			return err
		}
	}
	req, err := http.NewRequest(method, c.getURL(path), streamOptions.in)
	if err != nil {
		return err
	}
	req.Header.Set("User-Agent", userAgent)
	if method == "POST" {
		req.Header.Set("Content-Type", "plain/text")
	}
	for key, val := range streamOptions.headers {
		req.Header.Set(key, val)
	}
	var resp *http.Response
	protocol := c.endpointURL.Scheme
	address := c.endpointURL.Path
	if streamOptions.stdout == nil {
		streamOptions.stdout = ioutil.Discard
	} else if t, ok := streamOptions.stdout.(io.Closer); ok {
		defer t.Close()
	}
	if streamOptions.stderr == nil {
		streamOptions.stderr = ioutil.Discard
	} else if t, ok := streamOptions.stderr.(io.Closer); ok {
		defer t.Close()
	}
	if protocol == "unix" {
		dial, err := c.Dialer.Dial(protocol, address)
		if err != nil {
			return err
		}
		defer dial.Close()
		breader := bufio.NewReader(dial)
		err = req.Write(dial)
		if err != nil {
			return err
		}

		// ReadResponse may hang if server does not replay
		if streamOptions.timeout > 0 {
			dial.SetDeadline(time.Now().Add(streamOptions.timeout))
		}

		if resp, err = http.ReadResponse(breader, req); err != nil {
			// Cancel timeout for future I/O operations
			if streamOptions.timeout > 0 {
				dial.SetDeadline(time.Time{})
			}
			if strings.Contains(err.Error(), "connection refused") {
				return ErrConnectionRefused
			}
			return err
		}
	} else {
		if resp, err = c.HTTPClient.Do(req); err != nil {
			if strings.Contains(err.Error(), "connection refused") {
				return ErrConnectionRefused
			}
			return err
		}
	}
	defer resp.Body.Close()
	if resp.StatusCode < 200 || resp.StatusCode >= 400 {
		return newError(resp)
	}
	if streamOptions.useJSONDecoder || resp.Header.Get("Content-Type") == "application/json" {
		// if we want to get raw json stream, just copy it back to output
		// without decoding it
		if streamOptions.rawJSONStream {
			_, err = io.Copy(streamOptions.stdout, resp.Body)
			return err
		}
		dec := json.NewDecoder(resp.Body)
		for {
			var m jsonMessage
			if err := dec.Decode(&m); err == io.EOF {
				break
			} else if err != nil {
				return err
			}
			if m.Stream != "" {
				fmt.Fprint(streamOptions.stdout, m.Stream)
			} else if m.Progress != "" {
				fmt.Fprintf(streamOptions.stdout, "%s %s\r", m.Status, m.Progress)
			} else if m.Error != "" {
				return errors.New(m.Error)
			}
			if m.Status != "" {
				fmt.Fprintln(streamOptions.stdout, m.Status)
			}
		}
	} else {
		if streamOptions.setRawTerminal {
			_, err = io.Copy(streamOptions.stdout, resp.Body)
		} else {
			_, err = stdcopy.StdCopy(streamOptions.stdout, streamOptions.stderr, resp.Body)
		}
		return err
	}
	return nil
}