diff --git a/archive/path.go b/archive/path.go index 32374459c..0f6cfa32e 100644 --- a/archive/path.go +++ b/archive/path.go @@ -1,107 +1 @@ package archive - -import ( - "os" - "path/filepath" - "strings" - - "github.com/pkg/errors" -) - -var ( - errTooManyLinks = errors.New("too many links") -) - -// rootPath joins a path with a root, evaluating and bounding any -// symlink to the root directory. -// TODO(dmcgowan): Expose and move to fs package or continuity path driver -func rootPath(root, path string) (string, error) { - if path == "" { - return root, nil - } - var linksWalked int // to protect against cycles - for { - i := linksWalked - newpath, err := walkLinks(root, path, &linksWalked) - if err != nil { - return "", err - } - path = newpath - if i == linksWalked { - newpath = filepath.Join("/", newpath) - if path == newpath { - return filepath.Join(root, newpath), nil - } - path = newpath - } - } -} - -func walkLink(root, path string, linksWalked *int) (newpath string, islink bool, err error) { - if *linksWalked > 255 { - return "", false, errTooManyLinks - } - - path = filepath.Join("/", path) - if path == "/" { - return path, false, nil - } - realPath := filepath.Join(root, path) - - fi, err := os.Lstat(realPath) - if err != nil { - // If path does not yet exist, treat as non-symlink - if os.IsNotExist(err) { - return path, false, nil - } - return "", false, err - } - if fi.Mode()&os.ModeSymlink == 0 { - return path, false, nil - } - newpath, err = os.Readlink(realPath) - if err != nil { - return "", false, err - } - if filepath.IsAbs(newpath) && strings.HasPrefix(newpath, root) { - newpath = newpath[:len(root)] - if !strings.HasPrefix(newpath, "/") { - newpath = "/" + newpath - } - } - *linksWalked++ - return newpath, true, nil -} - -func walkLinks(root, path string, linksWalked *int) (string, error) { - switch dir, file := filepath.Split(path); { - case dir == "": - newpath, _, err := walkLink(root, file, linksWalked) - return newpath, err - case file == "": - if os.IsPathSeparator(dir[len(dir)-1]) { - if dir == "/" { - return dir, nil - } - return walkLinks(root, dir[:len(dir)-1], linksWalked) - } - newpath, _, err := walkLink(root, dir, linksWalked) - return newpath, err - default: - newdir, err := walkLinks(root, dir, linksWalked) - if err != nil { - return "", err - } - newpath, islink, err := walkLink(root, filepath.Join(newdir, file), linksWalked) - if err != nil { - return "", err - } - if !islink { - return newpath, nil - } - if filepath.IsAbs(newpath) { - return newpath, nil - } - return filepath.Join(newdir, newpath), nil - } -} diff --git a/archive/tar.go b/archive/tar.go index b29699e10..843234c0a 100644 --- a/archive/tar.go +++ b/archive/tar.go @@ -128,7 +128,7 @@ func Apply(ctx context.Context, root string, r io.Reader) (int64, error) { // Split name and resolve symlinks for root directory. ppath, base := filepath.Split(hdr.Name) - ppath, err = rootPath(root, ppath) + ppath, err = fs.RootPath(root, ppath) if err != nil { return 0, errors.Wrap(err, "failed to get root path") } @@ -170,7 +170,7 @@ func Apply(ctx context.Context, root string, r io.Reader) (int64, error) { } defer os.RemoveAll(aufsTempdir) } - p, err := rootPath(aufsTempdir, basename) + p, err := fs.RootPath(aufsTempdir, basename) if err != nil { return 0, err } @@ -243,7 +243,7 @@ func Apply(ctx context.Context, root string, r io.Reader) (int64, error) { if srcHdr == nil { return 0, fmt.Errorf("Invalid aufs hardlink") } - p, err := rootPath(aufsTempdir, linkBasename) + p, err := fs.RootPath(aufsTempdir, linkBasename) if err != nil { return 0, err } @@ -268,7 +268,7 @@ func Apply(ctx context.Context, root string, r io.Reader) (int64, error) { } for _, hdr := range dirs { - path, err := rootPath(root, hdr.Name) + path, err := fs.RootPath(root, hdr.Name) if err != nil { return 0, err } @@ -478,7 +478,7 @@ func createTarFile(ctx context.Context, path, extractDir string, hdr *tar.Header } case tar.TypeLink: - targetPath, err := rootPath(extractDir, hdr.Linkname) + targetPath, err := fs.RootPath(extractDir, hdr.Linkname) if err != nil { return err } diff --git a/archive/tar_test.go b/archive/tar_test.go index 68cfeebee..67dc8a02a 100644 --- a/archive/tar_test.go +++ b/archive/tar_test.go @@ -182,11 +182,11 @@ func TestBreakouts(t *testing.T) { errFileDiff := errors.New("files differ") sameFile := func(f1, f2 string) func(string) error { return func(root string) error { - p1, err := rootPath(root, f1) + p1, err := fs.RootPath(root, f1) if err != nil { return err } - p2, err := rootPath(root, f2) + p2, err := fs.RootPath(root, f2) if err != nil { return err } @@ -484,7 +484,7 @@ func TestApplyTar(t *testing.T) { directoriesExist := func(dirs ...string) func(string) error { return func(root string) error { for _, d := range dirs { - p, err := rootPath(root, d) + p, err := fs.RootPath(root, d) if err != nil { return err } diff --git a/fs/path.go b/fs/path.go index a46d0fcbd..644b1ee2e 100644 --- a/fs/path.go +++ b/fs/path.go @@ -7,6 +7,12 @@ import ( "os" "path/filepath" "strings" + + "github.com/pkg/errors" +) + +var ( + errTooManyLinks = errors.New("too many links") ) type currentPath struct { @@ -160,3 +166,96 @@ func nextPath(ctx context.Context, pathC <-chan *currentPath) (*currentPath, err return p, nil } } + +// RootPath joins a path with a root, evaluating and bounding any +// symlink to the root directory. +func RootPath(root, path string) (string, error) { + if path == "" { + return root, nil + } + var linksWalked int // to protect against cycles + for { + i := linksWalked + newpath, err := walkLinks(root, path, &linksWalked) + if err != nil { + return "", err + } + path = newpath + if i == linksWalked { + newpath = filepath.Join("/", newpath) + if path == newpath { + return filepath.Join(root, newpath), nil + } + path = newpath + } + } +} + +func walkLink(root, path string, linksWalked *int) (newpath string, islink bool, err error) { + if *linksWalked > 255 { + return "", false, errTooManyLinks + } + + path = filepath.Join("/", path) + if path == "/" { + return path, false, nil + } + realPath := filepath.Join(root, path) + + fi, err := os.Lstat(realPath) + if err != nil { + // If path does not yet exist, treat as non-symlink + if os.IsNotExist(err) { + return path, false, nil + } + return "", false, err + } + if fi.Mode()&os.ModeSymlink == 0 { + return path, false, nil + } + newpath, err = os.Readlink(realPath) + if err != nil { + return "", false, err + } + if filepath.IsAbs(newpath) && strings.HasPrefix(newpath, root) { + newpath = newpath[:len(root)] + if !strings.HasPrefix(newpath, "/") { + newpath = "/" + newpath + } + } + *linksWalked++ + return newpath, true, nil +} + +func walkLinks(root, path string, linksWalked *int) (string, error) { + switch dir, file := filepath.Split(path); { + case dir == "": + newpath, _, err := walkLink(root, file, linksWalked) + return newpath, err + case file == "": + if os.IsPathSeparator(dir[len(dir)-1]) { + if dir == "/" { + return dir, nil + } + return walkLinks(root, dir[:len(dir)-1], linksWalked) + } + newpath, _, err := walkLink(root, dir, linksWalked) + return newpath, err + default: + newdir, err := walkLinks(root, dir, linksWalked) + if err != nil { + return "", err + } + newpath, islink, err := walkLink(root, filepath.Join(newdir, file), linksWalked) + if err != nil { + return "", err + } + if !islink { + return newpath, nil + } + if filepath.IsAbs(newpath) { + return newpath, nil + } + return filepath.Join(newdir, newpath), nil + } +} diff --git a/archive/path_test.go b/fs/path_test.go similarity index 98% rename from archive/path_test.go rename to fs/path_test.go index 6cb83ef62..b09090800 100644 --- a/archive/path_test.go +++ b/fs/path_test.go @@ -1,6 +1,6 @@ // +build !windows -package archive +package fs import ( "io/ioutil" @@ -179,7 +179,7 @@ func testRootPathSymlinkRootScope(t *testing.T) { if err != nil { t.Fatal(err) } - rewrite, err := rootPath("/", tmpdir) + rewrite, err := RootPath("/", tmpdir) if err != nil { t.Fatal(err) } @@ -192,7 +192,7 @@ func testRootPathSymlinkEmpty(t *testing.T) { if err != nil { t.Fatal(err) } - res, err := rootPath(wd, "") + res, err := RootPath(wd, "") if err != nil { t.Fatal(err) } @@ -221,7 +221,7 @@ func makeRootPathTest(t *testing.T, apply fstest.Applier, checks []rootCheck) fu root = check.scope(root) } - actual, err := rootPath(root, check.unresolved) + actual, err := RootPath(root, check.unresolved) if check.cause != nil { if err == nil { t.Errorf("(Check %d) Expected error %q, %q evaluated as %q", i+1, check.cause.Error(), check.unresolved, actual)