diff --git a/linux/shim/init.go b/linux/shim/init.go index 8d6e1bfb1..f54daaaf7 100644 --- a/linux/shim/init.go +++ b/linux/shim/init.go @@ -55,6 +55,8 @@ type initProcess struct { } func newInitProcess(context context.Context, path, namespace string, r *shimapi.CreateTaskRequest) (*initProcess, error) { + var success bool + if err := identifiers.Validate(r.ID); err != nil { return nil, errors.Wrapf(err, "invalid task id") } @@ -72,11 +74,14 @@ func newInitProcess(context context.Context, path, namespace string, r *shimapi. // what was actually done rather than what should have been // done. nrRootMounts := 0 - cleanupMounts := func() { + defer func() { + if success { + return + } if err2 := mount.UnmountN(rootfs, 0, nrRootMounts); err2 != nil { log.G(context).WithError(err2).Warn("Failed to cleanup rootfs mount") } - } + }() for _, rm := range r.Rootfs { m := &mount.Mount{ Type: rm.Type, @@ -84,7 +89,6 @@ func newInitProcess(context context.Context, path, namespace string, r *shimapi. Options: rm.Options, } if err := m.Mount(rootfs); err != nil { - cleanupMounts() return nil, errors.Wrapf(err, "failed to mount rootfs component %v", m) } nrRootMounts++ @@ -116,14 +120,12 @@ func newInitProcess(context context.Context, path, namespace string, r *shimapi. ) if r.Terminal { if socket, err = runc.NewConsoleSocket(filepath.Join(path, "pty.sock")); err != nil { - cleanupMounts() return nil, errors.Wrap(err, "failed to create OCI runtime console socket") } defer os.Remove(socket.Path()) } else { // TODO: get uid/gid if io, err = runc.NewPipeIO(0, 0); err != nil { - cleanupMounts() return nil, errors.Wrap(err, "failed to create OCI runtime io pipes") } p.io = io @@ -143,7 +145,6 @@ func newInitProcess(context context.Context, path, namespace string, r *shimapi. NoSubreaper: true, } if _, err := p.runtime.Restore(context, r.ID, r.Bundle, opts); err != nil { - cleanupMounts() return nil, p.runtimeError(err, "OCI runtime restore failed") } } else { @@ -157,7 +158,6 @@ func newInitProcess(context context.Context, path, namespace string, r *shimapi. opts.ConsoleSocket = socket } if err := p.runtime.Create(context, r.ID, r.Bundle, opts); err != nil { - cleanupMounts() return nil, p.runtimeError(err, "OCI runtime create failed") } } @@ -173,17 +173,14 @@ func newInitProcess(context context.Context, path, namespace string, r *shimapi. if socket != nil { console, err := socket.ReceiveMaster() if err != nil { - cleanupMounts() return nil, errors.Wrap(err, "failed to retrieve console master") } p.console = console if err := copyConsole(context, console, r.Stdin, r.Stdout, r.Stderr, &p.WaitGroup, ©WaitGroup); err != nil { - cleanupMounts() return nil, errors.Wrap(err, "failed to start console copy") } } else { if err := copyPipes(context, io, r.Stdin, r.Stdout, r.Stderr, &p.WaitGroup, ©WaitGroup); err != nil { - cleanupMounts() return nil, errors.Wrap(err, "failed to start io pipe copy") } } @@ -191,10 +188,10 @@ func newInitProcess(context context.Context, path, namespace string, r *shimapi. copyWaitGroup.Wait() pid, err := runc.ReadPidFile(pidFile) if err != nil { - cleanupMounts() return nil, errors.Wrap(err, "failed to retrieve OCI runtime container pid") } p.pid = pid + success = true return p, nil }