예제 #1
0
파일: types.go 프로젝트: vasiliyl/playwand
func (fd *Fd) readFrom(c *net.UnixConn) error {
	var b []byte
	oob := make([]byte, 16)
	_, oobn, _, _, err := c.ReadMsgUnix(b, oob)
	if err != nil {
		return err
	}
	if oobn == 0 {
		return errors.New("error reading oob")
	}
	oob = oob[:oobn]

	scms, err := syscall.ParseSocketControlMessage(oob)
	if err != nil {
		return err
	}
	if len(scms) != 1 {
		return fmt.Errorf("expected 1 SocketControlMessage, got %d", len(scms))
	}
	scm := scms[0]
	fds, err := syscall.ParseUnixRights(&scm)
	if err != nil {
		return nil
	}
	if len(fds) != 1 {
		return fmt.Errorf("expected 1 fd, got %d", len(fds))
	}
	*fd = Fd(fds[0])
	return nil
}
예제 #2
0
파일: unix.go 프로젝트: ChaosCloud/docker
func extractFds(oob []byte) (fds []int) {
	// Grab forklock to make sure no forks accidentally inherit the new
	// fds before they are made CLOEXEC
	// There is a slight race condition between ReadMsgUnix returns and
	// when we grap the lock, so this is not perfect. Unfortunately
	// There is no way to pass MSG_CMSG_CLOEXEC to recvmsg() nor any
	// way to implement non-blocking i/o in go, so this is hard to fix.
	syscall.ForkLock.Lock()
	defer syscall.ForkLock.Unlock()
	scms, err := syscall.ParseSocketControlMessage(oob)
	if err != nil {
		return
	}
	for _, scm := range scms {
		gotFds, err := syscall.ParseUnixRights(&scm)
		if err != nil {
			continue
		}
		fds = append(fds, gotFds...)

		for _, fd := range fds {
			syscall.CloseOnExec(fd)
		}
	}
	return
}
예제 #3
0
파일: client.go 프로젝트: devick/flynn
func (r *FDReader) Read(b []byte) (int, error) {
	oob := make([]byte, 32)
	n, oobn, _, _, err := r.conn.ReadMsgUnix(b, oob)
	if err != nil {
		if n < 0 {
			n = 0
		}
		return n, err
	}
	if oobn > 0 {
		messages, err := syscall.ParseSocketControlMessage(oob[:oobn])
		if err != nil {
			return n, err
		}
		for _, m := range messages {
			fds, err := syscall.ParseUnixRights(&m)
			if err != nil {
				return n, err
			}

			// Set the CLOEXEC flag on the FDs so they won't be leaked into future forks
			for _, fd := range fds {
				if _, _, errno := syscall.Syscall(syscall.SYS_FCNTL, uintptr(fd), syscall.F_SETFD, syscall.FD_CLOEXEC); errno != 0 {
					return n, errno
				}

				r.FDs[r.fdCount] = fd
				r.fdCount++
			}
		}
	}

	return n, nil
}
예제 #4
0
func (s *OOBUnixConn) Read(p []byte) (n int, err error) {
	var oob [OOBMaxLength]byte
	n, oobn, _, _, err := s.ReadMsgUnix(p, oob[:])
	if err == nil && n > 0 && oobn > 0 {
		scm, err := syscall.ParseSocketControlMessage(oob[0:oobn])
		if err != nil {
			return n, err
		}
		s.m.Lock()
		for _, m := range scm {
			if m.Header.Level != syscall.SOL_SOCKET {
				continue
			}
			switch m.Header.Type {
			case syscall.SCM_RIGHTS:
				if fds, err := syscall.ParseUnixRights(&m); err == nil {
					for _, fd := range fds {
						// Note: We wrap the raw FDs inside an os.File just
						// once, early, to prevent double-free or leaking FDs.
						f := NewFile(fd)
						s.recvFiles = append(s.recvFiles, f)
					}
				}
			}
		}
		s.m.Unlock()
	}
	return n, err
}
예제 #5
0
func mount(dir string, ready chan<- struct{}, errp *error) (fusefd *os.File, err error) {
	// linux mount is never delayed
	close(ready)

	fds, err := syscall.Socketpair(syscall.AF_FILE, syscall.SOCK_STREAM, 0)
	if err != nil {
		return nil, fmt.Errorf("socketpair error: %v", err)
	}
	defer syscall.Close(fds[0])
	defer syscall.Close(fds[1])

	cmd := exec.Command("fusermount", "--", dir)
	cmd.Env = append(os.Environ(), "_FUSE_COMMFD=3")

	writeFile := os.NewFile(uintptr(fds[0]), "fusermount-child-writes")
	defer writeFile.Close()
	cmd.ExtraFiles = []*os.File{writeFile}

	out, err := cmd.CombinedOutput()
	if len(out) > 0 || err != nil {
		return nil, fmt.Errorf("fusermount: %q, %v", out, err)
	}

	readFile := os.NewFile(uintptr(fds[1]), "fusermount-parent-reads")
	defer readFile.Close()
	c, err := net.FileConn(readFile)
	if err != nil {
		return nil, fmt.Errorf("FileConn from fusermount socket: %v", err)
	}
	defer c.Close()

	uc, ok := c.(*net.UnixConn)
	if !ok {
		return nil, fmt.Errorf("unexpected FileConn type; expected UnixConn, got %T", c)
	}

	buf := make([]byte, 32) // expect 1 byte
	oob := make([]byte, 32) // expect 24 bytes
	_, oobn, _, _, err := uc.ReadMsgUnix(buf, oob)
	scms, err := syscall.ParseSocketControlMessage(oob[:oobn])
	if err != nil {
		return nil, fmt.Errorf("ParseSocketControlMessage: %v", err)
	}
	if len(scms) != 1 {
		return nil, fmt.Errorf("expected 1 SocketControlMessage; got scms = %#v", scms)
	}
	scm := scms[0]
	gotFds, err := syscall.ParseUnixRights(&scm)
	if err != nil {
		return nil, fmt.Errorf("syscall.ParseUnixRights: %v", err)
	}
	if len(gotFds) != 1 {
		return nil, fmt.Errorf("wanted 1 fd; got %#v", gotFds)
	}
	f := os.NewFile(uintptr(gotFds[0]), "/dev/fuse")
	return f, nil
}
예제 #6
0
func readRequest(c *net.UnixConn) (*Request, error) {
	var l uint32
	err := binary.Read(c, binary.BigEndian, &l)
	length := int(l)
	if err != nil {
		return nil, err
	}
	payload := make([]byte, length)
	n, err := c.Read(payload)
	if err != nil {
		return nil, err
	} else if n != length {
		return nil, fmt.Errorf("Payload was %d bytes rather than reported size of %d", n, length)
	}
	req := &Request{}
	err = json.Unmarshal(payload, req)
	if err != nil {
		return nil, err
	}
	if !req.HasFds {
		return req, nil
	}

	payload = make([]byte, 1)
	// TODO: does this buffer need to be configurable?
	oob := make([]byte, 8192)
	n, oobn, _, _, err := c.ReadMsgUnix(payload, oob)
	if err != nil && err != io.EOF {
		return nil, err
	}
	if n != 1 {
		return nil, fmt.Errorf("Error reading OOB filedescriptors")
	}
	oob = oob[0:oobn]
	scm, err := syscall.ParseSocketControlMessage(oob)
	if err != nil {
		return nil, fmt.Errorf("Error parsing socket control message: %v", err)
	}
	var fds []int
	for i := 0; i < len(scm); i++ {
		tfds, err := syscall.ParseUnixRights(&scm[i])
		if err == syscall.EINVAL {
			continue // Wasn't a UnixRights Control Message
		} else if err != nil {
			return nil, fmt.Errorf("Error parsing unix rights: %v", err)
		}
		fds = append(fds, tfds...)
	}
	if len(fds) == 0 {
		return nil, fmt.Errorf("Failed to receive any FDs on a request with HasFds == true")
	}
	req.ReceivedFds = fds
	return req, nil
}
예제 #7
0
func (m *Message) GetFD() uintptr {
	if m.control_msgs == nil {
		return 0
	}
	fds, err := syscall.ParseUnixRights(&m.control_msgs[0])
	if err != nil {
		panic("Unable to parse unix rights")
	}
	m.control_msgs = append(m.control_msgs[0:], m.control_msgs[1:]...)
	if len(fds) != 1 {
		panic("Expected 1 file descriptor, got more")
	}
	return uintptr(fds[0])
}
예제 #8
0
파일: conn.go 프로젝트: vasiliyl/playwand
func (c *Conn) ReadMessage() (m *Message, err error) {
	h, err := c.readHeader()
	if err != nil {
		return
	}

	log.Printf("ReadMessage: header = %s", h)

	m = &Message{
		object: h.object(),
		opcode: h.opcode(),
	}

	if h.size() == 0 {
		return
	}

	p := make([]byte, h.size())
	oob := make([]byte, 32)
	n, oobn, _, _, err := c.c.ReadMsgUnix(p, oob)
	if err != nil {
		return
	}
	if uint16(n) != h.size() {
		err = fmt.Errorf("expected %d bytes, got %d", h.size(), n)
		return
	}

	log.Printf("ReadMessage: n = %d, oobn = %d", n, oobn)

	m.p = bytes.NewBuffer(p)

	if oobn == 0 {
		return
	}

	oob = oob[:oobn]
	scms, err := syscall.ParseSocketControlMessage(oob)
	if err != nil {
		return
	}
	if len(scms) != 1 {
		err = fmt.Errorf("expected 1 SocketControlMessage, got %d", len(scms))
		return
	}
	scm := scms[0]
	m.fds, err = syscall.ParseUnixRights(&scm)
	return
}
예제 #9
0
func readData(conn net.Conn) ([]Fd, int, error) {
	var b [2048]byte
	var oob [2048]byte
	var response Response

	n, oobn, _, _, err := conn.(*net.UnixConn).ReadMsgUnix(b[:], oob[:])
	if err != nil {
		return nil, 0, fmt.Errorf("unix_socket: failed to read unix msg: %s (read: %d, %d)", err, n, oobn)
	}

	if n > 0 {
		err := json.Unmarshal(b[:n], &response)
		if err != nil {
			return nil, 0, fmt.Errorf("unix_socket: Unmarshal failed: %s", err)
		}

		if response.ErrMessage != "" {
			return nil, 0, errors.New(response.ErrMessage)
		}
	} else {
		return nil, 0, errors.New("unix_socket: No response received")
	}

	scms, err := syscall.ParseSocketControlMessage(oob[:oobn])
	if err != nil {
		return nil, 0, fmt.Errorf("unix_socket: failed to parse socket control message: %s", err)
	}

	if len(scms) < 1 {
		return nil, 0, fmt.Errorf("unix_socket: no socket control messages sent")
	}

	scm := scms[0]
	fds, err := syscall.ParseUnixRights(&scm)
	if err != nil {
		return nil, 0, fmt.Errorf("unix_socket: failed to parse unix rights: %s", err)
	}

	files := make([]Fd, len(fds))
	for i, fd := range fds {
		files[i] = os.NewFile(uintptr(fd), fmt.Sprintf("/dev/fake-fd-%d", i))
	}

	return files, response.Pid, nil
}
예제 #10
0
func extractFileDescriptorFromOOB(oob []byte) (int, error) {
	scms, err := syscall.ParseSocketControlMessage(oob)
	if err != nil {
		return -1, err
	}
	if len(scms) != 1 {
		return -1, errors.New(fmt.Sprintf("expected 1 SocketControlMessage; got scms = %#v", scms))
	}
	scm := scms[0]
	gotFds, err := syscall.ParseUnixRights(&scm)
	if err != nil {
		return -1, err
	}
	if len(gotFds) != 1 {
		return -1, errors.New(fmt.Sprintf("wanted 1 fd; got %#v", gotFds))
	}
	return gotFds[0], nil
}
예제 #11
0
// TestUnixRightsRoundtrip tests that UnixRights, ParseSocketControlMessage,
// and ParseUnixRights are able to successfully round-trip lists of file descriptors.
func TestUnixRightsRoundtrip(t *testing.T) {
	testCases := [...][][]int{
		{{42}},
		{{1, 2}},
		{{3, 4, 5}},
		{{}},
		{{1, 2}, {3, 4, 5}, {}, {7}},
	}
	for _, testCase := range testCases {
		b := []byte{}
		var n int
		for _, fds := range testCase {
			// Last assignment to n wins
			n = len(b) + syscall.CmsgLen(4*len(fds))
			b = append(b, syscall.UnixRights(fds...)...)
		}
		// Truncate b
		b = b[:n]

		scms, err := syscall.ParseSocketControlMessage(b)
		if err != nil {
			t.Fatalf("ParseSocketControlMessage: %v", err)
		}
		if len(scms) != len(testCase) {
			t.Fatalf("expected %v SocketControlMessage; got scms = %#v", len(testCase), scms)
		}
		for i, scm := range scms {
			gotFds, err := syscall.ParseUnixRights(&scm)
			if err != nil {
				t.Fatalf("ParseUnixRights: %v", err)
			}
			wantFds := testCase[i]
			if len(gotFds) != len(wantFds) {
				t.Fatalf("expected %v fds, got %#v", len(wantFds), gotFds)
			}
			for j, fd := range gotFds {
				if fd != wantFds[j] {
					t.Fatalf("expected fd %v, got %v", wantFds[j], fd)
				}
			}
		}
	}
}
예제 #12
0
파일: fd.go 프로젝트: ftrvxmtrx/fd
// Get receives file descriptors from a Unix domain socket.
//
// Num specifies the expected number of file descriptors in one message.
// Internal files' names to be assigned are specified via optional filenames
// argument.
//
// You need to close all files in the returned slice. The slice can be
// non-empty even if this function returns an error.
//
// Use net.FileConn() if you're receiving a network connection.
func Get(via *net.UnixConn, num int, filenames []string) ([]*os.File, error) {
	if num < 1 {
		return nil, nil
	}

	// get the underlying socket
	viaf, err := via.File()
	if err != nil {
		return nil, err
	}
	socket := int(viaf.Fd())
	defer viaf.Close()

	// recvmsg
	buf := make([]byte, syscall.CmsgSpace(num*4))
	_, _, _, _, err = syscall.Recvmsg(socket, nil, buf, 0)
	if err != nil {
		return nil, err
	}

	// parse control msgs
	var msgs []syscall.SocketControlMessage
	msgs, err = syscall.ParseSocketControlMessage(buf)

	// convert fds to files
	res := make([]*os.File, 0, len(msgs))
	for i := 0; i < len(msgs) && err == nil; i++ {
		var fds []int
		fds, err = syscall.ParseUnixRights(&msgs[i])

		for fi, fd := range fds {
			var filename string
			if fi < len(filenames) {
				filename = filenames[fi]
			}

			res = append(res, os.NewFile(uintptr(fd), filename))
		}
	}

	return res, err
}
예제 #13
0
func ReadFile(c *net.UnixConn, timeout time.Duration) (*os.File, error) {
	oob := make([]byte, 64)

	if timeout > 0 {
		deadline := time.Now().Add(timeout)
		if err := c.SetReadDeadline(deadline); err != nil {
			return nil, err
		}
	}

	_, oobn, flags, _, err := c.ReadMsgUnix(nil, oob)
	if err != nil {
		return nil, err
	}
	if flags != 0 || oobn <= 0 {
		panic("ReadMsgUnix: flags != 0 || oobn <= 0")
	}

	// file descriptors are now open in this process

	scm, err := syscall.ParseSocketControlMessage(oob[:oobn])
	if err != nil {
		return nil, err
	}
	if len(scm) != 1 {
		panic("invalid scm message")
	}

	fds, err := syscall.ParseUnixRights(&scm[0])
	if err != nil {
		return nil, err
	}
	if len(fds) != 1 {
		panic("invalid scm message")
	}

	return os.NewFile(uintptr(fds[0]), ""), nil
}
예제 #14
0
파일: message.go 프로젝트: Safe3/oz
func (m *Message) parseControlData(data []byte) error {
	cmsgs, err := syscall.ParseSocketControlMessage(data)
	if err != nil {
		return err
	}
	for _, cmsg := range cmsgs {
		switch cmsg.Header.Type {
		case syscall.SCM_CREDENTIALS:
			cred, err := syscall.ParseUnixCredentials(&cmsg)
			if err != nil {
				return err
			}
			m.Ucred = cred
		case syscall.SCM_RIGHTS:
			fds, err := syscall.ParseUnixRights(&cmsg)
			if err != nil {
				return err
			}
			m.Fds = fds
		}
	}
	return nil
}
예제 #15
0
파일: fdpass.go 프로젝트: kingpro/vitess
func RecvFd(conn *net.UnixConn) (*os.File, error) {
	buf := make([]byte, 32)
	oob := make([]byte, 32)
	_, oobn, _, _, err := conn.ReadMsgUnix(buf, oob)
	if err != nil {
		return nil, fmt.Errorf("recvfd: err %v", err)
	}
	scms, err := syscall.ParseSocketControlMessage(oob[:oobn])
	if err != nil {
		return nil, fmt.Errorf("recvfd: ParseSocketControlMessage failed %v", err)
	}
	if len(scms) != 1 {
		return nil, fmt.Errorf("recvfd: SocketControlMessage count not 1: %v", len(scms))
	}
	scm := scms[0]
	fds, err := syscall.ParseUnixRights(&scm)
	if err != nil {
		return nil, fmt.Errorf("recvfd: ParseUnixRights failed %v", err)
	}
	if len(fds) != 1 {
		return nil, fmt.Errorf("recvfd: fd count not 1: %v", len(fds))
	}
	return os.NewFile(uintptr(fds[0]), "passed-fd"), nil
}
예제 #16
0
파일: mount_linux.go 프로젝트: nyaxt/fuse
func mount(dir string, conf *mountConfig, ready chan<- struct{}, errp *error) (fusefd *os.File, err error) {
	// linux mount is never delayed
	close(ready)

	fds, err := syscall.Socketpair(syscall.AF_FILE, syscall.SOCK_STREAM, 0)
	if err != nil {
		return nil, fmt.Errorf("socketpair error: %v", err)
	}

	writeFile := os.NewFile(uintptr(fds[0]), "fusermount-child-writes")
	defer writeFile.Close()

	readFile := os.NewFile(uintptr(fds[1]), "fusermount-parent-reads")
	defer readFile.Close()

	cmd := exec.Command(
		"fusermount",
		"-o", conf.getOptions(),
		"--",
		dir,
	)
	cmd.Env = append(os.Environ(), "_FUSE_COMMFD=3")

	cmd.ExtraFiles = []*os.File{writeFile}

	var wg sync.WaitGroup
	stdout, err := cmd.StdoutPipe()
	if err != nil {
		return nil, fmt.Errorf("setting up fusermount stderr: %v", err)
	}
	stderr, err := cmd.StderrPipe()
	if err != nil {
		return nil, fmt.Errorf("setting up fusermount stderr: %v", err)
	}

	if err := cmd.Start(); err != nil {
		return nil, fmt.Errorf("fusermount: %v", err)
	}
	helperErrCh := make(chan error, 1)
	wg.Add(2)
	go lineLogger(&wg, "mount helper output", neverIgnoreLine, stdout)
	go lineLogger(&wg, "mount helper error", handleFusermountStderr(helperErrCh), stderr)
	wg.Wait()
	if err := cmd.Wait(); err != nil {
		// see if we have a better error to report
		select {
		case helperErr := <-helperErrCh:
			// log the Wait error if it's not what we expected
			if !isBoringFusermountError(err) {
				log.Printf("mount helper failed: %v", err)
			}
			// and now return what we grabbed from stderr as the real
			// error
			return nil, helperErr
		default:
			// nope, fall back to generic message
		}

		return nil, fmt.Errorf("fusermount: %v", err)
	}

	c, err := net.FileConn(readFile)
	if err != nil {
		return nil, fmt.Errorf("FileConn from fusermount socket: %v", err)
	}
	defer c.Close()

	uc, ok := c.(*net.UnixConn)
	if !ok {
		return nil, fmt.Errorf("unexpected FileConn type; expected UnixConn, got %T", c)
	}

	buf := make([]byte, 32) // expect 1 byte
	oob := make([]byte, 32) // expect 24 bytes
	_, oobn, _, _, err := uc.ReadMsgUnix(buf, oob)
	scms, err := syscall.ParseSocketControlMessage(oob[:oobn])
	if err != nil {
		return nil, fmt.Errorf("ParseSocketControlMessage: %v", err)
	}
	if len(scms) != 1 {
		return nil, fmt.Errorf("expected 1 SocketControlMessage; got scms = %#v", scms)
	}
	scm := scms[0]
	gotFds, err := syscall.ParseUnixRights(&scm)
	if err != nil {
		return nil, fmt.Errorf("syscall.ParseUnixRights: %v", err)
	}
	if len(gotFds) != 1 {
		return nil, fmt.Errorf("wanted 1 fd; got %#v", gotFds)
	}
	f := os.NewFile(uintptr(gotFds[0]), "/dev/fuse")
	return f, nil
}
예제 #17
0
func (t *unixTransport) ReadMessage() (*Message, error) {
	var (
		blen, hlen uint32
		csheader   [16]byte
		headers    []header
		order      binary.ByteOrder
		unixfds    uint32
	)
	// To be sure that all bytes of out-of-band data are read, we use a special
	// reader that uses ReadUnix on the underlying connection instead of Read
	// and gathers the out-of-band data in a buffer.
	rd := &oobReader{conn: t.UnixConn}
	// read the first 16 bytes (the part of the header that has a constant size),
	// from which we can figure out the length of the rest of the message
	if _, err := io.ReadFull(rd, csheader[:]); err != nil {
		return nil, err
	}
	switch csheader[0] {
	case 'l':
		order = binary.LittleEndian
	case 'B':
		order = binary.BigEndian
	default:
		return nil, InvalidMessageError("invalid byte order")
	}
	// csheader[4:8] -> length of message body, csheader[12:16] -> length of
	// header fields (without alignment)
	binary.Read(bytes.NewBuffer(csheader[4:8]), order, &blen)
	binary.Read(bytes.NewBuffer(csheader[12:]), order, &hlen)
	if hlen%8 != 0 {
		hlen += 8 - (hlen % 8)
	}

	// decode headers and look for unix fds
	headerdata := make([]byte, hlen+4)
	copy(headerdata, csheader[12:])
	if _, err := io.ReadFull(t, headerdata[4:]); err != nil {
		return nil, err
	}
	dec := newDecoder(bytes.NewBuffer(headerdata), order)
	dec.pos = 12
	vs, err := dec.Decode(Signature{"a(yv)"})
	if err != nil {
		return nil, err
	}
	Store(vs, &headers)
	for _, v := range headers {
		if v.Field == byte(FieldUnixFDs) {
			unixfds, _ = v.Variant.value.(uint32)
		}
	}
	all := make([]byte, 16+hlen+blen)
	copy(all, csheader[:])
	copy(all[16:], headerdata[4:])
	if _, err := io.ReadFull(rd, all[16+hlen:]); err != nil {
		return nil, err
	}
	if unixfds != 0 {
		if !t.hasUnixFDs {
			return nil, errors.New("dbus: got unix fds on unsupported transport")
		}
		// read the fds from the OOB data
		scms, err := syscall.ParseSocketControlMessage(rd.oob)
		if err != nil {
			return nil, err
		}
		if len(scms) != 1 {
			return nil, errors.New("dbus: received more than one socket control message")
		}
		fds, err := syscall.ParseUnixRights(&scms[0])
		if err != nil {
			return nil, err
		}
		msg, err := DecodeMessage(bytes.NewBuffer(all))
		if err != nil {
			return nil, err
		}
		// substitute the values in the message body (which are indices for the
		// array receiver via OOB) with the actual values
		for i, v := range msg.Body {
			if j, ok := v.(UnixFDIndex); ok {
				if uint32(j) >= unixfds {
					return nil, InvalidMessageError("invalid index for unix fd")
				}
				msg.Body[i] = UnixFD(fds[j])
			}
		}
		return msg, nil
	}
	return DecodeMessage(bytes.NewBuffer(all))
}
예제 #18
0
파일: link.go 프로젝트: glyn/pango
func Create(socketPath string, stdout io.Writer, stderr io.Writer) (*Link, error) {
	conn, err := net.Dial("unix", socketPath)
	if err != nil {
		return nil, fmt.Errorf("failed to connect to i/o daemon: %s", err)
	}

	var b [2048]byte
	var oob [2048]byte

	n, oobn, _, _, err := conn.(*net.UnixConn).ReadMsgUnix(b[:], oob[:])
	if err != nil {
		return nil, fmt.Errorf("failed to read unix msg: %s (read: %d, %d)", err, n, oobn)
	}

	scms, err := syscall.ParseSocketControlMessage(oob[:oobn])
	if err != nil {
		return nil, fmt.Errorf("failed to parse socket control message: %s", err)
	}

	if len(scms) < 1 {
		return nil, fmt.Errorf("no socket control messages sent")
	}

	scm := scms[0]

	fds, err := syscall.ParseUnixRights(&scm)
	if err != nil {
		return nil, fmt.Errorf("failed to parse unix rights: %s", err)
	}

	if len(fds) != 3 {
		return nil, fmt.Errorf("invalid number of fds; need 3, got %d", len(fds))
	}

	lstdout := os.NewFile(uintptr(fds[0]), "stdout")
	lstderr := os.NewFile(uintptr(fds[1]), "stderr")
	lstatus := os.NewFile(uintptr(fds[2]), "status")

	streaming := &sync.WaitGroup{}

	linkWriter := NewWriter(conn)

	streaming.Add(1)
	go func() {
		io.Copy(stdout, lstdout)
		lstdout.Close()
		streaming.Done()
	}()

	streaming.Add(1)
	go func() {
		io.Copy(stderr, lstderr)
		lstderr.Close()
		streaming.Done()
	}()

	done := make(chan struct{})
	go func() {
		streaming.Wait()
		close(done)
		conn.Close()
	}()

	return &Link{
		Writer: linkWriter,

		exitStatus: lstatus,
		done:       done,
	}, nil
}
예제 #19
0
// TestPassFD tests passing a file descriptor over a Unix socket.
//
// This test involved both a parent and child process. The parent
// process is invoked as a normal test, with "go test", which then
// runs the child process by running the current test binary with args
// "-test.run=^TestPassFD$" and an environment variable used to signal
// that the test should become the child process instead.
func TestPassFD(t *testing.T) {
	switch runtime.GOOS {
	case "dragonfly":
		// TODO(jsing): Figure out why sendmsg is returning EINVAL.
		t.Skip("skipping test on dragonfly")
	case "solaris":
		// TODO(aram): Figure out why ReadMsgUnix is returning empty message.
		t.Skip("skipping test on solaris, see issue 7402")
	case "darwin":
		switch runtime.GOARCH {
		case "arm", "arm64":
			t.Skipf("skipping test on %d/%s, no fork", runtime.GOOS, runtime.GOARCH)
		}
	}
	if os.Getenv("GO_WANT_HELPER_PROCESS") == "1" {
		passFDChild()
		return
	}

	tempDir, err := ioutil.TempDir("", "TestPassFD")
	if err != nil {
		t.Fatal(err)
	}
	defer os.RemoveAll(tempDir)

	fds, err := syscall.Socketpair(syscall.AF_LOCAL, syscall.SOCK_STREAM, 0)
	if err != nil {
		t.Fatalf("Socketpair: %v", err)
	}
	defer syscall.Close(fds[0])
	defer syscall.Close(fds[1])
	writeFile := os.NewFile(uintptr(fds[0]), "child-writes")
	readFile := os.NewFile(uintptr(fds[1]), "parent-reads")
	defer writeFile.Close()
	defer readFile.Close()

	cmd := exec.Command(os.Args[0], "-test.run=^TestPassFD$", "--", tempDir)
	cmd.Env = append(os.Environ(), "GO_WANT_HELPER_PROCESS=1")
	cmd.ExtraFiles = []*os.File{writeFile}

	out, err := cmd.CombinedOutput()
	if len(out) > 0 || err != nil {
		t.Fatalf("child process: %q, %v", out, err)
	}

	c, err := net.FileConn(readFile)
	if err != nil {
		t.Fatalf("FileConn: %v", err)
	}
	defer c.Close()

	uc, ok := c.(*net.UnixConn)
	if !ok {
		t.Fatalf("unexpected FileConn type; expected UnixConn, got %T", c)
	}

	buf := make([]byte, 32) // expect 1 byte
	oob := make([]byte, 32) // expect 24 bytes
	closeUnix := time.AfterFunc(5*time.Second, func() {
		t.Logf("timeout reading from unix socket")
		uc.Close()
	})
	_, oobn, _, _, err := uc.ReadMsgUnix(buf, oob)
	closeUnix.Stop()

	scms, err := syscall.ParseSocketControlMessage(oob[:oobn])
	if err != nil {
		t.Fatalf("ParseSocketControlMessage: %v", err)
	}
	if len(scms) != 1 {
		t.Fatalf("expected 1 SocketControlMessage; got scms = %#v", scms)
	}
	scm := scms[0]
	gotFds, err := syscall.ParseUnixRights(&scm)
	if err != nil {
		t.Fatalf("syscall.ParseUnixRights: %v", err)
	}
	if len(gotFds) != 1 {
		t.Fatalf("wanted 1 fd; got %#v", gotFds)
	}

	f := os.NewFile(uintptr(gotFds[0]), "fd-from-child")
	defer f.Close()

	got, err := ioutil.ReadAll(f)
	want := "Hello from child process!\n"
	if string(got) != want {
		t.Errorf("child process ReadAll: %q, %v; want %q", got, err, want)
	}
}
예제 #20
0
// Begin the process of mounting at the given directory, returning a connection
// to the kernel. Mounting continues in the background, and is complete when an
// error is written to the supplied channel. The file system may need to
// service the connection in order for mounting to complete.
func mount(
	dir string,
	cfg *MountConfig,
	ready chan<- error) (dev *os.File, err error) {
	// On linux, mounting is never delayed.
	ready <- nil

	// Create a socket pair.
	fds, err := syscall.Socketpair(syscall.AF_FILE, syscall.SOCK_STREAM, 0)
	if err != nil {
		err = fmt.Errorf("Socketpair: %v", err)
		return
	}

	// Wrap the sockets into os.File objects that we will pass off to fusermount.
	writeFile := os.NewFile(uintptr(fds[0]), "fusermount-child-writes")
	defer writeFile.Close()

	readFile := os.NewFile(uintptr(fds[1]), "fusermount-parent-reads")
	defer readFile.Close()

	// Start fusermount, passing it a buffer in which to write stderr.
	var stderr bytes.Buffer

	cmd := exec.Command(
		"fusermount",
		"-o", cfg.toOptionsString(),
		"--",
		dir,
	)

	cmd.Env = append(os.Environ(), "_FUSE_COMMFD=3")
	cmd.ExtraFiles = []*os.File{writeFile}
	cmd.Stderr = &stderr

	// Run the command.
	err = cmd.Run()
	if err != nil {
		err = fmt.Errorf("running fusermount: %v\n\nstderr:\n%s", err, stderr.Bytes())
		return
	}

	// Wrap the socket file in a connection.
	c, err := net.FileConn(readFile)
	if err != nil {
		err = fmt.Errorf("FileConn: %v", err)
		return
	}
	defer c.Close()

	// We expect to have a Unix domain socket.
	uc, ok := c.(*net.UnixConn)
	if !ok {
		err = fmt.Errorf("Expected UnixConn, got %T", c)
		return
	}

	// Read a message.
	buf := make([]byte, 32) // expect 1 byte
	oob := make([]byte, 32) // expect 24 bytes
	_, oobn, _, _, err := uc.ReadMsgUnix(buf, oob)
	if err != nil {
		err = fmt.Errorf("ReadMsgUnix: %v", err)
		return
	}

	// Parse the message.
	scms, err := syscall.ParseSocketControlMessage(oob[:oobn])
	if err != nil {
		err = fmt.Errorf("ParseSocketControlMessage: %v", err)
		return
	}

	// We expect one message.
	if len(scms) != 1 {
		err = fmt.Errorf("expected 1 SocketControlMessage; got scms = %#v", scms)
		return
	}

	scm := scms[0]

	// Pull out the FD returned by fusermount
	gotFds, err := syscall.ParseUnixRights(&scm)
	if err != nil {
		err = fmt.Errorf("syscall.ParseUnixRights: %v", err)
		return
	}

	if len(gotFds) != 1 {
		err = fmt.Errorf("wanted 1 fd; got %#v", gotFds)
		return
	}

	// Turn the FD into an os.File.
	dev = os.NewFile(uintptr(gotFds[0]), "/dev/fuse")

	return
}
예제 #21
0
// TestPassFD tests passing a file descriptor over a Unix socket.
//
// This test involved both a parent and child process. The parent
// process is invoked as a normal test, with "go test", which then
// runs the child process by running the current test binary with args
// "-test.run=^TestPassFD$" and an environment variable used to signal
// that the test should become the child process instead.
func TestPassFD(t *testing.T) {
	if os.Getenv("GO_WANT_HELPER_PROCESS") == "1" {
		passFDChild()
		return
	}

	tempDir, err := ioutil.TempDir("", "TestPassFD")
	if err != nil {
		t.Fatal(err)
	}
	defer os.RemoveAll(tempDir)

	fds, err := syscall.Socketpair(syscall.AF_LOCAL, syscall.SOCK_STREAM, 0)
	if err != nil {
		t.Fatalf("Socketpair: %v", err)
	}
	defer syscall.Close(fds[0])
	defer syscall.Close(fds[1])
	writeFile := os.NewFile(uintptr(fds[0]), "child-writes")
	readFile := os.NewFile(uintptr(fds[1]), "parent-reads")
	defer writeFile.Close()
	defer readFile.Close()

	cmd := exec.Command(os.Args[0], "-test.run=^TestPassFD$", "--", tempDir)
	cmd.Env = []string{"GO_WANT_HELPER_PROCESS=1"}
	cmd.ExtraFiles = []*os.File{writeFile}

	out, err := cmd.CombinedOutput()
	if len(out) > 0 || err != nil {
		t.Fatalf("child process: %q, %v", out, err)
	}

	c, err := net.FileConn(readFile)
	if err != nil {
		t.Fatalf("FileConn: %v", err)
	}
	defer c.Close()

	uc, ok := c.(*net.UnixConn)
	if !ok {
		t.Fatalf("unexpected FileConn type; expected UnixConn, got %T", c)
	}

	buf := make([]byte, 32) // expect 1 byte
	oob := make([]byte, 32) // expect 24 bytes
	closeUnix := time.AfterFunc(5*time.Second, func() {
		t.Logf("timeout reading from unix socket")
		uc.Close()
	})
	_, oobn, _, _, err := uc.ReadMsgUnix(buf, oob)
	closeUnix.Stop()

	scms, err := syscall.ParseSocketControlMessage(oob[:oobn])
	if err != nil {
		t.Fatalf("ParseSocketControlMessage: %v", err)
	}
	if len(scms) != 1 {
		t.Fatalf("expected 1 SocketControlMessage; got scms = %#v", scms)
	}
	scm := scms[0]
	gotFds, err := syscall.ParseUnixRights(&scm)
	if err != nil {
		t.Fatalf("syscall.ParseUnixRights: %v", err)
	}
	if len(gotFds) != 1 {
		t.Fatalf("wanted 1 fd; got %#v", gotFds)
	}

	f := os.NewFile(uintptr(gotFds[0]), "fd-from-child")
	defer f.Close()

	got, err := ioutil.ReadAll(f)
	want := "Hello from child process!\n"
	if string(got) != want {
		t.Errorf("child process ReadAll: %q, %v; want %q", got, err, want)
	}
}
예제 #22
0
func mount(dir string, conf *mountConfig, ready chan<- struct{}, errp *error) (fusefd *os.File, err error) {
	// linux mount is never delayed
	close(ready)

	fds, err := syscall.Socketpair(syscall.AF_FILE, syscall.SOCK_STREAM, 0)
	if err != nil {
		return nil, fmt.Errorf("socketpair error: %v", err)
	}

	writeFile := os.NewFile(uintptr(fds[0]), "fusermount-child-writes")
	defer writeFile.Close()

	readFile := os.NewFile(uintptr(fds[1]), "fusermount-parent-reads")
	defer readFile.Close()

	cmd := exec.Command(
		"fusermount",
		"-o", conf.getOptions(),
		"--",
		dir,
	)
	cmd.Env = append(os.Environ(), "_FUSE_COMMFD=3")

	cmd.ExtraFiles = []*os.File{writeFile}

	var wg sync.WaitGroup
	stdout, err := cmd.StdoutPipe()
	if err != nil {
		return nil, fmt.Errorf("setting up fusermount stderr: %v", err)
	}
	stderr, err := cmd.StderrPipe()
	if err != nil {
		return nil, fmt.Errorf("setting up fusermount stderr: %v", err)
	}

	if err := cmd.Start(); err != nil {
		return nil, fmt.Errorf("fusermount: %v", err)
	}
	wg.Add(2)
	go lineLogger(&wg, "mount helper output", stdout)
	go lineLogger(&wg, "mount helper error", stderr)
	wg.Wait()
	if err := cmd.Wait(); err != nil {
		return nil, fmt.Errorf("fusermount: %v", err)
	}

	c, err := net.FileConn(readFile)
	if err != nil {
		return nil, fmt.Errorf("FileConn from fusermount socket: %v", err)
	}
	defer c.Close()

	uc, ok := c.(*net.UnixConn)
	if !ok {
		return nil, fmt.Errorf("unexpected FileConn type; expected UnixConn, got %T", c)
	}

	buf := make([]byte, 32) // expect 1 byte
	oob := make([]byte, 32) // expect 24 bytes
	_, oobn, _, _, err := uc.ReadMsgUnix(buf, oob)
	scms, err := syscall.ParseSocketControlMessage(oob[:oobn])
	if err != nil {
		return nil, fmt.Errorf("ParseSocketControlMessage: %v", err)
	}
	if len(scms) != 1 {
		return nil, fmt.Errorf("expected 1 SocketControlMessage; got scms = %#v", scms)
	}
	scm := scms[0]
	gotFds, err := syscall.ParseUnixRights(&scm)
	if err != nil {
		return nil, fmt.Errorf("syscall.ParseUnixRights: %v", err)
	}
	if len(gotFds) != 1 {
		return nil, fmt.Errorf("wanted 1 fd; got %#v", gotFds)
	}
	f := os.NewFile(uintptr(gotFds[0]), "/dev/fuse")
	return f, nil
}