diff --git a/cmd/containerd-stress/main.go b/cmd/containerd-stress/main.go index 79c81aab5..ede505d96 100644 --- a/cmd/containerd-stress/main.go +++ b/cmd/containerd-stress/main.go @@ -208,30 +208,22 @@ func (w *worker) runContainer(ctx context.Context, id string) error { return err } defer task.Delete(ctx, containerd.WithProcessKill) - var ( - start sync.WaitGroup - status = make(chan uint32, 1) - ) - start.Add(1) - go func() { - start.Done() - s, err := task.Wait(w.waitContext) - if err != nil { - if err == context.DeadlineExceeded || - err == context.Canceled { - close(status) - return - } - w.failures++ - logrus.WithError(err).Errorf("wait task %s", id) - } - status <- s - }() - start.Wait() + + statusC, err := task.Wait(ctx) + if err != nil { + return err + } + if err := task.Start(ctx); err != nil { return err } - <-status + status := <-statusC + if status.Err != nil { + if status.Err == context.DeadlineExceeded || status.Err == context.Canceled { + return nil + } + w.failures++ + } return nil } diff --git a/cmd/ctr/attach.go b/cmd/ctr/attach.go index c4934fa21..390835c3e 100644 --- a/cmd/ctr/attach.go +++ b/cmd/ctr/attach.go @@ -44,6 +44,12 @@ var taskAttachCommand = cli.Command{ return err } defer task.Delete(ctx) + + statusC, err := task.Wait(ctx) + if err != nil { + return err + } + if tty { if err := handleConsoleResize(ctx, task, con); err != nil { logrus.WithError(err).Error("console resize") @@ -52,12 +58,13 @@ var taskAttachCommand = cli.Command{ sigc := forwardAllSignals(ctx, task) defer stopCatch(sigc) } - status, err := task.Wait(ctx) - if err != nil { + + ec := <-statusC + if ec.Err != nil { return err } - if status != 0 { - return cli.NewExitError("", int(status)) + if ec.Code != 0 { + return cli.NewExitError("", int(ec.Code)) } return nil }, diff --git a/cmd/ctr/exec.go b/cmd/ctr/exec.go index 015836d81..a9fd7530e 100644 --- a/cmd/ctr/exec.go +++ b/cmd/ctr/exec.go @@ -70,14 +70,11 @@ var taskExecCommand = cli.Command{ } defer process.Delete(ctx) - statusC := make(chan uint32, 1) - go func() { - status, err := process.Wait(ctx) - if err != nil { - logrus.WithError(err).Error("wait process") - } - statusC <- status - }() + statusC, err := task.Wait(ctx) + if err != nil { + return err + } + var con console.Console if tty { con = console.Current() @@ -98,8 +95,11 @@ var taskExecCommand = cli.Command{ defer stopCatch(sigc) } status := <-statusC - if status != 0 { - return cli.NewExitError("", int(status)) + if status.Err != nil { + return status.Err + } + if status.Code != 0 { + return cli.NewExitError("", int(status.Code)) } return nil }, diff --git a/cmd/ctr/run.go b/cmd/ctr/run.go index 60f3a916e..12611345c 100644 --- a/cmd/ctr/run.go +++ b/cmd/ctr/run.go @@ -129,14 +129,11 @@ var runCommand = cli.Command{ } defer task.Delete(ctx) - statusC := make(chan uint32, 1) - go func() { - status, err := task.Wait(ctx) - if err != nil { - logrus.WithError(err).Error("wait process") - } - statusC <- status - }() + statusC, err := task.Wait(ctx) + if err != nil { + return err + } + var con console.Console if tty { con = console.Current() @@ -158,11 +155,15 @@ var runCommand = cli.Command{ } status := <-statusC + if status.Err != nil { + return status.Err + } + if _, err := task.Delete(ctx); err != nil { return err } - if status != 0 { - return cli.NewExitError("", int(status)) + if status.Code != 0 { + return cli.NewExitError("", int(status.Code)) } return nil }, diff --git a/cmd/ctr/start.go b/cmd/ctr/start.go index 59ebefebd..ff7195886 100644 --- a/cmd/ctr/start.go +++ b/cmd/ctr/start.go @@ -47,14 +47,11 @@ var taskStartCommand = cli.Command{ } defer task.Delete(ctx) - statusC := make(chan uint32, 1) - go func() { - status, err := task.Wait(ctx) - if err != nil { - logrus.WithError(err).Error("wait process") - } - statusC <- status - }() + statusC, err := task.Wait(ctx) + if err != nil { + return err + } + var con console.Console if tty { con = console.Current() @@ -76,11 +73,14 @@ var taskStartCommand = cli.Command{ } status := <-statusC + if status.Err != nil { + return err + } if _, err := task.Delete(ctx); err != nil { return err } - if status != 0 { - return cli.NewExitError("", int(status)) + if status.Code != 0 { + return cli.NewExitError("", int(status.Code)) } return nil }, diff --git a/container_checkpoint_test.go b/container_checkpoint_test.go index 4beb11862..ebd762370 100644 --- a/container_checkpoint_test.go +++ b/container_checkpoint_test.go @@ -47,14 +47,11 @@ func TestCheckpointRestore(t *testing.T) { } defer task.Delete(ctx) - statusC := make(chan uint32, 1) - go func() { - status, err := task.Wait(ctx) - if err != nil { - t.Error(err) - } - statusC <- status - }() + statusC, err := task.Wait(ctx) + if err != nil { + t.Error(err) + return + } if err := task.Start(ctx); err != nil { t.Error(err) @@ -79,13 +76,11 @@ func TestCheckpointRestore(t *testing.T) { } defer task.Delete(ctx) - go func() { - status, err := task.Wait(ctx) - if err != nil { - t.Error(err) - } - statusC <- status - }() + statusC, err = task.Wait(ctx) + if err != nil { + t.Error(err) + return + } if err := task.Start(ctx); err != nil { t.Error(err) @@ -137,14 +132,11 @@ func TestCheckpointRestoreNewContainer(t *testing.T) { } defer task.Delete(ctx) - statusC := make(chan uint32, 1) - go func() { - status, err := task.Wait(ctx) - if err != nil { - t.Error(err) - } - statusC <- status - }() + statusC, err := task.Wait(ctx) + if err != nil { + t.Error(err) + return + } if err := task.Start(ctx); err != nil { t.Error(err) @@ -177,13 +169,11 @@ func TestCheckpointRestoreNewContainer(t *testing.T) { } defer task.Delete(ctx) - go func() { - status, err := task.Wait(ctx) - if err != nil { - t.Error(err) - } - statusC <- status - }() + statusC, err = task.Wait(ctx) + if err != nil { + t.Error(err) + return + } if err := task.Start(ctx); err != nil { t.Error(err) @@ -240,14 +230,11 @@ func TestCheckpointLeaveRunning(t *testing.T) { } defer task.Delete(ctx) - statusC := make(chan uint32, 1) - go func() { - status, err := task.Wait(ctx) - if err != nil { - t.Error(err) - } - statusC <- status - }() + statusC, err := task.Wait(ctx) + if err != nil { + t.Error(err) + return + } if err := task.Start(ctx); err != nil { t.Error(err) diff --git a/container_linux_test.go b/container_linux_test.go index b6746aaf8..4675ac17b 100644 --- a/container_linux_test.go +++ b/container_linux_test.go @@ -57,14 +57,11 @@ func TestContainerUpdate(t *testing.T) { } defer task.Delete(ctx) - statusC := make(chan uint32, 1) - go func() { - status, err := task.Wait(ctx) - if err != nil { - t.Error(err) - } - statusC <- status - }() + statusC, err := task.Wait(ctx) + if err != nil { + t.Error(err) + return + } // check that the task has a limit of 32mb cgroup, err := cgroups.Load(cgroups.V1, cgroups.PidPath(int(task.Pid()))) @@ -157,14 +154,12 @@ func TestShimInCgroup(t *testing.T) { } defer task.Delete(ctx) - statusC := make(chan uint32, 1) - go func() { - status, err := task.Wait(ctx) - if err != nil { - t.Error(err) - } - statusC <- status - }() + statusC, err := task.Wait(ctx) + if err != nil { + t.Error(err) + return + } + // check to see if the shim is inside the cgroup processes, err := cg.Processes(cgroups.Devices, false) if err != nil { @@ -221,17 +216,11 @@ func TestDaemonRestart(t *testing.T) { } defer task.Delete(ctx) - synC := make(chan struct{}) - statusC := make(chan uint32, 1) - go func() { - synC <- struct{}{} - status, err := task.Wait(ctx) - if err == nil { - t.Errorf(`first task.Wait() should have failed with "transport is closing"`) - } - statusC <- status - }() - <-synC + statusC, err := task.Wait(ctx) + if err != nil { + t.Error(err) + return + } if err := task.Start(ctx); err != nil { t.Error(err) @@ -242,7 +231,10 @@ func TestDaemonRestart(t *testing.T) { t.Fatal(err) } - <-statusC + status := <-statusC + if status.Err == nil { + t.Errorf(`first task.Wait() should have failed with "transport is closing"`) + } waitCtx, waitCancel := context.WithTimeout(ctx, 2*time.Second) serving, err := client.IsServing(waitCtx) @@ -251,15 +243,11 @@ func TestDaemonRestart(t *testing.T) { t.Fatalf("containerd did not start within 2s: %v", err) } - go func() { - synC <- struct{}{} - status, err := task.Wait(ctx) - if err != nil { - t.Error(err) - } - statusC <- status - }() - <-synC + statusC, err = task.Wait(ctx) + if err != nil { + t.Error(err) + return + } if err := task.Kill(ctx, syscall.SIGKILL); err != nil { t.Fatal(err) diff --git a/container_test.go b/container_test.go index 3844c00bc..514c5409e 100644 --- a/container_test.go +++ b/container_test.go @@ -124,14 +124,11 @@ func TestContainerStart(t *testing.T) { } defer task.Delete(ctx) - statusC := make(chan uint32, 1) - go func() { - status, err := task.Wait(ctx) - if err != nil { - t.Error(err) - } - statusC <- status - }() + statusC, err := task.Wait(ctx) + if err != nil { + t.Error(err) + return + } if pid := task.Pid(); pid <= 0 { t.Errorf("invalid task pid %d", pid) @@ -142,15 +139,21 @@ func TestContainerStart(t *testing.T) { return } status := <-statusC - if status != 7 { - t.Errorf("expected status 7 from wait but received %d", status) - } - if status, err = task.Delete(ctx); err != nil { + if status.Err != nil { t.Error(err) return } - if status != 7 { - t.Errorf("expected status 7 from delete but received %d", status) + if status.Code != 7 { + t.Errorf("expected status 7 from wait but received %d", status.Code) + } + + deleteStatus, err := task.Delete(ctx) + if err != nil { + t.Error(err) + return + } + if deleteStatus != 7 { + t.Errorf("expected status 7 from delete but received %d", deleteStatus) } } @@ -199,14 +202,11 @@ func TestContainerOutput(t *testing.T) { } defer task.Delete(ctx) - statusC := make(chan uint32, 1) - go func() { - status, err := task.Wait(ctx) - if err != nil { - t.Error(err) - } - statusC <- status - }() + statusC, err := task.Wait(ctx) + if err != nil { + t.Error(err) + return + } if err := task.Start(ctx); err != nil { t.Error(err) @@ -214,7 +214,7 @@ func TestContainerOutput(t *testing.T) { } status := <-statusC - if status != 0 { + if status.Code != 0 { t.Errorf("expected status 0 but received %d", status) } if _, err := task.Delete(ctx); err != nil { @@ -273,13 +273,11 @@ func TestContainerExec(t *testing.T) { } defer task.Delete(ctx) - finished := make(chan struct{}, 1) - go func() { - if _, err := task.Wait(ctx); err != nil { - t.Error(err) - } - close(finished) - }() + finishedC, err := task.Wait(ctx) + if err != nil { + t.Error(err) + return + } if err := task.Start(ctx); err != nil { t.Error(err) @@ -295,14 +293,11 @@ func TestContainerExec(t *testing.T) { t.Error(err) return } - processStatusC := make(chan uint32, 1) - go func() { - status, err := process.Wait(ctx) - if err != nil { - t.Error(err) - } - processStatusC <- status - }() + processStatusC, err := process.Wait(ctx) + if err != nil { + t.Error(err) + return + } if err := process.Start(ctx); err != nil { t.Error(err) @@ -311,9 +306,13 @@ func TestContainerExec(t *testing.T) { // wait for the exec to return status := <-processStatusC + if status.Err != nil { + t.Error(err) + return + } - if status != 6 { - t.Errorf("expected exec exit code 6 but received %d", status) + if status.Code != 6 { + t.Errorf("expected exec exit code 6 but received %d", status.Code) } deleteStatus, err := process.Delete(ctx) if err != nil { @@ -326,7 +325,7 @@ func TestContainerExec(t *testing.T) { if err := task.Kill(ctx, syscall.SIGKILL); err != nil { t.Error(err) } - <-finished + <-finishedC } func TestContainerPids(t *testing.T) { @@ -372,14 +371,11 @@ func TestContainerPids(t *testing.T) { } defer task.Delete(ctx) - statusC := make(chan uint32, 1) - go func() { - status, err := task.Wait(ctx) - if err != nil { - t.Error(err) - } - statusC <- status - }() + statusC, err := task.Wait(ctx) + if err != nil { + t.Error(err) + return + } if err := task.Start(ctx); err != nil { t.Error(err) @@ -462,14 +458,11 @@ func TestContainerCloseIO(t *testing.T) { } defer task.Delete(ctx) - statusC := make(chan uint32, 1) - go func() { - status, err := task.Wait(ctx) - if err != nil { - t.Error(err) - } - statusC <- status - }() + statusC, err := task.Wait(ctx) + if err != nil { + t.Error(err) + return + } if err := task.Start(ctx); err != nil { t.Error(err) @@ -576,14 +569,10 @@ func TestContainerAttach(t *testing.T) { defer task.Delete(ctx) originalIO := task.IO() - statusC := make(chan uint32, 1) - go func() { - status, err := task.Wait(ctx) - if err != nil { - t.Error(err) - } - statusC <- status - }() + statusC, err := task.Wait(ctx) + if err != nil { + t.Error(err) + } if err := task.Start(ctx); err != nil { t.Error(err) @@ -620,7 +609,11 @@ func TestContainerAttach(t *testing.T) { t.Error(err) } - <-statusC + status := <-statusC + if status.Err != nil { + t.Error(err) + return + } originalIO.Close() if _, err := task.Delete(ctx); err != nil { @@ -681,14 +674,11 @@ func TestDeleteRunningContainer(t *testing.T) { } defer task.Delete(ctx) - statusC := make(chan uint32, 1) - go func() { - status, err := task.Wait(ctx) - if err != nil { - t.Error(err) - } - statusC <- status - }() + statusC, err := task.Wait(ctx) + if err != nil { + t.Error(err) + return + } if err := task.Start(ctx); err != nil { t.Error(err) @@ -752,14 +742,11 @@ func TestContainerKill(t *testing.T) { } defer task.Delete(ctx) - statusC := make(chan uint32, 1) - go func() { - status, err := task.Wait(ctx) - if err != nil { - t.Error(err) - } - statusC <- status - }() + statusC, err := task.Wait(ctx) + if err != nil { + t.Error(err) + return + } if err := task.Start(ctx); err != nil { t.Error(err) @@ -878,19 +865,15 @@ func TestContainerExecNoBinaryExists(t *testing.T) { } defer task.Delete(ctx) + finishedC, err := task.Wait(ctx) + if err != nil { + t.Error(err) + } if err := task.Start(ctx); err != nil { t.Error(err) return } - finished := make(chan struct{}, 1) - go func() { - if _, err := task.Wait(ctx); err != nil { - t.Error(err) - } - close(finished) - }() - // start an exec process without running the original container process processSpec := spec.Process processSpec.Args = []string{ @@ -909,7 +892,7 @@ func TestContainerExecNoBinaryExists(t *testing.T) { if err := task.Kill(ctx, syscall.SIGKILL); err != nil { t.Error(err) } - <-finished + <-finishedC } func TestUserNamespaces(t *testing.T) { @@ -962,14 +945,11 @@ func TestUserNamespaces(t *testing.T) { } defer task.Delete(ctx) - statusC := make(chan uint32, 1) - go func() { - status, err := task.Wait(ctx) - if err != nil { - t.Error(err) - } - statusC <- status - }() + statusC, err := task.Wait(ctx) + if err != nil { + t.Error(err) + return + } if pid := task.Pid(); pid <= 0 { t.Errorf("invalid task pid %d", pid) @@ -980,15 +960,20 @@ func TestUserNamespaces(t *testing.T) { return } status := <-statusC - if status != 7 { - t.Errorf("expected status 7 from wait but received %d", status) - } - if status, err = task.Delete(ctx); err != nil { + if status.Err != nil { t.Error(err) return } - if status != 7 { - t.Errorf("expected status 7 from delete but received %d", status) + if status.Code != 7 { + t.Errorf("expected status 7 from wait but received %d", status.Code) + } + deleteStatus, err := task.Delete(ctx) + if err != nil { + t.Error(err) + return + } + if deleteStatus != 7 { + t.Errorf("expected status 7 from delete but received %d", deleteStatus) } } @@ -1035,14 +1020,11 @@ func TestWaitStoppedTask(t *testing.T) { } defer task.Delete(ctx) - statusC := make(chan uint32, 1) - go func() { - status, err := task.Wait(ctx) - if err != nil { - t.Error(err) - } - statusC <- status - }() + statusC, err := task.Wait(ctx) + if err != nil { + t.Error(err) + return + } if pid := task.Pid(); pid <= 0 { t.Errorf("invalid task pid %d", pid) @@ -1052,15 +1034,21 @@ func TestWaitStoppedTask(t *testing.T) { task.Delete(ctx) return } + // wait for the task to stop then call wait again <-statusC - status, err := task.Wait(ctx) + statusC, err = task.Wait(ctx) if err != nil { t.Error(err) return } - if status != 7 { - t.Errorf("exit status from stopped task should be 7 but received %d", status) + status := <-statusC + if status.Err != nil { + t.Error(status.Err) + return + } + if status.Code != 7 { + t.Errorf("exit status from stopped task should be 7 but received %d", status.Code) } } @@ -1107,13 +1095,10 @@ func TestWaitStoppedProcess(t *testing.T) { } defer task.Delete(ctx) - finished := make(chan struct{}, 1) - go func() { - if _, err := task.Wait(ctx); err != nil { - t.Error(err) - } - close(finished) - }() + finishedC, err := task.Wait(ctx) + if err != nil { + t.Error(err) + } if err := task.Start(ctx); err != nil { t.Error(err) @@ -1130,14 +1115,12 @@ func TestWaitStoppedProcess(t *testing.T) { return } defer process.Delete(ctx) - processStatusC := make(chan uint32, 1) - go func() { - status, err := process.Wait(ctx) - if err != nil { - t.Error(err) - } - processStatusC <- status - }() + + statusC, err := process.Wait(ctx) + if err != nil { + t.Error(err) + return + } if err := process.Start(ctx); err != nil { t.Error(err) @@ -1145,20 +1128,27 @@ func TestWaitStoppedProcess(t *testing.T) { } // wait for the exec to return - <-processStatusC + <-statusC + // try to wait on the process after it has stopped - status, err := process.Wait(ctx) + statusC, err = process.Wait(ctx) if err != nil { t.Error(err) return } - if status != 6 { - t.Errorf("exit status from stopped process should be 6 but received %d", status) + status := <-statusC + if status.Err != nil { + t.Error(err) + return } + if status.Code != 6 { + t.Errorf("exit status from stopped process should be 6 but received %d", status.Code) + } + if err := task.Kill(ctx, syscall.SIGKILL); err != nil { t.Error(err) } - <-finished + <-finishedC } func TestTaskForceDelete(t *testing.T) { @@ -1256,14 +1246,11 @@ func TestProcessForceDelete(t *testing.T) { } defer task.Delete(ctx) - statusC := make(chan uint32, 1) - go func() { - status, err := task.Wait(ctx) - if err != nil { - t.Error(err) - } - statusC <- status - }() + statusC, err := task.Wait(ctx) + if err != nil { + t.Error(err) + return + } // task must be started on windows if err := task.Start(ctx); err != nil { @@ -1344,14 +1331,11 @@ func TestContainerHostname(t *testing.T) { } defer task.Delete(ctx) - statusC := make(chan uint32, 1) - go func() { - status, err := task.Wait(ctx) - if err != nil { - t.Error(err) - } - statusC <- status - }() + statusC, err := task.Wait(ctx) + if err != nil { + t.Error(err) + return + } if err := task.Start(ctx); err != nil { t.Error(err) @@ -1359,7 +1343,11 @@ func TestContainerHostname(t *testing.T) { } status := <-statusC - if status != 0 { + if status.Err != nil { + t.Error(status.Err) + return + } + if status.Code != 0 { t.Errorf("expected status 0 but received %d", status) } if _, err := task.Delete(ctx); err != nil { @@ -1420,14 +1408,10 @@ func TestContainerExitedAtSet(t *testing.T) { } defer task.Delete(ctx) - statusC := make(chan uint32, 1) - go func() { - status, err := task.Wait(ctx) - if err != nil { - t.Error(err) - } - statusC <- status - }() + statusC, err := task.Wait(ctx) + if err != nil { + t.Error(err) + } startTime := time.Now() if err := task.Start(ctx); err != nil { @@ -1436,7 +1420,7 @@ func TestContainerExitedAtSet(t *testing.T) { } status := <-statusC - if status != 0 { + if status.Code != 0 { t.Errorf("expected status 0 but received %d", status) } @@ -1495,13 +1479,10 @@ func TestDeleteContainerExecCreated(t *testing.T) { } defer task.Delete(ctx) - finished := make(chan struct{}, 1) - go func() { - if _, err := task.Wait(ctx); err != nil { - t.Error(err) - } - close(finished) - }() + finished, err := task.Wait(ctx) + if err != nil { + t.Error(err) + } if err := task.Start(ctx); err != nil { t.Error(err) diff --git a/docs/getting-started.md b/docs/getting-started.md index 6d226730d..5e0e33aaa 100644 --- a/docs/getting-started.md +++ b/docs/getting-started.md @@ -163,14 +163,10 @@ You always want to make sure you `Wait` before calling `Start` on a task. This makes sure that you do not encounter any races if the task has a simple program like `/bin/true` that exits promptly after calling start. ```go - exitStatusC := make(chan uint32, 1) - go func() { - status, err := task.Wait(ctx) - if err != nil { - fmt.Println(err) - } - exitStatusC <- status - }() + exitStatusC, err := task.Wait(ctx) + if err != nil { + return err + } if err := task.Start(ctx); err != nil { return err @@ -192,7 +188,10 @@ To do this we will simply call `Kill` on the task after waiting a couple of seco } status := <-exitStatusC - fmt.Printf("redis-server exited with status: %d\n", status) + if status.Err != nil { + return status.Err + } + fmt.Printf("redis-server exited with status: %d\n", status.Code) ``` We wait on our exit status channel that we setup to ensure the task has fully exited and we get the exit status. @@ -271,14 +270,10 @@ func redisExample() error { defer task.Delete(ctx) // make sure we wait before calling start - exitStatusC := make(chan uint32, 1) - go func() { - status, err := task.Wait(ctx) - if err != nil { - fmt.Println(err) - } - exitStatusC <- status - }() + exitStatusC, err := task.Wait(ctx) + if err != nil { + fmt.Println(err) + } // call start on the task to execute the redis server if err := task.Start(ctx); err != nil { @@ -296,7 +291,10 @@ func redisExample() error { // wait for the process to fully exit and print out the exit status status := <-exitStatusC - fmt.Printf("redis-server exited with status: %d\n", status) + if status.Err != nil { + return err + } + fmt.Printf("redis-server exited with status: %d\n", status.Code) return nil } diff --git a/process.go b/process.go index 620d1a225..0facb5f41 100644 --- a/process.go +++ b/process.go @@ -4,6 +4,7 @@ import ( "context" "strings" "syscall" + "time" eventsapi "github.com/containerd/containerd/api/services/events/v1" "github.com/containerd/containerd/api/services/tasks/v1" @@ -24,8 +25,8 @@ type Process interface { Delete(context.Context, ...ProcessDeleteOpts) (uint32, error) // Kill sends the provided signal to the process Kill(context.Context, syscall.Signal) error - // Wait blocks until the process has exited returning the exit status - Wait(context.Context) (uint32, error) + // Wait asynchronously waits for the process to exit, and sends the exit code to the returned channel + Wait(context.Context) (<-chan ExitStatus, error) // CloseIO allows various pipes to be closed on the process CloseIO(context.Context, ...IOCloserOpts) error // Resize changes the width and heigh of the process's terminal @@ -36,6 +37,16 @@ type Process interface { Status(context.Context) (Status, error) } +// ExitStatus encapsulates a process' exit code. +// It is used by `Wait()` to return either a process exit code or an error +// The `Err` field is provided to return an error that may occur while waiting +// `Err` is not used to convey an error with the process itself. +type ExitStatus struct { + Code uint32 + ExitedAt time.Time + Err error +} + type process struct { id string task *task @@ -79,39 +90,55 @@ func (p *process) Kill(ctx context.Context, s syscall.Signal) error { return errdefs.FromGRPC(err) } -func (p *process) Wait(ctx context.Context) (uint32, error) { +func (p *process) Wait(ctx context.Context) (<-chan ExitStatus, error) { cancellable, cancel := context.WithCancel(ctx) - defer cancel() eventstream, err := p.task.client.EventService().Subscribe(cancellable, &eventsapi.SubscribeRequest{ Filters: []string{"topic==" + runtime.TaskExitEventTopic}, }) if err != nil { - return UnknownExitStatus, err + cancel() + return nil, err } // first check if the task has exited status, err := p.Status(ctx) if err != nil { - return UnknownExitStatus, errdefs.FromGRPC(err) + cancel() + return nil, errdefs.FromGRPC(err) } + + chStatus := make(chan ExitStatus, 1) if status.Status == Stopped { - return status.ExitStatus, nil + cancel() + chStatus <- ExitStatus{Code: status.ExitStatus, ExitedAt: status.ExitTime} + return chStatus, nil } - for { - evt, err := eventstream.Recv() - if err != nil { - return UnknownExitStatus, err - } - if typeurl.Is(evt.Event, &eventsapi.TaskExit{}) { - v, err := typeurl.UnmarshalAny(evt.Event) + + go func() { + defer cancel() + chStatus <- ExitStatus{} // signal that the goroutine is running + for { + evt, err := eventstream.Recv() if err != nil { - return UnknownExitStatus, err + chStatus <- ExitStatus{Code: UnknownExitStatus, Err: err} + return } - e := v.(*eventsapi.TaskExit) - if e.ID == p.id && e.ContainerID == p.task.id { - return e.ExitStatus, nil + if typeurl.Is(evt.Event, &eventsapi.TaskExit{}) { + v, err := typeurl.UnmarshalAny(evt.Event) + if err != nil { + chStatus <- ExitStatus{Code: UnknownExitStatus, Err: err} + return + } + e := v.(*eventsapi.TaskExit) + if e.ID == p.id && e.ContainerID == p.task.id { + chStatus <- ExitStatus{Code: e.ExitStatus, ExitedAt: e.ExitedAt} + return + } } } - } + }() + + <-chStatus // wait for the goroutine to be running + return chStatus, nil } func (p *process) CloseIO(ctx context.Context, opts ...IOCloserOpts) error { diff --git a/task.go b/task.go index be580b553..c0c6182f8 100644 --- a/task.go +++ b/task.go @@ -201,15 +201,18 @@ func (t *task) Status(ctx context.Context) (Status, error) { }, nil } -func (t *task) Wait(ctx context.Context) (uint32, error) { +func (t *task) Wait(ctx context.Context) (<-chan ExitStatus, error) { cancellable, cancel := context.WithCancel(ctx) - defer cancel() eventstream, err := t.client.EventService().Subscribe(cancellable, &eventsapi.SubscribeRequest{ Filters: []string{"topic==" + runtime.TaskExitEventTopic}, }) if err != nil { - return UnknownExitStatus, errdefs.FromGRPC(err) + cancel() + return nil, errdefs.FromGRPC(err) } + + chStatus := make(chan ExitStatus, 1) + t.mu.Lock() checkpoint := t.deferred != nil t.mu.Unlock() @@ -217,28 +220,42 @@ func (t *task) Wait(ctx context.Context) (uint32, error) { // first check if the task has exited status, err := t.Status(ctx) if err != nil { - return UnknownExitStatus, errdefs.FromGRPC(err) + cancel() + return nil, errdefs.FromGRPC(err) } if status.Status == Stopped { - return status.ExitStatus, nil + cancel() + chStatus <- ExitStatus{Code: status.ExitStatus, ExitedAt: status.ExitTime} + return chStatus, nil } } - for { - evt, err := eventstream.Recv() - if err != nil { - return UnknownExitStatus, errdefs.FromGRPC(err) - } - if typeurl.Is(evt.Event, &eventsapi.TaskExit{}) { - v, err := typeurl.UnmarshalAny(evt.Event) + + go func() { + defer cancel() + chStatus <- ExitStatus{} // signal that goroutine is running + for { + evt, err := eventstream.Recv() if err != nil { - return UnknownExitStatus, err + chStatus <- ExitStatus{Code: UnknownExitStatus, Err: errdefs.FromGRPC(err)} + return } - e := v.(*eventsapi.TaskExit) - if e.ContainerID == t.id && e.Pid == t.pid { - return e.ExitStatus, nil + if typeurl.Is(evt.Event, &eventsapi.TaskExit{}) { + v, err := typeurl.UnmarshalAny(evt.Event) + if err != nil { + chStatus <- ExitStatus{Code: UnknownExitStatus, Err: err} + return + } + e := v.(*eventsapi.TaskExit) + if e.ContainerID == t.id && e.Pid == t.pid { + chStatus <- ExitStatus{Code: e.ExitStatus, ExitedAt: e.ExitedAt} + return + } } } - } + }() + + <-chStatus // wait for the goroutine to be running + return chStatus, nil } // Delete deletes the task and its runtime state diff --git a/task_opts.go b/task_opts.go index 0e4b75e85..3dcab170f 100644 --- a/task_opts.go +++ b/task_opts.go @@ -33,15 +33,14 @@ type ProcessDeleteOpts func(context.Context, Process) error // WithProcessKill will forcefully kill and delete a process func WithProcessKill(ctx context.Context, p Process) error { - s := make(chan struct{}, 1) ctx, cancel := context.WithCancel(ctx) defer cancel() // ignore errors to wait and kill as we are forcefully killing // the process and don't care about the exit status - go func() { - p.Wait(ctx) - close(s) - }() + s, err := p.Wait(ctx) + if err != nil { + return err + } if err := p.Kill(ctx, syscall.SIGKILL); err != nil { if errdefs.IsFailedPrecondition(err) || errdefs.IsNotFound(err) { return nil