diff --git a/archive/tar.go b/archive/tar.go index 7d7702bee..6246691aa 100644 --- a/archive/tar.go +++ b/archive/tar.go @@ -366,10 +366,11 @@ func createTarFile(ctx context.Context, path, extractDir string, hdr *tar.Header } case tar.TypeLink: - targetPath, err := fs.RootPath(extractDir, hdr.Linkname) + targetPath, err := hardlinkRootPath(extractDir, hdr.Linkname) if err != nil { return err } + if err := os.Link(targetPath, path); err != nil { return err } @@ -648,3 +649,27 @@ func copyBuffered(ctx context.Context, dst io.Writer, src io.Reader) (written in return written, err } + +// hardlinkRootPath returns target linkname, evaluating and bounding any +// symlink to the parent directory. +// +// NOTE: Allow hardlink to the softlink, not the real one. For example, +// +// touch /tmp/zzz +// ln -s /tmp/zzz /tmp/xxx +// ln /tmp/xxx /tmp/yyy +// +// /tmp/yyy should be softlink which be same of /tmp/xxx, not /tmp/zzz. +func hardlinkRootPath(root, linkname string) (string, error) { + ppath, base := filepath.Split(linkname) + ppath, err := fs.RootPath(root, ppath) + if err != nil { + return "", err + } + + targetPath := filepath.Join(ppath, base) + if !strings.HasPrefix(targetPath, root) { + targetPath = root + } + return targetPath, nil +} diff --git a/archive/tar_test.go b/archive/tar_test.go index a7391adae..7486d3793 100644 --- a/archive/tar_test.go +++ b/archive/tar_test.go @@ -196,6 +196,49 @@ func TestBreakouts(t *testing.T) { return nil } errFileDiff := errors.New("files differ") + + isSymlinkFile := func(f string) func(string) error { + return func(root string) error { + fi, err := os.Lstat(filepath.Join(root, f)) + if err != nil { + return err + } + + if got := fi.Mode() & os.ModeSymlink; got != os.ModeSymlink { + return errors.Errorf("%s should be symlink", fi.Name()) + } + return nil + } + } + + sameSymlinkFile := func(f1, f2 string) func(string) error { + checkF1, checkF2 := isSymlinkFile(f1), isSymlinkFile(f2) + return func(root string) error { + if err := checkF1(root); err != nil { + return err + } + + if err := checkF2(root); err != nil { + return err + } + + t1, err := os.Readlink(filepath.Join(root, f1)) + if err != nil { + return err + } + + t2, err := os.Readlink(filepath.Join(root, f2)) + if err != nil { + return err + } + + if t1 != t2 { + return errors.Wrapf(errFileDiff, "%#v and %#v", t1, t2) + } + return nil + } + } + sameFile := func(f1, f2 string) func(string) error { return func(root string) error { p1, err := fs.RootPath(root, f1) @@ -406,6 +449,16 @@ func TestBreakouts(t *testing.T) { ), validator: sameFile("localpasswd", "/etc/passwd"), }, + { + name: "HardlinkSymlinkBeforeCreateTarget", + w: TarAll( + tc.Dir("etc", 0770), + tc.Symlink("/etc/passwd", "localpasswd"), + tc.Link("localpasswd", "localpasswd-dup"), + tc.File("/etc/passwd", []byte("after"), 0644), + ), + validator: sameFile("localpasswd-dup", "/etc/passwd"), + }, { name: "HardlinkSymlinkRelative", w: TarAll( @@ -414,7 +467,10 @@ func TestBreakouts(t *testing.T) { tc.Symlink("../../../../../etc/passwd", "passwdlink"), tc.Link("/passwdlink", "localpasswd"), ), - validator: sameFile("/localpasswd", "/etc/passwd"), + validator: all( + sameSymlinkFile("/localpasswd", "/passwdlink"), + sameFile("/localpasswd", "/etc/passwd"), + ), }, { name: "HardlinkSymlinkAbsolute", @@ -424,7 +480,10 @@ func TestBreakouts(t *testing.T) { tc.Symlink("/etc/passwd", "passwdlink"), tc.Link("/passwdlink", "localpasswd"), ), - validator: sameFile("/localpasswd", "/etc/passwd"), + validator: all( + sameSymlinkFile("/localpasswd", "/passwdlink"), + sameFile("/localpasswd", "/etc/passwd"), + ), }, { name: "SymlinkParentDirectory",