refactor: reduce duplicate code
Signed-off-by: Ye Sijun <junnplus@gmail.com>
This commit is contained in:
		| @@ -701,11 +701,8 @@ func WithUIDGID(uid, gid uint32) SpecOpts { | |||||||
| func WithUserID(uid uint32) SpecOpts { | func WithUserID(uid uint32) SpecOpts { | ||||||
| 	return func(ctx context.Context, client Client, c *containers.Container, s *Spec) (err error) { | 	return func(ctx context.Context, client Client, c *containers.Container, s *Spec) (err error) { | ||||||
| 		setProcess(s) | 		setProcess(s) | ||||||
| 		if c.Snapshotter == "" && c.SnapshotKey == "" { | 		setUser := func(root string) error { | ||||||
| 			if !isRootfsAbs(s.Root.Path) { | 			user, err := UserFromPath(root, func(u user.User) bool { | ||||||
| 				return errors.New("rootfs absolute path is required") |  | ||||||
| 			} |  | ||||||
| 			user, err := UserFromPath(s.Root.Path, func(u user.User) bool { |  | ||||||
| 				return u.Uid == int(uid) | 				return u.Uid == int(uid) | ||||||
| 			}) | 			}) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| @@ -717,7 +714,12 @@ func WithUserID(uid uint32) SpecOpts { | |||||||
| 			} | 			} | ||||||
| 			s.Process.User.UID, s.Process.User.GID = uint32(user.Uid), uint32(user.Gid) | 			s.Process.User.UID, s.Process.User.GID = uint32(user.Uid), uint32(user.Gid) | ||||||
| 			return nil | 			return nil | ||||||
|  | 		} | ||||||
|  | 		if c.Snapshotter == "" && c.SnapshotKey == "" { | ||||||
|  | 			if !isRootfsAbs(s.Root.Path) { | ||||||
|  | 				return errors.New("rootfs absolute path is required") | ||||||
|  | 			} | ||||||
|  | 			return setUser(s.Root.Path) | ||||||
| 		} | 		} | ||||||
| 		if c.Snapshotter == "" { | 		if c.Snapshotter == "" { | ||||||
| 			return errors.New("no snapshotter set for container") | 			return errors.New("no snapshotter set for container") | ||||||
| @@ -732,20 +734,7 @@ func WithUserID(uid uint32) SpecOpts { | |||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		mounts = tryReadonlyMounts(mounts) | 		mounts = tryReadonlyMounts(mounts) | ||||||
| 		return mount.WithTempMount(ctx, mounts, func(root string) error { | 		return mount.WithTempMount(ctx, mounts, setUser) | ||||||
| 			user, err := UserFromPath(root, func(u user.User) bool { |  | ||||||
| 				return u.Uid == int(uid) |  | ||||||
| 			}) |  | ||||||
| 			if err != nil { |  | ||||||
| 				if os.IsNotExist(err) || err == ErrNoUsersFound { |  | ||||||
| 					s.Process.User.UID, s.Process.User.GID = uid, 0 |  | ||||||
| 					return nil |  | ||||||
| 				} |  | ||||||
| 				return err |  | ||||||
| 			} |  | ||||||
| 			s.Process.User.UID, s.Process.User.GID = uint32(user.Uid), uint32(user.Gid) |  | ||||||
| 			return nil |  | ||||||
| 		}) |  | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -759,11 +748,8 @@ func WithUsername(username string) SpecOpts { | |||||||
| 	return func(ctx context.Context, client Client, c *containers.Container, s *Spec) (err error) { | 	return func(ctx context.Context, client Client, c *containers.Container, s *Spec) (err error) { | ||||||
| 		setProcess(s) | 		setProcess(s) | ||||||
| 		if s.Linux != nil { | 		if s.Linux != nil { | ||||||
| 			if c.Snapshotter == "" && c.SnapshotKey == "" { | 			setUser := func(root string) error { | ||||||
| 				if !isRootfsAbs(s.Root.Path) { | 				user, err := UserFromPath(root, func(u user.User) bool { | ||||||
| 					return errors.New("rootfs absolute path is required") |  | ||||||
| 				} |  | ||||||
| 				user, err := UserFromPath(s.Root.Path, func(u user.User) bool { |  | ||||||
| 					return u.Name == username | 					return u.Name == username | ||||||
| 				}) | 				}) | ||||||
| 				if err != nil { | 				if err != nil { | ||||||
| @@ -772,6 +758,12 @@ func WithUsername(username string) SpecOpts { | |||||||
| 				s.Process.User.UID, s.Process.User.GID = uint32(user.Uid), uint32(user.Gid) | 				s.Process.User.UID, s.Process.User.GID = uint32(user.Uid), uint32(user.Gid) | ||||||
| 				return nil | 				return nil | ||||||
| 			} | 			} | ||||||
|  | 			if c.Snapshotter == "" && c.SnapshotKey == "" { | ||||||
|  | 				if !isRootfsAbs(s.Root.Path) { | ||||||
|  | 					return errors.New("rootfs absolute path is required") | ||||||
|  | 				} | ||||||
|  | 				return setUser(s.Root.Path) | ||||||
|  | 			} | ||||||
| 			if c.Snapshotter == "" { | 			if c.Snapshotter == "" { | ||||||
| 				return errors.New("no snapshotter set for container") | 				return errors.New("no snapshotter set for container") | ||||||
| 			} | 			} | ||||||
| @@ -785,16 +777,7 @@ func WithUsername(username string) SpecOpts { | |||||||
| 			} | 			} | ||||||
|  |  | ||||||
| 			mounts = tryReadonlyMounts(mounts) | 			mounts = tryReadonlyMounts(mounts) | ||||||
| 			return mount.WithTempMount(ctx, mounts, func(root string) error { | 			return mount.WithTempMount(ctx, mounts, setUser) | ||||||
| 				user, err := UserFromPath(root, func(u user.User) bool { |  | ||||||
| 					return u.Name == username |  | ||||||
| 				}) |  | ||||||
| 				if err != nil { |  | ||||||
| 					return err |  | ||||||
| 				} |  | ||||||
| 				s.Process.User.UID, s.Process.User.GID = uint32(user.Uid), uint32(user.Gid) |  | ||||||
| 				return nil |  | ||||||
| 			}) |  | ||||||
| 		} else if s.Windows != nil { | 		} else if s.Windows != nil { | ||||||
| 			s.Process.User.Username = username | 			s.Process.User.Username = username | ||||||
| 		} else { | 		} else { | ||||||
|   | |||||||
| @@ -18,6 +18,7 @@ package oci | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
|  | 	"fmt" | ||||||
| 	"os" | 	"os" | ||||||
| 	"path/filepath" | 	"path/filepath" | ||||||
| 	"testing" | 	"testing" | ||||||
| @@ -30,6 +31,123 @@ import ( | |||||||
| 	"golang.org/x/sys/unix" | 	"golang.org/x/sys/unix" | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | // nolint:gosec | ||||||
|  | func TestWithUserID(t *testing.T) { | ||||||
|  | 	t.Parallel() | ||||||
|  |  | ||||||
|  | 	expectedPasswd := `root:x:0:0:root:/root:/bin/ash | ||||||
|  | guest:x:405:100:guest:/dev/null:/sbin/nologin | ||||||
|  | ` | ||||||
|  | 	td := t.TempDir() | ||||||
|  | 	apply := fstest.Apply( | ||||||
|  | 		fstest.CreateDir("/etc", 0777), | ||||||
|  | 		fstest.CreateFile("/etc/passwd", []byte(expectedPasswd), 0777), | ||||||
|  | 	) | ||||||
|  | 	if err := apply.Apply(td); err != nil { | ||||||
|  | 		t.Fatalf("failed to apply: %v", err) | ||||||
|  | 	} | ||||||
|  | 	c := containers.Container{ID: t.Name()} | ||||||
|  | 	testCases := []struct { | ||||||
|  | 		userID      uint32 | ||||||
|  | 		expectedUID uint32 | ||||||
|  | 		expectedGID uint32 | ||||||
|  | 	}{ | ||||||
|  | 		{ | ||||||
|  | 			userID:      0, | ||||||
|  | 			expectedUID: 0, | ||||||
|  | 			expectedGID: 0, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			userID:      405, | ||||||
|  | 			expectedUID: 405, | ||||||
|  | 			expectedGID: 100, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			userID:      1000, | ||||||
|  | 			expectedUID: 1000, | ||||||
|  | 			expectedGID: 0, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 	for _, testCase := range testCases { | ||||||
|  | 		t.Run(fmt.Sprintf("user %d", testCase.userID), func(t *testing.T) { | ||||||
|  | 			t.Parallel() | ||||||
|  | 			s := Spec{ | ||||||
|  | 				Version: specs.Version, | ||||||
|  | 				Root: &specs.Root{ | ||||||
|  | 					Path: td, | ||||||
|  | 				}, | ||||||
|  | 				Linux: &specs.Linux{}, | ||||||
|  | 			} | ||||||
|  | 			err := WithUserID(testCase.userID)(context.Background(), nil, &c, &s) | ||||||
|  | 			assert.NoError(t, err) | ||||||
|  | 			assert.Equal(t, testCase.expectedUID, s.Process.User.UID) | ||||||
|  | 			assert.Equal(t, testCase.expectedGID, s.Process.User.GID) | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // nolint:gosec | ||||||
|  | func TestWithUsername(t *testing.T) { | ||||||
|  | 	t.Parallel() | ||||||
|  |  | ||||||
|  | 	expectedPasswd := `root:x:0:0:root:/root:/bin/ash | ||||||
|  | guest:x:405:100:guest:/dev/null:/sbin/nologin | ||||||
|  | ` | ||||||
|  | 	td := t.TempDir() | ||||||
|  | 	apply := fstest.Apply( | ||||||
|  | 		fstest.CreateDir("/etc", 0777), | ||||||
|  | 		fstest.CreateFile("/etc/passwd", []byte(expectedPasswd), 0777), | ||||||
|  | 	) | ||||||
|  | 	if err := apply.Apply(td); err != nil { | ||||||
|  | 		t.Fatalf("failed to apply: %v", err) | ||||||
|  | 	} | ||||||
|  | 	c := containers.Container{ID: t.Name()} | ||||||
|  | 	testCases := []struct { | ||||||
|  | 		user        string | ||||||
|  | 		expectedUID uint32 | ||||||
|  | 		expectedGID uint32 | ||||||
|  | 		err         string | ||||||
|  | 	}{ | ||||||
|  | 		{ | ||||||
|  | 			user:        "root", | ||||||
|  | 			expectedUID: 0, | ||||||
|  | 			expectedGID: 0, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			user:        "guest", | ||||||
|  | 			expectedUID: 405, | ||||||
|  | 			expectedGID: 100, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			user: "1000", | ||||||
|  | 			err:  "no users found", | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			user: "unknown", | ||||||
|  | 			err:  "no users found", | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 	for _, testCase := range testCases { | ||||||
|  | 		t.Run(testCase.user, func(t *testing.T) { | ||||||
|  | 			t.Parallel() | ||||||
|  | 			s := Spec{ | ||||||
|  | 				Version: specs.Version, | ||||||
|  | 				Root: &specs.Root{ | ||||||
|  | 					Path: td, | ||||||
|  | 				}, | ||||||
|  | 				Linux: &specs.Linux{}, | ||||||
|  | 			} | ||||||
|  | 			err := WithUsername(testCase.user)(context.Background(), nil, &c, &s) | ||||||
|  | 			if err != nil { | ||||||
|  | 				assert.EqualError(t, err, testCase.err) | ||||||
|  | 			} | ||||||
|  | 			assert.Equal(t, testCase.expectedUID, s.Process.User.UID) | ||||||
|  | 			assert.Equal(t, testCase.expectedGID, s.Process.User.GID) | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | } | ||||||
|  |  | ||||||
| // nolint:gosec | // nolint:gosec | ||||||
| func TestWithAdditionalGIDs(t *testing.T) { | func TestWithAdditionalGIDs(t *testing.T) { | ||||||
| 	t.Parallel() | 	t.Parallel() | ||||||
| @@ -54,7 +172,6 @@ sys:x:3:root,bin,adm | |||||||
| 	c := containers.Container{ID: t.Name()} | 	c := containers.Container{ID: t.Name()} | ||||||
|  |  | ||||||
| 	testCases := []struct { | 	testCases := []struct { | ||||||
| 		name     string |  | ||||||
| 		user     string | 		user     string | ||||||
| 		expected []uint32 | 		expected []uint32 | ||||||
| 	}{ | 	}{ | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Ye Sijun
					Ye Sijun