Simplify mount cleanup on failure by using defer

This avoids someone adding a new error path and forgetting to call the cleanup
function.

We prefer to use an explicit flag to gate the clean rather than relying on `err
!= nil` so we don't have to rely on people never accidentally shadowing the
`err` as seen by the closure.

Signed-off-by: Ian Campbell <ian.campbell@docker.com>
This commit is contained in:
Ian Campbell 2017-07-12 16:11:38 +01:00
parent 300f083127
commit d63d2ecf6c

View File

@ -55,6 +55,8 @@ type initProcess struct {
} }
func newInitProcess(context context.Context, path, namespace string, r *shimapi.CreateTaskRequest) (*initProcess, error) { func newInitProcess(context context.Context, path, namespace string, r *shimapi.CreateTaskRequest) (*initProcess, error) {
var success bool
if err := identifiers.Validate(r.ID); err != nil { if err := identifiers.Validate(r.ID); err != nil {
return nil, errors.Wrapf(err, "invalid task id") return nil, errors.Wrapf(err, "invalid task id")
} }
@ -72,11 +74,14 @@ func newInitProcess(context context.Context, path, namespace string, r *shimapi.
// what was actually done rather than what should have been // what was actually done rather than what should have been
// done. // done.
nrRootMounts := 0 nrRootMounts := 0
cleanupMounts := func() { defer func() {
if success {
return
}
if err2 := mount.UnmountN(rootfs, 0, nrRootMounts); err2 != nil { if err2 := mount.UnmountN(rootfs, 0, nrRootMounts); err2 != nil {
log.G(context).WithError(err2).Warn("Failed to cleanup rootfs mount") log.G(context).WithError(err2).Warn("Failed to cleanup rootfs mount")
} }
} }()
for _, rm := range r.Rootfs { for _, rm := range r.Rootfs {
m := &mount.Mount{ m := &mount.Mount{
Type: rm.Type, Type: rm.Type,
@ -84,7 +89,6 @@ func newInitProcess(context context.Context, path, namespace string, r *shimapi.
Options: rm.Options, Options: rm.Options,
} }
if err := m.Mount(rootfs); err != nil { if err := m.Mount(rootfs); err != nil {
cleanupMounts()
return nil, errors.Wrapf(err, "failed to mount rootfs component %v", m) return nil, errors.Wrapf(err, "failed to mount rootfs component %v", m)
} }
nrRootMounts++ nrRootMounts++
@ -116,14 +120,12 @@ func newInitProcess(context context.Context, path, namespace string, r *shimapi.
) )
if r.Terminal { if r.Terminal {
if socket, err = runc.NewConsoleSocket(filepath.Join(path, "pty.sock")); err != nil { if socket, err = runc.NewConsoleSocket(filepath.Join(path, "pty.sock")); err != nil {
cleanupMounts()
return nil, errors.Wrap(err, "failed to create OCI runtime console socket") return nil, errors.Wrap(err, "failed to create OCI runtime console socket")
} }
defer os.Remove(socket.Path()) defer os.Remove(socket.Path())
} else { } else {
// TODO: get uid/gid // TODO: get uid/gid
if io, err = runc.NewPipeIO(0, 0); err != nil { if io, err = runc.NewPipeIO(0, 0); err != nil {
cleanupMounts()
return nil, errors.Wrap(err, "failed to create OCI runtime io pipes") return nil, errors.Wrap(err, "failed to create OCI runtime io pipes")
} }
p.io = io p.io = io
@ -143,7 +145,6 @@ func newInitProcess(context context.Context, path, namespace string, r *shimapi.
NoSubreaper: true, NoSubreaper: true,
} }
if _, err := p.runtime.Restore(context, r.ID, r.Bundle, opts); err != nil { if _, err := p.runtime.Restore(context, r.ID, r.Bundle, opts); err != nil {
cleanupMounts()
return nil, p.runtimeError(err, "OCI runtime restore failed") return nil, p.runtimeError(err, "OCI runtime restore failed")
} }
} else { } else {
@ -157,7 +158,6 @@ func newInitProcess(context context.Context, path, namespace string, r *shimapi.
opts.ConsoleSocket = socket opts.ConsoleSocket = socket
} }
if err := p.runtime.Create(context, r.ID, r.Bundle, opts); err != nil { if err := p.runtime.Create(context, r.ID, r.Bundle, opts); err != nil {
cleanupMounts()
return nil, p.runtimeError(err, "OCI runtime create failed") return nil, p.runtimeError(err, "OCI runtime create failed")
} }
} }
@ -173,17 +173,14 @@ func newInitProcess(context context.Context, path, namespace string, r *shimapi.
if socket != nil { if socket != nil {
console, err := socket.ReceiveMaster() console, err := socket.ReceiveMaster()
if err != nil { if err != nil {
cleanupMounts()
return nil, errors.Wrap(err, "failed to retrieve console master") return nil, errors.Wrap(err, "failed to retrieve console master")
} }
p.console = console p.console = console
if err := copyConsole(context, console, r.Stdin, r.Stdout, r.Stderr, &p.WaitGroup, &copyWaitGroup); err != nil { if err := copyConsole(context, console, r.Stdin, r.Stdout, r.Stderr, &p.WaitGroup, &copyWaitGroup); err != nil {
cleanupMounts()
return nil, errors.Wrap(err, "failed to start console copy") return nil, errors.Wrap(err, "failed to start console copy")
} }
} else { } else {
if err := copyPipes(context, io, r.Stdin, r.Stdout, r.Stderr, &p.WaitGroup, &copyWaitGroup); err != nil { if err := copyPipes(context, io, r.Stdin, r.Stdout, r.Stderr, &p.WaitGroup, &copyWaitGroup); err != nil {
cleanupMounts()
return nil, errors.Wrap(err, "failed to start io pipe copy") return nil, errors.Wrap(err, "failed to start io pipe copy")
} }
} }
@ -191,10 +188,10 @@ func newInitProcess(context context.Context, path, namespace string, r *shimapi.
copyWaitGroup.Wait() copyWaitGroup.Wait()
pid, err := runc.ReadPidFile(pidFile) pid, err := runc.ReadPidFile(pidFile)
if err != nil { if err != nil {
cleanupMounts()
return nil, errors.Wrap(err, "failed to retrieve OCI runtime container pid") return nil, errors.Wrap(err, "failed to retrieve OCI runtime container pid")
} }
p.pid = pid p.pid = pid
success = true
return p, nil return p, nil
} }