diff --git a/mount/mount_idmapped_linux.go b/mount/mount_idmapped_linux.go index 92208771e..930070e7f 100644 --- a/mount/mount_idmapped_linux.go +++ b/mount/mount_idmapped_linux.go @@ -19,15 +19,15 @@ package mount import ( "fmt" "os" + "runtime" "strconv" "strings" + "sync" "syscall" - "unsafe" "golang.org/x/sys/unix" "github.com/containerd/containerd/sys" - "github.com/sirupsen/logrus" ) // TODO: Support multiple mappings in future @@ -36,21 +36,26 @@ func parseIDMapping(mapping string) ([]syscall.SysProcIDMap, error) { if len(parts) != 3 { return nil, fmt.Errorf("user namespace mappings require the format `container-id:host-id:size`") } + cID, err := strconv.Atoi(parts[0]) if err != nil { return nil, fmt.Errorf("invalid container id for user namespace remapping, %w", err) } + hID, err := strconv.Atoi(parts[1]) if err != nil { return nil, fmt.Errorf("invalid host id for user namespace remapping, %w", err) } + size, err := strconv.Atoi(parts[2]) if err != nil { return nil, fmt.Errorf("invalid size for user namespace remapping, %w", err) } - if cID != 0 || hID < 0 || size < 0 { - return nil, fmt.Errorf("invalid mapping %s, all IDs and size must be positive integers (container ID of 0 is only supported)", mapping) + + if cID < 0 || hID < 0 || size < 0 { + return nil, fmt.Errorf("invalid mapping %s, all IDs and size must be positive integers", mapping) } + return []syscall.SysProcIDMap{ { ContainerID: cID, @@ -87,80 +92,121 @@ func IDMapMount(source, target string, usernsFd int) (err error) { return nil } -// GetUsernsFD forks the current process and creates a user namespace using the specified -// mappings. -// -// It returns: -// 1. The file descriptor of the /proc/[pid]/ns/user of the newly -// created mapping. -// 2. "Clean up" function that should be called once user namespace -// file descriptor is no longer needed. -// 3. Usual error. -func GetUsernsFD(uidmap, gidmap string) (_ int, _ func(), err error) { - var ( - usernsFile *os.File - pipeMap [2]int - pid uintptr - errno syscall.Errno - uidMaps, gidMaps []syscall.SysProcIDMap - ) - - if uidMaps, err = parseIDMapping(uidmap); err != nil { - return -1, nil, err - } - if gidMaps, err = parseIDMapping(gidmap); err != nil { - return -1, nil, err +// GetUsernsFD forks the current process and creates a user namespace using +// the specified mappings. +func GetUsernsFD(uidmap, gidmap string) (_usernsFD *os.File, _ error) { + uidMaps, err := parseIDMapping(uidmap) + if err != nil { + return nil, err } - syscall.ForkLock.Lock() - if err = syscall.Pipe2(pipeMap[:], syscall.O_CLOEXEC); err != nil { - syscall.ForkLock.Unlock() - return -1, nil, err + gidMaps, err := parseIDMapping(gidmap) + if err != nil { + return nil, err } + return getUsernsFD(uidMaps, gidMaps) +} - pid, errno = sys.ForkUserns(pipeMap) - syscall.ForkLock.Unlock() +func getUsernsFD(uidMaps, gidMaps []syscall.SysProcIDMap) (_usernsFD *os.File, retErr error) { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + pid, pidfd, errno := sys.ForkUserns() if errno != 0 { - syscall.Close(pipeMap[0]) - syscall.Close(pipeMap[1]) - return -1, nil, errno + return nil, errno } - syscall.Close(pipeMap[0]) + pidFD := os.NewFile(pidfd, "pidfd") + defer func() { + unix.PidfdSendSignal(int(pidFD.Fd()), unix.SIGKILL, nil, 0) - writeMappings := func(fname string, idmap []syscall.SysProcIDMap) error { + pidfdWaitid(pidFD) + + pidFD.Close() + }() + + // NOTE: + // + // The usernsFD will hold the userns reference in kernel. Even if the + // child process is reaped, the usernsFD is still valid. + usernsFD, err := os.Open(fmt.Sprintf("/proc/%d/ns/user", pid)) + if err != nil { + return nil, fmt.Errorf("failed to get userns file descriptor for /proc/%d/user/ns: %w", pid, err) + } + defer func() { + if retErr != nil { + usernsFD.Close() + } + }() + + uidmapFile, err := os.OpenFile(fmt.Sprintf("/proc/%d/%s", pid, "uid_map"), os.O_WRONLY, 0600) + if err != nil { + return nil, fmt.Errorf("failed to open /proc/%d/uid_map: %w", pid, err) + } + defer uidmapFile.Close() + + gidmapFile, err := os.OpenFile(fmt.Sprintf("/proc/%d/%s", pid, "gid_map"), os.O_WRONLY, 0600) + if err != nil { + return nil, fmt.Errorf("failed to open /proc/%d/gid_map: %w", pid, err) + } + defer gidmapFile.Close() + + testHookKillChildBeforePidfdSendSignal(pid, pidFD) + + // Ensure the child process is still alive. If the err is ESRCH, we + // should return error because we can't guarantee the usernsFD and + // u[g]idmapFile are valid. It's safe to return error and retry. + if err := unix.PidfdSendSignal(int(pidFD.Fd()), 0, nil, 0); err != nil { + return nil, fmt.Errorf("failed to ensure child process is alive: %w", err) + } + + testHookKillChildAfterPidfdSendSignal(pid, pidFD) + + // NOTE: + // + // The u[g]id_map file descriptor is still valid if the child process + // is reaped. + writeMappings := func(f *os.File, idmap []syscall.SysProcIDMap) error { mappings := "" for _, m := range idmap { - mappings = fmt.Sprintf("%d %d %d\n", m.ContainerID, m.HostID, m.Size) + mappings = fmt.Sprintf("%s%d %d %d\n", mappings, m.ContainerID, m.HostID, m.Size) } - return os.WriteFile(fmt.Sprintf("/proc/%d/%s", pid, fname), []byte(mappings), 0600) - } - cleanUpChild := func() { - sync := sys.ProcSyncExit - if _, _, errno := syscall.Syscall6(syscall.SYS_WRITE, uintptr(pipeMap[1]), uintptr(unsafe.Pointer(&sync)), unsafe.Sizeof(sync), 0, 0, 0); errno != 0 { - logrus.WithError(errno).Warnf("failed to sync with child (ProcSyncExit)") + _, err := f.Write([]byte(mappings)) + if err1 := f.Close(); err1 != nil && err == nil { + err = err1 } - syscall.Close(pipeMap[1]) - - if _, err := unix.Wait4(int(pid), nil, 0, nil); err != nil { - logrus.WithError(err).Warnf("failed to wait for child process; the SIGHLD might be received by shim reaper") - } - } - defer cleanUpChild() - - if err := writeMappings("uid_map", uidMaps); err != nil { - return -1, nil, err - } - if err := writeMappings("gid_map", gidMaps); err != nil { - return -1, nil, err + return err } - if usernsFile, err = os.Open(fmt.Sprintf("/proc/%d/ns/user", pid)); err != nil { - return -1, nil, fmt.Errorf("failed to get user ns file descriptor for - /proc/%d/user/ns: %w", pid, err) + if err := writeMappings(uidmapFile, uidMaps); err != nil { + return nil, fmt.Errorf("failed to write uid_map: %w", err) } - return int(usernsFile.Fd()), func() { - usernsFile.Close() - }, nil + if err := writeMappings(gidmapFile, gidMaps); err != nil { + return nil, fmt.Errorf("failed to write gid_map: %w", err) + } + return usernsFD, nil } + +func pidfdWaitid(pidFD *os.File) error { + // https://elixir.bootlin.com/linux/v5.4.258/source/include/uapi/linux/wait.h#L20 + const PPidFD = 3 + + var e syscall.Errno + for { + _, _, e = syscall.Syscall6(syscall.SYS_WAITID, PPidFD, pidFD.Fd(), 0, syscall.WEXITED, 0, 0) + if e != syscall.EINTR { + break + } + } + return e +} + +var ( + testHookLock sync.Mutex + + testHookKillChildBeforePidfdSendSignal = func(_pid uintptr, _pidFD *os.File) {} + + testHookKillChildAfterPidfdSendSignal = func(_pid uintptr, _pidFD *os.File) {} +) diff --git a/mount/mount_idmapped_linux_test.go b/mount/mount_idmapped_linux_test.go new file mode 100644 index 000000000..1b21b9492 --- /dev/null +++ b/mount/mount_idmapped_linux_test.go @@ -0,0 +1,235 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package mount + +import ( + "fmt" + "os" + "path/filepath" + "syscall" + "testing" + + "github.com/containerd/continuity/testutil" + "github.com/stretchr/testify/require" + "golang.org/x/sys/unix" +) + +var ( + testUIDMaps = []syscall.SysProcIDMap{ + {ContainerID: 1000, HostID: 0, Size: 100}, + {ContainerID: 5000, HostID: 2000, Size: 100}, + {ContainerID: 10000, HostID: 3000, Size: 100}, + } + + testGIDMaps = []syscall.SysProcIDMap{ + {ContainerID: 1000, HostID: 0, Size: 100}, + {ContainerID: 5000, HostID: 2000, Size: 100}, + {ContainerID: 10000, HostID: 3000, Size: 100}, + } +) + +func TestGetUsernsFD(t *testing.T) { + testutil.RequiresRoot(t) + + t.Run("basic", testGetUsernsFDBasic) + + t.Run("when kill child process before write u[g]id maps", testGetUsernsFDKillChildWhenWriteUGIDMaps) + + t.Run("when kill child process after open u[g]id_map file", testGetUsernsFDKillChildAfterOpenUGIDMapFiles) + +} + +func testGetUsernsFDBasic(t *testing.T) { + for idx, tc := range []struct { + uidMaps string + gidMaps string + hasErr bool + }{ + { + uidMaps: "0:1000:100", + gidMaps: "0:1000:100", + hasErr: false, + }, + { + uidMaps: "100:1000:100", + gidMaps: "0:-1:100", + hasErr: true, + }, + { + uidMaps: "100:1000:100", + gidMaps: "-1:1000:100", + hasErr: true, + }, + { + uidMaps: "100:1000:100", + gidMaps: "0:1000:-1", + hasErr: true, + }, + } { + t.Run(fmt.Sprintf("#%v", idx), func(t *testing.T) { + _, err := GetUsernsFD(tc.uidMaps, tc.gidMaps) + if tc.hasErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} + +func testGetUsernsFDKillChildWhenWriteUGIDMaps(t *testing.T) { + hookFunc := func(reap bool) func(uintptr, *os.File) { + return func(_pid uintptr, pidFD *os.File) { + err := unix.PidfdSendSignal(int(pidFD.Fd()), unix.SIGKILL, nil, 0) + require.NoError(t, err) + + if reap { + pidfdWaitid(pidFD) + } + } + } + + for _, tcReap := range []bool{true, false} { + t.Run(fmt.Sprintf("#reap=%v", tcReap), func(t *testing.T) { + updateTestHookKillForGetUsernsFD(t, nil, hookFunc(tcReap)) + + usernsFD, err := getUsernsFD(testUIDMaps, testGIDMaps) + require.NoError(t, err) + defer usernsFD.Close() + + srcDir, checkFunc := initIDMappedChecker(t, testUIDMaps, testGIDMaps) + destDir := t.TempDir() + defer func() { + require.NoError(t, UnmountAll(destDir, 0)) + }() + + err = IDMapMount(srcDir, destDir, int(usernsFD.Fd())) + usernsFD.Close() + require.NoError(t, err) + checkFunc(destDir) + }) + } + +} + +func testGetUsernsFDKillChildAfterOpenUGIDMapFiles(t *testing.T) { + hookFunc := func(reap bool) func(uintptr, *os.File) { + return func(_pid uintptr, pidFD *os.File) { + err := unix.PidfdSendSignal(int(pidFD.Fd()), unix.SIGKILL, nil, 0) + require.NoError(t, err) + + if reap { + pidfdWaitid(pidFD) + } + } + } + + for _, tc := range []struct { + reap bool + expected error + }{ + { + reap: false, + expected: nil, + }, + { + reap: true, + expected: syscall.ESRCH, + }, + } { + t.Run(fmt.Sprintf("#reap=%v", tc.reap), func(t *testing.T) { + updateTestHookKillForGetUsernsFD(t, hookFunc(tc.reap), nil) + + usernsFD, err := getUsernsFD(testUIDMaps, testGIDMaps) + if tc.expected != nil { + require.Error(t, tc.expected, err) + return + } + + require.NoError(t, err) + defer usernsFD.Close() + + srcDir, checkFunc := initIDMappedChecker(t, testUIDMaps, testGIDMaps) + destDir := t.TempDir() + defer func() { + require.NoError(t, UnmountAll(destDir, 0)) + }() + + err = IDMapMount(srcDir, destDir, int(usernsFD.Fd())) + usernsFD.Close() + require.NoError(t, err) + checkFunc(destDir) + }) + } +} + +func updateTestHookKillForGetUsernsFD(t *testing.T, newBeforeFunc, newAfterFunc func(uintptr, *os.File)) { + testHookLock.Lock() + + oldBefore := testHookKillChildBeforePidfdSendSignal + oldAfter := testHookKillChildAfterPidfdSendSignal + t.Cleanup(func() { + testHookKillChildBeforePidfdSendSignal = oldBefore + testHookKillChildAfterPidfdSendSignal = oldAfter + testHookLock.Unlock() + }) + if newBeforeFunc != nil { + testHookKillChildBeforePidfdSendSignal = newBeforeFunc + } + if newAfterFunc != nil { + testHookKillChildAfterPidfdSendSignal = newAfterFunc + } +} + +func initIDMappedChecker(t *testing.T, uidMaps, gidMaps []syscall.SysProcIDMap) (_srcDir string, _verifyFunc func(destDir string)) { + testutil.RequiresRoot(t) + + srcDir := t.TempDir() + + require.Equal(t, len(uidMaps), len(gidMaps)) + for idx := range uidMaps { + file := filepath.Join(srcDir, fmt.Sprintf("%v", idx)) + + f, err := os.Create(file) + require.NoError(t, err, fmt.Sprintf("create file %s", file)) + defer f.Close() + + uid, gid := uidMaps[idx].ContainerID, gidMaps[idx].ContainerID + err = f.Chown(uid, gid) + require.NoError(t, err, fmt.Sprintf("chown %v:%v for file %s", uid, gid, file)) + } + + return srcDir, func(destDir string) { + for idx := range uidMaps { + file := filepath.Join(destDir, fmt.Sprintf("%v", idx)) + + f, err := os.Open(file) + require.NoError(t, err, fmt.Sprintf("open file %s", file)) + defer f.Close() + + stat, err := f.Stat() + require.NoError(t, err, fmt.Sprintf("stat file %s", file)) + + sysStat := stat.Sys().(*syscall.Stat_t) + + uid, gid := uidMaps[idx].HostID, gidMaps[idx].HostID + require.Equal(t, uint32(uid), sysStat.Uid, fmt.Sprintf("check file %s uid", file)) + require.Equal(t, uint32(gid), sysStat.Gid, fmt.Sprintf("check file %s gid", file)) + t.Logf("IDMapped File %s uid=%v, gid=%v", file, uid, gid) + } + } +} diff --git a/mount/mount_linux.go b/mount/mount_linux.go index 837d9b802..69a032b9f 100644 --- a/mount/mount_linux.go +++ b/mount/mount_linux.go @@ -93,26 +93,23 @@ func (m *Mount) mount(target string) (err error) { var ( chdir string recalcOpt bool - usernsFd int + usernsFd *os.File options = m.Options ) opt := parseMountOptions(options) // The only remapping of both GID and UID is supported if opt.uidmap != "" && opt.gidmap != "" { - var ( - childProcCleanUp func() - ) - if usernsFd, childProcCleanUp, err = GetUsernsFD(opt.uidmap, opt.gidmap); err != nil { + if usernsFd, err = GetUsernsFD(opt.uidmap, opt.gidmap); err != nil { return err } - defer childProcCleanUp() + defer usernsFd.Close() // overlay expects lowerdir's to be remapped instead if m.Type == "overlay" { var ( userNsCleanUp func() ) - options, userNsCleanUp, err = prepareIDMappedOverlay(usernsFd, options) + options, userNsCleanUp, err = prepareIDMappedOverlay(int(usernsFd.Fd()), options) defer userNsCleanUp() if err != nil { @@ -196,7 +193,7 @@ func (m *Mount) mount(target string) (err error) { // remap non-overlay mount point if opt.uidmap != "" && opt.gidmap != "" && m.Type != "overlay" { - if err := IDMapMount(target, target, usernsFd); err != nil { + if err := IDMapMount(target, target, int(usernsFd.Fd())); err != nil { return err } } diff --git a/snapshots/overlay/overlayutils/check.go b/snapshots/overlay/overlayutils/check.go index e94c95a88..678a7e1cb 100644 --- a/snapshots/overlay/overlayutils/check.go +++ b/snapshots/overlay/overlayutils/check.go @@ -255,13 +255,13 @@ func SupportsIDMappedMounts() (bool, error) { uidmap := fmt.Sprintf("%d:%d:%d", uidMap.ContainerID, uidMap.HostID, uidMap.Size) gidmap := fmt.Sprintf("%d:%d:%d", gidMap.ContainerID, gidMap.HostID, gidMap.Size) - usernsFd, childProcCleanUp, err := mount.GetUsernsFD(uidmap, gidmap) + usernsFd, err := mount.GetUsernsFD(uidmap, gidmap) if err != nil { return false, err } - defer childProcCleanUp() + defer usernsFd.Close() - if err = mount.IDMapMount(lowerDir, lowerDir, usernsFd); err != nil { + if err = mount.IDMapMount(lowerDir, lowerDir, int(usernsFd.Fd())); err != nil { return false, fmt.Errorf("failed to remap lowerdir %s: %w", lowerDir, err) } defer func() { diff --git a/sys/userns_unsafe_linux.go b/sys/userns_unsafe_linux.go index bedf8943c..4ebe46c32 100644 --- a/sys/userns_unsafe_linux.go +++ b/sys/userns_unsafe_linux.go @@ -22,44 +22,73 @@ import ( "unsafe" ) -// ProcSyncType is used for synchronization -// between parent and child processes. -type ProcSyncType uint8 - -const ( - // ProcSyncExit tells child "it's time to exit". - ProcSyncExit ProcSyncType = 0x1 -) - +// ForkUserns is to fork child process with user namespace. It returns child +// process's pid and pidfd reference to the child process. +// +// Precondition: The runtime OS thread must be locked, which is GO runtime +// requirement. +// +// Beside this, the child process sets PR_SET_PDEATHSIG with SIGKILL so that +// the parent process's OS thread must be locked. Otherwise, the exit event of +// parent process's OS thread will send kill signal to the child process, +// even if parent process is still running. +// //go:norace //go:noinline -func ForkUserns(pipeMap [2]int) (pid uintptr, errno syscall.Errno) { - var sync ProcSyncType +func ForkUserns() (_pid uintptr, _pidfd uintptr, _ syscall.Errno) { + var ( + pidfd uintptr + pid, ppid uintptr + err syscall.Errno + ) + + ppid, _, err = syscall.RawSyscall(uintptr(syscall.SYS_GETPID), 0, 0, 0) + if err != 0 { + return 0, 0, err + } beforeFork() if runtime.GOARCH == "s390x" { - pid, _, errno = syscall.RawSyscall6(uintptr(syscall.SYS_CLONE), 0, syscall.CLONE_NEWUSER|uintptr(syscall.SIGCHLD), 0, 0, 0, 0) + // NOTE: + // + // On the s390 architectures, the order of the first two + // arguments is reversed. + // + // REF: https://man7.org/linux/man-pages/man2/clone.2.html + pid, _, err = syscall.RawSyscall(syscall.SYS_CLONE, + 0, + uintptr(syscall.CLONE_NEWUSER|syscall.SIGCHLD|syscall.CLONE_PIDFD), + uintptr(unsafe.Pointer(&pidfd)), + ) } else { - pid, _, errno = syscall.RawSyscall6(uintptr(syscall.SYS_CLONE), syscall.CLONE_NEWUSER|uintptr(syscall.SIGCHLD), 0, 0, 0, 0, 0) + pid, _, err = syscall.RawSyscall(syscall.SYS_CLONE, + uintptr(syscall.CLONE_NEWUSER|syscall.SIGCHLD|syscall.CLONE_PIDFD), + 0, + uintptr(unsafe.Pointer(&pidfd)), + ) } - if errno != 0 || pid != 0 { + if err != 0 || pid != 0 { afterFork() - return pid, errno + return pid, pidfd, err } - afterForkInChild() - if _, _, errno = syscall.RawSyscall(syscall.SYS_CLOSE, uintptr(pipeMap[1]), 0, 0); errno != 0 { - goto err - } - if _, _, errno = syscall.RawSyscall6(syscall.SYS_PRCTL, syscall.PR_SET_PDEATHSIG, uintptr(syscall.SIGKILL), 0, 0, 0, 0); errno != 0 { - goto err - } - // wait for parent's signal - if _, _, errno = syscall.RawSyscall6(syscall.SYS_READ, uintptr(pipeMap[0]), uintptr(unsafe.Pointer(&sync)), unsafe.Sizeof(sync), 0, 0, 0); errno != 0 || sync != ProcSyncExit { + + if _, _, err = syscall.RawSyscall(syscall.SYS_PRCTL, syscall.PR_SET_PDEATHSIG, uintptr(syscall.SIGKILL), 0); err != 0 { goto err } + pid, _, err = syscall.RawSyscall(syscall.SYS_GETPPID, 0, 0, 0) + if err != 0 { + goto err + } + + // exit if re-parent + if pid != ppid { + goto err + } + + _, _, err = syscall.RawSyscall(syscall.SYS_PPOLL, 0, 0, 0) err: - syscall.RawSyscall6(syscall.SYS_EXIT, uintptr(errno), 0, 0, 0, 0, 0) + syscall.RawSyscall(syscall.SYS_EXIT, uintptr(err), 0, 0) panic("unreachable") }