diff --git a/image.go b/image.go index 216b3adb1..784df5dd9 100644 --- a/image.go +++ b/image.go @@ -28,6 +28,7 @@ import ( "github.com/containerd/containerd/diff" "github.com/containerd/containerd/errdefs" "github.com/containerd/containerd/images" + "github.com/containerd/containerd/pkg/kmutex" "github.com/containerd/containerd/platforms" "github.com/containerd/containerd/rootfs" "github.com/containerd/containerd/snapshots" @@ -287,6 +288,10 @@ type UnpackConfig struct { // CheckPlatformSupported is whether to validate that a snapshotter // supports an image's platform before unpacking CheckPlatformSupported bool + // DuplicationSuppressor is used to make sure that there is only one + // in-flight fetch request or unpack handler for a given descriptor's + // digest or chain ID. + DuplicationSuppressor kmutex.KeyedLocker } // UnpackOpt provides configuration for unpack @@ -300,6 +305,14 @@ func WithSnapshotterPlatformCheck() UnpackOpt { } } +// WithUnpackDuplicationSuppressor sets `DuplicationSuppressor` on the UnpackConfig. +func WithUnpackDuplicationSuppressor(suppressor kmutex.KeyedLocker) UnpackOpt { + return func(ctx context.Context, uc *UnpackConfig) error { + uc.DuplicationSuppressor = suppressor + return nil + } +} + func (i *image) Unpack(ctx context.Context, snapshotterName string, opts ...UnpackOpt) error { ctx, done, err := i.client.WithLease(ctx) if err != nil { diff --git a/pkg/cri/server/image_pull.go b/pkg/cri/server/image_pull.go index dd3eecab8..46c3ffb9e 100644 --- a/pkg/cri/server/image_pull.go +++ b/pkg/cri/server/image_pull.go @@ -121,6 +121,9 @@ func (c *criService) PullImage(ctx context.Context, r *runtime.PullImageRequest) containerd.WithPullLabel(imageLabelKey, imageLabelValue), containerd.WithMaxConcurrentDownloads(c.config.MaxConcurrentDownloads), containerd.WithImageHandler(imageHandler), + containerd.WithUnpackOpts([]containerd.UnpackOpt{ + containerd.WithUnpackDuplicationSuppressor(c.unpackDuplicationSuppressor), + }), } pullOpts = append(pullOpts, c.encryptedImagesPullOpts()...) diff --git a/pkg/cri/server/service.go b/pkg/cri/server/service.go index b15164fea..2e2cca42b 100644 --- a/pkg/cri/server/service.go +++ b/pkg/cri/server/service.go @@ -29,6 +29,7 @@ import ( "github.com/containerd/containerd" "github.com/containerd/containerd/oci" "github.com/containerd/containerd/pkg/cri/streaming" + "github.com/containerd/containerd/pkg/kmutex" "github.com/containerd/containerd/plugin" cni "github.com/containerd/go-cni" "github.com/sirupsen/logrus" @@ -113,6 +114,10 @@ type criService struct { // allCaps is the list of the capabilities. // When nil, parsed from CapEff of /proc/self/status. allCaps []string // nolint + // unpackDuplicationSuppressor is used to make sure that there is only + // one in-flight fetch request or unpack handler for a given descriptor's + // or chain ID. + unpackDuplicationSuppressor kmutex.KeyedLocker } // NewCRIService returns a new instance of CRIService @@ -120,17 +125,18 @@ func NewCRIService(config criconfig.Config, client *containerd.Client) (CRIServi var err error labels := label.NewStore() c := &criService{ - config: config, - client: client, - os: osinterface.RealOS{}, - sandboxStore: sandboxstore.NewStore(labels), - containerStore: containerstore.NewStore(labels), - imageStore: imagestore.NewStore(client), - snapshotStore: snapshotstore.NewStore(), - sandboxNameIndex: registrar.NewRegistrar(), - containerNameIndex: registrar.NewRegistrar(), - initialized: atomic.NewBool(false), - netPlugin: make(map[string]cni.CNI), + config: config, + client: client, + os: osinterface.RealOS{}, + sandboxStore: sandboxstore.NewStore(labels), + containerStore: containerstore.NewStore(labels), + imageStore: imagestore.NewStore(client), + snapshotStore: snapshotstore.NewStore(), + sandboxNameIndex: registrar.NewRegistrar(), + containerNameIndex: registrar.NewRegistrar(), + initialized: atomic.NewBool(false), + netPlugin: make(map[string]cni.CNI), + unpackDuplicationSuppressor: kmutex.New(), } if client.SnapshotService(c.config.ContainerdConfig.Snapshotter) == nil { diff --git a/pkg/kmutex/kmutex.go b/pkg/kmutex/kmutex.go new file mode 100644 index 000000000..74846c057 --- /dev/null +++ b/pkg/kmutex/kmutex.go @@ -0,0 +1,105 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +// Package kmutex provides synchronization primitives to lock/unlock resource by unique key. +package kmutex + +import ( + "context" + "fmt" + "sync" + + "golang.org/x/sync/semaphore" +) + +// KeyedLocker is the interface for acquiring locks based on string. +type KeyedLocker interface { + Lock(ctx context.Context, key string) error + Unlock(key string) +} + +func New() KeyedLocker { + return newKeyMutex() +} + +func newKeyMutex() *keyMutex { + return &keyMutex{ + locks: make(map[string]*klock), + } +} + +type keyMutex struct { + mu sync.Mutex + + locks map[string]*klock +} + +type klock struct { + *semaphore.Weighted + ref int +} + +func (km *keyMutex) Lock(ctx context.Context, key string) error { + km.mu.Lock() + + l, ok := km.locks[key] + if !ok { + km.locks[key] = &klock{ + Weighted: semaphore.NewWeighted(1), + } + l = km.locks[key] + } + l.ref++ + km.mu.Unlock() + + if err := l.Acquire(ctx, 1); err != nil { + km.mu.Lock() + defer km.mu.Unlock() + + l.ref-- + + if l.ref < 0 { + panic(fmt.Errorf("kmutex: release of unlocked key %v", key)) + } + + if l.ref == 0 { + delete(km.locks, key) + } + return err + } + return nil +} + +func (km *keyMutex) Unlock(key string) { + km.mu.Lock() + defer km.mu.Unlock() + + l, ok := km.locks[key] + if !ok { + panic(fmt.Errorf("kmutex: unlock of unlocked key %v", key)) + } + l.Release(1) + + l.ref-- + + if l.ref < 0 { + panic(fmt.Errorf("kmutex: released of unlocked key %v", key)) + } + + if l.ref == 0 { + delete(km.locks, key) + } +} diff --git a/pkg/kmutex/kmutex_test.go b/pkg/kmutex/kmutex_test.go new file mode 100644 index 000000000..a6fd751ee --- /dev/null +++ b/pkg/kmutex/kmutex_test.go @@ -0,0 +1,175 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package kmutex + +import ( + "context" + "math/rand" + "runtime" + "strconv" + "sync" + "testing" + "time" + + "github.com/containerd/containerd/pkg/seed" + "github.com/stretchr/testify/assert" +) + +func init() { + seed.WithTimeAndRand() +} + +func TestBasic(t *testing.T) { + t.Parallel() + + km := newKeyMutex() + ctx := context.Background() + + km.Lock(ctx, "c1") + km.Lock(ctx, "c2") + + assert.Equal(t, len(km.locks), 2) + assert.Equal(t, km.locks["c1"].ref, 1) + assert.Equal(t, km.locks["c2"].ref, 1) + + checkWaitFn := func(key string, num int) { + retries := 100 + waitLock := false + + for i := 0; i < retries; i++ { + // prevent from data-race + km.mu.Lock() + ref := km.locks[key].ref + km.mu.Unlock() + + if ref == num { + waitLock = true + break + } + time.Sleep(time.Duration(rand.Int63n(100)) * time.Millisecond) + } + assert.Equal(t, waitLock, true) + } + + // should acquire successfully after release + { + waitCh := make(chan struct{}) + go func() { + defer close(waitCh) + + km.Lock(ctx, "c1") + }() + checkWaitFn("c1", 2) + + km.Unlock("c1") + + <-waitCh + assert.Equal(t, km.locks["c1"].ref, 1) + } + + // failed to acquire if context cancel + { + var errCh = make(chan error, 1) + + ctx, cancel := context.WithCancel(context.Background()) + go func() { + errCh <- km.Lock(ctx, "c1") + }() + + checkWaitFn("c1", 2) + + cancel() + assert.Equal(t, <-errCh, context.Canceled) + assert.Equal(t, km.locks["c1"].ref, 1) + } +} + +func TestReleasePanic(t *testing.T) { + t.Parallel() + + km := newKeyMutex() + + defer func() { + if recover() == nil { + t.Fatal("release of unlocked key did not panic") + } + }() + + km.Unlock(t.Name()) +} + +func TestMultileAcquireOnKeys(t *testing.T) { + t.Parallel() + + km := newKeyMutex() + nloops := 10000 + nproc := runtime.GOMAXPROCS(0) + ctx := context.Background() + + var wg sync.WaitGroup + for i := 0; i < nproc; i++ { + wg.Add(1) + + go func(key string) { + defer wg.Done() + + for i := 0; i < nloops; i++ { + km.Lock(ctx, key) + + time.Sleep(time.Duration(rand.Int63n(100)) * time.Nanosecond) + + km.Unlock(key) + } + }("key-" + strconv.Itoa(i)) + } + wg.Wait() +} + +func TestMultiAcquireOnSameKey(t *testing.T) { + t.Parallel() + + km := newKeyMutex() + key := "c1" + ctx := context.Background() + + assert.Nil(t, km.Lock(ctx, key)) + + nproc := runtime.GOMAXPROCS(0) + nloops := 10000 + + var wg sync.WaitGroup + for i := 0; i < nproc; i++ { + wg.Add(1) + + go func() { + defer wg.Done() + + for i := 0; i < nloops; i++ { + km.Lock(ctx, key) + + time.Sleep(time.Duration(rand.Int63n(100)) * time.Nanosecond) + + km.Unlock(key) + } + }() + } + km.Unlock(key) + wg.Wait() + + // c1 key has been released so the it should not have any klock. + assert.Equal(t, len(km.locks), 0) +} diff --git a/pkg/kmutex/noop.go b/pkg/kmutex/noop.go new file mode 100644 index 000000000..66c46f15a --- /dev/null +++ b/pkg/kmutex/noop.go @@ -0,0 +1,33 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package kmutex + +import "context" + +func NewNoop() KeyedLocker { + return &noopMutex{} +} + +type noopMutex struct { +} + +func (*noopMutex) Lock(_ context.Context, _ string) error { + return nil +} + +func (*noopMutex) Unlock(_ string) { +} diff --git a/unpacker.go b/unpacker.go index 719345a1c..03cf7554e 100644 --- a/unpacker.go +++ b/unpacker.go @@ -32,6 +32,7 @@ import ( "github.com/containerd/containerd/images" "github.com/containerd/containerd/log" "github.com/containerd/containerd/mount" + "github.com/containerd/containerd/pkg/kmutex" "github.com/containerd/containerd/platforms" "github.com/containerd/containerd/snapshots" "github.com/opencontainers/go-digest" @@ -59,7 +60,9 @@ func (c *Client) newUnpacker(ctx context.Context, rCtx *RemoteContext) (*unpacke if err != nil { return nil, err } - var config UnpackConfig + var config = UnpackConfig{ + DuplicationSuppressor: kmutex.NewNoop(), + } for _, o := range rCtx.UnpackOpts { if err := o(ctx, &config); err != nil { return nil, err @@ -127,15 +130,20 @@ func (u *unpacker) unpack( ctx, cancel := context.WithCancel(ctx) defer cancel() -EachLayer: - for i, desc := range layers { + doUnpackFn := func(i int, desc ocispec.Descriptor) error { parent := identity.ChainID(chain) chain = append(chain, diffIDs[i]) - chainID := identity.ChainID(chain).String() + + unlock, err := u.lockSnChainID(ctx, chainID) + if err != nil { + return err + } + defer unlock() + if _, err := sn.Stat(ctx, chainID); err == nil { // no need to handle - continue + return nil } else if !errdefs.IsNotFound(err) { return fmt.Errorf("failed to stat snapshot %s: %w", chainID, err) } @@ -167,7 +175,7 @@ EachLayer: log.G(ctx).WithField("key", key).WithField("chainid", chainID).Debug("extraction snapshot already exists, chain id not found") } else { // no need to handle, snapshot now found with chain id - continue EachLayer + return nil } } else { return fmt.Errorf("failed to prepare extraction snapshot %q: %w", key, err) @@ -227,7 +235,7 @@ EachLayer: if err = sn.Commit(ctx, chainID, key, opts...); err != nil { abort() if errdefs.IsAlreadyExists(err) { - continue + return nil } return fmt.Errorf("failed to commit snapshot %s: %w", key, err) } @@ -243,7 +251,13 @@ EachLayer: if _, err := cs.Update(ctx, cinfo, "labels.containerd.io/uncompressed"); err != nil { return err } + return nil + } + for i, desc := range layers { + if err := doUnpackFn(i, desc); err != nil { + return err + } } chainID := identity.ChainID(chain).String() @@ -271,17 +285,22 @@ func (u *unpacker) fetch(ctx context.Context, h images.Handler, layers []ocispec desc := desc i := i - if u.limiter != nil { - if err := u.limiter.Acquire(ctx, 1); err != nil { - return err - } + if err := u.acquire(ctx); err != nil { + return err } eg.Go(func() error { - _, err := h.Handle(ctx2, desc) - if u.limiter != nil { - u.limiter.Release(1) + unlock, err := u.lockBlobDescriptor(ctx2, desc) + if err != nil { + u.release() + return err } + + _, err = h.Handle(ctx2, desc) + + unlock() + u.release() + if err != nil && !errors.Is(err, images.ErrSkipDesc) { return err } @@ -306,7 +325,13 @@ func (u *unpacker) handlerWrapper( layers = map[digest.Digest][]ocispec.Descriptor{} ) return images.HandlerFunc(func(ctx context.Context, desc ocispec.Descriptor) ([]ocispec.Descriptor, error) { + unlock, err := u.lockBlobDescriptor(ctx, desc) + if err != nil { + return nil, err + } + children, err := f.Handle(ctx, desc) + unlock() if err != nil { return children, err } @@ -349,6 +374,50 @@ func (u *unpacker) handlerWrapper( }, eg } +func (u *unpacker) acquire(ctx context.Context) error { + if u.limiter == nil { + return nil + } + return u.limiter.Acquire(ctx, 1) +} + +func (u *unpacker) release() { + if u.limiter == nil { + return + } + u.limiter.Release(1) +} + +func (u *unpacker) lockSnChainID(ctx context.Context, chainID string) (func(), error) { + key := u.makeChainIDKeyWithSnapshotter(chainID) + + if err := u.config.DuplicationSuppressor.Lock(ctx, key); err != nil { + return nil, err + } + return func() { + u.config.DuplicationSuppressor.Unlock(key) + }, nil +} + +func (u *unpacker) lockBlobDescriptor(ctx context.Context, desc ocispec.Descriptor) (func(), error) { + key := u.makeBlobDescriptorKey(desc) + + if err := u.config.DuplicationSuppressor.Lock(ctx, key); err != nil { + return nil, err + } + return func() { + u.config.DuplicationSuppressor.Unlock(key) + }, nil +} + +func (u *unpacker) makeChainIDKeyWithSnapshotter(chainID string) string { + return fmt.Sprintf("sn://%s/%v", u.snapshotter, chainID) +} + +func (u *unpacker) makeBlobDescriptorKey(desc ocispec.Descriptor) string { + return fmt.Sprintf("blob://%v", desc.Digest) +} + func uniquePart() string { t := time.Now() var b [3]byte