diff --git a/remotes/docker/resolver.go b/remotes/docker/resolver.go index a04ef4d90..96110a188 100644 --- a/remotes/docker/resolver.go +++ b/remotes/docker/resolver.go @@ -148,6 +148,9 @@ func NewResolver(options ResolverOptions) remotes.Resolver { if options.Headers == nil { options.Headers = make(http.Header) + } else { + // make a copy of the headers to avoid race due to concurrent map write + options.Headers = options.Headers.Clone() } if _, ok := options.Headers["User-Agent"]; !ok { options.Headers.Set("User-Agent", "containerd/"+version.Version) @@ -543,9 +546,10 @@ func (r *request) do(ctx context.Context) (*http.Response, error) { if err != nil { return nil, err } - req.Header = http.Header{} // headers need to be copied to avoid concurrent map access - for k, v := range r.header { - req.Header[k] = v + if r.header == nil { + req.Header = http.Header{} + } else { + req.Header = r.header.Clone() // headers need to be copied to avoid concurrent map access } if r.body != nil { body, err := r.body() diff --git a/remotes/docker/resolver_test.go b/remotes/docker/resolver_test.go index 4719e75fb..c1522d9b0 100644 --- a/remotes/docker/resolver_test.go +++ b/remotes/docker/resolver_test.go @@ -46,7 +46,6 @@ func TestHTTPResolver(t *testing.T) { base := s.URL[7:] // strip "http://" return base, options, s.Close } - runBasicTest(t, "testname", s) } @@ -54,6 +53,30 @@ func TestHTTPSResolver(t *testing.T) { runBasicTest(t, "testname", tlsServer) } +func TestResolverOptionsRace(t *testing.T) { + header := http.Header{} + header.Set("X-Test", "test") + + s := func(h http.Handler) (string, ResolverOptions, func()) { + s := httptest.NewServer(h) + + options := ResolverOptions{ + Headers: header, + } + base := s.URL[7:] // strip "http://" + return base, options, s.Close + } + + for i := 0; i < 5; i++ { + t.Run(fmt.Sprintf("test ResolverOptions race %d", i), func(t *testing.T) { + // parallel sub tests so the race condition (if not handled) can be caught + // by race detector + t.Parallel() + runBasicTest(t, "testname", s) + }) + } +} + func TestBasicResolver(t *testing.T) { basicAuth := func(h http.Handler) (string, ResolverOptions, func()) { // Wrap with basic auth