diff --git a/archive/compression/compression.go b/archive/compression/compression.go index bd50f083b..bd64e0353 100644 --- a/archive/compression/compression.go +++ b/archive/compression/compression.go @@ -20,9 +20,15 @@ import ( "bufio" "bytes" "compress/gzip" + "context" "fmt" "io" + "os" + "os/exec" + "strconv" "sync" + + "github.com/containerd/containerd/log" ) type ( @@ -37,6 +43,13 @@ const ( Gzip ) +const disablePigzEnv = "CONTAINERD_DISABLE_PIGZ" + +var ( + initPigz sync.Once + unpigzPath string +) + var ( bufioReader32KPool = &sync.Pool{ New: func() interface{} { return bufio.NewReaderSize(nil, 32*1024) }, @@ -120,11 +133,18 @@ func DecompressStream(archive io.Reader) (DecompressReadCloser, error) { readBufWrapper := &readCloserWrapper{buf, compression, closer} return readBufWrapper, nil case Gzip: - gzReader, err := gzip.NewReader(buf) + ctx, cancel := context.WithCancel(context.Background()) + gzReader, err := gzipDecompress(ctx, buf) if err != nil { + cancel() return nil, err } - readBufWrapper := &readCloserWrapper{gzReader, compression, closer} + + readBufWrapper := &readCloserWrapper{gzReader, compression, func() error { + cancel() + return closer() + }} + return readBufWrapper, nil default: return nil, fmt.Errorf("unsupported compression format %s", (&compression).Extension()) @@ -151,3 +171,67 @@ func (compression *Compression) Extension() string { } return "" } + +func gzipDecompress(ctx context.Context, buf io.Reader) (io.ReadCloser, error) { + initPigz.Do(func() { + if unpigzPath = detectPigz(); unpigzPath != "" { + log.L.Debug("using pigz for decompression") + } + }) + + if unpigzPath == "" { + return gzip.NewReader(buf) + } + + return cmdStream(exec.CommandContext(ctx, unpigzPath, "-d", "-c"), buf) +} + +func cmdStream(cmd *exec.Cmd, in io.Reader) (io.ReadCloser, error) { + reader, writer := io.Pipe() + + cmd.Stdin = in + cmd.Stdout = writer + + var errBuf bytes.Buffer + cmd.Stderr = &errBuf + + if err := cmd.Start(); err != nil { + return nil, err + } + + go func() { + if err := cmd.Wait(); err != nil { + writer.CloseWithError(fmt.Errorf("%s: %s", err, errBuf.String())) + } else { + writer.Close() + } + }() + + return reader, nil +} + +func detectPigz() string { + path, err := exec.LookPath("unpigz") + if err != nil { + log.L.WithError(err).Debug("unpigz not found, falling back to go gzip") + return "" + } + + // Check if pigz disabled via CONTAINERD_DISABLE_PIGZ env variable + value := os.Getenv(disablePigzEnv) + if value == "" { + return path + } + + disable, err := strconv.ParseBool(value) + if err != nil { + log.L.WithError(err).Warnf("could not parse %s: %s", disablePigzEnv, value) + return path + } + + if disable { + return "" + } + + return path +} diff --git a/archive/compression/compression_test.go b/archive/compression/compression_test.go index 9d9092701..7f00e9fcd 100644 --- a/archive/compression/compression_test.go +++ b/archive/compression/compression_test.go @@ -18,11 +18,25 @@ package compression import ( "bytes" + "compress/gzip" + "context" + "io" "io/ioutil" "math/rand" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" "testing" ) +func TestMain(m *testing.M) { + // Force initPigz to be called, so tests start with the same initial state + gzipDecompress(context.Background(), strings.NewReader("")) + os.Exit(m.Run()) +} + // generateData generates data that composed of 2 random parts // and single zero-filled part within them. // Typically, the compression ratio would be about 67%. @@ -42,7 +56,7 @@ func generateData(t *testing.T, size int) []byte { return append(part0Data, append(part1Data, part2Data...)...) } -func testCompressDecompress(t *testing.T, size int, compression Compression) { +func testCompressDecompress(t *testing.T, size int, compression Compression) DecompressReadCloser { orig := generateData(t, size) var b bytes.Buffer compressor, err := CompressStream(&b, compression) @@ -72,12 +86,105 @@ func testCompressDecompress(t *testing.T, size int, compression Compression) { if !bytes.Equal(orig, decompressed) { t.Fatal("strange decompressed data") } + + return decompressor } func TestCompressDecompressGzip(t *testing.T) { - testCompressDecompress(t, 1024*1024, Gzip) + oldUnpigzPath := unpigzPath + unpigzPath = "" + defer func() { unpigzPath = oldUnpigzPath }() + + decompressor := testCompressDecompress(t, 1024*1024, Gzip) + wrapper := decompressor.(*readCloserWrapper) + _, ok := wrapper.Reader.(*gzip.Reader) + if !ok { + t.Fatalf("unexpected compressor type: %T", wrapper.Reader) + } +} + +func TestCompressDecompressPigz(t *testing.T) { + if _, err := exec.LookPath("unpigz"); err != nil { + t.Skip("pigz not installed") + } + + decompressor := testCompressDecompress(t, 1024*1024, Gzip) + wrapper := decompressor.(*readCloserWrapper) + _, ok := wrapper.Reader.(*io.PipeReader) + if !ok { + t.Fatalf("unexpected compressor type: %T", wrapper.Reader) + } } func TestCompressDecompressUncompressed(t *testing.T) { testCompressDecompress(t, 1024*1024, Uncompressed) } + +func TestDetectPigz(t *testing.T) { + // Create fake PATH with unpigz executable, make sure detectPigz can find it + tempPath, err := ioutil.TempDir("", "containerd_temp_") + if err != nil { + t.Fatal(err) + } + + filename := "unpigz" + if runtime.GOOS == "windows" { + filename = "unpigz.exe" + } + + fullPath := filepath.Join(tempPath, filename) + + if err := ioutil.WriteFile(fullPath, []byte(""), 0111); err != nil { + t.Fatal(err) + } + + defer os.RemoveAll(tempPath) + + oldPath := os.Getenv("PATH") + os.Setenv("PATH", tempPath) + defer os.Setenv("PATH", oldPath) + + if pigzPath := detectPigz(); pigzPath == "" { + t.Fatal("failed to detect pigz path") + } else if pigzPath != fullPath { + t.Fatalf("wrong pigz found: %s != %s", pigzPath, fullPath) + } + + os.Setenv(disablePigzEnv, "1") + defer os.Unsetenv(disablePigzEnv) + + if pigzPath := detectPigz(); pigzPath != "" { + t.Fatalf("disable via %s doesn't work", disablePigzEnv) + } +} + +func TestCmdStream(t *testing.T) { + out, err := cmdStream(exec.Command("sh", "-c", "echo hello; exit 0"), nil) + if err != nil { + t.Fatal(err) + } + + buf, err := ioutil.ReadAll(out) + if err != nil { + t.Fatalf("failed to read from stdout: %s", err) + } + + if string(buf) != "hello\n" { + t.Fatalf("unexpected command output ('%s' != '%s')", string(buf), "hello\n") + } +} + +func TestCmdStreamBad(t *testing.T) { + out, err := cmdStream(exec.Command("sh", "-c", "echo hello; echo >&2 bad result; exit 1"), nil) + if err != nil { + t.Fatalf("failed to start command: %v", err) + } + + if buf, err := ioutil.ReadAll(out); err == nil { + t.Fatal("command should have failed") + } else if err.Error() != "exit status 1: bad result\n" { + t.Fatalf("wrong error: %s", err.Error()) + } else if string(buf) != "hello\n" { + t.Fatalf("wrong output: %s", string(buf)) + } +}