ttrpc updates for interceptors, close, and metadata

Signed-off-by: Michael Crosby <crosbymichael@gmail.com>
This commit is contained in:
Michael Crosby 2019-06-13 19:09:07 +00:00
parent 40b17e97f6
commit 0b7abc02b2
10 changed files with 268 additions and 116 deletions

View File

@ -30,7 +30,7 @@ const (
func withTTRPCNamespaceHeader(ctx context.Context, namespace string) context.Context { func withTTRPCNamespaceHeader(ctx context.Context, namespace string) context.Context {
md, ok := ttrpc.GetMetadata(ctx) md, ok := ttrpc.GetMetadata(ctx)
if !ok { if !ok {
md = ttrpc.Metadata{} md = ttrpc.MD{}
} }
md.Set(TTRPCHeader, namespace) md.Set(TTRPCHeader, namespace)
return ttrpc.WithMetadata(ctx, md) return ttrpc.WithMetadata(ctx, md)

View File

@ -37,7 +37,7 @@ github.com/Microsoft/go-winio 84b4ab48a50763fe7b3abcef38e5205c12027fac
github.com/Microsoft/hcsshim 8abdbb8205e4192c68b5f84c31197156f31be517 github.com/Microsoft/hcsshim 8abdbb8205e4192c68b5f84c31197156f31be517
google.golang.org/genproto d80a6e20e776b0b17a324d0ba1ab50a39c8e8944 google.golang.org/genproto d80a6e20e776b0b17a324d0ba1ab50a39c8e8944
golang.org/x/text 19e51611da83d6be54ddafce4a4af510cb3e9ea4 golang.org/x/text 19e51611da83d6be54ddafce4a4af510cb3e9ea4
github.com/containerd/ttrpc a5bd8ce9e40bc7c065a11c6936f4d032ce6bfa2b github.com/containerd/ttrpc 1fb3814edf44a76e0ccf503decf726d994919a9a
github.com/syndtr/gocapability d98352740cb2c55f81556b63d4a1ec64c5a319c2 github.com/syndtr/gocapability d98352740cb2c55f81556b63d4a1ec64c5a319c2
gotest.tools v2.3.0 gotest.tools v2.3.0
github.com/google/go-cmp v0.2.0 github.com/google/go-cmp v0.2.0

View File

@ -18,7 +18,6 @@ package ttrpc
import ( import (
"bufio" "bufio"
"context"
"encoding/binary" "encoding/binary"
"io" "io"
"net" "net"
@ -98,7 +97,7 @@ func newChannel(conn net.Conn) *channel {
// returned will be valid and caller should send that along to // returned will be valid and caller should send that along to
// the correct consumer. The bytes on the underlying channel // the correct consumer. The bytes on the underlying channel
// will be discarded. // will be discarded.
func (ch *channel) recv(ctx context.Context) (messageHeader, []byte, error) { func (ch *channel) recv() (messageHeader, []byte, error) {
mh, err := readMessageHeader(ch.hrbuf[:], ch.br) mh, err := readMessageHeader(ch.hrbuf[:], ch.br)
if err != nil { if err != nil {
return messageHeader{}, nil, err return messageHeader{}, nil, err
@ -120,7 +119,7 @@ func (ch *channel) recv(ctx context.Context) (messageHeader, []byte, error) {
return mh, p, nil return mh, p, nil
} }
func (ch *channel) send(ctx context.Context, streamID uint32, t messageType, p []byte) error { func (ch *channel) send(streamID uint32, t messageType, p []byte) error {
if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t}); err != nil { if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t}); err != nil {
return err return err
} }

View File

@ -36,36 +36,52 @@ import (
// closed. // closed.
var ErrClosed = errors.New("ttrpc: closed") var ErrClosed = errors.New("ttrpc: closed")
// Client for a ttrpc server
type Client struct { type Client struct {
codec codec codec codec
conn net.Conn conn net.Conn
channel *channel channel *channel
calls chan *callRequest calls chan *callRequest
closed chan struct{} ctx context.Context
closeOnce sync.Once closed func()
closeFunc func()
done chan struct{} closeOnce sync.Once
err error userCloseFunc func()
errOnce sync.Once
err error
interceptor UnaryClientInterceptor
} }
// ClientOpts configures a client
type ClientOpts func(c *Client) type ClientOpts func(c *Client)
// WithOnClose sets the close func whenever the client's Close() method is called
func WithOnClose(onClose func()) ClientOpts { func WithOnClose(onClose func()) ClientOpts {
return func(c *Client) { return func(c *Client) {
c.closeFunc = onClose c.userCloseFunc = onClose
}
}
// WithUnaryClientInterceptor sets the provided client interceptor
func WithUnaryClientInterceptor(i UnaryClientInterceptor) ClientOpts {
return func(c *Client) {
c.interceptor = i
} }
} }
func NewClient(conn net.Conn, opts ...ClientOpts) *Client { func NewClient(conn net.Conn, opts ...ClientOpts) *Client {
ctx, cancel := context.WithCancel(context.Background())
c := &Client{ c := &Client{
codec: codec{}, codec: codec{},
conn: conn, conn: conn,
channel: newChannel(conn), channel: newChannel(conn),
calls: make(chan *callRequest), calls: make(chan *callRequest),
closed: make(chan struct{}), closed: cancel,
done: make(chan struct{}), ctx: ctx,
closeFunc: func() {}, userCloseFunc: func() {},
interceptor: defaultClientInterceptor,
} }
for _, o := range opts { for _, o := range opts {
@ -100,14 +116,17 @@ func (c *Client) Call(ctx context.Context, service, method string, req, resp int
) )
if metadata, ok := GetMetadata(ctx); ok { if metadata, ok := GetMetadata(ctx); ok {
creq.Metadata = metadata metadata.setRequest(creq)
} }
if dl, ok := ctx.Deadline(); ok { if dl, ok := ctx.Deadline(); ok {
creq.TimeoutNano = dl.Sub(time.Now()).Nanoseconds() creq.TimeoutNano = dl.Sub(time.Now()).Nanoseconds()
} }
if err := c.dispatch(ctx, creq, cresp); err != nil { info := &UnaryClientInfo{
FullMethod: fullPath(service, method),
}
if err := c.interceptor(ctx, creq, cresp, info, c.dispatch); err != nil {
return err return err
} }
@ -135,8 +154,8 @@ func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) err
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return ctx.Err()
case c.calls <- call: case c.calls <- call:
case <-c.done: case <-c.ctx.Done():
return c.err return c.error()
} }
select { select {
@ -144,16 +163,15 @@ func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) err
return ctx.Err() return ctx.Err()
case err := <-errs: case err := <-errs:
return filterCloseErr(err) return filterCloseErr(err)
case <-c.done: case <-c.ctx.Done():
return c.err return c.error()
} }
} }
func (c *Client) Close() error { func (c *Client) Close() error {
c.closeOnce.Do(func() { c.closeOnce.Do(func() {
close(c.closed) c.closed()
}) })
return nil return nil
} }
@ -163,51 +181,82 @@ type message struct {
err error err error
} }
func (c *Client) run() { type receiver struct {
var ( wg *sync.WaitGroup
streamID uint32 = 1 messages chan *message
waiters = make(map[uint32]*callRequest) err error
calls = c.calls }
incoming = make(chan *message)
shutdown = make(chan struct{})
shutdownErr error
)
go func() { func (r *receiver) run(ctx context.Context, c *channel) {
defer close(shutdown) defer r.wg.Done()
// start one more goroutine to recv messages without blocking. for {
for { select {
mh, p, err := c.channel.recv(context.TODO()) case <-ctx.Done():
r.err = ctx.Err()
return
default:
mh, p, err := c.recv()
if err != nil { if err != nil {
_, ok := status.FromError(err) _, ok := status.FromError(err)
if !ok { if !ok {
// treat all errors that are not an rpc status as terminal. // treat all errors that are not an rpc status as terminal.
// all others poison the connection. // all others poison the connection.
shutdownErr = err r.err = filterCloseErr(err)
return return
} }
} }
select { select {
case incoming <- &message{ case r.messages <- &message{
messageHeader: mh, messageHeader: mh,
p: p[:mh.Length], p: p[:mh.Length],
err: err, err: err,
}: }:
case <-c.done: case <-ctx.Done():
r.err = ctx.Err()
return return
} }
} }
}() }
}
defer c.conn.Close() func (c *Client) run() {
defer close(c.done) var (
defer c.closeFunc() streamID uint32 = 1
waiters = make(map[uint32]*callRequest)
calls = c.calls
incoming = make(chan *message)
receiversDone = make(chan struct{})
wg sync.WaitGroup
)
// broadcast the shutdown error to the remaining waiters.
abortWaiters := func(wErr error) {
for _, waiter := range waiters {
waiter.errs <- wErr
}
}
recv := &receiver{
wg: &wg,
messages: incoming,
}
wg.Add(1)
go func() {
wg.Wait()
close(receiversDone)
}()
go recv.run(c.ctx, c.channel)
defer func() {
c.conn.Close()
c.userCloseFunc()
}()
for { for {
select { select {
case call := <-calls: case call := <-calls:
if err := c.send(call.ctx, streamID, messageTypeRequest, call.req); err != nil { if err := c.send(streamID, messageTypeRequest, call.req); err != nil {
call.errs <- err call.errs <- err
continue continue
} }
@ -223,41 +272,42 @@ func (c *Client) run() {
call.errs <- c.recv(call.resp, msg) call.errs <- c.recv(call.resp, msg)
delete(waiters, msg.StreamID) delete(waiters, msg.StreamID)
case <-shutdown: case <-receiversDone:
if shutdownErr != nil { // all the receivers have exited
shutdownErr = filterCloseErr(shutdownErr) if recv.err != nil {
} else { c.setError(recv.err)
shutdownErr = ErrClosed
}
shutdownErr = errors.Wrapf(shutdownErr, "ttrpc: client shutting down")
c.err = shutdownErr
for _, waiter := range waiters {
waiter.errs <- shutdownErr
} }
// don't return out, let the close of the context trigger the abort of waiters
c.Close() c.Close()
return case <-c.ctx.Done():
case <-c.closed: abortWaiters(c.error())
if c.err == nil {
c.err = ErrClosed
}
// broadcast the shutdown error to the remaining waiters.
for _, waiter := range waiters {
waiter.errs <- c.err
}
return return
} }
} }
} }
func (c *Client) send(ctx context.Context, streamID uint32, mtype messageType, msg interface{}) error { func (c *Client) error() error {
c.errOnce.Do(func() {
if c.err == nil {
c.err = ErrClosed
}
})
return c.err
}
func (c *Client) setError(err error) {
c.errOnce.Do(func() {
c.err = err
})
}
func (c *Client) send(streamID uint32, mtype messageType, msg interface{}) error {
p, err := c.codec.Marshal(msg) p, err := c.codec.Marshal(msg)
if err != nil { if err != nil {
return err return err
} }
return c.channel.send(ctx, streamID, mtype, p) return c.channel.send(streamID, mtype, p)
} }
func (c *Client) recv(resp *Response, msg *message) error { func (c *Client) recv(resp *Response, msg *message) error {
@ -278,22 +328,21 @@ func (c *Client) recv(resp *Response, msg *message) error {
// //
// This purposely ignores errors with a wrapped cause. // This purposely ignores errors with a wrapped cause.
func filterCloseErr(err error) error { func filterCloseErr(err error) error {
if err == nil { switch {
case err == nil:
return nil return nil
} case err == io.EOF:
if err == io.EOF {
return ErrClosed return ErrClosed
} case errors.Cause(err) == io.EOF:
if strings.Contains(err.Error(), "use of closed network connection") {
return ErrClosed return ErrClosed
} case strings.Contains(err.Error(), "use of closed network connection"):
return ErrClosed
// if we have an epipe on a write, we cast to errclosed default:
if oerr, ok := err.(*net.OpError); ok && oerr.Op == "write" { // if we have an epipe on a write, we cast to errclosed
if serr, ok := oerr.Err.(*os.SyscallError); ok && serr.Err == syscall.EPIPE { if oerr, ok := err.(*net.OpError); ok && oerr.Op == "write" {
return ErrClosed if serr, ok := oerr.Err.(*os.SyscallError); ok && serr.Err == syscall.EPIPE {
return ErrClosed
}
} }
} }

View File

@ -19,9 +19,11 @@ package ttrpc
import "github.com/pkg/errors" import "github.com/pkg/errors"
type serverConfig struct { type serverConfig struct {
handshaker Handshaker handshaker Handshaker
interceptor UnaryServerInterceptor
} }
// ServerOpt for configuring a ttrpc server
type ServerOpt func(*serverConfig) error type ServerOpt func(*serverConfig) error
// WithServerHandshaker can be passed to NewServer to ensure that the // WithServerHandshaker can be passed to NewServer to ensure that the
@ -37,3 +39,14 @@ func WithServerHandshaker(handshaker Handshaker) ServerOpt {
return nil return nil
} }
} }
// WithUnaryServerInterceptor sets the provided interceptor on the server
func WithUnaryServerInterceptor(i UnaryServerInterceptor) ServerOpt {
return func(c *serverConfig) error {
if c.interceptor != nil {
return errors.New("only one interceptor allowed per server")
}
c.interceptor = i
return nil
}
}

50
vendor/github.com/containerd/ttrpc/interceptor.go generated vendored Normal file
View File

@ -0,0 +1,50 @@
/*
Copyright The containerd Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ttrpc
import "context"
// UnaryServerInfo provides information about the server request
type UnaryServerInfo struct {
FullMethod string
}
// UnaryClientInfo provides information about the client request
type UnaryClientInfo struct {
FullMethod string
}
// Unmarshaler contains the server request data and allows it to be unmarshaled
// into a concrete type
type Unmarshaler func(interface{}) error
// Invoker invokes the client's request and response from the ttrpc server
type Invoker func(context.Context, *Request, *Response) error
// UnaryServerInterceptor specifies the interceptor function for server request/response
type UnaryServerInterceptor func(context.Context, Unmarshaler, *UnaryServerInfo, Method) (interface{}, error)
// UnaryClientInterceptor specifies the interceptor function for client request/response
type UnaryClientInterceptor func(context.Context, *Request, *Response, *UnaryClientInfo, Invoker) error
func defaultServerInterceptor(ctx context.Context, unmarshal Unmarshaler, info *UnaryServerInfo, method Method) (interface{}, error) {
return method(ctx, unmarshal)
}
func defaultClientInterceptor(ctx context.Context, req *Request, resp *Response, _ *UnaryClientInfo, invoker Invoker) error {
return invoker(ctx, req, resp)
}

View File

@ -16,53 +16,74 @@
package ttrpc package ttrpc
import "context" import (
"context"
"strings"
)
// Metadata represents the key-value pairs (similar to http.Header) to be passed to ttrpc server from a client. // MD is the user type for ttrpc metadata
type Metadata map[string]StringList type MD map[string][]string
// Get returns the metadata for a given key when they exist. // Get returns the metadata for a given key when they exist.
// If there is no metadata, a nil slice and false are returned. // If there is no metadata, a nil slice and false are returned.
func (m Metadata) Get(key string) ([]string, bool) { func (m MD) Get(key string) ([]string, bool) {
key = strings.ToLower(key)
list, ok := m[key] list, ok := m[key]
if !ok || len(list.List) == 0 { if !ok || len(list) == 0 {
return nil, false return nil, false
} }
return list.List, true return list, true
} }
// Set sets the provided values for a given key. // Set sets the provided values for a given key.
// The values will overwrite any existing values. // The values will overwrite any existing values.
// If no values provided, a key will be deleted. // If no values provided, a key will be deleted.
func (m Metadata) Set(key string, values ...string) { func (m MD) Set(key string, values ...string) {
key = strings.ToLower(key)
if len(values) == 0 { if len(values) == 0 {
delete(m, key) delete(m, key)
return return
} }
m[key] = values
m[key] = StringList{List: values}
} }
// Append appends additional values to the given key. // Append appends additional values to the given key.
func (m Metadata) Append(key string, values ...string) { func (m MD) Append(key string, values ...string) {
key = strings.ToLower(key)
if len(values) == 0 { if len(values) == 0 {
return return
} }
current, ok := m[key]
list, ok := m[key]
if ok { if ok {
m.Set(key, append(list.List, values...)...) m.Set(key, append(current, values...)...)
} else { } else {
m.Set(key, values...) m.Set(key, values...)
} }
} }
func (m MD) setRequest(r *Request) {
for k, values := range m {
for _, v := range values {
r.Metadata = append(r.Metadata, &KeyValue{
Key: k,
Value: v,
})
}
}
}
func (m MD) fromRequest(r *Request) {
for _, kv := range r.Metadata {
m[kv.Key] = append(m[kv.Key], kv.Value)
}
}
type metadataKey struct{} type metadataKey struct{}
// GetMetadata retrieves metadata from context.Context (previously attached with WithMetadata) // GetMetadata retrieves metadata from context.Context (previously attached with WithMetadata)
func GetMetadata(ctx context.Context) (Metadata, bool) { func GetMetadata(ctx context.Context) (MD, bool) {
metadata, ok := ctx.Value(metadataKey{}).(Metadata) metadata, ok := ctx.Value(metadataKey{}).(MD)
return metadata, ok return metadata, ok
} }
@ -81,6 +102,6 @@ func GetMetadataValue(ctx context.Context, name string) (string, bool) {
} }
// WithMetadata attaches metadata map to a context.Context // WithMetadata attaches metadata map to a context.Context
func WithMetadata(ctx context.Context, headers Metadata) context.Context { func WithMetadata(ctx context.Context, md MD) context.Context {
return context.WithValue(ctx, metadataKey{}, headers) return context.WithValue(ctx, metadataKey{}, md)
} }

View File

@ -53,10 +53,13 @@ func NewServer(opts ...ServerOpt) (*Server, error) {
return nil, err return nil, err
} }
} }
if config.interceptor == nil {
config.interceptor = defaultServerInterceptor
}
return &Server{ return &Server{
config: config, config: config,
services: newServiceSet(), services: newServiceSet(config.interceptor),
done: make(chan struct{}), done: make(chan struct{}),
listeners: make(map[net.Listener]struct{}), listeners: make(map[net.Listener]struct{}),
connections: make(map[*serverConn]struct{}), connections: make(map[*serverConn]struct{}),
@ -341,7 +344,7 @@ func (c *serverConn) run(sctx context.Context) {
default: // proceed default: // proceed
} }
mh, p, err := ch.recv(ctx) mh, p, err := ch.recv()
if err != nil { if err != nil {
status, ok := status.FromError(err) status, ok := status.FromError(err)
if !ok { if !ok {
@ -438,7 +441,7 @@ func (c *serverConn) run(sctx context.Context) {
return return
} }
if err := ch.send(ctx, response.id, messageTypeResponse, p); err != nil { if err := ch.send(response.id, messageTypeResponse, p); err != nil {
logrus.WithError(err).Error("failed sending message on channel") logrus.WithError(err).Error("failed sending message on channel")
return return
} }
@ -466,8 +469,10 @@ func (c *serverConn) run(sctx context.Context) {
var noopFunc = func() {} var noopFunc = func() {}
func getRequestContext(ctx context.Context, req *Request) (retCtx context.Context, cancel func()) { func getRequestContext(ctx context.Context, req *Request) (retCtx context.Context, cancel func()) {
if req.Metadata != nil { if len(req.Metadata) > 0 {
ctx = WithMetadata(ctx, req.Metadata) md := MD{}
md.fromRequest(req)
ctx = WithMetadata(ctx, md)
} }
cancel = noopFunc cancel = noopFunc

View File

@ -37,12 +37,14 @@ type ServiceDesc struct {
} }
type serviceSet struct { type serviceSet struct {
services map[string]ServiceDesc services map[string]ServiceDesc
interceptor UnaryServerInterceptor
} }
func newServiceSet() *serviceSet { func newServiceSet(interceptor UnaryServerInterceptor) *serviceSet {
return &serviceSet{ return &serviceSet{
services: make(map[string]ServiceDesc), services: make(map[string]ServiceDesc),
interceptor: interceptor,
} }
} }
@ -84,7 +86,11 @@ func (s *serviceSet) dispatch(ctx context.Context, serviceName, methodName strin
return nil return nil
} }
resp, err := method(ctx, unmarshal) info := &UnaryServerInfo{
FullMethod: fullPath(serviceName, methodName),
}
resp, err := s.interceptor(ctx, unmarshal, info, method)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -23,11 +23,11 @@ import (
) )
type Request struct { type Request struct {
Service string `protobuf:"bytes,1,opt,name=service,proto3"` Service string `protobuf:"bytes,1,opt,name=service,proto3"`
Method string `protobuf:"bytes,2,opt,name=method,proto3"` Method string `protobuf:"bytes,2,opt,name=method,proto3"`
Payload []byte `protobuf:"bytes,3,opt,name=payload,proto3"` Payload []byte `protobuf:"bytes,3,opt,name=payload,proto3"`
TimeoutNano int64 `protobuf:"varint,4,opt,name=timeout_nano,proto3"` TimeoutNano int64 `protobuf:"varint,4,opt,name=timeout_nano,proto3"`
Metadata Metadata `protobuf:"bytes,5,opt,name=metadata,proto3" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` Metadata []*KeyValue `protobuf:"bytes,5,rep,name=metadata,proto3"`
} }
func (r *Request) Reset() { *r = Request{} } func (r *Request) Reset() { *r = Request{} }
@ -52,3 +52,12 @@ func (r *StringList) String() string { return fmt.Sprintf("%+#v", r) }
func (r *StringList) ProtoMessage() {} func (r *StringList) ProtoMessage() {}
func makeStringList(item ...string) StringList { return StringList{List: item} } func makeStringList(item ...string) StringList { return StringList{List: item} }
type KeyValue struct {
Key string `protobuf:"bytes,1,opt,name=key,proto3"`
Value string `protobuf:"bytes,2,opt,name=value,proto3"`
}
func (m *KeyValue) Reset() { *m = KeyValue{} }
func (*KeyValue) ProtoMessage() {}
func (m *KeyValue) String() string { return fmt.Sprintf("%+#v", m) }