diff --git a/linux/shim/exec.go b/linux/shim/exec.go index 2f8d88c22..68c769504 100644 --- a/linux/shim/exec.go +++ b/linux/shim/exec.go @@ -86,20 +86,22 @@ func newExecProcess(context context.Context, path string, r *shimapi.ExecRequest e.closers = append(e.closers, sc) e.stdin = sc } + var copyWaitGroup sync.WaitGroup if socket != nil { console, err := socket.ReceiveMaster() if err != nil { return nil, err } e.console = console - if err := copyConsole(context, console, r.Stdin, r.Stdout, r.Stderr, &e.WaitGroup); err != nil { + if err := copyConsole(context, console, r.Stdin, r.Stdout, r.Stderr, &e.WaitGroup, ©WaitGroup); err != nil { return nil, err } } else { - if err := copyPipes(context, io, r.Stdin, r.Stdout, r.Stderr, &e.WaitGroup); err != nil { + if err := copyPipes(context, io, r.Stdin, r.Stdout, r.Stderr, &e.WaitGroup, ©WaitGroup); err != nil { return nil, err } } + copyWaitGroup.Wait() pid, err := runc.ReadPidFile(opts.PidFile) if err != nil { return nil, err diff --git a/linux/shim/init.go b/linux/shim/init.go index b476349ad..0429e6aac 100644 --- a/linux/shim/init.go +++ b/linux/shim/init.go @@ -93,20 +93,22 @@ func newInitProcess(context context.Context, path string, r *shimapi.CreateReque p.stdin = sc p.closers = append(p.closers, sc) } + var copyWaitGroup sync.WaitGroup if socket != nil { console, err := socket.ReceiveMaster() if err != nil { return nil, err } p.console = console - if err := copyConsole(context, console, r.Stdin, r.Stdout, r.Stderr, &p.WaitGroup); err != nil { + if err := copyConsole(context, console, r.Stdin, r.Stdout, r.Stderr, &p.WaitGroup, ©WaitGroup); err != nil { return nil, err } } else { - if err := copyPipes(context, io, r.Stdin, r.Stdout, r.Stderr, &p.WaitGroup); err != nil { + if err := copyPipes(context, io, r.Stdin, r.Stdout, r.Stderr, &p.WaitGroup, ©WaitGroup); err != nil { return nil, err } } + copyWaitGroup.Wait() pid, err := runc.ReadPidFile(opts.PidFile) if err != nil { return nil, err diff --git a/linux/shim/io.go b/linux/shim/io.go index 2194a99e2..2360491fe 100644 --- a/linux/shim/io.go +++ b/linux/shim/io.go @@ -14,13 +14,17 @@ import ( "github.com/tonistiigi/fifo" ) -func copyConsole(ctx context.Context, console console.Console, stdin, stdout, stderr string, wg *sync.WaitGroup) error { +func copyConsole(ctx context.Context, console console.Console, stdin, stdout, stderr string, wg, cwg *sync.WaitGroup) error { if stdin != "" { in, err := fifo.OpenFifo(ctx, stdin, syscall.O_RDONLY, 0) if err != nil { return err } - go io.Copy(console, in) + cwg.Add(1) + go func() { + cwg.Done() + io.Copy(console, in) + }() } outw, err := fifo.OpenFifo(ctx, stdout, syscall.O_WRONLY, 0) if err != nil { @@ -31,7 +35,9 @@ func copyConsole(ctx context.Context, console console.Console, stdin, stdout, st return err } wg.Add(1) + cwg.Add(1) go func() { + cwg.Done() io.Copy(outw, console) console.Close() outr.Close() @@ -41,11 +47,13 @@ func copyConsole(ctx context.Context, console console.Console, stdin, stdout, st return nil } -func copyPipes(ctx context.Context, rio runc.IO, stdin, stdout, stderr string, wg *sync.WaitGroup) error { +func copyPipes(ctx context.Context, rio runc.IO, stdin, stdout, stderr string, wg, cwg *sync.WaitGroup) error { for name, dest := range map[string]func(wc io.WriteCloser, rc io.Closer){ stdout: func(wc io.WriteCloser, rc io.Closer) { wg.Add(1) + cwg.Add(1) go func() { + cwg.Done() io.Copy(wc, rio.Stdout()) wg.Done() wc.Close() @@ -54,7 +62,9 @@ func copyPipes(ctx context.Context, rio runc.IO, stdin, stdout, stderr string, w }, stderr: func(wc io.WriteCloser, rc io.Closer) { wg.Add(1) + cwg.Add(1) go func() { + cwg.Done() io.Copy(wc, rio.Stderr()) wg.Done() wc.Close() @@ -79,7 +89,9 @@ func copyPipes(ctx context.Context, rio runc.IO, stdin, stdout, stderr string, w if err != nil { return fmt.Errorf("containerd-shim: opening %s failed: %s", stdin, err) } + cwg.Add(1) go func() { + cwg.Done() io.Copy(rio.Stdin(), f) rio.Stdin().Close() f.Close()