diff --git a/container_opts.go b/container_opts.go index ca960302e..f005fe1c7 100644 --- a/container_opts.go +++ b/container_opts.go @@ -22,12 +22,10 @@ import ( "errors" "fmt" - "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" "github.com/containerd/containerd/containers" "github.com/containerd/containerd/content" "github.com/containerd/containerd/errdefs" "github.com/containerd/containerd/images" - "github.com/containerd/containerd/log" "github.com/containerd/containerd/oci" "github.com/containerd/containerd/protobuf" "github.com/containerd/containerd/snapshots" @@ -326,37 +324,3 @@ func WithSpec(s *oci.Spec, opts ...oci.SpecOpts) NewContainerOpts { func WithoutRefreshedMetadata(i *InfoConfig) { i.Refresh = false } - -// WithCDI updates OCI spec with CDI content -func WithCDI(s *oci.Spec, annotations map[string]string, cdiSpecDirs []string) NewContainerOpts { - return func(ctx context.Context, _ *Client, c *containers.Container) error { - // TODO: Once CRI is extended with native CDI support this will need to be updated... - _, cdiDevices, err := cdi.ParseAnnotations(annotations) - if err != nil { - return fmt.Errorf("failed to parse CDI device annotations: %w", err) - } - if cdiDevices == nil { - return nil - } - - registry := cdi.GetRegistry(cdi.WithSpecDirs(cdiSpecDirs...)) - if err = registry.Refresh(); err != nil { - // We don't consider registry refresh failure a fatal error. - // For instance, a dynamically generated invalid CDI Spec file for - // any particular vendor shouldn't prevent injection of devices of - // different vendors. CDI itself knows better and it will fail the - // injection if necessary. - log.G(ctx).Warnf("CDI registry refresh failed: %v", err) - } - - if _, err := registry.InjectDevices(s, cdiDevices...); err != nil { - return fmt.Errorf("CDI device injection failed: %w", err) - } - - // One crucial thing to keep in mind is that CDI device injection - // might add OCI Spec environment variables, hooks, and mounts as - // well. Therefore it is important that none of the corresponding - // OCI Spec fields are reset up in the call stack once we return. - return nil - } -} diff --git a/oci/spec_opts_linux.go b/oci/spec_opts_linux.go index 90c4887a4..e5032391b 100644 --- a/oci/spec_opts_linux.go +++ b/oci/spec_opts_linux.go @@ -18,8 +18,11 @@ package oci import ( "context" + "fmt" + "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" "github.com/containerd/containerd/containers" + "github.com/containerd/containerd/log" "github.com/containerd/containerd/pkg/cap" specs "github.com/opencontainers/runtime-spec/specs-go" ) @@ -157,3 +160,37 @@ func WithRdt(closID, l3CacheSchema, memBwSchema string) SpecOpts { func escapeAndCombineArgs(args []string) string { panic("not supported") } + +// WithCDI updates OCI spec with CDI content +func WithCDI(annotations map[string]string, cdiSpecDirs []string) SpecOpts { + return func(ctx context.Context, _ Client, c *containers.Container, s *Spec) error { + // TODO: Once CRI is extended with native CDI support this will need to be updated... + _, cdiDevices, err := cdi.ParseAnnotations(annotations) + if err != nil { + return fmt.Errorf("failed to parse CDI device annotations: %w", err) + } + if cdiDevices == nil { + return nil + } + + registry := cdi.GetRegistry(cdi.WithSpecDirs(cdiSpecDirs...)) + if err = registry.Refresh(); err != nil { + // We don't consider registry refresh failure a fatal error. + // For instance, a dynamically generated invalid CDI Spec file for + // any particular vendor shouldn't prevent injection of devices of + // different vendors. CDI itself knows better and it will fail the + // injection if necessary. + log.G(ctx).Warnf("CDI registry refresh failed: %v", err) + } + + if _, err := registry.InjectDevices(s, cdiDevices...); err != nil { + return fmt.Errorf("CDI device injection failed: %w", err) + } + + // One crucial thing to keep in mind is that CDI device injection + // might add OCI Spec environment variables, hooks, and mounts as + // well. Therefore it is important that none of the corresponding + // OCI Spec fields are reset up in the call stack once we return. + return nil + } +} diff --git a/pkg/cri/server/container_create.go b/pkg/cri/server/container_create.go index 4d84a582b..0978951bf 100644 --- a/pkg/cri/server/container_create.go +++ b/pkg/cri/server/container_create.go @@ -239,10 +239,6 @@ func (c *criService) CreateContainer(ctx context.Context, r *runtime.CreateConta return nil, fmt.Errorf("failed to get runtime options: %w", err) } - if c.config.EnableCDI { - opts = append(opts, containerd.WithCDI(spec, config.Annotations, c.config.CDISpecDirs)) - } - opts = append(opts, containerd.WithSpec(spec, specOpts...), containerd.WithRuntime(sandboxInfo.Runtime.Name, runtimeOptions), diff --git a/pkg/cri/server/container_create_linux.go b/pkg/cri/server/container_create_linux.go index 8fb41e210..47d349251 100644 --- a/pkg/cri/server/container_create_linux.go +++ b/pkg/cri/server/container_create_linux.go @@ -386,6 +386,9 @@ func (c *criService) containerSpecOpts(config *runtime.ContainerConfig, imageCon if seccompSpecOpts != nil { specOpts = append(specOpts, seccompSpecOpts) } + if c.config.EnableCDI { + specOpts = append(specOpts, oci.WithCDI(config.Annotations, c.config.CDISpecDirs)) + } return specOpts, nil } diff --git a/pkg/cri/server/container_create_linux_test.go b/pkg/cri/server/container_create_linux_test.go index ff725df99..d7a1fda36 100644 --- a/pkg/cri/server/container_create_linux_test.go +++ b/pkg/cri/server/container_create_linux_test.go @@ -28,7 +28,6 @@ import ( "testing" "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" - "github.com/containerd/containerd" "github.com/containerd/containerd/containers" "github.com/containerd/containerd/contrib/apparmor" "github.com/containerd/containerd/contrib/seccomp" @@ -1619,8 +1618,8 @@ containerEdits: } require.NoError(t, err) - injectFun := containerd.WithCDI(spec, test.annotations, []string{cdiDir}) - err = injectFun(nil, nil, nil) + injectFun := oci.WithCDI(test.annotations, []string{cdiDir}) + err = injectFun(nil, nil, nil, spec) assert.Equal(t, test.expectError, err != nil) if err != nil {