Merge pull request #1208 from dmcgowan/tar-test
archive: add link breakout checks and tests
This commit is contained in:
commit
0600753bd8
123
archive/path.go
Normal file
123
archive/path.go
Normal file
@ -0,0 +1,123 @@
|
||||
package archive
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var (
|
||||
errTooManyLinks = errors.New("too many links")
|
||||
)
|
||||
|
||||
// rootPath joins a path with a root, evaluating and bounding any
|
||||
// symlink to the root directory.
|
||||
// TODO(dmcgowan): Expose and move to fs package or continuity path driver
|
||||
func rootPath(root, path string) (string, error) {
|
||||
if path == "" {
|
||||
return root, nil
|
||||
}
|
||||
var linksWalked int // to protect against cycles
|
||||
for {
|
||||
i := linksWalked
|
||||
newpath, err := walkLinks(root, path, &linksWalked)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
path = newpath
|
||||
if i == linksWalked {
|
||||
newpath = rootJoin(newpath)
|
||||
if path == newpath {
|
||||
return filepath.Join(root, newpath), nil
|
||||
}
|
||||
path = newpath
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// rootJoin joins a path with root, cleaning up any links that
|
||||
// reference above root.
|
||||
func rootJoin(path string) string {
|
||||
if filepath.IsAbs(path) {
|
||||
path = filepath.Clean(path)
|
||||
}
|
||||
|
||||
// Resolve any ".." or "/.." before joining to root
|
||||
for !filepath.IsAbs(path) {
|
||||
path = "/" + path
|
||||
path = filepath.Clean(path)
|
||||
}
|
||||
|
||||
return path
|
||||
}
|
||||
|
||||
func walkLink(root, path string, linksWalked *int) (newpath string, islink bool, err error) {
|
||||
if *linksWalked > 255 {
|
||||
return "", false, errTooManyLinks
|
||||
}
|
||||
|
||||
path = rootJoin(path)
|
||||
if path == "/" {
|
||||
return path, false, nil
|
||||
}
|
||||
realPath := filepath.Join(root, path)
|
||||
|
||||
fi, err := os.Lstat(realPath)
|
||||
if err != nil {
|
||||
// If path does not yet exist, treat as non-symlink
|
||||
if os.IsNotExist(err) {
|
||||
return path, false, nil
|
||||
}
|
||||
return "", false, err
|
||||
}
|
||||
if fi.Mode()&os.ModeSymlink == 0 {
|
||||
return path, false, nil
|
||||
}
|
||||
newpath, err = os.Readlink(realPath)
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
if filepath.IsAbs(newpath) && strings.HasPrefix(newpath, root) {
|
||||
newpath = newpath[:len(root)]
|
||||
if !strings.HasPrefix(newpath, "/") {
|
||||
newpath = "/" + newpath
|
||||
}
|
||||
}
|
||||
*linksWalked++
|
||||
return newpath, true, nil
|
||||
}
|
||||
|
||||
func walkLinks(root, path string, linksWalked *int) (string, error) {
|
||||
switch dir, file := filepath.Split(path); {
|
||||
case dir == "":
|
||||
newpath, _, err := walkLink(root, file, linksWalked)
|
||||
return newpath, err
|
||||
case file == "":
|
||||
if os.IsPathSeparator(dir[len(dir)-1]) {
|
||||
if dir == "/" {
|
||||
return dir, nil
|
||||
}
|
||||
return walkLinks(root, dir[:len(dir)-1], linksWalked)
|
||||
}
|
||||
newpath, _, err := walkLink(root, dir, linksWalked)
|
||||
return newpath, err
|
||||
default:
|
||||
newdir, err := walkLinks(root, dir, linksWalked)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
newpath, islink, err := walkLink(root, filepath.Join(newdir, file), linksWalked)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !islink {
|
||||
return newpath, nil
|
||||
}
|
||||
if filepath.IsAbs(newpath) {
|
||||
return newpath, nil
|
||||
}
|
||||
return filepath.Join(newdir, newpath), nil
|
||||
}
|
||||
}
|
293
archive/path_test.go
Normal file
293
archive/path_test.go
Normal file
@ -0,0 +1,293 @@
|
||||
package archive
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/containerd/containerd/fs/fstest"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type rootCheck struct {
|
||||
unresolved string
|
||||
expected string
|
||||
scope func(string) string
|
||||
cause error
|
||||
}
|
||||
|
||||
func TestRootPath(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
apply fstest.Applier
|
||||
checks []rootCheck
|
||||
scope func(string) (string, error)
|
||||
}{
|
||||
{
|
||||
name: "SymlinkAbsolute",
|
||||
apply: Symlink("/b", "fs/a/d"),
|
||||
checks: Check("fs/a/d/c/data", "b/c/data"),
|
||||
},
|
||||
{
|
||||
name: "SymlinkRelativePath",
|
||||
apply: Symlink("a", "fs/i"),
|
||||
checks: Check("fs/i", "fs/a"),
|
||||
},
|
||||
{
|
||||
name: "SymlinkSkipSymlinksOutsideScope",
|
||||
apply: Symlink("realdir", "linkdir"),
|
||||
checks: CheckWithScope("foo/bar", "foo/bar", "linkdir"),
|
||||
},
|
||||
{
|
||||
name: "SymlinkLastLink",
|
||||
apply: Symlink("/b", "fs/a/d"),
|
||||
checks: Check("fs/a/d", "b"),
|
||||
},
|
||||
{
|
||||
name: "SymlinkRelativeLinkChangeScope",
|
||||
apply: Symlink("../b", "fs/a/e"),
|
||||
checks: CheckAll(
|
||||
Check("fs/a/e/c/data", "fs/b/c/data"),
|
||||
CheckWithScope("e", "b", "fs/a"), // Original return
|
||||
),
|
||||
},
|
||||
{
|
||||
name: "SymlinkDeepRelativeLinkChangeScope",
|
||||
apply: Symlink("../../../../test", "fs/a/f"),
|
||||
checks: CheckAll(
|
||||
Check("fs/a/f", "test"), // Original return
|
||||
CheckWithScope("a/f", "test", "fs"), // Original return
|
||||
),
|
||||
},
|
||||
{
|
||||
name: "SymlinkRelativeLinkChain",
|
||||
apply: fstest.Apply(
|
||||
Symlink("../g", "fs/b/h"),
|
||||
fstest.Symlink("../../../../../../../../../../../../root", "fs/g"),
|
||||
),
|
||||
checks: Check("fs/b/h", "root"),
|
||||
},
|
||||
{
|
||||
name: "SymlinkBreakoutPath",
|
||||
apply: Symlink("../i/a", "fs/j/k"),
|
||||
checks: CheckWithScope("k", "i/a", "fs/j"),
|
||||
},
|
||||
{
|
||||
name: "SymlinkToRoot",
|
||||
apply: Symlink("/", "foo"),
|
||||
checks: Check("foo", ""),
|
||||
},
|
||||
{
|
||||
name: "SymlinkSlashDotdot",
|
||||
apply: Symlink("/../../", "foo"),
|
||||
checks: Check("foo", ""),
|
||||
},
|
||||
{
|
||||
name: "SymlinkDotdot",
|
||||
apply: Symlink("../../", "foo"),
|
||||
checks: Check("foo", ""),
|
||||
},
|
||||
{
|
||||
name: "SymlinkRelativePath2",
|
||||
apply: Symlink("baz/target", "bar/foo"),
|
||||
checks: Check("bar/foo", "bar/baz/target"),
|
||||
},
|
||||
{
|
||||
name: "SymlinkScopeLink",
|
||||
apply: fstest.Apply(
|
||||
Symlink("root2", "root"),
|
||||
Symlink("../bar", "root2/foo"),
|
||||
),
|
||||
checks: CheckWithScope("foo", "bar", "root"),
|
||||
},
|
||||
{
|
||||
name: "SymlinkSelf",
|
||||
apply: fstest.Apply(
|
||||
Symlink("foo", "root/foo"),
|
||||
),
|
||||
checks: ErrorWithScope("foo", "root", errTooManyLinks),
|
||||
},
|
||||
{
|
||||
name: "SymlinkCircular",
|
||||
apply: fstest.Apply(
|
||||
Symlink("foo", "bar"),
|
||||
Symlink("bar", "foo"),
|
||||
),
|
||||
checks: ErrorWithScope("foo", "", errTooManyLinks), //TODO: Test for circular error
|
||||
},
|
||||
{
|
||||
name: "SymlinkCircularUnderRoot",
|
||||
apply: fstest.Apply(
|
||||
Symlink("baz", "root/bar"),
|
||||
Symlink("../bak", "root/baz"),
|
||||
Symlink("/bar", "root/bak"),
|
||||
),
|
||||
checks: ErrorWithScope("bar", "root", errTooManyLinks), // TODO: Test for circular error
|
||||
},
|
||||
{
|
||||
name: "SymlinkComplexChain",
|
||||
apply: fstest.Apply(
|
||||
fstest.CreateDir("root2", 0777),
|
||||
Symlink("root2", "root"),
|
||||
Symlink("r/s", "root/a"),
|
||||
Symlink("../root/t", "root/r"),
|
||||
Symlink("/../u", "root/root/t/s/b"),
|
||||
Symlink(".", "root/u/c"),
|
||||
Symlink("../v", "root/u/x/y"),
|
||||
Symlink("/../w", "root/u/v"),
|
||||
),
|
||||
checks: CheckWithScope("a/b/c/x/y/z", "w/z", "root"), // Original return
|
||||
},
|
||||
{
|
||||
name: "SymlinkBreakoutNonExistent",
|
||||
apply: fstest.Apply(
|
||||
Symlink("/", "root/slash"),
|
||||
Symlink("/idontexist/../slash", "root/sym"),
|
||||
),
|
||||
checks: CheckWithScope("sym/file", "file", "root"),
|
||||
},
|
||||
{
|
||||
name: "SymlinkNoLexicalCleaning",
|
||||
apply: fstest.Apply(
|
||||
Symlink("/foo/bar", "root/sym"),
|
||||
Symlink("/sym/../baz", "root/hello"),
|
||||
),
|
||||
checks: CheckWithScope("hello", "foo/baz", "root"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, makeRootPathTest(t, test.apply, test.checks))
|
||||
}
|
||||
|
||||
// Add related tests which are unable to follow same pattern
|
||||
t.Run("SymlinkRootScope", testRootPathSymlinkRootScope)
|
||||
t.Run("SymlinkEmpty", testRootPathSymlinkEmpty)
|
||||
}
|
||||
|
||||
func testRootPathSymlinkRootScope(t *testing.T) {
|
||||
tmpdir, err := ioutil.TempDir("", "TestFollowSymlinkRootScope")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.RemoveAll(tmpdir)
|
||||
|
||||
expected, err := filepath.EvalSymlinks(tmpdir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rewrite, err := rootPath("/", tmpdir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if rewrite != expected {
|
||||
t.Fatalf("expected %q got %q", expected, rewrite)
|
||||
}
|
||||
}
|
||||
func testRootPathSymlinkEmpty(t *testing.T) {
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
res, err := rootPath(wd, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if res != wd {
|
||||
t.Fatalf("expected %q got %q", wd, res)
|
||||
}
|
||||
}
|
||||
|
||||
func makeRootPathTest(t *testing.T, apply fstest.Applier, checks []rootCheck) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
applyDir, err := ioutil.TempDir("", "test-root-path-")
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to make temp directory: %+v", err)
|
||||
}
|
||||
defer os.RemoveAll(applyDir)
|
||||
|
||||
if apply != nil {
|
||||
if err := apply.Apply(applyDir); err != nil {
|
||||
t.Fatalf("Apply failed: %+v", err)
|
||||
}
|
||||
}
|
||||
|
||||
for i, check := range checks {
|
||||
root := applyDir
|
||||
if check.scope != nil {
|
||||
root = check.scope(root)
|
||||
}
|
||||
|
||||
actual, err := rootPath(root, check.unresolved)
|
||||
if check.cause != nil {
|
||||
if err == nil {
|
||||
t.Errorf("(Check %d) Expected error %q, %q evaluated as %q", i+1, check.cause.Error(), check.unresolved, actual)
|
||||
}
|
||||
if errors.Cause(err) != check.cause {
|
||||
t.Fatalf("(Check %d) Failed to evaluate root path: %+v", i+1, err)
|
||||
}
|
||||
} else {
|
||||
expected := filepath.Join(root, check.expected)
|
||||
if err != nil {
|
||||
t.Fatalf("(Check %d) Failed to evaluate root path: %+v", i+1, err)
|
||||
}
|
||||
if actual != expected {
|
||||
t.Errorf("(Check %d) Unexpected evaluated path %q, expected %q", i+1, actual, expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Check(unresolved, expected string) []rootCheck {
|
||||
return []rootCheck{
|
||||
{
|
||||
unresolved: unresolved,
|
||||
expected: expected,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func CheckWithScope(unresolved, expected, scope string) []rootCheck {
|
||||
return []rootCheck{
|
||||
{
|
||||
unresolved: unresolved,
|
||||
expected: expected,
|
||||
scope: func(root string) string {
|
||||
return filepath.Join(root, scope)
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func ErrorWithScope(unresolved, scope string, cause error) []rootCheck {
|
||||
return []rootCheck{
|
||||
{
|
||||
unresolved: unresolved,
|
||||
cause: cause,
|
||||
scope: func(root string) string {
|
||||
return filepath.Join(root, scope)
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func CheckAll(checks ...[]rootCheck) []rootCheck {
|
||||
all := make([]rootCheck, 0, len(checks))
|
||||
for _, c := range checks {
|
||||
all = append(all, c...)
|
||||
}
|
||||
return all
|
||||
}
|
||||
|
||||
func Symlink(oldname, newname string) fstest.Applier {
|
||||
dir := filepath.Dir(newname)
|
||||
if dir != "" {
|
||||
return fstest.Apply(
|
||||
fstest.CreateDir(dir, 0755),
|
||||
fstest.Symlink(oldname, newname),
|
||||
)
|
||||
}
|
||||
return fstest.Symlink(oldname, newname)
|
||||
}
|
@ -25,8 +25,6 @@ var (
|
||||
return make([]byte, 32*1024)
|
||||
},
|
||||
}
|
||||
|
||||
breakoutError = errors.New("file name outside of root")
|
||||
)
|
||||
|
||||
// Diff returns a tar stream of the computed filesystem
|
||||
@ -134,7 +132,10 @@ func Apply(ctx context.Context, root string, r io.Reader) (int64, error) {
|
||||
// This happened in some tests where an image had a tarfile without any
|
||||
// parent directories.
|
||||
parent := filepath.Dir(hdr.Name)
|
||||
parentPath := filepath.Join(root, parent)
|
||||
parentPath, err := rootPath(root, parent)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if _, err := os.Lstat(parentPath); err != nil && os.IsNotExist(err) {
|
||||
err = mkdirAll(parentPath, 0600)
|
||||
@ -158,7 +159,11 @@ func Apply(ctx context.Context, root string, r io.Reader) (int64, error) {
|
||||
}
|
||||
defer os.RemoveAll(aufsTempdir)
|
||||
}
|
||||
if err := createTarFile(ctx, filepath.Join(aufsTempdir, basename), root, hdr, tr); err != nil {
|
||||
p, err := rootPath(aufsTempdir, basename)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if err := createTarFile(ctx, p, root, hdr, tr); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
@ -168,16 +173,11 @@ func Apply(ctx context.Context, root string, r io.Reader) (int64, error) {
|
||||
}
|
||||
}
|
||||
|
||||
path := filepath.Join(root, hdr.Name)
|
||||
rel, err := filepath.Rel(root, path)
|
||||
path, err := rootPath(root, hdr.Name)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return 0, errors.Wrap(err, "failed to get root path")
|
||||
}
|
||||
|
||||
// Note as these operations are platform specific, so must the slash be.
|
||||
if strings.HasPrefix(rel, ".."+string(os.PathSeparator)) {
|
||||
return 0, errors.Wrapf(breakoutError, "%q is outside of %q", hdr.Name, root)
|
||||
}
|
||||
base := filepath.Base(path)
|
||||
|
||||
if strings.HasPrefix(base, whiteoutPrefix) {
|
||||
@ -239,7 +239,11 @@ func Apply(ctx context.Context, root string, r io.Reader) (int64, error) {
|
||||
if srcHdr == nil {
|
||||
return 0, fmt.Errorf("Invalid aufs hardlink")
|
||||
}
|
||||
tmpFile, err := os.Open(filepath.Join(aufsTempdir, linkBasename))
|
||||
p, err := rootPath(aufsTempdir, linkBasename)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
tmpFile, err := os.Open(p)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@ -260,7 +264,10 @@ func Apply(ctx context.Context, root string, r io.Reader) (int64, error) {
|
||||
}
|
||||
|
||||
for _, hdr := range dirs {
|
||||
path := filepath.Join(root, hdr.Name)
|
||||
path, err := rootPath(root, hdr.Name)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if err := chtimes(path, boundTime(latestTime(hdr.AccessTime, hdr.ModTime)), boundTime(hdr.ModTime)); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@ -467,25 +474,15 @@ func createTarFile(ctx context.Context, path, extractDir string, hdr *tar.Header
|
||||
}
|
||||
|
||||
case tar.TypeLink:
|
||||
targetPath := filepath.Join(extractDir, hdr.Linkname)
|
||||
// check for hardlink breakout
|
||||
if !strings.HasPrefix(targetPath, extractDir) {
|
||||
return errors.Wrapf(breakoutError, "invalid hardlink %q -> %q", targetPath, hdr.Linkname)
|
||||
targetPath, err := rootPath(extractDir, hdr.Linkname)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.Link(targetPath, path); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case tar.TypeSymlink:
|
||||
// path -> hdr.Linkname = targetPath
|
||||
// e.g. /extractDir/path/to/symlink -> ../2/file = /extractDir/path/2/file
|
||||
targetPath := filepath.Join(filepath.Dir(path), hdr.Linkname)
|
||||
|
||||
// the reason we don't need to check symlinks in the path (with FollowSymlinkInScope) is because
|
||||
// that symlink would first have to be created, which would be caught earlier, at this very check:
|
||||
if !strings.HasPrefix(targetPath, extractDir) {
|
||||
return errors.Wrapf(breakoutError, "invalid symlink %q -> %q", path, hdr.Linkname)
|
||||
}
|
||||
if err := os.Symlink(hdr.Linkname, path); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -1,12 +1,16 @@
|
||||
package archive
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
_ "crypto/sha256"
|
||||
|
||||
@ -46,6 +50,238 @@ func TestBaseDiff(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRelativeSymlinks(t *testing.T) {
|
||||
breakoutLinks := []fstest.Applier{
|
||||
fstest.Apply(
|
||||
baseApplier,
|
||||
fstest.Symlink("../other", "/home/derek/other"),
|
||||
fstest.Symlink("../../etc", "/home/derek/etc"),
|
||||
fstest.Symlink("up/../../other", "/home/derek/updown"),
|
||||
),
|
||||
fstest.Apply(
|
||||
baseApplier,
|
||||
fstest.Symlink("../../../breakout", "/home/derek/breakout"),
|
||||
),
|
||||
fstest.Apply(
|
||||
baseApplier,
|
||||
fstest.Symlink("../../breakout", "/breakout"),
|
||||
),
|
||||
fstest.Apply(
|
||||
baseApplier,
|
||||
fstest.Symlink("etc/../../upandout", "/breakout"),
|
||||
),
|
||||
fstest.Apply(
|
||||
baseApplier,
|
||||
fstest.Symlink("derek/../../../downandout", "/home/breakout"),
|
||||
),
|
||||
fstest.Apply(
|
||||
baseApplier,
|
||||
fstest.Symlink("/etc", "localetc"),
|
||||
),
|
||||
}
|
||||
|
||||
for _, bo := range breakoutLinks {
|
||||
if err := testDiffApply(bo); err != nil {
|
||||
t.Fatalf("Test apply failed: %+v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBreakouts(t *testing.T) {
|
||||
tc := TarContext{}.WithUidGid(os.Getuid(), os.Getgid()).WithModTime(time.Now().UTC())
|
||||
expected := "unbroken"
|
||||
unbrokenCheck := func(root string) error {
|
||||
b, err := ioutil.ReadFile(filepath.Join(root, "etc", "unbroken"))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to read unbroken")
|
||||
}
|
||||
if string(b) != expected {
|
||||
return errors.Errorf("/etc/unbroken: unexpected value %s, expected %s", b, expected)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
sameFile := func(f1, f2 string) func(string) error {
|
||||
return func(root string) error {
|
||||
p1 := filepath.Join(root, f1)
|
||||
p2 := filepath.Join(root, f2)
|
||||
s1, err := os.Stat(p1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s2, err := os.Stat(p2)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !os.SameFile(s1, s2) {
|
||||
return errors.Errorf("files differ: %#v and %#v", s1, s2)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
breakouts := []struct {
|
||||
name string
|
||||
w WriterToTar
|
||||
validator func(string) error
|
||||
}{
|
||||
{
|
||||
name: "SymlinkAbsolute",
|
||||
w: TarAll(
|
||||
tc.Dir("etc", 0755),
|
||||
tc.Symlink("/etc", "localetc"),
|
||||
tc.File("/localetc/unbroken", []byte(expected), 0644),
|
||||
),
|
||||
validator: unbrokenCheck,
|
||||
},
|
||||
{
|
||||
name: "SymlinkUpAndOut",
|
||||
w: TarAll(
|
||||
tc.Dir("etc", 0755),
|
||||
tc.Dir("dummy", 0755),
|
||||
tc.Symlink("/dummy/../etc", "localetc"),
|
||||
tc.File("/localetc/unbroken", []byte(expected), 0644),
|
||||
),
|
||||
validator: unbrokenCheck,
|
||||
},
|
||||
{
|
||||
name: "SymlinkMultipleAbsolute",
|
||||
w: TarAll(
|
||||
tc.Dir("etc", 0755),
|
||||
tc.Dir("dummy", 0755),
|
||||
tc.Symlink("/etc", "/dummy/etc"),
|
||||
tc.Symlink("/dummy/etc", "localetc"),
|
||||
tc.File("/dummy/etc/unbroken", []byte(expected), 0644),
|
||||
),
|
||||
validator: unbrokenCheck,
|
||||
},
|
||||
{
|
||||
name: "SymlinkMultipleRelative",
|
||||
w: TarAll(
|
||||
tc.Dir("etc", 0755),
|
||||
tc.Dir("dummy", 0755),
|
||||
tc.Symlink("/etc", "/dummy/etc"),
|
||||
tc.Symlink("./dummy/etc", "localetc"),
|
||||
tc.File("/dummy/etc/unbroken", []byte(expected), 0644),
|
||||
),
|
||||
validator: unbrokenCheck,
|
||||
},
|
||||
{
|
||||
name: "SymlinkEmptyFile",
|
||||
w: TarAll(
|
||||
tc.Dir("etc", 0755),
|
||||
tc.File("etc/emptied", []byte("notempty"), 0644),
|
||||
tc.Symlink("/etc", "localetc"),
|
||||
tc.File("/localetc/emptied", []byte{}, 0644),
|
||||
),
|
||||
validator: func(root string) error {
|
||||
b, err := ioutil.ReadFile(filepath.Join(root, "etc", "emptied"))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to read unbroken")
|
||||
}
|
||||
if len(b) > 0 {
|
||||
return errors.Errorf("/etc/emptied: non-empty")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "HardlinkRelative",
|
||||
w: TarAll(
|
||||
tc.Dir("etc", 0770),
|
||||
tc.File("/etc/passwd", []byte("inside"), 0644),
|
||||
tc.Dir("breakouts", 0755),
|
||||
tc.Symlink("../../etc", "breakouts/d1"),
|
||||
tc.Link("/breakouts/d1/passwd", "breakouts/mypasswd"),
|
||||
),
|
||||
validator: sameFile("/breakouts/mypasswd", "/etc/passwd"),
|
||||
},
|
||||
{
|
||||
name: "HardlinkDownAndOut",
|
||||
w: TarAll(
|
||||
tc.Dir("etc", 0770),
|
||||
tc.File("/etc/passwd", []byte("inside"), 0644),
|
||||
tc.Dir("breakouts", 0755),
|
||||
tc.Dir("downandout", 0755),
|
||||
tc.Symlink("../downandout/../../etc", "breakouts/d1"),
|
||||
tc.Link("/breakouts/d1/passwd", "breakouts/mypasswd"),
|
||||
),
|
||||
validator: sameFile("/breakouts/mypasswd", "/etc/passwd"),
|
||||
},
|
||||
{
|
||||
name: "HardlinkAbsolute",
|
||||
w: TarAll(
|
||||
tc.Dir("etc", 0770),
|
||||
tc.File("/etc/passwd", []byte("inside"), 0644),
|
||||
tc.Symlink("/etc", "localetc"),
|
||||
tc.Link("/localetc/passwd", "localpasswd"),
|
||||
),
|
||||
validator: sameFile("localpasswd", "/etc/passwd"),
|
||||
},
|
||||
{
|
||||
name: "HardlinkRelativeLong",
|
||||
w: TarAll(
|
||||
tc.Dir("etc", 0770),
|
||||
tc.File("/etc/passwd", []byte("inside"), 0644),
|
||||
tc.Symlink("../../../../../../../etc", "localetc"),
|
||||
tc.Link("/localetc/passwd", "localpasswd"),
|
||||
),
|
||||
validator: sameFile("localpasswd", "/etc/passwd"),
|
||||
},
|
||||
{
|
||||
name: "HardlinkRelativeUpAndOut",
|
||||
w: TarAll(
|
||||
tc.Dir("etc", 0770),
|
||||
tc.File("/etc/passwd", []byte("inside"), 0644),
|
||||
tc.Symlink("upandout/../../../etc", "localetc"),
|
||||
tc.Link("/localetc/passwd", "localpasswd"),
|
||||
),
|
||||
validator: sameFile("localpasswd", "/etc/passwd"),
|
||||
},
|
||||
{
|
||||
name: "HardlinkDirectRelative",
|
||||
w: TarAll(
|
||||
tc.Dir("etc", 0770),
|
||||
tc.File("/etc/passwd", []byte("inside"), 0644),
|
||||
tc.Link("../../../../../etc/passwd", "localpasswd"),
|
||||
),
|
||||
validator: sameFile("localpasswd", "/etc/passwd"),
|
||||
},
|
||||
{
|
||||
name: "HardlinkDirectAbsolute",
|
||||
w: TarAll(
|
||||
tc.Dir("etc", 0770),
|
||||
tc.File("/etc/passwd", []byte("inside"), 0644),
|
||||
tc.Link("/etc/passwd", "localpasswd"),
|
||||
),
|
||||
validator: sameFile("localpasswd", "/etc/passwd"),
|
||||
},
|
||||
{
|
||||
name: "HardlinkSymlinkRelative",
|
||||
w: TarAll(
|
||||
tc.Dir("etc", 0770),
|
||||
tc.File("/etc/passwd", []byte("inside"), 0644),
|
||||
tc.Symlink("../../../../../etc/passwd", "passwdlink"),
|
||||
tc.Link("/passwdlink", "localpasswd"),
|
||||
),
|
||||
validator: sameFile("/localpasswd", "/etc/passwd"),
|
||||
},
|
||||
{
|
||||
name: "HardlinkSymlinkAbsolute",
|
||||
w: TarAll(
|
||||
tc.Dir("etc", 0770),
|
||||
tc.File("/etc/passwd", []byte("inside"), 0644),
|
||||
tc.Symlink("/etc/passwd", "passwdlink"),
|
||||
tc.Link("/passwdlink", "localpasswd"),
|
||||
),
|
||||
validator: sameFile("/localpasswd", "/etc/passwd"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, bo := range breakouts {
|
||||
t.Run(bo.name, makeWriterToTarTest(bo.w, bo.validator))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDiffApply(t *testing.T) {
|
||||
fstest.FSSuite(t, diffApplier{})
|
||||
}
|
||||
@ -118,6 +354,58 @@ func testBaseDiff(a fstest.Applier) error {
|
||||
return fstest.CheckDirectoryEqual(td, dest)
|
||||
}
|
||||
|
||||
func testDiffApply(a fstest.Applier) error {
|
||||
td, err := ioutil.TempDir("", "test-diff-apply-")
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to create temp dir")
|
||||
}
|
||||
defer os.RemoveAll(td)
|
||||
dest, err := ioutil.TempDir("", "test-diff-apply-dest-")
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to create temp dir")
|
||||
}
|
||||
defer os.RemoveAll(dest)
|
||||
|
||||
if err := a.Apply(td); err != nil {
|
||||
return errors.Wrap(err, "failed to apply filesystem changes")
|
||||
}
|
||||
|
||||
diffBytes, err := ioutil.ReadAll(Diff(context.Background(), "", td))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to create diff")
|
||||
}
|
||||
|
||||
if _, err := Apply(context.Background(), dest, bytes.NewReader(diffBytes)); err != nil {
|
||||
return errors.Wrap(err, "failed to apply tar stream")
|
||||
}
|
||||
|
||||
return fstest.CheckDirectoryEqual(td, dest)
|
||||
}
|
||||
|
||||
func makeWriterToTarTest(wt WriterToTar, validate func(string) error) func(*testing.T) {
|
||||
return func(t *testing.T) {
|
||||
td, err := ioutil.TempDir("", "test-writer-to-tar-")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(td)
|
||||
|
||||
tr := TarFromWriterTo(wt)
|
||||
|
||||
if _, err := Apply(context.Background(), td, tr); err != nil {
|
||||
t.Fatalf("Failed to apply tar: %v", err)
|
||||
}
|
||||
|
||||
if validate != nil {
|
||||
if err := validate(td); err != nil {
|
||||
t.Errorf("Validation failed: %v", err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
type diffApplier struct{}
|
||||
|
||||
func (d diffApplier) TestContext(ctx context.Context) (context.Context, func(), error) {
|
||||
@ -174,3 +462,182 @@ func requireTar(t *testing.T) {
|
||||
t.Skipf("%s not found, skipping", tarCmd)
|
||||
}
|
||||
}
|
||||
|
||||
// WriterToTar is an type which writes to a tar writer
|
||||
type WriterToTar interface {
|
||||
WriteTo(*tar.Writer) error
|
||||
}
|
||||
|
||||
type writerToFn func(*tar.Writer) error
|
||||
|
||||
func (w writerToFn) WriteTo(tw *tar.Writer) error {
|
||||
return w(tw)
|
||||
}
|
||||
|
||||
// TarAll creates a WriterToTar which calls all the provided writers
|
||||
// in the order in which they are provided.
|
||||
func TarAll(wt ...WriterToTar) WriterToTar {
|
||||
return writerToFn(func(tw *tar.Writer) error {
|
||||
for _, w := range wt {
|
||||
if err := w.WriteTo(tw); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// TarFromWriterTo is used to create a tar stream from a tar record
|
||||
// creator. This can be used to manifacture more specific tar records
|
||||
// which allow testing specific tar cases which may be encountered
|
||||
// by the untar process.
|
||||
func TarFromWriterTo(wt WriterToTar) io.ReadCloser {
|
||||
r, w := io.Pipe()
|
||||
go func() {
|
||||
tw := tar.NewWriter(w)
|
||||
if err := wt.WriteTo(tw); err != nil {
|
||||
w.CloseWithError(err)
|
||||
return
|
||||
}
|
||||
w.CloseWithError(tw.Close())
|
||||
}()
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// TarContext is used to create tar records
|
||||
type TarContext struct {
|
||||
Uid int
|
||||
Gid int
|
||||
|
||||
// ModTime sets the modtimes for all files, if nil the current time
|
||||
// is used for each file when it was written
|
||||
ModTime *time.Time
|
||||
|
||||
Xattrs map[string]string
|
||||
}
|
||||
|
||||
func (tc TarContext) newHeader(mode os.FileMode, name, link string, size int64) *tar.Header {
|
||||
ti := tarInfo{
|
||||
name: name,
|
||||
mode: mode,
|
||||
size: size,
|
||||
modt: tc.ModTime,
|
||||
hdr: &tar.Header{
|
||||
Uid: tc.Uid,
|
||||
Gid: tc.Gid,
|
||||
Xattrs: tc.Xattrs,
|
||||
},
|
||||
}
|
||||
|
||||
if mode&os.ModeSymlink == 0 && link != "" {
|
||||
ti.hdr.Typeflag = tar.TypeLink
|
||||
ti.hdr.Linkname = link
|
||||
}
|
||||
|
||||
hdr, err := tar.FileInfoHeader(ti, link)
|
||||
if err != nil {
|
||||
// Only returns an error on bad input mode
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return hdr
|
||||
}
|
||||
|
||||
type tarInfo struct {
|
||||
name string
|
||||
mode os.FileMode
|
||||
size int64
|
||||
modt *time.Time
|
||||
hdr *tar.Header
|
||||
}
|
||||
|
||||
func (ti tarInfo) Name() string {
|
||||
return ti.name
|
||||
}
|
||||
|
||||
func (ti tarInfo) Size() int64 {
|
||||
return ti.size
|
||||
}
|
||||
func (ti tarInfo) Mode() os.FileMode {
|
||||
return ti.mode
|
||||
}
|
||||
|
||||
func (ti tarInfo) ModTime() time.Time {
|
||||
if ti.modt != nil {
|
||||
return *ti.modt
|
||||
}
|
||||
return time.Now().UTC()
|
||||
}
|
||||
|
||||
func (ti tarInfo) IsDir() bool {
|
||||
return (ti.mode & os.ModeDir) != 0
|
||||
}
|
||||
func (ti tarInfo) Sys() interface{} {
|
||||
return ti.hdr
|
||||
}
|
||||
|
||||
func (tc TarContext) WithUidGid(uid, gid int) TarContext {
|
||||
ntc := tc
|
||||
ntc.Uid = uid
|
||||
ntc.Gid = gid
|
||||
return ntc
|
||||
}
|
||||
|
||||
func (tc TarContext) WithModTime(modtime time.Time) TarContext {
|
||||
ntc := tc
|
||||
ntc.ModTime = &modtime
|
||||
return ntc
|
||||
}
|
||||
|
||||
// WithXattrs adds these xattrs to all files, merges with any
|
||||
// previously added xattrs
|
||||
func (tc TarContext) WithXattrs(xattrs map[string]string) TarContext {
|
||||
ntc := tc
|
||||
if ntc.Xattrs == nil {
|
||||
ntc.Xattrs = map[string]string{}
|
||||
}
|
||||
for k, v := range xattrs {
|
||||
ntc.Xattrs[k] = v
|
||||
}
|
||||
return ntc
|
||||
}
|
||||
|
||||
func (tc TarContext) File(name string, content []byte, perm os.FileMode) WriterToTar {
|
||||
return writerToFn(func(tw *tar.Writer) error {
|
||||
return writeHeaderAndContent(tw, tc.newHeader(perm, name, "", int64(len(content))), content)
|
||||
})
|
||||
}
|
||||
|
||||
func (tc TarContext) Dir(name string, perm os.FileMode) WriterToTar {
|
||||
return writerToFn(func(tw *tar.Writer) error {
|
||||
return writeHeaderAndContent(tw, tc.newHeader(perm|os.ModeDir, name, "", 0), nil)
|
||||
})
|
||||
}
|
||||
|
||||
func (tc TarContext) Symlink(oldname, newname string) WriterToTar {
|
||||
return writerToFn(func(tw *tar.Writer) error {
|
||||
return writeHeaderAndContent(tw, tc.newHeader(0777|os.ModeSymlink, newname, oldname, 0), nil)
|
||||
})
|
||||
}
|
||||
|
||||
func (tc TarContext) Link(oldname, newname string) WriterToTar {
|
||||
return writerToFn(func(tw *tar.Writer) error {
|
||||
return writeHeaderAndContent(tw, tc.newHeader(0777, newname, oldname, 0), nil)
|
||||
})
|
||||
}
|
||||
|
||||
func writeHeaderAndContent(tw *tar.Writer, h *tar.Header, b []byte) error {
|
||||
if h.Size != int64(len(b)) {
|
||||
return errors.New("bad content length")
|
||||
}
|
||||
if err := tw.WriteHeader(h); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(b) > 0 {
|
||||
if _, err := tw.Write(b); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user