diff --git a/client_test.go b/client_test.go index addd23919..db307a96a 100644 --- a/client_test.go +++ b/client_test.go @@ -98,7 +98,7 @@ func TestMain(m *testing.M) { if err := cmd.Process.Signal(syscall.SIGTERM); err != nil { fmt.Fprintln(os.Stderr, err) } - if _, err := cmd.Process.Wait(); err != nil { + if err := cmd.Wait(); err != nil { fmt.Fprintln(os.Stderr, err) } if err := os.RemoveAll(defaultRoot); err != nil { diff --git a/io.go b/io.go index ea103c354..e4fce203c 100644 --- a/io.go +++ b/io.go @@ -1,6 +1,7 @@ package containerd import ( + "context" "fmt" "io" "io/ioutil" @@ -18,6 +19,13 @@ type IO struct { closer *wgCloser } +func (i *IO) Cancel() { + if i.closer == nil { + return + } + i.closer.Cancel() +} + func (i *IO) Wait() { if i.closer == nil { return @@ -134,9 +142,10 @@ type ioSet struct { } type wgCloser struct { - wg *sync.WaitGroup - dir string - set []io.Closer + wg *sync.WaitGroup + dir string + set []io.Closer + cancel context.CancelFunc } func (g *wgCloser) Wait() { @@ -152,3 +161,7 @@ func (g *wgCloser) Close() error { } return nil } + +func (g *wgCloser) Cancel() { + g.cancel() +} diff --git a/io_unix.go b/io_unix.go index ed8960585..ea05f251a 100644 --- a/io_unix.go +++ b/io_unix.go @@ -13,16 +13,17 @@ import ( func copyIO(fifos *FIFOSet, ioset *ioSet, tty bool) (_ *wgCloser, err error) { var ( - f io.ReadWriteCloser - set []io.Closer - ctx = context.Background() - wg = &sync.WaitGroup{} + f io.ReadWriteCloser + set []io.Closer + ctx, cancel = context.WithCancel(context.Background()) + wg = &sync.WaitGroup{} ) defer func() { if err != nil { for _, f := range set { f.Close() } + cancel() } }() @@ -55,13 +56,14 @@ func copyIO(fifos *FIFOSet, ioset *ioSet, tty bool) (_ *wgCloser, err error) { wg.Add(1) go func(r io.ReadCloser) { io.Copy(ioset.err, r) - wg.Done() r.Close() + wg.Done() }(f) } return &wgCloser{ - wg: wg, - dir: fifos.Dir, - set: set, + wg: wg, + dir: fifos.Dir, + set: set, + cancel: cancel, }, nil } diff --git a/process.go b/process.go index caa332017..a3b4b3565 100644 --- a/process.go +++ b/process.go @@ -45,6 +45,8 @@ func (p *process) Start(ctx context.Context) error { } response, err := p.task.client.TaskService().Exec(ctx, request) if err != nil { + p.io.Cancel() + p.io.Wait() p.io.Close() return err }