diff --git a/remotes/docker/fetcher.go b/remotes/docker/fetcher.go index 3efe4124e..ec2f9f850 100644 --- a/remotes/docker/fetcher.go +++ b/remotes/docker/fetcher.go @@ -17,6 +17,8 @@ package docker import ( + "compress/flate" + "compress/gzip" "context" "encoding/json" "errors" @@ -30,6 +32,7 @@ import ( "github.com/containerd/containerd/v2/images" "github.com/containerd/containerd/v2/remotes" "github.com/containerd/log" + "github.com/klauspost/compress/zstd" digest "github.com/opencontainers/go-digest" ocispec "github.com/opencontainers/image-spec/specs-go/v1" ) @@ -262,6 +265,7 @@ func (r dockerFetcher) open(ctx context.Context, req *request, mediatype string, } else { req.header.Set("Accept", strings.Join([]string{mediatype, `*/*`}, ", ")) } + req.header.Set("Accept-Encoding", "zstd;q=1.0, gzip;q=0.8, deflate;q=0.5") if offset > 0 { // Note: "Accept-Ranges: bytes" cannot be trusted as some endpoints @@ -320,5 +324,32 @@ func (r dockerFetcher) open(ctx context.Context, req *request, mediatype string, } } - return resp.Body, nil + body := resp.Body + encoding := strings.FieldsFunc(resp.Header.Get("Content-Encoding"), func(r rune) bool { + return r == ' ' || r == '\t' || r == ',' + }) + for i := len(encoding) - 1; i >= 0; i-- { + algorithm := strings.ToLower(encoding[i]) + switch algorithm { + case "zstd": + r, err := zstd.NewReader(body) + if err != nil { + return nil, err + } + body = r.IOReadCloser() + case "gzip": + body, err = gzip.NewReader(body) + if err != nil { + return nil, err + } + case "deflate": + body = flate.NewReader(body) + case "identity", "": + // no content-encoding applied, use raw body + default: + return nil, errors.New("unsupported Content-Encoding algorithm: " + algorithm) + } + } + + return body, nil } diff --git a/remotes/docker/fetcher_test.go b/remotes/docker/fetcher_test.go index 695d6045f..b9a0f5b87 100644 --- a/remotes/docker/fetcher_test.go +++ b/remotes/docker/fetcher_test.go @@ -17,6 +17,9 @@ package docker import ( + "bytes" + "compress/flate" + "compress/gzip" "context" "encoding/json" "fmt" @@ -28,6 +31,7 @@ import ( "strconv" "testing" + "github.com/klauspost/compress/zstd" "github.com/stretchr/testify/assert" ) @@ -114,6 +118,152 @@ func TestFetcherOpen(t *testing.T) { } } +func TestContentEncoding(t *testing.T) { + t.Parallel() + + zstdEncode := func(in []byte) []byte { + var b bytes.Buffer + zw, err := zstd.NewWriter(&b) + if err != nil { + t.Fatal(err) + } + _, err = zw.Write(in) + if err != nil { + t.Fatal() + } + err = zw.Close() + if err != nil { + t.Fatal(err) + } + return b.Bytes() + } + gzipEncode := func(in []byte) []byte { + var b bytes.Buffer + gw := gzip.NewWriter(&b) + _, err := gw.Write(in) + if err != nil { + t.Fatal(err) + } + err = gw.Close() + if err != nil { + t.Fatal(err) + } + return b.Bytes() + } + flateEncode := func(in []byte) []byte { + var b bytes.Buffer + dw, err := flate.NewWriter(&b, -1) + if err != nil { + t.Fatal(err) + } + _, err = dw.Write(in) + if err != nil { + t.Fatal(err) + } + err = dw.Close() + if err != nil { + t.Fatal(err) + } + return b.Bytes() + } + + tests := []struct { + encodingFuncs []func([]byte) []byte + encodingHeader string + }{ + { + encodingFuncs: []func([]byte) []byte{}, + encodingHeader: "", + }, + { + encodingFuncs: []func([]byte) []byte{zstdEncode}, + encodingHeader: "zstd", + }, + { + encodingFuncs: []func([]byte) []byte{gzipEncode}, + encodingHeader: "gzip", + }, + { + encodingFuncs: []func([]byte) []byte{flateEncode}, + encodingHeader: "deflate", + }, + { + encodingFuncs: []func([]byte) []byte{zstdEncode, gzipEncode}, + encodingHeader: "zstd,gzip", + }, + { + encodingFuncs: []func([]byte) []byte{gzipEncode, flateEncode}, + encodingHeader: "gzip,deflate", + }, + { + encodingFuncs: []func([]byte) []byte{gzipEncode, zstdEncode}, + encodingHeader: "gzip,zstd", + }, + { + encodingFuncs: []func([]byte) []byte{gzipEncode, zstdEncode, flateEncode}, + encodingHeader: "gzip,zstd,deflate", + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.encodingHeader, func(t *testing.T) { + t.Parallel() + content := make([]byte, 128) + rand.New(rand.NewSource(1)).Read(content) + + s := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + compressedContent := content + for _, enc := range tc.encodingFuncs { + compressedContent = enc(compressedContent) + } + rw.Header().Set("content-length", fmt.Sprintf("%d", len(compressedContent))) + rw.Header().Set("Content-Encoding", tc.encodingHeader) + rw.Write(compressedContent) + })) + defer s.Close() + + u, err := url.Parse(s.URL) + if err != nil { + t.Fatal(err) + } + + f := dockerFetcher{&dockerBase{ + repository: "nonempty", + }} + + host := RegistryHost{ + Client: s.Client(), + Host: u.Host, + Scheme: u.Scheme, + Path: u.Path, + } + + req := f.request(host, http.MethodGet) + + rc, err := f.open(context.Background(), req, "", 0) + if err != nil { + t.Fatalf("failed to open for encoding %s: %+v", tc.encodingHeader, err) + } + b, err := io.ReadAll(rc) + if err != nil { + t.Fatal(err) + } + expected := content + if len(b) != len(expected) { + t.Errorf("unexpected length %d, expected %d", len(b), len(expected)) + return + } + for i, c := range expected { + if b[i] != c { + t.Errorf("unexpected byte %x at %d, expected %x", b[i], i, c) + return + } + } + }) + } +} + // New set of tests to test new error cases func TestDockerFetcherOpen(t *testing.T) { tests := []struct {