diff --git a/cmd/containerd-shim-runc-v2/runc/container.go b/cmd/containerd-shim-runc-v2/runc/container.go index d5b4a1f80..ceef95e39 100644 --- a/cmd/containerd-shim-runc-v2/runc/container.go +++ b/cmd/containerd-shim-runc-v2/runc/container.go @@ -146,24 +146,9 @@ func NewContainer(ctx context.Context, platform stdio.Platform, r *task.CreateTa } pid := p.Pid() if pid > 0 { - var cg interface{} - if cgroups.Mode() == cgroups.Unified { - g, err := cgroupsv2.PidGroupPath(pid) - if err != nil { - log.G(ctx).WithError(err).Errorf("loading cgroup2 for %d", pid) - return container, nil - } - cg, err = cgroupsv2.Load(g) - if err != nil { - log.G(ctx).WithError(err).Errorf("loading cgroup2 for %d", pid) - } - } else { - cg, err = cgroup1.Load(cgroup1.PidPath(pid)) - if err != nil { - log.G(ctx).WithError(err).Errorf("loading cgroup for %d", pid) - } + if cg, err := loadProcessCgroup(ctx, pid); err == nil { + container.cgroup = cg } - container.cgroup = cg } return container, nil } @@ -367,23 +352,9 @@ func (c *Container) Start(ctx context.Context, r *task.StartRequest) (process.Pr return p, err } if c.Cgroup() == nil && p.Pid() > 0 { - var cg interface{} - if cgroups.Mode() == cgroups.Unified { - g, err := cgroupsv2.PidGroupPath(p.Pid()) - if err != nil { - log.G(ctx).WithError(err).Errorf("loading cgroup2 for %d", p.Pid()) - } - cg, err = cgroupsv2.Load(g) - if err != nil { - log.G(ctx).WithError(err).Errorf("loading cgroup2 for %d", p.Pid()) - } - } else { - cg, err = cgroup1.Load(cgroup1.PidPath(p.Pid())) - if err != nil { - log.G(ctx).WithError(err).Errorf("loading cgroup for %d", p.Pid()) - } + if cg, err := loadProcessCgroup(ctx, p.Pid()); err == nil { + c.cgroup = cg } - c.cgroup = cg } return p, nil } @@ -512,3 +483,25 @@ func (c *Container) HasPid(pid int) bool { } return false } + +func loadProcessCgroup(ctx context.Context, pid int) (cg interface{}, err error) { + if cgroups.Mode() == cgroups.Unified { + g, err := cgroupsv2.PidGroupPath(pid) + if err != nil { + log.G(ctx).WithError(err).Errorf("loading cgroup2 for %d", pid) + return nil, err + } + cg, err = cgroupsv2.Load(g) + if err != nil { + log.G(ctx).WithError(err).Errorf("loading cgroup2 for %d", pid) + return nil, err + } + } else { + cg, err = cgroup1.Load(cgroup1.PidPath(pid)) + if err != nil { + log.G(ctx).WithError(err).Errorf("loading cgroup for %d", pid) + return nil, err + } + } + return cg, nil +}