From bd908acabd1a31c8329570b5283e8fdca0b39906 Mon Sep 17 00:00:00 2001 From: Michael Crosby Date: Wed, 24 Jun 2020 15:13:21 -0400 Subject: [PATCH] Use path based unix socket for shims This allows filesystem based ACLs for configuring access to the socket of a shim. Co-authored-by: Samuel Karp Signed-off-by: Samuel Karp Signed-off-by: Michael Crosby Signed-off-by: Michael Crosby --- cmd/ctr/commands/shim/shim.go | 8 ++- runtime/v2/runc/v1/service.go | 18 ++++-- runtime/v2/runc/v2/service.go | 43 +++++++++++---- runtime/v2/shim/shim.go | 9 ++- runtime/v2/shim/shim_unix.go | 8 +-- runtime/v2/shim/util.go | 2 +- runtime/v2/shim/util_unix.go | 98 +++++++++++++++++++++++++++++---- runtime/v2/shim/util_windows.go | 6 ++ 8 files changed, 155 insertions(+), 37 deletions(-) diff --git a/cmd/ctr/commands/shim/shim.go b/cmd/ctr/commands/shim/shim.go index a5caeae2d..c210dbc6c 100644 --- a/cmd/ctr/commands/shim/shim.go +++ b/cmd/ctr/commands/shim/shim.go @@ -24,6 +24,7 @@ import ( "io/ioutil" "net" "path/filepath" + "strings" "github.com/containerd/console" "github.com/containerd/containerd/cmd/ctr/commands" @@ -240,10 +241,11 @@ func getTaskService(context *cli.Context) (task.TaskService, error) { s1 := filepath.Join(string(filepath.Separator), "containerd-shim", ns, id, "shim.sock") // this should not error, ctr always get a default ns ctx := namespaces.WithNamespace(gocontext.Background(), ns) - s2, _ := shim.SocketAddress(ctx, id) + s2, _ := shim.SocketAddress(ctx, context.GlobalString("address"), id) + s2 = strings.TrimPrefix(s2, "unix://") - for _, socket := range []string{s1, s2} { - conn, err := net.Dial("unix", "\x00"+socket) + for _, socket := range []string{s2, "\x00" + s1} { + conn, err := net.Dial("unix", socket) if err == nil { client := ttrpc.NewClient(conn) diff --git a/runtime/v2/runc/v1/service.go b/runtime/v2/runc/v1/service.go index e8ef09c8e..6d0140a8d 100644 --- a/runtime/v2/runc/v1/service.go +++ b/runtime/v2/runc/v1/service.go @@ -131,20 +131,26 @@ func (s *service) StartShim(ctx context.Context, id, containerdBinary, container if err != nil { return "", err } - address, err := shim.SocketAddress(ctx, id) + address, err := shim.SocketAddress(ctx, containerdAddress, id) if err != nil { return "", err } socket, err := shim.NewSocket(address) if err != nil { - return "", err + if !shim.SocketEaddrinuse(err) { + return "", err + } + if err := shim.RemoveSocket(address); err != nil { + return "", errors.Wrap(err, "remove already used socket") + } + if socket, err = shim.NewSocket(address); err != nil { + return "", err + } } - defer socket.Close() f, err := socket.File() if err != nil { return "", err } - defer f.Close() cmd.ExtraFiles = append(cmd.ExtraFiles, f) @@ -153,6 +159,7 @@ func (s *service) StartShim(ctx context.Context, id, containerdBinary, container } defer func() { if err != nil { + _ = shim.RemoveSocket(address) cmd.Process.Kill() } }() @@ -551,6 +558,9 @@ func (s *service) Connect(ctx context.Context, r *taskAPI.ConnectRequest) (*task func (s *service) Shutdown(ctx context.Context, r *taskAPI.ShutdownRequest) (*ptypes.Empty, error) { s.cancel() close(s.events) + if address, err := shim.ReadAddress("address"); err == nil { + _ = shim.RemoveSocket(address) + } return empty, nil } diff --git a/runtime/v2/runc/v2/service.go b/runtime/v2/runc/v2/service.go index d3ea1e8ff..7f15ee89b 100644 --- a/runtime/v2/runc/v2/service.go +++ b/runtime/v2/runc/v2/service.go @@ -25,7 +25,6 @@ import ( "os" "os/exec" "path/filepath" - "strings" "sync" "syscall" "time" @@ -105,6 +104,10 @@ func New(ctx context.Context, id string, publisher shim.Publisher, shutdown func return nil, errors.Wrap(err, "failed to initialized platform behavior") } go s.forward(ctx, publisher) + + if address, err := shim.ReadAddress("address"); err == nil { + s.shimAddress = address + } return s, nil } @@ -124,7 +127,8 @@ type service struct { containers map[string]*runc.Container - cancel func() + shimAddress string + cancel func() } func newCommand(ctx context.Context, id, containerdBinary, containerdAddress, containerdTTRPCAddress string) (*exec.Cmd, error) { @@ -183,30 +187,48 @@ func (s *service) StartShim(ctx context.Context, id, containerdBinary, container break } } - address, err := shim.SocketAddress(ctx, grouping) + address, err := shim.SocketAddress(ctx, containerdAddress, grouping) if err != nil { return "", err } + socket, err := shim.NewSocket(address) if err != nil { - if strings.Contains(err.Error(), "address already in use") { + // the only time where this would happen is if there is a bug and the socket + // was not cleaned up in the cleanup method of the shim or we are using the + // grouping functionality where the new process should be run with the same + // shim as an existing container + if !shim.SocketEaddrinuse(err) { + return "", errors.Wrap(err, "create new shim socket") + } + if shim.CanConnect(address) { if err := shim.WriteAddress("address", address); err != nil { - return "", err + return "", errors.Wrap(err, "write existing socket for shim") } return address, nil } - return "", err + if err := shim.RemoveSocket(address); err != nil { + return "", errors.Wrap(err, "remove pre-existing socket") + } + if socket, err = shim.NewSocket(address); err != nil { + return "", errors.Wrap(err, "try create new shim socket 2x") + } } - defer socket.Close() + defer func() { + if retErr != nil { + socket.Close() + _ = shim.RemoveSocket(address) + } + }() f, err := socket.File() if err != nil { return "", err } - defer f.Close() cmd.ExtraFiles = append(cmd.ExtraFiles, f) if err := cmd.Start(); err != nil { + f.Close() return "", err } defer func() { @@ -273,7 +295,6 @@ func (s *service) Cleanup(ctx context.Context) (*taskAPI.DeleteResponse, error) if err != nil { return nil, err } - runtime, err := runc.ReadRuntime(path) if err != nil { return nil, err @@ -652,7 +673,9 @@ func (s *service) Shutdown(ctx context.Context, r *taskAPI.ShutdownRequest) (*pt if s.platform != nil { s.platform.Close() } - + if s.shimAddress != "" { + _ = shim.RemoveSocket(s.shimAddress) + } return empty, nil } diff --git a/runtime/v2/shim/shim.go b/runtime/v2/shim/shim.go index 026ebb4e2..2f62b57c9 100644 --- a/runtime/v2/shim/shim.go +++ b/runtime/v2/shim/shim.go @@ -104,7 +104,7 @@ func parseFlags() { flag.BoolVar(&versionFlag, "v", false, "show the shim version and exit") flag.StringVar(&namespaceFlag, "namespace", "", "namespace that owns the shim") flag.StringVar(&idFlag, "id", "", "id of the task") - flag.StringVar(&socketFlag, "socket", "", "abstract socket path to serve") + flag.StringVar(&socketFlag, "socket", "", "socket path to serve") flag.StringVar(&bundlePath, "bundle", "", "path to the bundle if not workdir") flag.StringVar(&addressFlag, "address", "", "grpc address back to main containerd") @@ -195,7 +195,6 @@ func run(id string, initFunc Init, config Config) error { ctx = context.WithValue(ctx, OptsKey{}, Opts{BundlePath: bundlePath, Debug: debugFlag}) ctx = log.WithLogger(ctx, log.G(ctx).WithField("runtime", id)) ctx, cancel := context.WithCancel(ctx) - service, err := initFunc(ctx, idFlag, publisher, cancel) if err != nil { return err @@ -300,11 +299,15 @@ func serve(ctx context.Context, server *ttrpc.Server, path string) error { return err } go func() { - defer l.Close() if err := server.Serve(ctx, l); err != nil && !strings.Contains(err.Error(), "use of closed network connection") { logrus.WithError(err).Fatal("containerd-shim: ttrpc server failure") } + l.Close() + if address, err := ReadAddress("address"); err == nil { + _ = RemoveSocket(address) + } + }() return nil } diff --git a/runtime/v2/shim/shim_unix.go b/runtime/v2/shim/shim_unix.go index e6dc3e02f..a712dc7a5 100644 --- a/runtime/v2/shim/shim_unix.go +++ b/runtime/v2/shim/shim_unix.go @@ -58,15 +58,15 @@ func serveListener(path string) (net.Listener, error) { l, err = net.FileListener(os.NewFile(3, "socket")) path = "[inherited from parent]" } else { - if len(path) > 106 { - return nil, errors.Errorf("%q: unix socket path too long (> 106)", path) + if len(path) > socketPathLimit { + return nil, errors.Errorf("%q: unix socket path too long (> %d)", path, socketPathLimit) } - l, err = net.Listen("unix", "\x00"+path) + l, err = net.Listen("unix", path) } if err != nil { return nil, err } - logrus.WithField("socket", path).Debug("serving api on abstract socket") + logrus.WithField("socket", path).Debug("serving api on socket") return l, nil } diff --git a/runtime/v2/shim/util.go b/runtime/v2/shim/util.go index c8efd0dac..2bb786d90 100644 --- a/runtime/v2/shim/util.go +++ b/runtime/v2/shim/util.go @@ -169,7 +169,7 @@ func WriteAddress(path, address string) error { // ErrNoAddress is returned when the address file has no content var ErrNoAddress = errors.New("no shim address") -// ReadAddress returns the shim's abstract socket address from the path +// ReadAddress returns the shim's socket address from the path func ReadAddress(path string) (string, error) { path, err := filepath.Abs(path) if err != nil { diff --git a/runtime/v2/shim/util_unix.go b/runtime/v2/shim/util_unix.go index 093a66239..2b0d0ada3 100644 --- a/runtime/v2/shim/util_unix.go +++ b/runtime/v2/shim/util_unix.go @@ -35,7 +35,10 @@ import ( "github.com/pkg/errors" ) -const shimBinaryFormat = "containerd-shim-%s-%s" +const ( + shimBinaryFormat = "containerd-shim-%s-%s" + socketPathLimit = 106 +) func getSysProcAttr() *syscall.SysProcAttr { return &syscall.SysProcAttr{ @@ -63,20 +66,21 @@ func AdjustOOMScore(pid int) error { return nil } -// SocketAddress returns an abstract socket address -func SocketAddress(ctx context.Context, id string) (string, error) { +const socketRoot = "/run/containerd" + +// SocketAddress returns a socket address +func SocketAddress(ctx context.Context, socketPath, id string) (string, error) { ns, err := namespaces.NamespaceRequired(ctx) if err != nil { return "", err } - d := sha256.Sum256([]byte(filepath.Join(ns, id))) - return filepath.Join(string(filepath.Separator), "containerd-shim", fmt.Sprintf("%x.sock", d)), nil + d := sha256.Sum256([]byte(filepath.Join(socketPath, ns, id))) + return fmt.Sprintf("unix://%s/%x", filepath.Join(socketRoot, "s"), d), nil } -// AnonDialer returns a dialer for an abstract socket +// AnonDialer returns a dialer for a socket func AnonDialer(address string, timeout time.Duration) (net.Conn, error) { - address = strings.TrimPrefix(address, "unix://") - return dialer.Dialer("\x00"+address, timeout) + return dialer.Dialer(socket(address).path(), timeout) } func AnonReconnectDialer(address string, timeout time.Duration) (net.Conn, error) { @@ -85,12 +89,82 @@ func AnonReconnectDialer(address string, timeout time.Duration) (net.Conn, error // NewSocket returns a new socket func NewSocket(address string) (*net.UnixListener, error) { - if len(address) > 106 { - return nil, errors.Errorf("%q: unix socket path too long (> 106)", address) + var ( + sock = socket(address) + path = sock.path() + ) + if !sock.isAbstract() { + if err := os.MkdirAll(filepath.Dir(path), 0600); err != nil { + return nil, errors.Wrapf(err, "%s", path) + } } - l, err := net.Listen("unix", "\x00"+address) + l, err := net.Listen("unix", path) if err != nil { - return nil, errors.Wrapf(err, "failed to listen to abstract unix socket %q", address) + return nil, err + } + if err := os.Chmod(path, 0600); err != nil { + os.Remove(sock.path()) + l.Close() + return nil, err } return l.(*net.UnixListener), nil } + +const abstractSocketPrefix = "\x00" + +type socket string + +func (s socket) isAbstract() bool { + return !strings.HasPrefix(string(s), "unix://") +} + +func (s socket) path() string { + path := strings.TrimPrefix(string(s), "unix://") + // if there was no trim performed, we assume an abstract socket + if len(path) == len(s) { + path = abstractSocketPrefix + path + } + return path +} + +// RemoveSocket removes the socket at the specified address if +// it exists on the filesystem +func RemoveSocket(address string) error { + sock := socket(address) + if !sock.isAbstract() { + return os.Remove(sock.path()) + } + return nil +} + +// SocketEaddrinuse returns true if the provided error is caused by the +// EADDRINUSE error number +func SocketEaddrinuse(err error) bool { + netErr, ok := err.(*net.OpError) + if !ok { + return false + } + if netErr.Op != "listen" { + return false + } + syscallErr, ok := netErr.Err.(*os.SyscallError) + if !ok { + return false + } + errno, ok := syscallErr.Err.(syscall.Errno) + if !ok { + return false + } + return errno == syscall.EADDRINUSE +} + +// CanConnect returns true if the socket provided at the address +// is accepting new connections +func CanConnect(address string) bool { + conn, err := AnonDialer(address, 100*time.Millisecond) + if err != nil { + return false + } + conn.Close() + return true +} diff --git a/runtime/v2/shim/util_windows.go b/runtime/v2/shim/util_windows.go index a94cdf250..325c29004 100644 --- a/runtime/v2/shim/util_windows.go +++ b/runtime/v2/shim/util_windows.go @@ -79,3 +79,9 @@ func AnonDialer(address string, timeout time.Duration) (net.Conn, error) { return c, nil } } + +// RemoveSocket removes the socket at the specified address if +// it exists on the filesystem +func RemoveSocket(address string) error { + return nil +}