From 393cf8e8fc08156bbde50c83d762aa338c6f8951 Mon Sep 17 00:00:00 2001 From: Stephen J Day Date: Wed, 29 Nov 2017 14:13:01 -0800 Subject: [PATCH] vendor: update ttrpc with latest changes Signed-off-by: Stephen J Day --- vendor.conf | 2 +- vendor/github.com/stevvooe/ttrpc/README.md | 2 + vendor/github.com/stevvooe/ttrpc/channel.go | 52 ++- vendor/github.com/stevvooe/ttrpc/client.go | 249 +++++++------- vendor/github.com/stevvooe/ttrpc/server.go | 345 +++++++++++++++++--- 5 files changed, 478 insertions(+), 172 deletions(-) diff --git a/vendor.conf b/vendor.conf index 0316ebb83..8f164c84c 100644 --- a/vendor.conf +++ b/vendor.conf @@ -41,4 +41,4 @@ github.com/boltdb/bolt e9cf4fae01b5a8ff89d0ec6b32f0d9c9f79aefdd google.golang.org/genproto d80a6e20e776b0b17a324d0ba1ab50a39c8e8944 golang.org/x/text 19e51611da83d6be54ddafce4a4af510cb3e9ea4 github.com/dmcgowan/go-tar go1.10 -github.com/stevvooe/ttrpc bdb2ab7a8169e485e39421e666e15a505e575fd2 +github.com/stevvooe/ttrpc 8c92e22ce0c492875ccaac3ab06143a77d8ed0c1 diff --git a/vendor/github.com/stevvooe/ttrpc/README.md b/vendor/github.com/stevvooe/ttrpc/README.md index 6e2159a84..b246e578b 100644 --- a/vendor/github.com/stevvooe/ttrpc/README.md +++ b/vendor/github.com/stevvooe/ttrpc/README.md @@ -1,5 +1,7 @@ # ttrpc +[![Build Status](https://travis-ci.org/stevvooe/ttrpc.svg?branch=master)](https://travis-ci.org/stevvooe/ttrpc) + GRPC for low-memory environments. The existing grpc-go project requires a lot of memory overhead for importing diff --git a/vendor/github.com/stevvooe/ttrpc/channel.go b/vendor/github.com/stevvooe/ttrpc/channel.go index a71260bcc..4a33827a4 100644 --- a/vendor/github.com/stevvooe/ttrpc/channel.go +++ b/vendor/github.com/stevvooe/ttrpc/channel.go @@ -5,13 +5,16 @@ import ( "context" "encoding/binary" "io" + "sync" "github.com/pkg/errors" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) const ( messageHeaderLength = 10 - messageLengthMax = 8 << 10 + messageLengthMax = 4 << 20 ) type messageType uint8 @@ -54,6 +57,8 @@ func writeMessageHeader(w io.Writer, p []byte, mh messageHeader) error { return err } +var buffers sync.Pool + type channel struct { bw *bufio.Writer br *bufio.Reader @@ -68,21 +73,32 @@ func newChannel(w io.Writer, r io.Reader) *channel { } } -func (ch *channel) recv(ctx context.Context, p []byte) (messageHeader, error) { +// recv a message from the channel. The returned buffer contains the message. +// +// If a valid grpc status is returned, the message header +// returned will be valid and caller should send that along to +// the correct consumer. The bytes on the underlying channel +// will be discarded. +func (ch *channel) recv(ctx context.Context) (messageHeader, []byte, error) { mh, err := readMessageHeader(ch.hrbuf[:], ch.br) if err != nil { - return messageHeader{}, err + return messageHeader{}, nil, err } - if mh.Length > uint32(len(p)) { - return messageHeader{}, errors.Wrapf(io.ErrShortBuffer, "message length %v over buffer size %v", mh.Length, len(p)) + if mh.Length > uint32(messageLengthMax) { + if _, err := ch.br.Discard(int(mh.Length)); err != nil { + return mh, nil, errors.Wrapf(err, "failed to discard after receiving oversized message") + } + + return mh, nil, status.Errorf(codes.ResourceExhausted, "message length %v exceed maximum message size of %v", mh.Length, messageLengthMax) } - if _, err := io.ReadFull(ch.br, p[:mh.Length]); err != nil { - return messageHeader{}, errors.Wrapf(err, "failed reading message") + p := ch.getmbuf(int(mh.Length)) + if _, err := io.ReadFull(ch.br, p); err != nil { + return messageHeader{}, nil, errors.Wrapf(err, "failed reading message") } - return mh, nil + return mh, p, nil } func (ch *channel) send(ctx context.Context, streamID uint32, t messageType, p []byte) error { @@ -97,3 +113,23 @@ func (ch *channel) send(ctx context.Context, streamID uint32, t messageType, p [ return ch.bw.Flush() } + +func (ch *channel) getmbuf(size int) []byte { + // we can't use the standard New method on pool because we want to allocate + // based on size. + b, ok := buffers.Get().(*[]byte) + if !ok || cap(*b) < size { + // TODO(stevvooe): It may be better to allocate these in fixed length + // buckets to reduce fragmentation but its not clear that would help + // with performance. An ilogb approach or similar would work well. + bb := make([]byte, size) + b = &bb + } else { + *b = (*b)[:size] + } + return *b +} + +func (ch *channel) putmbuf(p []byte) { + buffers.Put(&p) +} diff --git a/vendor/github.com/stevvooe/ttrpc/client.go b/vendor/github.com/stevvooe/ttrpc/client.go index 9dbb3d3b4..ca76afe19 100644 --- a/vendor/github.com/stevvooe/ttrpc/client.go +++ b/vendor/github.com/stevvooe/ttrpc/client.go @@ -4,19 +4,18 @@ import ( "context" "net" "sync" - "sync/atomic" + "github.com/containerd/containerd/log" "github.com/gogo/protobuf/proto" "github.com/pkg/errors" "google.golang.org/grpc/status" ) type Client struct { - codec codec - channel *channel - requestID uint32 - sendRequests chan sendRequest - recvRequests chan recvRequest + codec codec + conn net.Conn + channel *channel + calls chan *callRequest closed chan struct{} closeOnce sync.Once @@ -26,50 +25,76 @@ type Client struct { func NewClient(conn net.Conn) *Client { c := &Client{ - codec: codec{}, - requestID: 1, - channel: newChannel(conn, conn), - sendRequests: make(chan sendRequest), - recvRequests: make(chan recvRequest), - closed: make(chan struct{}), - done: make(chan struct{}), + codec: codec{}, + conn: conn, + channel: newChannel(conn, conn), + calls: make(chan *callRequest), + closed: make(chan struct{}), + done: make(chan struct{}), } go c.run() return c } +type callRequest struct { + ctx context.Context + req *Request + resp *Response // response will be written back here + errs chan error // error written here on completion +} + func (c *Client) Call(ctx context.Context, service, method string, req, resp interface{}) error { payload, err := c.codec.Marshal(req) if err != nil { return err } - requestID := atomic.AddUint32(&c.requestID, 2) - request := Request{ - Service: service, - Method: method, - Payload: payload, - } + var ( + creq = &Request{ + Service: service, + Method: method, + Payload: payload, + } - if err := c.send(ctx, requestID, &request); err != nil { + cresp = &Response{} + ) + + if err := c.dispatch(ctx, creq, cresp); err != nil { return err } - var response Response - if err := c.recv(ctx, requestID, &response); err != nil { + if err := c.codec.Unmarshal(cresp.Payload, resp); err != nil { return err } - if err := c.codec.Unmarshal(response.Payload, resp); err != nil { - return err - } - - if response.Status == nil { + if cresp.Status == nil { return errors.New("no status provided on response") } - return status.ErrorProto(response.Status) + return status.ErrorProto(cresp.Status) +} + +func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) error { + errs := make(chan error, 1) + call := &callRequest{ + req: req, + resp: resp, + errs: errs, + } + + select { + case c.calls <- call: + case <-c.done: + return c.err + } + + select { + case err := <-errs: + return err + case <-c.done: + return c.err + } } func (c *Client) Close() error { @@ -80,92 +105,42 @@ func (c *Client) Close() error { return nil } -type sendRequest struct { - ctx context.Context - id uint32 - msg interface{} - err chan error -} - -func (c *Client) send(ctx context.Context, id uint32, msg interface{}) error { - errs := make(chan error, 1) - select { - case c.sendRequests <- sendRequest{ - ctx: ctx, - id: id, - msg: msg, - err: errs, - }: - case <-ctx.Done(): - return ctx.Err() - case <-c.done: - return c.err - } - - select { - case err := <-errs: - return err - case <-ctx.Done(): - return ctx.Err() - case <-c.done: - return c.err - } -} - -type recvRequest struct { - id uint32 - msg interface{} - err chan error -} - -func (c *Client) recv(ctx context.Context, id uint32, msg interface{}) error { - errs := make(chan error, 1) - select { - case c.recvRequests <- recvRequest{ - id: id, - msg: msg, - err: errs, - }: - case <-c.done: - return c.err - case <-ctx.Done(): - return ctx.Err() - } - - select { - case err := <-errs: - return err - case <-c.done: - return c.err - case <-ctx.Done(): - return ctx.Err() - } -} - -type received struct { - mh messageHeader +type message struct { + messageHeader p []byte err error } func (c *Client) run() { - defer close(c.done) var ( - waiters = map[uint32]recvRequest{} - queued = map[uint32]received{} // messages unmatched by waiter - incoming = make(chan received) + streamID uint32 = 1 + waiters = make(map[uint32]*callRequest) + calls = c.calls + incoming = make(chan *message) + shutdown = make(chan struct{}) + shutdownErr error ) go func() { + defer close(shutdown) + // start one more goroutine to recv messages without blocking. for { - var p [messageLengthMax]byte - mh, err := c.channel.recv(context.TODO(), p[:]) + mh, p, err := c.channel.recv(context.TODO()) + if err != nil { + _, ok := status.FromError(err) + if !ok { + // treat all errors that are not an rpc status as terminal. + // all others poison the connection. + shutdownErr = err + return + } + } select { - case incoming <- received{ - mh: mh, - p: p[:mh.Length], - err: err, + case incoming <- &message{ + messageHeader: mh, + p: p[:mh.Length], + err: err, }: case <-c.done: return @@ -173,32 +148,64 @@ func (c *Client) run() { } }() + defer c.conn.Close() + defer close(c.done) + for { select { - case req := <-c.sendRequests: - if p, err := proto.Marshal(req.msg.(proto.Message)); err != nil { - req.err <- err - } else { - req.err <- c.channel.send(req.ctx, req.id, messageTypeRequest, p) - } - case req := <-c.recvRequests: - if r, ok := queued[req.id]; ok { - req.err <- proto.Unmarshal(r.p, req.msg.(proto.Message)) - } - waiters[req.id] = req - case r := <-incoming: - if r.err != nil { - c.err = r.err - return + case call := <-calls: + if err := c.send(call.ctx, streamID, messageTypeRequest, call.req); err != nil { + call.errs <- err + continue } - if waiter, ok := waiters[r.mh.StreamID]; ok { - waiter.err <- proto.Unmarshal(r.p, waiter.msg.(proto.Message)) - } else { - queued[r.mh.StreamID] = r + waiters[streamID] = call + streamID += 2 // enforce odd client initiated request ids + case msg := <-incoming: + call, ok := waiters[msg.StreamID] + if !ok { + log.L.Errorf("ttrpc: received message for unknown channel %v", msg.StreamID) + continue } + + call.errs <- c.recv(call.resp, msg) + delete(waiters, msg.StreamID) + case <-shutdown: + shutdownErr = errors.Wrapf(shutdownErr, "ttrpc: client shutting down") + c.err = shutdownErr + for _, waiter := range waiters { + waiter.errs <- shutdownErr + } + c.Close() + return case <-c.closed: + // broadcast the shutdown error to the remaining waiters. + for _, waiter := range waiters { + waiter.errs <- shutdownErr + } return } } } + +func (c *Client) send(ctx context.Context, streamID uint32, mtype messageType, msg interface{}) error { + p, err := c.codec.Marshal(msg) + if err != nil { + return err + } + + return c.channel.send(ctx, streamID, mtype, p) +} + +func (c *Client) recv(resp *Response, msg *message) error { + if msg.err != nil { + return msg.err + } + + if msg.Type != messageTypeResponse { + return errors.New("unkown message type received") + } + + defer c.channel.putmbuf(msg.p) + return proto.Unmarshal(msg.p, resp) +} diff --git a/vendor/github.com/stevvooe/ttrpc/server.go b/vendor/github.com/stevvooe/ttrpc/server.go index 407068fc3..ed2d14cf7 100644 --- a/vendor/github.com/stevvooe/ttrpc/server.go +++ b/vendor/github.com/stevvooe/ttrpc/server.go @@ -2,21 +2,38 @@ package ttrpc import ( "context" + "math/rand" "net" + "sync" + "sync/atomic" + "time" "github.com/containerd/containerd/log" + "github.com/pkg/errors" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) +var ( + ErrServerClosed = errors.New("ttrpc: server close") +) + type Server struct { services *serviceSet codec codec + + mu sync.Mutex + listeners map[net.Listener]struct{} + connections map[*serverConn]struct{} // all connections to current state + done chan struct{} // marks point at which we stop serving requests } func NewServer() *Server { return &Server{ - services: newServiceSet(), + services: newServiceSet(), + done: make(chan struct{}), + listeners: make(map[net.Listener]struct{}), + connections: make(map[*serverConn]struct{}), } } @@ -24,28 +41,208 @@ func (s *Server) Register(name string, methods map[string]Method) { s.services.register(name, methods) } -func (s *Server) Shutdown(ctx context.Context) error { - // TODO(stevvooe): Wait on connection shutdown. - return nil -} - func (s *Server) Serve(l net.Listener) error { + s.addListener(l) + defer s.closeListener(l) + + var ( + ctx = context.Background() + backoff time.Duration + ) + for { conn, err := l.Accept() if err != nil { - log.L.WithError(err).Error("failed accept") - continue + select { + case <-s.done: + return ErrServerClosed + default: + } + + if terr, ok := err.(interface { + Temporary() bool + }); ok && terr.Temporary() { + if backoff == 0 { + backoff = time.Millisecond + } else { + backoff *= 2 + } + + if max := time.Second; backoff > max { + backoff = max + } + + sleep := time.Duration(rand.Int63n(int64(backoff))) + log.L.WithError(err).Errorf("ttrpc: failed accept; backoff %v", sleep) + time.Sleep(sleep) + continue + } + + return err } - go s.handleConn(conn) + backoff = 0 + sc := s.newConn(conn) + go sc.run(ctx) } +} + +func (s *Server) Shutdown(ctx context.Context) error { + s.mu.Lock() + lnerr := s.closeListeners() + select { + case <-s.done: + default: + // protected by mutex + close(s.done) + } + s.mu.Unlock() + + ticker := time.NewTicker(200 * time.Millisecond) + defer ticker.Stop() + for { + if s.closeIdleConns() { + return lnerr + } + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + } + } +} + +// Close the server without waiting for active connections. +func (s *Server) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + + select { + case <-s.done: + default: + // protected by mutex + close(s.done) + } + + err := s.closeListeners() + for c := range s.connections { + c.close() + delete(s.connections, c) + } + + return err +} + +func (s *Server) addListener(l net.Listener) { + s.mu.Lock() + defer s.mu.Unlock() + s.listeners[l] = struct{}{} +} + +func (s *Server) closeListener(l net.Listener) error { + s.mu.Lock() + defer s.mu.Unlock() + + return s.closeListenerLocked(l) +} + +func (s *Server) closeListenerLocked(l net.Listener) error { + defer delete(s.listeners, l) + return l.Close() +} + +func (s *Server) closeListeners() error { + var err error + for l := range s.listeners { + if cerr := s.closeListenerLocked(l); cerr != nil && err == nil { + err = cerr + } + } + return err +} + +func (s *Server) addConnection(c *serverConn) { + s.mu.Lock() + defer s.mu.Unlock() + + s.connections[c] = struct{}{} +} + +func (s *Server) closeIdleConns() bool { + s.mu.Lock() + defer s.mu.Unlock() + quiescent := true + for c := range s.connections { + st, ok := c.getState() + if !ok || st != connStateIdle { + quiescent = false + continue + } + c.close() + delete(s.connections, c) + } + return quiescent +} + +type connState int + +const ( + connStateActive = iota + 1 // outstanding requests + connStateIdle // no requests + connStateClosed // closed connection +) + +func (cs connState) String() string { + switch cs { + case connStateActive: + return "active" + case connStateIdle: + return "idle" + case connStateClosed: + return "closed" + default: + return "unknown" + } +} + +func (s *Server) newConn(conn net.Conn) *serverConn { + c := &serverConn{ + server: s, + conn: conn, + shutdown: make(chan struct{}), + } + c.setState(connStateIdle) + s.addConnection(c) + return c +} + +type serverConn struct { + server *Server + conn net.Conn + state atomic.Value + + shutdownOnce sync.Once + shutdown chan struct{} // forced shutdown, used by close +} + +func (c *serverConn) getState() (connState, bool) { + cs, ok := c.state.Load().(connState) + return cs, ok +} + +func (c *serverConn) setState(newstate connState) { + c.state.Store(newstate) +} + +func (c *serverConn) close() error { + c.shutdownOnce.Do(func() { + close(c.shutdown) + }) return nil } -func (s *Server) handleConn(conn net.Conn) { - defer conn.Close() - +func (c *serverConn) run(sctx context.Context) { type ( request struct { id uint32 @@ -59,25 +256,66 @@ func (s *Server) handleConn(conn net.Conn) { ) var ( - ch = newChannel(conn, conn) - ctx, cancel = context.WithCancel(context.Background()) - responses = make(chan response) - requests = make(chan request) - recvErr = make(chan error, 1) - done = make(chan struct{}) + ch = newChannel(c.conn, c.conn) + ctx, cancel = context.WithCancel(sctx) + active int + state connState = connStateIdle + responses = make(chan response) + requests = make(chan request) + recvErr = make(chan error, 1) + shutdown = c.shutdown + done = make(chan struct{}) ) + defer c.conn.Close() defer cancel() defer close(done) - go func() { + go func(recvErr chan error) { defer close(recvErr) - var p [messageLengthMax]byte + sendImmediate := func(id uint32, st *status.Status) bool { + select { + case responses <- response{ + // even though we've had an invalid stream id, we send it + // back on the same stream id so the client knows which + // stream id was bad. + id: id, + resp: &Response{ + Status: st.Proto(), + }, + }: + return true + case <-c.shutdown: + return false + case <-done: + return false + } + } + for { - mh, err := ch.recv(ctx, p[:]) - if err != nil { - recvErr <- err + select { + case <-c.shutdown: return + case <-done: + return + default: // proceed + } + + mh, p, err := ch.recv(ctx) + if err != nil { + status, ok := status.FromError(err) + if !ok { + recvErr <- err + return + } + + // in this case, we send an error for that particular message + // when the status is defined. + if !sendImmediate(mh.StreamID, status) { + return + } + + continue } if mh.Type != messageTypeRequest { @@ -86,44 +324,57 @@ func (s *Server) handleConn(conn net.Conn) { } var req Request - if err := s.codec.Unmarshal(p[:mh.Length], &req); err != nil { - recvErr <- err - return + if err := c.server.codec.Unmarshal(p, &req); err != nil { + ch.putmbuf(p) + if !sendImmediate(mh.StreamID, status.Newf(codes.InvalidArgument, "unmarshal request error: %v", err)) { + return + } + continue } + ch.putmbuf(p) if mh.StreamID%2 != 1 { // enforce odd client initiated identifiers. - select { - case responses <- response{ - // even though we've had an invalid stream id, we send it - // back on the same stream id so the client knows which - // stream id was bad. - id: mh.StreamID, - resp: &Response{ - Status: status.New(codes.InvalidArgument, "StreamID must be odd for client initiated streams").Proto(), - }, - }: - case <-done: + if !sendImmediate(mh.StreamID, status.Newf(codes.InvalidArgument, "StreamID must be odd for client initiated streams")) { + return } - continue } + // Forward the request to the main loop. We don't wait on s.done + // because we have already accepted the client request. select { case requests <- request{ id: mh.StreamID, req: &req, }: case <-done: + return } } - }() + }(recvErr) for { + newstate := state + switch { + case active > 0: + newstate = connStateActive + shutdown = nil + case active == 0: + newstate = connStateIdle + shutdown = c.shutdown // only enable this branch in idle mode + } + + if newstate != state { + c.setState(newstate) + state = newstate + } + select { case request := <-requests: + active++ go func(id uint32) { - p, status := s.services.call(ctx, request.req.Service, request.req.Method, request.req.Payload) + p, status := c.server.services.call(ctx, request.req.Service, request.req.Method, request.req.Payload) resp := &Response{ Status: status.Proto(), Payload: p, @@ -138,17 +389,27 @@ func (s *Server) handleConn(conn net.Conn) { } }(request.id) case response := <-responses: - p, err := s.codec.Marshal(response.resp) + p, err := c.server.codec.Marshal(response.resp) if err != nil { log.L.WithError(err).Error("failed marshaling response") return } + if err := ch.send(ctx, response.id, messageTypeResponse, p); err != nil { log.L.WithError(err).Error("failed sending message on channel") return } + + active-- case err := <-recvErr: - log.L.WithError(err).Error("error receiving message") + // TODO(stevvooe): Not wildly clear what we should do in this + // branch. Basically, it means that we are no longer receiving + // requests due to a terminal error. + recvErr = nil // connection is now "closing" + if err != nil { + log.L.WithError(err).Error("error receiving message") + } + case <-shutdown: return } }