diff --git a/runtime/v2/shim/shim.go b/runtime/v2/shim/shim.go index 9fa319367..7487ca043 100644 --- a/runtime/v2/shim/shim.go +++ b/runtime/v2/shim/shim.go @@ -108,6 +108,12 @@ type ttrpcService interface { RegisterTTRPC(*ttrpc.Server) error } +type ttrpcServerOptioner interface { + ttrpcService + + UnaryInterceptor() ttrpc.UnaryServerInterceptor +} + type taskService struct { shimapi.TaskService } @@ -370,6 +376,8 @@ func run(ctx context.Context, manager Manager, initFunc Init, name string, confi var ( initialized = plugin.NewPluginSet() ttrpcServices = []ttrpcService{} + + ttrpcUnaryInterceptors = []ttrpc.UnaryServerInterceptor{} ) plugins := plugin.Graph(func(*plugin.Registration) bool { return false }) for _, p := range plugins { @@ -418,10 +426,16 @@ func run(ctx context.Context, manager Manager, initFunc Init, name string, confi if src, ok := instance.(ttrpcService); ok { logrus.WithField("id", id).Debug("registering ttrpc service") ttrpcServices = append(ttrpcServices, src) + + } + + if src, ok := instance.(ttrpcServerOptioner); ok { + ttrpcUnaryInterceptors = append(ttrpcUnaryInterceptors, src.UnaryInterceptor()) } } - server, err := newServer() + unaryInterceptor := chainUnaryServerInterceptors(ttrpcUnaryInterceptors...) + server, err := newServer(ttrpc.WithUnaryServerInterceptor(unaryInterceptor)) if err != nil { return fmt.Errorf("failed creating server: %w", err) } diff --git a/runtime/v2/shim/shim_darwin.go b/runtime/v2/shim/shim_darwin.go index fe833df01..0bdf289bb 100644 --- a/runtime/v2/shim/shim_darwin.go +++ b/runtime/v2/shim/shim_darwin.go @@ -18,8 +18,8 @@ package shim import "github.com/containerd/ttrpc" -func newServer() (*ttrpc.Server, error) { - return ttrpc.NewServer() +func newServer(opts ...ttrpc.ServerOpt) (*ttrpc.Server, error) { + return ttrpc.NewServer(opts...) } func subreaper() error { diff --git a/runtime/v2/shim/shim_freebsd.go b/runtime/v2/shim/shim_freebsd.go index fe833df01..0bdf289bb 100644 --- a/runtime/v2/shim/shim_freebsd.go +++ b/runtime/v2/shim/shim_freebsd.go @@ -18,8 +18,8 @@ package shim import "github.com/containerd/ttrpc" -func newServer() (*ttrpc.Server, error) { - return ttrpc.NewServer() +func newServer(opts ...ttrpc.ServerOpt) (*ttrpc.Server, error) { + return ttrpc.NewServer(opts...) } func subreaper() error { diff --git a/runtime/v2/shim/shim_linux.go b/runtime/v2/shim/shim_linux.go index 06266a533..1c05c2c56 100644 --- a/runtime/v2/shim/shim_linux.go +++ b/runtime/v2/shim/shim_linux.go @@ -21,8 +21,9 @@ import ( "github.com/containerd/ttrpc" ) -func newServer() (*ttrpc.Server, error) { - return ttrpc.NewServer(ttrpc.WithServerHandshaker(ttrpc.UnixSocketRequireSameUser())) +func newServer(opts ...ttrpc.ServerOpt) (*ttrpc.Server, error) { + opts = append(opts, ttrpc.WithServerHandshaker(ttrpc.UnixSocketRequireSameUser())) + return ttrpc.NewServer(opts...) } func subreaper() error { diff --git a/runtime/v2/shim/shim_windows.go b/runtime/v2/shim/shim_windows.go index 4b098ab16..2add7ac33 100644 --- a/runtime/v2/shim/shim_windows.go +++ b/runtime/v2/shim/shim_windows.go @@ -31,7 +31,7 @@ func setupSignals(config Config) (chan os.Signal, error) { return nil, errors.New("not supported") } -func newServer() (*ttrpc.Server, error) { +func newServer(opts ...ttrpc.ServerOpt) (*ttrpc.Server, error) { return nil, errors.New("not supported") } diff --git a/runtime/v2/shim/util.go b/runtime/v2/shim/util.go index bf78fa58f..d1cd47946 100644 --- a/runtime/v2/shim/util.go +++ b/runtime/v2/shim/util.go @@ -30,6 +30,7 @@ import ( "github.com/containerd/containerd/namespaces" "github.com/containerd/containerd/protobuf/proto" "github.com/containerd/containerd/protobuf/types" + "github.com/containerd/ttrpc" exec "golang.org/x/sys/execabs" ) @@ -167,3 +168,28 @@ func ReadAddress(path string) (string, error) { } return string(data), nil } + +// chainUnaryServerInterceptors creates a single ttrpc server interceptor from +// a chain of many interceptors executed from first to last. +func chainUnaryServerInterceptors(interceptors ...ttrpc.UnaryServerInterceptor) ttrpc.UnaryServerInterceptor { + n := len(interceptors) + + // force to use default interceptor in ttrpc + if n == 0 { + return nil + } + + return func(ctx context.Context, unmarshal ttrpc.Unmarshaler, info *ttrpc.UnaryServerInfo, method ttrpc.Method) (interface{}, error) { + currentMethod := method + + for i := n - 1; i > 0; i-- { + interceptor := interceptors[i] + innerMethod := currentMethod + + currentMethod = func(currentCtx context.Context, currentUnmarshal func(interface{}) error) (interface{}, error) { + return interceptor(currentCtx, currentUnmarshal, info, innerMethod) + } + } + return interceptors[0](ctx, unmarshal, info, currentMethod) + } +} diff --git a/runtime/v2/shim/util_test.go b/runtime/v2/shim/util_test.go new file mode 100644 index 000000000..8341bcddb --- /dev/null +++ b/runtime/v2/shim/util_test.go @@ -0,0 +1,118 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package shim + +import ( + "context" + "path/filepath" + "reflect" + "testing" + + "github.com/containerd/ttrpc" +) + +func TestChainUnaryServerInterceptors(t *testing.T) { + methodInfo := &ttrpc.UnaryServerInfo{ + FullMethod: filepath.Join("/", t.Name(), "foo"), + } + + type callKey struct{} + callValue := "init" + callCtx := context.WithValue(context.Background(), callKey{}, callValue) + + verifyCallCtxFn := func(ctx context.Context, key interface{}, expected interface{}) { + got := ctx.Value(key) + if !reflect.DeepEqual(expected, got) { + t.Fatalf("[context(key:%s) expected %v, but got %v", key, expected, got) + } + } + + verifyInfoFn := func(info *ttrpc.UnaryServerInfo) { + if !reflect.DeepEqual(methodInfo, info) { + t.Fatalf("[info] expected %+v, but got %+v", methodInfo, info) + } + } + + origUnmarshaler := func(obj interface{}) error { + v := obj.(*int64) + *v *= 2 + return nil + } + + type firstKey struct{} + firstValue := "from first" + var firstUnmarshaler ttrpc.Unmarshaler + first := func(ctx context.Context, unmarshal ttrpc.Unmarshaler, info *ttrpc.UnaryServerInfo, method ttrpc.Method) (interface{}, error) { + verifyCallCtxFn(ctx, callKey{}, callValue) + verifyInfoFn(info) + + ctx = context.WithValue(ctx, firstKey{}, firstValue) + + firstUnmarshaler = func(obj interface{}) error { + if err := unmarshal(obj); err != nil { + return err + } + + v := obj.(*int64) + *v *= 2 + return nil + } + + return method(ctx, firstUnmarshaler) + } + + type secondKey struct{} + secondValue := "from second" + second := func(ctx context.Context, unmarshal ttrpc.Unmarshaler, info *ttrpc.UnaryServerInfo, method ttrpc.Method) (interface{}, error) { + verifyCallCtxFn(ctx, callKey{}, callValue) + verifyCallCtxFn(ctx, firstKey{}, firstValue) + verifyInfoFn(info) + + v := int64(3) // should return 12 + if err := unmarshal(&v); err != nil { + t.Fatalf("unexpected error %v", err) + } + if expected := int64(12); v != expected { + t.Fatalf("expected int64(%v), but got %v", expected, v) + } + + ctx = context.WithValue(ctx, secondKey{}, secondValue) + return method(ctx, unmarshal) + } + + methodFn := func(ctx context.Context, unmarshal func(interface{}) error) (interface{}, error) { + verifyCallCtxFn(ctx, callKey{}, callValue) + verifyCallCtxFn(ctx, firstKey{}, firstValue) + verifyCallCtxFn(ctx, secondKey{}, secondValue) + + v := int64(2) + if err := unmarshal(&v); err != nil { + return nil, err + } + return v, nil + } + + interceptor := chainUnaryServerInterceptors(first, second) + v, err := interceptor(callCtx, origUnmarshaler, methodInfo, methodFn) + if err != nil { + t.Fatalf("expected nil, but got %v", err) + } + + if expected := int64(8); v != expected { + t.Fatalf("expected result is int64(%v), but got %v", expected, v) + } +}