diff --git a/contrib/nvidia/nvidia.go b/contrib/nvidia/nvidia.go index 72c6f63d4..b520ab79e 100644 --- a/contrib/nvidia/nvidia.go +++ b/contrib/nvidia/nvidia.go @@ -60,9 +60,12 @@ func WithGPUs(opts ...Opts) oci.SpecOpts { return err } } - path, err := exec.LookPath("containerd") - if err != nil { - return err + if c.OCIHookPath == "" { + path, err := exec.LookPath("containerd") + if err != nil { + return err + } + c.OCIHookPath = path } nvidiaPath, err := exec.LookPath(nvidiaCLI) if err != nil { @@ -72,7 +75,7 @@ func WithGPUs(opts ...Opts) oci.SpecOpts { s.Hooks = &specs.Hooks{} } s.Hooks.Prestart = append(s.Hooks.Prestart, specs.Hook{ - Path: path, + Path: c.OCIHookPath, Args: append([]string{ "containerd", "oci-hook", @@ -92,6 +95,7 @@ type config struct { LDCache string LDConfig string Requirements []string + OCIHookPath string } func (c *config) args() []string { @@ -181,3 +185,23 @@ func WithRequiredCUDAVersion(major, minor int) Opts { return nil } } + +// WithOCIHookPath sets the hook path for the binary +func WithOCIHookPath(path string) Opts { + return func(c *config) error { + c.OCIHookPath = path + return nil + } +} + +// WithLookupOCIHookPath sets the hook path for the binary via a binary name +func WithLookupOCIHookPath(name string) Opts { + return func(c *config) error { + path, err := exec.LookPath(name) + if err != nil { + return err + } + c.OCIHookPath = path + return nil + } +}