diff --git a/oci/spec_opts.go b/oci/spec_opts.go index dc62ca775..355fd45b5 100644 --- a/oci/spec_opts.go +++ b/oci/spec_opts.go @@ -701,11 +701,8 @@ func WithUIDGID(uid, gid uint32) SpecOpts { func WithUserID(uid uint32) SpecOpts { return func(ctx context.Context, client Client, c *containers.Container, s *Spec) (err error) { setProcess(s) - if c.Snapshotter == "" && c.SnapshotKey == "" { - if !isRootfsAbs(s.Root.Path) { - return errors.New("rootfs absolute path is required") - } - user, err := UserFromPath(s.Root.Path, func(u user.User) bool { + setUser := func(root string) error { + user, err := UserFromPath(root, func(u user.User) bool { return u.Uid == int(uid) }) if err != nil { @@ -717,7 +714,12 @@ func WithUserID(uid uint32) SpecOpts { } s.Process.User.UID, s.Process.User.GID = uint32(user.Uid), uint32(user.Gid) return nil - + } + if c.Snapshotter == "" && c.SnapshotKey == "" { + if !isRootfsAbs(s.Root.Path) { + return errors.New("rootfs absolute path is required") + } + return setUser(s.Root.Path) } if c.Snapshotter == "" { return errors.New("no snapshotter set for container") @@ -732,20 +734,7 @@ func WithUserID(uid uint32) SpecOpts { } mounts = tryReadonlyMounts(mounts) - return mount.WithTempMount(ctx, mounts, func(root string) error { - user, err := UserFromPath(root, func(u user.User) bool { - return u.Uid == int(uid) - }) - if err != nil { - if os.IsNotExist(err) || err == ErrNoUsersFound { - s.Process.User.UID, s.Process.User.GID = uid, 0 - return nil - } - return err - } - s.Process.User.UID, s.Process.User.GID = uint32(user.Uid), uint32(user.Gid) - return nil - }) + return mount.WithTempMount(ctx, mounts, setUser) } } @@ -759,11 +748,8 @@ func WithUsername(username string) SpecOpts { return func(ctx context.Context, client Client, c *containers.Container, s *Spec) (err error) { setProcess(s) if s.Linux != nil { - if c.Snapshotter == "" && c.SnapshotKey == "" { - if !isRootfsAbs(s.Root.Path) { - return errors.New("rootfs absolute path is required") - } - user, err := UserFromPath(s.Root.Path, func(u user.User) bool { + setUser := func(root string) error { + user, err := UserFromPath(root, func(u user.User) bool { return u.Name == username }) if err != nil { @@ -772,6 +758,12 @@ func WithUsername(username string) SpecOpts { s.Process.User.UID, s.Process.User.GID = uint32(user.Uid), uint32(user.Gid) return nil } + if c.Snapshotter == "" && c.SnapshotKey == "" { + if !isRootfsAbs(s.Root.Path) { + return errors.New("rootfs absolute path is required") + } + return setUser(s.Root.Path) + } if c.Snapshotter == "" { return errors.New("no snapshotter set for container") } @@ -785,16 +777,7 @@ func WithUsername(username string) SpecOpts { } mounts = tryReadonlyMounts(mounts) - return mount.WithTempMount(ctx, mounts, func(root string) error { - user, err := UserFromPath(root, func(u user.User) bool { - return u.Name == username - }) - if err != nil { - return err - } - s.Process.User.UID, s.Process.User.GID = uint32(user.Uid), uint32(user.Gid) - return nil - }) + return mount.WithTempMount(ctx, mounts, setUser) } else if s.Windows != nil { s.Process.User.Username = username } else { diff --git a/oci/spec_opts_linux_test.go b/oci/spec_opts_linux_test.go index c027c3517..08cbd71ef 100644 --- a/oci/spec_opts_linux_test.go +++ b/oci/spec_opts_linux_test.go @@ -18,6 +18,7 @@ package oci import ( "context" + "fmt" "os" "path/filepath" "testing" @@ -30,6 +31,123 @@ import ( "golang.org/x/sys/unix" ) +// nolint:gosec +func TestWithUserID(t *testing.T) { + t.Parallel() + + expectedPasswd := `root:x:0:0:root:/root:/bin/ash +guest:x:405:100:guest:/dev/null:/sbin/nologin +` + td := t.TempDir() + apply := fstest.Apply( + fstest.CreateDir("/etc", 0777), + fstest.CreateFile("/etc/passwd", []byte(expectedPasswd), 0777), + ) + if err := apply.Apply(td); err != nil { + t.Fatalf("failed to apply: %v", err) + } + c := containers.Container{ID: t.Name()} + testCases := []struct { + userID uint32 + expectedUID uint32 + expectedGID uint32 + }{ + { + userID: 0, + expectedUID: 0, + expectedGID: 0, + }, + { + userID: 405, + expectedUID: 405, + expectedGID: 100, + }, + { + userID: 1000, + expectedUID: 1000, + expectedGID: 0, + }, + } + for _, testCase := range testCases { + t.Run(fmt.Sprintf("user %d", testCase.userID), func(t *testing.T) { + t.Parallel() + s := Spec{ + Version: specs.Version, + Root: &specs.Root{ + Path: td, + }, + Linux: &specs.Linux{}, + } + err := WithUserID(testCase.userID)(context.Background(), nil, &c, &s) + assert.NoError(t, err) + assert.Equal(t, testCase.expectedUID, s.Process.User.UID) + assert.Equal(t, testCase.expectedGID, s.Process.User.GID) + }) + } +} + +// nolint:gosec +func TestWithUsername(t *testing.T) { + t.Parallel() + + expectedPasswd := `root:x:0:0:root:/root:/bin/ash +guest:x:405:100:guest:/dev/null:/sbin/nologin +` + td := t.TempDir() + apply := fstest.Apply( + fstest.CreateDir("/etc", 0777), + fstest.CreateFile("/etc/passwd", []byte(expectedPasswd), 0777), + ) + if err := apply.Apply(td); err != nil { + t.Fatalf("failed to apply: %v", err) + } + c := containers.Container{ID: t.Name()} + testCases := []struct { + user string + expectedUID uint32 + expectedGID uint32 + err string + }{ + { + user: "root", + expectedUID: 0, + expectedGID: 0, + }, + { + user: "guest", + expectedUID: 405, + expectedGID: 100, + }, + { + user: "1000", + err: "no users found", + }, + { + user: "unknown", + err: "no users found", + }, + } + for _, testCase := range testCases { + t.Run(testCase.user, func(t *testing.T) { + t.Parallel() + s := Spec{ + Version: specs.Version, + Root: &specs.Root{ + Path: td, + }, + Linux: &specs.Linux{}, + } + err := WithUsername(testCase.user)(context.Background(), nil, &c, &s) + if err != nil { + assert.EqualError(t, err, testCase.err) + } + assert.Equal(t, testCase.expectedUID, s.Process.User.UID) + assert.Equal(t, testCase.expectedGID, s.Process.User.GID) + }) + } + +} + // nolint:gosec func TestWithAdditionalGIDs(t *testing.T) { t.Parallel() @@ -54,7 +172,6 @@ sys:x:3:root,bin,adm c := containers.Container{ID: t.Name()} testCases := []struct { - name string user string expected []uint32 }{