diff --git a/archive/tar.go b/archive/tar.go index 99cfc35fc..0e32d7105 100644 --- a/archive/tar.go +++ b/archive/tar.go @@ -26,6 +26,8 @@ var bufferPool = &sync.Pool{ }, } +var errInvalidArchive = errors.New("invalid archive") + // Diff returns a tar stream of the computed filesystem // difference between the provided directories. // @@ -231,6 +233,15 @@ func applyNaive(ctx context.Context, root string, tr *tar.Reader, options ApplyO originalBase := base[len(whiteoutPrefix):] originalPath := filepath.Join(dir, originalBase) + + // Ensure originalPath is under dir + if dir[len(dir)-1] != filepath.Separator { + dir += string(filepath.Separator) + } + if !strings.HasPrefix(originalPath, dir) { + return 0, errors.Wrapf(errInvalidArchive, "invalid whiteout name: %v", base) + } + if err := os.RemoveAll(originalPath); err != nil { return 0, err } diff --git a/archive/tar_test.go b/archive/tar_test.go index d5e10a6c6..d9c1dda65 100644 --- a/archive/tar_test.go +++ b/archive/tar_test.go @@ -217,12 +217,47 @@ func TestBreakouts(t *testing.T) { return nil } } + fileValue := func(f1 string, content []byte) func(string) error { + return func(root string) error { + b, err := ioutil.ReadFile(filepath.Join(root, f1)) + if err != nil { + return err + } + if bytes.Compare(b, content) != 0 { + return errors.Errorf("content differs: expected %v, got %v", content, b) + } + return nil + } + } + fileNotExists := func(f1 string) func(string) error { + return func(root string) error { + _, err := os.Lstat(filepath.Join(root, f1)) + if err == nil { + return errors.New("file exists") + } else if !os.IsNotExist(err) { + return err + } + return nil + } + + } + all := func(funcs ...func(string) error) func(string) error { + return func(root string) error { + for _, f := range funcs { + if err := f(root); err != nil { + return err + } + } + return nil + } + } breakouts := []struct { name string w WriterToTar apply fstest.Applier validator func(string) error + err error }{ { name: "SymlinkAbsolute", @@ -468,10 +503,129 @@ func TestBreakouts(t *testing.T) { ), validator: notSameFile("/localetc/localpasswd", "/etc/passwd"), }, + { + name: "WhiteoutRootParent", + apply: fstest.Apply( + fstest.CreateDir("/etc/", 0755), + fstest.CreateFile("/etc/passwd", []byte("inside"), 0644), + ), + w: TarAll( + tc.File(".wh...", []byte{}, 0644), // Should wipe out whole directory + ), + err: errInvalidArchive, + }, + { + name: "WhiteoutParent", + apply: fstest.Apply( + fstest.CreateDir("/etc/", 0755), + fstest.CreateFile("/etc/passwd", []byte("inside"), 0644), + ), + w: TarAll( + tc.File("etc/.wh...", []byte{}, 0644), + ), + err: errInvalidArchive, + }, + { + name: "WhiteoutRoot", + apply: fstest.Apply( + fstest.CreateDir("/etc/", 0755), + fstest.CreateFile("/etc/passwd", []byte("inside"), 0644), + ), + w: TarAll( + tc.File(".wh..", []byte{}, 0644), + ), + err: errInvalidArchive, + }, + { + name: "WhiteoutCurrentDirectory", + apply: fstest.Apply( + fstest.CreateDir("/etc/", 0755), + fstest.CreateFile("/etc/passwd", []byte("inside"), 0644), + ), + w: TarAll( + tc.File("etc/.wh..", []byte{}, 0644), // Should wipe out whole directory + ), + err: errInvalidArchive, + }, + { + name: "WhiteoutSymlink", + apply: fstest.Apply( + fstest.CreateDir("/etc/", 0755), + fstest.CreateFile("/etc/passwd", []byte("all users"), 0644), + fstest.Symlink("/etc", "localetc"), + ), + w: TarAll( + tc.File(".wh.localetc", []byte{}, 0644), // Should wipe out whole directory + ), + validator: all( + fileValue("etc/passwd", []byte("all users")), + fileNotExists("localetc"), + ), + }, + { + // TODO: This test should change once archive apply is disallowing + // symlinks as parents in the name + name: "WhiteoutSymlinkPath", + apply: fstest.Apply( + fstest.CreateDir("/etc/", 0755), + fstest.CreateFile("/etc/passwd", []byte("all users"), 0644), + fstest.CreateFile("/etc/whitedout", []byte("ahhhh whiteout"), 0644), + fstest.Symlink("/etc", "localetc"), + ), + w: TarAll( + tc.File("localetc/.wh.whitedout", []byte{}, 0644), + ), + validator: all( + fileValue("etc/passwd", []byte("all users")), + fileNotExists("etc/whitedout"), + ), + }, + { + name: "WhiteoutDirectoryName", + apply: fstest.Apply( + fstest.CreateDir("/etc/", 0755), + fstest.CreateFile("/etc/passwd", []byte("all users"), 0644), + fstest.CreateFile("/etc/whitedout", []byte("ahhhh whiteout"), 0644), + fstest.Symlink("/etc", "localetc"), + ), + w: TarAll( + tc.File(".wh.etc/somefile", []byte("non-empty"), 0644), + ), + validator: all( + fileValue("etc/passwd", []byte("all users")), + fileValue(".wh.etc/somefile", []byte("non-empty")), + ), + }, + { + name: "WhiteoutDeadSymlinkParent", + apply: fstest.Apply( + fstest.CreateDir("/etc/", 0755), + fstest.CreateFile("/etc/passwd", []byte("all users"), 0644), + fstest.Symlink("/dne", "localetc"), + ), + w: TarAll( + tc.File("localetc/.wh.etc", []byte{}, 0644), + ), + // no-op, remove does not + validator: fileValue("etc/passwd", []byte("all users")), + }, + { + name: "WhiteoutRelativePath", + apply: fstest.Apply( + fstest.CreateDir("/etc/", 0755), + fstest.CreateFile("/etc/passwd", []byte("all users"), 0644), + fstest.Symlink("/dne", "localetc"), + ), + w: TarAll( + tc.File("dne/../.wh.etc", []byte{}, 0644), + ), + // resolution ends up just removing etc + validator: fileNotExists("etc/passwd"), + }, } for _, bo := range breakouts { - t.Run(bo.name, makeWriterToTarTest(bo.w, bo.apply, bo.validator)) + t.Run(bo.name, makeWriterToTarTest(bo.w, bo.apply, bo.validator, bo.err)) } } @@ -501,6 +655,7 @@ func TestApplyTar(t *testing.T) { w WriterToTar apply fstest.Applier validator func(string) error + err error }{ { name: "DirectoryCreation", @@ -525,7 +680,7 @@ func TestApplyTar(t *testing.T) { } for _, at := range tests { - t.Run(at.name, makeWriterToTarTest(at.w, at.apply, at.validator)) + t.Run(at.name, makeWriterToTarTest(at.w, at.apply, at.validator, at.err)) } } @@ -636,7 +791,7 @@ func testDiffApply(appliers ...fstest.Applier) error { return fstest.CheckDirectoryEqual(td, dest) } -func makeWriterToTarTest(wt WriterToTar, a fstest.Applier, validate func(string) error) func(*testing.T) { +func makeWriterToTarTest(wt WriterToTar, a fstest.Applier, validate func(string) error, applyErr error) func(*testing.T) { return func(t *testing.T) { td, err := ioutil.TempDir("", "test-writer-to-tar-") if err != nil { @@ -653,7 +808,14 @@ func makeWriterToTarTest(wt WriterToTar, a fstest.Applier, validate func(string) tr := TarFromWriterTo(wt) if _, err := Apply(context.Background(), td, tr); err != nil { - t.Fatalf("Failed to apply tar: %v", err) + if applyErr == nil { + t.Fatalf("Failed to apply tar: %v", err) + } else if errors.Cause(err) != applyErr { + t.Fatalf("Unexpected apply error: %v, expected %v", err, applyErr) + } + return + } else if applyErr != nil { + t.Fatalf("Expected apply error, got none: %v", applyErr) } if validate != nil {