Simplify mount cleanup on failure by using defer
This avoids someone adding a new error path and forgetting to call the cleanup function. We prefer to use an explicit flag to gate the clean rather than relying on `err != nil` so we don't have to rely on people never accidentally shadowing the `err` as seen by the closure. Signed-off-by: Ian Campbell <ian.campbell@docker.com>
This commit is contained in:
parent
300f083127
commit
d63d2ecf6c
@ -55,6 +55,8 @@ type initProcess struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func newInitProcess(context context.Context, path, namespace string, r *shimapi.CreateTaskRequest) (*initProcess, error) {
|
func newInitProcess(context context.Context, path, namespace string, r *shimapi.CreateTaskRequest) (*initProcess, error) {
|
||||||
|
var success bool
|
||||||
|
|
||||||
if err := identifiers.Validate(r.ID); err != nil {
|
if err := identifiers.Validate(r.ID); err != nil {
|
||||||
return nil, errors.Wrapf(err, "invalid task id")
|
return nil, errors.Wrapf(err, "invalid task id")
|
||||||
}
|
}
|
||||||
@ -72,11 +74,14 @@ func newInitProcess(context context.Context, path, namespace string, r *shimapi.
|
|||||||
// what was actually done rather than what should have been
|
// what was actually done rather than what should have been
|
||||||
// done.
|
// done.
|
||||||
nrRootMounts := 0
|
nrRootMounts := 0
|
||||||
cleanupMounts := func() {
|
defer func() {
|
||||||
|
if success {
|
||||||
|
return
|
||||||
|
}
|
||||||
if err2 := mount.UnmountN(rootfs, 0, nrRootMounts); err2 != nil {
|
if err2 := mount.UnmountN(rootfs, 0, nrRootMounts); err2 != nil {
|
||||||
log.G(context).WithError(err2).Warn("Failed to cleanup rootfs mount")
|
log.G(context).WithError(err2).Warn("Failed to cleanup rootfs mount")
|
||||||
}
|
}
|
||||||
}
|
}()
|
||||||
for _, rm := range r.Rootfs {
|
for _, rm := range r.Rootfs {
|
||||||
m := &mount.Mount{
|
m := &mount.Mount{
|
||||||
Type: rm.Type,
|
Type: rm.Type,
|
||||||
@ -84,7 +89,6 @@ func newInitProcess(context context.Context, path, namespace string, r *shimapi.
|
|||||||
Options: rm.Options,
|
Options: rm.Options,
|
||||||
}
|
}
|
||||||
if err := m.Mount(rootfs); err != nil {
|
if err := m.Mount(rootfs); err != nil {
|
||||||
cleanupMounts()
|
|
||||||
return nil, errors.Wrapf(err, "failed to mount rootfs component %v", m)
|
return nil, errors.Wrapf(err, "failed to mount rootfs component %v", m)
|
||||||
}
|
}
|
||||||
nrRootMounts++
|
nrRootMounts++
|
||||||
@ -116,14 +120,12 @@ func newInitProcess(context context.Context, path, namespace string, r *shimapi.
|
|||||||
)
|
)
|
||||||
if r.Terminal {
|
if r.Terminal {
|
||||||
if socket, err = runc.NewConsoleSocket(filepath.Join(path, "pty.sock")); err != nil {
|
if socket, err = runc.NewConsoleSocket(filepath.Join(path, "pty.sock")); err != nil {
|
||||||
cleanupMounts()
|
|
||||||
return nil, errors.Wrap(err, "failed to create OCI runtime console socket")
|
return nil, errors.Wrap(err, "failed to create OCI runtime console socket")
|
||||||
}
|
}
|
||||||
defer os.Remove(socket.Path())
|
defer os.Remove(socket.Path())
|
||||||
} else {
|
} else {
|
||||||
// TODO: get uid/gid
|
// TODO: get uid/gid
|
||||||
if io, err = runc.NewPipeIO(0, 0); err != nil {
|
if io, err = runc.NewPipeIO(0, 0); err != nil {
|
||||||
cleanupMounts()
|
|
||||||
return nil, errors.Wrap(err, "failed to create OCI runtime io pipes")
|
return nil, errors.Wrap(err, "failed to create OCI runtime io pipes")
|
||||||
}
|
}
|
||||||
p.io = io
|
p.io = io
|
||||||
@ -143,7 +145,6 @@ func newInitProcess(context context.Context, path, namespace string, r *shimapi.
|
|||||||
NoSubreaper: true,
|
NoSubreaper: true,
|
||||||
}
|
}
|
||||||
if _, err := p.runtime.Restore(context, r.ID, r.Bundle, opts); err != nil {
|
if _, err := p.runtime.Restore(context, r.ID, r.Bundle, opts); err != nil {
|
||||||
cleanupMounts()
|
|
||||||
return nil, p.runtimeError(err, "OCI runtime restore failed")
|
return nil, p.runtimeError(err, "OCI runtime restore failed")
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -157,7 +158,6 @@ func newInitProcess(context context.Context, path, namespace string, r *shimapi.
|
|||||||
opts.ConsoleSocket = socket
|
opts.ConsoleSocket = socket
|
||||||
}
|
}
|
||||||
if err := p.runtime.Create(context, r.ID, r.Bundle, opts); err != nil {
|
if err := p.runtime.Create(context, r.ID, r.Bundle, opts); err != nil {
|
||||||
cleanupMounts()
|
|
||||||
return nil, p.runtimeError(err, "OCI runtime create failed")
|
return nil, p.runtimeError(err, "OCI runtime create failed")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -173,17 +173,14 @@ func newInitProcess(context context.Context, path, namespace string, r *shimapi.
|
|||||||
if socket != nil {
|
if socket != nil {
|
||||||
console, err := socket.ReceiveMaster()
|
console, err := socket.ReceiveMaster()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cleanupMounts()
|
|
||||||
return nil, errors.Wrap(err, "failed to retrieve console master")
|
return nil, errors.Wrap(err, "failed to retrieve console master")
|
||||||
}
|
}
|
||||||
p.console = console
|
p.console = console
|
||||||
if err := copyConsole(context, console, r.Stdin, r.Stdout, r.Stderr, &p.WaitGroup, ©WaitGroup); err != nil {
|
if err := copyConsole(context, console, r.Stdin, r.Stdout, r.Stderr, &p.WaitGroup, ©WaitGroup); err != nil {
|
||||||
cleanupMounts()
|
|
||||||
return nil, errors.Wrap(err, "failed to start console copy")
|
return nil, errors.Wrap(err, "failed to start console copy")
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if err := copyPipes(context, io, r.Stdin, r.Stdout, r.Stderr, &p.WaitGroup, ©WaitGroup); err != nil {
|
if err := copyPipes(context, io, r.Stdin, r.Stdout, r.Stderr, &p.WaitGroup, ©WaitGroup); err != nil {
|
||||||
cleanupMounts()
|
|
||||||
return nil, errors.Wrap(err, "failed to start io pipe copy")
|
return nil, errors.Wrap(err, "failed to start io pipe copy")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -191,10 +188,10 @@ func newInitProcess(context context.Context, path, namespace string, r *shimapi.
|
|||||||
copyWaitGroup.Wait()
|
copyWaitGroup.Wait()
|
||||||
pid, err := runc.ReadPidFile(pidFile)
|
pid, err := runc.ReadPidFile(pidFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cleanupMounts()
|
|
||||||
return nil, errors.Wrap(err, "failed to retrieve OCI runtime container pid")
|
return nil, errors.Wrap(err, "failed to retrieve OCI runtime container pid")
|
||||||
}
|
}
|
||||||
p.pid = pid
|
p.pid = pid
|
||||||
|
success = true
|
||||||
return p, nil
|
return p, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user