diff --git a/archive/compression/compression.go b/archive/compression/compression.go index b82c60a36..ceceb21f5 100644 --- a/archive/compression/compression.go +++ b/archive/compression/compression.go @@ -21,6 +21,7 @@ import ( "bytes" "compress/gzip" "context" + "encoding/binary" "fmt" "io" "os" @@ -125,17 +126,52 @@ func (r *bufferedReader) Peek(n int) ([]byte, error) { return r.buf.Peek(n) } +const ( + zstdMagicSkippableStart = 0x184D2A50 + zstdMagicSkippableMask = 0xFFFFFFF0 +) + +var ( + gzipMagic = []byte{0x1F, 0x8B, 0x08} + zstdMagic = []byte{0x28, 0xb5, 0x2f, 0xfd} +) + +type matcher = func([]byte) bool + +func magicNumberMatcher(m []byte) matcher { + return func(source []byte) bool { + return bytes.HasPrefix(source, m) + } +} + +// zstdMatcher detects zstd compression algorithm. +// There are two frame formats defined by Zstandard: Zstandard frames and Skippable frames. +// See https://tools.ietf.org/id/draft-kucherawy-dispatch-zstd-00.html#rfc.section.2 for more details. +func zstdMatcher() matcher { + return func(source []byte) bool { + if bytes.HasPrefix(source, zstdMagic) { + // Zstandard frame + return true + } + // skippable frame + if len(source) < 8 { + return false + } + // magic number from 0x184D2A50 to 0x184D2A5F. + if binary.LittleEndian.Uint32(source[:4])&zstdMagicSkippableMask == zstdMagicSkippableStart { + return true + } + return false + } +} + // DetectCompression detects the compression algorithm of the source. func DetectCompression(source []byte) Compression { - for compression, m := range map[Compression][]byte{ - Gzip: {0x1F, 0x8B, 0x08}, - Zstd: {0x28, 0xb5, 0x2f, 0xfd}, + for compression, fn := range map[Compression]matcher{ + Gzip: magicNumberMatcher(gzipMagic), + Zstd: zstdMatcher(), } { - if len(source) < len(m) { - // Len too short - continue - } - if bytes.Equal(m, source[:len(m)]) { + if fn(source) { return compression } } diff --git a/archive/compression/compression_test.go b/archive/compression/compression_test.go index 5b16fcf2d..59fc0898d 100644 --- a/archive/compression/compression_test.go +++ b/archive/compression/compression_test.go @@ -189,3 +189,39 @@ func TestCmdStreamBad(t *testing.T) { t.Fatalf("wrong output: %s", string(buf)) } } + +func TestDetectCompressionZstd(t *testing.T) { + for _, tc := range []struct { + source []byte + expected Compression + }{ + { + // test zstd compression without skippable frames. + source: []byte{ + 0x28, 0xb5, 0x2f, 0xfd, // magic number of Zstandard frame: 0xFD2FB528 + 0x04, 0x00, 0x31, 0x00, 0x00, // frame header + 0x64, 0x6f, 0x63, 0x6b, 0x65, 0x72, // data block "docker" + 0x16, 0x0e, 0x21, 0xc3, // content checksum + }, + expected: Zstd, + }, + { + // test zstd compression with skippable frames. + source: []byte{ + 0x50, 0x2a, 0x4d, 0x18, // magic number of skippable frame: 0x184D2A50 to 0x184D2A5F + 0x04, 0x00, 0x00, 0x00, // frame size + 0x5d, 0x00, 0x00, 0x00, // user data + 0x28, 0xb5, 0x2f, 0xfd, // magic number of Zstandard frame: 0xFD2FB528 + 0x04, 0x00, 0x31, 0x00, 0x00, // frame header + 0x64, 0x6f, 0x63, 0x6b, 0x65, 0x72, // data block "docker" + 0x16, 0x0e, 0x21, 0xc3, // content checksum + }, + expected: Zstd, + }, + } { + compression := DetectCompression(tc.source) + if compression != tc.expected { + t.Fatalf("Unexpected compression %v, expected %v", compression, tc.expected) + } + } +}