Merge pull request #9233 from mxpv/tasks

Switch runc shim to task service v3 and fix restore
This commit is contained in:
Derek McGowan 2023-10-20 17:26:31 +00:00 committed by GitHub
commit e973109c2d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 366 additions and 158 deletions

View File

@ -23,7 +23,7 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
taskapi "github.com/containerd/containerd/api/runtime/task/v2" taskapi "github.com/containerd/containerd/api/runtime/task/v3"
"github.com/containerd/containerd/oci" "github.com/containerd/containerd/oci"
"github.com/containerd/containerd/pkg/failpoint" "github.com/containerd/containerd/pkg/failpoint"
"github.com/containerd/containerd/pkg/shutdown" "github.com/containerd/containerd/pkg/shutdown"
@ -79,11 +79,11 @@ var (
type taskServiceWithFp struct { type taskServiceWithFp struct {
fps map[string]*failpoint.Failpoint fps map[string]*failpoint.Failpoint
local taskapi.TaskService local taskapi.TTRPCTaskService
} }
func (s *taskServiceWithFp) RegisterTTRPC(server *ttrpc.Server) error { func (s *taskServiceWithFp) RegisterTTRPC(server *ttrpc.Server) error {
taskapi.RegisterTaskService(server, s.local) taskapi.RegisterTTRPCTaskService(server, s.local)
return nil return nil
} }

View File

@ -28,7 +28,7 @@ import (
"testing" "testing"
"time" "time"
apitask "github.com/containerd/containerd/api/runtime/task/v2" apitask "github.com/containerd/containerd/api/runtime/task/v3"
"github.com/containerd/containerd/integration/images" "github.com/containerd/containerd/integration/images"
"github.com/containerd/containerd/namespaces" "github.com/containerd/containerd/namespaces"
"github.com/containerd/containerd/runtime/v2/shim" "github.com/containerd/containerd/runtime/v2/shim"
@ -111,7 +111,7 @@ func TestIssue7496(t *testing.T) {
// example, umount overlayfs rootfs which doesn't with volatile. // example, umount overlayfs rootfs which doesn't with volatile.
// //
// REF: https://man7.org/linux/man-pages/man1/strace.1.html // REF: https://man7.org/linux/man-pages/man1/strace.1.html
func injectDelayToUmount2(ctx context.Context, t *testing.T, shimCli apitask.TaskService, delayInSec int) chan struct{} { func injectDelayToUmount2(ctx context.Context, t *testing.T, shimCli apitask.TTRPCTaskService, delayInSec int) chan struct{} {
pid := shimPid(ctx, t, shimCli) pid := shimPid(ctx, t, shimCli)
doneCh := make(chan struct{}) doneCh := make(chan struct{})
@ -153,7 +153,7 @@ func injectDelayToUmount2(ctx context.Context, t *testing.T, shimCli apitask.Tas
return doneCh return doneCh
} }
func connectToShim(ctx context.Context, t *testing.T, id string) apitask.TaskService { func connectToShim(ctx context.Context, t *testing.T, id string) apitask.TTRPCTaskService {
addr, err := shim.SocketAddress(ctx, containerdEndpoint, id) addr, err := shim.SocketAddress(ctx, containerdEndpoint, id)
require.NoError(t, err) require.NoError(t, err)
addr = strings.TrimPrefix(addr, "unix://") addr = strings.TrimPrefix(addr, "unix://")
@ -162,10 +162,10 @@ func connectToShim(ctx context.Context, t *testing.T, id string) apitask.TaskSer
require.NoError(t, err) require.NoError(t, err)
client := ttrpc.NewClient(conn) client := ttrpc.NewClient(conn)
return apitask.NewTaskClient(client) return apitask.NewTTRPCTaskClient(client)
} }
func shimPid(ctx context.Context, t *testing.T, shimCli apitask.TaskService) uint32 { func shimPid(ctx context.Context, t *testing.T, shimCli apitask.TTRPCTaskService) uint32 {
resp, err := shimCli.Connect(ctx, &apitask.ConnectRequest{}) resp, err := shimCli.Connect(ctx, &apitask.ConnectRequest{})
require.NoError(t, err) require.NoError(t, err)
return resp.GetShimPid() return resp.GetShimPid()

View File

@ -22,7 +22,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
apitask "github.com/containerd/containerd/api/runtime/task/v2" apitask "github.com/containerd/containerd/api/runtime/task/v3"
"github.com/containerd/containerd/namespaces" "github.com/containerd/containerd/namespaces"
) )

View File

@ -129,19 +129,25 @@ func (b *binary) Start(ctx context.Context, opts *types.Any, onClose func()) (_
return nil, err return nil, err
} }
params, err := parseStartResponse(ctx, response) params, err := parseStartResponse(response)
if err != nil { if err != nil {
return nil, err return nil, err
} }
conn, err := makeConnection(ctx, params, onCloseWithShimLog) conn, err := makeConnection(ctx, b.bundle.ID, params, onCloseWithShimLog)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Save bootstrap configuration (so containerd can restore shims after restart).
if err := writeBootstrapParams(filepath.Join(b.bundle.Path, "bootstrap.json"), params); err != nil {
return nil, fmt.Errorf("failed to write bootstrap.json: %w", err)
}
return &shim{ return &shim{
bundle: b.bundle, bundle: b.bundle,
client: conn, client: conn,
version: params.Version,
}, nil }, nil
} }

View File

@ -20,14 +20,38 @@ import (
"context" "context"
"fmt" "fmt"
"github.com/containerd/ttrpc"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/protobuf/types/known/emptypb" "google.golang.org/protobuf/types/known/emptypb"
v2 "github.com/containerd/containerd/api/runtime/task/v2" v2 "github.com/containerd/containerd/api/runtime/task/v2"
v3 "github.com/containerd/containerd/api/runtime/task/v3" v3 "github.com/containerd/containerd/api/runtime/task/v3"
"github.com/containerd/ttrpc"
api "github.com/containerd/containerd/api/runtime/task/v3" // Current version used by TaskServiceClient
) )
// TaskServiceClient exposes a client interface to shims, which aims to hide
// the underlying complexity and backward compatibility (v2 task service vs v3, TTRPC vs GRPC, etc).
type TaskServiceClient interface {
State(context.Context, *api.StateRequest) (*api.StateResponse, error)
Create(context.Context, *api.CreateTaskRequest) (*api.CreateTaskResponse, error)
Start(context.Context, *api.StartRequest) (*api.StartResponse, error)
Delete(context.Context, *api.DeleteRequest) (*api.DeleteResponse, error)
Pids(context.Context, *api.PidsRequest) (*api.PidsResponse, error)
Pause(context.Context, *api.PauseRequest) (*emptypb.Empty, error)
Resume(context.Context, *api.ResumeRequest) (*emptypb.Empty, error)
Checkpoint(context.Context, *api.CheckpointTaskRequest) (*emptypb.Empty, error)
Kill(context.Context, *api.KillRequest) (*emptypb.Empty, error)
Exec(context.Context, *api.ExecProcessRequest) (*emptypb.Empty, error)
ResizePty(context.Context, *api.ResizePtyRequest) (*emptypb.Empty, error)
CloseIO(context.Context, *api.CloseIORequest) (*emptypb.Empty, error)
Update(context.Context, *api.UpdateTaskRequest) (*emptypb.Empty, error)
Wait(context.Context, *api.WaitRequest) (*api.WaitResponse, error)
Stats(context.Context, *api.StatsRequest) (*api.StatsResponse, error)
Connect(context.Context, *api.ConnectRequest) (*api.ConnectResponse, error)
Shutdown(context.Context, *api.ShutdownRequest) (*emptypb.Empty, error)
}
// NewTaskClient returns a new task client interface which handles both GRPC and TTRPC servers depending on the // NewTaskClient returns a new task client interface which handles both GRPC and TTRPC servers depending on the
// client object type passed in. // client object type passed in.
// //
@ -35,33 +59,47 @@ import (
// - *ttrpc.Client // - *ttrpc.Client
// - grpc.ClientConnInterface // - grpc.ClientConnInterface
// //
// In 1.7 we support TaskService v2 (for backward compatibility with existing shims) and GRPC TaskService v3. // Currently supported servers:
// In 2.0 we'll switch to TaskService v3 only for both TTRPC and GRPC, which will remove overhead of mapping v2 structs to v3 structs. // - TTRPC v2 (compatibility with shims before 2.0)
func NewTaskClient(client interface{}) (v2.TaskService, error) { // - TTRPC v3
// - GRPC v3
func NewTaskClient(client interface{}, version int) (TaskServiceClient, error) {
switch c := client.(type) { switch c := client.(type) {
case *ttrpc.Client: case *ttrpc.Client:
return v2.NewTaskClient(c), nil switch version {
case 2:
return &ttrpcV2Bridge{client: v2.NewTaskClient(c)}, nil
case 3:
return v3.NewTTRPCTaskClient(c), nil
default:
return nil, fmt.Errorf("containerd client supports only v2 and v3 TTRPC task client (got %d)", version)
}
case grpc.ClientConnInterface: case grpc.ClientConnInterface:
return &grpcBridge{v3.NewTaskClient(c)}, nil if version != 3 {
return nil, fmt.Errorf("containerd client supports only v3 GRPC task service (got %d)", version)
}
return &grpcV3Bridge{v3.NewTaskClient(c)}, nil
default: default:
return nil, fmt.Errorf("unsupported shim client type %T", c) return nil, fmt.Errorf("unsupported shim client type %T", c)
} }
} }
// grpcBridge implements `v2.TaskService` interface for GRPC shim server. // ttrpcV2Bridge is a bridge from TTRPC v2 task service.
type grpcBridge struct { type ttrpcV2Bridge struct {
client v3.TaskClient client v2.TaskService
} }
var _ v2.TaskService = (*grpcBridge)(nil) var _ TaskServiceClient = (*ttrpcV2Bridge)(nil)
func (g *grpcBridge) State(ctx context.Context, request *v2.StateRequest) (*v2.StateResponse, error) { func (b *ttrpcV2Bridge) State(ctx context.Context, request *api.StateRequest) (*api.StateResponse, error) {
resp, err := g.client.State(ctx, &v3.StateRequest{ resp, err := b.client.State(ctx, &v2.StateRequest{
ID: request.GetID(), ID: request.GetID(),
ExecID: request.GetExecID(), ExecID: request.GetExecID(),
}) })
return &v2.StateResponse{ return &v3.StateResponse{
ID: resp.GetID(), ID: resp.GetID(),
Bundle: resp.GetBundle(), Bundle: resp.GetBundle(),
Pid: resp.GetPid(), Pid: resp.GetPid(),
@ -76,8 +114,8 @@ func (g *grpcBridge) State(ctx context.Context, request *v2.StateRequest) (*v2.S
}, err }, err
} }
func (g *grpcBridge) Create(ctx context.Context, request *v2.CreateTaskRequest) (*v2.CreateTaskResponse, error) { func (b *ttrpcV2Bridge) Create(ctx context.Context, request *api.CreateTaskRequest) (*api.CreateTaskResponse, error) {
resp, err := g.client.Create(ctx, &v3.CreateTaskRequest{ resp, err := b.client.Create(ctx, &v2.CreateTaskRequest{
ID: request.GetID(), ID: request.GetID(),
Bundle: request.GetBundle(), Bundle: request.GetBundle(),
Rootfs: request.GetRootfs(), Rootfs: request.GetRootfs(),
@ -90,54 +128,54 @@ func (g *grpcBridge) Create(ctx context.Context, request *v2.CreateTaskRequest)
Options: request.GetOptions(), Options: request.GetOptions(),
}) })
return &v2.CreateTaskResponse{Pid: resp.GetPid()}, err return &api.CreateTaskResponse{Pid: resp.GetPid()}, err
} }
func (g *grpcBridge) Start(ctx context.Context, request *v2.StartRequest) (*v2.StartResponse, error) { func (b *ttrpcV2Bridge) Start(ctx context.Context, request *api.StartRequest) (*api.StartResponse, error) {
resp, err := g.client.Start(ctx, &v3.StartRequest{ resp, err := b.client.Start(ctx, &v2.StartRequest{
ID: request.GetID(), ID: request.GetID(),
ExecID: request.GetExecID(), ExecID: request.GetExecID(),
}) })
return &v2.StartResponse{Pid: resp.GetPid()}, err return &api.StartResponse{Pid: resp.GetPid()}, err
} }
func (g *grpcBridge) Delete(ctx context.Context, request *v2.DeleteRequest) (*v2.DeleteResponse, error) { func (b *ttrpcV2Bridge) Delete(ctx context.Context, request *api.DeleteRequest) (*api.DeleteResponse, error) {
resp, err := g.client.Delete(ctx, &v3.DeleteRequest{ resp, err := b.client.Delete(ctx, &v2.DeleteRequest{
ID: request.GetID(), ID: request.GetID(),
ExecID: request.GetExecID(), ExecID: request.GetExecID(),
}) })
return &v2.DeleteResponse{ return &api.DeleteResponse{
Pid: resp.GetPid(), Pid: resp.GetPid(),
ExitStatus: resp.GetExitStatus(), ExitStatus: resp.GetExitStatus(),
ExitedAt: resp.GetExitedAt(), ExitedAt: resp.GetExitedAt(),
}, err }, err
} }
func (g *grpcBridge) Pids(ctx context.Context, request *v2.PidsRequest) (*v2.PidsResponse, error) { func (b *ttrpcV2Bridge) Pids(ctx context.Context, request *api.PidsRequest) (*api.PidsResponse, error) {
resp, err := g.client.Pids(ctx, &v3.PidsRequest{ID: request.GetID()}) resp, err := b.client.Pids(ctx, &v2.PidsRequest{ID: request.GetID()})
return &v2.PidsResponse{Processes: resp.GetProcesses()}, err return &api.PidsResponse{Processes: resp.GetProcesses()}, err
} }
func (g *grpcBridge) Pause(ctx context.Context, request *v2.PauseRequest) (*emptypb.Empty, error) { func (b *ttrpcV2Bridge) Pause(ctx context.Context, request *api.PauseRequest) (*emptypb.Empty, error) {
return g.client.Pause(ctx, &v3.PauseRequest{ID: request.GetID()}) return b.client.Pause(ctx, &v2.PauseRequest{ID: request.GetID()})
} }
func (g *grpcBridge) Resume(ctx context.Context, request *v2.ResumeRequest) (*emptypb.Empty, error) { func (b *ttrpcV2Bridge) Resume(ctx context.Context, request *api.ResumeRequest) (*emptypb.Empty, error) {
return g.client.Resume(ctx, &v3.ResumeRequest{ID: request.GetID()}) return b.client.Resume(ctx, &v2.ResumeRequest{ID: request.GetID()})
} }
func (g *grpcBridge) Checkpoint(ctx context.Context, request *v2.CheckpointTaskRequest) (*emptypb.Empty, error) { func (b *ttrpcV2Bridge) Checkpoint(ctx context.Context, request *api.CheckpointTaskRequest) (*emptypb.Empty, error) {
return g.client.Checkpoint(ctx, &v3.CheckpointTaskRequest{ return b.client.Checkpoint(ctx, &v2.CheckpointTaskRequest{
ID: request.GetID(), ID: request.GetID(),
Path: request.GetPath(), Path: request.GetPath(),
Options: request.GetOptions(), Options: request.GetOptions(),
}) })
} }
func (g *grpcBridge) Kill(ctx context.Context, request *v2.KillRequest) (*emptypb.Empty, error) { func (b *ttrpcV2Bridge) Kill(ctx context.Context, request *api.KillRequest) (*emptypb.Empty, error) {
return g.client.Kill(ctx, &v3.KillRequest{ return b.client.Kill(ctx, &v2.KillRequest{
ID: request.GetID(), ID: request.GetID(),
ExecID: request.GetExecID(), ExecID: request.GetExecID(),
Signal: request.GetSignal(), Signal: request.GetSignal(),
@ -145,8 +183,8 @@ func (g *grpcBridge) Kill(ctx context.Context, request *v2.KillRequest) (*emptyp
}) })
} }
func (g *grpcBridge) Exec(ctx context.Context, request *v2.ExecProcessRequest) (*emptypb.Empty, error) { func (b *ttrpcV2Bridge) Exec(ctx context.Context, request *api.ExecProcessRequest) (*emptypb.Empty, error) {
return g.client.Exec(ctx, &v3.ExecProcessRequest{ return b.client.Exec(ctx, &v2.ExecProcessRequest{
ID: request.GetID(), ID: request.GetID(),
ExecID: request.GetExecID(), ExecID: request.GetExecID(),
Terminal: request.GetTerminal(), Terminal: request.GetTerminal(),
@ -157,8 +195,8 @@ func (g *grpcBridge) Exec(ctx context.Context, request *v2.ExecProcessRequest) (
}) })
} }
func (g *grpcBridge) ResizePty(ctx context.Context, request *v2.ResizePtyRequest) (*emptypb.Empty, error) { func (b *ttrpcV2Bridge) ResizePty(ctx context.Context, request *api.ResizePtyRequest) (*emptypb.Empty, error) {
return g.client.ResizePty(ctx, &v3.ResizePtyRequest{ return b.client.ResizePty(ctx, &v2.ResizePtyRequest{
ID: request.GetID(), ID: request.GetID(),
ExecID: request.GetExecID(), ExecID: request.GetExecID(),
Width: request.GetWidth(), Width: request.GetWidth(),
@ -166,52 +204,128 @@ func (g *grpcBridge) ResizePty(ctx context.Context, request *v2.ResizePtyRequest
}) })
} }
func (g *grpcBridge) CloseIO(ctx context.Context, request *v2.CloseIORequest) (*emptypb.Empty, error) { func (b *ttrpcV2Bridge) CloseIO(ctx context.Context, request *api.CloseIORequest) (*emptypb.Empty, error) {
return g.client.CloseIO(ctx, &v3.CloseIORequest{ return b.client.CloseIO(ctx, &v2.CloseIORequest{
ID: request.GetID(), ID: request.GetID(),
ExecID: request.GetExecID(), ExecID: request.GetExecID(),
Stdin: request.GetStdin(), Stdin: request.GetStdin(),
}) })
} }
func (g *grpcBridge) Update(ctx context.Context, request *v2.UpdateTaskRequest) (*emptypb.Empty, error) { func (b *ttrpcV2Bridge) Update(ctx context.Context, request *api.UpdateTaskRequest) (*emptypb.Empty, error) {
return g.client.Update(ctx, &v3.UpdateTaskRequest{ return b.client.Update(ctx, &v2.UpdateTaskRequest{
ID: request.GetID(), ID: request.GetID(),
Resources: request.GetResources(), Resources: request.GetResources(),
Annotations: request.GetAnnotations(), Annotations: request.GetAnnotations(),
}) })
} }
func (g *grpcBridge) Wait(ctx context.Context, request *v2.WaitRequest) (*v2.WaitResponse, error) { func (b *ttrpcV2Bridge) Wait(ctx context.Context, request *api.WaitRequest) (*api.WaitResponse, error) {
resp, err := g.client.Wait(ctx, &v3.WaitRequest{ resp, err := b.client.Wait(ctx, &v2.WaitRequest{
ID: request.GetID(), ID: request.GetID(),
ExecID: request.GetExecID(), ExecID: request.GetExecID(),
}) })
return &v2.WaitResponse{ return &api.WaitResponse{
ExitStatus: resp.GetExitStatus(), ExitStatus: resp.GetExitStatus(),
ExitedAt: resp.GetExitedAt(), ExitedAt: resp.GetExitedAt(),
}, err }, err
} }
func (g *grpcBridge) Stats(ctx context.Context, request *v2.StatsRequest) (*v2.StatsResponse, error) { func (b *ttrpcV2Bridge) Stats(ctx context.Context, request *api.StatsRequest) (*api.StatsResponse, error) {
resp, err := g.client.Stats(ctx, &v3.StatsRequest{ID: request.GetID()}) resp, err := b.client.Stats(ctx, &v2.StatsRequest{ID: request.GetID()})
return &v2.StatsResponse{Stats: resp.GetStats()}, err return &api.StatsResponse{Stats: resp.GetStats()}, err
} }
func (g *grpcBridge) Connect(ctx context.Context, request *v2.ConnectRequest) (*v2.ConnectResponse, error) { func (b *ttrpcV2Bridge) Connect(ctx context.Context, request *api.ConnectRequest) (*api.ConnectResponse, error) {
resp, err := g.client.Connect(ctx, &v3.ConnectRequest{ID: request.GetID()}) resp, err := b.client.Connect(ctx, &v2.ConnectRequest{ID: request.GetID()})
return &v2.ConnectResponse{ return &api.ConnectResponse{
ShimPid: resp.GetShimPid(), ShimPid: resp.GetShimPid(),
TaskPid: resp.GetTaskPid(), TaskPid: resp.GetTaskPid(),
Version: resp.GetVersion(), Version: resp.GetVersion(),
}, err }, err
} }
func (g *grpcBridge) Shutdown(ctx context.Context, request *v2.ShutdownRequest) (*emptypb.Empty, error) { func (b *ttrpcV2Bridge) Shutdown(ctx context.Context, request *api.ShutdownRequest) (*emptypb.Empty, error) {
return g.client.Shutdown(ctx, &v3.ShutdownRequest{ return b.client.Shutdown(ctx, &v2.ShutdownRequest{
ID: request.GetID(), ID: request.GetID(),
Now: request.GetNow(), Now: request.GetNow(),
}) })
} }
// grpcV3Bridge implements task service client for v3 GRPC server.
// GRPC uses same request/response structures as TTRPC, so it just wraps GRPC calls.
type grpcV3Bridge struct {
client v3.TaskClient
}
var _ TaskServiceClient = (*grpcV3Bridge)(nil)
func (g *grpcV3Bridge) State(ctx context.Context, request *api.StateRequest) (*api.StateResponse, error) {
return g.client.State(ctx, request)
}
func (g *grpcV3Bridge) Create(ctx context.Context, request *api.CreateTaskRequest) (*api.CreateTaskResponse, error) {
return g.client.Create(ctx, request)
}
func (g *grpcV3Bridge) Start(ctx context.Context, request *api.StartRequest) (*api.StartResponse, error) {
return g.client.Start(ctx, request)
}
func (g *grpcV3Bridge) Delete(ctx context.Context, request *api.DeleteRequest) (*api.DeleteResponse, error) {
return g.client.Delete(ctx, request)
}
func (g *grpcV3Bridge) Pids(ctx context.Context, request *api.PidsRequest) (*api.PidsResponse, error) {
return g.client.Pids(ctx, request)
}
func (g *grpcV3Bridge) Pause(ctx context.Context, request *api.PauseRequest) (*emptypb.Empty, error) {
return g.client.Pause(ctx, request)
}
func (g *grpcV3Bridge) Resume(ctx context.Context, request *api.ResumeRequest) (*emptypb.Empty, error) {
return g.client.Resume(ctx, request)
}
func (g *grpcV3Bridge) Checkpoint(ctx context.Context, request *api.CheckpointTaskRequest) (*emptypb.Empty, error) {
return g.client.Checkpoint(ctx, request)
}
func (g *grpcV3Bridge) Kill(ctx context.Context, request *api.KillRequest) (*emptypb.Empty, error) {
return g.client.Kill(ctx, request)
}
func (g *grpcV3Bridge) Exec(ctx context.Context, request *api.ExecProcessRequest) (*emptypb.Empty, error) {
return g.client.Exec(ctx, request)
}
func (g *grpcV3Bridge) ResizePty(ctx context.Context, request *api.ResizePtyRequest) (*emptypb.Empty, error) {
return g.client.ResizePty(ctx, request)
}
func (g *grpcV3Bridge) CloseIO(ctx context.Context, request *api.CloseIORequest) (*emptypb.Empty, error) {
return g.client.CloseIO(ctx, request)
}
func (g *grpcV3Bridge) Update(ctx context.Context, request *api.UpdateTaskRequest) (*emptypb.Empty, error) {
return g.client.Update(ctx, request)
}
func (g *grpcV3Bridge) Wait(ctx context.Context, request *api.WaitRequest) (*api.WaitResponse, error) {
return g.client.Wait(ctx, request)
}
func (g *grpcV3Bridge) Stats(ctx context.Context, request *api.StatsRequest) (*api.StatsResponse, error) {
return g.client.Stats(ctx, request)
}
func (g *grpcV3Bridge) Connect(ctx context.Context, request *api.ConnectRequest) (*api.ConnectResponse, error) {
return g.client.Connect(ctx, request)
}
func (g *grpcV3Bridge) Shutdown(ctx context.Context, request *api.ShutdownRequest) (*emptypb.Empty, error) {
return g.client.Shutdown(ctx, request)
}

View File

@ -65,8 +65,8 @@ func (m manager) Name() string {
return m.name return m.name
} }
func (m manager) Start(ctx context.Context, id string, opts shim.StartOpts) (string, error) { func (m manager) Start(ctx context.Context, id string, opts shim.StartOpts) (shim.BootstrapParams, error) {
return "", errdefs.ErrNotImplemented return shim.BootstrapParams{}, errdefs.ErrNotImplemented
} }
func (m manager) Stop(ctx context.Context, id string) (shim.StopStatus, error) { func (m manager) Stop(ctx context.Context, id string) (shim.StopStatus, error) {

View File

@ -18,6 +18,7 @@ package v2
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"os" "os"
"os/exec" "os/exec"
@ -205,14 +206,13 @@ func (m *ShimManager) Start(ctx context.Context, id string, opts runtime.CreateO
return nil, err return nil, err
} }
address, err := shimbinary.ReadAddress(filepath.Join(m.state, process.Namespace(), opts.SandboxID, "address")) params, err := restoreBootstrapParams(filepath.Join(m.state, process.Namespace(), opts.SandboxID))
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get socket address for sandbox %q: %w", opts.SandboxID, err) return nil, err
} }
// Use sandbox's socket address to handle task requests for this container. if err := writeBootstrapParams(filepath.Join(bundle.Path, "bootstrap.json"), params); err != nil {
if err := shimbinary.WriteAddress(filepath.Join(bundle.Path, "address"), address); err != nil { return nil, fmt.Errorf("failed to write bootstrap.json for bundle %s: %w", bundle.Path, err)
return nil, err
} }
shim, err := loadShim(ctx, bundle, func() {}) shim, err := loadShim(ctx, bundle, func() {})
@ -284,6 +284,39 @@ func (m *ShimManager) startShim(ctx context.Context, bundle *Bundle, id string,
return shim, nil return shim, nil
} }
// restoreBootstrapParams reads bootstrap.json to restore shim configuration.
// If its an old shim, this will perform migration - read address file and write default bootstrap
// configuration (version = 2, protocol = ttrpc, and address).
func restoreBootstrapParams(bundlePath string) (shimbinary.BootstrapParams, error) {
filePath := filepath.Join(bundlePath, "bootstrap.json")
// Read bootstrap.json if exists
if _, err := os.Stat(filePath); err == nil {
return readBootstrapParams(filePath)
} else if !errors.Is(err, os.ErrNotExist) {
return shimbinary.BootstrapParams{}, fmt.Errorf("failed to stat %s: %w", filePath, err)
}
// File not found, likely its an older shim. Try migrate.
address, err := shimbinary.ReadAddress(filepath.Join(bundlePath, "address"))
if err != nil {
return shimbinary.BootstrapParams{}, fmt.Errorf("unable to migrate shim: failed to get socket address for bundle %s: %w", bundlePath, err)
}
params := shimbinary.BootstrapParams{
Version: 2,
Address: address,
Protocol: "ttrpc",
}
if err := writeBootstrapParams(filePath, params); err != nil {
return shimbinary.BootstrapParams{}, fmt.Errorf("unable to migrate: failed to write bootstrap.json file: %w", err)
}
return params, nil
}
func (m *ShimManager) resolveRuntimePath(runtime string) (string, error) { func (m *ShimManager) resolveRuntimePath(runtime string) (string, error) {
if runtime == "" { if runtime == "" {
return "", fmt.Errorf("no runtime name") return "", fmt.Errorf("no runtime name")

View File

@ -20,7 +20,7 @@ import (
"context" "context"
"errors" "errors"
"github.com/containerd/containerd/api/runtime/task/v2" task "github.com/containerd/containerd/api/runtime/task/v3"
tasktypes "github.com/containerd/containerd/api/types/task" tasktypes "github.com/containerd/containerd/api/types/task"
"github.com/containerd/containerd/errdefs" "github.com/containerd/containerd/errdefs"
"github.com/containerd/containerd/protobuf" "github.com/containerd/containerd/protobuf"

View File

@ -30,7 +30,7 @@ import (
"github.com/containerd/cgroups/v3/cgroup1" "github.com/containerd/cgroups/v3/cgroup1"
cgroupsv2 "github.com/containerd/cgroups/v3/cgroup2" cgroupsv2 "github.com/containerd/cgroups/v3/cgroup2"
"github.com/containerd/console" "github.com/containerd/console"
"github.com/containerd/containerd/api/runtime/task/v2" "github.com/containerd/containerd/api/runtime/task/v3"
"github.com/containerd/containerd/errdefs" "github.com/containerd/containerd/errdefs"
"github.com/containerd/containerd/mount" "github.com/containerd/containerd/mount"
"github.com/containerd/containerd/namespaces" "github.com/containerd/containerd/namespaces"

View File

@ -118,15 +118,19 @@ func (m manager) Name() string {
return m.name return m.name
} }
func (manager) Start(ctx context.Context, id string, opts shim.StartOpts) (_ string, retErr error) { func (manager) Start(ctx context.Context, id string, opts shim.StartOpts) (_ shim.BootstrapParams, retErr error) {
var params shim.BootstrapParams
params.Version = 3
params.Protocol = "ttrpc"
cmd, err := newCommand(ctx, id, opts.Address, opts.TTRPCAddress, opts.Debug) cmd, err := newCommand(ctx, id, opts.Address, opts.TTRPCAddress, opts.Debug)
if err != nil { if err != nil {
return "", err return params, err
} }
grouping := id grouping := id
spec, err := readSpec() spec, err := readSpec()
if err != nil { if err != nil {
return "", err return params, err
} }
for _, group := range groupLabels { for _, group := range groupLabels {
if groupID, ok := spec.Annotations[group]; ok { if groupID, ok := spec.Annotations[group]; ok {
@ -136,7 +140,7 @@ func (manager) Start(ctx context.Context, id string, opts shim.StartOpts) (_ str
} }
address, err := shim.SocketAddress(ctx, opts.Address, grouping) address, err := shim.SocketAddress(ctx, opts.Address, grouping)
if err != nil { if err != nil {
return "", err return params, err
} }
socket, err := shim.NewSocket(address) socket, err := shim.NewSocket(address)
@ -146,19 +150,17 @@ func (manager) Start(ctx context.Context, id string, opts shim.StartOpts) (_ str
// grouping functionality where the new process should be run with the same // grouping functionality where the new process should be run with the same
// shim as an existing container // shim as an existing container
if !shim.SocketEaddrinuse(err) { if !shim.SocketEaddrinuse(err) {
return "", fmt.Errorf("create new shim socket: %w", err) return params, fmt.Errorf("create new shim socket: %w", err)
} }
if shim.CanConnect(address) { if shim.CanConnect(address) {
if err := shim.WriteAddress("address", address); err != nil { params.Address = address
return "", fmt.Errorf("write existing socket for shim: %w", err) return params, nil
}
return address, nil
} }
if err := shim.RemoveSocket(address); err != nil { if err := shim.RemoveSocket(address); err != nil {
return "", fmt.Errorf("remove pre-existing socket: %w", err) return params, fmt.Errorf("remove pre-existing socket: %w", err)
} }
if socket, err = shim.NewSocket(address); err != nil { if socket, err = shim.NewSocket(address); err != nil {
return "", fmt.Errorf("try create new shim socket 2x: %w", err) return params, fmt.Errorf("try create new shim socket 2x: %w", err)
} }
} }
defer func() { defer func() {
@ -168,14 +170,9 @@ func (manager) Start(ctx context.Context, id string, opts shim.StartOpts) (_ str
} }
}() }()
// make sure that reexec shim-v2 binary use the value if need
if err := shim.WriteAddress("address", address); err != nil {
return "", err
}
f, err := socket.File() f, err := socket.File()
if err != nil { if err != nil {
return "", err return params, err
} }
cmd.ExtraFiles = append(cmd.ExtraFiles, f) cmd.ExtraFiles = append(cmd.ExtraFiles, f)
@ -183,13 +180,13 @@ func (manager) Start(ctx context.Context, id string, opts shim.StartOpts) (_ str
goruntime.LockOSThread() goruntime.LockOSThread()
if os.Getenv("SCHED_CORE") != "" { if os.Getenv("SCHED_CORE") != "" {
if err := schedcore.Create(schedcore.ProcessGroup); err != nil { if err := schedcore.Create(schedcore.ProcessGroup); err != nil {
return "", fmt.Errorf("enable sched core support: %w", err) return params, fmt.Errorf("enable sched core support: %w", err)
} }
} }
if err := cmd.Start(); err != nil { if err := cmd.Start(); err != nil {
f.Close() f.Close()
return "", err return params, err
} }
goruntime.UnlockOSThread() goruntime.UnlockOSThread()
@ -207,27 +204,29 @@ func (manager) Start(ctx context.Context, id string, opts shim.StartOpts) (_ str
if cgroups.Mode() == cgroups.Unified { if cgroups.Mode() == cgroups.Unified {
cg, err := cgroupsv2.Load(opts.ShimCgroup) cg, err := cgroupsv2.Load(opts.ShimCgroup)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to load cgroup %s: %w", opts.ShimCgroup, err) return params, fmt.Errorf("failed to load cgroup %s: %w", opts.ShimCgroup, err)
} }
if err := cg.AddProc(uint64(cmd.Process.Pid)); err != nil { if err := cg.AddProc(uint64(cmd.Process.Pid)); err != nil {
return "", fmt.Errorf("failed to join cgroup %s: %w", opts.ShimCgroup, err) return params, fmt.Errorf("failed to join cgroup %s: %w", opts.ShimCgroup, err)
} }
} else { } else {
cg, err := cgroup1.Load(cgroup1.StaticPath(opts.ShimCgroup)) cg, err := cgroup1.Load(cgroup1.StaticPath(opts.ShimCgroup))
if err != nil { if err != nil {
return "", fmt.Errorf("failed to load cgroup %s: %w", opts.ShimCgroup, err) return params, fmt.Errorf("failed to load cgroup %s: %w", opts.ShimCgroup, err)
} }
if err := cg.AddProc(uint64(cmd.Process.Pid)); err != nil { if err := cg.AddProc(uint64(cmd.Process.Pid)); err != nil {
return "", fmt.Errorf("failed to join cgroup %s: %w", opts.ShimCgroup, err) return params, fmt.Errorf("failed to join cgroup %s: %w", opts.ShimCgroup, err)
} }
} }
} }
} }
if err := shim.AdjustOOMScore(cmd.Process.Pid); err != nil { if err := shim.AdjustOOMScore(cmd.Process.Pid); err != nil {
return "", fmt.Errorf("failed to adjust OOM score for shim: %w", err) return params, fmt.Errorf("failed to adjust OOM score for shim: %w", err)
} }
return address, nil
params.Address = address
return params, nil
} }
func (manager) Stop(ctx context.Context, id string) (shim.StopStatus, error) { func (manager) Stop(ctx context.Context, id string) (shim.StopStatus, error) {

View File

@ -28,7 +28,7 @@ import (
"github.com/containerd/cgroups/v3/cgroup1" "github.com/containerd/cgroups/v3/cgroup1"
cgroupsv2 "github.com/containerd/cgroups/v3/cgroup2" cgroupsv2 "github.com/containerd/cgroups/v3/cgroup2"
eventstypes "github.com/containerd/containerd/api/events" eventstypes "github.com/containerd/containerd/api/events"
taskAPI "github.com/containerd/containerd/api/runtime/task/v2" taskAPI "github.com/containerd/containerd/api/runtime/task/v3"
"github.com/containerd/containerd/api/types/task" "github.com/containerd/containerd/api/types/task"
"github.com/containerd/containerd/errdefs" "github.com/containerd/containerd/errdefs"
"github.com/containerd/containerd/namespaces" "github.com/containerd/containerd/namespaces"
@ -58,7 +58,7 @@ var (
) )
// NewTaskService creates a new instance of a task service // NewTaskService creates a new instance of a task service
func NewTaskService(ctx context.Context, publisher shim.Publisher, sd shutdown.Service) (taskAPI.TaskService, error) { func NewTaskService(ctx context.Context, publisher shim.Publisher, sd shutdown.Service) (taskAPI.TTRPCTaskService, error) {
var ( var (
ep oom.Watcher ep oom.Watcher
err error err error
@ -252,7 +252,7 @@ func (s *service) Create(ctx context.Context, r *taskAPI.CreateTaskRequest) (_ *
} }
func (s *service) RegisterTTRPC(server *ttrpc.Server) error { func (s *service) RegisterTTRPC(server *ttrpc.Server) error {
taskAPI.RegisterTaskService(server, s) taskAPI.RegisterTTRPCTaskService(server, s)
return nil return nil
} }

View File

@ -27,13 +27,14 @@ import (
"strings" "strings"
"time" "time"
"github.com/containerd/containerd/pkg/atomicfile"
"github.com/containerd/ttrpc" "github.com/containerd/ttrpc"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/connectivity" "google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
eventstypes "github.com/containerd/containerd/api/events" eventstypes "github.com/containerd/containerd/api/events"
"github.com/containerd/containerd/api/runtime/task/v2" task "github.com/containerd/containerd/api/runtime/task/v3"
"github.com/containerd/containerd/api/types" "github.com/containerd/containerd/api/types"
"github.com/containerd/containerd/errdefs" "github.com/containerd/containerd/errdefs"
"github.com/containerd/containerd/events/exchange" "github.com/containerd/containerd/events/exchange"
@ -59,20 +60,7 @@ func init() {
timeout.Set(shutdownTimeout, 3*time.Second) timeout.Set(shutdownTimeout, 3*time.Second)
} }
func loadAddress(path string) ([]byte, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
return data, nil
}
func loadShim(ctx context.Context, bundle *Bundle, onClose func()) (_ ShimInstance, retErr error) { func loadShim(ctx context.Context, bundle *Bundle, onClose func()) (_ ShimInstance, retErr error) {
address, err := loadAddress(filepath.Join(bundle.Path, "address"))
if err != nil {
return nil, err
}
shimCtx, cancelShimLog := context.WithCancel(ctx) shimCtx, cancelShimLog := context.WithCancel(ctx)
defer func() { defer func() {
if retErr != nil { if retErr != nil {
@ -108,14 +96,14 @@ func loadShim(ctx context.Context, bundle *Bundle, onClose func()) (_ ShimInstan
f.Close() f.Close()
} }
params, err := parseStartResponse(ctx, address) params, err := restoreBootstrapParams(bundle.Path)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to read boostrap.json when restoring bundle %q: %w", bundle.ID, err)
} }
conn, err := makeConnection(ctx, params, onCloseWithShimLog) conn, err := makeConnection(ctx, bundle.ID, params, onCloseWithShimLog)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("unable to make connection: %w", err)
} }
defer func() { defer func() {
@ -125,8 +113,9 @@ func loadShim(ctx context.Context, bundle *Bundle, onClose func()) (_ ShimInstan
}() }()
shim := &shim{ shim := &shim{
bundle: bundle, bundle: bundle,
client: conn, client: conn,
version: params.Version,
} }
return shim, nil return shim, nil
@ -177,6 +166,9 @@ func cleanupAfterDeadShim(ctx context.Context, id string, rt *runtime.NSMap[Shim
}) })
} }
// CurrentShimVersion is the latest shim version supported by containerd (e.g. TaskService v3).
const CurrentShimVersion = 3
// ShimInstance represents running shim process managed by ShimManager. // ShimInstance represents running shim process managed by ShimManager.
type ShimInstance interface { type ShimInstance interface {
io.Closer io.Closer
@ -192,31 +184,75 @@ type ShimInstance interface {
Client() any Client() any
// Delete will close the client and remove bundle from disk. // Delete will close the client and remove bundle from disk.
Delete(ctx context.Context) error Delete(ctx context.Context) error
// Version returns shim's features compatibility version.
Version() int
} }
func parseStartResponse(ctx context.Context, response []byte) (client.BootstrapParams, error) { func parseStartResponse(response []byte) (client.BootstrapParams, error) {
var params client.BootstrapParams var params client.BootstrapParams
if err := json.Unmarshal(response, &params); err != nil || params.Version < 2 { if err := json.Unmarshal(response, &params); err != nil || params.Version < 2 {
// Use TTRPC for legacy shims // Use TTRPC for legacy shims
params.Address = string(response) params.Address = string(response)
params.Protocol = "ttrpc" params.Protocol = "ttrpc"
params.Version = 2
} }
if params.Version > 2 { if params.Version > CurrentShimVersion {
return client.BootstrapParams{}, fmt.Errorf("unsupported shim version (%d): %w", params.Version, errdefs.ErrNotImplemented) return client.BootstrapParams{}, fmt.Errorf("unsupported shim version (%d): %w", params.Version, errdefs.ErrNotImplemented)
} }
return params, nil return params, nil
} }
// writeBootstrapParams writes shim's bootstrap configuration (e.g. how to connect, version, etc).
func writeBootstrapParams(path string, params client.BootstrapParams) error {
path, err := filepath.Abs(path)
if err != nil {
return err
}
data, err := json.Marshal(&params)
if err != nil {
return err
}
f, err := atomicfile.New(path, 0o666)
if err != nil {
return err
}
_, err = f.Write(data)
if err != nil {
f.Cancel()
return err
}
return f.Close()
}
func readBootstrapParams(path string) (client.BootstrapParams, error) {
path, err := filepath.Abs(path)
if err != nil {
return client.BootstrapParams{}, err
}
data, err := os.ReadFile(path)
if err != nil {
return client.BootstrapParams{}, err
}
return parseStartResponse(data)
}
// makeConnection creates a new TTRPC or GRPC connection object from address. // makeConnection creates a new TTRPC or GRPC connection object from address.
// address can be either a socket path for TTRPC or JSON serialized BootstrapParams. // address can be either a socket path for TTRPC or JSON serialized BootstrapParams.
func makeConnection(ctx context.Context, params client.BootstrapParams, onClose func()) (_ io.Closer, retErr error) { func makeConnection(ctx context.Context, id string, params client.BootstrapParams, onClose func()) (_ io.Closer, retErr error) {
log.G(ctx).WithFields(log.Fields{ log.G(ctx).WithFields(log.Fields{
"address": params.Address, "address": params.Address,
"protocol": params.Protocol, "protocol": params.Protocol,
}).Debug("shim bootstrap parameters") "version": params.Version,
}).Infof("connecting to shim %s", id)
switch strings.ToLower(params.Protocol) { switch strings.ToLower(params.Protocol) {
case "ttrpc": case "ttrpc":
@ -303,8 +339,9 @@ func (gc *grpcConn) UserOnCloseWait(ctx context.Context) error {
} }
type shim struct { type shim struct {
bundle *Bundle bundle *Bundle
client any client any
version int
} }
var _ ShimInstance = (*shim)(nil) var _ ShimInstance = (*shim)(nil)
@ -314,6 +351,10 @@ func (s *shim) ID() string {
return s.bundle.ID return s.bundle.ID
} }
func (s *shim) Version() int {
return s.version
}
func (s *shim) Namespace() string { func (s *shim) Namespace() string {
return s.bundle.Namespace return s.bundle.Namespace
} }
@ -375,11 +416,11 @@ var _ runtime.Task = &shimTask{}
// shimTask wraps shim process and adds task service client for compatibility with existing shim manager. // shimTask wraps shim process and adds task service client for compatibility with existing shim manager.
type shimTask struct { type shimTask struct {
ShimInstance ShimInstance
task task.TaskService task TaskServiceClient
} }
func newShimTask(shim ShimInstance) (*shimTask, error) { func newShimTask(shim ShimInstance) (*shimTask, error) {
taskClient, err := NewTaskClient(shim.Client()) taskClient, err := NewTaskClient(shim.Client(), shim.Version())
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -18,6 +18,7 @@ package shim
import ( import (
"context" "context"
"encoding/json"
"errors" "errors"
"flag" "flag"
"fmt" "fmt"
@ -29,7 +30,7 @@ import (
"runtime/debug" "runtime/debug"
"time" "time"
shimapi "github.com/containerd/containerd/api/runtime/task/v2" shimapi "github.com/containerd/containerd/api/runtime/task/v3"
"github.com/containerd/containerd/events" "github.com/containerd/containerd/events"
"github.com/containerd/containerd/namespaces" "github.com/containerd/containerd/namespaces"
"github.com/containerd/containerd/pkg/shutdown" "github.com/containerd/containerd/pkg/shutdown"
@ -76,7 +77,7 @@ type StopStatus struct {
// Manager is the interface which manages the shim process // Manager is the interface which manages the shim process
type Manager interface { type Manager interface {
Name() string Name() string
Start(ctx context.Context, id string, opts StartOpts) (string, error) Start(ctx context.Context, id string, opts StartOpts) (BootstrapParams, error)
Stop(ctx context.Context, id string) (StopStatus, error) Stop(ctx context.Context, id string) (StopStatus, error)
} }
@ -268,13 +269,20 @@ func run(ctx context.Context, manager Manager, name string, config Config) error
Debug: debugFlag, Debug: debugFlag,
} }
address, err := manager.Start(ctx, id, opts) params, err := manager.Start(ctx, id, opts)
if err != nil { if err != nil {
return err return err
} }
if _, err := os.Stdout.WriteString(address); err != nil {
data, err := json.Marshal(&params)
if err != nil {
return fmt.Errorf("failed to marshal bootstrap params to json: %w", err)
}
if _, err := os.Stdout.Write(data); err != nil {
return err return err
} }
return nil return nil
} }

View File

@ -138,24 +138,6 @@ func WritePidFile(path string, pid int) error {
return f.Close() return f.Close()
} }
// WriteAddress writes a address file atomically
func WriteAddress(path, address string) error {
path, err := filepath.Abs(path)
if err != nil {
return err
}
f, err := atomicfile.New(path, 0o666)
if err != nil {
return err
}
_, err = f.Write([]byte(address))
if err != nil {
f.Cancel()
return err
}
return f.Close()
}
// ErrNoAddress is returned when the address file has no content // ErrNoAddress is returned when the address file has no content
var ErrNoAddress = errors.New("no shim address") var ErrNoAddress = errors.New("no shim address")

View File

@ -150,6 +150,7 @@ func (m *ShimManager) loadShims(ctx context.Context) error {
m.shims.Delete(ctx, id) m.shims.Delete(ctx, id)
}) })
if err != nil { if err != nil {
log.G(ctx).WithError(err).Errorf("unable to load shim %q", id)
cleanupAfterDeadShim(ctx, id, m.shims, m.events, binaryCall) cleanupAfterDeadShim(ctx, id, m.shims, m.events, binaryCall)
continue continue
} }

View File

@ -17,12 +17,14 @@
package v2 package v2
import ( import (
"context"
"errors" "errors"
"os"
"path/filepath"
"testing" "testing"
"github.com/containerd/containerd/errdefs" "github.com/containerd/containerd/errdefs"
client "github.com/containerd/containerd/runtime/v2/shim" client "github.com/containerd/containerd/runtime/v2/shim"
"github.com/stretchr/testify/require"
) )
func TestParseStartResponse(t *testing.T) { func TestParseStartResponse(t *testing.T) {
@ -36,7 +38,7 @@ func TestParseStartResponse(t *testing.T) {
Name: "v2 shim", Name: "v2 shim",
Response: "/somedirectory/somesocket", Response: "/somedirectory/somesocket",
Expected: client.BootstrapParams{ Expected: client.BootstrapParams{
Version: 0, Version: 2,
Address: "/somedirectory/somesocket", Address: "/somedirectory/somesocket",
Protocol: "ttrpc", Protocol: "ttrpc",
}, },
@ -63,20 +65,20 @@ func TestParseStartResponse(t *testing.T) {
Name: "invalid shim v2 response", Name: "invalid shim v2 response",
Response: `{"address":"/somedirectory/somesocket","protocol":"ttrpc"}`, Response: `{"address":"/somedirectory/somesocket","protocol":"ttrpc"}`,
Expected: client.BootstrapParams{ Expected: client.BootstrapParams{
Version: 0, Version: 2,
Address: `{"address":"/somedirectory/somesocket","protocol":"ttrpc"}`, Address: `{"address":"/somedirectory/somesocket","protocol":"ttrpc"}`,
Protocol: "ttrpc", Protocol: "ttrpc",
}, },
}, },
{ {
Name: "later unsupported shim", Name: "later unsupported shim",
Response: `{"Version": 3,"Address":"/somedirectory/somesocket","Protocol":"ttrpc"}`, Response: `{"Version": 4,"Address":"/somedirectory/somesocket","Protocol":"ttrpc"}`,
Expected: client.BootstrapParams{}, Expected: client.BootstrapParams{},
Err: errdefs.ErrNotImplemented, Err: errdefs.ErrNotImplemented,
}, },
} { } {
t.Run(tc.Name, func(t *testing.T) { t.Run(tc.Name, func(t *testing.T) {
params, err := parseStartResponse(context.Background(), []byte(tc.Response)) params, err := parseStartResponse([]byte(tc.Response))
if err != nil { if err != nil {
if !errors.Is(err, tc.Err) { if !errors.Is(err, tc.Err) {
t.Errorf("unexpected error: %v", err) t.Errorf("unexpected error: %v", err)
@ -96,5 +98,27 @@ func TestParseStartResponse(t *testing.T) {
} }
}) })
} }
}
func TestRestoreBootstrapParams(t *testing.T) {
bundlePath := t.TempDir()
err := os.WriteFile(filepath.Join(bundlePath, "address"), []byte("unix://123"), 0o666)
require.NoError(t, err)
restored, err := restoreBootstrapParams(bundlePath)
require.NoError(t, err)
expected := client.BootstrapParams{
Version: 2,
Address: "unix://123",
Protocol: "ttrpc",
}
require.EqualValues(t, expected, restored)
loaded, err := readBootstrapParams(filepath.Join(bundlePath, "bootstrap.json"))
require.NoError(t, err)
require.EqualValues(t, expected, loaded)
} }