diff --git a/remotes/docker/fetcher.go b/remotes/docker/fetcher.go index 55c01beaf..cd0168be5 100644 --- a/remotes/docker/fetcher.go +++ b/remotes/docker/fetcher.go @@ -98,6 +98,9 @@ func (r dockerFetcher) Fetch(ctx context.Context, desc ocispec.Descriptor) (io.R var firstErr error for _, host := range r.hosts { req := r.request(host, http.MethodGet, "manifests", desc.Digest.String()) + if err := req.addNamespace(r.refspec.Hostname()); err != nil { + return nil, err + } rc, err := r.open(ctx, req, desc.MediaType, offset) if err != nil { @@ -118,6 +121,9 @@ func (r dockerFetcher) Fetch(ctx context.Context, desc ocispec.Descriptor) (io.R var firstErr error for _, host := range r.hosts { req := r.request(host, http.MethodGet, "blobs", desc.Digest.String()) + if err := req.addNamespace(r.refspec.Hostname()); err != nil { + return nil, err + } rc, err := r.open(ctx, req, desc.MediaType, offset) if err != nil { diff --git a/remotes/docker/fetcher_test.go b/remotes/docker/fetcher_test.go index 5930ec614..a9d0e39f8 100644 --- a/remotes/docker/fetcher_test.go +++ b/remotes/docker/fetcher_test.go @@ -52,7 +52,7 @@ func TestFetcherOpen(t *testing.T) { } f := dockerFetcher{&dockerBase{ - namespace: "nonempty", + repository: "nonempty", }} host := RegistryHost{ @@ -182,7 +182,7 @@ func TestDockerFetcherOpen(t *testing.T) { } f := dockerFetcher{&dockerBase{ - namespace: "ns", + repository: "ns", }} host := RegistryHost{ diff --git a/remotes/docker/registry.go b/remotes/docker/registry.go index ffc939b40..7c231d928 100644 --- a/remotes/docker/registry.go +++ b/remotes/docker/registry.go @@ -73,6 +73,15 @@ type RegistryHost struct { Header http.Header } +func (h RegistryHost) isProxy(refhost string) bool { + if refhost != h.Host { + if refhost != "docker.io" || h.Host != "registry-1.docker.io" { + return true + } + } + return false +} + // RegistryHosts fetches the registry hosts for a given namespace, // provided by the host component of an distribution image reference. type RegistryHosts func(string) ([]RegistryHost, error) diff --git a/remotes/docker/resolver.go b/remotes/docker/resolver.go index 32b6abd90..53e42ecc5 100644 --- a/remotes/docker/resolver.go +++ b/remotes/docker/resolver.go @@ -22,6 +22,7 @@ import ( "io" "io/ioutil" "net/http" + "net/url" "path" "strings" @@ -276,6 +277,10 @@ func (r *dockerResolver) Resolve(ctx context.Context, ref string) (string, ocisp ctx := log.WithLogger(ctx, log.G(ctx).WithField("host", host.Host)) req := base.request(host, http.MethodHead, u...) + if err := req.addNamespace(base.refspec.Hostname()); err != nil { + return "", ocispec.Descriptor{}, err + } + for key, value := range r.resolveHeader { req.header[key] = append(req.header[key], value...) } @@ -323,6 +328,10 @@ func (r *dockerResolver) Resolve(ctx context.Context, ref string) (string, ocisp log.G(ctx).Debug("no Docker-Content-Digest header, fetching manifest instead") req = base.request(host, http.MethodGet, u...) + if err := req.addNamespace(base.refspec.Hostname()); err != nil { + return "", ocispec.Descriptor{}, err + } + for key, value := range r.resolveHeader { req.header[key] = append(req.header[key], value...) } @@ -416,10 +425,10 @@ func (r *dockerResolver) Pusher(ctx context.Context, ref string) (remotes.Pusher } type dockerBase struct { - refspec reference.Spec - namespace string - hosts []RegistryHost - header http.Header + refspec reference.Spec + repository string + hosts []RegistryHost + header http.Header } func (r *dockerResolver) base(refspec reference.Spec) (*dockerBase, error) { @@ -429,10 +438,10 @@ func (r *dockerResolver) base(refspec reference.Spec) (*dockerBase, error) { return nil, err } return &dockerBase{ - refspec: refspec, - namespace: strings.TrimPrefix(refspec.Locator, host+"/"), - hosts: hosts, - header: r.header, + refspec: refspec, + repository: strings.TrimPrefix(refspec.Locator, host+"/"), + hosts: hosts, + header: r.header, }, nil } @@ -453,7 +462,7 @@ func (r *dockerBase) request(host RegistryHost, method string, ps ...string) *re for key, value := range host.Header { header[key] = append(header[key], value...) } - parts := append([]string{"/", host.Path, r.namespace}, ps...) + parts := append([]string{"/", host.Path, r.repository}, ps...) p := path.Join(parts...) // Join strips trailing slash, re-add ending "/" if included if len(parts) > 0 && strings.HasSuffix(parts[len(parts)-1], "/") { @@ -478,6 +487,29 @@ func (r *request) authorize(ctx context.Context, req *http.Request) error { return nil } +func (r *request) addNamespace(ns string) (err error) { + if !r.host.isProxy(ns) { + return nil + } + var q url.Values + // Parse query + if i := strings.IndexByte(r.path, '?'); i > 0 { + r.path = r.path[:i+1] + q, err = url.ParseQuery(r.path[i+1:]) + if err != nil { + return + } + } else { + r.path = r.path + "?" + q = url.Values{} + } + q.Add("ns", ns) + + r.path = r.path + q.Encode() + + return +} + type request struct { method string path string diff --git a/remotes/docker/resolver_test.go b/remotes/docker/resolver_test.go index e34337db6..1f2ff8c01 100644 --- a/remotes/docker/resolver_test.go +++ b/remotes/docker/resolver_test.go @@ -280,6 +280,151 @@ func TestHostTLSFailureFallbackResolver(t *testing.T) { runBasicTest(t, "testname", sf) } +func TestResolveProxy(t *testing.T) { + var ( + ctx = context.Background() + tag = "latest" + r = http.NewServeMux() + name = "testname" + ns = "upstream.example.com" + ) + + m := newManifest( + newContent(ocispec.MediaTypeImageConfig, []byte("1")), + newContent(ocispec.MediaTypeImageLayerGzip, []byte("2")), + ) + mc := newContent(ocispec.MediaTypeImageManifest, m.OCIManifest()) + m.RegisterHandler(r, name) + r.Handle(fmt.Sprintf("/v2/%s/manifests/%s", name, tag), mc) + r.Handle(fmt.Sprintf("/v2/%s/manifests/%s", name, mc.Digest()), mc) + + nr := namespaceRouter{ + "upstream.example.com": r, + } + + base, ro, close := tlsServer(logHandler{t, nr}) + defer close() + + ro.Hosts = func(host string) ([]RegistryHost, error) { + return []RegistryHost{{ + Client: ro.Client, + Host: base, + Scheme: "https", + Path: "/v2", + Capabilities: HostCapabilityPull | HostCapabilityResolve, + }}, nil + } + + resolver := NewResolver(ro) + image := fmt.Sprintf("%s/%s:%s", ns, name, tag) + + _, d, err := resolver.Resolve(ctx, image) + if err != nil { + t.Fatal(err) + } + f, err := resolver.Fetcher(ctx, image) + if err != nil { + t.Fatal(err) + } + + refs, err := testocimanifest(ctx, f, d) + if err != nil { + t.Fatal(err) + } + + if len(refs) != 2 { + t.Fatalf("Unexpected number of references: %d, expected 2", len(refs)) + } + + for _, ref := range refs { + if err := testFetch(ctx, f, ref); err != nil { + t.Fatal(err) + } + } +} + +func TestResolveProxyFallback(t *testing.T) { + var ( + ctx = context.Background() + tag = "latest" + r = http.NewServeMux() + name = "testname" + ) + + m := newManifest( + newContent(ocispec.MediaTypeImageConfig, []byte("1")), + newContent(ocispec.MediaTypeImageLayerGzip, []byte("2")), + ) + mc := newContent(ocispec.MediaTypeImageManifest, m.OCIManifest()) + m.RegisterHandler(r, name) + r.Handle(fmt.Sprintf("/v2/%s/manifests/%s", name, tag), mc) + r.Handle(fmt.Sprintf("/v2/%s/manifests/%s", name, mc.Digest()), mc) + + nr := namespaceRouter{ + "": r, + } + s := httptest.NewServer(logHandler{t, nr}) + defer s.Close() + + base := s.URL[7:] // strip "http://" + + ro := ResolverOptions{ + Hosts: func(host string) ([]RegistryHost, error) { + return []RegistryHost{ + { + Host: flipLocalhost(host), + Scheme: "http", + Path: "/v2", + Capabilities: HostCapabilityPull | HostCapabilityResolve, + }, + { + Host: host, + Scheme: "http", + Path: "/v2", + Capabilities: HostCapabilityPull | HostCapabilityResolve | HostCapabilityPush, + }, + }, nil + }, + } + + resolver := NewResolver(ro) + image := fmt.Sprintf("%s/%s:%s", base, name, tag) + + _, d, err := resolver.Resolve(ctx, image) + if err != nil { + t.Fatal(err) + } + f, err := resolver.Fetcher(ctx, image) + if err != nil { + t.Fatal(err) + } + + refs, err := testocimanifest(ctx, f, d) + if err != nil { + t.Fatal(err) + } + + if len(refs) != 2 { + t.Fatalf("Unexpected number of references: %d, expected 2", len(refs)) + } + + for _, ref := range refs { + if err := testFetch(ctx, f, ref); err != nil { + t.Fatal(err) + } + } +} + +func flipLocalhost(host string) string { + if strings.HasPrefix(host, "127.0.0.1") { + return "localhost" + host[9:] + + } else if strings.HasPrefix(host, "localhost") { + return "127.0.0.1" + host[9:] + } + return host +} + func withTokenServer(th http.Handler, creds func(string) (string, string, error)) func(h http.Handler) (string, ResolverOptions, func()) { return func(h http.Handler) (string, ResolverOptions, func()) { s := httptest.NewUnstartedServer(th) @@ -352,6 +497,17 @@ func (h logHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { h.handler.ServeHTTP(rw, r) } +type namespaceRouter map[string]http.Handler + +func (nr namespaceRouter) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + h, ok := nr[r.URL.Query().Get("ns")] + if !ok { + rw.WriteHeader(http.StatusNotFound) + return + } + h.ServeHTTP(rw, r) +} + func runBasicTest(t *testing.T, name string, sf func(h http.Handler) (string, ResolverOptions, func())) { var ( ctx = context.Background()