Improve shim shutdown logic
Shims no longer call `os.Exit` but close the context on shutdown so that events and other resources have hit the `defer`s. Signed-off-by: Michael Crosby <crosbymichael@gmail.com>
This commit is contained in:
		| @@ -23,7 +23,6 @@ import ( | ||||
| 	"os" | ||||
|  | ||||
| 	"github.com/containerd/containerd/errdefs" | ||||
| 	"github.com/containerd/containerd/events" | ||||
| 	"github.com/containerd/containerd/runtime/v2/shim" | ||||
| 	taskAPI "github.com/containerd/containerd/runtime/v2/task" | ||||
| 	ptypes "github.com/gogo/protobuf/types" | ||||
| @@ -37,7 +36,7 @@ var ( | ||||
| ) | ||||
|  | ||||
| // New returns a new shim service | ||||
| func New(ctx context.Context, id string, publisher events.Publisher) (shim.Shim, error) { | ||||
| func New(ctx context.Context, id string, publisher shim.Publisher, shutdown func()) (shim.Shim, error) { | ||||
| 	return &service{}, nil | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -24,15 +24,15 @@ import ( | ||||
|  | ||||
| 	"github.com/containerd/cgroups" | ||||
| 	eventstypes "github.com/containerd/containerd/api/events" | ||||
| 	"github.com/containerd/containerd/events" | ||||
| 	"github.com/containerd/containerd/runtime" | ||||
| 	"github.com/containerd/containerd/runtime/v2/shim" | ||||
| 	"github.com/sirupsen/logrus" | ||||
| 	"golang.org/x/sys/unix" | ||||
| ) | ||||
|  | ||||
| // NewOOMEpoller returns an epoll implementation that listens to OOM events | ||||
| // from a container's cgroups. | ||||
| func NewOOMEpoller(publisher events.Publisher) (*Epoller, error) { | ||||
| func NewOOMEpoller(publisher shim.Publisher) (*Epoller, error) { | ||||
| 	fd, err := unix.EpollCreate1(unix.EPOLL_CLOEXEC) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| @@ -49,7 +49,7 @@ type Epoller struct { | ||||
| 	mu sync.Mutex | ||||
|  | ||||
| 	fd        int | ||||
| 	publisher events.Publisher | ||||
| 	publisher shim.Publisher | ||||
| 	set       map[uintptr]*item | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -33,7 +33,6 @@ import ( | ||||
| 	eventstypes "github.com/containerd/containerd/api/events" | ||||
| 	"github.com/containerd/containerd/api/types/task" | ||||
| 	"github.com/containerd/containerd/errdefs" | ||||
| 	"github.com/containerd/containerd/events" | ||||
| 	"github.com/containerd/containerd/log" | ||||
| 	"github.com/containerd/containerd/mount" | ||||
| 	"github.com/containerd/containerd/namespaces" | ||||
| @@ -58,12 +57,11 @@ var ( | ||||
| ) | ||||
|  | ||||
| // New returns a new shim service that can be used via GRPC | ||||
| func New(ctx context.Context, id string, publisher events.Publisher) (shim.Shim, error) { | ||||
| func New(ctx context.Context, id string, publisher shim.Publisher, shutdown func()) (shim.Shim, error) { | ||||
| 	ep, err := runc.NewOOMEpoller(publisher) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	ctx, cancel := context.WithCancel(ctx) | ||||
| 	go ep.Run(ctx) | ||||
| 	s := &service{ | ||||
| 		id:      id, | ||||
| @@ -71,15 +69,15 @@ func New(ctx context.Context, id string, publisher events.Publisher) (shim.Shim, | ||||
| 		events:  make(chan interface{}, 128), | ||||
| 		ec:      shim.Default.Subscribe(), | ||||
| 		ep:      ep, | ||||
| 		cancel:  cancel, | ||||
| 		cancel:  shutdown, | ||||
| 	} | ||||
| 	go s.processExits() | ||||
| 	runcC.Monitor = shim.Default | ||||
| 	if err := s.initPlatform(); err != nil { | ||||
| 		cancel() | ||||
| 		shutdown() | ||||
| 		return nil, errors.Wrap(err, "failed to initialized platform behavior") | ||||
| 	} | ||||
| 	go s.forward(publisher) | ||||
| 	go s.forward(ctx, publisher) | ||||
| 	return s, nil | ||||
| } | ||||
|  | ||||
| @@ -511,7 +509,7 @@ func (s *service) Connect(ctx context.Context, r *taskAPI.ConnectRequest) (*task | ||||
|  | ||||
| func (s *service) Shutdown(ctx context.Context, r *taskAPI.ShutdownRequest) (*ptypes.Empty, error) { | ||||
| 	s.cancel() | ||||
| 	os.Exit(0) | ||||
| 	close(s.events) | ||||
| 	return empty, nil | ||||
| } | ||||
|  | ||||
| @@ -619,15 +617,18 @@ func (s *service) getContainerPids(ctx context.Context, id string) ([]uint32, er | ||||
| 	return pids, nil | ||||
| } | ||||
|  | ||||
| func (s *service) forward(publisher events.Publisher) { | ||||
| func (s *service) forward(ctx context.Context, publisher shim.Publisher) { | ||||
| 	ns, _ := namespaces.Namespace(ctx) | ||||
| 	ctx = namespaces.WithNamespace(context.Background(), ns) | ||||
| 	for e := range s.events { | ||||
| 		ctx, cancel := context.WithTimeout(s.context, 5*time.Second) | ||||
| 		ctx, cancel := context.WithTimeout(ctx, 5*time.Second) | ||||
| 		err := publisher.Publish(ctx, runc.GetTopic(e), e) | ||||
| 		cancel() | ||||
| 		if err != nil { | ||||
| 			logrus.WithError(err).Error("post event") | ||||
| 		} | ||||
| 	} | ||||
| 	publisher.Close() | ||||
| } | ||||
|  | ||||
| func (s *service) getContainer() (*runc.Container, error) { | ||||
|   | ||||
| @@ -34,7 +34,6 @@ import ( | ||||
| 	eventstypes "github.com/containerd/containerd/api/events" | ||||
| 	"github.com/containerd/containerd/api/types/task" | ||||
| 	"github.com/containerd/containerd/errdefs" | ||||
| 	"github.com/containerd/containerd/events" | ||||
| 	"github.com/containerd/containerd/log" | ||||
| 	"github.com/containerd/containerd/mount" | ||||
| 	"github.com/containerd/containerd/namespaces" | ||||
| @@ -71,12 +70,11 @@ type spec struct { | ||||
| } | ||||
|  | ||||
| // New returns a new shim service that can be used via GRPC | ||||
| func New(ctx context.Context, id string, publisher events.Publisher) (shim.Shim, error) { | ||||
| func New(ctx context.Context, id string, publisher shim.Publisher, shutdown func()) (shim.Shim, error) { | ||||
| 	ep, err := runc.NewOOMEpoller(publisher) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	ctx, cancel := context.WithCancel(ctx) | ||||
| 	go ep.Run(ctx) | ||||
| 	s := &service{ | ||||
| 		id:         id, | ||||
| @@ -84,16 +82,16 @@ func New(ctx context.Context, id string, publisher events.Publisher) (shim.Shim, | ||||
| 		events:     make(chan interface{}, 128), | ||||
| 		ec:         shim.Default.Subscribe(), | ||||
| 		ep:         ep, | ||||
| 		cancel:     cancel, | ||||
| 		cancel:     shutdown, | ||||
| 		containers: make(map[string]*runc.Container), | ||||
| 	} | ||||
| 	go s.processExits() | ||||
| 	runcC.Monitor = shim.Default | ||||
| 	if err := s.initPlatform(); err != nil { | ||||
| 		cancel() | ||||
| 		shutdown() | ||||
| 		return nil, errors.Wrap(err, "failed to initialized platform behavior") | ||||
| 	} | ||||
| 	go s.forward(publisher) | ||||
| 	go s.forward(ctx, publisher) | ||||
| 	return s, nil | ||||
| } | ||||
|  | ||||
| @@ -570,7 +568,7 @@ func (s *service) Shutdown(ctx context.Context, r *taskAPI.ShutdownRequest) (*pt | ||||
| 		return empty, nil | ||||
| 	} | ||||
| 	s.cancel() | ||||
| 	os.Exit(0) | ||||
| 	close(s.events) | ||||
| 	return empty, nil | ||||
| } | ||||
|  | ||||
| @@ -689,15 +687,18 @@ func (s *service) getContainerPids(ctx context.Context, id string) ([]uint32, er | ||||
| 	return pids, nil | ||||
| } | ||||
|  | ||||
| func (s *service) forward(publisher events.Publisher) { | ||||
| func (s *service) forward(ctx context.Context, publisher shim.Publisher) { | ||||
| 	ns, _ := namespaces.Namespace(ctx) | ||||
| 	ctx = namespaces.WithNamespace(context.Background(), ns) | ||||
| 	for e := range s.events { | ||||
| 		ctx, cancel := context.WithTimeout(s.context, 5*time.Second) | ||||
| 		ctx, cancel := context.WithTimeout(ctx, 5*time.Second) | ||||
| 		err := publisher.Publish(ctx, runc.GetTopic(e), e) | ||||
| 		cancel() | ||||
| 		if err != nil { | ||||
| 			logrus.WithError(err).Error("post event") | ||||
| 		} | ||||
| 	} | ||||
| 	publisher.Close() | ||||
| } | ||||
|  | ||||
| func (s *service) getContainer(id string) (*runc.Container, error) { | ||||
|   | ||||
| @@ -41,7 +41,6 @@ import ( | ||||
| 	containerd_types "github.com/containerd/containerd/api/types" | ||||
| 	"github.com/containerd/containerd/api/types/task" | ||||
| 	"github.com/containerd/containerd/errdefs" | ||||
| 	"github.com/containerd/containerd/events" | ||||
| 	"github.com/containerd/containerd/log" | ||||
| 	"github.com/containerd/containerd/mount" | ||||
| 	"github.com/containerd/containerd/namespaces" | ||||
| @@ -129,12 +128,13 @@ func forwardRunhcsLogs(ctx context.Context, c net.Conn, fields logrus.Fields) { | ||||
| } | ||||
|  | ||||
| // New returns a new runhcs shim service that can be used via GRPC | ||||
| func New(ctx context.Context, id string, publisher events.Publisher) (shim.Shim, error) { | ||||
| func New(ctx context.Context, id string, publisher shim.Publisher, shutdown func()) (shim.Shim, error) { | ||||
| 	return &service{ | ||||
| 		context:   ctx, | ||||
| 		id:        id, | ||||
| 		processes: make(map[string]*process), | ||||
| 		publisher: publisher, | ||||
| 		shutdown:  shutdown, | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
| @@ -159,7 +159,8 @@ type service struct { | ||||
| 	id        string | ||||
| 	processes map[string]*process | ||||
|  | ||||
| 	publisher events.Publisher | ||||
| 	publisher shim.Publisher | ||||
| 	shutdown  func() | ||||
| } | ||||
|  | ||||
| func (s *service) newRunhcs() *runhcs.Runhcs { | ||||
| @@ -1068,7 +1069,8 @@ func (s *service) Shutdown(ctx context.Context, r *taskAPI.ShutdownRequest) (*pt | ||||
| 	if s.debugListener != nil { | ||||
| 		s.debugListener.Close() | ||||
| 	} | ||||
| 	s.publisher.Close() | ||||
| 	s.shutdown() | ||||
|  | ||||
| 	os.Exit(0) | ||||
| 	return empty, nil | ||||
| } | ||||
|   | ||||
| @@ -20,11 +20,13 @@ import ( | ||||
| 	"context" | ||||
| 	"flag" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net" | ||||
| 	"os" | ||||
| 	"runtime" | ||||
| 	"runtime/debug" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"time" | ||||
|  | ||||
| 	v1 "github.com/containerd/containerd/api/services/ttrpc/events/v1" | ||||
| @@ -46,8 +48,14 @@ type Client struct { | ||||
| 	signals chan os.Signal | ||||
| } | ||||
|  | ||||
| // Publisher for events | ||||
| type Publisher interface { | ||||
| 	events.Publisher | ||||
| 	io.Closer | ||||
| } | ||||
|  | ||||
| // Init func for the creation of a shim server | ||||
| type Init func(context.Context, string, events.Publisher) (Shim, error) | ||||
| type Init func(context.Context, string, Publisher, func()) (Shim, error) | ||||
|  | ||||
| // Shim server interface | ||||
| type Shim interface { | ||||
| @@ -156,15 +164,18 @@ func run(id string, initFunc Init, config Config) error { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	publisher := &remoteEventsPublisher{ | ||||
| 		address: fmt.Sprintf("%s.ttrpc", addressFlag), | ||||
| 	} | ||||
| 	conn, err := connect(publisher.address, dialer) | ||||
| 	address := fmt.Sprintf("%s.ttrpc", addressFlag) | ||||
| 	conn, err := connect(address, dialer) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	defer conn.Close() | ||||
| 	publisher := &remoteEventsPublisher{ | ||||
| 		address: address, | ||||
| 		conn:    conn, | ||||
| 		closed:  make(chan struct{}), | ||||
| 	} | ||||
| 	defer publisher.Close() | ||||
|  | ||||
| 	publisher.client = v1.NewEventsClient(ttrpc.NewClient(conn)) | ||||
| 	if namespaceFlag == "" { | ||||
| 		return fmt.Errorf("shim namespace cannot be empty") | ||||
| @@ -172,8 +183,9 @@ func run(id string, initFunc Init, config Config) error { | ||||
| 	ctx := namespaces.WithNamespace(context.Background(), namespaceFlag) | ||||
| 	ctx = context.WithValue(ctx, OptsKey{}, Opts{BundlePath: bundlePath, Debug: debugFlag}) | ||||
| 	ctx = log.WithLogger(ctx, log.G(ctx).WithField("runtime", id)) | ||||
| 	ctx, cancel := context.WithCancel(ctx) | ||||
|  | ||||
| 	service, err := initFunc(ctx, idFlag, publisher) | ||||
| 	service, err := initFunc(ctx, idFlag, publisher, cancel) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @@ -183,7 +195,7 @@ func run(id string, initFunc Init, config Config) error { | ||||
| 			"pid":       os.Getpid(), | ||||
| 			"namespace": namespaceFlag, | ||||
| 		}) | ||||
| 		go handleSignals(logger, signals) | ||||
| 		go handleSignals(ctx, logger, signals) | ||||
| 		response, err := service.Cleanup(ctx) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| @@ -210,7 +222,17 @@ func run(id string, initFunc Init, config Config) error { | ||||
| 			return err | ||||
| 		} | ||||
| 		client := NewShimClient(ctx, service, signals) | ||||
| 		return client.Serve() | ||||
| 		if err := client.Serve(); err != nil { | ||||
| 			if err != context.Canceled { | ||||
| 				return err | ||||
| 			} | ||||
| 		} | ||||
| 		select { | ||||
| 		case <-publisher.Done(): | ||||
| 			return nil | ||||
| 		case <-time.After(5 * time.Second): | ||||
| 			return errors.New("publisher not closed") | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -254,7 +276,7 @@ func (s *Client) Serve() error { | ||||
| 			dumpStacks(logger) | ||||
| 		} | ||||
| 	}() | ||||
| 	return handleSignals(logger, s.signals) | ||||
| 	return handleSignals(s.context, logger, s.signals) | ||||
| } | ||||
|  | ||||
| // serve serves the ttrpc API over a unix socket at the provided path | ||||
| @@ -291,7 +313,22 @@ func dumpStacks(logger *logrus.Entry) { | ||||
|  | ||||
| type remoteEventsPublisher struct { | ||||
| 	address string | ||||
| 	conn    net.Conn | ||||
| 	client  v1.EventsService | ||||
| 	closed  chan struct{} | ||||
| 	closer  sync.Once | ||||
| } | ||||
|  | ||||
| func (l *remoteEventsPublisher) Done() <-chan struct{} { | ||||
| 	return l.closed | ||||
| } | ||||
|  | ||||
| func (l *remoteEventsPublisher) Close() (err error) { | ||||
| 	l.closer.Do(func() { | ||||
| 		err = l.conn.Close() | ||||
| 		close(l.closed) | ||||
| 	}) | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func (l *remoteEventsPublisher) Publish(ctx context.Context, topic string, event events.Event) error { | ||||
|   | ||||
| @@ -71,11 +71,14 @@ func serveListener(path string) (net.Listener, error) { | ||||
| 	return l, nil | ||||
| } | ||||
|  | ||||
| func handleSignals(logger *logrus.Entry, signals chan os.Signal) error { | ||||
| func handleSignals(ctx context.Context, logger *logrus.Entry, signals chan os.Signal) error { | ||||
| 	logger.Info("starting signal loop") | ||||
|  | ||||
| 	for { | ||||
| 		for s := range signals { | ||||
| 		select { | ||||
| 		case <-ctx.Done(): | ||||
| 			return ctx.Err() | ||||
| 		case s := <-signals: | ||||
| 			switch s { | ||||
| 			case unix.SIGCHLD: | ||||
| 				if err := Reap(); err != nil { | ||||
|   | ||||
| @@ -104,11 +104,14 @@ func serveListener(path string) (net.Listener, error) { | ||||
| 	return l, nil | ||||
| } | ||||
|  | ||||
| func handleSignals(logger *logrus.Entry, signals chan os.Signal) error { | ||||
| func handleSignals(ctx context.Context, logger *logrus.Entry, signals chan os.Signal) error { | ||||
| 	logger.Info("starting signal loop") | ||||
|  | ||||
| 	for { | ||||
| 		for s := range signals { | ||||
| 		select { | ||||
| 		case <-ctx.Done(): | ||||
| 			return ctx.Err() | ||||
| 		case s := <-signals: | ||||
| 			switch s { | ||||
| 			case os.Interrupt: | ||||
| 				return nil | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Michael Crosby
					Michael Crosby