From 822cc51d840775f0c9aa48e6e93cecc657a90f06 Mon Sep 17 00:00:00 2001 From: Wei Fu Date: Fri, 17 Jun 2022 00:39:13 +0800 Subject: [PATCH] runtime/v2: manager supports server interceptor Currently, the runc shimv2 commandline manager doesn't support ttrpc server's customized option, for example, the ttrpc server interceptor. This commit is to allow the task plugin can return the `UnaryServerInterceptor` option to the manager so that the task plugin can do enhancement before handling the incoming request, like API-level failpoint control. Signed-off-by: Wei Fu --- runtime/v2/shim/shim.go | 16 ++++- runtime/v2/shim/shim_darwin.go | 4 +- runtime/v2/shim/shim_freebsd.go | 4 +- runtime/v2/shim/shim_linux.go | 5 +- runtime/v2/shim/shim_windows.go | 2 +- runtime/v2/shim/util.go | 26 +++++++ runtime/v2/shim/util_test.go | 118 ++++++++++++++++++++++++++++++++ 7 files changed, 167 insertions(+), 8 deletions(-) create mode 100644 runtime/v2/shim/util_test.go diff --git a/runtime/v2/shim/shim.go b/runtime/v2/shim/shim.go index 9fa319367..7487ca043 100644 --- a/runtime/v2/shim/shim.go +++ b/runtime/v2/shim/shim.go @@ -108,6 +108,12 @@ type ttrpcService interface { RegisterTTRPC(*ttrpc.Server) error } +type ttrpcServerOptioner interface { + ttrpcService + + UnaryInterceptor() ttrpc.UnaryServerInterceptor +} + type taskService struct { shimapi.TaskService } @@ -370,6 +376,8 @@ func run(ctx context.Context, manager Manager, initFunc Init, name string, confi var ( initialized = plugin.NewPluginSet() ttrpcServices = []ttrpcService{} + + ttrpcUnaryInterceptors = []ttrpc.UnaryServerInterceptor{} ) plugins := plugin.Graph(func(*plugin.Registration) bool { return false }) for _, p := range plugins { @@ -418,10 +426,16 @@ func run(ctx context.Context, manager Manager, initFunc Init, name string, confi if src, ok := instance.(ttrpcService); ok { logrus.WithField("id", id).Debug("registering ttrpc service") ttrpcServices = append(ttrpcServices, src) + + } + + if src, ok := instance.(ttrpcServerOptioner); ok { + ttrpcUnaryInterceptors = append(ttrpcUnaryInterceptors, src.UnaryInterceptor()) } } - server, err := newServer() + unaryInterceptor := chainUnaryServerInterceptors(ttrpcUnaryInterceptors...) + server, err := newServer(ttrpc.WithUnaryServerInterceptor(unaryInterceptor)) if err != nil { return fmt.Errorf("failed creating server: %w", err) } diff --git a/runtime/v2/shim/shim_darwin.go b/runtime/v2/shim/shim_darwin.go index fe833df01..0bdf289bb 100644 --- a/runtime/v2/shim/shim_darwin.go +++ b/runtime/v2/shim/shim_darwin.go @@ -18,8 +18,8 @@ package shim import "github.com/containerd/ttrpc" -func newServer() (*ttrpc.Server, error) { - return ttrpc.NewServer() +func newServer(opts ...ttrpc.ServerOpt) (*ttrpc.Server, error) { + return ttrpc.NewServer(opts...) } func subreaper() error { diff --git a/runtime/v2/shim/shim_freebsd.go b/runtime/v2/shim/shim_freebsd.go index fe833df01..0bdf289bb 100644 --- a/runtime/v2/shim/shim_freebsd.go +++ b/runtime/v2/shim/shim_freebsd.go @@ -18,8 +18,8 @@ package shim import "github.com/containerd/ttrpc" -func newServer() (*ttrpc.Server, error) { - return ttrpc.NewServer() +func newServer(opts ...ttrpc.ServerOpt) (*ttrpc.Server, error) { + return ttrpc.NewServer(opts...) } func subreaper() error { diff --git a/runtime/v2/shim/shim_linux.go b/runtime/v2/shim/shim_linux.go index 06266a533..1c05c2c56 100644 --- a/runtime/v2/shim/shim_linux.go +++ b/runtime/v2/shim/shim_linux.go @@ -21,8 +21,9 @@ import ( "github.com/containerd/ttrpc" ) -func newServer() (*ttrpc.Server, error) { - return ttrpc.NewServer(ttrpc.WithServerHandshaker(ttrpc.UnixSocketRequireSameUser())) +func newServer(opts ...ttrpc.ServerOpt) (*ttrpc.Server, error) { + opts = append(opts, ttrpc.WithServerHandshaker(ttrpc.UnixSocketRequireSameUser())) + return ttrpc.NewServer(opts...) } func subreaper() error { diff --git a/runtime/v2/shim/shim_windows.go b/runtime/v2/shim/shim_windows.go index 4b098ab16..2add7ac33 100644 --- a/runtime/v2/shim/shim_windows.go +++ b/runtime/v2/shim/shim_windows.go @@ -31,7 +31,7 @@ func setupSignals(config Config) (chan os.Signal, error) { return nil, errors.New("not supported") } -func newServer() (*ttrpc.Server, error) { +func newServer(opts ...ttrpc.ServerOpt) (*ttrpc.Server, error) { return nil, errors.New("not supported") } diff --git a/runtime/v2/shim/util.go b/runtime/v2/shim/util.go index bf78fa58f..d1cd47946 100644 --- a/runtime/v2/shim/util.go +++ b/runtime/v2/shim/util.go @@ -30,6 +30,7 @@ import ( "github.com/containerd/containerd/namespaces" "github.com/containerd/containerd/protobuf/proto" "github.com/containerd/containerd/protobuf/types" + "github.com/containerd/ttrpc" exec "golang.org/x/sys/execabs" ) @@ -167,3 +168,28 @@ func ReadAddress(path string) (string, error) { } return string(data), nil } + +// chainUnaryServerInterceptors creates a single ttrpc server interceptor from +// a chain of many interceptors executed from first to last. +func chainUnaryServerInterceptors(interceptors ...ttrpc.UnaryServerInterceptor) ttrpc.UnaryServerInterceptor { + n := len(interceptors) + + // force to use default interceptor in ttrpc + if n == 0 { + return nil + } + + return func(ctx context.Context, unmarshal ttrpc.Unmarshaler, info *ttrpc.UnaryServerInfo, method ttrpc.Method) (interface{}, error) { + currentMethod := method + + for i := n - 1; i > 0; i-- { + interceptor := interceptors[i] + innerMethod := currentMethod + + currentMethod = func(currentCtx context.Context, currentUnmarshal func(interface{}) error) (interface{}, error) { + return interceptor(currentCtx, currentUnmarshal, info, innerMethod) + } + } + return interceptors[0](ctx, unmarshal, info, currentMethod) + } +} diff --git a/runtime/v2/shim/util_test.go b/runtime/v2/shim/util_test.go new file mode 100644 index 000000000..8341bcddb --- /dev/null +++ b/runtime/v2/shim/util_test.go @@ -0,0 +1,118 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package shim + +import ( + "context" + "path/filepath" + "reflect" + "testing" + + "github.com/containerd/ttrpc" +) + +func TestChainUnaryServerInterceptors(t *testing.T) { + methodInfo := &ttrpc.UnaryServerInfo{ + FullMethod: filepath.Join("/", t.Name(), "foo"), + } + + type callKey struct{} + callValue := "init" + callCtx := context.WithValue(context.Background(), callKey{}, callValue) + + verifyCallCtxFn := func(ctx context.Context, key interface{}, expected interface{}) { + got := ctx.Value(key) + if !reflect.DeepEqual(expected, got) { + t.Fatalf("[context(key:%s) expected %v, but got %v", key, expected, got) + } + } + + verifyInfoFn := func(info *ttrpc.UnaryServerInfo) { + if !reflect.DeepEqual(methodInfo, info) { + t.Fatalf("[info] expected %+v, but got %+v", methodInfo, info) + } + } + + origUnmarshaler := func(obj interface{}) error { + v := obj.(*int64) + *v *= 2 + return nil + } + + type firstKey struct{} + firstValue := "from first" + var firstUnmarshaler ttrpc.Unmarshaler + first := func(ctx context.Context, unmarshal ttrpc.Unmarshaler, info *ttrpc.UnaryServerInfo, method ttrpc.Method) (interface{}, error) { + verifyCallCtxFn(ctx, callKey{}, callValue) + verifyInfoFn(info) + + ctx = context.WithValue(ctx, firstKey{}, firstValue) + + firstUnmarshaler = func(obj interface{}) error { + if err := unmarshal(obj); err != nil { + return err + } + + v := obj.(*int64) + *v *= 2 + return nil + } + + return method(ctx, firstUnmarshaler) + } + + type secondKey struct{} + secondValue := "from second" + second := func(ctx context.Context, unmarshal ttrpc.Unmarshaler, info *ttrpc.UnaryServerInfo, method ttrpc.Method) (interface{}, error) { + verifyCallCtxFn(ctx, callKey{}, callValue) + verifyCallCtxFn(ctx, firstKey{}, firstValue) + verifyInfoFn(info) + + v := int64(3) // should return 12 + if err := unmarshal(&v); err != nil { + t.Fatalf("unexpected error %v", err) + } + if expected := int64(12); v != expected { + t.Fatalf("expected int64(%v), but got %v", expected, v) + } + + ctx = context.WithValue(ctx, secondKey{}, secondValue) + return method(ctx, unmarshal) + } + + methodFn := func(ctx context.Context, unmarshal func(interface{}) error) (interface{}, error) { + verifyCallCtxFn(ctx, callKey{}, callValue) + verifyCallCtxFn(ctx, firstKey{}, firstValue) + verifyCallCtxFn(ctx, secondKey{}, secondValue) + + v := int64(2) + if err := unmarshal(&v); err != nil { + return nil, err + } + return v, nil + } + + interceptor := chainUnaryServerInterceptors(first, second) + v, err := interceptor(callCtx, origUnmarshaler, methodInfo, methodFn) + if err != nil { + t.Fatalf("expected nil, but got %v", err) + } + + if expected := int64(8); v != expected { + t.Fatalf("expected result is int64(%v), but got %v", expected, v) + } +}