Merge pull request #1208 from dmcgowan/tar-test

archive: add link breakout checks and tests
This commit is contained in:
Phil Estes 2017-07-19 09:47:17 -04:00 committed by GitHub
commit 0600753bd8
4 changed files with 906 additions and 26 deletions

123
archive/path.go Normal file
View File

@ -0,0 +1,123 @@
package archive
import (
"os"
"path/filepath"
"strings"
"github.com/pkg/errors"
)
var (
errTooManyLinks = errors.New("too many links")
)
// rootPath joins a path with a root, evaluating and bounding any
// symlink to the root directory.
// TODO(dmcgowan): Expose and move to fs package or continuity path driver
func rootPath(root, path string) (string, error) {
if path == "" {
return root, nil
}
var linksWalked int // to protect against cycles
for {
i := linksWalked
newpath, err := walkLinks(root, path, &linksWalked)
if err != nil {
return "", err
}
path = newpath
if i == linksWalked {
newpath = rootJoin(newpath)
if path == newpath {
return filepath.Join(root, newpath), nil
}
path = newpath
}
}
}
// rootJoin joins a path with root, cleaning up any links that
// reference above root.
func rootJoin(path string) string {
if filepath.IsAbs(path) {
path = filepath.Clean(path)
}
// Resolve any ".." or "/.." before joining to root
for !filepath.IsAbs(path) {
path = "/" + path
path = filepath.Clean(path)
}
return path
}
func walkLink(root, path string, linksWalked *int) (newpath string, islink bool, err error) {
if *linksWalked > 255 {
return "", false, errTooManyLinks
}
path = rootJoin(path)
if path == "/" {
return path, false, nil
}
realPath := filepath.Join(root, path)
fi, err := os.Lstat(realPath)
if err != nil {
// If path does not yet exist, treat as non-symlink
if os.IsNotExist(err) {
return path, false, nil
}
return "", false, err
}
if fi.Mode()&os.ModeSymlink == 0 {
return path, false, nil
}
newpath, err = os.Readlink(realPath)
if err != nil {
return "", false, err
}
if filepath.IsAbs(newpath) && strings.HasPrefix(newpath, root) {
newpath = newpath[:len(root)]
if !strings.HasPrefix(newpath, "/") {
newpath = "/" + newpath
}
}
*linksWalked++
return newpath, true, nil
}
func walkLinks(root, path string, linksWalked *int) (string, error) {
switch dir, file := filepath.Split(path); {
case dir == "":
newpath, _, err := walkLink(root, file, linksWalked)
return newpath, err
case file == "":
if os.IsPathSeparator(dir[len(dir)-1]) {
if dir == "/" {
return dir, nil
}
return walkLinks(root, dir[:len(dir)-1], linksWalked)
}
newpath, _, err := walkLink(root, dir, linksWalked)
return newpath, err
default:
newdir, err := walkLinks(root, dir, linksWalked)
if err != nil {
return "", err
}
newpath, islink, err := walkLink(root, filepath.Join(newdir, file), linksWalked)
if err != nil {
return "", err
}
if !islink {
return newpath, nil
}
if filepath.IsAbs(newpath) {
return newpath, nil
}
return filepath.Join(newdir, newpath), nil
}
}

293
archive/path_test.go Normal file
View File

@ -0,0 +1,293 @@
package archive
import (
"io/ioutil"
"os"
"path/filepath"
"testing"
"github.com/containerd/containerd/fs/fstest"
"github.com/pkg/errors"
)
type rootCheck struct {
unresolved string
expected string
scope func(string) string
cause error
}
func TestRootPath(t *testing.T) {
tests := []struct {
name string
apply fstest.Applier
checks []rootCheck
scope func(string) (string, error)
}{
{
name: "SymlinkAbsolute",
apply: Symlink("/b", "fs/a/d"),
checks: Check("fs/a/d/c/data", "b/c/data"),
},
{
name: "SymlinkRelativePath",
apply: Symlink("a", "fs/i"),
checks: Check("fs/i", "fs/a"),
},
{
name: "SymlinkSkipSymlinksOutsideScope",
apply: Symlink("realdir", "linkdir"),
checks: CheckWithScope("foo/bar", "foo/bar", "linkdir"),
},
{
name: "SymlinkLastLink",
apply: Symlink("/b", "fs/a/d"),
checks: Check("fs/a/d", "b"),
},
{
name: "SymlinkRelativeLinkChangeScope",
apply: Symlink("../b", "fs/a/e"),
checks: CheckAll(
Check("fs/a/e/c/data", "fs/b/c/data"),
CheckWithScope("e", "b", "fs/a"), // Original return
),
},
{
name: "SymlinkDeepRelativeLinkChangeScope",
apply: Symlink("../../../../test", "fs/a/f"),
checks: CheckAll(
Check("fs/a/f", "test"), // Original return
CheckWithScope("a/f", "test", "fs"), // Original return
),
},
{
name: "SymlinkRelativeLinkChain",
apply: fstest.Apply(
Symlink("../g", "fs/b/h"),
fstest.Symlink("../../../../../../../../../../../../root", "fs/g"),
),
checks: Check("fs/b/h", "root"),
},
{
name: "SymlinkBreakoutPath",
apply: Symlink("../i/a", "fs/j/k"),
checks: CheckWithScope("k", "i/a", "fs/j"),
},
{
name: "SymlinkToRoot",
apply: Symlink("/", "foo"),
checks: Check("foo", ""),
},
{
name: "SymlinkSlashDotdot",
apply: Symlink("/../../", "foo"),
checks: Check("foo", ""),
},
{
name: "SymlinkDotdot",
apply: Symlink("../../", "foo"),
checks: Check("foo", ""),
},
{
name: "SymlinkRelativePath2",
apply: Symlink("baz/target", "bar/foo"),
checks: Check("bar/foo", "bar/baz/target"),
},
{
name: "SymlinkScopeLink",
apply: fstest.Apply(
Symlink("root2", "root"),
Symlink("../bar", "root2/foo"),
),
checks: CheckWithScope("foo", "bar", "root"),
},
{
name: "SymlinkSelf",
apply: fstest.Apply(
Symlink("foo", "root/foo"),
),
checks: ErrorWithScope("foo", "root", errTooManyLinks),
},
{
name: "SymlinkCircular",
apply: fstest.Apply(
Symlink("foo", "bar"),
Symlink("bar", "foo"),
),
checks: ErrorWithScope("foo", "", errTooManyLinks), //TODO: Test for circular error
},
{
name: "SymlinkCircularUnderRoot",
apply: fstest.Apply(
Symlink("baz", "root/bar"),
Symlink("../bak", "root/baz"),
Symlink("/bar", "root/bak"),
),
checks: ErrorWithScope("bar", "root", errTooManyLinks), // TODO: Test for circular error
},
{
name: "SymlinkComplexChain",
apply: fstest.Apply(
fstest.CreateDir("root2", 0777),
Symlink("root2", "root"),
Symlink("r/s", "root/a"),
Symlink("../root/t", "root/r"),
Symlink("/../u", "root/root/t/s/b"),
Symlink(".", "root/u/c"),
Symlink("../v", "root/u/x/y"),
Symlink("/../w", "root/u/v"),
),
checks: CheckWithScope("a/b/c/x/y/z", "w/z", "root"), // Original return
},
{
name: "SymlinkBreakoutNonExistent",
apply: fstest.Apply(
Symlink("/", "root/slash"),
Symlink("/idontexist/../slash", "root/sym"),
),
checks: CheckWithScope("sym/file", "file", "root"),
},
{
name: "SymlinkNoLexicalCleaning",
apply: fstest.Apply(
Symlink("/foo/bar", "root/sym"),
Symlink("/sym/../baz", "root/hello"),
),
checks: CheckWithScope("hello", "foo/baz", "root"),
},
}
for _, test := range tests {
t.Run(test.name, makeRootPathTest(t, test.apply, test.checks))
}
// Add related tests which are unable to follow same pattern
t.Run("SymlinkRootScope", testRootPathSymlinkRootScope)
t.Run("SymlinkEmpty", testRootPathSymlinkEmpty)
}
func testRootPathSymlinkRootScope(t *testing.T) {
tmpdir, err := ioutil.TempDir("", "TestFollowSymlinkRootScope")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpdir)
expected, err := filepath.EvalSymlinks(tmpdir)
if err != nil {
t.Fatal(err)
}
rewrite, err := rootPath("/", tmpdir)
if err != nil {
t.Fatal(err)
}
if rewrite != expected {
t.Fatalf("expected %q got %q", expected, rewrite)
}
}
func testRootPathSymlinkEmpty(t *testing.T) {
wd, err := os.Getwd()
if err != nil {
t.Fatal(err)
}
res, err := rootPath(wd, "")
if err != nil {
t.Fatal(err)
}
if res != wd {
t.Fatalf("expected %q got %q", wd, res)
}
}
func makeRootPathTest(t *testing.T, apply fstest.Applier, checks []rootCheck) func(t *testing.T) {
return func(t *testing.T) {
applyDir, err := ioutil.TempDir("", "test-root-path-")
if err != nil {
t.Fatalf("Unable to make temp directory: %+v", err)
}
defer os.RemoveAll(applyDir)
if apply != nil {
if err := apply.Apply(applyDir); err != nil {
t.Fatalf("Apply failed: %+v", err)
}
}
for i, check := range checks {
root := applyDir
if check.scope != nil {
root = check.scope(root)
}
actual, err := rootPath(root, check.unresolved)
if check.cause != nil {
if err == nil {
t.Errorf("(Check %d) Expected error %q, %q evaluated as %q", i+1, check.cause.Error(), check.unresolved, actual)
}
if errors.Cause(err) != check.cause {
t.Fatalf("(Check %d) Failed to evaluate root path: %+v", i+1, err)
}
} else {
expected := filepath.Join(root, check.expected)
if err != nil {
t.Fatalf("(Check %d) Failed to evaluate root path: %+v", i+1, err)
}
if actual != expected {
t.Errorf("(Check %d) Unexpected evaluated path %q, expected %q", i+1, actual, expected)
}
}
}
}
}
func Check(unresolved, expected string) []rootCheck {
return []rootCheck{
{
unresolved: unresolved,
expected: expected,
},
}
}
func CheckWithScope(unresolved, expected, scope string) []rootCheck {
return []rootCheck{
{
unresolved: unresolved,
expected: expected,
scope: func(root string) string {
return filepath.Join(root, scope)
},
},
}
}
func ErrorWithScope(unresolved, scope string, cause error) []rootCheck {
return []rootCheck{
{
unresolved: unresolved,
cause: cause,
scope: func(root string) string {
return filepath.Join(root, scope)
},
},
}
}
func CheckAll(checks ...[]rootCheck) []rootCheck {
all := make([]rootCheck, 0, len(checks))
for _, c := range checks {
all = append(all, c...)
}
return all
}
func Symlink(oldname, newname string) fstest.Applier {
dir := filepath.Dir(newname)
if dir != "" {
return fstest.Apply(
fstest.CreateDir(dir, 0755),
fstest.Symlink(oldname, newname),
)
}
return fstest.Symlink(oldname, newname)
}

View File

@ -25,8 +25,6 @@ var (
return make([]byte, 32*1024) return make([]byte, 32*1024)
}, },
} }
breakoutError = errors.New("file name outside of root")
) )
// Diff returns a tar stream of the computed filesystem // Diff returns a tar stream of the computed filesystem
@ -134,7 +132,10 @@ func Apply(ctx context.Context, root string, r io.Reader) (int64, error) {
// This happened in some tests where an image had a tarfile without any // This happened in some tests where an image had a tarfile without any
// parent directories. // parent directories.
parent := filepath.Dir(hdr.Name) parent := filepath.Dir(hdr.Name)
parentPath := filepath.Join(root, parent) parentPath, err := rootPath(root, parent)
if err != nil {
return 0, err
}
if _, err := os.Lstat(parentPath); err != nil && os.IsNotExist(err) { if _, err := os.Lstat(parentPath); err != nil && os.IsNotExist(err) {
err = mkdirAll(parentPath, 0600) err = mkdirAll(parentPath, 0600)
@ -158,7 +159,11 @@ func Apply(ctx context.Context, root string, r io.Reader) (int64, error) {
} }
defer os.RemoveAll(aufsTempdir) defer os.RemoveAll(aufsTempdir)
} }
if err := createTarFile(ctx, filepath.Join(aufsTempdir, basename), root, hdr, tr); err != nil { p, err := rootPath(aufsTempdir, basename)
if err != nil {
return 0, err
}
if err := createTarFile(ctx, p, root, hdr, tr); err != nil {
return 0, err return 0, err
} }
} }
@ -168,16 +173,11 @@ func Apply(ctx context.Context, root string, r io.Reader) (int64, error) {
} }
} }
path := filepath.Join(root, hdr.Name) path, err := rootPath(root, hdr.Name)
rel, err := filepath.Rel(root, path)
if err != nil { if err != nil {
return 0, err return 0, errors.Wrap(err, "failed to get root path")
} }
// Note as these operations are platform specific, so must the slash be.
if strings.HasPrefix(rel, ".."+string(os.PathSeparator)) {
return 0, errors.Wrapf(breakoutError, "%q is outside of %q", hdr.Name, root)
}
base := filepath.Base(path) base := filepath.Base(path)
if strings.HasPrefix(base, whiteoutPrefix) { if strings.HasPrefix(base, whiteoutPrefix) {
@ -239,7 +239,11 @@ func Apply(ctx context.Context, root string, r io.Reader) (int64, error) {
if srcHdr == nil { if srcHdr == nil {
return 0, fmt.Errorf("Invalid aufs hardlink") return 0, fmt.Errorf("Invalid aufs hardlink")
} }
tmpFile, err := os.Open(filepath.Join(aufsTempdir, linkBasename)) p, err := rootPath(aufsTempdir, linkBasename)
if err != nil {
return 0, err
}
tmpFile, err := os.Open(p)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -260,7 +264,10 @@ func Apply(ctx context.Context, root string, r io.Reader) (int64, error) {
} }
for _, hdr := range dirs { for _, hdr := range dirs {
path := filepath.Join(root, hdr.Name) path, err := rootPath(root, hdr.Name)
if err != nil {
return 0, err
}
if err := chtimes(path, boundTime(latestTime(hdr.AccessTime, hdr.ModTime)), boundTime(hdr.ModTime)); err != nil { if err := chtimes(path, boundTime(latestTime(hdr.AccessTime, hdr.ModTime)), boundTime(hdr.ModTime)); err != nil {
return 0, err return 0, err
} }
@ -467,25 +474,15 @@ func createTarFile(ctx context.Context, path, extractDir string, hdr *tar.Header
} }
case tar.TypeLink: case tar.TypeLink:
targetPath := filepath.Join(extractDir, hdr.Linkname) targetPath, err := rootPath(extractDir, hdr.Linkname)
// check for hardlink breakout if err != nil {
if !strings.HasPrefix(targetPath, extractDir) { return err
return errors.Wrapf(breakoutError, "invalid hardlink %q -> %q", targetPath, hdr.Linkname)
} }
if err := os.Link(targetPath, path); err != nil { if err := os.Link(targetPath, path); err != nil {
return err return err
} }
case tar.TypeSymlink: case tar.TypeSymlink:
// path -> hdr.Linkname = targetPath
// e.g. /extractDir/path/to/symlink -> ../2/file = /extractDir/path/2/file
targetPath := filepath.Join(filepath.Dir(path), hdr.Linkname)
// the reason we don't need to check symlinks in the path (with FollowSymlinkInScope) is because
// that symlink would first have to be created, which would be caught earlier, at this very check:
if !strings.HasPrefix(targetPath, extractDir) {
return errors.Wrapf(breakoutError, "invalid symlink %q -> %q", path, hdr.Linkname)
}
if err := os.Symlink(hdr.Linkname, path); err != nil { if err := os.Symlink(hdr.Linkname, path); err != nil {
return err return err
} }

View File

@ -1,12 +1,16 @@
package archive package archive
import ( import (
"archive/tar"
"bytes" "bytes"
"context" "context"
"io"
"io/ioutil" "io/ioutil"
"os" "os"
"os/exec" "os/exec"
"path/filepath"
"testing" "testing"
"time"
_ "crypto/sha256" _ "crypto/sha256"
@ -46,6 +50,238 @@ func TestBaseDiff(t *testing.T) {
} }
} }
func TestRelativeSymlinks(t *testing.T) {
breakoutLinks := []fstest.Applier{
fstest.Apply(
baseApplier,
fstest.Symlink("../other", "/home/derek/other"),
fstest.Symlink("../../etc", "/home/derek/etc"),
fstest.Symlink("up/../../other", "/home/derek/updown"),
),
fstest.Apply(
baseApplier,
fstest.Symlink("../../../breakout", "/home/derek/breakout"),
),
fstest.Apply(
baseApplier,
fstest.Symlink("../../breakout", "/breakout"),
),
fstest.Apply(
baseApplier,
fstest.Symlink("etc/../../upandout", "/breakout"),
),
fstest.Apply(
baseApplier,
fstest.Symlink("derek/../../../downandout", "/home/breakout"),
),
fstest.Apply(
baseApplier,
fstest.Symlink("/etc", "localetc"),
),
}
for _, bo := range breakoutLinks {
if err := testDiffApply(bo); err != nil {
t.Fatalf("Test apply failed: %+v", err)
}
}
}
func TestBreakouts(t *testing.T) {
tc := TarContext{}.WithUidGid(os.Getuid(), os.Getgid()).WithModTime(time.Now().UTC())
expected := "unbroken"
unbrokenCheck := func(root string) error {
b, err := ioutil.ReadFile(filepath.Join(root, "etc", "unbroken"))
if err != nil {
return errors.Wrap(err, "failed to read unbroken")
}
if string(b) != expected {
return errors.Errorf("/etc/unbroken: unexpected value %s, expected %s", b, expected)
}
return nil
}
sameFile := func(f1, f2 string) func(string) error {
return func(root string) error {
p1 := filepath.Join(root, f1)
p2 := filepath.Join(root, f2)
s1, err := os.Stat(p1)
if err != nil {
return err
}
s2, err := os.Stat(p2)
if err != nil {
return err
}
if !os.SameFile(s1, s2) {
return errors.Errorf("files differ: %#v and %#v", s1, s2)
}
return nil
}
}
breakouts := []struct {
name string
w WriterToTar
validator func(string) error
}{
{
name: "SymlinkAbsolute",
w: TarAll(
tc.Dir("etc", 0755),
tc.Symlink("/etc", "localetc"),
tc.File("/localetc/unbroken", []byte(expected), 0644),
),
validator: unbrokenCheck,
},
{
name: "SymlinkUpAndOut",
w: TarAll(
tc.Dir("etc", 0755),
tc.Dir("dummy", 0755),
tc.Symlink("/dummy/../etc", "localetc"),
tc.File("/localetc/unbroken", []byte(expected), 0644),
),
validator: unbrokenCheck,
},
{
name: "SymlinkMultipleAbsolute",
w: TarAll(
tc.Dir("etc", 0755),
tc.Dir("dummy", 0755),
tc.Symlink("/etc", "/dummy/etc"),
tc.Symlink("/dummy/etc", "localetc"),
tc.File("/dummy/etc/unbroken", []byte(expected), 0644),
),
validator: unbrokenCheck,
},
{
name: "SymlinkMultipleRelative",
w: TarAll(
tc.Dir("etc", 0755),
tc.Dir("dummy", 0755),
tc.Symlink("/etc", "/dummy/etc"),
tc.Symlink("./dummy/etc", "localetc"),
tc.File("/dummy/etc/unbroken", []byte(expected), 0644),
),
validator: unbrokenCheck,
},
{
name: "SymlinkEmptyFile",
w: TarAll(
tc.Dir("etc", 0755),
tc.File("etc/emptied", []byte("notempty"), 0644),
tc.Symlink("/etc", "localetc"),
tc.File("/localetc/emptied", []byte{}, 0644),
),
validator: func(root string) error {
b, err := ioutil.ReadFile(filepath.Join(root, "etc", "emptied"))
if err != nil {
return errors.Wrap(err, "failed to read unbroken")
}
if len(b) > 0 {
return errors.Errorf("/etc/emptied: non-empty")
}
return nil
},
},
{
name: "HardlinkRelative",
w: TarAll(
tc.Dir("etc", 0770),
tc.File("/etc/passwd", []byte("inside"), 0644),
tc.Dir("breakouts", 0755),
tc.Symlink("../../etc", "breakouts/d1"),
tc.Link("/breakouts/d1/passwd", "breakouts/mypasswd"),
),
validator: sameFile("/breakouts/mypasswd", "/etc/passwd"),
},
{
name: "HardlinkDownAndOut",
w: TarAll(
tc.Dir("etc", 0770),
tc.File("/etc/passwd", []byte("inside"), 0644),
tc.Dir("breakouts", 0755),
tc.Dir("downandout", 0755),
tc.Symlink("../downandout/../../etc", "breakouts/d1"),
tc.Link("/breakouts/d1/passwd", "breakouts/mypasswd"),
),
validator: sameFile("/breakouts/mypasswd", "/etc/passwd"),
},
{
name: "HardlinkAbsolute",
w: TarAll(
tc.Dir("etc", 0770),
tc.File("/etc/passwd", []byte("inside"), 0644),
tc.Symlink("/etc", "localetc"),
tc.Link("/localetc/passwd", "localpasswd"),
),
validator: sameFile("localpasswd", "/etc/passwd"),
},
{
name: "HardlinkRelativeLong",
w: TarAll(
tc.Dir("etc", 0770),
tc.File("/etc/passwd", []byte("inside"), 0644),
tc.Symlink("../../../../../../../etc", "localetc"),
tc.Link("/localetc/passwd", "localpasswd"),
),
validator: sameFile("localpasswd", "/etc/passwd"),
},
{
name: "HardlinkRelativeUpAndOut",
w: TarAll(
tc.Dir("etc", 0770),
tc.File("/etc/passwd", []byte("inside"), 0644),
tc.Symlink("upandout/../../../etc", "localetc"),
tc.Link("/localetc/passwd", "localpasswd"),
),
validator: sameFile("localpasswd", "/etc/passwd"),
},
{
name: "HardlinkDirectRelative",
w: TarAll(
tc.Dir("etc", 0770),
tc.File("/etc/passwd", []byte("inside"), 0644),
tc.Link("../../../../../etc/passwd", "localpasswd"),
),
validator: sameFile("localpasswd", "/etc/passwd"),
},
{
name: "HardlinkDirectAbsolute",
w: TarAll(
tc.Dir("etc", 0770),
tc.File("/etc/passwd", []byte("inside"), 0644),
tc.Link("/etc/passwd", "localpasswd"),
),
validator: sameFile("localpasswd", "/etc/passwd"),
},
{
name: "HardlinkSymlinkRelative",
w: TarAll(
tc.Dir("etc", 0770),
tc.File("/etc/passwd", []byte("inside"), 0644),
tc.Symlink("../../../../../etc/passwd", "passwdlink"),
tc.Link("/passwdlink", "localpasswd"),
),
validator: sameFile("/localpasswd", "/etc/passwd"),
},
{
name: "HardlinkSymlinkAbsolute",
w: TarAll(
tc.Dir("etc", 0770),
tc.File("/etc/passwd", []byte("inside"), 0644),
tc.Symlink("/etc/passwd", "passwdlink"),
tc.Link("/passwdlink", "localpasswd"),
),
validator: sameFile("/localpasswd", "/etc/passwd"),
},
}
for _, bo := range breakouts {
t.Run(bo.name, makeWriterToTarTest(bo.w, bo.validator))
}
}
func TestDiffApply(t *testing.T) { func TestDiffApply(t *testing.T) {
fstest.FSSuite(t, diffApplier{}) fstest.FSSuite(t, diffApplier{})
} }
@ -118,6 +354,58 @@ func testBaseDiff(a fstest.Applier) error {
return fstest.CheckDirectoryEqual(td, dest) return fstest.CheckDirectoryEqual(td, dest)
} }
func testDiffApply(a fstest.Applier) error {
td, err := ioutil.TempDir("", "test-diff-apply-")
if err != nil {
return errors.Wrap(err, "failed to create temp dir")
}
defer os.RemoveAll(td)
dest, err := ioutil.TempDir("", "test-diff-apply-dest-")
if err != nil {
return errors.Wrap(err, "failed to create temp dir")
}
defer os.RemoveAll(dest)
if err := a.Apply(td); err != nil {
return errors.Wrap(err, "failed to apply filesystem changes")
}
diffBytes, err := ioutil.ReadAll(Diff(context.Background(), "", td))
if err != nil {
return errors.Wrap(err, "failed to create diff")
}
if _, err := Apply(context.Background(), dest, bytes.NewReader(diffBytes)); err != nil {
return errors.Wrap(err, "failed to apply tar stream")
}
return fstest.CheckDirectoryEqual(td, dest)
}
func makeWriterToTarTest(wt WriterToTar, validate func(string) error) func(*testing.T) {
return func(t *testing.T) {
td, err := ioutil.TempDir("", "test-writer-to-tar-")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(td)
tr := TarFromWriterTo(wt)
if _, err := Apply(context.Background(), td, tr); err != nil {
t.Fatalf("Failed to apply tar: %v", err)
}
if validate != nil {
if err := validate(td); err != nil {
t.Errorf("Validation failed: %v", err)
}
}
}
}
type diffApplier struct{} type diffApplier struct{}
func (d diffApplier) TestContext(ctx context.Context) (context.Context, func(), error) { func (d diffApplier) TestContext(ctx context.Context) (context.Context, func(), error) {
@ -174,3 +462,182 @@ func requireTar(t *testing.T) {
t.Skipf("%s not found, skipping", tarCmd) t.Skipf("%s not found, skipping", tarCmd)
} }
} }
// WriterToTar is an type which writes to a tar writer
type WriterToTar interface {
WriteTo(*tar.Writer) error
}
type writerToFn func(*tar.Writer) error
func (w writerToFn) WriteTo(tw *tar.Writer) error {
return w(tw)
}
// TarAll creates a WriterToTar which calls all the provided writers
// in the order in which they are provided.
func TarAll(wt ...WriterToTar) WriterToTar {
return writerToFn(func(tw *tar.Writer) error {
for _, w := range wt {
if err := w.WriteTo(tw); err != nil {
return err
}
}
return nil
})
}
// TarFromWriterTo is used to create a tar stream from a tar record
// creator. This can be used to manifacture more specific tar records
// which allow testing specific tar cases which may be encountered
// by the untar process.
func TarFromWriterTo(wt WriterToTar) io.ReadCloser {
r, w := io.Pipe()
go func() {
tw := tar.NewWriter(w)
if err := wt.WriteTo(tw); err != nil {
w.CloseWithError(err)
return
}
w.CloseWithError(tw.Close())
}()
return r
}
// TarContext is used to create tar records
type TarContext struct {
Uid int
Gid int
// ModTime sets the modtimes for all files, if nil the current time
// is used for each file when it was written
ModTime *time.Time
Xattrs map[string]string
}
func (tc TarContext) newHeader(mode os.FileMode, name, link string, size int64) *tar.Header {
ti := tarInfo{
name: name,
mode: mode,
size: size,
modt: tc.ModTime,
hdr: &tar.Header{
Uid: tc.Uid,
Gid: tc.Gid,
Xattrs: tc.Xattrs,
},
}
if mode&os.ModeSymlink == 0 && link != "" {
ti.hdr.Typeflag = tar.TypeLink
ti.hdr.Linkname = link
}
hdr, err := tar.FileInfoHeader(ti, link)
if err != nil {
// Only returns an error on bad input mode
panic(err)
}
return hdr
}
type tarInfo struct {
name string
mode os.FileMode
size int64
modt *time.Time
hdr *tar.Header
}
func (ti tarInfo) Name() string {
return ti.name
}
func (ti tarInfo) Size() int64 {
return ti.size
}
func (ti tarInfo) Mode() os.FileMode {
return ti.mode
}
func (ti tarInfo) ModTime() time.Time {
if ti.modt != nil {
return *ti.modt
}
return time.Now().UTC()
}
func (ti tarInfo) IsDir() bool {
return (ti.mode & os.ModeDir) != 0
}
func (ti tarInfo) Sys() interface{} {
return ti.hdr
}
func (tc TarContext) WithUidGid(uid, gid int) TarContext {
ntc := tc
ntc.Uid = uid
ntc.Gid = gid
return ntc
}
func (tc TarContext) WithModTime(modtime time.Time) TarContext {
ntc := tc
ntc.ModTime = &modtime
return ntc
}
// WithXattrs adds these xattrs to all files, merges with any
// previously added xattrs
func (tc TarContext) WithXattrs(xattrs map[string]string) TarContext {
ntc := tc
if ntc.Xattrs == nil {
ntc.Xattrs = map[string]string{}
}
for k, v := range xattrs {
ntc.Xattrs[k] = v
}
return ntc
}
func (tc TarContext) File(name string, content []byte, perm os.FileMode) WriterToTar {
return writerToFn(func(tw *tar.Writer) error {
return writeHeaderAndContent(tw, tc.newHeader(perm, name, "", int64(len(content))), content)
})
}
func (tc TarContext) Dir(name string, perm os.FileMode) WriterToTar {
return writerToFn(func(tw *tar.Writer) error {
return writeHeaderAndContent(tw, tc.newHeader(perm|os.ModeDir, name, "", 0), nil)
})
}
func (tc TarContext) Symlink(oldname, newname string) WriterToTar {
return writerToFn(func(tw *tar.Writer) error {
return writeHeaderAndContent(tw, tc.newHeader(0777|os.ModeSymlink, newname, oldname, 0), nil)
})
}
func (tc TarContext) Link(oldname, newname string) WriterToTar {
return writerToFn(func(tw *tar.Writer) error {
return writeHeaderAndContent(tw, tc.newHeader(0777, newname, oldname, 0), nil)
})
}
func writeHeaderAndContent(tw *tar.Writer, h *tar.Header, b []byte) error {
if h.Size != int64(len(b)) {
return errors.New("bad content length")
}
if err := tw.WriteHeader(h); err != nil {
return err
}
if len(b) > 0 {
if _, err := tw.Write(b); err != nil {
return err
}
}
return nil
}