diff --git a/contrib/nvidia/nvidia.go b/contrib/nvidia/nvidia.go index 296d85b64..72c6f63d4 100644 --- a/contrib/nvidia/nvidia.go +++ b/contrib/nvidia/nvidia.go @@ -86,8 +86,7 @@ func WithGPUs(opts ...Opts) oci.SpecOpts { } type config struct { - Devices []int - DeviceUUID string + Devices []string Capabilities []Capability LoadKmods bool LDCache string @@ -108,10 +107,7 @@ func (c *config) args() []string { "configure", ) if len(c.Devices) > 0 { - args = append(args, fmt.Sprintf("--device=%s", strings.Join(toStrings(c.Devices), ","))) - } - if c.DeviceUUID != "" { - args = append(args, fmt.Sprintf("--device=%s", c.DeviceUUID)) + args = append(args, fmt.Sprintf("--device=%s", strings.Join(c.Devices, ","))) } for _, c := range c.Capabilities { args = append(args, fmt.Sprintf("--%s", capFlags[c])) @@ -135,36 +131,30 @@ var capFlags = map[Capability]string{ Display: "display", } -func toStrings(ints []int) []string { - var s []string - for _, i := range ints { - s = append(s, strconv.Itoa(i)) - } - return s -} - // Opts are options for configuring gpu support type Opts func(*config) error // WithDevices adds the provided device indexes to the container func WithDevices(ids ...int) Opts { return func(c *config) error { - c.Devices = ids + for _, i := range ids { + c.Devices = append(c.Devices, strconv.Itoa(i)) + } return nil } } -// WithDeviceUUID adds the specific device UUID to the container -func WithDeviceUUID(guid string) Opts { +// WithDeviceUUIDs adds the specific device UUID to the container +func WithDeviceUUIDs(uuids ...string) Opts { return func(c *config) error { - c.DeviceUUID = guid + c.Devices = append(c.Devices, uuids...) return nil } } // WithAllDevices adds all gpus to the container func WithAllDevices(c *config) error { - c.DeviceUUID = "all" + c.Devices = []string{"all"} return nil } @@ -176,6 +166,14 @@ func WithAllCapabilities(c *config) error { return nil } +// WithCapabilities adds the specified capabilities to the container for the gpus +func WithCapabilities(caps ...Capability) Opts { + return func(c *config) error { + c.Capabilities = append(c.Capabilities, caps...) + return nil + } +} + // WithRequiredCUDAVersion sets the required cuda version func WithRequiredCUDAVersion(major, minor int) Opts { return func(c *config) error {