diff --git a/archive/compression/compression.go b/archive/compression/compression.go index bd64e0353..60c80e98a 100644 --- a/archive/compression/compression.go +++ b/archive/compression/compression.go @@ -92,6 +92,36 @@ func (w *writeCloserWrapper) Close() error { return nil } +type bufferedReader struct { + buf *bufio.Reader +} + +func newBufferedReader(r io.Reader) *bufferedReader { + buf := bufioReader32KPool.Get().(*bufio.Reader) + buf.Reset(r) + return &bufferedReader{buf} +} + +func (r *bufferedReader) Read(p []byte) (n int, err error) { + if r.buf == nil { + return 0, io.EOF + } + n, err = r.buf.Read(p) + if err == io.EOF { + r.buf.Reset(nil) + bufioReader32KPool.Put(r.buf) + r.buf = nil + } + return +} + +func (r *bufferedReader) Peek(n int) ([]byte, error) { + if r.buf == nil { + return nil, io.EOF + } + return r.buf.Peek(n) +} + // DetectCompression detects the compression algorithm of the source. func DetectCompression(source []byte) Compression { for compression, m := range map[Compression][]byte{ @@ -110,8 +140,7 @@ func DetectCompression(source []byte) Compression { // DecompressStream decompresses the archive and returns a ReaderCloser with the decompressed archive. func DecompressStream(archive io.Reader) (DecompressReadCloser, error) { - buf := bufioReader32KPool.Get().(*bufio.Reader) - buf.Reset(archive) + buf := newBufferedReader(archive) bs, err := buf.Peek(10) if err != nil && err != io.EOF { // Note: we'll ignore any io.EOF error because there are some odd @@ -123,15 +152,12 @@ func DecompressStream(archive io.Reader) (DecompressReadCloser, error) { return nil, err } - closer := func() error { - buf.Reset(nil) - bufioReader32KPool.Put(buf) - return nil - } switch compression := DetectCompression(bs); compression { case Uncompressed: - readBufWrapper := &readCloserWrapper{buf, compression, closer} - return readBufWrapper, nil + return &readCloserWrapper{ + Reader: buf, + compression: compression, + }, nil case Gzip: ctx, cancel := context.WithCancel(context.Background()) gzReader, err := gzipDecompress(ctx, buf) @@ -140,12 +166,15 @@ func DecompressStream(archive io.Reader) (DecompressReadCloser, error) { return nil, err } - readBufWrapper := &readCloserWrapper{gzReader, compression, func() error { - cancel() - return closer() - }} + return &readCloserWrapper{ + Reader: gzReader, + compression: compression, + closer: func() error { + cancel() + return gzReader.Close() + }, + }, nil - return readBufWrapper, nil default: return nil, fmt.Errorf("unsupported compression format %s", (&compression).Extension()) }