diff --git a/cmd/containerd-shim/main_unix.go b/cmd/containerd-shim/main_unix.go index 16889183b..8a13299e4 100644 --- a/cmd/containerd-shim/main_unix.go +++ b/cmd/containerd-shim/main_unix.go @@ -67,9 +67,6 @@ func main() { if err != nil { return err } - if err := setupRoot(); err != nil { - return err - } path, err := os.Getwd() if err != nil { return err diff --git a/cmd/containerd-shim/shim_linux.go b/cmd/containerd-shim/shim_linux.go index 10d51b78b..43ab3db1b 100644 --- a/cmd/containerd-shim/shim_linux.go +++ b/cmd/containerd-shim/shim_linux.go @@ -10,7 +10,6 @@ import ( "google.golang.org/grpc/credentials" "golang.org/x/net/context" - "golang.org/x/sys/unix" "github.com/containerd/containerd/reaper" "github.com/containerd/containerd/sys" @@ -33,11 +32,6 @@ func setupSignals() (chan os.Signal, error) { return signals, nil } -// setupRoot sets up the root as the shim is started in its own mount namespace -func setupRoot() error { - return unix.Mount("", "/", "", unix.MS_SLAVE|unix.MS_REC, "") -} - func newServer() *grpc.Server { return grpc.NewServer(grpc.Creds(NewUnixSocketCredentils(0, 0))) } diff --git a/cmd/containerd-shim/shim_unix.go b/cmd/containerd-shim/shim_unix.go index 110123d7f..b6bf2a6a0 100644 --- a/cmd/containerd-shim/shim_unix.go +++ b/cmd/containerd-shim/shim_unix.go @@ -23,11 +23,6 @@ func setupSignals() (chan os.Signal, error) { return signals, nil } -// setupRoot is a no op except on Linux -func setupRoot() error { - return nil -} - func newServer() *grpc.Server { return grpc.NewServer() } diff --git a/linux/shim/init.go b/linux/shim/init.go index 41c7d96ad..f34defe70 100644 --- a/linux/shim/init.go +++ b/linux/shim/init.go @@ -50,9 +50,12 @@ type initProcess struct { closers []io.Closer stdin io.Closer stdio stdio + rootfs string } func newInitProcess(context context.Context, path, namespace string, r *shimapi.CreateTaskRequest) (*initProcess, error) { + var success bool + if err := identifiers.Validate(r.ID); err != nil { return nil, errors.Wrapf(err, "invalid task id") } @@ -64,13 +67,26 @@ func newInitProcess(context context.Context, path, namespace string, r *shimapi. } options = *v.(*runcopts.CreateOptions) } + + rootfs := filepath.Join(path, "rootfs") + // count the number of successful mounts so we can undo + // what was actually done rather than what should have been + // done. + defer func() { + if success { + return + } + if err2 := mount.UnmountAll(rootfs, 0); err2 != nil { + log.G(context).WithError(err2).Warn("Failed to cleanup rootfs mount") + } + }() for _, rm := range r.Rootfs { m := &mount.Mount{ Type: rm.Type, Source: rm.Source, Options: rm.Options, } - if err := m.Mount(filepath.Join(path, "rootfs")); err != nil { + if err := m.Mount(rootfs); err != nil { return nil, errors.Wrapf(err, "failed to mount rootfs component %v", m) } } @@ -91,6 +107,7 @@ func newInitProcess(context context.Context, path, namespace string, r *shimapi. stderr: r.Stderr, terminal: r.Terminal, }, + rootfs: rootfs, } var ( err error @@ -170,6 +187,7 @@ func newInitProcess(context context.Context, path, namespace string, r *shimapi. return nil, errors.Wrap(err, "failed to retrieve OCI runtime container pid") } p.pid = pid + success = true return p, nil } @@ -229,7 +247,16 @@ func (p *initProcess) Delete(context context.Context) error { } p.io.Close() } - return p.runtimeError(err, "OCI runtime delete failed") + err = p.runtimeError(err, "OCI runtime delete failed") + + if err2 := mount.UnmountAll(p.rootfs, 0); err2 != nil { + log.G(context).WithError(err2).Warn("Failed to cleanup rootfs mount") + if err == nil { + err = errors.Wrap(err2, "Failed rootfs umount") + } + } + + return err } func (p *initProcess) Resize(ws console.WinSize) error { diff --git a/mount/mount_linux.go b/mount/mount_linux.go index 86df8bbc1..5995051b8 100644 --- a/mount/mount_linux.go +++ b/mount/mount_linux.go @@ -15,6 +15,25 @@ func Unmount(mount string, flags int) error { return unix.Unmount(mount, flags) } +// UnmountAll repeatedly unmounts the given mount point until there +// are no mounts remaining (EINVAL is returned by mount), which is +// useful for undoing a stack of mounts on the same mount point. +func UnmountAll(mount string, flags int) error { + for { + if err := Unmount(mount, flags); err != nil { + // EINVAL is returned if the target is not a + // mount point, indicating that we are + // done. It can also indicate a few other + // things (such as invalid flags) which we + // unfortunately end up squelching here too. + if err == unix.EINVAL { + return nil + } + return err + } + } +} + // parseMountOptions takes fstab style mount options and parses them for // use with a standard mount() syscall func parseMountOptions(options []string) (int, string) { diff --git a/mount/mount_unix.go b/mount/mount_unix.go index 32ea2691a..23467a8cc 100644 --- a/mount/mount_unix.go +++ b/mount/mount_unix.go @@ -15,3 +15,7 @@ func (m *Mount) Mount(target string) error { func Unmount(mount string, flags int) error { return ErrNotImplementOnUnix } + +func UnmountAll(mount string, flags int) error { + return ErrNotImplementOnUnix +} diff --git a/mount/mount_windows.go b/mount/mount_windows.go index 35dead411..8eeca6817 100644 --- a/mount/mount_windows.go +++ b/mount/mount_windows.go @@ -13,3 +13,7 @@ func (m *Mount) Mount(target string) error { func Unmount(mount string, flags int) error { return ErrNotImplementOnWindows } + +func UnmountAll(mount string, flags int) error { + return ErrNotImplementOnWindows +}