diff --git a/config.go b/config.go index 0974196..f401f67 100644 --- a/config.go +++ b/config.go @@ -16,7 +16,10 @@ package ttrpc -import "errors" +import ( + "context" + "errors" +) type serverConfig struct { handshaker Handshaker @@ -44,9 +47,40 @@ func WithServerHandshaker(handshaker Handshaker) ServerOpt { func WithUnaryServerInterceptor(i UnaryServerInterceptor) ServerOpt { return func(c *serverConfig) error { if c.interceptor != nil { - return errors.New("only one interceptor allowed per server") + return errors.New("only one unchained interceptor allowed per server") } c.interceptor = i return nil } } + +// WithChainUnaryServerInterceptor sets the provided chain of server interceptors +func WithChainUnaryServerInterceptor(interceptors ...UnaryServerInterceptor) ServerOpt { + return func(c *serverConfig) error { + if len(interceptors) == 0 { + return nil + } + if c.interceptor != nil { + interceptors = append([]UnaryServerInterceptor{c.interceptor}, interceptors...) + } + c.interceptor = func( + ctx context.Context, + unmarshal Unmarshaler, + info *UnaryServerInfo, + method Method) (interface{}, error) { + return interceptors[0](ctx, unmarshal, info, + chainUnaryServerInterceptors(info, method, interceptors[1:])) + } + return nil + } +} + +func chainUnaryServerInterceptors(info *UnaryServerInfo, method Method, interceptors []UnaryServerInterceptor) Method { + if len(interceptors) == 0 { + return method + } + return func(ctx context.Context, unmarshal func(interface{}) error) (interface{}, error) { + return interceptors[0](ctx, unmarshal, info, + chainUnaryServerInterceptors(info, method, interceptors[1:])) + } +} diff --git a/interceptor_test.go b/interceptor_test.go index 39d82f8..47fa28f 100644 --- a/interceptor_test.go +++ b/interceptor_test.go @@ -133,3 +133,113 @@ func TestChainUnaryClientInterceptor(t *testing.T) { t.Fatalf("unexpected test service reply: %q != %q", response.Foo, reply) } } + +func TestUnaryServerInterceptor(t *testing.T) { + var ( + intercepted = false + interceptor = func(ctx context.Context, unmarshal Unmarshaler, _ *UnaryServerInfo, method Method) (interface{}, error) { + intercepted = true + return method(ctx, unmarshal) + } + + ctx = context.Background() + server = mustServer(t)(NewServer(WithUnaryServerInterceptor(interceptor))) + testImpl = &testingServer{} + addr, listener = newTestListener(t) + client, cleanup = newTestClient(t, addr) + message = strings.Repeat("a", 16) + reply = strings.Repeat(message, 2) + ) + + defer listener.Close() + defer cleanup() + + registerTestingService(server, testImpl) + + go server.Serve(ctx, listener) + defer server.Shutdown(ctx) + + request := &internal.TestPayload{ + Foo: message, + } + response := &internal.TestPayload{} + if err := client.Call(ctx, serviceName, "Test", request, response); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !intercepted { + t.Fatalf("ttrpc server call not intercepted") + } + + if response.Foo != reply { + t.Fatalf("unexpected test service reply: %q != %q", response.Foo, reply) + } +} + +func TestChainUnaryServerInterceptor(t *testing.T) { + var ( + orderIdx = 0 + recorded = []string{} + intercept = func(idx int, tag string) UnaryServerInterceptor { + return func(ctx context.Context, unmarshal Unmarshaler, _ *UnaryServerInfo, method Method) (interface{}, error) { + if orderIdx != idx { + t.Fatalf("unexpected interceptor invocation order (%d != %d)", orderIdx, idx) + } + recorded = append(recorded, tag) + orderIdx++ + return method(ctx, unmarshal) + } + } + + ctx = context.Background() + server = mustServer(t)(NewServer( + WithUnaryServerInterceptor( + intercept(0, "seen it"), + ), + WithChainUnaryServerInterceptor( + intercept(1, "been"), + intercept(2, "there"), + intercept(3, "done"), + intercept(4, "that"), + ), + )) + expected = []string{ + "seen it", + "been", + "there", + "done", + "that", + } + testImpl = &testingServer{} + addr, listener = newTestListener(t) + client, cleanup = newTestClient(t, addr) + message = strings.Repeat("a", 16) + reply = strings.Repeat(message, 2) + ) + + defer listener.Close() + defer cleanup() + + registerTestingService(server, testImpl) + + go server.Serve(ctx, listener) + defer server.Shutdown(ctx) + + request := &internal.TestPayload{ + Foo: message, + } + response := &internal.TestPayload{} + + if err := client.Call(ctx, serviceName, "Test", request, response); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !reflect.DeepEqual(recorded, expected) { + t.Fatalf("unexpected ttrpc chained server unary interceptor order (%s != %s)", + strings.Join(recorded, " "), strings.Join(expected, " ")) + } + + if response.Foo != reply { + t.Fatalf("unexpected test service reply: %q != %q", response.Foo, reply) + } +}