Ejemplo n.º 1
0
func (c *Client) hijack(method, path string, hijackOptions hijackOptions) error {
	if path != "/version" && !c.SkipServerVersionCheck && c.expectedAPIVersion == nil {
		err := c.checkAPIVersion()
		if err != nil {
			return err
		}
	}

	var params io.Reader
	if hijackOptions.data != nil {
		buf, err := json.Marshal(hijackOptions.data)
		if err != nil {
			return err
		}
		params = bytes.NewBuffer(buf)
	}

	if hijackOptions.stdout == nil {
		hijackOptions.stdout = ioutil.Discard
	}
	if hijackOptions.stderr == nil {
		hijackOptions.stderr = ioutil.Discard
	}
	req, err := http.NewRequest(method, c.getURL(path), params)
	if err != nil {
		return err
	}
	req.Header.Set("Content-Type", "plain/text")
	protocol := c.endpointURL.Scheme
	address := c.endpointURL.Path
	if protocol != "unix" {
		protocol = "tcp"
		address = c.endpointURL.Host
	}
	var dial net.Conn
	if c.TLSConfig != nil && protocol != "unix" {
		dial, err = tlsDial(protocol, address, c.TLSConfig)
		if err != nil {
			return err
		}
	} else {
		dial, err = net.Dial(protocol, address)
		if err != nil {
			return err
		}
	}
	clientconn := httputil.NewClientConn(dial, nil)
	defer clientconn.Close()
	clientconn.Do(req)
	if hijackOptions.success != nil {
		hijackOptions.success <- struct{}{}
		<-hijackOptions.success
	}
	rwc, br := clientconn.Hijack()
	defer rwc.Close()
	errChanOut := make(chan error, 1)
	errChanIn := make(chan error, 1)
	exit := make(chan bool)
	go func() {
		defer close(exit)
		defer close(errChanOut)
		var err error
		if hijackOptions.setRawTerminal {
			// When TTY is ON, use regular copy
			_, err = io.Copy(hijackOptions.stdout, br)
		} else {
			_, err = stdcopy.StdCopy(hijackOptions.stdout, hijackOptions.stderr, br)
		}
		errChanOut <- err
	}()
	go func() {
		if hijackOptions.in != nil {
			_, err := io.Copy(rwc, hijackOptions.in)
			errChanIn <- err
		}
		rwc.(interface {
			CloseWrite() error
		}).CloseWrite()
	}()
	<-exit
	select {
	case err = <-errChanIn:
		return err
	case err = <-errChanOut:
		return err
	}
}
Ejemplo n.º 2
0
func (c *Client) stream(method, path string, streamOptions streamOptions) error {
	if (method == "POST" || method == "PUT") && streamOptions.in == nil {
		streamOptions.in = bytes.NewReader(nil)
	}
	if path != "/version" && !c.SkipServerVersionCheck && c.expectedAPIVersion == nil {
		err := c.checkAPIVersion()
		if err != nil {
			return err
		}
	}
	req, err := http.NewRequest(method, c.getURL(path), streamOptions.in)
	if err != nil {
		return err
	}
	req.Header.Set("User-Agent", userAgent)
	if method == "POST" {
		req.Header.Set("Content-Type", "plain/text")
	}
	for key, val := range streamOptions.headers {
		req.Header.Set(key, val)
	}
	var resp *http.Response
	protocol := c.endpointURL.Scheme
	address := c.endpointURL.Path
	if streamOptions.stdout == nil {
		streamOptions.stdout = ioutil.Discard
	}
	if streamOptions.stderr == nil {
		streamOptions.stderr = ioutil.Discard
	}
	if protocol == "unix" {
		dial, err := net.Dial(protocol, address)
		if err != nil {
			return err
		}
		clientconn := httputil.NewClientConn(dial, nil)
		resp, err = clientconn.Do(req)
		defer clientconn.Close()
	} else {
		resp, err = c.HTTPClient.Do(req)
	}
	if err != nil {
		if strings.Contains(err.Error(), "connection refused") {
			return ErrConnectionRefused
		}
		return err
	}
	defer resp.Body.Close()
	if resp.StatusCode < 200 || resp.StatusCode >= 400 {
		body, err := ioutil.ReadAll(resp.Body)
		if err != nil {
			return err
		}
		return newError(resp.StatusCode, body)
	}
	if resp.Header.Get("Content-Type") == "application/json" {
		// if we want to get raw json stream, just copy it back to output
		// without decoding it
		if streamOptions.rawJSONStream {
			_, err = io.Copy(streamOptions.stdout, resp.Body)
			return err
		}
		dec := json.NewDecoder(resp.Body)
		for {
			var m jsonMessage
			if err := dec.Decode(&m); err == io.EOF {
				break
			} else if err != nil {
				return err
			}
			if m.Stream != "" {
				fmt.Fprint(streamOptions.stdout, m.Stream)
			} else if m.Progress != "" {
				fmt.Fprintf(streamOptions.stdout, "%s %s\r", m.Status, m.Progress)
			} else if m.Error != "" {
				return errors.New(m.Error)
			}
			if m.Status != "" {
				fmt.Fprintln(streamOptions.stdout, m.Status)
			}
		}
	} else {
		if streamOptions.setRawTerminal {
			_, err = io.Copy(streamOptions.stdout, resp.Body)
		} else {
			_, err = stdcopy.StdCopy(streamOptions.stdout, streamOptions.stderr, resp.Body)
		}
		return err
	}
	return nil
}