func main() { var err error flag.Parse() defer exit() log.Println("Loading NVIDIA management library") assert(nvidia.Init()) defer func() { assert(nvidia.Shutdown()) }() log.Println("Loading NVIDIA unified memory") assert(nvidia.LoadUVM()) log.Println("Discovering GPU devices") Devices, err = nvidia.LookupDevices() assert(err) log.Println("Provisioning volumes at", VolumesPath) Volumes, err = nvidia.LookupVolumes(VolumesPath) assert(err) plugin := NewPluginAPI(SocketPath) remote := NewRemoteAPI(ListenAddr) log.Println("Serving plugin API at", SocketPath) log.Println("Serving remote API at", ListenAddr) p := plugin.Serve() r := remote.Serve() join, joined := make(chan int, 2), 0 L: for { select { case <-p: remote.Stop() p = nil join <- 1 case <-r: plugin.Stop() r = nil join <- 1 case j := <-join: if joined += j; joined == cap(join) { break L } } } assert(plugin.Error()) assert(remote.Error()) log.Println("Successfully terminated") }
func devicesArgs() ([]string, error) { args := []string{"--device=/dev/nvidiactl", "--device=/dev/nvidia-uvm"} // FIXME avoid looking up every devices devs, err := nvidia.LookupDevices() if err != nil { return nil, err } if len(GPU) == 0 { for i := range devs { args = append(args, fmt.Sprintf("--device=%s", devs[i].Path)) } } else { for _, id := range GPU { i, err := strconv.Atoi(id) if err != nil || i < 0 || i >= len(devs) { return nil, fmt.Errorf("invalid device: %s", id) } args = append(args, fmt.Sprintf("--device=%s", devs[i].Path)) } } return args, nil }