diff --git a/task.go b/task.go index d586fb9ff..f7cb4bb4d 100644 --- a/task.go +++ b/task.go @@ -8,6 +8,7 @@ import ( "io" goruntime "runtime" "strings" + "sync" "syscall" eventsapi "github.com/containerd/containerd/api/services/events/v1" @@ -93,6 +94,7 @@ type task struct { id string pid uint32 + mu sync.Mutex deferred *tasks.CreateTaskRequest } @@ -102,9 +104,14 @@ func (t *task) Pid() uint32 { } func (t *task) Start(ctx context.Context) error { - if t.deferred != nil { - response, err := t.client.TaskService().Create(ctx, t.deferred) + t.mu.Lock() + deferred := t.deferred + t.mu.Unlock() + if deferred != nil { + response, err := t.client.TaskService().Create(ctx, deferred) + t.mu.Lock() t.deferred = nil + t.mu.Unlock() if err != nil { t.io.closer.Close() return err @@ -166,13 +173,18 @@ func (t *task) Wait(ctx context.Context) (uint32, error) { if err != nil { return UnknownExitStatus, errdefs.FromGRPC(err) } - // first check if the task has exited - status, err := t.Status(ctx) - if err != nil { - return UnknownExitStatus, errdefs.FromGRPC(err) - } - if status.Status == Stopped { - return status.ExitStatus, nil + t.mu.Lock() + checkpoint := t.deferred != nil + t.mu.Unlock() + if !checkpoint { + // first check if the task has exited + status, err := t.Status(ctx) + if err != nil { + return UnknownExitStatus, errdefs.FromGRPC(err) + } + if status.Status == Stopped { + return status.ExitStatus, nil + } } for { evt, err := eventstream.Recv()