refactor: reduce duplicate code

Signed-off-by: Ye Sijun <junnplus@gmail.com>
This commit is contained in:
Ye Sijun 2022-06-24 15:45:37 +08:00
parent 279f4069fe
commit 1ab42be15d
No known key found for this signature in database
GPG Key ID: 0582626C83FA9CD0
2 changed files with 136 additions and 36 deletions

View File

@ -701,11 +701,8 @@ func WithUIDGID(uid, gid uint32) SpecOpts {
func WithUserID(uid uint32) SpecOpts { func WithUserID(uid uint32) SpecOpts {
return func(ctx context.Context, client Client, c *containers.Container, s *Spec) (err error) { return func(ctx context.Context, client Client, c *containers.Container, s *Spec) (err error) {
setProcess(s) setProcess(s)
if c.Snapshotter == "" && c.SnapshotKey == "" { setUser := func(root string) error {
if !isRootfsAbs(s.Root.Path) { user, err := UserFromPath(root, func(u user.User) bool {
return errors.New("rootfs absolute path is required")
}
user, err := UserFromPath(s.Root.Path, func(u user.User) bool {
return u.Uid == int(uid) return u.Uid == int(uid)
}) })
if err != nil { if err != nil {
@ -717,7 +714,12 @@ func WithUserID(uid uint32) SpecOpts {
} }
s.Process.User.UID, s.Process.User.GID = uint32(user.Uid), uint32(user.Gid) s.Process.User.UID, s.Process.User.GID = uint32(user.Uid), uint32(user.Gid)
return nil return nil
}
if c.Snapshotter == "" && c.SnapshotKey == "" {
if !isRootfsAbs(s.Root.Path) {
return errors.New("rootfs absolute path is required")
}
return setUser(s.Root.Path)
} }
if c.Snapshotter == "" { if c.Snapshotter == "" {
return errors.New("no snapshotter set for container") return errors.New("no snapshotter set for container")
@ -732,20 +734,7 @@ func WithUserID(uid uint32) SpecOpts {
} }
mounts = tryReadonlyMounts(mounts) mounts = tryReadonlyMounts(mounts)
return mount.WithTempMount(ctx, mounts, func(root string) error { return mount.WithTempMount(ctx, mounts, setUser)
user, err := UserFromPath(root, func(u user.User) bool {
return u.Uid == int(uid)
})
if err != nil {
if os.IsNotExist(err) || err == ErrNoUsersFound {
s.Process.User.UID, s.Process.User.GID = uid, 0
return nil
}
return err
}
s.Process.User.UID, s.Process.User.GID = uint32(user.Uid), uint32(user.Gid)
return nil
})
} }
} }
@ -759,11 +748,8 @@ func WithUsername(username string) SpecOpts {
return func(ctx context.Context, client Client, c *containers.Container, s *Spec) (err error) { return func(ctx context.Context, client Client, c *containers.Container, s *Spec) (err error) {
setProcess(s) setProcess(s)
if s.Linux != nil { if s.Linux != nil {
if c.Snapshotter == "" && c.SnapshotKey == "" { setUser := func(root string) error {
if !isRootfsAbs(s.Root.Path) { user, err := UserFromPath(root, func(u user.User) bool {
return errors.New("rootfs absolute path is required")
}
user, err := UserFromPath(s.Root.Path, func(u user.User) bool {
return u.Name == username return u.Name == username
}) })
if err != nil { if err != nil {
@ -772,6 +758,12 @@ func WithUsername(username string) SpecOpts {
s.Process.User.UID, s.Process.User.GID = uint32(user.Uid), uint32(user.Gid) s.Process.User.UID, s.Process.User.GID = uint32(user.Uid), uint32(user.Gid)
return nil return nil
} }
if c.Snapshotter == "" && c.SnapshotKey == "" {
if !isRootfsAbs(s.Root.Path) {
return errors.New("rootfs absolute path is required")
}
return setUser(s.Root.Path)
}
if c.Snapshotter == "" { if c.Snapshotter == "" {
return errors.New("no snapshotter set for container") return errors.New("no snapshotter set for container")
} }
@ -785,16 +777,7 @@ func WithUsername(username string) SpecOpts {
} }
mounts = tryReadonlyMounts(mounts) mounts = tryReadonlyMounts(mounts)
return mount.WithTempMount(ctx, mounts, func(root string) error { return mount.WithTempMount(ctx, mounts, setUser)
user, err := UserFromPath(root, func(u user.User) bool {
return u.Name == username
})
if err != nil {
return err
}
s.Process.User.UID, s.Process.User.GID = uint32(user.Uid), uint32(user.Gid)
return nil
})
} else if s.Windows != nil { } else if s.Windows != nil {
s.Process.User.Username = username s.Process.User.Username = username
} else { } else {

View File

@ -18,6 +18,7 @@ package oci
import ( import (
"context" "context"
"fmt"
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
@ -30,6 +31,123 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
// nolint:gosec
func TestWithUserID(t *testing.T) {
t.Parallel()
expectedPasswd := `root:x:0:0:root:/root:/bin/ash
guest:x:405:100:guest:/dev/null:/sbin/nologin
`
td := t.TempDir()
apply := fstest.Apply(
fstest.CreateDir("/etc", 0777),
fstest.CreateFile("/etc/passwd", []byte(expectedPasswd), 0777),
)
if err := apply.Apply(td); err != nil {
t.Fatalf("failed to apply: %v", err)
}
c := containers.Container{ID: t.Name()}
testCases := []struct {
userID uint32
expectedUID uint32
expectedGID uint32
}{
{
userID: 0,
expectedUID: 0,
expectedGID: 0,
},
{
userID: 405,
expectedUID: 405,
expectedGID: 100,
},
{
userID: 1000,
expectedUID: 1000,
expectedGID: 0,
},
}
for _, testCase := range testCases {
t.Run(fmt.Sprintf("user %d", testCase.userID), func(t *testing.T) {
t.Parallel()
s := Spec{
Version: specs.Version,
Root: &specs.Root{
Path: td,
},
Linux: &specs.Linux{},
}
err := WithUserID(testCase.userID)(context.Background(), nil, &c, &s)
assert.NoError(t, err)
assert.Equal(t, testCase.expectedUID, s.Process.User.UID)
assert.Equal(t, testCase.expectedGID, s.Process.User.GID)
})
}
}
// nolint:gosec
func TestWithUsername(t *testing.T) {
t.Parallel()
expectedPasswd := `root:x:0:0:root:/root:/bin/ash
guest:x:405:100:guest:/dev/null:/sbin/nologin
`
td := t.TempDir()
apply := fstest.Apply(
fstest.CreateDir("/etc", 0777),
fstest.CreateFile("/etc/passwd", []byte(expectedPasswd), 0777),
)
if err := apply.Apply(td); err != nil {
t.Fatalf("failed to apply: %v", err)
}
c := containers.Container{ID: t.Name()}
testCases := []struct {
user string
expectedUID uint32
expectedGID uint32
err string
}{
{
user: "root",
expectedUID: 0,
expectedGID: 0,
},
{
user: "guest",
expectedUID: 405,
expectedGID: 100,
},
{
user: "1000",
err: "no users found",
},
{
user: "unknown",
err: "no users found",
},
}
for _, testCase := range testCases {
t.Run(testCase.user, func(t *testing.T) {
t.Parallel()
s := Spec{
Version: specs.Version,
Root: &specs.Root{
Path: td,
},
Linux: &specs.Linux{},
}
err := WithUsername(testCase.user)(context.Background(), nil, &c, &s)
if err != nil {
assert.EqualError(t, err, testCase.err)
}
assert.Equal(t, testCase.expectedUID, s.Process.User.UID)
assert.Equal(t, testCase.expectedGID, s.Process.User.GID)
})
}
}
// nolint:gosec // nolint:gosec
func TestWithAdditionalGIDs(t *testing.T) { func TestWithAdditionalGIDs(t *testing.T) {
t.Parallel() t.Parallel()
@ -54,7 +172,6 @@ sys:x:3:root,bin,adm
c := containers.Container{ID: t.Name()} c := containers.Container{ID: t.Name()}
testCases := []struct { testCases := []struct {
name string
user string user string
expected []uint32 expected []uint32
}{ }{