diff --git a/pkg/os/os.go b/pkg/os/os.go index 128ffb8f9..f0e7a3271 100644 --- a/pkg/os/os.go +++ b/pkg/os/os.go @@ -25,6 +25,7 @@ import ( containerdmount "github.com/containerd/containerd/mount" "github.com/containerd/fifo" "github.com/docker/docker/pkg/mount" + "github.com/docker/docker/pkg/symlink" "golang.org/x/net/context" "golang.org/x/sys/unix" ) @@ -37,6 +38,7 @@ type OS interface { OpenFifo(ctx context.Context, fn string, flag int, perm os.FileMode) (io.ReadWriteCloser, error) Stat(name string) (os.FileInfo, error) ResolveSymbolicLink(name string) (string, error) + FollowSymlinkInScope(path, scope string) (string, error) CopyFile(src, dest string, perm os.FileMode) error WriteFile(filename string, data []byte, perm os.FileMode) error Mount(source string, target string, fstype string, flags uintptr, data string) error @@ -79,6 +81,11 @@ func (RealOS) ResolveSymbolicLink(path string) (string, error) { return filepath.EvalSymlinks(path) } +// FollowSymlinkInScope will call symlink.FollowSymlinkInScope. +func (RealOS) FollowSymlinkInScope(path, scope string) (string, error) { + return symlink.FollowSymlinkInScope(path, scope) +} + // CopyFile will copy src file to dest file func (RealOS) CopyFile(src, dest string, perm os.FileMode) error { in, err := os.Open(src) @@ -107,17 +114,21 @@ func (RealOS) Mount(source string, target string, fstype string, flags uintptr, return unix.Mount(source, target, fstype, flags, data) } -// Unmount will call unix.Unmount to unmount the file. The function doesn't -// return error if target is not mounted. +// Unmount will call Unmount to unmount the file. func (RealOS) Unmount(target string, flags int) error { - // TODO(random-liu): Follow symlink to make sure the result is correct. - if mounted, err := mount.Mounted(target); err != nil || !mounted { - return err - } - return unix.Unmount(target, flags) + return Unmount(target, flags) } // LookupMount gets mount info of a given path. func (RealOS) LookupMount(path string) (containerdmount.Info, error) { return containerdmount.Lookup(path) } + +// Unmount will call unix.Unmount to unmount the file. The function doesn't +// return error if target is not mounted. +func Unmount(target string, flags int) error { + if mounted, err := mount.Mounted(target); err != nil || !mounted { + return err + } + return unix.Unmount(target, flags) +} diff --git a/pkg/os/testing/fake_os.go b/pkg/os/testing/fake_os.go index fdba1d2a7..f33cbbd4f 100644 --- a/pkg/os/testing/fake_os.go +++ b/pkg/os/testing/fake_os.go @@ -40,18 +40,19 @@ type CalledDetail struct { // of the real call. type FakeOS struct { sync.Mutex - MkdirAllFn func(string, os.FileMode) error - RemoveAllFn func(string) error - OpenFifoFn func(context.Context, string, int, os.FileMode) (io.ReadWriteCloser, error) - StatFn func(string) (os.FileInfo, error) - ResolveSymbolicLinkFn func(string) (string, error) - CopyFileFn func(string, string, os.FileMode) error - WriteFileFn func(string, []byte, os.FileMode) error - MountFn func(source string, target string, fstype string, flags uintptr, data string) error - UnmountFn func(target string, flags int) error - LookupMountFn func(path string) (containerdmount.Info, error) - calls []CalledDetail - errors map[string]error + MkdirAllFn func(string, os.FileMode) error + RemoveAllFn func(string) error + OpenFifoFn func(context.Context, string, int, os.FileMode) (io.ReadWriteCloser, error) + StatFn func(string) (os.FileInfo, error) + ResolveSymbolicLinkFn func(string) (string, error) + FollowSymlinkInScopeFn func(string, string) (string, error) + CopyFileFn func(string, string, os.FileMode) error + WriteFileFn func(string, []byte, os.FileMode) error + MountFn func(source string, target string, fstype string, flags uintptr, data string) error + UnmountFn func(target string, flags int) error + LookupMountFn func(path string) (containerdmount.Info, error) + calls []CalledDetail + errors map[string]error } var _ osInterface.OS = &FakeOS{} @@ -176,6 +177,19 @@ func (f *FakeOS) ResolveSymbolicLink(path string) (string, error) { return path, nil } +// FollowSymlinkInScope is a fake call that invokes FollowSymlinkInScope or returns its input +func (f *FakeOS) FollowSymlinkInScope(path, scope string) (string, error) { + f.appendCalls("FollowSymlinkInScope", path, scope) + if err := f.getError("FollowSymlinkInScope"); err != nil { + return "", err + } + + if f.FollowSymlinkInScopeFn != nil { + return f.FollowSymlinkInScopeFn(path, scope) + } + return path, nil +} + // CopyFile is a fake call that invokes CopyFileFn or just return nil. func (f *FakeOS) CopyFile(src, dest string, perm os.FileMode) error { f.appendCalls("CopyFile", src, dest, perm) diff --git a/pkg/server/sandbox_run.go b/pkg/server/sandbox_run.go index de0b42bee..e27ec8bc1 100644 --- a/pkg/server/sandbox_run.go +++ b/pkg/server/sandbox_run.go @@ -491,8 +491,12 @@ func parseDNSOptions(servers, searches, options []string) (string, error) { // 2) The mount point doesn't exist. func (c *criService) unmountSandboxFiles(id string, config *runtime.PodSandboxConfig) error { if config.GetLinux().GetSecurityContext().GetNamespaceOptions().GetIpc() != runtime.NamespaceMode_NODE { - if err := c.os.Unmount(c.getSandboxDevShm(id), unix.MNT_DETACH); err != nil && !os.IsNotExist(err) { - return err + path, err := c.os.FollowSymlinkInScope(c.getSandboxDevShm(id), "/") + if err != nil { + return errors.Wrap(err, "failed to follow symlink") + } + if err := c.os.Unmount(path, unix.MNT_DETACH); err != nil && !os.IsNotExist(err) { + return errors.Wrapf(err, "failed to unmount %q", path) } } return nil diff --git a/pkg/store/sandbox/netns.go b/pkg/store/sandbox/netns.go index 1fc98ee00..8ec4c1de5 100644 --- a/pkg/store/sandbox/netns.go +++ b/pkg/store/sandbox/netns.go @@ -21,10 +21,11 @@ import ( "sync" cnins "github.com/containernetworking/plugins/pkg/ns" - "github.com/docker/docker/pkg/mount" "github.com/docker/docker/pkg/symlink" "github.com/pkg/errors" "golang.org/x/sys/unix" + + osinterface "github.com/containerd/cri/pkg/os" ) // ErrClosedNetNS is the error returned when network namespace is closed. @@ -81,7 +82,6 @@ func (n *NetNS) Remove() error { } if n.restored { path := n.ns.Path() - // TODO(random-liu): Add util function for unmount. // Check netns existence. if _, err := os.Stat(path); err != nil { if os.IsNotExist(err) { @@ -93,15 +93,8 @@ func (n *NetNS) Remove() error { if err != nil { return errors.Wrap(err, "failed to follow symlink") } - mounted, err := mount.Mounted(path) - if err != nil { - return errors.Wrap(err, "failed to check netns mounted") - } - if mounted { - err := unix.Unmount(path, unix.MNT_DETACH) - if err != nil && !os.IsNotExist(err) { - return errors.Wrap(err, "failed to umount netns") - } + if err := osinterface.Unmount(path, unix.MNT_DETACH); err != nil && !os.IsNotExist(err) { + return errors.Wrap(err, "failed to umount netns") } if err := os.RemoveAll(path); err != nil { return errors.Wrap(err, "failed to remove netns")