diff --git a/pkg/server/image_pull.go b/pkg/server/image_pull.go index db6ec46ec..b545405f4 100644 --- a/pkg/server/image_pull.go +++ b/pkg/server/image_pull.go @@ -21,9 +21,11 @@ import ( "crypto/x509" "encoding/base64" "io/ioutil" + "net" "net/http" "net/url" "strings" + "time" "github.com/containerd/containerd" "github.com/containerd/containerd/errdefs" @@ -251,8 +253,8 @@ func (c *criService) credentials(auth *runtime.AuthConfig) func(string) (string, } } -// getRegistryTLSTransport returns a http.Transport configured with a CA/Cert/Key specified by registryTLSConfig -func (c *criService) getRegistryTLSTransport(ctx context.Context, registryTLSConfig config.TLSConfig) (*http.Transport, error) { +// getTLSConfig returns a TLSConfig configured with a CA/Cert/Key specified by registryTLSConfig +func (c *criService) getTLSConfig(registryTLSConfig config.TLSConfig) (*tls.Config, error) { cert, err := tls.LoadX509KeyPair(registryTLSConfig.CertFile, registryTLSConfig.KeyFile) if err != nil { return nil, errors.Wrap(err, "failed to load cert file") @@ -273,9 +275,7 @@ func (c *criService) getRegistryTLSTransport(ctx context.Context, registryTLSCon RootCAs: caCertPool, } tlsConfig.BuildNameToCertificate() - transport := &http.Transport{TLSClientConfig: tlsConfig} - - return transport, nil + return tlsConfig, nil } // getResolver tries registry mirrors and the default registry, and returns the resolver and descriptor @@ -286,7 +286,10 @@ func (c *criService) getResolver(ctx context.Context, ref string, cred func(stri return nil, imagespec.Descriptor{}, errors.Wrap(err, "parse image reference") } - httpClient := http.DefaultClient + var ( + transport = newTransport() + httpClient = &http.Client{Transport: transport} + ) // Try mirrors in order first, and then try default host name. for _, e := range c.config.Registry.Mirrors[refspec.Hostname()].Endpoints { @@ -296,15 +299,14 @@ func (c *criService) getResolver(ctx context.Context, ref string, cred func(stri } if registryTLSConfig, ok := c.config.Registry.TLSConfigs[u.Host]; ok { - tlsTransport, err := c.getRegistryTLSTransport(ctx, registryTLSConfig) + transport.TLSClientConfig, err = c.getTLSConfig(registryTLSConfig) if err != nil { - return nil, imagespec.Descriptor{}, errors.Wrapf(err, "get tlsTransport for registry %q", refspec.Hostname()) + return nil, imagespec.Descriptor{}, errors.Wrapf(err, "get TLSConfig for registry %q", refspec.Hostname()) } - httpClient = &http.Client{Transport: tlsTransport} } resolver := docker.NewResolver(docker.ResolverOptions{ - Authorizer: docker.NewAuthorizer(http.DefaultClient, cred), + Authorizer: docker.NewAuthorizer(httpClient, cred), Client: httpClient, Host: func(string) (string, error) { return u.Host, nil }, // By default use "https". @@ -323,11 +325,10 @@ func (c *criService) getResolver(ctx context.Context, ref string, cred func(stri return nil, imagespec.Descriptor{}, errors.Wrapf(err, "get host for refspec %q", refspec.Hostname()) } if registryTLSConfig, ok := c.config.Registry.TLSConfigs[hostname]; ok { - tlsTransport, err := c.getRegistryTLSTransport(ctx, registryTLSConfig) + transport.TLSClientConfig, err = c.getTLSConfig(registryTLSConfig) if err != nil { - return nil, imagespec.Descriptor{}, errors.Wrapf(err, "get tlsTransport for registry %q", refspec.Hostname()) + return nil, imagespec.Descriptor{}, errors.Wrapf(err, "get TLSConfig for registry %q", refspec.Hostname()) } - httpClient = &http.Client{Transport: tlsTransport} } resolver := docker.NewResolver(docker.ResolverOptions{ @@ -340,3 +341,20 @@ func (c *criService) getResolver(ctx context.Context, ref string, cred func(stri } return resolver, desc, nil } + +// newTransport returns a new HTTP transport used to pull image. +// TODO(random-liu): Create a library and share this code with `ctr`. +func newTransport() *http.Transport { + return &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + DualStack: true, + }).DialContext, + MaxIdleConns: 10, + IdleConnTimeout: 30 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 5 * time.Second, + } +}