diff --git a/archive/path.go b/archive/path.go new file mode 100644 index 000000000..fd3d0b108 --- /dev/null +++ b/archive/path.go @@ -0,0 +1,123 @@ +package archive + +import ( + "os" + "path/filepath" + "strings" + + "github.com/pkg/errors" +) + +var ( + errTooManyLinks = errors.New("too many links") +) + +// rootPath joins a path with a root, evaluating and bounding any +// symlink to the root directory. +// TODO(dmcgowan): Expose and move to fs package or continuity path driver +func rootPath(root, path string) (string, error) { + if path == "" { + return root, nil + } + var linksWalked int // to protect against cycles + for { + i := linksWalked + newpath, err := walkLinks(root, path, &linksWalked) + if err != nil { + return "", err + } + path = newpath + if i == linksWalked { + newpath = rootJoin(newpath) + if path == newpath { + return filepath.Join(root, newpath), nil + } + path = newpath + } + } +} + +// rootJoin joins a path with root, cleaning up any links that +// reference above root. +func rootJoin(path string) string { + if filepath.IsAbs(path) { + path = filepath.Clean(path) + } + + // Resolve any ".." or "/.." before joining to root + for !filepath.IsAbs(path) { + path = "/" + path + path = filepath.Clean(path) + } + + return path +} + +func walkLink(root, path string, linksWalked *int) (newpath string, islink bool, err error) { + if *linksWalked > 255 { + return "", false, errTooManyLinks + } + + path = rootJoin(path) + if path == "/" { + return path, false, nil + } + realPath := filepath.Join(root, path) + + fi, err := os.Lstat(realPath) + if err != nil { + // If path does not yet exist, treat as non-symlink + if os.IsNotExist(err) { + return path, false, nil + } + return "", false, err + } + if fi.Mode()&os.ModeSymlink == 0 { + return path, false, nil + } + newpath, err = os.Readlink(realPath) + if err != nil { + return "", false, err + } + if filepath.IsAbs(newpath) && strings.HasPrefix(newpath, root) { + newpath = newpath[:len(root)] + if !strings.HasPrefix(newpath, "/") { + newpath = "/" + newpath + } + } + *linksWalked++ + return newpath, true, nil +} + +func walkLinks(root, path string, linksWalked *int) (string, error) { + switch dir, file := filepath.Split(path); { + case dir == "": + newpath, _, err := walkLink(root, file, linksWalked) + return newpath, err + case file == "": + if os.IsPathSeparator(dir[len(dir)-1]) { + if dir == "/" { + return dir, nil + } + return walkLinks(root, dir[:len(dir)-1], linksWalked) + } + newpath, _, err := walkLink(root, dir, linksWalked) + return newpath, err + default: + newdir, err := walkLinks(root, dir, linksWalked) + if err != nil { + return "", err + } + newpath, islink, err := walkLink(root, filepath.Join(newdir, file), linksWalked) + if err != nil { + return "", err + } + if !islink { + return newpath, nil + } + if filepath.IsAbs(newpath) { + return newpath, nil + } + return filepath.Join(newdir, newpath), nil + } +} diff --git a/archive/path_test.go b/archive/path_test.go new file mode 100644 index 000000000..942ab19c0 --- /dev/null +++ b/archive/path_test.go @@ -0,0 +1,293 @@ +package archive + +import ( + "io/ioutil" + "os" + "path/filepath" + "testing" + + "github.com/containerd/containerd/fs/fstest" + "github.com/pkg/errors" +) + +type rootCheck struct { + unresolved string + expected string + scope func(string) string + cause error +} + +func TestRootPath(t *testing.T) { + tests := []struct { + name string + apply fstest.Applier + checks []rootCheck + scope func(string) (string, error) + }{ + { + name: "SymlinkAbsolute", + apply: Symlink("/b", "fs/a/d"), + checks: Check("fs/a/d/c/data", "b/c/data"), + }, + { + name: "SymlinkRelativePath", + apply: Symlink("a", "fs/i"), + checks: Check("fs/i", "fs/a"), + }, + { + name: "SymlinkSkipSymlinksOutsideScope", + apply: Symlink("realdir", "linkdir"), + checks: CheckWithScope("foo/bar", "foo/bar", "linkdir"), + }, + { + name: "SymlinkLastLink", + apply: Symlink("/b", "fs/a/d"), + checks: Check("fs/a/d", "b"), + }, + { + name: "SymlinkRelativeLinkChangeScope", + apply: Symlink("../b", "fs/a/e"), + checks: CheckAll( + Check("fs/a/e/c/data", "fs/b/c/data"), + CheckWithScope("e", "b", "fs/a"), // Original return + ), + }, + { + name: "SymlinkDeepRelativeLinkChangeScope", + apply: Symlink("../../../../test", "fs/a/f"), + checks: CheckAll( + Check("fs/a/f", "test"), // Original return + CheckWithScope("a/f", "test", "fs"), // Original return + ), + }, + { + name: "SymlinkRelativeLinkChain", + apply: fstest.Apply( + Symlink("../g", "fs/b/h"), + fstest.Symlink("../../../../../../../../../../../../root", "fs/g"), + ), + checks: Check("fs/b/h", "root"), + }, + { + name: "SymlinkBreakoutPath", + apply: Symlink("../i/a", "fs/j/k"), + checks: CheckWithScope("k", "i/a", "fs/j"), + }, + { + name: "SymlinkToRoot", + apply: Symlink("/", "foo"), + checks: Check("foo", ""), + }, + { + name: "SymlinkSlashDotdot", + apply: Symlink("/../../", "foo"), + checks: Check("foo", ""), + }, + { + name: "SymlinkDotdot", + apply: Symlink("../../", "foo"), + checks: Check("foo", ""), + }, + { + name: "SymlinkRelativePath2", + apply: Symlink("baz/target", "bar/foo"), + checks: Check("bar/foo", "bar/baz/target"), + }, + { + name: "SymlinkScopeLink", + apply: fstest.Apply( + Symlink("root2", "root"), + Symlink("../bar", "root2/foo"), + ), + checks: CheckWithScope("foo", "bar", "root"), + }, + { + name: "SymlinkSelf", + apply: fstest.Apply( + Symlink("foo", "root/foo"), + ), + checks: ErrorWithScope("foo", "root", errTooManyLinks), + }, + { + name: "SymlinkCircular", + apply: fstest.Apply( + Symlink("foo", "bar"), + Symlink("bar", "foo"), + ), + checks: ErrorWithScope("foo", "", errTooManyLinks), //TODO: Test for circular error + }, + { + name: "SymlinkCircularUnderRoot", + apply: fstest.Apply( + Symlink("baz", "root/bar"), + Symlink("../bak", "root/baz"), + Symlink("/bar", "root/bak"), + ), + checks: ErrorWithScope("bar", "root", errTooManyLinks), // TODO: Test for circular error + }, + { + name: "SymlinkComplexChain", + apply: fstest.Apply( + fstest.CreateDir("root2", 0777), + Symlink("root2", "root"), + Symlink("r/s", "root/a"), + Symlink("../root/t", "root/r"), + Symlink("/../u", "root/root/t/s/b"), + Symlink(".", "root/u/c"), + Symlink("../v", "root/u/x/y"), + Symlink("/../w", "root/u/v"), + ), + checks: CheckWithScope("a/b/c/x/y/z", "w/z", "root"), // Original return + }, + { + name: "SymlinkBreakoutNonExistent", + apply: fstest.Apply( + Symlink("/", "root/slash"), + Symlink("/idontexist/../slash", "root/sym"), + ), + checks: CheckWithScope("sym/file", "file", "root"), + }, + { + name: "SymlinkNoLexicalCleaning", + apply: fstest.Apply( + Symlink("/foo/bar", "root/sym"), + Symlink("/sym/../baz", "root/hello"), + ), + checks: CheckWithScope("hello", "foo/baz", "root"), + }, + } + + for _, test := range tests { + t.Run(test.name, makeRootPathTest(t, test.apply, test.checks)) + } + + // Add related tests which are unable to follow same pattern + t.Run("SymlinkRootScope", testRootPathSymlinkRootScope) + t.Run("SymlinkEmpty", testRootPathSymlinkEmpty) +} + +func testRootPathSymlinkRootScope(t *testing.T) { + tmpdir, err := ioutil.TempDir("", "TestFollowSymlinkRootScope") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tmpdir) + + expected, err := filepath.EvalSymlinks(tmpdir) + if err != nil { + t.Fatal(err) + } + rewrite, err := rootPath("/", tmpdir) + if err != nil { + t.Fatal(err) + } + if rewrite != expected { + t.Fatalf("expected %q got %q", expected, rewrite) + } +} +func testRootPathSymlinkEmpty(t *testing.T) { + wd, err := os.Getwd() + if err != nil { + t.Fatal(err) + } + res, err := rootPath(wd, "") + if err != nil { + t.Fatal(err) + } + if res != wd { + t.Fatalf("expected %q got %q", wd, res) + } +} + +func makeRootPathTest(t *testing.T, apply fstest.Applier, checks []rootCheck) func(t *testing.T) { + return func(t *testing.T) { + applyDir, err := ioutil.TempDir("", "test-root-path-") + if err != nil { + t.Fatalf("Unable to make temp directory: %+v", err) + } + defer os.RemoveAll(applyDir) + + if apply != nil { + if err := apply.Apply(applyDir); err != nil { + t.Fatalf("Apply failed: %+v", err) + } + } + + for i, check := range checks { + root := applyDir + if check.scope != nil { + root = check.scope(root) + } + + actual, err := rootPath(root, check.unresolved) + if check.cause != nil { + if err == nil { + t.Errorf("(Check %d) Expected error %q, %q evaluated as %q", i+1, check.cause.Error(), check.unresolved, actual) + } + if errors.Cause(err) != check.cause { + t.Fatalf("(Check %d) Failed to evaluate root path: %+v", i+1, err) + } + } else { + expected := filepath.Join(root, check.expected) + if err != nil { + t.Fatalf("(Check %d) Failed to evaluate root path: %+v", i+1, err) + } + if actual != expected { + t.Errorf("(Check %d) Unexpected evaluated path %q, expected %q", i+1, actual, expected) + } + } + } + } +} + +func Check(unresolved, expected string) []rootCheck { + return []rootCheck{ + { + unresolved: unresolved, + expected: expected, + }, + } +} + +func CheckWithScope(unresolved, expected, scope string) []rootCheck { + return []rootCheck{ + { + unresolved: unresolved, + expected: expected, + scope: func(root string) string { + return filepath.Join(root, scope) + }, + }, + } +} + +func ErrorWithScope(unresolved, scope string, cause error) []rootCheck { + return []rootCheck{ + { + unresolved: unresolved, + cause: cause, + scope: func(root string) string { + return filepath.Join(root, scope) + }, + }, + } +} + +func CheckAll(checks ...[]rootCheck) []rootCheck { + all := make([]rootCheck, 0, len(checks)) + for _, c := range checks { + all = append(all, c...) + } + return all +} + +func Symlink(oldname, newname string) fstest.Applier { + dir := filepath.Dir(newname) + if dir != "" { + return fstest.Apply( + fstest.CreateDir(dir, 0755), + fstest.Symlink(oldname, newname), + ) + } + return fstest.Symlink(oldname, newname) +} diff --git a/archive/tar.go b/archive/tar.go index 7e9705891..d88ce8ec9 100644 --- a/archive/tar.go +++ b/archive/tar.go @@ -25,8 +25,6 @@ var ( return make([]byte, 32*1024) }, } - - breakoutError = errors.New("file name outside of root") ) // Diff returns a tar stream of the computed filesystem @@ -168,16 +166,11 @@ func Apply(ctx context.Context, root string, r io.Reader) (int64, error) { } } - path := filepath.Join(root, hdr.Name) - rel, err := filepath.Rel(root, path) + path, err := rootPath(root, hdr.Name) if err != nil { - return 0, err + return 0, errors.Wrap(err, "failed to get root path") } - // Note as these operations are platform specific, so must the slash be. - if strings.HasPrefix(rel, ".."+string(os.PathSeparator)) { - return 0, errors.Wrapf(breakoutError, "%q is outside of %q", hdr.Name, root) - } base := filepath.Base(path) if strings.HasPrefix(base, whiteoutPrefix) { @@ -467,25 +460,15 @@ func createTarFile(ctx context.Context, path, extractDir string, hdr *tar.Header } case tar.TypeLink: - targetPath := filepath.Join(extractDir, hdr.Linkname) - // check for hardlink breakout - if !strings.HasPrefix(targetPath, extractDir) { - return errors.Wrapf(breakoutError, "invalid hardlink %q -> %q", targetPath, hdr.Linkname) + targetPath, err := rootPath(extractDir, hdr.Linkname) + if err != nil { + return err } if err := os.Link(targetPath, path); err != nil { return err } case tar.TypeSymlink: - // path -> hdr.Linkname = targetPath - // e.g. /extractDir/path/to/symlink -> ../2/file = /extractDir/path/2/file - targetPath := filepath.Join(filepath.Dir(path), hdr.Linkname) - - // the reason we don't need to check symlinks in the path (with FollowSymlinkInScope) is because - // that symlink would first have to be created, which would be caught earlier, at this very check: - if !strings.HasPrefix(targetPath, extractDir) { - return errors.Wrapf(breakoutError, "invalid symlink %q -> %q", path, hdr.Linkname) - } if err := os.Symlink(hdr.Linkname, path); err != nil { return err } diff --git a/archive/tar_test.go b/archive/tar_test.go index 33b83b063..02e0109df 100644 --- a/archive/tar_test.go +++ b/archive/tar_test.go @@ -1,12 +1,16 @@ package archive import ( + "archive/tar" "bytes" "context" + "io" "io/ioutil" "os" "os/exec" + "path/filepath" "testing" + "time" _ "crypto/sha256" @@ -46,6 +50,238 @@ func TestBaseDiff(t *testing.T) { } } +func TestRelativeSymlinks(t *testing.T) { + breakoutLinks := []fstest.Applier{ + fstest.Apply( + baseApplier, + fstest.Symlink("../other", "/home/derek/other"), + fstest.Symlink("../../etc", "/home/derek/etc"), + fstest.Symlink("up/../../other", "/home/derek/updown"), + ), + fstest.Apply( + baseApplier, + fstest.Symlink("../../../breakout", "/home/derek/breakout"), + ), + fstest.Apply( + baseApplier, + fstest.Symlink("../../breakout", "/breakout"), + ), + fstest.Apply( + baseApplier, + fstest.Symlink("etc/../../upandout", "/breakout"), + ), + fstest.Apply( + baseApplier, + fstest.Symlink("derek/../../../downandout", "/home/breakout"), + ), + fstest.Apply( + baseApplier, + fstest.Symlink("/etc", "localetc"), + ), + } + + for _, bo := range breakoutLinks { + if err := testDiffApply(bo); err != nil { + t.Fatalf("Test apply failed: %+v", err) + } + } +} + +func TestBreakouts(t *testing.T) { + tc := TarContext{}.WithUidGid(os.Getuid(), os.Getgid()).WithModTime(time.Now().UTC()) + expected := "unbroken" + unbrokenCheck := func(root string) error { + b, err := ioutil.ReadFile(filepath.Join(root, "etc", "unbroken")) + if err != nil { + return errors.Wrap(err, "failed to read unbroken") + } + if string(b) != expected { + return errors.Errorf("/etc/unbroken: unexpected value %s, expected %s", b, expected) + } + return nil + } + sameFile := func(f1, f2 string) func(string) error { + return func(root string) error { + p1 := filepath.Join(root, f1) + p2 := filepath.Join(root, f2) + s1, err := os.Stat(p1) + if err != nil { + return err + } + s2, err := os.Stat(p2) + if err != nil { + return err + } + if !os.SameFile(s1, s2) { + return errors.Errorf("files differ: %#v and %#v", s1, s2) + } + return nil + } + } + + breakouts := []struct { + name string + w WriterToTar + validator func(string) error + }{ + { + name: "SymlinkAbsolute", + w: TarAll( + tc.Dir("etc", 0755), + tc.Symlink("/etc", "localetc"), + tc.File("/localetc/unbroken", []byte(expected), 0644), + ), + validator: unbrokenCheck, + }, + { + name: "SymlinkUpAndOut", + w: TarAll( + tc.Dir("etc", 0755), + tc.Dir("dummy", 0755), + tc.Symlink("/dummy/../etc", "localetc"), + tc.File("/localetc/unbroken", []byte(expected), 0644), + ), + validator: unbrokenCheck, + }, + { + name: "SymlinkMultipleAbsolute", + w: TarAll( + tc.Dir("etc", 0755), + tc.Dir("dummy", 0755), + tc.Symlink("/etc", "/dummy/etc"), + tc.Symlink("/dummy/etc", "localetc"), + tc.File("/dummy/etc/unbroken", []byte(expected), 0644), + ), + validator: unbrokenCheck, + }, + { + name: "SymlinkMultipleRelative", + w: TarAll( + tc.Dir("etc", 0755), + tc.Dir("dummy", 0755), + tc.Symlink("/etc", "/dummy/etc"), + tc.Symlink("./dummy/etc", "localetc"), + tc.File("/dummy/etc/unbroken", []byte(expected), 0644), + ), + validator: unbrokenCheck, + }, + { + name: "SymlinkEmptyFile", + w: TarAll( + tc.Dir("etc", 0755), + tc.File("etc/emptied", []byte("notempty"), 0644), + tc.Symlink("/etc", "localetc"), + tc.File("/localetc/emptied", []byte{}, 0644), + ), + validator: func(root string) error { + b, err := ioutil.ReadFile(filepath.Join(root, "etc", "emptied")) + if err != nil { + return errors.Wrap(err, "failed to read unbroken") + } + if len(b) > 0 { + return errors.Errorf("/etc/emptied: non-empty") + } + return nil + }, + }, + { + name: "HardlinkRelative", + w: TarAll( + tc.Dir("etc", 0770), + tc.File("/etc/passwd", []byte("inside"), 0644), + tc.Dir("breakouts", 0755), + tc.Symlink("../../etc", "breakouts/d1"), + tc.Link("/breakouts/d1/passwd", "breakouts/mypasswd"), + ), + validator: sameFile("/breakouts/mypasswd", "/etc/passwd"), + }, + { + name: "HardlinkDownAndOut", + w: TarAll( + tc.Dir("etc", 0770), + tc.File("/etc/passwd", []byte("inside"), 0644), + tc.Dir("breakouts", 0755), + tc.Dir("downandout", 0755), + tc.Symlink("../downandout/../../etc", "breakouts/d1"), + tc.Link("/breakouts/d1/passwd", "breakouts/mypasswd"), + ), + validator: sameFile("/breakouts/mypasswd", "/etc/passwd"), + }, + { + name: "HardlinkAbsolute", + w: TarAll( + tc.Dir("etc", 0770), + tc.File("/etc/passwd", []byte("inside"), 0644), + tc.Symlink("/etc", "localetc"), + tc.Link("/localetc/passwd", "localpasswd"), + ), + validator: sameFile("localpasswd", "/etc/passwd"), + }, + { + name: "HardlinkRelativeLong", + w: TarAll( + tc.Dir("etc", 0770), + tc.File("/etc/passwd", []byte("inside"), 0644), + tc.Symlink("../../../../../../../etc", "localetc"), + tc.Link("/localetc/passwd", "localpasswd"), + ), + validator: sameFile("localpasswd", "/etc/passwd"), + }, + { + name: "HardlinkRelativeUpAndOut", + w: TarAll( + tc.Dir("etc", 0770), + tc.File("/etc/passwd", []byte("inside"), 0644), + tc.Symlink("upandout/../../../etc", "localetc"), + tc.Link("/localetc/passwd", "localpasswd"), + ), + validator: sameFile("localpasswd", "/etc/passwd"), + }, + { + name: "HardlinkDirectRelative", + w: TarAll( + tc.Dir("etc", 0770), + tc.File("/etc/passwd", []byte("inside"), 0644), + tc.Link("../../../../../etc/passwd", "localpasswd"), + ), + validator: sameFile("localpasswd", "/etc/passwd"), + }, + { + name: "HardlinkDirectAbsolute", + w: TarAll( + tc.Dir("etc", 0770), + tc.File("/etc/passwd", []byte("inside"), 0644), + tc.Link("/etc/passwd", "localpasswd"), + ), + validator: sameFile("localpasswd", "/etc/passwd"), + }, + { + name: "HardlinkSymlinkRelative", + w: TarAll( + tc.Dir("etc", 0770), + tc.File("/etc/passwd", []byte("inside"), 0644), + tc.Symlink("../../../../../etc/passwd", "passwdlink"), + tc.Link("/passwdlink", "localpasswd"), + ), + validator: sameFile("/localpasswd", "/etc/passwd"), + }, + { + name: "HardlinkSymlinkAbsolute", + w: TarAll( + tc.Dir("etc", 0770), + tc.File("/etc/passwd", []byte("inside"), 0644), + tc.Symlink("/etc/passwd", "passwdlink"), + tc.Link("/passwdlink", "localpasswd"), + ), + validator: sameFile("/localpasswd", "/etc/passwd"), + }, + } + + for _, bo := range breakouts { + t.Run(bo.name, makeWriterToTarTest(bo.w, bo.validator)) + } +} + func TestDiffApply(t *testing.T) { fstest.FSSuite(t, diffApplier{}) } @@ -118,6 +354,58 @@ func testBaseDiff(a fstest.Applier) error { return fstest.CheckDirectoryEqual(td, dest) } +func testDiffApply(a fstest.Applier) error { + td, err := ioutil.TempDir("", "test-diff-apply-") + if err != nil { + return errors.Wrap(err, "failed to create temp dir") + } + defer os.RemoveAll(td) + dest, err := ioutil.TempDir("", "test-diff-apply-dest-") + if err != nil { + return errors.Wrap(err, "failed to create temp dir") + } + defer os.RemoveAll(dest) + + if err := a.Apply(td); err != nil { + return errors.Wrap(err, "failed to apply filesystem changes") + } + + diffBytes, err := ioutil.ReadAll(Diff(context.Background(), "", td)) + if err != nil { + return errors.Wrap(err, "failed to create diff") + } + + if _, err := Apply(context.Background(), dest, bytes.NewReader(diffBytes)); err != nil { + return errors.Wrap(err, "failed to apply tar stream") + } + + return fstest.CheckDirectoryEqual(td, dest) +} + +func makeWriterToTarTest(wt WriterToTar, validate func(string) error) func(*testing.T) { + return func(t *testing.T) { + td, err := ioutil.TempDir("", "test-writer-to-tar-") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(td) + + tr := TarFromWriterTo(wt) + + if _, err := Apply(context.Background(), td, tr); err != nil { + t.Fatalf("Failed to apply tar: %v", err) + } + + if validate != nil { + if err := validate(td); err != nil { + t.Errorf("Validation failed: %v", err) + } + + } + + } +} + type diffApplier struct{} func (d diffApplier) TestContext(ctx context.Context) (context.Context, func(), error) { @@ -174,3 +462,182 @@ func requireTar(t *testing.T) { t.Skipf("%s not found, skipping", tarCmd) } } + +// WriterToTar is an type which writes to a tar writer +type WriterToTar interface { + WriteTo(*tar.Writer) error +} + +type writerToFn func(*tar.Writer) error + +func (w writerToFn) WriteTo(tw *tar.Writer) error { + return w(tw) +} + +// TarAll creates a WriterToTar which calls all the provided writers +// in the order in which they are provided. +func TarAll(wt ...WriterToTar) WriterToTar { + return writerToFn(func(tw *tar.Writer) error { + for _, w := range wt { + if err := w.WriteTo(tw); err != nil { + return err + } + } + return nil + }) +} + +// TarFromWriterTo is used to create a tar stream from a tar record +// creator. This can be used to manifacture more specific tar records +// which allow testing specific tar cases which may be encountered +// by the untar process. +func TarFromWriterTo(wt WriterToTar) io.ReadCloser { + r, w := io.Pipe() + go func() { + tw := tar.NewWriter(w) + if err := wt.WriteTo(tw); err != nil { + w.CloseWithError(err) + return + } + w.CloseWithError(tw.Close()) + }() + + return r +} + +// TarContext is used to create tar records +type TarContext struct { + Uid int + Gid int + + // ModTime sets the modtimes for all files, if nil the current time + // is used for each file when it was written + ModTime *time.Time + + Xattrs map[string]string +} + +func (tc TarContext) newHeader(mode os.FileMode, name, link string, size int64) *tar.Header { + ti := tarInfo{ + name: name, + mode: mode, + size: size, + modt: tc.ModTime, + hdr: &tar.Header{ + Uid: tc.Uid, + Gid: tc.Gid, + Xattrs: tc.Xattrs, + }, + } + + if mode&os.ModeSymlink == 0 && link != "" { + ti.hdr.Typeflag = tar.TypeLink + ti.hdr.Linkname = link + } + + hdr, err := tar.FileInfoHeader(ti, link) + if err != nil { + // Only returns an error on bad input mode + panic(err) + } + + return hdr +} + +type tarInfo struct { + name string + mode os.FileMode + size int64 + modt *time.Time + hdr *tar.Header +} + +func (ti tarInfo) Name() string { + return ti.name +} + +func (ti tarInfo) Size() int64 { + return ti.size +} +func (ti tarInfo) Mode() os.FileMode { + return ti.mode +} + +func (ti tarInfo) ModTime() time.Time { + if ti.modt != nil { + return *ti.modt + } + return time.Now().UTC() +} + +func (ti tarInfo) IsDir() bool { + return (ti.mode & os.ModeDir) != 0 +} +func (ti tarInfo) Sys() interface{} { + return ti.hdr +} + +func (tc TarContext) WithUidGid(uid, gid int) TarContext { + ntc := tc + ntc.Uid = uid + ntc.Gid = gid + return ntc +} + +func (tc TarContext) WithModTime(modtime time.Time) TarContext { + ntc := tc + ntc.ModTime = &modtime + return ntc +} + +// WithXattrs adds these xattrs to all files, merges with any +// previously added xattrs +func (tc TarContext) WithXattrs(xattrs map[string]string) TarContext { + ntc := tc + if ntc.Xattrs == nil { + ntc.Xattrs = map[string]string{} + } + for k, v := range xattrs { + ntc.Xattrs[k] = v + } + return ntc +} + +func (tc TarContext) File(name string, content []byte, perm os.FileMode) WriterToTar { + return writerToFn(func(tw *tar.Writer) error { + return writeHeaderAndContent(tw, tc.newHeader(perm, name, "", int64(len(content))), content) + }) +} + +func (tc TarContext) Dir(name string, perm os.FileMode) WriterToTar { + return writerToFn(func(tw *tar.Writer) error { + return writeHeaderAndContent(tw, tc.newHeader(perm|os.ModeDir, name, "", 0), nil) + }) +} + +func (tc TarContext) Symlink(oldname, newname string) WriterToTar { + return writerToFn(func(tw *tar.Writer) error { + return writeHeaderAndContent(tw, tc.newHeader(0777|os.ModeSymlink, newname, oldname, 0), nil) + }) +} + +func (tc TarContext) Link(oldname, newname string) WriterToTar { + return writerToFn(func(tw *tar.Writer) error { + return writeHeaderAndContent(tw, tc.newHeader(0777, newname, oldname, 0), nil) + }) +} + +func writeHeaderAndContent(tw *tar.Writer, h *tar.Header, b []byte) error { + if h.Size != int64(len(b)) { + return errors.New("bad content length") + } + if err := tw.WriteHeader(h); err != nil { + return err + } + if len(b) > 0 { + if _, err := tw.Write(b); err != nil { + return err + } + } + return nil +}