diff --git a/container_linux_test.go b/container_linux_test.go index 8cc9555dd..ac1c8c47c 100644 --- a/container_linux_test.go +++ b/container_linux_test.go @@ -133,10 +133,6 @@ func TestShimInCgroup(t *testing.T) { t.Fatal(err) } defer client.Close() - if CheckRuntime(client.runtime, "io.containerd.runc") { - t.Skip() - } - var ( ctx, cancel = testContext() id = t.Name() @@ -160,12 +156,7 @@ func TestShimInCgroup(t *testing.T) { } defer cg.Delete() - task, err := container.NewTask(ctx, empty(), func(_ context.Context, client *Client, r *TaskInfo) error { - r.Options = &runctypes.CreateOptions{ - ShimCgroup: path, - } - return nil - }) + task, err := container.NewTask(ctx, empty(), WithShimCgroup(path)) if err != nil { t.Fatal(err) } diff --git a/runtime/v2/runc/v1/service.go b/runtime/v2/runc/v1/service.go index 269d26471..2125b8ae7 100644 --- a/runtime/v2/runc/v1/service.go +++ b/runtime/v2/runc/v1/service.go @@ -44,6 +44,8 @@ import ( taskAPI "github.com/containerd/containerd/runtime/v2/task" runcC "github.com/containerd/go-runc" "github.com/containerd/typeurl" + "github.com/gogo/protobuf/proto" + "github.com/gogo/protobuf/types" ptypes "github.com/gogo/protobuf/types" specs "github.com/opencontainers/runtime-spec/specs-go" "github.com/pkg/errors" @@ -163,6 +165,31 @@ func (s *service) StartShim(ctx context.Context, id, containerdBinary, container if err := shim.WriteAddress("address", address); err != nil { return "", err } + if data, err := ioutil.ReadAll(os.Stdin); err == nil { + if len(data) > 0 { + var any types.Any + if err := proto.Unmarshal(data, &any); err != nil { + return "", err + } + v, err := typeurl.UnmarshalAny(&any) + if err != nil { + return "", err + } + if opts, ok := v.(*options.Options); ok { + if opts.ShimCgroup != "" { + cg, err := cgroups.Load(cgroups.V1, cgroups.StaticPath(opts.ShimCgroup)) + if err != nil { + return "", errors.Wrapf(err, "failed to load cgroup %s", opts.ShimCgroup) + } + if err := cg.Add(cgroups.Process{ + Pid: cmd.Process.Pid, + }); err != nil { + return "", errors.Wrapf(err, "failed to join cgroup %s", opts.ShimCgroup) + } + } + } + } + } if err := shim.SetScore(cmd.Process.Pid); err != nil { return "", errors.Wrap(err, "failed to set OOM Score on shim") } diff --git a/runtime/v2/runc/v2/service.go b/runtime/v2/runc/v2/service.go index 9fc2d1b9d..4623a78ed 100644 --- a/runtime/v2/runc/v2/service.go +++ b/runtime/v2/runc/v2/service.go @@ -45,6 +45,8 @@ import ( taskAPI "github.com/containerd/containerd/runtime/v2/task" runcC "github.com/containerd/go-runc" "github.com/containerd/typeurl" + "github.com/gogo/protobuf/proto" + "github.com/gogo/protobuf/types" ptypes "github.com/gogo/protobuf/types" specs "github.com/opencontainers/runtime-spec/specs-go" "github.com/pkg/errors" @@ -206,6 +208,31 @@ func (s *service) StartShim(ctx context.Context, id, containerdBinary, container if err := shim.WriteAddress("address", address); err != nil { return "", err } + if data, err := ioutil.ReadAll(os.Stdin); err == nil { + if len(data) > 0 { + var any types.Any + if err := proto.Unmarshal(data, &any); err != nil { + return "", err + } + v, err := typeurl.UnmarshalAny(&any) + if err != nil { + return "", err + } + if opts, ok := v.(*options.Options); ok { + if opts.ShimCgroup != "" { + cg, err := cgroups.Load(cgroups.V1, cgroups.StaticPath(opts.ShimCgroup)) + if err != nil { + return "", errors.Wrapf(err, "failed to load cgroup %s", opts.ShimCgroup) + } + if err := cg.Add(cgroups.Process{ + Pid: cmd.Process.Pid, + }); err != nil { + return "", errors.Wrapf(err, "failed to join cgroup %s", opts.ShimCgroup) + } + } + } + } + } if err := shim.SetScore(cmd.Process.Pid); err != nil { return "", errors.Wrap(err, "failed to set OOM Score on shim") } diff --git a/task_opts_unix.go b/task_opts_unix.go index d3b51a76d..8b498d47e 100644 --- a/task_opts_unix.go +++ b/task_opts_unix.go @@ -77,3 +77,29 @@ func WithNoPivotRoot(_ context.Context, _ *Client, ti *TaskInfo) error { } return nil } + +// WithShimCgroup sets the existing cgroup for the shim +func WithShimCgroup(path string) NewTaskOpts { + return func(ctx context.Context, c *Client, ti *TaskInfo) error { + if CheckRuntime(ti.Runtime(), "io.containerd.runc") { + if ti.Options == nil { + ti.Options = &options.Options{} + } + opts, ok := ti.Options.(*options.Options) + if !ok { + return errors.New("invalid v2 shim create options format") + } + opts.ShimCgroup = path + } else { + if ti.Options == nil { + ti.Options = &runctypes.CreateOptions{} + } + opts, ok := ti.Options.(*runctypes.CreateOptions) + if !ok { + return errors.New("could not cast TaskInfo Options to CreateOptions") + } + opts.ShimCgroup = path + } + return nil + } +}