Improve shim shutdown logic

Shims no longer call `os.Exit` but close the context on shutdown so that
events and other resources have hit the `defer`s.

Signed-off-by: Michael Crosby <crosbymichael@gmail.com>
This commit is contained in:
Michael Crosby 2019-04-10 14:29:10 -04:00
parent 4ba756edda
commit ae87730ad2
8 changed files with 88 additions and 42 deletions

View File

@ -23,7 +23,6 @@ import (
"os" "os"
"github.com/containerd/containerd/errdefs" "github.com/containerd/containerd/errdefs"
"github.com/containerd/containerd/events"
"github.com/containerd/containerd/runtime/v2/shim" "github.com/containerd/containerd/runtime/v2/shim"
taskAPI "github.com/containerd/containerd/runtime/v2/task" taskAPI "github.com/containerd/containerd/runtime/v2/task"
ptypes "github.com/gogo/protobuf/types" ptypes "github.com/gogo/protobuf/types"
@ -37,7 +36,7 @@ var (
) )
// New returns a new shim service // New returns a new shim service
func New(ctx context.Context, id string, publisher events.Publisher) (shim.Shim, error) { func New(ctx context.Context, id string, publisher shim.Publisher, shutdown func()) (shim.Shim, error) {
return &service{}, nil return &service{}, nil
} }

View File

@ -24,15 +24,15 @@ import (
"github.com/containerd/cgroups" "github.com/containerd/cgroups"
eventstypes "github.com/containerd/containerd/api/events" eventstypes "github.com/containerd/containerd/api/events"
"github.com/containerd/containerd/events"
"github.com/containerd/containerd/runtime" "github.com/containerd/containerd/runtime"
"github.com/containerd/containerd/runtime/v2/shim"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
// NewOOMEpoller returns an epoll implementation that listens to OOM events // NewOOMEpoller returns an epoll implementation that listens to OOM events
// from a container's cgroups. // from a container's cgroups.
func NewOOMEpoller(publisher events.Publisher) (*Epoller, error) { func NewOOMEpoller(publisher shim.Publisher) (*Epoller, error) {
fd, err := unix.EpollCreate1(unix.EPOLL_CLOEXEC) fd, err := unix.EpollCreate1(unix.EPOLL_CLOEXEC)
if err != nil { if err != nil {
return nil, err return nil, err
@ -49,7 +49,7 @@ type Epoller struct {
mu sync.Mutex mu sync.Mutex
fd int fd int
publisher events.Publisher publisher shim.Publisher
set map[uintptr]*item set map[uintptr]*item
} }

View File

@ -33,7 +33,6 @@ import (
eventstypes "github.com/containerd/containerd/api/events" eventstypes "github.com/containerd/containerd/api/events"
"github.com/containerd/containerd/api/types/task" "github.com/containerd/containerd/api/types/task"
"github.com/containerd/containerd/errdefs" "github.com/containerd/containerd/errdefs"
"github.com/containerd/containerd/events"
"github.com/containerd/containerd/log" "github.com/containerd/containerd/log"
"github.com/containerd/containerd/mount" "github.com/containerd/containerd/mount"
"github.com/containerd/containerd/namespaces" "github.com/containerd/containerd/namespaces"
@ -58,12 +57,11 @@ var (
) )
// New returns a new shim service that can be used via GRPC // New returns a new shim service that can be used via GRPC
func New(ctx context.Context, id string, publisher events.Publisher) (shim.Shim, error) { func New(ctx context.Context, id string, publisher shim.Publisher, shutdown func()) (shim.Shim, error) {
ep, err := runc.NewOOMEpoller(publisher) ep, err := runc.NewOOMEpoller(publisher)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ctx, cancel := context.WithCancel(ctx)
go ep.Run(ctx) go ep.Run(ctx)
s := &service{ s := &service{
id: id, id: id,
@ -71,15 +69,15 @@ func New(ctx context.Context, id string, publisher events.Publisher) (shim.Shim,
events: make(chan interface{}, 128), events: make(chan interface{}, 128),
ec: shim.Default.Subscribe(), ec: shim.Default.Subscribe(),
ep: ep, ep: ep,
cancel: cancel, cancel: shutdown,
} }
go s.processExits() go s.processExits()
runcC.Monitor = shim.Default runcC.Monitor = shim.Default
if err := s.initPlatform(); err != nil { if err := s.initPlatform(); err != nil {
cancel() shutdown()
return nil, errors.Wrap(err, "failed to initialized platform behavior") return nil, errors.Wrap(err, "failed to initialized platform behavior")
} }
go s.forward(publisher) go s.forward(ctx, publisher)
return s, nil return s, nil
} }
@ -511,7 +509,7 @@ func (s *service) Connect(ctx context.Context, r *taskAPI.ConnectRequest) (*task
func (s *service) Shutdown(ctx context.Context, r *taskAPI.ShutdownRequest) (*ptypes.Empty, error) { func (s *service) Shutdown(ctx context.Context, r *taskAPI.ShutdownRequest) (*ptypes.Empty, error) {
s.cancel() s.cancel()
os.Exit(0) close(s.events)
return empty, nil return empty, nil
} }
@ -619,15 +617,18 @@ func (s *service) getContainerPids(ctx context.Context, id string) ([]uint32, er
return pids, nil return pids, nil
} }
func (s *service) forward(publisher events.Publisher) { func (s *service) forward(ctx context.Context, publisher shim.Publisher) {
ns, _ := namespaces.Namespace(ctx)
ctx = namespaces.WithNamespace(context.Background(), ns)
for e := range s.events { for e := range s.events {
ctx, cancel := context.WithTimeout(s.context, 5*time.Second) ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
err := publisher.Publish(ctx, runc.GetTopic(e), e) err := publisher.Publish(ctx, runc.GetTopic(e), e)
cancel() cancel()
if err != nil { if err != nil {
logrus.WithError(err).Error("post event") logrus.WithError(err).Error("post event")
} }
} }
publisher.Close()
} }
func (s *service) getContainer() (*runc.Container, error) { func (s *service) getContainer() (*runc.Container, error) {

View File

@ -34,7 +34,6 @@ import (
eventstypes "github.com/containerd/containerd/api/events" eventstypes "github.com/containerd/containerd/api/events"
"github.com/containerd/containerd/api/types/task" "github.com/containerd/containerd/api/types/task"
"github.com/containerd/containerd/errdefs" "github.com/containerd/containerd/errdefs"
"github.com/containerd/containerd/events"
"github.com/containerd/containerd/log" "github.com/containerd/containerd/log"
"github.com/containerd/containerd/mount" "github.com/containerd/containerd/mount"
"github.com/containerd/containerd/namespaces" "github.com/containerd/containerd/namespaces"
@ -71,12 +70,11 @@ type spec struct {
} }
// New returns a new shim service that can be used via GRPC // New returns a new shim service that can be used via GRPC
func New(ctx context.Context, id string, publisher events.Publisher) (shim.Shim, error) { func New(ctx context.Context, id string, publisher shim.Publisher, shutdown func()) (shim.Shim, error) {
ep, err := runc.NewOOMEpoller(publisher) ep, err := runc.NewOOMEpoller(publisher)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ctx, cancel := context.WithCancel(ctx)
go ep.Run(ctx) go ep.Run(ctx)
s := &service{ s := &service{
id: id, id: id,
@ -84,16 +82,16 @@ func New(ctx context.Context, id string, publisher events.Publisher) (shim.Shim,
events: make(chan interface{}, 128), events: make(chan interface{}, 128),
ec: shim.Default.Subscribe(), ec: shim.Default.Subscribe(),
ep: ep, ep: ep,
cancel: cancel, cancel: shutdown,
containers: make(map[string]*runc.Container), containers: make(map[string]*runc.Container),
} }
go s.processExits() go s.processExits()
runcC.Monitor = shim.Default runcC.Monitor = shim.Default
if err := s.initPlatform(); err != nil { if err := s.initPlatform(); err != nil {
cancel() shutdown()
return nil, errors.Wrap(err, "failed to initialized platform behavior") return nil, errors.Wrap(err, "failed to initialized platform behavior")
} }
go s.forward(publisher) go s.forward(ctx, publisher)
return s, nil return s, nil
} }
@ -570,7 +568,7 @@ func (s *service) Shutdown(ctx context.Context, r *taskAPI.ShutdownRequest) (*pt
return empty, nil return empty, nil
} }
s.cancel() s.cancel()
os.Exit(0) close(s.events)
return empty, nil return empty, nil
} }
@ -689,15 +687,18 @@ func (s *service) getContainerPids(ctx context.Context, id string) ([]uint32, er
return pids, nil return pids, nil
} }
func (s *service) forward(publisher events.Publisher) { func (s *service) forward(ctx context.Context, publisher shim.Publisher) {
ns, _ := namespaces.Namespace(ctx)
ctx = namespaces.WithNamespace(context.Background(), ns)
for e := range s.events { for e := range s.events {
ctx, cancel := context.WithTimeout(s.context, 5*time.Second) ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
err := publisher.Publish(ctx, runc.GetTopic(e), e) err := publisher.Publish(ctx, runc.GetTopic(e), e)
cancel() cancel()
if err != nil { if err != nil {
logrus.WithError(err).Error("post event") logrus.WithError(err).Error("post event")
} }
} }
publisher.Close()
} }
func (s *service) getContainer(id string) (*runc.Container, error) { func (s *service) getContainer(id string) (*runc.Container, error) {

View File

@ -41,7 +41,6 @@ import (
containerd_types "github.com/containerd/containerd/api/types" containerd_types "github.com/containerd/containerd/api/types"
"github.com/containerd/containerd/api/types/task" "github.com/containerd/containerd/api/types/task"
"github.com/containerd/containerd/errdefs" "github.com/containerd/containerd/errdefs"
"github.com/containerd/containerd/events"
"github.com/containerd/containerd/log" "github.com/containerd/containerd/log"
"github.com/containerd/containerd/mount" "github.com/containerd/containerd/mount"
"github.com/containerd/containerd/namespaces" "github.com/containerd/containerd/namespaces"
@ -129,12 +128,13 @@ func forwardRunhcsLogs(ctx context.Context, c net.Conn, fields logrus.Fields) {
} }
// New returns a new runhcs shim service that can be used via GRPC // New returns a new runhcs shim service that can be used via GRPC
func New(ctx context.Context, id string, publisher events.Publisher) (shim.Shim, error) { func New(ctx context.Context, id string, publisher shim.Publisher, shutdown func()) (shim.Shim, error) {
return &service{ return &service{
context: ctx, context: ctx,
id: id, id: id,
processes: make(map[string]*process), processes: make(map[string]*process),
publisher: publisher, publisher: publisher,
shutdown: shutdown,
}, nil }, nil
} }
@ -159,7 +159,8 @@ type service struct {
id string id string
processes map[string]*process processes map[string]*process
publisher events.Publisher publisher shim.Publisher
shutdown func()
} }
func (s *service) newRunhcs() *runhcs.Runhcs { func (s *service) newRunhcs() *runhcs.Runhcs {
@ -1068,7 +1069,8 @@ func (s *service) Shutdown(ctx context.Context, r *taskAPI.ShutdownRequest) (*pt
if s.debugListener != nil { if s.debugListener != nil {
s.debugListener.Close() s.debugListener.Close()
} }
s.publisher.Close()
s.shutdown()
os.Exit(0)
return empty, nil return empty, nil
} }

View File

@ -20,11 +20,13 @@ import (
"context" "context"
"flag" "flag"
"fmt" "fmt"
"io"
"net" "net"
"os" "os"
"runtime" "runtime"
"runtime/debug" "runtime/debug"
"strings" "strings"
"sync"
"time" "time"
v1 "github.com/containerd/containerd/api/services/ttrpc/events/v1" v1 "github.com/containerd/containerd/api/services/ttrpc/events/v1"
@ -46,8 +48,14 @@ type Client struct {
signals chan os.Signal signals chan os.Signal
} }
// Publisher for events
type Publisher interface {
events.Publisher
io.Closer
}
// Init func for the creation of a shim server // Init func for the creation of a shim server
type Init func(context.Context, string, events.Publisher) (Shim, error) type Init func(context.Context, string, Publisher, func()) (Shim, error)
// Shim server interface // Shim server interface
type Shim interface { type Shim interface {
@ -156,15 +164,18 @@ func run(id string, initFunc Init, config Config) error {
return err return err
} }
} }
address := fmt.Sprintf("%s.ttrpc", addressFlag)
publisher := &remoteEventsPublisher{ conn, err := connect(address, dialer)
address: fmt.Sprintf("%s.ttrpc", addressFlag),
}
conn, err := connect(publisher.address, dialer)
if err != nil { if err != nil {
return err return err
} }
defer conn.Close() publisher := &remoteEventsPublisher{
address: address,
conn: conn,
closed: make(chan struct{}),
}
defer publisher.Close()
publisher.client = v1.NewEventsClient(ttrpc.NewClient(conn)) publisher.client = v1.NewEventsClient(ttrpc.NewClient(conn))
if namespaceFlag == "" { if namespaceFlag == "" {
return fmt.Errorf("shim namespace cannot be empty") return fmt.Errorf("shim namespace cannot be empty")
@ -172,8 +183,9 @@ func run(id string, initFunc Init, config Config) error {
ctx := namespaces.WithNamespace(context.Background(), namespaceFlag) ctx := namespaces.WithNamespace(context.Background(), namespaceFlag)
ctx = context.WithValue(ctx, OptsKey{}, Opts{BundlePath: bundlePath, Debug: debugFlag}) ctx = context.WithValue(ctx, OptsKey{}, Opts{BundlePath: bundlePath, Debug: debugFlag})
ctx = log.WithLogger(ctx, log.G(ctx).WithField("runtime", id)) ctx = log.WithLogger(ctx, log.G(ctx).WithField("runtime", id))
ctx, cancel := context.WithCancel(ctx)
service, err := initFunc(ctx, idFlag, publisher) service, err := initFunc(ctx, idFlag, publisher, cancel)
if err != nil { if err != nil {
return err return err
} }
@ -183,7 +195,7 @@ func run(id string, initFunc Init, config Config) error {
"pid": os.Getpid(), "pid": os.Getpid(),
"namespace": namespaceFlag, "namespace": namespaceFlag,
}) })
go handleSignals(logger, signals) go handleSignals(ctx, logger, signals)
response, err := service.Cleanup(ctx) response, err := service.Cleanup(ctx)
if err != nil { if err != nil {
return err return err
@ -210,7 +222,17 @@ func run(id string, initFunc Init, config Config) error {
return err return err
} }
client := NewShimClient(ctx, service, signals) client := NewShimClient(ctx, service, signals)
return client.Serve() if err := client.Serve(); err != nil {
if err != context.Canceled {
return err
}
}
select {
case <-publisher.Done():
return nil
case <-time.After(5 * time.Second):
return errors.New("publisher not closed")
}
} }
} }
@ -254,7 +276,7 @@ func (s *Client) Serve() error {
dumpStacks(logger) dumpStacks(logger)
} }
}() }()
return handleSignals(logger, s.signals) return handleSignals(s.context, logger, s.signals)
} }
// serve serves the ttrpc API over a unix socket at the provided path // serve serves the ttrpc API over a unix socket at the provided path
@ -291,7 +313,22 @@ func dumpStacks(logger *logrus.Entry) {
type remoteEventsPublisher struct { type remoteEventsPublisher struct {
address string address string
conn net.Conn
client v1.EventsService client v1.EventsService
closed chan struct{}
closer sync.Once
}
func (l *remoteEventsPublisher) Done() <-chan struct{} {
return l.closed
}
func (l *remoteEventsPublisher) Close() (err error) {
l.closer.Do(func() {
err = l.conn.Close()
close(l.closed)
})
return err
} }
func (l *remoteEventsPublisher) Publish(ctx context.Context, topic string, event events.Event) error { func (l *remoteEventsPublisher) Publish(ctx context.Context, topic string, event events.Event) error {

View File

@ -71,11 +71,14 @@ func serveListener(path string) (net.Listener, error) {
return l, nil return l, nil
} }
func handleSignals(logger *logrus.Entry, signals chan os.Signal) error { func handleSignals(ctx context.Context, logger *logrus.Entry, signals chan os.Signal) error {
logger.Info("starting signal loop") logger.Info("starting signal loop")
for { for {
for s := range signals { select {
case <-ctx.Done():
return ctx.Err()
case s := <-signals:
switch s { switch s {
case unix.SIGCHLD: case unix.SIGCHLD:
if err := Reap(); err != nil { if err := Reap(); err != nil {

View File

@ -104,11 +104,14 @@ func serveListener(path string) (net.Listener, error) {
return l, nil return l, nil
} }
func handleSignals(logger *logrus.Entry, signals chan os.Signal) error { func handleSignals(ctx context.Context, logger *logrus.Entry, signals chan os.Signal) error {
logger.Info("starting signal loop") logger.Info("starting signal loop")
for { for {
for s := range signals { select {
case <-ctx.Done():
return ctx.Err()
case s := <-signals:
switch s { switch s {
case os.Interrupt: case os.Interrupt:
return nil return nil