From 126b35ca433b18b236db85ee7f87831e1dce56c9 Mon Sep 17 00:00:00 2001 From: Samuel Karp Date: Wed, 7 Oct 2020 22:28:19 -0700 Subject: [PATCH] containerd-shim: use path-based unix socket This allows filesystem-based ACLs for configuring access to the socket of a shim. Ported from Michael Crosby's similar patch for v2 shims. Signed-off-by: Samuel Karp --- cmd/containerd-shim/main_unix.go | 16 ++++-- runtime/v1/linux/bundle.go | 15 +++-- runtime/v1/shim/client/client.go | 94 ++++++++++++++++++++++++++++---- 3 files changed, 106 insertions(+), 19 deletions(-) diff --git a/cmd/containerd-shim/main_unix.go b/cmd/containerd-shim/main_unix.go index 49f16e6ca..43bf71d4d 100644 --- a/cmd/containerd-shim/main_unix.go +++ b/cmd/containerd-shim/main_unix.go @@ -71,7 +71,7 @@ var ( func init() { flag.BoolVar(&debugFlag, "debug", false, "enable debug output in logs") flag.StringVar(&namespaceFlag, "namespace", "", "namespace that owns the shim") - flag.StringVar(&socketFlag, "socket", "", "abstract socket path to serve") + flag.StringVar(&socketFlag, "socket", "", "socket path to serve") flag.StringVar(&addressFlag, "address", "", "grpc address back to main containerd") flag.StringVar(&workdirFlag, "workdir", "", "path used to storge large temporary data") flag.StringVar(&runtimeRootFlag, "runtime-root", process.RuncRoot, "root directory for the runtime") @@ -202,10 +202,18 @@ func serve(ctx context.Context, server *ttrpc.Server, path string) error { f.Close() path = "[inherited from parent]" } else { - if len(path) > 106 { - return errors.Errorf("%q: unix socket path too long (> 106)", path) + const ( + abstractSocketPrefix = "\x00" + socketPathLimit = 106 + ) + p := strings.TrimPrefix(path, "unix://") + if len(p) == len(path) { + p = abstractSocketPrefix + p } - l, err = net.Listen("unix", "\x00"+path) + if len(p) > socketPathLimit { + return errors.Errorf("%q: unix socket path too long (> %d)", p, socketPathLimit) + } + l, err = net.Listen("unix", p) } if err != nil { return err diff --git a/runtime/v1/linux/bundle.go b/runtime/v1/linux/bundle.go index e8b629b79..9d0a6c447 100644 --- a/runtime/v1/linux/bundle.go +++ b/runtime/v1/linux/bundle.go @@ -91,7 +91,7 @@ func ShimRemote(c *Config, daemonAddress, cgroup string, exitHandler func()) Shi return func(b *bundle, ns string, ropts *runctypes.RuncOptions) (shim.Config, client.Opt) { config := b.shimConfig(ns, c, ropts) return config, - client.WithStart(c.Shim, b.shimAddress(ns), daemonAddress, cgroup, c.ShimDebug, exitHandler) + client.WithStart(c.Shim, b.shimAddress(ns, daemonAddress), daemonAddress, cgroup, c.ShimDebug, exitHandler) } } @@ -117,6 +117,11 @@ func (b *bundle) NewShimClient(ctx context.Context, namespace string, getClientO // Delete deletes the bundle from disk func (b *bundle) Delete() error { + address, _ := b.loadAddress() + if address != "" { + // we don't care about errors here + client.RemoveSocket(address) + } err := atomicDelete(b.path) if err == nil { return atomicDelete(b.workDir) @@ -133,9 +138,11 @@ func (b *bundle) legacyShimAddress(namespace string) string { return filepath.Join(string(filepath.Separator), "containerd-shim", namespace, b.id, "shim.sock") } -func (b *bundle) shimAddress(namespace string) string { - d := sha256.Sum256([]byte(filepath.Join(namespace, b.id))) - return filepath.Join(string(filepath.Separator), "containerd-shim", fmt.Sprintf("%x.sock", d)) +const socketRoot = "/run/containerd" + +func (b *bundle) shimAddress(namespace, socketPath string) string { + d := sha256.Sum256([]byte(filepath.Join(socketPath, namespace, b.id))) + return fmt.Sprintf("unix://%s/%x", filepath.Join(socketRoot, "s"), d) } func (b *bundle) loadAddress() (string, error) { diff --git a/runtime/v1/shim/client/client.go b/runtime/v1/shim/client/client.go index 9653454af..e35dafec3 100644 --- a/runtime/v1/shim/client/client.go +++ b/runtime/v1/shim/client/client.go @@ -59,9 +59,17 @@ func WithStart(binary, address, daemonAddress, cgroup string, debug bool, exitHa return func(ctx context.Context, config shim.Config) (_ shimapi.ShimService, _ io.Closer, err error) { socket, err := newSocket(address) if err != nil { - return nil, nil, err + if !eaddrinuse(err) { + return nil, nil, err + } + if err := RemoveSocket(address); err != nil { + return nil, nil, errors.Wrap(err, "remove already used socket") + } + if socket, err = newSocket(address); err != nil { + return nil, nil, err + } } - defer socket.Close() + f, err := socket.File() if err != nil { return nil, nil, errors.Wrapf(err, "failed to get fd for socket %s", address) @@ -108,6 +116,8 @@ func WithStart(binary, address, daemonAddress, cgroup string, debug bool, exitHa if stderrLog != nil { stderrLog.Close() } + socket.Close() + RemoveSocket(address) }() log.G(ctx).WithFields(logrus.Fields{ "pid": cmd.Process.Pid, @@ -142,6 +152,26 @@ func WithStart(binary, address, daemonAddress, cgroup string, debug bool, exitHa } } +func eaddrinuse(err error) bool { + cause := errors.Cause(err) + netErr, ok := cause.(*net.OpError) + if !ok { + return false + } + if netErr.Op != "listen" { + return false + } + syscallErr, ok := netErr.Err.(*os.SyscallError) + if !ok { + return false + } + errno, ok := syscallErr.Err.(syscall.Errno) + if !ok { + return false + } + return errno == syscall.EADDRINUSE +} + // setupOOMScore gets containerd's oom score and adds +1 to it // to ensure a shim has a lower* score than the daemons func setupOOMScore(shimPid int) error { @@ -214,31 +244,73 @@ func writeFile(path, address string) error { return os.Rename(tempPath, path) } -func newSocket(address string) (*net.UnixListener, error) { - if len(address) > 106 { - return nil, errors.Errorf("%q: unix socket path too long (> 106)", address) +const ( + abstractSocketPrefix = "\x00" + socketPathLimit = 106 +) + +type socket string + +func (s socket) isAbstract() bool { + return !strings.HasPrefix(string(s), "unix://") +} + +func (s socket) path() string { + path := strings.TrimPrefix(string(s), "unix://") + // if there was no trim performed, we assume an abstract socket + if len(path) == len(s) { + path = abstractSocketPrefix + path } - l, err := net.Listen("unix", "\x00"+address) + return path +} + +func newSocket(address string) (*net.UnixListener, error) { + if len(address) > socketPathLimit { + return nil, errors.Errorf("%q: unix socket path too long (> %d)", address, socketPathLimit) + } + var ( + sock = socket(address) + path = sock.path() + ) + if !sock.isAbstract() { + if err := os.MkdirAll(filepath.Dir(path), 0600); err != nil { + return nil, errors.Wrapf(err, "%s", path) + } + } + l, err := net.Listen("unix", path) if err != nil { - return nil, errors.Wrapf(err, "failed to listen to abstract unix socket %q", address) + return nil, errors.Wrapf(err, "failed to listen to unix socket %q (abstract: %t)", address, sock.isAbstract()) + } + if err := os.Chmod(path, 0600); err != nil { + l.Close() + return nil, err } return l.(*net.UnixListener), nil } +// RemoveSocket removes the socket at the specified address if +// it exists on the filesystem +func RemoveSocket(address string) error { + sock := socket(address) + if !sock.isAbstract() { + return os.Remove(sock.path()) + } + return nil +} + func connect(address string, d func(string, time.Duration) (net.Conn, error)) (net.Conn, error) { return d(address, 100*time.Second) } -func annonDialer(address string, timeout time.Duration) (net.Conn, error) { - address = strings.TrimPrefix(address, "unix://") - return dialer.Dialer("\x00"+address, timeout) +func anonDialer(address string, timeout time.Duration) (net.Conn, error) { + return dialer.Dialer(socket(address).path(), timeout) } // WithConnect connects to an existing shim func WithConnect(address string, onClose func()) Opt { return func(ctx context.Context, config shim.Config) (shimapi.ShimService, io.Closer, error) { - conn, err := connect(address, annonDialer) + conn, err := connect(address, anonDialer) if err != nil { return nil, nil, err }