Add client and server unary interceptors

Signed-off-by: Michael Crosby <crosbymichael@gmail.com>
This commit is contained in:
Michael Crosby 2019-06-06 19:28:01 +00:00
parent a5bd8ce9e4
commit 819653f40c
5 changed files with 93 additions and 19 deletions

View File

@ -42,11 +42,12 @@ type Client struct {
channel *channel
calls chan *callRequest
closed chan struct{}
closeOnce sync.Once
closeFunc func()
done chan struct{}
err error
closed chan struct{}
closeOnce sync.Once
closeFunc func()
done chan struct{}
err error
interceptor UnaryClientInterceptor
}
type ClientOpts func(c *Client)
@ -57,15 +58,22 @@ func WithOnClose(onClose func()) ClientOpts {
}
}
func WithUnaryClientInterceptor(i UnaryClientInterceptor) ClientOpts {
return func(c *Client) {
c.interceptor = i
}
}
func NewClient(conn net.Conn, opts ...ClientOpts) *Client {
c := &Client{
codec: codec{},
conn: conn,
channel: newChannel(conn),
calls: make(chan *callRequest),
closed: make(chan struct{}),
done: make(chan struct{}),
closeFunc: func() {},
codec: codec{},
conn: conn,
channel: newChannel(conn),
calls: make(chan *callRequest),
closed: make(chan struct{}),
done: make(chan struct{}),
closeFunc: func() {},
interceptor: defaultClientInterceptor,
}
for _, o := range opts {
@ -107,7 +115,10 @@ func (c *Client) Call(ctx context.Context, service, method string, req, resp int
creq.TimeoutNano = dl.Sub(time.Now()).Nanoseconds()
}
if err := c.dispatch(ctx, creq, cresp); err != nil {
info := &UnaryClientInfo{
FullMethod: fullPath(service, method),
}
if err := c.interceptor(ctx, creq, cresp, info, c.dispatch); err != nil {
return err
}

View File

@ -19,7 +19,8 @@ package ttrpc
import "github.com/pkg/errors"
type serverConfig struct {
handshaker Handshaker
handshaker Handshaker
interceptor UnaryServerInterceptor
}
type ServerOpt func(*serverConfig) error
@ -37,3 +38,13 @@ func WithServerHandshaker(handshaker Handshaker) ServerOpt {
return nil
}
}
func WithUnaryServerInterceptor(i UnaryServerInterceptor) ServerOpt {
return func(c *serverConfig) error {
if c.interceptor != nil {
return errors.New("only one interceptor allowed per server")
}
c.interceptor = i
return nil
}
}

43
interceptor.go Normal file
View File

@ -0,0 +1,43 @@
/*
Copyright The containerd Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ttrpc
import "context"
type UnaryServerInfo struct {
FullMethod string
}
type UnaryClientInfo struct {
FullMethod string
}
type Unmarshaler func(interface{}) error
type Invoker func(context.Context, *Request, *Response) error
type UnaryServerInterceptor func(context.Context, Unmarshaler, *UnaryServerInfo, Method) (interface{}, error)
type UnaryClientInterceptor func(context.Context, *Request, *Response, *UnaryClientInfo, Invoker) error
func defaultServerInterceptor(ctx context.Context, unmarshal Unmarshaler, info *UnaryServerInfo, method Method) (interface{}, error) {
return method(ctx, unmarshal)
}
func defaultClientInterceptor(ctx context.Context, req *Request, resp *Response, _ *UnaryClientInfo, invoker Invoker) error {
return invoker(ctx, req, resp)
}

View File

@ -53,10 +53,13 @@ func NewServer(opts ...ServerOpt) (*Server, error) {
return nil, err
}
}
if config.interceptor == nil {
config.interceptor = defaultServerInterceptor
}
return &Server{
config: config,
services: newServiceSet(),
services: newServiceSet(config.interceptor),
done: make(chan struct{}),
listeners: make(map[net.Listener]struct{}),
connections: make(map[*serverConn]struct{}),

View File

@ -37,12 +37,14 @@ type ServiceDesc struct {
}
type serviceSet struct {
services map[string]ServiceDesc
services map[string]ServiceDesc
interceptor UnaryServerInterceptor
}
func newServiceSet() *serviceSet {
func newServiceSet(interceptor UnaryServerInterceptor) *serviceSet {
return &serviceSet{
services: make(map[string]ServiceDesc),
services: make(map[string]ServiceDesc),
interceptor: interceptor,
}
}
@ -84,7 +86,11 @@ func (s *serviceSet) dispatch(ctx context.Context, serviceName, methodName strin
return nil
}
resp, err := method(ctx, unmarshal)
info := &UnaryServerInfo{
FullMethod: fullPath(serviceName, methodName),
}
resp, err := s.interceptor(ctx, unmarshal, info, method)
if err != nil {
return nil, err
}