diff --git a/contrib/nvidia/nvidia.go b/contrib/nvidia/nvidia.go index 05f2e7e5b..bf64d3014 100644 --- a/contrib/nvidia/nvidia.go +++ b/contrib/nvidia/nvidia.go @@ -69,9 +69,12 @@ func WithGPUs(opts ...Opts) oci.SpecOpts { return err } } - path, err := exec.LookPath("containerd") - if err != nil { - return err + if c.OCIHookPath == "" { + path, err := exec.LookPath("containerd") + if err != nil { + return err + } + c.OCIHookPath = path } nvidiaPath, err := exec.LookPath(nvidiaCLI) if err != nil { @@ -81,7 +84,7 @@ func WithGPUs(opts ...Opts) oci.SpecOpts { s.Hooks = &specs.Hooks{} } s.Hooks.Prestart = append(s.Hooks.Prestart, specs.Hook{ - Path: path, + Path: c.OCIHookPath, Args: append([]string{ "containerd", "oci-hook", @@ -101,6 +104,7 @@ type config struct { LDCache string LDConfig string Requirements []string + OCIHookPath string } func (c *config) args() []string { @@ -179,3 +183,23 @@ func WithRequiredCUDAVersion(major, minor int) Opts { return nil } } + +// WithOCIHookPath sets the hook path for the binary +func WithOCIHookPath(path string) Opts { + return func(c *config) error { + c.OCIHookPath = path + return nil + } +} + +// WithLookupOCIHookPath sets the hook path for the binary via a binary name +func WithLookupOCIHookPath(name string) Opts { + return func(c *config) error { + path, err := exec.LookPath(name) + if err != nil { + return err + } + c.OCIHookPath = path + return nil + } +}