diff --git a/cmd/containerd-shim/main_unix.go b/cmd/containerd-shim/main_unix.go index 16889183b..8a13299e4 100644 --- a/cmd/containerd-shim/main_unix.go +++ b/cmd/containerd-shim/main_unix.go @@ -67,9 +67,6 @@ func main() { if err != nil { return err } - if err := setupRoot(); err != nil { - return err - } path, err := os.Getwd() if err != nil { return err diff --git a/cmd/containerd-shim/shim_linux.go b/cmd/containerd-shim/shim_linux.go index 10d51b78b..43ab3db1b 100644 --- a/cmd/containerd-shim/shim_linux.go +++ b/cmd/containerd-shim/shim_linux.go @@ -10,7 +10,6 @@ import ( "google.golang.org/grpc/credentials" "golang.org/x/net/context" - "golang.org/x/sys/unix" "github.com/containerd/containerd/reaper" "github.com/containerd/containerd/sys" @@ -33,11 +32,6 @@ func setupSignals() (chan os.Signal, error) { return signals, nil } -// setupRoot sets up the root as the shim is started in its own mount namespace -func setupRoot() error { - return unix.Mount("", "/", "", unix.MS_SLAVE|unix.MS_REC, "") -} - func newServer() *grpc.Server { return grpc.NewServer(grpc.Creds(NewUnixSocketCredentils(0, 0))) } diff --git a/cmd/containerd-shim/shim_unix.go b/cmd/containerd-shim/shim_unix.go index 110123d7f..b6bf2a6a0 100644 --- a/cmd/containerd-shim/shim_unix.go +++ b/cmd/containerd-shim/shim_unix.go @@ -23,11 +23,6 @@ func setupSignals() (chan os.Signal, error) { return signals, nil } -// setupRoot is a no op except on Linux -func setupRoot() error { - return nil -} - func newServer() *grpc.Server { return grpc.NewServer() } diff --git a/linux/shim/init.go b/linux/shim/init.go index 41c7d96ad..9e40daaac 100644 --- a/linux/shim/init.go +++ b/linux/shim/init.go @@ -39,17 +39,19 @@ type initProcess struct { // the reaper interface. mu sync.Mutex - id string - bundle string - console console.Console - io runc.IO - runtime *runc.Runc - status int - exited time.Time - pid int - closers []io.Closer - stdin io.Closer - stdio stdio + id string + bundle string + console console.Console + io runc.IO + runtime *runc.Runc + status int + exited time.Time + pid int + closers []io.Closer + stdin io.Closer + stdio stdio + rootfs string + nrRootMounts int // Number of rootfs overmounts } func newInitProcess(context context.Context, path, namespace string, r *shimapi.CreateTaskRequest) (*initProcess, error) { @@ -64,15 +66,27 @@ func newInitProcess(context context.Context, path, namespace string, r *shimapi. } options = *v.(*runcopts.CreateOptions) } + + rootfs := filepath.Join(path, "rootfs") + // count the number of successful mounts so we can undo + // what was actually done rather than what should have been + // done. + nrRootMounts := 0 for _, rm := range r.Rootfs { m := &mount.Mount{ Type: rm.Type, Source: rm.Source, Options: rm.Options, } - if err := m.Mount(filepath.Join(path, "rootfs")); err != nil { + if err := m.Mount(rootfs); err != nil { return nil, errors.Wrapf(err, "failed to mount rootfs component %v", m) } + nrRootMounts++ + } + cleanupMounts := func() { + if err2 := mount.UnmountN(rootfs, 0, nrRootMounts); err2 != nil { + log.G(context).WithError(err2).Warn("Failed to cleanup rootfs mount") + } } runtime := &runc.Runc{ Command: r.Runtime, @@ -91,6 +105,8 @@ func newInitProcess(context context.Context, path, namespace string, r *shimapi. stderr: r.Stderr, terminal: r.Terminal, }, + rootfs: rootfs, + nrRootMounts: nrRootMounts, } var ( err error @@ -99,12 +115,14 @@ func newInitProcess(context context.Context, path, namespace string, r *shimapi. ) if r.Terminal { if socket, err = runc.NewConsoleSocket(filepath.Join(path, "pty.sock")); err != nil { + cleanupMounts() return nil, errors.Wrap(err, "failed to create OCI runtime console socket") } defer os.Remove(socket.Path()) } else { // TODO: get uid/gid if io, err = runc.NewPipeIO(0, 0); err != nil { + cleanupMounts() return nil, errors.Wrap(err, "failed to create OCI runtime io pipes") } p.io = io @@ -124,6 +142,7 @@ func newInitProcess(context context.Context, path, namespace string, r *shimapi. NoSubreaper: true, } if _, err := p.runtime.Restore(context, r.ID, r.Bundle, opts); err != nil { + cleanupMounts() return nil, p.runtimeError(err, "OCI runtime restore failed") } } else { @@ -137,6 +156,7 @@ func newInitProcess(context context.Context, path, namespace string, r *shimapi. opts.ConsoleSocket = socket } if err := p.runtime.Create(context, r.ID, r.Bundle, opts); err != nil { + cleanupMounts() return nil, p.runtimeError(err, "OCI runtime create failed") } } @@ -152,14 +172,17 @@ func newInitProcess(context context.Context, path, namespace string, r *shimapi. if socket != nil { console, err := socket.ReceiveMaster() if err != nil { + cleanupMounts() return nil, errors.Wrap(err, "failed to retrieve console master") } p.console = console if err := copyConsole(context, console, r.Stdin, r.Stdout, r.Stderr, &p.WaitGroup, ©WaitGroup); err != nil { + cleanupMounts() return nil, errors.Wrap(err, "failed to start console copy") } } else { if err := copyPipes(context, io, r.Stdin, r.Stdout, r.Stderr, &p.WaitGroup, ©WaitGroup); err != nil { + cleanupMounts() return nil, errors.Wrap(err, "failed to start io pipe copy") } } @@ -167,6 +190,7 @@ func newInitProcess(context context.Context, path, namespace string, r *shimapi. copyWaitGroup.Wait() pid, err := runc.ReadPidFile(pidFile) if err != nil { + cleanupMounts() return nil, errors.Wrap(err, "failed to retrieve OCI runtime container pid") } p.pid = pid @@ -229,7 +253,16 @@ func (p *initProcess) Delete(context context.Context) error { } p.io.Close() } - return p.runtimeError(err, "OCI runtime delete failed") + err = p.runtimeError(err, "OCI runtime delete failed") + + if err2 := mount.UnmountN(p.rootfs, 0, p.nrRootMounts); err2 != nil { + log.G(context).WithError(err2).Warn("Failed to cleanup rootfs mount") + if err == nil { + err = errors.Wrap(err2, "Failed rootfs umount") + } + } + + return err } func (p *initProcess) Resize(ws console.WinSize) error { diff --git a/mount/mount.go b/mount/mount.go index 94086f178..2ef4a0164 100644 --- a/mount/mount.go +++ b/mount/mount.go @@ -22,3 +22,19 @@ func MountAll(mounts []Mount, target string) error { } return nil } + +// UnmountN tries to unmount the given mount point nr times, which is +// useful for undoing a stack of mounts on the same mount +// point. Returns the first error encountered, but always attempts the +// full nr umounts. +func UnmountN(mount string, flags, nr int) error { + var err error + for i := 0; i < nr; i++ { + if err2 := Unmount(mount, flags); err2 != nil { + if err == nil { + err = err2 + } + } + } + return err +}