func (fd *Fd) readFrom(c *net.UnixConn) error { var b []byte oob := make([]byte, 16) _, oobn, _, _, err := c.ReadMsgUnix(b, oob) if err != nil { return err } if oobn == 0 { return errors.New("error reading oob") } oob = oob[:oobn] scms, err := syscall.ParseSocketControlMessage(oob) if err != nil { return err } if len(scms) != 1 { return fmt.Errorf("expected 1 SocketControlMessage, got %d", len(scms)) } scm := scms[0] fds, err := syscall.ParseUnixRights(&scm) if err != nil { return nil } if len(fds) != 1 { return fmt.Errorf("expected 1 fd, got %d", len(fds)) } *fd = Fd(fds[0]) return nil }
func extractFds(oob []byte) (fds []int) { // Grab forklock to make sure no forks accidentally inherit the new // fds before they are made CLOEXEC // There is a slight race condition between ReadMsgUnix returns and // when we grap the lock, so this is not perfect. Unfortunately // There is no way to pass MSG_CMSG_CLOEXEC to recvmsg() nor any // way to implement non-blocking i/o in go, so this is hard to fix. syscall.ForkLock.Lock() defer syscall.ForkLock.Unlock() scms, err := syscall.ParseSocketControlMessage(oob) if err != nil { return } for _, scm := range scms { gotFds, err := syscall.ParseUnixRights(&scm) if err != nil { continue } fds = append(fds, gotFds...) for _, fd := range fds { syscall.CloseOnExec(fd) } } return }
func (r *FDReader) Read(b []byte) (int, error) { oob := make([]byte, 32) n, oobn, _, _, err := r.conn.ReadMsgUnix(b, oob) if err != nil { if n < 0 { n = 0 } return n, err } if oobn > 0 { messages, err := syscall.ParseSocketControlMessage(oob[:oobn]) if err != nil { return n, err } for _, m := range messages { fds, err := syscall.ParseUnixRights(&m) if err != nil { return n, err } // Set the CLOEXEC flag on the FDs so they won't be leaked into future forks for _, fd := range fds { if _, _, errno := syscall.Syscall(syscall.SYS_FCNTL, uintptr(fd), syscall.F_SETFD, syscall.FD_CLOEXEC); errno != 0 { return n, errno } r.FDs[r.fdCount] = fd r.fdCount++ } } } return n, nil }
func (s *OOBUnixConn) Read(p []byte) (n int, err error) { var oob [OOBMaxLength]byte n, oobn, _, _, err := s.ReadMsgUnix(p, oob[:]) if err == nil && n > 0 && oobn > 0 { scm, err := syscall.ParseSocketControlMessage(oob[0:oobn]) if err != nil { return n, err } s.m.Lock() for _, m := range scm { if m.Header.Level != syscall.SOL_SOCKET { continue } switch m.Header.Type { case syscall.SCM_RIGHTS: if fds, err := syscall.ParseUnixRights(&m); err == nil { for _, fd := range fds { // Note: We wrap the raw FDs inside an os.File just // once, early, to prevent double-free or leaking FDs. f := NewFile(fd) s.recvFiles = append(s.recvFiles, f) } } } } s.m.Unlock() } return n, err }
func mount(dir string, ready chan<- struct{}, errp *error) (fusefd *os.File, err error) { // linux mount is never delayed close(ready) fds, err := syscall.Socketpair(syscall.AF_FILE, syscall.SOCK_STREAM, 0) if err != nil { return nil, fmt.Errorf("socketpair error: %v", err) } defer syscall.Close(fds[0]) defer syscall.Close(fds[1]) cmd := exec.Command("fusermount", "--", dir) cmd.Env = append(os.Environ(), "_FUSE_COMMFD=3") writeFile := os.NewFile(uintptr(fds[0]), "fusermount-child-writes") defer writeFile.Close() cmd.ExtraFiles = []*os.File{writeFile} out, err := cmd.CombinedOutput() if len(out) > 0 || err != nil { return nil, fmt.Errorf("fusermount: %q, %v", out, err) } readFile := os.NewFile(uintptr(fds[1]), "fusermount-parent-reads") defer readFile.Close() c, err := net.FileConn(readFile) if err != nil { return nil, fmt.Errorf("FileConn from fusermount socket: %v", err) } defer c.Close() uc, ok := c.(*net.UnixConn) if !ok { return nil, fmt.Errorf("unexpected FileConn type; expected UnixConn, got %T", c) } buf := make([]byte, 32) // expect 1 byte oob := make([]byte, 32) // expect 24 bytes _, oobn, _, _, err := uc.ReadMsgUnix(buf, oob) scms, err := syscall.ParseSocketControlMessage(oob[:oobn]) if err != nil { return nil, fmt.Errorf("ParseSocketControlMessage: %v", err) } if len(scms) != 1 { return nil, fmt.Errorf("expected 1 SocketControlMessage; got scms = %#v", scms) } scm := scms[0] gotFds, err := syscall.ParseUnixRights(&scm) if err != nil { return nil, fmt.Errorf("syscall.ParseUnixRights: %v", err) } if len(gotFds) != 1 { return nil, fmt.Errorf("wanted 1 fd; got %#v", gotFds) } f := os.NewFile(uintptr(gotFds[0]), "/dev/fuse") return f, nil }
func readRequest(c *net.UnixConn) (*Request, error) { var l uint32 err := binary.Read(c, binary.BigEndian, &l) length := int(l) if err != nil { return nil, err } payload := make([]byte, length) n, err := c.Read(payload) if err != nil { return nil, err } else if n != length { return nil, fmt.Errorf("Payload was %d bytes rather than reported size of %d", n, length) } req := &Request{} err = json.Unmarshal(payload, req) if err != nil { return nil, err } if !req.HasFds { return req, nil } payload = make([]byte, 1) // TODO: does this buffer need to be configurable? oob := make([]byte, 8192) n, oobn, _, _, err := c.ReadMsgUnix(payload, oob) if err != nil && err != io.EOF { return nil, err } if n != 1 { return nil, fmt.Errorf("Error reading OOB filedescriptors") } oob = oob[0:oobn] scm, err := syscall.ParseSocketControlMessage(oob) if err != nil { return nil, fmt.Errorf("Error parsing socket control message: %v", err) } var fds []int for i := 0; i < len(scm); i++ { tfds, err := syscall.ParseUnixRights(&scm[i]) if err == syscall.EINVAL { continue // Wasn't a UnixRights Control Message } else if err != nil { return nil, fmt.Errorf("Error parsing unix rights: %v", err) } fds = append(fds, tfds...) } if len(fds) == 0 { return nil, fmt.Errorf("Failed to receive any FDs on a request with HasFds == true") } req.ReceivedFds = fds return req, nil }
func (m *Message) GetFD() uintptr { if m.control_msgs == nil { return 0 } fds, err := syscall.ParseUnixRights(&m.control_msgs[0]) if err != nil { panic("Unable to parse unix rights") } m.control_msgs = append(m.control_msgs[0:], m.control_msgs[1:]...) if len(fds) != 1 { panic("Expected 1 file descriptor, got more") } return uintptr(fds[0]) }
func (c *Conn) ReadMessage() (m *Message, err error) { h, err := c.readHeader() if err != nil { return } log.Printf("ReadMessage: header = %s", h) m = &Message{ object: h.object(), opcode: h.opcode(), } if h.size() == 0 { return } p := make([]byte, h.size()) oob := make([]byte, 32) n, oobn, _, _, err := c.c.ReadMsgUnix(p, oob) if err != nil { return } if uint16(n) != h.size() { err = fmt.Errorf("expected %d bytes, got %d", h.size(), n) return } log.Printf("ReadMessage: n = %d, oobn = %d", n, oobn) m.p = bytes.NewBuffer(p) if oobn == 0 { return } oob = oob[:oobn] scms, err := syscall.ParseSocketControlMessage(oob) if err != nil { return } if len(scms) != 1 { err = fmt.Errorf("expected 1 SocketControlMessage, got %d", len(scms)) return } scm := scms[0] m.fds, err = syscall.ParseUnixRights(&scm) return }
func readData(conn net.Conn) ([]Fd, int, error) { var b [2048]byte var oob [2048]byte var response Response n, oobn, _, _, err := conn.(*net.UnixConn).ReadMsgUnix(b[:], oob[:]) if err != nil { return nil, 0, fmt.Errorf("unix_socket: failed to read unix msg: %s (read: %d, %d)", err, n, oobn) } if n > 0 { err := json.Unmarshal(b[:n], &response) if err != nil { return nil, 0, fmt.Errorf("unix_socket: Unmarshal failed: %s", err) } if response.ErrMessage != "" { return nil, 0, errors.New(response.ErrMessage) } } else { return nil, 0, errors.New("unix_socket: No response received") } scms, err := syscall.ParseSocketControlMessage(oob[:oobn]) if err != nil { return nil, 0, fmt.Errorf("unix_socket: failed to parse socket control message: %s", err) } if len(scms) < 1 { return nil, 0, fmt.Errorf("unix_socket: no socket control messages sent") } scm := scms[0] fds, err := syscall.ParseUnixRights(&scm) if err != nil { return nil, 0, fmt.Errorf("unix_socket: failed to parse unix rights: %s", err) } files := make([]Fd, len(fds)) for i, fd := range fds { files[i] = os.NewFile(uintptr(fd), fmt.Sprintf("/dev/fake-fd-%d", i)) } return files, response.Pid, nil }
func extractFileDescriptorFromOOB(oob []byte) (int, error) { scms, err := syscall.ParseSocketControlMessage(oob) if err != nil { return -1, err } if len(scms) != 1 { return -1, errors.New(fmt.Sprintf("expected 1 SocketControlMessage; got scms = %#v", scms)) } scm := scms[0] gotFds, err := syscall.ParseUnixRights(&scm) if err != nil { return -1, err } if len(gotFds) != 1 { return -1, errors.New(fmt.Sprintf("wanted 1 fd; got %#v", gotFds)) } return gotFds[0], nil }
// TestUnixRightsRoundtrip tests that UnixRights, ParseSocketControlMessage, // and ParseUnixRights are able to successfully round-trip lists of file descriptors. func TestUnixRightsRoundtrip(t *testing.T) { testCases := [...][][]int{ {{42}}, {{1, 2}}, {{3, 4, 5}}, {{}}, {{1, 2}, {3, 4, 5}, {}, {7}}, } for _, testCase := range testCases { b := []byte{} var n int for _, fds := range testCase { // Last assignment to n wins n = len(b) + syscall.CmsgLen(4*len(fds)) b = append(b, syscall.UnixRights(fds...)...) } // Truncate b b = b[:n] scms, err := syscall.ParseSocketControlMessage(b) if err != nil { t.Fatalf("ParseSocketControlMessage: %v", err) } if len(scms) != len(testCase) { t.Fatalf("expected %v SocketControlMessage; got scms = %#v", len(testCase), scms) } for i, scm := range scms { gotFds, err := syscall.ParseUnixRights(&scm) if err != nil { t.Fatalf("ParseUnixRights: %v", err) } wantFds := testCase[i] if len(gotFds) != len(wantFds) { t.Fatalf("expected %v fds, got %#v", len(wantFds), gotFds) } for j, fd := range gotFds { if fd != wantFds[j] { t.Fatalf("expected fd %v, got %v", wantFds[j], fd) } } } } }
// Get receives file descriptors from a Unix domain socket. // // Num specifies the expected number of file descriptors in one message. // Internal files' names to be assigned are specified via optional filenames // argument. // // You need to close all files in the returned slice. The slice can be // non-empty even if this function returns an error. // // Use net.FileConn() if you're receiving a network connection. func Get(via *net.UnixConn, num int, filenames []string) ([]*os.File, error) { if num < 1 { return nil, nil } // get the underlying socket viaf, err := via.File() if err != nil { return nil, err } socket := int(viaf.Fd()) defer viaf.Close() // recvmsg buf := make([]byte, syscall.CmsgSpace(num*4)) _, _, _, _, err = syscall.Recvmsg(socket, nil, buf, 0) if err != nil { return nil, err } // parse control msgs var msgs []syscall.SocketControlMessage msgs, err = syscall.ParseSocketControlMessage(buf) // convert fds to files res := make([]*os.File, 0, len(msgs)) for i := 0; i < len(msgs) && err == nil; i++ { var fds []int fds, err = syscall.ParseUnixRights(&msgs[i]) for fi, fd := range fds { var filename string if fi < len(filenames) { filename = filenames[fi] } res = append(res, os.NewFile(uintptr(fd), filename)) } } return res, err }
func ReadFile(c *net.UnixConn, timeout time.Duration) (*os.File, error) { oob := make([]byte, 64) if timeout > 0 { deadline := time.Now().Add(timeout) if err := c.SetReadDeadline(deadline); err != nil { return nil, err } } _, oobn, flags, _, err := c.ReadMsgUnix(nil, oob) if err != nil { return nil, err } if flags != 0 || oobn <= 0 { panic("ReadMsgUnix: flags != 0 || oobn <= 0") } // file descriptors are now open in this process scm, err := syscall.ParseSocketControlMessage(oob[:oobn]) if err != nil { return nil, err } if len(scm) != 1 { panic("invalid scm message") } fds, err := syscall.ParseUnixRights(&scm[0]) if err != nil { return nil, err } if len(fds) != 1 { panic("invalid scm message") } return os.NewFile(uintptr(fds[0]), ""), nil }
func (m *Message) parseControlData(data []byte) error { cmsgs, err := syscall.ParseSocketControlMessage(data) if err != nil { return err } for _, cmsg := range cmsgs { switch cmsg.Header.Type { case syscall.SCM_CREDENTIALS: cred, err := syscall.ParseUnixCredentials(&cmsg) if err != nil { return err } m.Ucred = cred case syscall.SCM_RIGHTS: fds, err := syscall.ParseUnixRights(&cmsg) if err != nil { return err } m.Fds = fds } } return nil }
func RecvFd(conn *net.UnixConn) (*os.File, error) { buf := make([]byte, 32) oob := make([]byte, 32) _, oobn, _, _, err := conn.ReadMsgUnix(buf, oob) if err != nil { return nil, fmt.Errorf("recvfd: err %v", err) } scms, err := syscall.ParseSocketControlMessage(oob[:oobn]) if err != nil { return nil, fmt.Errorf("recvfd: ParseSocketControlMessage failed %v", err) } if len(scms) != 1 { return nil, fmt.Errorf("recvfd: SocketControlMessage count not 1: %v", len(scms)) } scm := scms[0] fds, err := syscall.ParseUnixRights(&scm) if err != nil { return nil, fmt.Errorf("recvfd: ParseUnixRights failed %v", err) } if len(fds) != 1 { return nil, fmt.Errorf("recvfd: fd count not 1: %v", len(fds)) } return os.NewFile(uintptr(fds[0]), "passed-fd"), nil }
func mount(dir string, conf *mountConfig, ready chan<- struct{}, errp *error) (fusefd *os.File, err error) { // linux mount is never delayed close(ready) fds, err := syscall.Socketpair(syscall.AF_FILE, syscall.SOCK_STREAM, 0) if err != nil { return nil, fmt.Errorf("socketpair error: %v", err) } writeFile := os.NewFile(uintptr(fds[0]), "fusermount-child-writes") defer writeFile.Close() readFile := os.NewFile(uintptr(fds[1]), "fusermount-parent-reads") defer readFile.Close() cmd := exec.Command( "fusermount", "-o", conf.getOptions(), "--", dir, ) cmd.Env = append(os.Environ(), "_FUSE_COMMFD=3") cmd.ExtraFiles = []*os.File{writeFile} var wg sync.WaitGroup stdout, err := cmd.StdoutPipe() if err != nil { return nil, fmt.Errorf("setting up fusermount stderr: %v", err) } stderr, err := cmd.StderrPipe() if err != nil { return nil, fmt.Errorf("setting up fusermount stderr: %v", err) } if err := cmd.Start(); err != nil { return nil, fmt.Errorf("fusermount: %v", err) } helperErrCh := make(chan error, 1) wg.Add(2) go lineLogger(&wg, "mount helper output", neverIgnoreLine, stdout) go lineLogger(&wg, "mount helper error", handleFusermountStderr(helperErrCh), stderr) wg.Wait() if err := cmd.Wait(); err != nil { // see if we have a better error to report select { case helperErr := <-helperErrCh: // log the Wait error if it's not what we expected if !isBoringFusermountError(err) { log.Printf("mount helper failed: %v", err) } // and now return what we grabbed from stderr as the real // error return nil, helperErr default: // nope, fall back to generic message } return nil, fmt.Errorf("fusermount: %v", err) } c, err := net.FileConn(readFile) if err != nil { return nil, fmt.Errorf("FileConn from fusermount socket: %v", err) } defer c.Close() uc, ok := c.(*net.UnixConn) if !ok { return nil, fmt.Errorf("unexpected FileConn type; expected UnixConn, got %T", c) } buf := make([]byte, 32) // expect 1 byte oob := make([]byte, 32) // expect 24 bytes _, oobn, _, _, err := uc.ReadMsgUnix(buf, oob) scms, err := syscall.ParseSocketControlMessage(oob[:oobn]) if err != nil { return nil, fmt.Errorf("ParseSocketControlMessage: %v", err) } if len(scms) != 1 { return nil, fmt.Errorf("expected 1 SocketControlMessage; got scms = %#v", scms) } scm := scms[0] gotFds, err := syscall.ParseUnixRights(&scm) if err != nil { return nil, fmt.Errorf("syscall.ParseUnixRights: %v", err) } if len(gotFds) != 1 { return nil, fmt.Errorf("wanted 1 fd; got %#v", gotFds) } f := os.NewFile(uintptr(gotFds[0]), "/dev/fuse") return f, nil }
func (t *unixTransport) ReadMessage() (*Message, error) { var ( blen, hlen uint32 csheader [16]byte headers []header order binary.ByteOrder unixfds uint32 ) // To be sure that all bytes of out-of-band data are read, we use a special // reader that uses ReadUnix on the underlying connection instead of Read // and gathers the out-of-band data in a buffer. rd := &oobReader{conn: t.UnixConn} // read the first 16 bytes (the part of the header that has a constant size), // from which we can figure out the length of the rest of the message if _, err := io.ReadFull(rd, csheader[:]); err != nil { return nil, err } switch csheader[0] { case 'l': order = binary.LittleEndian case 'B': order = binary.BigEndian default: return nil, InvalidMessageError("invalid byte order") } // csheader[4:8] -> length of message body, csheader[12:16] -> length of // header fields (without alignment) binary.Read(bytes.NewBuffer(csheader[4:8]), order, &blen) binary.Read(bytes.NewBuffer(csheader[12:]), order, &hlen) if hlen%8 != 0 { hlen += 8 - (hlen % 8) } // decode headers and look for unix fds headerdata := make([]byte, hlen+4) copy(headerdata, csheader[12:]) if _, err := io.ReadFull(t, headerdata[4:]); err != nil { return nil, err } dec := newDecoder(bytes.NewBuffer(headerdata), order) dec.pos = 12 vs, err := dec.Decode(Signature{"a(yv)"}) if err != nil { return nil, err } Store(vs, &headers) for _, v := range headers { if v.Field == byte(FieldUnixFDs) { unixfds, _ = v.Variant.value.(uint32) } } all := make([]byte, 16+hlen+blen) copy(all, csheader[:]) copy(all[16:], headerdata[4:]) if _, err := io.ReadFull(rd, all[16+hlen:]); err != nil { return nil, err } if unixfds != 0 { if !t.hasUnixFDs { return nil, errors.New("dbus: got unix fds on unsupported transport") } // read the fds from the OOB data scms, err := syscall.ParseSocketControlMessage(rd.oob) if err != nil { return nil, err } if len(scms) != 1 { return nil, errors.New("dbus: received more than one socket control message") } fds, err := syscall.ParseUnixRights(&scms[0]) if err != nil { return nil, err } msg, err := DecodeMessage(bytes.NewBuffer(all)) if err != nil { return nil, err } // substitute the values in the message body (which are indices for the // array receiver via OOB) with the actual values for i, v := range msg.Body { if j, ok := v.(UnixFDIndex); ok { if uint32(j) >= unixfds { return nil, InvalidMessageError("invalid index for unix fd") } msg.Body[i] = UnixFD(fds[j]) } } return msg, nil } return DecodeMessage(bytes.NewBuffer(all)) }
func Create(socketPath string, stdout io.Writer, stderr io.Writer) (*Link, error) { conn, err := net.Dial("unix", socketPath) if err != nil { return nil, fmt.Errorf("failed to connect to i/o daemon: %s", err) } var b [2048]byte var oob [2048]byte n, oobn, _, _, err := conn.(*net.UnixConn).ReadMsgUnix(b[:], oob[:]) if err != nil { return nil, fmt.Errorf("failed to read unix msg: %s (read: %d, %d)", err, n, oobn) } scms, err := syscall.ParseSocketControlMessage(oob[:oobn]) if err != nil { return nil, fmt.Errorf("failed to parse socket control message: %s", err) } if len(scms) < 1 { return nil, fmt.Errorf("no socket control messages sent") } scm := scms[0] fds, err := syscall.ParseUnixRights(&scm) if err != nil { return nil, fmt.Errorf("failed to parse unix rights: %s", err) } if len(fds) != 3 { return nil, fmt.Errorf("invalid number of fds; need 3, got %d", len(fds)) } lstdout := os.NewFile(uintptr(fds[0]), "stdout") lstderr := os.NewFile(uintptr(fds[1]), "stderr") lstatus := os.NewFile(uintptr(fds[2]), "status") streaming := &sync.WaitGroup{} linkWriter := NewWriter(conn) streaming.Add(1) go func() { io.Copy(stdout, lstdout) lstdout.Close() streaming.Done() }() streaming.Add(1) go func() { io.Copy(stderr, lstderr) lstderr.Close() streaming.Done() }() done := make(chan struct{}) go func() { streaming.Wait() close(done) conn.Close() }() return &Link{ Writer: linkWriter, exitStatus: lstatus, done: done, }, nil }
// TestPassFD tests passing a file descriptor over a Unix socket. // // This test involved both a parent and child process. The parent // process is invoked as a normal test, with "go test", which then // runs the child process by running the current test binary with args // "-test.run=^TestPassFD$" and an environment variable used to signal // that the test should become the child process instead. func TestPassFD(t *testing.T) { switch runtime.GOOS { case "dragonfly": // TODO(jsing): Figure out why sendmsg is returning EINVAL. t.Skip("skipping test on dragonfly") case "solaris": // TODO(aram): Figure out why ReadMsgUnix is returning empty message. t.Skip("skipping test on solaris, see issue 7402") case "darwin": switch runtime.GOARCH { case "arm", "arm64": t.Skipf("skipping test on %d/%s, no fork", runtime.GOOS, runtime.GOARCH) } } if os.Getenv("GO_WANT_HELPER_PROCESS") == "1" { passFDChild() return } tempDir, err := ioutil.TempDir("", "TestPassFD") if err != nil { t.Fatal(err) } defer os.RemoveAll(tempDir) fds, err := syscall.Socketpair(syscall.AF_LOCAL, syscall.SOCK_STREAM, 0) if err != nil { t.Fatalf("Socketpair: %v", err) } defer syscall.Close(fds[0]) defer syscall.Close(fds[1]) writeFile := os.NewFile(uintptr(fds[0]), "child-writes") readFile := os.NewFile(uintptr(fds[1]), "parent-reads") defer writeFile.Close() defer readFile.Close() cmd := exec.Command(os.Args[0], "-test.run=^TestPassFD$", "--", tempDir) cmd.Env = append(os.Environ(), "GO_WANT_HELPER_PROCESS=1") cmd.ExtraFiles = []*os.File{writeFile} out, err := cmd.CombinedOutput() if len(out) > 0 || err != nil { t.Fatalf("child process: %q, %v", out, err) } c, err := net.FileConn(readFile) if err != nil { t.Fatalf("FileConn: %v", err) } defer c.Close() uc, ok := c.(*net.UnixConn) if !ok { t.Fatalf("unexpected FileConn type; expected UnixConn, got %T", c) } buf := make([]byte, 32) // expect 1 byte oob := make([]byte, 32) // expect 24 bytes closeUnix := time.AfterFunc(5*time.Second, func() { t.Logf("timeout reading from unix socket") uc.Close() }) _, oobn, _, _, err := uc.ReadMsgUnix(buf, oob) closeUnix.Stop() scms, err := syscall.ParseSocketControlMessage(oob[:oobn]) if err != nil { t.Fatalf("ParseSocketControlMessage: %v", err) } if len(scms) != 1 { t.Fatalf("expected 1 SocketControlMessage; got scms = %#v", scms) } scm := scms[0] gotFds, err := syscall.ParseUnixRights(&scm) if err != nil { t.Fatalf("syscall.ParseUnixRights: %v", err) } if len(gotFds) != 1 { t.Fatalf("wanted 1 fd; got %#v", gotFds) } f := os.NewFile(uintptr(gotFds[0]), "fd-from-child") defer f.Close() got, err := ioutil.ReadAll(f) want := "Hello from child process!\n" if string(got) != want { t.Errorf("child process ReadAll: %q, %v; want %q", got, err, want) } }
// Begin the process of mounting at the given directory, returning a connection // to the kernel. Mounting continues in the background, and is complete when an // error is written to the supplied channel. The file system may need to // service the connection in order for mounting to complete. func mount( dir string, cfg *MountConfig, ready chan<- error) (dev *os.File, err error) { // On linux, mounting is never delayed. ready <- nil // Create a socket pair. fds, err := syscall.Socketpair(syscall.AF_FILE, syscall.SOCK_STREAM, 0) if err != nil { err = fmt.Errorf("Socketpair: %v", err) return } // Wrap the sockets into os.File objects that we will pass off to fusermount. writeFile := os.NewFile(uintptr(fds[0]), "fusermount-child-writes") defer writeFile.Close() readFile := os.NewFile(uintptr(fds[1]), "fusermount-parent-reads") defer readFile.Close() // Start fusermount, passing it a buffer in which to write stderr. var stderr bytes.Buffer cmd := exec.Command( "fusermount", "-o", cfg.toOptionsString(), "--", dir, ) cmd.Env = append(os.Environ(), "_FUSE_COMMFD=3") cmd.ExtraFiles = []*os.File{writeFile} cmd.Stderr = &stderr // Run the command. err = cmd.Run() if err != nil { err = fmt.Errorf("running fusermount: %v\n\nstderr:\n%s", err, stderr.Bytes()) return } // Wrap the socket file in a connection. c, err := net.FileConn(readFile) if err != nil { err = fmt.Errorf("FileConn: %v", err) return } defer c.Close() // We expect to have a Unix domain socket. uc, ok := c.(*net.UnixConn) if !ok { err = fmt.Errorf("Expected UnixConn, got %T", c) return } // Read a message. buf := make([]byte, 32) // expect 1 byte oob := make([]byte, 32) // expect 24 bytes _, oobn, _, _, err := uc.ReadMsgUnix(buf, oob) if err != nil { err = fmt.Errorf("ReadMsgUnix: %v", err) return } // Parse the message. scms, err := syscall.ParseSocketControlMessage(oob[:oobn]) if err != nil { err = fmt.Errorf("ParseSocketControlMessage: %v", err) return } // We expect one message. if len(scms) != 1 { err = fmt.Errorf("expected 1 SocketControlMessage; got scms = %#v", scms) return } scm := scms[0] // Pull out the FD returned by fusermount gotFds, err := syscall.ParseUnixRights(&scm) if err != nil { err = fmt.Errorf("syscall.ParseUnixRights: %v", err) return } if len(gotFds) != 1 { err = fmt.Errorf("wanted 1 fd; got %#v", gotFds) return } // Turn the FD into an os.File. dev = os.NewFile(uintptr(gotFds[0]), "/dev/fuse") return }
// TestPassFD tests passing a file descriptor over a Unix socket. // // This test involved both a parent and child process. The parent // process is invoked as a normal test, with "go test", which then // runs the child process by running the current test binary with args // "-test.run=^TestPassFD$" and an environment variable used to signal // that the test should become the child process instead. func TestPassFD(t *testing.T) { if os.Getenv("GO_WANT_HELPER_PROCESS") == "1" { passFDChild() return } tempDir, err := ioutil.TempDir("", "TestPassFD") if err != nil { t.Fatal(err) } defer os.RemoveAll(tempDir) fds, err := syscall.Socketpair(syscall.AF_LOCAL, syscall.SOCK_STREAM, 0) if err != nil { t.Fatalf("Socketpair: %v", err) } defer syscall.Close(fds[0]) defer syscall.Close(fds[1]) writeFile := os.NewFile(uintptr(fds[0]), "child-writes") readFile := os.NewFile(uintptr(fds[1]), "parent-reads") defer writeFile.Close() defer readFile.Close() cmd := exec.Command(os.Args[0], "-test.run=^TestPassFD$", "--", tempDir) cmd.Env = []string{"GO_WANT_HELPER_PROCESS=1"} cmd.ExtraFiles = []*os.File{writeFile} out, err := cmd.CombinedOutput() if len(out) > 0 || err != nil { t.Fatalf("child process: %q, %v", out, err) } c, err := net.FileConn(readFile) if err != nil { t.Fatalf("FileConn: %v", err) } defer c.Close() uc, ok := c.(*net.UnixConn) if !ok { t.Fatalf("unexpected FileConn type; expected UnixConn, got %T", c) } buf := make([]byte, 32) // expect 1 byte oob := make([]byte, 32) // expect 24 bytes closeUnix := time.AfterFunc(5*time.Second, func() { t.Logf("timeout reading from unix socket") uc.Close() }) _, oobn, _, _, err := uc.ReadMsgUnix(buf, oob) closeUnix.Stop() scms, err := syscall.ParseSocketControlMessage(oob[:oobn]) if err != nil { t.Fatalf("ParseSocketControlMessage: %v", err) } if len(scms) != 1 { t.Fatalf("expected 1 SocketControlMessage; got scms = %#v", scms) } scm := scms[0] gotFds, err := syscall.ParseUnixRights(&scm) if err != nil { t.Fatalf("syscall.ParseUnixRights: %v", err) } if len(gotFds) != 1 { t.Fatalf("wanted 1 fd; got %#v", gotFds) } f := os.NewFile(uintptr(gotFds[0]), "fd-from-child") defer f.Close() got, err := ioutil.ReadAll(f) want := "Hello from child process!\n" if string(got) != want { t.Errorf("child process ReadAll: %q, %v; want %q", got, err, want) } }
func mount(dir string, conf *mountConfig, ready chan<- struct{}, errp *error) (fusefd *os.File, err error) { // linux mount is never delayed close(ready) fds, err := syscall.Socketpair(syscall.AF_FILE, syscall.SOCK_STREAM, 0) if err != nil { return nil, fmt.Errorf("socketpair error: %v", err) } writeFile := os.NewFile(uintptr(fds[0]), "fusermount-child-writes") defer writeFile.Close() readFile := os.NewFile(uintptr(fds[1]), "fusermount-parent-reads") defer readFile.Close() cmd := exec.Command( "fusermount", "-o", conf.getOptions(), "--", dir, ) cmd.Env = append(os.Environ(), "_FUSE_COMMFD=3") cmd.ExtraFiles = []*os.File{writeFile} var wg sync.WaitGroup stdout, err := cmd.StdoutPipe() if err != nil { return nil, fmt.Errorf("setting up fusermount stderr: %v", err) } stderr, err := cmd.StderrPipe() if err != nil { return nil, fmt.Errorf("setting up fusermount stderr: %v", err) } if err := cmd.Start(); err != nil { return nil, fmt.Errorf("fusermount: %v", err) } wg.Add(2) go lineLogger(&wg, "mount helper output", stdout) go lineLogger(&wg, "mount helper error", stderr) wg.Wait() if err := cmd.Wait(); err != nil { return nil, fmt.Errorf("fusermount: %v", err) } c, err := net.FileConn(readFile) if err != nil { return nil, fmt.Errorf("FileConn from fusermount socket: %v", err) } defer c.Close() uc, ok := c.(*net.UnixConn) if !ok { return nil, fmt.Errorf("unexpected FileConn type; expected UnixConn, got %T", c) } buf := make([]byte, 32) // expect 1 byte oob := make([]byte, 32) // expect 24 bytes _, oobn, _, _, err := uc.ReadMsgUnix(buf, oob) scms, err := syscall.ParseSocketControlMessage(oob[:oobn]) if err != nil { return nil, fmt.Errorf("ParseSocketControlMessage: %v", err) } if len(scms) != 1 { return nil, fmt.Errorf("expected 1 SocketControlMessage; got scms = %#v", scms) } scm := scms[0] gotFds, err := syscall.ParseUnixRights(&scm) if err != nil { return nil, fmt.Errorf("syscall.ParseUnixRights: %v", err) } if len(gotFds) != 1 { return nil, fmt.Errorf("wanted 1 fd; got %#v", gotFds) } f := os.NewFile(uintptr(gotFds[0]), "/dev/fuse") return f, nil }