diff --git a/integration/failpoint/cmd/containerd-shim-runc-fp-v1/plugin_linux.go b/integration/failpoint/cmd/containerd-shim-runc-fp-v1/plugin_linux.go index a9a3d3501..c0b72681e 100644 --- a/integration/failpoint/cmd/containerd-shim-runc-fp-v1/plugin_linux.go +++ b/integration/failpoint/cmd/containerd-shim-runc-fp-v1/plugin_linux.go @@ -23,7 +23,7 @@ import ( "path/filepath" "strings" - taskapi "github.com/containerd/containerd/api/runtime/task/v2" + taskapi "github.com/containerd/containerd/api/runtime/task/v3" "github.com/containerd/containerd/oci" "github.com/containerd/containerd/pkg/failpoint" "github.com/containerd/containerd/pkg/shutdown" @@ -79,11 +79,11 @@ var ( type taskServiceWithFp struct { fps map[string]*failpoint.Failpoint - local taskapi.TaskService + local taskapi.TTRPCTaskService } func (s *taskServiceWithFp) RegisterTTRPC(server *ttrpc.Server) error { - taskapi.RegisterTaskService(server, s.local) + taskapi.RegisterTTRPCTaskService(server, s.local) return nil } diff --git a/integration/issue7496_linux_test.go b/integration/issue7496_linux_test.go index 7cca02c98..c77dbcfc1 100644 --- a/integration/issue7496_linux_test.go +++ b/integration/issue7496_linux_test.go @@ -28,7 +28,7 @@ import ( "testing" "time" - apitask "github.com/containerd/containerd/api/runtime/task/v2" + apitask "github.com/containerd/containerd/api/runtime/task/v3" "github.com/containerd/containerd/integration/images" "github.com/containerd/containerd/namespaces" "github.com/containerd/containerd/runtime/v2/shim" @@ -111,7 +111,7 @@ func TestIssue7496(t *testing.T) { // example, umount overlayfs rootfs which doesn't with volatile. // // REF: https://man7.org/linux/man-pages/man1/strace.1.html -func injectDelayToUmount2(ctx context.Context, t *testing.T, shimCli apitask.TaskService, delayInSec int) chan struct{} { +func injectDelayToUmount2(ctx context.Context, t *testing.T, shimCli apitask.TTRPCTaskService, delayInSec int) chan struct{} { pid := shimPid(ctx, t, shimCli) doneCh := make(chan struct{}) @@ -153,7 +153,7 @@ func injectDelayToUmount2(ctx context.Context, t *testing.T, shimCli apitask.Tas return doneCh } -func connectToShim(ctx context.Context, t *testing.T, id string) apitask.TaskService { +func connectToShim(ctx context.Context, t *testing.T, id string) apitask.TTRPCTaskService { addr, err := shim.SocketAddress(ctx, containerdEndpoint, id) require.NoError(t, err) addr = strings.TrimPrefix(addr, "unix://") @@ -162,10 +162,10 @@ func connectToShim(ctx context.Context, t *testing.T, id string) apitask.TaskSer require.NoError(t, err) client := ttrpc.NewClient(conn) - return apitask.NewTaskClient(client) + return apitask.NewTTRPCTaskClient(client) } -func shimPid(ctx context.Context, t *testing.T, shimCli apitask.TaskService) uint32 { +func shimPid(ctx context.Context, t *testing.T, shimCli apitask.TTRPCTaskService) uint32 { resp, err := shimCli.Connect(ctx, &apitask.ConnectRequest{}) require.NoError(t, err) return resp.GetShimPid() diff --git a/integration/issue7496_shutdown_linux_test.go b/integration/issue7496_shutdown_linux_test.go index e9e1cdc58..70ade7e41 100644 --- a/integration/issue7496_shutdown_linux_test.go +++ b/integration/issue7496_shutdown_linux_test.go @@ -22,7 +22,7 @@ import ( "github.com/stretchr/testify/require" - apitask "github.com/containerd/containerd/api/runtime/task/v2" + apitask "github.com/containerd/containerd/api/runtime/task/v3" "github.com/containerd/containerd/namespaces" ) diff --git a/runtime/v2/binary.go b/runtime/v2/binary.go index 2c71c843c..288a2f965 100644 --- a/runtime/v2/binary.go +++ b/runtime/v2/binary.go @@ -129,19 +129,25 @@ func (b *binary) Start(ctx context.Context, opts *types.Any, onClose func()) (_ return nil, err } - params, err := parseStartResponse(ctx, response) + params, err := parseStartResponse(response) if err != nil { return nil, err } - conn, err := makeConnection(ctx, params, onCloseWithShimLog) + conn, err := makeConnection(ctx, b.bundle.ID, params, onCloseWithShimLog) if err != nil { return nil, err } + // Save bootstrap configuration (so containerd can restore shims after restart). + if err := writeBootstrapParams(filepath.Join(b.bundle.Path, "bootstrap.json"), params); err != nil { + return nil, fmt.Errorf("failed to write bootstrap.json: %w", err) + } + return &shim{ - bundle: b.bundle, - client: conn, + bundle: b.bundle, + client: conn, + version: params.Version, }, nil } diff --git a/runtime/v2/bridge.go b/runtime/v2/bridge.go index 4075d6c2d..6262e0efd 100644 --- a/runtime/v2/bridge.go +++ b/runtime/v2/bridge.go @@ -20,14 +20,38 @@ import ( "context" "fmt" + "github.com/containerd/ttrpc" "google.golang.org/grpc" "google.golang.org/protobuf/types/known/emptypb" v2 "github.com/containerd/containerd/api/runtime/task/v2" v3 "github.com/containerd/containerd/api/runtime/task/v3" - "github.com/containerd/ttrpc" + + api "github.com/containerd/containerd/api/runtime/task/v3" // Current version used by TaskServiceClient ) +// TaskServiceClient exposes a client interface to shims, which aims to hide +// the underlying complexity and backward compatibility (v2 task service vs v3, TTRPC vs GRPC, etc). +type TaskServiceClient interface { + State(context.Context, *api.StateRequest) (*api.StateResponse, error) + Create(context.Context, *api.CreateTaskRequest) (*api.CreateTaskResponse, error) + Start(context.Context, *api.StartRequest) (*api.StartResponse, error) + Delete(context.Context, *api.DeleteRequest) (*api.DeleteResponse, error) + Pids(context.Context, *api.PidsRequest) (*api.PidsResponse, error) + Pause(context.Context, *api.PauseRequest) (*emptypb.Empty, error) + Resume(context.Context, *api.ResumeRequest) (*emptypb.Empty, error) + Checkpoint(context.Context, *api.CheckpointTaskRequest) (*emptypb.Empty, error) + Kill(context.Context, *api.KillRequest) (*emptypb.Empty, error) + Exec(context.Context, *api.ExecProcessRequest) (*emptypb.Empty, error) + ResizePty(context.Context, *api.ResizePtyRequest) (*emptypb.Empty, error) + CloseIO(context.Context, *api.CloseIORequest) (*emptypb.Empty, error) + Update(context.Context, *api.UpdateTaskRequest) (*emptypb.Empty, error) + Wait(context.Context, *api.WaitRequest) (*api.WaitResponse, error) + Stats(context.Context, *api.StatsRequest) (*api.StatsResponse, error) + Connect(context.Context, *api.ConnectRequest) (*api.ConnectResponse, error) + Shutdown(context.Context, *api.ShutdownRequest) (*emptypb.Empty, error) +} + // NewTaskClient returns a new task client interface which handles both GRPC and TTRPC servers depending on the // client object type passed in. // @@ -35,33 +59,47 @@ import ( // - *ttrpc.Client // - grpc.ClientConnInterface // -// In 1.7 we support TaskService v2 (for backward compatibility with existing shims) and GRPC TaskService v3. -// In 2.0 we'll switch to TaskService v3 only for both TTRPC and GRPC, which will remove overhead of mapping v2 structs to v3 structs. -func NewTaskClient(client interface{}) (v2.TaskService, error) { +// Currently supported servers: +// - TTRPC v2 (compatibility with shims before 2.0) +// - TTRPC v3 +// - GRPC v3 +func NewTaskClient(client interface{}, version int) (TaskServiceClient, error) { switch c := client.(type) { case *ttrpc.Client: - return v2.NewTaskClient(c), nil + switch version { + case 2: + return &ttrpcV2Bridge{client: v2.NewTaskClient(c)}, nil + case 3: + return v3.NewTTRPCTaskClient(c), nil + default: + return nil, fmt.Errorf("containerd client supports only v2 and v3 TTRPC task client (got %d)", version) + } + case grpc.ClientConnInterface: - return &grpcBridge{v3.NewTaskClient(c)}, nil + if version != 3 { + return nil, fmt.Errorf("containerd client supports only v3 GRPC task service (got %d)", version) + } + + return &grpcV3Bridge{v3.NewTaskClient(c)}, nil default: return nil, fmt.Errorf("unsupported shim client type %T", c) } } -// grpcBridge implements `v2.TaskService` interface for GRPC shim server. -type grpcBridge struct { - client v3.TaskClient +// ttrpcV2Bridge is a bridge from TTRPC v2 task service. +type ttrpcV2Bridge struct { + client v2.TaskService } -var _ v2.TaskService = (*grpcBridge)(nil) +var _ TaskServiceClient = (*ttrpcV2Bridge)(nil) -func (g *grpcBridge) State(ctx context.Context, request *v2.StateRequest) (*v2.StateResponse, error) { - resp, err := g.client.State(ctx, &v3.StateRequest{ +func (b *ttrpcV2Bridge) State(ctx context.Context, request *api.StateRequest) (*api.StateResponse, error) { + resp, err := b.client.State(ctx, &v2.StateRequest{ ID: request.GetID(), ExecID: request.GetExecID(), }) - return &v2.StateResponse{ + return &v3.StateResponse{ ID: resp.GetID(), Bundle: resp.GetBundle(), Pid: resp.GetPid(), @@ -76,8 +114,8 @@ func (g *grpcBridge) State(ctx context.Context, request *v2.StateRequest) (*v2.S }, err } -func (g *grpcBridge) Create(ctx context.Context, request *v2.CreateTaskRequest) (*v2.CreateTaskResponse, error) { - resp, err := g.client.Create(ctx, &v3.CreateTaskRequest{ +func (b *ttrpcV2Bridge) Create(ctx context.Context, request *api.CreateTaskRequest) (*api.CreateTaskResponse, error) { + resp, err := b.client.Create(ctx, &v2.CreateTaskRequest{ ID: request.GetID(), Bundle: request.GetBundle(), Rootfs: request.GetRootfs(), @@ -90,54 +128,54 @@ func (g *grpcBridge) Create(ctx context.Context, request *v2.CreateTaskRequest) Options: request.GetOptions(), }) - return &v2.CreateTaskResponse{Pid: resp.GetPid()}, err + return &api.CreateTaskResponse{Pid: resp.GetPid()}, err } -func (g *grpcBridge) Start(ctx context.Context, request *v2.StartRequest) (*v2.StartResponse, error) { - resp, err := g.client.Start(ctx, &v3.StartRequest{ +func (b *ttrpcV2Bridge) Start(ctx context.Context, request *api.StartRequest) (*api.StartResponse, error) { + resp, err := b.client.Start(ctx, &v2.StartRequest{ ID: request.GetID(), ExecID: request.GetExecID(), }) - return &v2.StartResponse{Pid: resp.GetPid()}, err + return &api.StartResponse{Pid: resp.GetPid()}, err } -func (g *grpcBridge) Delete(ctx context.Context, request *v2.DeleteRequest) (*v2.DeleteResponse, error) { - resp, err := g.client.Delete(ctx, &v3.DeleteRequest{ +func (b *ttrpcV2Bridge) Delete(ctx context.Context, request *api.DeleteRequest) (*api.DeleteResponse, error) { + resp, err := b.client.Delete(ctx, &v2.DeleteRequest{ ID: request.GetID(), ExecID: request.GetExecID(), }) - return &v2.DeleteResponse{ + return &api.DeleteResponse{ Pid: resp.GetPid(), ExitStatus: resp.GetExitStatus(), ExitedAt: resp.GetExitedAt(), }, err } -func (g *grpcBridge) Pids(ctx context.Context, request *v2.PidsRequest) (*v2.PidsResponse, error) { - resp, err := g.client.Pids(ctx, &v3.PidsRequest{ID: request.GetID()}) - return &v2.PidsResponse{Processes: resp.GetProcesses()}, err +func (b *ttrpcV2Bridge) Pids(ctx context.Context, request *api.PidsRequest) (*api.PidsResponse, error) { + resp, err := b.client.Pids(ctx, &v2.PidsRequest{ID: request.GetID()}) + return &api.PidsResponse{Processes: resp.GetProcesses()}, err } -func (g *grpcBridge) Pause(ctx context.Context, request *v2.PauseRequest) (*emptypb.Empty, error) { - return g.client.Pause(ctx, &v3.PauseRequest{ID: request.GetID()}) +func (b *ttrpcV2Bridge) Pause(ctx context.Context, request *api.PauseRequest) (*emptypb.Empty, error) { + return b.client.Pause(ctx, &v2.PauseRequest{ID: request.GetID()}) } -func (g *grpcBridge) Resume(ctx context.Context, request *v2.ResumeRequest) (*emptypb.Empty, error) { - return g.client.Resume(ctx, &v3.ResumeRequest{ID: request.GetID()}) +func (b *ttrpcV2Bridge) Resume(ctx context.Context, request *api.ResumeRequest) (*emptypb.Empty, error) { + return b.client.Resume(ctx, &v2.ResumeRequest{ID: request.GetID()}) } -func (g *grpcBridge) Checkpoint(ctx context.Context, request *v2.CheckpointTaskRequest) (*emptypb.Empty, error) { - return g.client.Checkpoint(ctx, &v3.CheckpointTaskRequest{ +func (b *ttrpcV2Bridge) Checkpoint(ctx context.Context, request *api.CheckpointTaskRequest) (*emptypb.Empty, error) { + return b.client.Checkpoint(ctx, &v2.CheckpointTaskRequest{ ID: request.GetID(), Path: request.GetPath(), Options: request.GetOptions(), }) } -func (g *grpcBridge) Kill(ctx context.Context, request *v2.KillRequest) (*emptypb.Empty, error) { - return g.client.Kill(ctx, &v3.KillRequest{ +func (b *ttrpcV2Bridge) Kill(ctx context.Context, request *api.KillRequest) (*emptypb.Empty, error) { + return b.client.Kill(ctx, &v2.KillRequest{ ID: request.GetID(), ExecID: request.GetExecID(), Signal: request.GetSignal(), @@ -145,8 +183,8 @@ func (g *grpcBridge) Kill(ctx context.Context, request *v2.KillRequest) (*emptyp }) } -func (g *grpcBridge) Exec(ctx context.Context, request *v2.ExecProcessRequest) (*emptypb.Empty, error) { - return g.client.Exec(ctx, &v3.ExecProcessRequest{ +func (b *ttrpcV2Bridge) Exec(ctx context.Context, request *api.ExecProcessRequest) (*emptypb.Empty, error) { + return b.client.Exec(ctx, &v2.ExecProcessRequest{ ID: request.GetID(), ExecID: request.GetExecID(), Terminal: request.GetTerminal(), @@ -157,8 +195,8 @@ func (g *grpcBridge) Exec(ctx context.Context, request *v2.ExecProcessRequest) ( }) } -func (g *grpcBridge) ResizePty(ctx context.Context, request *v2.ResizePtyRequest) (*emptypb.Empty, error) { - return g.client.ResizePty(ctx, &v3.ResizePtyRequest{ +func (b *ttrpcV2Bridge) ResizePty(ctx context.Context, request *api.ResizePtyRequest) (*emptypb.Empty, error) { + return b.client.ResizePty(ctx, &v2.ResizePtyRequest{ ID: request.GetID(), ExecID: request.GetExecID(), Width: request.GetWidth(), @@ -166,52 +204,128 @@ func (g *grpcBridge) ResizePty(ctx context.Context, request *v2.ResizePtyRequest }) } -func (g *grpcBridge) CloseIO(ctx context.Context, request *v2.CloseIORequest) (*emptypb.Empty, error) { - return g.client.CloseIO(ctx, &v3.CloseIORequest{ +func (b *ttrpcV2Bridge) CloseIO(ctx context.Context, request *api.CloseIORequest) (*emptypb.Empty, error) { + return b.client.CloseIO(ctx, &v2.CloseIORequest{ ID: request.GetID(), ExecID: request.GetExecID(), Stdin: request.GetStdin(), }) } -func (g *grpcBridge) Update(ctx context.Context, request *v2.UpdateTaskRequest) (*emptypb.Empty, error) { - return g.client.Update(ctx, &v3.UpdateTaskRequest{ +func (b *ttrpcV2Bridge) Update(ctx context.Context, request *api.UpdateTaskRequest) (*emptypb.Empty, error) { + return b.client.Update(ctx, &v2.UpdateTaskRequest{ ID: request.GetID(), Resources: request.GetResources(), Annotations: request.GetAnnotations(), }) } -func (g *grpcBridge) Wait(ctx context.Context, request *v2.WaitRequest) (*v2.WaitResponse, error) { - resp, err := g.client.Wait(ctx, &v3.WaitRequest{ +func (b *ttrpcV2Bridge) Wait(ctx context.Context, request *api.WaitRequest) (*api.WaitResponse, error) { + resp, err := b.client.Wait(ctx, &v2.WaitRequest{ ID: request.GetID(), ExecID: request.GetExecID(), }) - return &v2.WaitResponse{ + return &api.WaitResponse{ ExitStatus: resp.GetExitStatus(), ExitedAt: resp.GetExitedAt(), }, err } -func (g *grpcBridge) Stats(ctx context.Context, request *v2.StatsRequest) (*v2.StatsResponse, error) { - resp, err := g.client.Stats(ctx, &v3.StatsRequest{ID: request.GetID()}) - return &v2.StatsResponse{Stats: resp.GetStats()}, err +func (b *ttrpcV2Bridge) Stats(ctx context.Context, request *api.StatsRequest) (*api.StatsResponse, error) { + resp, err := b.client.Stats(ctx, &v2.StatsRequest{ID: request.GetID()}) + return &api.StatsResponse{Stats: resp.GetStats()}, err } -func (g *grpcBridge) Connect(ctx context.Context, request *v2.ConnectRequest) (*v2.ConnectResponse, error) { - resp, err := g.client.Connect(ctx, &v3.ConnectRequest{ID: request.GetID()}) +func (b *ttrpcV2Bridge) Connect(ctx context.Context, request *api.ConnectRequest) (*api.ConnectResponse, error) { + resp, err := b.client.Connect(ctx, &v2.ConnectRequest{ID: request.GetID()}) - return &v2.ConnectResponse{ + return &api.ConnectResponse{ ShimPid: resp.GetShimPid(), TaskPid: resp.GetTaskPid(), Version: resp.GetVersion(), }, err } -func (g *grpcBridge) Shutdown(ctx context.Context, request *v2.ShutdownRequest) (*emptypb.Empty, error) { - return g.client.Shutdown(ctx, &v3.ShutdownRequest{ +func (b *ttrpcV2Bridge) Shutdown(ctx context.Context, request *api.ShutdownRequest) (*emptypb.Empty, error) { + return b.client.Shutdown(ctx, &v2.ShutdownRequest{ ID: request.GetID(), Now: request.GetNow(), }) } + +// grpcV3Bridge implements task service client for v3 GRPC server. +// GRPC uses same request/response structures as TTRPC, so it just wraps GRPC calls. +type grpcV3Bridge struct { + client v3.TaskClient +} + +var _ TaskServiceClient = (*grpcV3Bridge)(nil) + +func (g *grpcV3Bridge) State(ctx context.Context, request *api.StateRequest) (*api.StateResponse, error) { + return g.client.State(ctx, request) +} + +func (g *grpcV3Bridge) Create(ctx context.Context, request *api.CreateTaskRequest) (*api.CreateTaskResponse, error) { + return g.client.Create(ctx, request) +} + +func (g *grpcV3Bridge) Start(ctx context.Context, request *api.StartRequest) (*api.StartResponse, error) { + return g.client.Start(ctx, request) +} + +func (g *grpcV3Bridge) Delete(ctx context.Context, request *api.DeleteRequest) (*api.DeleteResponse, error) { + return g.client.Delete(ctx, request) +} + +func (g *grpcV3Bridge) Pids(ctx context.Context, request *api.PidsRequest) (*api.PidsResponse, error) { + return g.client.Pids(ctx, request) +} + +func (g *grpcV3Bridge) Pause(ctx context.Context, request *api.PauseRequest) (*emptypb.Empty, error) { + return g.client.Pause(ctx, request) +} + +func (g *grpcV3Bridge) Resume(ctx context.Context, request *api.ResumeRequest) (*emptypb.Empty, error) { + return g.client.Resume(ctx, request) +} + +func (g *grpcV3Bridge) Checkpoint(ctx context.Context, request *api.CheckpointTaskRequest) (*emptypb.Empty, error) { + return g.client.Checkpoint(ctx, request) +} + +func (g *grpcV3Bridge) Kill(ctx context.Context, request *api.KillRequest) (*emptypb.Empty, error) { + return g.client.Kill(ctx, request) +} + +func (g *grpcV3Bridge) Exec(ctx context.Context, request *api.ExecProcessRequest) (*emptypb.Empty, error) { + return g.client.Exec(ctx, request) +} + +func (g *grpcV3Bridge) ResizePty(ctx context.Context, request *api.ResizePtyRequest) (*emptypb.Empty, error) { + return g.client.ResizePty(ctx, request) +} + +func (g *grpcV3Bridge) CloseIO(ctx context.Context, request *api.CloseIORequest) (*emptypb.Empty, error) { + return g.client.CloseIO(ctx, request) +} + +func (g *grpcV3Bridge) Update(ctx context.Context, request *api.UpdateTaskRequest) (*emptypb.Empty, error) { + return g.client.Update(ctx, request) +} + +func (g *grpcV3Bridge) Wait(ctx context.Context, request *api.WaitRequest) (*api.WaitResponse, error) { + return g.client.Wait(ctx, request) +} + +func (g *grpcV3Bridge) Stats(ctx context.Context, request *api.StatsRequest) (*api.StatsResponse, error) { + return g.client.Stats(ctx, request) +} + +func (g *grpcV3Bridge) Connect(ctx context.Context, request *api.ConnectRequest) (*api.ConnectResponse, error) { + return g.client.Connect(ctx, request) +} + +func (g *grpcV3Bridge) Shutdown(ctx context.Context, request *api.ShutdownRequest) (*emptypb.Empty, error) { + return g.client.Shutdown(ctx, request) +} diff --git a/runtime/v2/example/example.go b/runtime/v2/example/example.go index 68ad833f4..8633aa17e 100644 --- a/runtime/v2/example/example.go +++ b/runtime/v2/example/example.go @@ -65,8 +65,8 @@ func (m manager) Name() string { return m.name } -func (m manager) Start(ctx context.Context, id string, opts shim.StartOpts) (string, error) { - return "", errdefs.ErrNotImplemented +func (m manager) Start(ctx context.Context, id string, opts shim.StartOpts) (shim.BootstrapParams, error) { + return shim.BootstrapParams{}, errdefs.ErrNotImplemented } func (m manager) Stop(ctx context.Context, id string) (shim.StopStatus, error) { diff --git a/runtime/v2/manager.go b/runtime/v2/manager.go index 33803c3d7..923489eed 100644 --- a/runtime/v2/manager.go +++ b/runtime/v2/manager.go @@ -18,6 +18,7 @@ package v2 import ( "context" + "errors" "fmt" "os" "os/exec" @@ -205,14 +206,13 @@ func (m *ShimManager) Start(ctx context.Context, id string, opts runtime.CreateO return nil, err } - address, err := shimbinary.ReadAddress(filepath.Join(m.state, process.Namespace(), opts.SandboxID, "address")) + params, err := restoreBootstrapParams(filepath.Join(m.state, process.Namespace(), opts.SandboxID)) if err != nil { - return nil, fmt.Errorf("failed to get socket address for sandbox %q: %w", opts.SandboxID, err) + return nil, err } - // Use sandbox's socket address to handle task requests for this container. - if err := shimbinary.WriteAddress(filepath.Join(bundle.Path, "address"), address); err != nil { - return nil, err + if err := writeBootstrapParams(filepath.Join(bundle.Path, "bootstrap.json"), params); err != nil { + return nil, fmt.Errorf("failed to write bootstrap.json for bundle %s: %w", bundle.Path, err) } shim, err := loadShim(ctx, bundle, func() {}) @@ -284,6 +284,39 @@ func (m *ShimManager) startShim(ctx context.Context, bundle *Bundle, id string, return shim, nil } +// restoreBootstrapParams reads bootstrap.json to restore shim configuration. +// If its an old shim, this will perform migration - read address file and write default bootstrap +// configuration (version = 2, protocol = ttrpc, and address). +func restoreBootstrapParams(bundlePath string) (shimbinary.BootstrapParams, error) { + filePath := filepath.Join(bundlePath, "bootstrap.json") + + // Read bootstrap.json if exists + if _, err := os.Stat(filePath); err == nil { + return readBootstrapParams(filePath) + } else if !errors.Is(err, os.ErrNotExist) { + return shimbinary.BootstrapParams{}, fmt.Errorf("failed to stat %s: %w", filePath, err) + } + + // File not found, likely its an older shim. Try migrate. + + address, err := shimbinary.ReadAddress(filepath.Join(bundlePath, "address")) + if err != nil { + return shimbinary.BootstrapParams{}, fmt.Errorf("unable to migrate shim: failed to get socket address for bundle %s: %w", bundlePath, err) + } + + params := shimbinary.BootstrapParams{ + Version: 2, + Address: address, + Protocol: "ttrpc", + } + + if err := writeBootstrapParams(filePath, params); err != nil { + return shimbinary.BootstrapParams{}, fmt.Errorf("unable to migrate: failed to write bootstrap.json file: %w", err) + } + + return params, nil +} + func (m *ShimManager) resolveRuntimePath(runtime string) (string, error) { if runtime == "" { return "", fmt.Errorf("no runtime name") diff --git a/runtime/v2/process.go b/runtime/v2/process.go index e2c9a5c0d..83e18151f 100644 --- a/runtime/v2/process.go +++ b/runtime/v2/process.go @@ -20,7 +20,7 @@ import ( "context" "errors" - "github.com/containerd/containerd/api/runtime/task/v2" + task "github.com/containerd/containerd/api/runtime/task/v3" tasktypes "github.com/containerd/containerd/api/types/task" "github.com/containerd/containerd/errdefs" "github.com/containerd/containerd/protobuf" diff --git a/runtime/v2/runc/container.go b/runtime/v2/runc/container.go index a401c9450..2f6ebbbd2 100644 --- a/runtime/v2/runc/container.go +++ b/runtime/v2/runc/container.go @@ -30,7 +30,7 @@ import ( "github.com/containerd/cgroups/v3/cgroup1" cgroupsv2 "github.com/containerd/cgroups/v3/cgroup2" "github.com/containerd/console" - "github.com/containerd/containerd/api/runtime/task/v2" + "github.com/containerd/containerd/api/runtime/task/v3" "github.com/containerd/containerd/errdefs" "github.com/containerd/containerd/mount" "github.com/containerd/containerd/namespaces" diff --git a/runtime/v2/runc/manager/manager_linux.go b/runtime/v2/runc/manager/manager_linux.go index 67d20c08a..eefc32822 100644 --- a/runtime/v2/runc/manager/manager_linux.go +++ b/runtime/v2/runc/manager/manager_linux.go @@ -118,15 +118,19 @@ func (m manager) Name() string { return m.name } -func (manager) Start(ctx context.Context, id string, opts shim.StartOpts) (_ string, retErr error) { +func (manager) Start(ctx context.Context, id string, opts shim.StartOpts) (_ shim.BootstrapParams, retErr error) { + var params shim.BootstrapParams + params.Version = 3 + params.Protocol = "ttrpc" + cmd, err := newCommand(ctx, id, opts.Address, opts.TTRPCAddress, opts.Debug) if err != nil { - return "", err + return params, err } grouping := id spec, err := readSpec() if err != nil { - return "", err + return params, err } for _, group := range groupLabels { if groupID, ok := spec.Annotations[group]; ok { @@ -136,7 +140,7 @@ func (manager) Start(ctx context.Context, id string, opts shim.StartOpts) (_ str } address, err := shim.SocketAddress(ctx, opts.Address, grouping) if err != nil { - return "", err + return params, err } socket, err := shim.NewSocket(address) @@ -146,19 +150,17 @@ func (manager) Start(ctx context.Context, id string, opts shim.StartOpts) (_ str // grouping functionality where the new process should be run with the same // shim as an existing container if !shim.SocketEaddrinuse(err) { - return "", fmt.Errorf("create new shim socket: %w", err) + return params, fmt.Errorf("create new shim socket: %w", err) } if shim.CanConnect(address) { - if err := shim.WriteAddress("address", address); err != nil { - return "", fmt.Errorf("write existing socket for shim: %w", err) - } - return address, nil + params.Address = address + return params, nil } if err := shim.RemoveSocket(address); err != nil { - return "", fmt.Errorf("remove pre-existing socket: %w", err) + return params, fmt.Errorf("remove pre-existing socket: %w", err) } if socket, err = shim.NewSocket(address); err != nil { - return "", fmt.Errorf("try create new shim socket 2x: %w", err) + return params, fmt.Errorf("try create new shim socket 2x: %w", err) } } defer func() { @@ -168,14 +170,9 @@ func (manager) Start(ctx context.Context, id string, opts shim.StartOpts) (_ str } }() - // make sure that reexec shim-v2 binary use the value if need - if err := shim.WriteAddress("address", address); err != nil { - return "", err - } - f, err := socket.File() if err != nil { - return "", err + return params, err } cmd.ExtraFiles = append(cmd.ExtraFiles, f) @@ -183,13 +180,13 @@ func (manager) Start(ctx context.Context, id string, opts shim.StartOpts) (_ str goruntime.LockOSThread() if os.Getenv("SCHED_CORE") != "" { if err := schedcore.Create(schedcore.ProcessGroup); err != nil { - return "", fmt.Errorf("enable sched core support: %w", err) + return params, fmt.Errorf("enable sched core support: %w", err) } } if err := cmd.Start(); err != nil { f.Close() - return "", err + return params, err } goruntime.UnlockOSThread() @@ -207,27 +204,29 @@ func (manager) Start(ctx context.Context, id string, opts shim.StartOpts) (_ str if cgroups.Mode() == cgroups.Unified { cg, err := cgroupsv2.Load(opts.ShimCgroup) if err != nil { - return "", fmt.Errorf("failed to load cgroup %s: %w", opts.ShimCgroup, err) + return params, fmt.Errorf("failed to load cgroup %s: %w", opts.ShimCgroup, err) } if err := cg.AddProc(uint64(cmd.Process.Pid)); err != nil { - return "", fmt.Errorf("failed to join cgroup %s: %w", opts.ShimCgroup, err) + return params, fmt.Errorf("failed to join cgroup %s: %w", opts.ShimCgroup, err) } } else { cg, err := cgroup1.Load(cgroup1.StaticPath(opts.ShimCgroup)) if err != nil { - return "", fmt.Errorf("failed to load cgroup %s: %w", opts.ShimCgroup, err) + return params, fmt.Errorf("failed to load cgroup %s: %w", opts.ShimCgroup, err) } if err := cg.AddProc(uint64(cmd.Process.Pid)); err != nil { - return "", fmt.Errorf("failed to join cgroup %s: %w", opts.ShimCgroup, err) + return params, fmt.Errorf("failed to join cgroup %s: %w", opts.ShimCgroup, err) } } } } if err := shim.AdjustOOMScore(cmd.Process.Pid); err != nil { - return "", fmt.Errorf("failed to adjust OOM score for shim: %w", err) + return params, fmt.Errorf("failed to adjust OOM score for shim: %w", err) } - return address, nil + + params.Address = address + return params, nil } func (manager) Stop(ctx context.Context, id string) (shim.StopStatus, error) { diff --git a/runtime/v2/runc/task/service.go b/runtime/v2/runc/task/service.go index 4b97c7eb9..6233b1615 100644 --- a/runtime/v2/runc/task/service.go +++ b/runtime/v2/runc/task/service.go @@ -28,7 +28,7 @@ import ( "github.com/containerd/cgroups/v3/cgroup1" cgroupsv2 "github.com/containerd/cgroups/v3/cgroup2" eventstypes "github.com/containerd/containerd/api/events" - taskAPI "github.com/containerd/containerd/api/runtime/task/v2" + taskAPI "github.com/containerd/containerd/api/runtime/task/v3" "github.com/containerd/containerd/api/types/task" "github.com/containerd/containerd/errdefs" "github.com/containerd/containerd/namespaces" @@ -58,7 +58,7 @@ var ( ) // NewTaskService creates a new instance of a task service -func NewTaskService(ctx context.Context, publisher shim.Publisher, sd shutdown.Service) (taskAPI.TaskService, error) { +func NewTaskService(ctx context.Context, publisher shim.Publisher, sd shutdown.Service) (taskAPI.TTRPCTaskService, error) { var ( ep oom.Watcher err error @@ -252,7 +252,7 @@ func (s *service) Create(ctx context.Context, r *taskAPI.CreateTaskRequest) (_ * } func (s *service) RegisterTTRPC(server *ttrpc.Server) error { - taskAPI.RegisterTaskService(server, s) + taskAPI.RegisterTTRPCTaskService(server, s) return nil } diff --git a/runtime/v2/shim.go b/runtime/v2/shim.go index 4b6172602..76ce6d061 100644 --- a/runtime/v2/shim.go +++ b/runtime/v2/shim.go @@ -27,13 +27,14 @@ import ( "strings" "time" + "github.com/containerd/containerd/pkg/atomicfile" "github.com/containerd/ttrpc" "google.golang.org/grpc" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials/insecure" eventstypes "github.com/containerd/containerd/api/events" - "github.com/containerd/containerd/api/runtime/task/v2" + task "github.com/containerd/containerd/api/runtime/task/v3" "github.com/containerd/containerd/api/types" "github.com/containerd/containerd/errdefs" "github.com/containerd/containerd/events/exchange" @@ -59,20 +60,7 @@ func init() { timeout.Set(shutdownTimeout, 3*time.Second) } -func loadAddress(path string) ([]byte, error) { - data, err := os.ReadFile(path) - if err != nil { - return nil, err - } - return data, nil -} - func loadShim(ctx context.Context, bundle *Bundle, onClose func()) (_ ShimInstance, retErr error) { - address, err := loadAddress(filepath.Join(bundle.Path, "address")) - if err != nil { - return nil, err - } - shimCtx, cancelShimLog := context.WithCancel(ctx) defer func() { if retErr != nil { @@ -108,14 +96,14 @@ func loadShim(ctx context.Context, bundle *Bundle, onClose func()) (_ ShimInstan f.Close() } - params, err := parseStartResponse(ctx, address) + params, err := restoreBootstrapParams(bundle.Path) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to read boostrap.json when restoring bundle %q: %w", bundle.ID, err) } - conn, err := makeConnection(ctx, params, onCloseWithShimLog) + conn, err := makeConnection(ctx, bundle.ID, params, onCloseWithShimLog) if err != nil { - return nil, err + return nil, fmt.Errorf("unable to make connection: %w", err) } defer func() { @@ -125,8 +113,9 @@ func loadShim(ctx context.Context, bundle *Bundle, onClose func()) (_ ShimInstan }() shim := &shim{ - bundle: bundle, - client: conn, + bundle: bundle, + client: conn, + version: params.Version, } return shim, nil @@ -177,6 +166,9 @@ func cleanupAfterDeadShim(ctx context.Context, id string, rt *runtime.NSMap[Shim }) } +// CurrentShimVersion is the latest shim version supported by containerd (e.g. TaskService v3). +const CurrentShimVersion = 3 + // ShimInstance represents running shim process managed by ShimManager. type ShimInstance interface { io.Closer @@ -192,31 +184,75 @@ type ShimInstance interface { Client() any // Delete will close the client and remove bundle from disk. Delete(ctx context.Context) error + // Version returns shim's features compatibility version. + Version() int } -func parseStartResponse(ctx context.Context, response []byte) (client.BootstrapParams, error) { +func parseStartResponse(response []byte) (client.BootstrapParams, error) { var params client.BootstrapParams if err := json.Unmarshal(response, ¶ms); err != nil || params.Version < 2 { // Use TTRPC for legacy shims params.Address = string(response) params.Protocol = "ttrpc" + params.Version = 2 } - if params.Version > 2 { + if params.Version > CurrentShimVersion { return client.BootstrapParams{}, fmt.Errorf("unsupported shim version (%d): %w", params.Version, errdefs.ErrNotImplemented) } return params, nil } +// writeBootstrapParams writes shim's bootstrap configuration (e.g. how to connect, version, etc). +func writeBootstrapParams(path string, params client.BootstrapParams) error { + path, err := filepath.Abs(path) + if err != nil { + return err + } + + data, err := json.Marshal(¶ms) + if err != nil { + return err + } + + f, err := atomicfile.New(path, 0o666) + if err != nil { + return err + } + + _, err = f.Write(data) + if err != nil { + f.Cancel() + return err + } + + return f.Close() +} + +func readBootstrapParams(path string) (client.BootstrapParams, error) { + path, err := filepath.Abs(path) + if err != nil { + return client.BootstrapParams{}, err + } + + data, err := os.ReadFile(path) + if err != nil { + return client.BootstrapParams{}, err + } + + return parseStartResponse(data) +} + // makeConnection creates a new TTRPC or GRPC connection object from address. // address can be either a socket path for TTRPC or JSON serialized BootstrapParams. -func makeConnection(ctx context.Context, params client.BootstrapParams, onClose func()) (_ io.Closer, retErr error) { +func makeConnection(ctx context.Context, id string, params client.BootstrapParams, onClose func()) (_ io.Closer, retErr error) { log.G(ctx).WithFields(log.Fields{ "address": params.Address, "protocol": params.Protocol, - }).Debug("shim bootstrap parameters") + "version": params.Version, + }).Infof("connecting to shim %s", id) switch strings.ToLower(params.Protocol) { case "ttrpc": @@ -303,8 +339,9 @@ func (gc *grpcConn) UserOnCloseWait(ctx context.Context) error { } type shim struct { - bundle *Bundle - client any + bundle *Bundle + client any + version int } var _ ShimInstance = (*shim)(nil) @@ -314,6 +351,10 @@ func (s *shim) ID() string { return s.bundle.ID } +func (s *shim) Version() int { + return s.version +} + func (s *shim) Namespace() string { return s.bundle.Namespace } @@ -375,11 +416,11 @@ var _ runtime.Task = &shimTask{} // shimTask wraps shim process and adds task service client for compatibility with existing shim manager. type shimTask struct { ShimInstance - task task.TaskService + task TaskServiceClient } func newShimTask(shim ShimInstance) (*shimTask, error) { - taskClient, err := NewTaskClient(shim.Client()) + taskClient, err := NewTaskClient(shim.Client(), shim.Version()) if err != nil { return nil, err } diff --git a/runtime/v2/shim/shim.go b/runtime/v2/shim/shim.go index 166f7d5c6..0a4e6ab75 100644 --- a/runtime/v2/shim/shim.go +++ b/runtime/v2/shim/shim.go @@ -18,6 +18,7 @@ package shim import ( "context" + "encoding/json" "errors" "flag" "fmt" @@ -29,7 +30,7 @@ import ( "runtime/debug" "time" - shimapi "github.com/containerd/containerd/api/runtime/task/v2" + shimapi "github.com/containerd/containerd/api/runtime/task/v3" "github.com/containerd/containerd/events" "github.com/containerd/containerd/namespaces" "github.com/containerd/containerd/pkg/shutdown" @@ -76,7 +77,7 @@ type StopStatus struct { // Manager is the interface which manages the shim process type Manager interface { Name() string - Start(ctx context.Context, id string, opts StartOpts) (string, error) + Start(ctx context.Context, id string, opts StartOpts) (BootstrapParams, error) Stop(ctx context.Context, id string) (StopStatus, error) } @@ -268,13 +269,20 @@ func run(ctx context.Context, manager Manager, name string, config Config) error Debug: debugFlag, } - address, err := manager.Start(ctx, id, opts) + params, err := manager.Start(ctx, id, opts) if err != nil { return err } - if _, err := os.Stdout.WriteString(address); err != nil { + + data, err := json.Marshal(¶ms) + if err != nil { + return fmt.Errorf("failed to marshal bootstrap params to json: %w", err) + } + + if _, err := os.Stdout.Write(data); err != nil { return err } + return nil } diff --git a/runtime/v2/shim/util.go b/runtime/v2/shim/util.go index fce1318a6..3e7d19ec0 100644 --- a/runtime/v2/shim/util.go +++ b/runtime/v2/shim/util.go @@ -138,24 +138,6 @@ func WritePidFile(path string, pid int) error { return f.Close() } -// WriteAddress writes a address file atomically -func WriteAddress(path, address string) error { - path, err := filepath.Abs(path) - if err != nil { - return err - } - f, err := atomicfile.New(path, 0o666) - if err != nil { - return err - } - _, err = f.Write([]byte(address)) - if err != nil { - f.Cancel() - return err - } - return f.Close() -} - // ErrNoAddress is returned when the address file has no content var ErrNoAddress = errors.New("no shim address") diff --git a/runtime/v2/shim_load.go b/runtime/v2/shim_load.go index 5556e6c4d..980dba264 100644 --- a/runtime/v2/shim_load.go +++ b/runtime/v2/shim_load.go @@ -150,6 +150,7 @@ func (m *ShimManager) loadShims(ctx context.Context) error { m.shims.Delete(ctx, id) }) if err != nil { + log.G(ctx).WithError(err).Errorf("unable to load shim %q", id) cleanupAfterDeadShim(ctx, id, m.shims, m.events, binaryCall) continue } diff --git a/runtime/v2/shim_test.go b/runtime/v2/shim_test.go index 13ca68634..05140fcbb 100644 --- a/runtime/v2/shim_test.go +++ b/runtime/v2/shim_test.go @@ -17,12 +17,14 @@ package v2 import ( - "context" "errors" + "os" + "path/filepath" "testing" "github.com/containerd/containerd/errdefs" client "github.com/containerd/containerd/runtime/v2/shim" + "github.com/stretchr/testify/require" ) func TestParseStartResponse(t *testing.T) { @@ -36,7 +38,7 @@ func TestParseStartResponse(t *testing.T) { Name: "v2 shim", Response: "/somedirectory/somesocket", Expected: client.BootstrapParams{ - Version: 0, + Version: 2, Address: "/somedirectory/somesocket", Protocol: "ttrpc", }, @@ -63,20 +65,20 @@ func TestParseStartResponse(t *testing.T) { Name: "invalid shim v2 response", Response: `{"address":"/somedirectory/somesocket","protocol":"ttrpc"}`, Expected: client.BootstrapParams{ - Version: 0, + Version: 2, Address: `{"address":"/somedirectory/somesocket","protocol":"ttrpc"}`, Protocol: "ttrpc", }, }, { Name: "later unsupported shim", - Response: `{"Version": 3,"Address":"/somedirectory/somesocket","Protocol":"ttrpc"}`, + Response: `{"Version": 4,"Address":"/somedirectory/somesocket","Protocol":"ttrpc"}`, Expected: client.BootstrapParams{}, Err: errdefs.ErrNotImplemented, }, } { t.Run(tc.Name, func(t *testing.T) { - params, err := parseStartResponse(context.Background(), []byte(tc.Response)) + params, err := parseStartResponse([]byte(tc.Response)) if err != nil { if !errors.Is(err, tc.Err) { t.Errorf("unexpected error: %v", err) @@ -96,5 +98,27 @@ func TestParseStartResponse(t *testing.T) { } }) } - +} + +func TestRestoreBootstrapParams(t *testing.T) { + bundlePath := t.TempDir() + + err := os.WriteFile(filepath.Join(bundlePath, "address"), []byte("unix://123"), 0o666) + require.NoError(t, err) + + restored, err := restoreBootstrapParams(bundlePath) + require.NoError(t, err) + + expected := client.BootstrapParams{ + Version: 2, + Address: "unix://123", + Protocol: "ttrpc", + } + + require.EqualValues(t, expected, restored) + + loaded, err := readBootstrapParams(filepath.Join(bundlePath, "bootstrap.json")) + + require.NoError(t, err) + require.EqualValues(t, expected, loaded) }