diff --git a/container_linux_test.go b/container_linux_test.go index 0fcdd5f6f..e7cd5fc8b 100644 --- a/container_linux_test.go +++ b/container_linux_test.go @@ -1040,7 +1040,7 @@ func TestContainerRuntimeOptionsv2(t *testing.T) { } } -func TestContainerKillInitPidHost(t *testing.T) { +func initContainerAndCheckChildrenDieOnKill(t *testing.T, opts ...oci.SpecOpts) { client, err := newClient(t, address) if err != nil { t.Fatal(err) @@ -1059,12 +1059,12 @@ func TestContainerKillInitPidHost(t *testing.T) { t.Fatal(err) } + opts = append(opts, oci.WithImageConfig(image)) + opts = append(opts, withProcessArgs("sh", "-c", "sleep 42; echo hi")) + container, err := client.NewContainer(ctx, id, WithNewSnapshot(id, image), - WithNewSpec(oci.WithImageConfig(image), - withProcessArgs("sh", "-c", "sleep 42; echo hi"), - oci.WithHostNamespace(specs.PIDNamespace), - ), + WithNewSpec(opts...), ) if err != nil { t.Fatal(err) @@ -1111,6 +1111,14 @@ func TestContainerKillInitPidHost(t *testing.T) { } } +func TestContainerKillInitPidHost(t *testing.T) { + initContainerAndCheckChildrenDieOnKill(t, oci.WithHostNamespace(specs.PIDNamespace)) +} + +func TestContainerKillInitKillsChildWhenNotHostPid(t *testing.T) { + initContainerAndCheckChildrenDieOnKill(t) +} + func TestUserNamespaces(t *testing.T) { t.Parallel() t.Run("WritableRootFS", func(t *testing.T) { testUserNamespaces(t, false) }) diff --git a/runtime/v1/linux/proc/init.go b/runtime/v1/linux/proc/init.go index fe11285c7..d37f45031 100644 --- a/runtime/v1/linux/proc/init.go +++ b/runtime/v1/linux/proc/init.go @@ -249,7 +249,6 @@ func (p *Init) setExited(status int) { } func (p *Init) delete(context context.Context) error { - p.KillAll(context) p.wg.Wait() err := p.runtime.Delete(context, p.id, nil) // ignore errors if a runtime has already deleted the process diff --git a/runtime/v1/shim/service.go b/runtime/v1/shim/service.go index c0e7c868a..afb860cb0 100644 --- a/runtime/v1/shim/service.go +++ b/runtime/v1/shim/service.go @@ -20,7 +20,9 @@ package shim import ( "context" + "encoding/json" "fmt" + "io/ioutil" "os" "path/filepath" "sync" @@ -41,6 +43,7 @@ import ( runc "github.com/containerd/go-runc" "github.com/containerd/typeurl" ptypes "github.com/gogo/protobuf/types" + specs "github.com/opencontainers/runtime-spec/specs-go" "github.com/pkg/errors" "github.com/sirupsen/logrus" "google.golang.org/grpc/codes" @@ -507,13 +510,22 @@ func (s *Service) processExits() { func (s *Service) checkProcesses(e runc.Exit) { s.mu.Lock() defer s.mu.Unlock() + + shouldKillAll, err := shouldKillAllOnExit(s.bundle) + if err != nil { + log.G(s.context).WithError(err).Error("failed to check shouldKillAll") + } + for _, p := range s.processes { if p.Pid() == e.Pid { - if ip, ok := p.(*proc.Init); ok { - // Ensure all children are killed - if err := ip.KillAll(s.context); err != nil { - log.G(s.context).WithError(err).WithField("id", ip.ID()). - Error("failed to kill init's children") + + if shouldKillAll { + if ip, ok := p.(*proc.Init); ok { + // Ensure all children are killed + if err := ip.KillAll(s.context); err != nil { + log.G(s.context).WithError(err).WithField("id", ip.ID()). + Error("failed to kill init's children") + } } } p.SetExited(e.Status) @@ -529,6 +541,25 @@ func (s *Service) checkProcesses(e runc.Exit) { } } +func shouldKillAllOnExit(bundlePath string) (bool, error) { + var bundleSpec specs.Spec + bundleConfigContents, err := ioutil.ReadFile(filepath.Join(bundlePath, "config.json")) + if err != nil { + return false, err + } + json.Unmarshal(bundleConfigContents, &bundleSpec) + + if bundleSpec.Linux != nil { + for _, ns := range bundleSpec.Linux.Namespaces { + if ns.Type == specs.PIDNamespace { + return false, nil + } + } + } + + return true, nil +} + func (s *Service) getContainerPids(ctx context.Context, id string) ([]uint32, error) { s.mu.Lock() defer s.mu.Unlock() diff --git a/runtime/v2/runc/service.go b/runtime/v2/runc/service.go index aeb48cfe5..e37fb2976 100644 --- a/runtime/v2/runc/service.go +++ b/runtime/v2/runc/service.go @@ -20,6 +20,7 @@ package runc import ( "context" + "encoding/json" "io/ioutil" "os" "os/exec" @@ -34,6 +35,7 @@ import ( "github.com/containerd/containerd/api/types/task" "github.com/containerd/containerd/errdefs" "github.com/containerd/containerd/events" + "github.com/containerd/containerd/log" "github.com/containerd/containerd/mount" "github.com/containerd/containerd/namespaces" "github.com/containerd/containerd/runtime" @@ -45,6 +47,7 @@ import ( runcC "github.com/containerd/go-runc" "github.com/containerd/typeurl" ptypes "github.com/gogo/protobuf/types" + specs "github.com/opencontainers/runtime-spec/specs-go" "github.com/pkg/errors" "github.com/sirupsen/logrus" "golang.org/x/sys/unix" @@ -638,13 +641,20 @@ func (s *service) processExits() { } func (s *service) checkProcesses(e runcC.Exit) { + shouldKillAll, err := shouldKillAllOnExit(s.bundle) + if err != nil { + log.G(s.context).WithError(err).Error("failed to check shouldKillAll") + } + for _, p := range s.allProcesses() { if p.Pid() == e.Pid { - if ip, ok := p.(*proc.Init); ok { - // Ensure all children are killed - if err := ip.KillAll(s.context); err != nil { - logrus.WithError(err).WithField("id", ip.ID()). - Error("failed to kill init's children") + if shouldKillAll { + if ip, ok := p.(*proc.Init); ok { + // Ensure all children are killed + if err := ip.KillAll(s.context); err != nil { + logrus.WithError(err).WithField("id", ip.ID()). + Error("failed to kill init's children") + } } } p.SetExited(e.Status) @@ -660,6 +670,25 @@ func (s *service) checkProcesses(e runcC.Exit) { } } +func shouldKillAllOnExit(bundlePath string) (bool, error) { + var bundleSpec specs.Spec + bundleConfigContents, err := ioutil.ReadFile(filepath.Join(bundlePath, "config.json")) + if err != nil { + return false, err + } + json.Unmarshal(bundleConfigContents, &bundleSpec) + + if bundleSpec.Linux != nil { + for _, ns := range bundleSpec.Linux.Namespaces { + if ns.Type == specs.PIDNamespace { + return false, nil + } + } + } + + return true, nil +} + func (s *service) allProcesses() (o []rproc.Process) { s.mu.Lock() defer s.mu.Unlock()