From 57903eacc803ded81f02c9a7469a36324c8e985e Mon Sep 17 00:00:00 2001 From: Derek McGowan Date: Tue, 2 May 2017 21:57:33 -0700 Subject: [PATCH] Add Docker resolver tests Tests resolving, fetching, and using the various authentication methods supported by the Docker registry protocol. Signed-off-by: Derek McGowan --- remotes/docker/resolver_test.go | 388 ++++++++++++++++++++++++++++++++ 1 file changed, 388 insertions(+) create mode 100644 remotes/docker/resolver_test.go diff --git a/remotes/docker/resolver_test.go b/remotes/docker/resolver_test.go new file mode 100644 index 000000000..ee9f15080 --- /dev/null +++ b/remotes/docker/resolver_test.go @@ -0,0 +1,388 @@ +package docker + +import ( + "context" + "crypto/tls" + "crypto/x509" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "strconv" + "strings" + "testing" + + "github.com/containerd/containerd/remotes" + digest "github.com/opencontainers/go-digest" + specs "github.com/opencontainers/image-spec/specs-go" + ocispec "github.com/opencontainers/image-spec/specs-go/v1" + "github.com/pkg/errors" +) + +func TestHTTPResolver(t *testing.T) { + s := func(h http.Handler) (string, ResolverOptions, func()) { + s := httptest.NewServer(h) + + options := ResolverOptions{ + PlainHTTP: true, + } + base := s.URL[7:] // strip "http://" + return base, options, s.Close + } + + runBasicTest(t, "testname", s) +} + +func TestHTTPSResolver(t *testing.T) { + runBasicTest(t, "testname", tlsServer) +} + +func TestBasicResolver(t *testing.T) { + basicAuth := func(h http.Handler) (string, ResolverOptions, func()) { + // Wrap with basic auth + wrapped := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + username, password, ok := r.BasicAuth() + if !ok || username != "user1" || password != "password1" { + rw.Header().Set("WWW-Authenticate", "Basic realm=localhost") + rw.WriteHeader(http.StatusUnauthorized) + return + } + h.ServeHTTP(rw, r) + }) + + base, options, close := tlsServer(wrapped) + options.Credentials = func(string) (string, string, error) { + return "user1", "password1", nil + } + return base, options, close + } + runBasicTest(t, "testname", basicAuth) +} + +func TestAnonymousTokenResolver(t *testing.T) { + th := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + rw.WriteHeader(http.StatusMethodNotAllowed) + return + } + rw.Header().Set("Content-Type", "application/json") + rw.WriteHeader(http.StatusOK) + rw.Write([]byte(`{"access_token":"perfectlyvalidopaquetoken"}`)) + }) + + runBasicTest(t, "testname", withTokenServer(th, nil)) +} + +func TestBasicAuthTokenResolver(t *testing.T) { + th := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + rw.WriteHeader(http.StatusMethodNotAllowed) + return + } + rw.Header().Set("Content-Type", "application/json") + rw.WriteHeader(http.StatusOK) + username, password, ok := r.BasicAuth() + if !ok || username != "user1" || password != "password1" { + rw.Write([]byte(`{"access_token":"insufficientscope"}`)) + } else { + rw.Write([]byte(`{"access_token":"perfectlyvalidopaquetoken"}`)) + } + }) + creds := func(string) (string, string, error) { + return "user1", "password1", nil + } + + runBasicTest(t, "testname", withTokenServer(th, creds)) +} + +func TestRefreshTokenResolver(t *testing.T) { + th := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + rw.WriteHeader(http.StatusMethodNotAllowed) + return + } + rw.Header().Set("Content-Type", "application/json") + rw.WriteHeader(http.StatusOK) + + r.ParseForm() + if r.PostForm.Get("grant_type") != "refresh_token" || r.PostForm.Get("refresh_token") != "somerefreshtoken" { + rw.Write([]byte(`{"access_token":"insufficientscope"}`)) + } else { + rw.Write([]byte(`{"access_token":"perfectlyvalidopaquetoken"}`)) + } + }) + creds := func(string) (string, string, error) { + return "", "somerefreshtoken", nil + } + + runBasicTest(t, "testname", withTokenServer(th, creds)) +} + +func TestPostBasicAuthTokenResolver(t *testing.T) { + th := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + rw.WriteHeader(http.StatusMethodNotAllowed) + return + } + rw.Header().Set("Content-Type", "application/json") + rw.WriteHeader(http.StatusOK) + + r.ParseForm() + if r.PostForm.Get("grant_type") != "password" || r.PostForm.Get("username") != "user1" || r.PostForm.Get("password") != "password1" { + rw.Write([]byte(`{"access_token":"insufficientscope"}`)) + } else { + rw.Write([]byte(`{"access_token":"perfectlyvalidopaquetoken"}`)) + } + }) + creds := func(string) (string, string, error) { + return "user1", "password1", nil + } + + runBasicTest(t, "testname", withTokenServer(th, creds)) +} + +func TestBadTokenResolver(t *testing.T) { + th := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + rw.WriteHeader(http.StatusMethodNotAllowed) + return + } + rw.Header().Set("Content-Type", "application/json") + rw.WriteHeader(http.StatusOK) + rw.Write([]byte(`{"access_token":"insufficientscope"}`)) + }) + creds := func(string) (string, string, error) { + return "", "somerefreshtoken", nil + } + + ctx := context.Background() + h := content(ocispec.MediaTypeImageManifest, []byte("not anything parse-able")) + + base, ro, close := withTokenServer(th, creds)(logHandler{t, h}) + defer close() + + resolver := NewResolver(ro) + image := fmt.Sprintf("%s/doesntmatter:sometatg", base) + + _, _, _, err := resolver.Resolve(ctx, image) + if err == nil { + t.Fatal("Expected error getting token with inssufficient scope") + } + if errors.Cause(err) != ErrInvalidAuthorization { + t.Fatal(err) + } +} + +func withTokenServer(th http.Handler, creds func(string) (string, string, error)) func(h http.Handler) (string, ResolverOptions, func()) { + return func(h http.Handler) (string, ResolverOptions, func()) { + s := httptest.NewUnstartedServer(th) + s.StartTLS() + + cert, _ := x509.ParseCertificate(s.TLS.Certificates[0].Certificate[0]) + tokenBase := s.URL + "/token" + + // Wrap with token auth + wrapped := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + auth := strings.ToLower(r.Header.Get("Authorization")) + if auth != "bearer perfectlyvalidopaquetoken" { + authHeader := fmt.Sprintf("Bearer realm=%q,service=registry,scope=\"repository:testname:pull,pull\"", tokenBase) + if strings.HasPrefix(auth, "bearer ") { + authHeader = authHeader + ",error=" + auth[7:] + } + rw.Header().Set("WWW-Authenticate", authHeader) + rw.WriteHeader(http.StatusUnauthorized) + return + } + h.ServeHTTP(rw, r) + }) + + base, options, close := tlsServer(wrapped) + options.Credentials = creds + options.Client.Transport.(*http.Transport).TLSClientConfig.RootCAs.AddCert(cert) + return base, options, func() { + s.Close() + close() + } + } +} + +func tlsServer(h http.Handler) (string, ResolverOptions, func()) { + s := httptest.NewUnstartedServer(h) + s.StartTLS() + + capool := x509.NewCertPool() + cert, _ := x509.ParseCertificate(s.TLS.Certificates[0].Certificate[0]) + capool.AddCert(cert) + + options := ResolverOptions{ + Client: &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + RootCAs: capool, + }, + }, + }, + } + base := s.URL[8:] // strip "https://" + return base, options, s.Close +} + +type logHandler struct { + t *testing.T + handler http.Handler +} + +func (h logHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + h.t.Logf("%s %s", r.Method, r.URL.String()) + h.handler.ServeHTTP(rw, r) +} + +func runBasicTest(t *testing.T, name string, sf func(h http.Handler) (string, ResolverOptions, func())) { + var ( + ctx = context.Background() + tag = "latest" + r = http.NewServeMux() + ) + + m := newManifest( + content(ocispec.MediaTypeImageConfig, []byte("1")), + content(ocispec.MediaTypeImageLayerGzip, []byte("2")), + ) + mc := content(ocispec.MediaTypeImageManifest, m.OCIManifest()) + m.RegisterHandler(r, name) + r.Handle(fmt.Sprintf("/v2/%s/manifests/%s", name, tag), mc) + r.Handle(fmt.Sprintf("/v2/%s/manifests/%s", name, mc.Digest()), mc) + + base, ro, close := sf(logHandler{t, r}) + defer close() + + resolver := NewResolver(ro) + image := fmt.Sprintf("%s/%s:%s", base, name, tag) + + _, d, f, err := resolver.Resolve(ctx, image) + if err != nil { + t.Fatal(err) + } + + refs, err := testocimanifest(ctx, f, d) + if err != nil { + t.Fatal(err) + } + + if len(refs) != 2 { + t.Fatalf("Unexpected number of references: %d, expected 2", len(refs)) + } + + for _, ref := range refs { + if err := testFetch(ctx, f, ref); err != nil { + t.Fatal(err) + } + } +} + +func testFetch(ctx context.Context, f remotes.Fetcher, desc ocispec.Descriptor) error { + r, err := f.Fetch(ctx, desc) + if err != nil { + return err + } + dgstr := desc.Digest.Algorithm().Digester() + io.Copy(dgstr.Hash(), r) + if dgstr.Digest() != desc.Digest { + return errors.Errorf("content mismatch: %s != %s", dgstr.Digest(), desc.Digest) + } + + return nil +} + +func testocimanifest(ctx context.Context, f remotes.Fetcher, desc ocispec.Descriptor) ([]ocispec.Descriptor, error) { + r, err := f.Fetch(ctx, desc) + if err != nil { + return nil, errors.Wrapf(err, "failed to fetch %s", desc.Digest) + } + p, err := ioutil.ReadAll(r) + if err != nil { + return nil, err + } + if dgst := desc.Digest.Algorithm().FromBytes(p); dgst != desc.Digest { + return nil, errors.Errorf("digest mismatch: %s != %s", dgst, desc.Digest) + } + + var manifest ocispec.Manifest + if err := json.Unmarshal(p, &manifest); err != nil { + return nil, err + } + + var descs []ocispec.Descriptor + + descs = append(descs, manifest.Config) + descs = append(descs, manifest.Layers...) + + return descs, nil +} + +type testContent struct { + mediaType string + content []byte +} + +func content(mediaType string, b []byte) testContent { + return testContent{ + mediaType: mediaType, + content: b, + } +} + +func (tc testContent) Descriptor() ocispec.Descriptor { + return ocispec.Descriptor{ + MediaType: tc.mediaType, + Digest: digest.FromBytes(tc.content), + Size: int64(len(tc.content)), + } +} + +func (tc testContent) Digest() digest.Digest { + return digest.FromBytes(tc.content) +} + +func (tc testContent) ServeHTTP(w http.ResponseWriter, r *http.Request) { + w.Header().Add("Content-Type", tc.mediaType) + w.Header().Add("Content-Length", strconv.Itoa(len(tc.content))) + w.Header().Add("Docker-Content-Digest", tc.Digest().String()) + w.WriteHeader(http.StatusOK) + w.Write(tc.content) +} + +type testManifest struct { + config testContent + references []testContent +} + +func newManifest(config testContent, refs ...testContent) testManifest { + return testManifest{ + config: config, + references: refs, + } +} + +func (m testManifest) OCIManifest() []byte { + manifest := ocispec.Manifest{ + Versioned: specs.Versioned{ + SchemaVersion: 1, + }, + Config: m.config.Descriptor(), + Layers: make([]ocispec.Descriptor, len(m.references)), + } + for i, c := range append(m.references) { + manifest.Layers[i] = c.Descriptor() + } + b, _ := json.Marshal(manifest) + return b +} + +func (m testManifest) RegisterHandler(r *http.ServeMux, name string) { + for _, c := range append(m.references, m.config) { + r.Handle(fmt.Sprintf("/v2/%s/blobs/%s", name, c.Digest()), c) + } +}