Move shim package to pkg

Signed-off-by: Maksym Pavlenko <pavlenko.maksym@gmail.com>
This commit is contained in:
Maksym Pavlenko
2024-03-04 17:23:57 -08:00
parent e53663cca7
commit 6a96e45012
29 changed files with 17 additions and 17 deletions

170
pkg/shim/publisher.go Normal file
View File

@@ -0,0 +1,170 @@
/*
Copyright The containerd Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package shim
import (
"context"
"sync"
"time"
v1 "github.com/containerd/containerd/v2/api/services/ttrpc/events/v1"
"github.com/containerd/containerd/v2/api/types"
"github.com/containerd/containerd/v2/core/events"
"github.com/containerd/containerd/v2/pkg/namespaces"
"github.com/containerd/containerd/v2/pkg/ttrpcutil"
"github.com/containerd/containerd/v2/protobuf"
"github.com/containerd/log"
"github.com/containerd/ttrpc"
)
const (
queueSize = 2048
maxRequeue = 5
)
type item struct {
ev *types.Envelope
ctx context.Context
count int
}
// NewPublisher creates a new remote events publisher
func NewPublisher(address string) (*RemoteEventsPublisher, error) {
client, err := ttrpcutil.NewClient(address)
if err != nil {
return nil, err
}
l := &RemoteEventsPublisher{
client: client,
closed: make(chan struct{}),
requeue: make(chan *item, queueSize),
}
go l.processQueue()
return l, nil
}
// RemoteEventsPublisher forwards events to a ttrpc server
type RemoteEventsPublisher struct {
client *ttrpcutil.Client
closed chan struct{}
closer sync.Once
requeue chan *item
}
// Done returns a channel which closes when done
func (l *RemoteEventsPublisher) Done() <-chan struct{} {
return l.closed
}
// Close closes the remote connection and closes the done channel
func (l *RemoteEventsPublisher) Close() (err error) {
err = l.client.Close()
l.closer.Do(func() {
close(l.closed)
})
return err
}
func (l *RemoteEventsPublisher) processQueue() {
for i := range l.requeue {
if i.count > maxRequeue {
log.L.Errorf("evicting %s from queue because of retry count", i.ev.Topic)
// drop the event
continue
}
if err := l.forwardRequest(i.ctx, &v1.ForwardRequest{Envelope: i.ev}); err != nil {
log.L.WithError(err).Error("forward event")
l.queue(i)
}
}
}
func (l *RemoteEventsPublisher) queue(i *item) {
go func() {
i.count++
// re-queue after a short delay
time.Sleep(time.Duration(1*i.count) * time.Second)
l.requeue <- i
}()
}
// Publish publishes the event by forwarding it to the configured ttrpc server
func (l *RemoteEventsPublisher) Publish(ctx context.Context, topic string, event events.Event) error {
ns, err := namespaces.NamespaceRequired(ctx)
if err != nil {
return err
}
evt, err := protobuf.MarshalAnyToProto(event)
if err != nil {
return err
}
i := &item{
ev: &types.Envelope{
Timestamp: protobuf.ToTimestamp(time.Now()),
Namespace: ns,
Topic: topic,
Event: evt,
},
ctx: ctx,
}
if err := l.forwardRequest(i.ctx, &v1.ForwardRequest{Envelope: i.ev}); err != nil {
l.queue(i)
return err
}
return nil
}
func (l *RemoteEventsPublisher) forwardRequest(ctx context.Context, req *v1.ForwardRequest) error {
service, err := l.client.EventsService()
if err == nil {
fCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
_, err = service.Forward(fCtx, req)
cancel()
if err == nil {
return nil
}
}
if err != ttrpc.ErrClosed {
return err
}
// Reconnect and retry request
if err = l.client.Reconnect(); err != nil {
return err
}
service, err = l.client.EventsService()
if err != nil {
return err
}
// try again with a fresh context, otherwise we may get a context timeout unexpectedly.
fCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
_, err = service.Forward(fCtx, req)
cancel()
if err != nil {
return err
}
return nil
}

481
pkg/shim/shim.go Normal file
View File

@@ -0,0 +1,481 @@
/*
Copyright The containerd Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package shim
import (
"context"
"encoding/json"
"errors"
"flag"
"fmt"
"io"
"net"
"os"
"path/filepath"
"runtime"
"runtime/debug"
"time"
shimapi "github.com/containerd/containerd/v2/api/runtime/task/v3"
"github.com/containerd/containerd/v2/api/types"
"github.com/containerd/containerd/v2/core/events"
"github.com/containerd/containerd/v2/pkg/namespaces"
"github.com/containerd/containerd/v2/pkg/shutdown"
"github.com/containerd/containerd/v2/plugins"
"github.com/containerd/containerd/v2/protobuf"
"github.com/containerd/containerd/v2/protobuf/proto"
"github.com/containerd/containerd/v2/version"
"github.com/containerd/log"
"github.com/containerd/plugin"
"github.com/containerd/plugin/registry"
"github.com/containerd/ttrpc"
"github.com/sirupsen/logrus"
)
// Publisher for events
type Publisher interface {
events.Publisher
io.Closer
}
// StartOpts describes shim start configuration received from containerd
type StartOpts struct {
Address string
TTRPCAddress string
Debug bool
}
// BootstrapParams is a JSON payload returned in stdout from shim.Start call.
type BootstrapParams struct {
// Version is the version of shim parameters (expected 2 for shim v2)
Version int `json:"version"`
// Address is a address containerd should use to connect to shim.
Address string `json:"address"`
// Protocol is either TTRPC or GRPC.
Protocol string `json:"protocol"`
}
type StopStatus struct {
Pid int
ExitStatus int
ExitedAt time.Time
}
// Manager is the interface which manages the shim process
type Manager interface {
Name() string
Start(ctx context.Context, id string, opts StartOpts) (BootstrapParams, error)
Stop(ctx context.Context, id string) (StopStatus, error)
Info(ctx context.Context, optionsR io.Reader) (*types.RuntimeInfo, error)
}
// OptsKey is the context key for the Opts value.
type OptsKey struct{}
// Opts are context options associated with the shim invocation.
type Opts struct {
BundlePath string
Debug bool
}
// BinaryOpts allows the configuration of a shims binary setup
type BinaryOpts func(*Config)
// Config of shim binary options provided by shim implementations
type Config struct {
// NoSubreaper disables setting the shim as a child subreaper
NoSubreaper bool
// NoReaper disables the shim binary from reaping any child process implicitly
NoReaper bool
// NoSetupLogger disables automatic configuration of logrus to use the shim FIFO
NoSetupLogger bool
}
type TTRPCService interface {
RegisterTTRPC(*ttrpc.Server) error
}
type TTRPCServerOptioner interface {
TTRPCService
UnaryInterceptor() ttrpc.UnaryServerInterceptor
}
var (
debugFlag bool
versionFlag bool
infoFlag bool
id string
namespaceFlag string
socketFlag string
bundlePath string
addressFlag string
containerdBinaryFlag string
action string
)
const (
ttrpcAddressEnv = "TTRPC_ADDRESS"
grpcAddressEnv = "GRPC_ADDRESS"
namespaceEnv = "NAMESPACE"
maxVersionEnv = "MAX_SHIM_VERSION"
)
func parseFlags() {
flag.BoolVar(&debugFlag, "debug", false, "enable debug output in logs")
flag.BoolVar(&versionFlag, "v", false, "show the shim version and exit")
// "info" is not a subcommand, because old shims produce very confusing errors for unknown subcommands
// https://github.com/containerd/containerd/pull/8509#discussion_r1210021403
flag.BoolVar(&infoFlag, "info", false, "get the option protobuf from stdin, print the shim info protobuf to stdout, and exit")
flag.StringVar(&namespaceFlag, "namespace", "", "namespace that owns the shim")
flag.StringVar(&id, "id", "", "id of the task")
flag.StringVar(&socketFlag, "socket", "", "socket path to serve")
flag.StringVar(&bundlePath, "bundle", "", "path to the bundle if not workdir")
flag.StringVar(&addressFlag, "address", "", "grpc address back to main containerd")
flag.StringVar(&containerdBinaryFlag, "publish-binary", "",
fmt.Sprintf("path to publish binary (used for publishing events), but %s will ignore this flag, please use the %s env", os.Args[0], ttrpcAddressEnv),
)
flag.Parse()
action = flag.Arg(0)
}
func setRuntime() {
debug.SetGCPercent(40)
go func() {
for range time.Tick(30 * time.Second) {
debug.FreeOSMemory()
}
}()
if os.Getenv("GOMAXPROCS") == "" {
// If GOMAXPROCS hasn't been set, we default to a value of 2 to reduce
// the number of Go stacks present in the shim.
runtime.GOMAXPROCS(2)
}
}
func setLogger(ctx context.Context, id string) (context.Context, error) {
l := log.G(ctx)
l.Logger.SetFormatter(&logrus.TextFormatter{
TimestampFormat: log.RFC3339NanoFixed,
FullTimestamp: true,
})
if debugFlag {
l.Logger.SetLevel(log.DebugLevel)
}
f, err := openLog(ctx, id)
if err != nil {
return ctx, err
}
l.Logger.SetOutput(f)
return log.WithLogger(ctx, l), nil
}
// Run initializes and runs a shim server.
func Run(ctx context.Context, manager Manager, opts ...BinaryOpts) {
var config Config
for _, o := range opts {
o(&config)
}
ctx = log.WithLogger(ctx, log.G(ctx).WithField("runtime", manager.Name()))
if err := run(ctx, manager, config); err != nil {
fmt.Fprintf(os.Stderr, "%s: %s", manager.Name(), err)
os.Exit(1)
}
}
func runInfo(ctx context.Context, manager Manager) error {
info, err := manager.Info(ctx, os.Stdin)
if err != nil {
return err
}
infoB, err := proto.Marshal(info)
if err != nil {
return err
}
_, err = os.Stdout.Write(infoB)
return err
}
func run(ctx context.Context, manager Manager, config Config) error {
parseFlags()
if versionFlag {
fmt.Printf("%s:\n", filepath.Base(os.Args[0]))
fmt.Println(" Version: ", version.Version)
fmt.Println(" Revision:", version.Revision)
fmt.Println(" Go version:", version.GoVersion)
fmt.Println("")
return nil
}
if infoFlag {
return runInfo(ctx, manager)
}
if namespaceFlag == "" {
return fmt.Errorf("shim namespace cannot be empty")
}
setRuntime()
signals, err := setupSignals(config)
if err != nil {
return err
}
if !config.NoSubreaper {
if err := subreaper(); err != nil {
return err
}
}
ttrpcAddress := os.Getenv(ttrpcAddressEnv)
publisher, err := NewPublisher(ttrpcAddress)
if err != nil {
return err
}
defer publisher.Close()
ctx = namespaces.WithNamespace(ctx, namespaceFlag)
ctx = context.WithValue(ctx, OptsKey{}, Opts{BundlePath: bundlePath, Debug: debugFlag})
ctx, sd := shutdown.WithShutdown(ctx)
defer sd.Shutdown()
// Handle explicit actions
switch action {
case "delete":
logger := log.G(ctx).WithFields(log.Fields{
"pid": os.Getpid(),
"namespace": namespaceFlag,
})
if debugFlag {
logger.Logger.SetLevel(log.DebugLevel)
}
go reap(ctx, logger, signals)
ss, err := manager.Stop(ctx, id)
if err != nil {
return err
}
data, err := proto.Marshal(&shimapi.DeleteResponse{
Pid: uint32(ss.Pid),
ExitStatus: uint32(ss.ExitStatus),
ExitedAt: protobuf.ToTimestamp(ss.ExitedAt),
})
if err != nil {
return err
}
if _, err := os.Stdout.Write(data); err != nil {
return err
}
return nil
case "start":
opts := StartOpts{
Address: addressFlag,
TTRPCAddress: ttrpcAddress,
Debug: debugFlag,
}
params, err := manager.Start(ctx, id, opts)
if err != nil {
return err
}
data, err := json.Marshal(&params)
if err != nil {
return fmt.Errorf("failed to marshal bootstrap params to json: %w", err)
}
if _, err := os.Stdout.Write(data); err != nil {
return err
}
return nil
}
if !config.NoSetupLogger {
ctx, err = setLogger(ctx, id)
if err != nil {
return err
}
}
registry.Register(&plugin.Registration{
Type: plugins.InternalPlugin,
ID: "shutdown",
InitFn: func(ic *plugin.InitContext) (interface{}, error) {
return sd, nil
},
})
// Register event plugin
registry.Register(&plugin.Registration{
Type: plugins.EventPlugin,
ID: "publisher",
InitFn: func(ic *plugin.InitContext) (interface{}, error) {
return publisher, nil
},
})
var (
initialized = plugin.NewPluginSet()
ttrpcServices = []TTRPCService{}
ttrpcUnaryInterceptors = []ttrpc.UnaryServerInterceptor{}
)
for _, p := range registry.Graph(func(*plugin.Registration) bool { return false }) {
pID := p.URI()
log.G(ctx).WithFields(log.Fields{"id": pID, "type": p.Type}).Debug("loading plugin")
initContext := plugin.NewContext(
ctx,
initialized,
map[string]string{
// NOTE: Root is empty since the shim does not support persistent storage,
// shim plugins should make use state directory for writing files to disk.
// The state directory will be destroyed when the shim if cleaned up or
// on reboot
plugins.PropertyStateDir: filepath.Join(bundlePath, p.URI()),
plugins.PropertyGRPCAddress: addressFlag,
plugins.PropertyTTRPCAddress: ttrpcAddress,
},
)
// load the plugin specific configuration if it is provided
// TODO: Read configuration passed into shim, or from state directory?
// if p.Config != nil {
// pc, err := config.Decode(p)
// if err != nil {
// return nil, err
// }
// initContext.Config = pc
// }
result := p.Init(initContext)
if err := initialized.Add(result); err != nil {
return fmt.Errorf("could not add plugin result to plugin set: %w", err)
}
instance, err := result.Instance()
if err != nil {
if plugin.IsSkipPlugin(err) {
log.G(ctx).WithFields(log.Fields{"id": pID, "type": p.Type, "error": err}).Info("skip loading plugin")
continue
}
return fmt.Errorf("failed to load plugin %s: %w", pID, err)
}
if src, ok := instance.(TTRPCService); ok {
log.G(ctx).WithField("id", pID).Debug("registering ttrpc service")
ttrpcServices = append(ttrpcServices, src)
}
if src, ok := instance.(TTRPCServerOptioner); ok {
ttrpcUnaryInterceptors = append(ttrpcUnaryInterceptors, src.UnaryInterceptor())
}
}
if len(ttrpcServices) == 0 {
return fmt.Errorf("required that ttrpc service")
}
unaryInterceptor := chainUnaryServerInterceptors(ttrpcUnaryInterceptors...)
server, err := newServer(ttrpc.WithUnaryServerInterceptor(unaryInterceptor))
if err != nil {
return fmt.Errorf("failed creating server: %w", err)
}
for _, srv := range ttrpcServices {
if err := srv.RegisterTTRPC(server); err != nil {
return fmt.Errorf("failed to register service: %w", err)
}
}
if err := serve(ctx, server, signals, sd.Shutdown); err != nil {
if !errors.Is(err, shutdown.ErrShutdown) {
return err
}
}
// NOTE: If the shim server is down(like oom killer), the address
// socket might be leaking.
if address, err := ReadAddress("address"); err == nil {
_ = RemoveSocket(address)
}
select {
case <-sd.Done():
return nil
case <-time.After(5 * time.Second):
return errors.New("shim shutdown timeout")
}
}
// serve serves the ttrpc API over a unix socket in the current working directory
// and blocks until the context is canceled
func serve(ctx context.Context, server *ttrpc.Server, signals chan os.Signal, shutdown func()) error {
dump := make(chan os.Signal, 32)
setupDumpStacks(dump)
path, err := os.Getwd()
if err != nil {
return err
}
l, err := serveListener(socketFlag)
if err != nil {
return err
}
go func() {
defer l.Close()
if err := server.Serve(ctx, l); err != nil && !errors.Is(err, net.ErrClosed) {
log.G(ctx).WithError(err).Fatal("containerd-shim: ttrpc server failure")
}
}()
logger := log.G(ctx).WithFields(log.Fields{
"pid": os.Getpid(),
"path": path,
"namespace": namespaceFlag,
})
go func() {
for range dump {
dumpStacks(logger)
}
}()
go handleExitSignals(ctx, logger, shutdown)
return reap(ctx, logger, signals)
}
func dumpStacks(logger *log.Entry) {
var (
buf []byte
stackSize int
)
bufferLen := 16384
for stackSize == len(buf) {
buf = make([]byte, bufferLen)
stackSize = runtime.Stack(buf, true)
bufferLen *= 2
}
buf = buf[:stackSize]
logger.Infof("=== BEGIN goroutine stack dump ===\n%s\n=== END goroutine stack dump ===", buf)
}

27
pkg/shim/shim_darwin.go Normal file
View File

@@ -0,0 +1,27 @@
/*
Copyright The containerd Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package shim
import "github.com/containerd/ttrpc"
func newServer(opts ...ttrpc.ServerOpt) (*ttrpc.Server, error) {
return ttrpc.NewServer(opts...)
}
func subreaper() error {
return nil
}

27
pkg/shim/shim_freebsd.go Normal file
View File

@@ -0,0 +1,27 @@
/*
Copyright The containerd Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package shim
import "github.com/containerd/ttrpc"
func newServer(opts ...ttrpc.ServerOpt) (*ttrpc.Server, error) {
return ttrpc.NewServer(opts...)
}
func subreaper() error {
return nil
}

31
pkg/shim/shim_linux.go Normal file
View File

@@ -0,0 +1,31 @@
/*
Copyright The containerd Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package shim
import (
"github.com/containerd/containerd/v2/pkg/sys/reaper"
"github.com/containerd/ttrpc"
)
func newServer(opts ...ttrpc.ServerOpt) (*ttrpc.Server, error) {
opts = append(opts, ttrpc.WithServerHandshaker(ttrpc.UnixSocketRequireSameUser()))
return ttrpc.NewServer(opts...)
}
func subreaper() error {
return reaper.SetSubreaper(1)
}

62
pkg/shim/shim_test.go Normal file
View File

@@ -0,0 +1,62 @@
/*
Copyright The containerd Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package shim
import (
"context"
"runtime"
"testing"
)
func TestRuntimeWithEmptyMaxEnvProcs(t *testing.T) {
var oldGoMaxProcs = runtime.GOMAXPROCS(0)
defer runtime.GOMAXPROCS(oldGoMaxProcs)
t.Setenv("GOMAXPROCS", "")
setRuntime()
var currentGoMaxProcs = runtime.GOMAXPROCS(0)
if currentGoMaxProcs != 2 {
t.Fatal("the max number of procs should be 2")
}
}
func TestRuntimeWithNonEmptyMaxEnvProcs(t *testing.T) {
t.Setenv("GOMAXPROCS", "not_empty")
setRuntime()
var oldGoMaxProcs2 = runtime.GOMAXPROCS(0)
if oldGoMaxProcs2 != runtime.NumCPU() {
t.Fatal("the max number CPU should be equal to available CPUs")
}
}
func TestShimOptWithValue(t *testing.T) {
ctx := context.TODO()
ctx = context.WithValue(ctx, OptsKey{}, Opts{Debug: true})
o := ctx.Value(OptsKey{})
if o == nil {
t.Fatal("opts nil")
}
op, ok := o.(Opts)
if !ok {
t.Fatal("opts not of type Opts")
}
if !op.Debug {
t.Fatal("opts.Debug should be true")
}
}

113
pkg/shim/shim_unix.go Normal file
View File

@@ -0,0 +1,113 @@
//go:build !windows
/*
Copyright The containerd Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package shim
import (
"context"
"fmt"
"io"
"net"
"os"
"os/signal"
"syscall"
"github.com/containerd/containerd/v2/pkg/sys/reaper"
"github.com/containerd/fifo"
"github.com/containerd/log"
"github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
)
// setupSignals creates a new signal handler for all signals and sets the shim as a
// sub-reaper so that the container processes are reparented
func setupSignals(config Config) (chan os.Signal, error) {
signals := make(chan os.Signal, 32)
smp := []os.Signal{unix.SIGTERM, unix.SIGINT, unix.SIGPIPE}
if !config.NoReaper {
smp = append(smp, unix.SIGCHLD)
}
signal.Notify(signals, smp...)
return signals, nil
}
func setupDumpStacks(dump chan<- os.Signal) {
signal.Notify(dump, syscall.SIGUSR1)
}
func serveListener(path string) (net.Listener, error) {
var (
l net.Listener
err error
)
if path == "" {
l, err = net.FileListener(os.NewFile(3, "socket"))
path = "[inherited from parent]"
} else {
if len(path) > socketPathLimit {
return nil, fmt.Errorf("%q: unix socket path too long (> %d)", path, socketPathLimit)
}
l, err = net.Listen("unix", path)
}
if err != nil {
return nil, err
}
log.L.WithField("socket", path).Debug("serving api on socket")
return l, nil
}
func reap(ctx context.Context, logger *logrus.Entry, signals chan os.Signal) error {
logger.Debug("starting signal loop")
for {
select {
case <-ctx.Done():
return ctx.Err()
case s := <-signals:
// Exit signals are handled separately from this loop
// They get registered with this channel so that we can ignore such signals for short-running actions (e.g. `delete`)
switch s {
case unix.SIGCHLD:
if err := reaper.Reap(); err != nil {
logger.WithError(err).Error("reap exit status")
}
case unix.SIGPIPE:
}
}
}
}
func handleExitSignals(ctx context.Context, logger *logrus.Entry, cancel context.CancelFunc) {
ch := make(chan os.Signal, 32)
signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM)
for {
select {
case s := <-ch:
logger.WithField("signal", s).Debugf("Caught exit signal")
cancel()
return
case <-ctx.Done():
return
}
}
}
func openLog(ctx context.Context, _ string) (io.Writer, error) {
return fifo.OpenFifoDup2(ctx, "log", unix.O_WRONLY, 0700, int(os.Stderr.Fd()))
}

58
pkg/shim/shim_windows.go Normal file
View File

@@ -0,0 +1,58 @@
/*
Copyright The containerd Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package shim
import (
"context"
"io"
"net"
"os"
"github.com/containerd/errdefs"
"github.com/containerd/ttrpc"
"github.com/sirupsen/logrus"
)
func setupSignals(config Config) (chan os.Signal, error) {
return nil, errdefs.ErrNotImplemented
}
func newServer(opts ...ttrpc.ServerOpt) (*ttrpc.Server, error) {
return nil, errdefs.ErrNotImplemented
}
func subreaper() error {
return errdefs.ErrNotImplemented
}
func setupDumpStacks(dump chan<- os.Signal) {
}
func serveListener(path string) (net.Listener, error) {
return nil, errdefs.ErrNotImplemented
}
func reap(ctx context.Context, logger *logrus.Entry, signals chan os.Signal) error {
return errdefs.ErrNotImplemented
}
func handleExitSignals(ctx context.Context, logger *logrus.Entry, cancel context.CancelFunc) {
}
func openLog(ctx context.Context, _ string) (io.Writer, error) {
return nil, errdefs.ErrNotImplemented
}

218
pkg/shim/util.go Normal file
View File

@@ -0,0 +1,218 @@
/*
Copyright The containerd Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package shim
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"os"
"os/exec"
"path/filepath"
"strings"
"time"
"github.com/containerd/ttrpc"
"github.com/containerd/typeurl/v2"
"github.com/containerd/containerd/v2/pkg/atomicfile"
"github.com/containerd/containerd/v2/pkg/namespaces"
"github.com/containerd/containerd/v2/protobuf/proto"
"github.com/containerd/containerd/v2/protobuf/types"
"github.com/containerd/errdefs"
)
type CommandConfig struct {
Runtime string
Address string
TTRPCAddress string
Path string
SchedCore bool
Args []string
Opts *types.Any
}
// Command returns the shim command with the provided args and configuration
func Command(ctx context.Context, config *CommandConfig) (*exec.Cmd, error) {
ns, err := namespaces.NamespaceRequired(ctx)
if err != nil {
return nil, err
}
self, err := os.Executable()
if err != nil {
return nil, err
}
args := []string{
"-namespace", ns,
"-address", config.Address,
"-publish-binary", self,
}
args = append(args, config.Args...)
cmd := exec.CommandContext(ctx, config.Runtime, args...)
cmd.Dir = config.Path
cmd.Env = append(
os.Environ(),
"GOMAXPROCS=2",
fmt.Sprintf("%s=2", maxVersionEnv),
fmt.Sprintf("%s=%s", ttrpcAddressEnv, config.TTRPCAddress),
fmt.Sprintf("%s=%s", grpcAddressEnv, config.Address),
fmt.Sprintf("%s=%s", namespaceEnv, ns),
)
if config.SchedCore {
cmd.Env = append(cmd.Env, "SCHED_CORE=1")
}
cmd.SysProcAttr = getSysProcAttr()
if config.Opts != nil {
d, err := proto.Marshal(config.Opts)
if err != nil {
return nil, err
}
cmd.Stdin = bytes.NewReader(d)
}
return cmd, nil
}
// BinaryName returns the shim binary name from the runtime name,
// empty string returns means runtime name is invalid
func BinaryName(runtime string) string {
// runtime name should format like $prefix.name.version
parts := strings.Split(runtime, ".")
if len(parts) < 2 || parts[0] == "" {
return ""
}
return fmt.Sprintf(shimBinaryFormat, parts[len(parts)-2], parts[len(parts)-1])
}
// BinaryPath returns the full path for the shim binary from the runtime name,
// empty string returns means runtime name is invalid
func BinaryPath(runtime string) string {
dir := filepath.Dir(runtime)
binary := BinaryName(runtime)
path, err := filepath.Abs(filepath.Join(dir, binary))
if err != nil {
return ""
}
return path
}
// Connect to the provided address
func Connect(address string, d func(string, time.Duration) (net.Conn, error)) (net.Conn, error) {
return d(address, 100*time.Second)
}
// WritePidFile writes a pid file atomically
func WritePidFile(path string, pid int) error {
path, err := filepath.Abs(path)
if err != nil {
return err
}
f, err := atomicfile.New(path, 0o644)
if err != nil {
return err
}
_, err = fmt.Fprintf(f, "%d", pid)
if err != nil {
f.Cancel()
return err
}
return f.Close()
}
// ErrNoAddress is returned when the address file has no content
var ErrNoAddress = errors.New("no shim address")
// ReadAddress returns the shim's socket address from the path
func ReadAddress(path string) (string, error) {
path, err := filepath.Abs(path)
if err != nil {
return "", err
}
data, err := os.ReadFile(path)
if err != nil {
return "", err
}
if len(data) == 0 {
return "", ErrNoAddress
}
return string(data), nil
}
// ReadRuntimeOptions reads config bytes from io.Reader and unmarshals it into the provided type.
// The type must be registered with typeurl.
//
// The function will return ErrNotFound, if the config is not provided.
// And ErrInvalidArgument, if unable to cast the config to the provided type T.
func ReadRuntimeOptions[T any](reader io.Reader) (T, error) {
var config T
data, err := io.ReadAll(reader)
if err != nil {
return config, fmt.Errorf("failed to read config bytes from stdin: %w", err)
}
if len(data) == 0 {
return config, errdefs.ErrNotFound
}
var any types.Any
if err := proto.Unmarshal(data, &any); err != nil {
return config, err
}
v, err := typeurl.UnmarshalAny(&any)
if err != nil {
return config, err
}
config, ok := v.(T)
if !ok {
return config, fmt.Errorf("invalid type %T: %w", v, errdefs.ErrInvalidArgument)
}
return config, nil
}
// chainUnaryServerInterceptors creates a single ttrpc server interceptor from
// a chain of many interceptors executed from first to last.
func chainUnaryServerInterceptors(interceptors ...ttrpc.UnaryServerInterceptor) ttrpc.UnaryServerInterceptor {
n := len(interceptors)
// force to use default interceptor in ttrpc
if n == 0 {
return nil
}
return func(ctx context.Context, unmarshal ttrpc.Unmarshaler, info *ttrpc.UnaryServerInfo, method ttrpc.Method) (interface{}, error) {
currentMethod := method
for i := n - 1; i > 0; i-- {
interceptor := interceptors[i]
innerMethod := currentMethod
currentMethod = func(currentCtx context.Context, currentUnmarshal func(interface{}) error) (interface{}, error) {
return interceptor(currentCtx, currentUnmarshal, info, innerMethod)
}
}
return interceptors[0](ctx, unmarshal, info, currentMethod)
}
}

118
pkg/shim/util_test.go Normal file
View File

@@ -0,0 +1,118 @@
/*
Copyright The containerd Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package shim
import (
"context"
"path/filepath"
"reflect"
"testing"
"github.com/containerd/ttrpc"
)
func TestChainUnaryServerInterceptors(t *testing.T) {
methodInfo := &ttrpc.UnaryServerInfo{
FullMethod: filepath.Join("/", t.Name(), "foo"),
}
type callKey struct{}
callValue := "init"
callCtx := context.WithValue(context.Background(), callKey{}, callValue)
verifyCallCtxFn := func(ctx context.Context, key interface{}, expected interface{}) {
got := ctx.Value(key)
if !reflect.DeepEqual(expected, got) {
t.Fatalf("[context(key:%s) expected %v, but got %v", key, expected, got)
}
}
verifyInfoFn := func(info *ttrpc.UnaryServerInfo) {
if !reflect.DeepEqual(methodInfo, info) {
t.Fatalf("[info] expected %+v, but got %+v", methodInfo, info)
}
}
origUnmarshaler := func(obj interface{}) error {
v := obj.(*int64)
*v *= 2
return nil
}
type firstKey struct{}
firstValue := "from first"
var firstUnmarshaler ttrpc.Unmarshaler
first := func(ctx context.Context, unmarshal ttrpc.Unmarshaler, info *ttrpc.UnaryServerInfo, method ttrpc.Method) (interface{}, error) {
verifyCallCtxFn(ctx, callKey{}, callValue)
verifyInfoFn(info)
ctx = context.WithValue(ctx, firstKey{}, firstValue)
firstUnmarshaler = func(obj interface{}) error {
if err := unmarshal(obj); err != nil {
return err
}
v := obj.(*int64)
*v *= 2
return nil
}
return method(ctx, firstUnmarshaler)
}
type secondKey struct{}
secondValue := "from second"
second := func(ctx context.Context, unmarshal ttrpc.Unmarshaler, info *ttrpc.UnaryServerInfo, method ttrpc.Method) (interface{}, error) {
verifyCallCtxFn(ctx, callKey{}, callValue)
verifyCallCtxFn(ctx, firstKey{}, firstValue)
verifyInfoFn(info)
v := int64(3) // should return 12
if err := unmarshal(&v); err != nil {
t.Fatalf("unexpected error %v", err)
}
if expected := int64(12); v != expected {
t.Fatalf("expected int64(%v), but got %v", expected, v)
}
ctx = context.WithValue(ctx, secondKey{}, secondValue)
return method(ctx, unmarshal)
}
methodFn := func(ctx context.Context, unmarshal func(interface{}) error) (interface{}, error) {
verifyCallCtxFn(ctx, callKey{}, callValue)
verifyCallCtxFn(ctx, firstKey{}, firstValue)
verifyCallCtxFn(ctx, secondKey{}, secondValue)
v := int64(2)
if err := unmarshal(&v); err != nil {
return nil, err
}
return v, nil
}
interceptor := chainUnaryServerInterceptors(first, second)
v, err := interceptor(callCtx, origUnmarshaler, methodInfo, methodFn)
if err != nil {
t.Fatalf("expected nil, but got %v", err)
}
if expected := int64(8); v != expected {
t.Fatalf("expected result is int64(%v), but got %v", expected, v)
}
}

288
pkg/shim/util_unix.go Normal file
View File

@@ -0,0 +1,288 @@
//go:build !windows
/*
Copyright The containerd Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package shim
import (
"bufio"
"context"
"crypto/sha256"
"fmt"
"io"
"math"
"net"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"syscall"
"time"
"github.com/containerd/log"
"github.com/mdlayher/vsock"
"github.com/containerd/containerd/v2/defaults"
"github.com/containerd/containerd/v2/pkg/namespaces"
"github.com/containerd/containerd/v2/pkg/sys"
)
const (
shimBinaryFormat = "containerd-shim-%s-%s"
socketPathLimit = 106
protoVsock = "vsock"
protoHybridVsock = "hvsock"
protoUnix = "unix"
)
func getSysProcAttr() *syscall.SysProcAttr {
return &syscall.SysProcAttr{
Setpgid: true,
}
}
// AdjustOOMScore sets the OOM score for the process to the parents OOM score +1
// to ensure that they parent has a lower* score than the shim
// if not already at the maximum OOM Score
func AdjustOOMScore(pid int) error {
parent := os.Getppid()
score, err := sys.GetOOMScoreAdj(parent)
if err != nil {
return fmt.Errorf("get parent OOM score: %w", err)
}
shimScore := score + 1
if err := sys.AdjustOOMScore(pid, shimScore); err != nil {
return fmt.Errorf("set shim OOM score: %w", err)
}
return nil
}
const socketRoot = defaults.DefaultStateDir
// SocketAddress returns a socket address
func SocketAddress(ctx context.Context, socketPath, id string) (string, error) {
ns, err := namespaces.NamespaceRequired(ctx)
if err != nil {
return "", err
}
d := sha256.Sum256([]byte(filepath.Join(socketPath, ns, id)))
return fmt.Sprintf("unix://%s/%x", filepath.Join(socketRoot, "s"), d), nil
}
// AnonDialer returns a dialer for a socket
func AnonDialer(address string, timeout time.Duration) (net.Conn, error) {
proto, addr, ok := strings.Cut(address, "://")
if !ok {
return net.DialTimeout("unix", socket(address).path(), timeout)
}
switch proto {
case protoVsock:
// vsock dialer can not set timeout
return dialVsock(addr)
case protoHybridVsock:
return dialHybridVsock(addr, timeout)
case protoUnix:
return net.DialTimeout("unix", socket(address).path(), timeout)
default:
return nil, fmt.Errorf("unsupported protocol: %s", proto)
}
}
// AnonReconnectDialer returns a dialer for an existing socket on reconnection
func AnonReconnectDialer(address string, timeout time.Duration) (net.Conn, error) {
return AnonDialer(address, timeout)
}
// NewSocket returns a new socket
func NewSocket(address string) (*net.UnixListener, error) {
var (
sock = socket(address)
path = sock.path()
isAbstract = sock.isAbstract()
perm = os.FileMode(0600)
)
// Darwin needs +x to access socket, otherwise it'll fail with "bind: permission denied" when running as non-root.
if runtime.GOOS == "darwin" {
perm = 0700
}
if !isAbstract {
if err := os.MkdirAll(filepath.Dir(path), perm); err != nil {
return nil, fmt.Errorf("mkdir failed for %s: %w", path, err)
}
}
l, err := net.Listen("unix", path)
if err != nil {
return nil, err
}
if !isAbstract {
if err := os.Chmod(path, perm); err != nil {
os.Remove(sock.path())
l.Close()
return nil, fmt.Errorf("chmod failed for %s: %w", path, err)
}
}
return l.(*net.UnixListener), nil
}
const abstractSocketPrefix = "\x00"
type socket string
func (s socket) isAbstract() bool {
return !strings.HasPrefix(string(s), "unix://")
}
func (s socket) path() string {
path := strings.TrimPrefix(string(s), "unix://")
// if there was no trim performed, we assume an abstract socket
if len(path) == len(s) {
path = abstractSocketPrefix + path
}
return path
}
// RemoveSocket removes the socket at the specified address if
// it exists on the filesystem
func RemoveSocket(address string) error {
sock := socket(address)
if !sock.isAbstract() {
return os.Remove(sock.path())
}
return nil
}
// SocketEaddrinuse returns true if the provided error is caused by the
// EADDRINUSE error number
func SocketEaddrinuse(err error) bool {
netErr, ok := err.(*net.OpError)
if !ok {
return false
}
if netErr.Op != "listen" {
return false
}
syscallErr, ok := netErr.Err.(*os.SyscallError)
if !ok {
return false
}
errno, ok := syscallErr.Err.(syscall.Errno)
if !ok {
return false
}
return errno == syscall.EADDRINUSE
}
// CanConnect returns true if the socket provided at the address
// is accepting new connections
func CanConnect(address string) bool {
conn, err := AnonDialer(address, 100*time.Millisecond)
if err != nil {
return false
}
conn.Close()
return true
}
func hybridVsockDialer(addr string, port uint64, timeout time.Duration) (net.Conn, error) {
timeoutCh := time.After(timeout)
// Do 10 retries before timeout
retryInterval := timeout / 10
for {
conn, err := net.DialTimeout("unix", addr, timeout)
if err != nil {
return nil, err
}
if _, err = conn.Write([]byte(fmt.Sprintf("CONNECT %d\n", port))); err != nil {
conn.Close()
return nil, err
}
errChan := make(chan error, 1)
go func() {
reader := bufio.NewReader(conn)
response, err := reader.ReadString('\n')
if err != nil {
errChan <- err
return
}
if strings.Contains(response, "OK") {
errChan <- nil
} else {
errChan <- fmt.Errorf("hybrid vsock handshake response error: %s", response)
}
}()
select {
case err = <-errChan:
if err != nil {
conn.Close()
// When it is EOF, maybe the server side is not ready.
if err == io.EOF {
log.G(context.Background()).Warnf("Read hybrid vsock got EOF, server may not ready")
time.Sleep(retryInterval)
continue
}
return nil, err
}
return conn, nil
case <-timeoutCh:
conn.Close()
return nil, fmt.Errorf("timeout waiting for hybrid vsocket handshake of %s:%d", addr, port)
}
}
}
func dialVsock(address string) (net.Conn, error) {
contextIDString, portString, ok := strings.Cut(address, ":")
if !ok {
return nil, fmt.Errorf("invalid vsock address %s", address)
}
contextID, err := strconv.ParseUint(contextIDString, 10, 0)
if err != nil {
return nil, fmt.Errorf("failed to parse vsock context id %s, %v", contextIDString, err)
}
if contextID > math.MaxUint32 {
return nil, fmt.Errorf("vsock context id %d is invalid", contextID)
}
port, err := strconv.ParseUint(portString, 10, 0)
if err != nil {
return nil, fmt.Errorf("failed to parse vsock port %s, %v", portString, err)
}
if port > math.MaxUint32 {
return nil, fmt.Errorf("vsock port %d is invalid", port)
}
return vsock.Dial(uint32(contextID), uint32(port), &vsock.Config{})
}
func dialHybridVsock(address string, timeout time.Duration) (net.Conn, error) {
addr, portString, ok := strings.Cut(address, ":")
if !ok {
return nil, fmt.Errorf("invalid hybrid vsock address %s", address)
}
port, err := strconv.ParseUint(portString, 10, 0)
if err != nil {
return nil, fmt.Errorf("failed to parse hybrid vsock port %s, %v", portString, err)
}
if port > math.MaxUint32 {
return nil, fmt.Errorf("hybrid vsock port %d is invalid", port)
}
return hybridVsockDialer(addr, port, timeout)
}

87
pkg/shim/util_windows.go Normal file
View File

@@ -0,0 +1,87 @@
/*
Copyright The containerd Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package shim
import (
"context"
"fmt"
"net"
"os"
"syscall"
"time"
winio "github.com/Microsoft/go-winio"
)
const shimBinaryFormat = "containerd-shim-%s-%s.exe"
func getSysProcAttr() *syscall.SysProcAttr {
return nil
}
// AnonReconnectDialer returns a dialer for an existing npipe on containerd reconnection
func AnonReconnectDialer(address string, timeout time.Duration) (net.Conn, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
c, err := winio.DialPipeContext(ctx, address)
if os.IsNotExist(err) {
return nil, fmt.Errorf("npipe not found on reconnect: %w", os.ErrNotExist)
} else if err == context.DeadlineExceeded {
return nil, fmt.Errorf("timed out waiting for npipe %s: %w", address, err)
} else if err != nil {
return nil, err
}
return c, nil
}
// AnonDialer returns a dialer for a npipe
func AnonDialer(address string, timeout time.Duration) (net.Conn, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
// If there is nobody serving the pipe we limit the timeout for this case to
// 5 seconds because any shim that would serve this endpoint should serve it
// within 5 seconds.
serveTimer := time.NewTimer(5 * time.Second)
defer serveTimer.Stop()
for {
c, err := winio.DialPipeContext(ctx, address)
if err != nil {
if os.IsNotExist(err) {
select {
case <-serveTimer.C:
return nil, fmt.Errorf("pipe not found before timeout: %w", os.ErrNotExist)
default:
// Wait 10ms for the shim to serve and try again.
time.Sleep(10 * time.Millisecond)
continue
}
} else if err == context.DeadlineExceeded {
return nil, fmt.Errorf("timed out waiting for npipe %s: %w", address, err)
}
return nil, err
}
return c, nil
}
}
// RemoveSocket removes the socket at the specified address if
// it exists on the filesystem
func RemoveSocket(address string) error {
return nil
}