vendor: update ttrpc with latest changes

Signed-off-by: Stephen J Day <stephen.day@docker.com>
This commit is contained in:
Stephen J Day 2017-11-29 14:13:01 -08:00
parent 59bd196711
commit 393cf8e8fc
No known key found for this signature in database
GPG Key ID: 67B3DED84EDC823F
5 changed files with 478 additions and 172 deletions

View File

@ -41,4 +41,4 @@ github.com/boltdb/bolt e9cf4fae01b5a8ff89d0ec6b32f0d9c9f79aefdd
google.golang.org/genproto d80a6e20e776b0b17a324d0ba1ab50a39c8e8944 google.golang.org/genproto d80a6e20e776b0b17a324d0ba1ab50a39c8e8944
golang.org/x/text 19e51611da83d6be54ddafce4a4af510cb3e9ea4 golang.org/x/text 19e51611da83d6be54ddafce4a4af510cb3e9ea4
github.com/dmcgowan/go-tar go1.10 github.com/dmcgowan/go-tar go1.10
github.com/stevvooe/ttrpc bdb2ab7a8169e485e39421e666e15a505e575fd2 github.com/stevvooe/ttrpc 8c92e22ce0c492875ccaac3ab06143a77d8ed0c1

View File

@ -1,5 +1,7 @@
# ttrpc # ttrpc
[![Build Status](https://travis-ci.org/stevvooe/ttrpc.svg?branch=master)](https://travis-ci.org/stevvooe/ttrpc)
GRPC for low-memory environments. GRPC for low-memory environments.
The existing grpc-go project requires a lot of memory overhead for importing The existing grpc-go project requires a lot of memory overhead for importing

View File

@ -5,13 +5,16 @@ import (
"context" "context"
"encoding/binary" "encoding/binary"
"io" "io"
"sync"
"github.com/pkg/errors" "github.com/pkg/errors"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
) )
const ( const (
messageHeaderLength = 10 messageHeaderLength = 10
messageLengthMax = 8 << 10 messageLengthMax = 4 << 20
) )
type messageType uint8 type messageType uint8
@ -54,6 +57,8 @@ func writeMessageHeader(w io.Writer, p []byte, mh messageHeader) error {
return err return err
} }
var buffers sync.Pool
type channel struct { type channel struct {
bw *bufio.Writer bw *bufio.Writer
br *bufio.Reader br *bufio.Reader
@ -68,21 +73,32 @@ func newChannel(w io.Writer, r io.Reader) *channel {
} }
} }
func (ch *channel) recv(ctx context.Context, p []byte) (messageHeader, error) { // recv a message from the channel. The returned buffer contains the message.
//
// If a valid grpc status is returned, the message header
// returned will be valid and caller should send that along to
// the correct consumer. The bytes on the underlying channel
// will be discarded.
func (ch *channel) recv(ctx context.Context) (messageHeader, []byte, error) {
mh, err := readMessageHeader(ch.hrbuf[:], ch.br) mh, err := readMessageHeader(ch.hrbuf[:], ch.br)
if err != nil { if err != nil {
return messageHeader{}, err return messageHeader{}, nil, err
} }
if mh.Length > uint32(len(p)) { if mh.Length > uint32(messageLengthMax) {
return messageHeader{}, errors.Wrapf(io.ErrShortBuffer, "message length %v over buffer size %v", mh.Length, len(p)) if _, err := ch.br.Discard(int(mh.Length)); err != nil {
return mh, nil, errors.Wrapf(err, "failed to discard after receiving oversized message")
}
return mh, nil, status.Errorf(codes.ResourceExhausted, "message length %v exceed maximum message size of %v", mh.Length, messageLengthMax)
} }
if _, err := io.ReadFull(ch.br, p[:mh.Length]); err != nil { p := ch.getmbuf(int(mh.Length))
return messageHeader{}, errors.Wrapf(err, "failed reading message") if _, err := io.ReadFull(ch.br, p); err != nil {
return messageHeader{}, nil, errors.Wrapf(err, "failed reading message")
} }
return mh, nil return mh, p, nil
} }
func (ch *channel) send(ctx context.Context, streamID uint32, t messageType, p []byte) error { func (ch *channel) send(ctx context.Context, streamID uint32, t messageType, p []byte) error {
@ -97,3 +113,23 @@ func (ch *channel) send(ctx context.Context, streamID uint32, t messageType, p [
return ch.bw.Flush() return ch.bw.Flush()
} }
func (ch *channel) getmbuf(size int) []byte {
// we can't use the standard New method on pool because we want to allocate
// based on size.
b, ok := buffers.Get().(*[]byte)
if !ok || cap(*b) < size {
// TODO(stevvooe): It may be better to allocate these in fixed length
// buckets to reduce fragmentation but its not clear that would help
// with performance. An ilogb approach or similar would work well.
bb := make([]byte, size)
b = &bb
} else {
*b = (*b)[:size]
}
return *b
}
func (ch *channel) putmbuf(p []byte) {
buffers.Put(&p)
}

View File

@ -4,19 +4,18 @@ import (
"context" "context"
"net" "net"
"sync" "sync"
"sync/atomic"
"github.com/containerd/containerd/log"
"github.com/gogo/protobuf/proto" "github.com/gogo/protobuf/proto"
"github.com/pkg/errors" "github.com/pkg/errors"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
) )
type Client struct { type Client struct {
codec codec codec codec
channel *channel conn net.Conn
requestID uint32 channel *channel
sendRequests chan sendRequest calls chan *callRequest
recvRequests chan recvRequest
closed chan struct{} closed chan struct{}
closeOnce sync.Once closeOnce sync.Once
@ -26,50 +25,76 @@ type Client struct {
func NewClient(conn net.Conn) *Client { func NewClient(conn net.Conn) *Client {
c := &Client{ c := &Client{
codec: codec{}, codec: codec{},
requestID: 1, conn: conn,
channel: newChannel(conn, conn), channel: newChannel(conn, conn),
sendRequests: make(chan sendRequest), calls: make(chan *callRequest),
recvRequests: make(chan recvRequest), closed: make(chan struct{}),
closed: make(chan struct{}), done: make(chan struct{}),
done: make(chan struct{}),
} }
go c.run() go c.run()
return c return c
} }
type callRequest struct {
ctx context.Context
req *Request
resp *Response // response will be written back here
errs chan error // error written here on completion
}
func (c *Client) Call(ctx context.Context, service, method string, req, resp interface{}) error { func (c *Client) Call(ctx context.Context, service, method string, req, resp interface{}) error {
payload, err := c.codec.Marshal(req) payload, err := c.codec.Marshal(req)
if err != nil { if err != nil {
return err return err
} }
requestID := atomic.AddUint32(&c.requestID, 2) var (
request := Request{ creq = &Request{
Service: service, Service: service,
Method: method, Method: method,
Payload: payload, Payload: payload,
} }
if err := c.send(ctx, requestID, &request); err != nil { cresp = &Response{}
)
if err := c.dispatch(ctx, creq, cresp); err != nil {
return err return err
} }
var response Response if err := c.codec.Unmarshal(cresp.Payload, resp); err != nil {
if err := c.recv(ctx, requestID, &response); err != nil {
return err return err
} }
if err := c.codec.Unmarshal(response.Payload, resp); err != nil { if cresp.Status == nil {
return err
}
if response.Status == nil {
return errors.New("no status provided on response") return errors.New("no status provided on response")
} }
return status.ErrorProto(response.Status) return status.ErrorProto(cresp.Status)
}
func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) error {
errs := make(chan error, 1)
call := &callRequest{
req: req,
resp: resp,
errs: errs,
}
select {
case c.calls <- call:
case <-c.done:
return c.err
}
select {
case err := <-errs:
return err
case <-c.done:
return c.err
}
} }
func (c *Client) Close() error { func (c *Client) Close() error {
@ -80,92 +105,42 @@ func (c *Client) Close() error {
return nil return nil
} }
type sendRequest struct { type message struct {
ctx context.Context messageHeader
id uint32
msg interface{}
err chan error
}
func (c *Client) send(ctx context.Context, id uint32, msg interface{}) error {
errs := make(chan error, 1)
select {
case c.sendRequests <- sendRequest{
ctx: ctx,
id: id,
msg: msg,
err: errs,
}:
case <-ctx.Done():
return ctx.Err()
case <-c.done:
return c.err
}
select {
case err := <-errs:
return err
case <-ctx.Done():
return ctx.Err()
case <-c.done:
return c.err
}
}
type recvRequest struct {
id uint32
msg interface{}
err chan error
}
func (c *Client) recv(ctx context.Context, id uint32, msg interface{}) error {
errs := make(chan error, 1)
select {
case c.recvRequests <- recvRequest{
id: id,
msg: msg,
err: errs,
}:
case <-c.done:
return c.err
case <-ctx.Done():
return ctx.Err()
}
select {
case err := <-errs:
return err
case <-c.done:
return c.err
case <-ctx.Done():
return ctx.Err()
}
}
type received struct {
mh messageHeader
p []byte p []byte
err error err error
} }
func (c *Client) run() { func (c *Client) run() {
defer close(c.done)
var ( var (
waiters = map[uint32]recvRequest{} streamID uint32 = 1
queued = map[uint32]received{} // messages unmatched by waiter waiters = make(map[uint32]*callRequest)
incoming = make(chan received) calls = c.calls
incoming = make(chan *message)
shutdown = make(chan struct{})
shutdownErr error
) )
go func() { go func() {
defer close(shutdown)
// start one more goroutine to recv messages without blocking. // start one more goroutine to recv messages without blocking.
for { for {
var p [messageLengthMax]byte mh, p, err := c.channel.recv(context.TODO())
mh, err := c.channel.recv(context.TODO(), p[:]) if err != nil {
_, ok := status.FromError(err)
if !ok {
// treat all errors that are not an rpc status as terminal.
// all others poison the connection.
shutdownErr = err
return
}
}
select { select {
case incoming <- received{ case incoming <- &message{
mh: mh, messageHeader: mh,
p: p[:mh.Length], p: p[:mh.Length],
err: err, err: err,
}: }:
case <-c.done: case <-c.done:
return return
@ -173,32 +148,64 @@ func (c *Client) run() {
} }
}() }()
defer c.conn.Close()
defer close(c.done)
for { for {
select { select {
case req := <-c.sendRequests: case call := <-calls:
if p, err := proto.Marshal(req.msg.(proto.Message)); err != nil { if err := c.send(call.ctx, streamID, messageTypeRequest, call.req); err != nil {
req.err <- err call.errs <- err
} else { continue
req.err <- c.channel.send(req.ctx, req.id, messageTypeRequest, p)
}
case req := <-c.recvRequests:
if r, ok := queued[req.id]; ok {
req.err <- proto.Unmarshal(r.p, req.msg.(proto.Message))
}
waiters[req.id] = req
case r := <-incoming:
if r.err != nil {
c.err = r.err
return
} }
if waiter, ok := waiters[r.mh.StreamID]; ok { waiters[streamID] = call
waiter.err <- proto.Unmarshal(r.p, waiter.msg.(proto.Message)) streamID += 2 // enforce odd client initiated request ids
} else { case msg := <-incoming:
queued[r.mh.StreamID] = r call, ok := waiters[msg.StreamID]
if !ok {
log.L.Errorf("ttrpc: received message for unknown channel %v", msg.StreamID)
continue
} }
call.errs <- c.recv(call.resp, msg)
delete(waiters, msg.StreamID)
case <-shutdown:
shutdownErr = errors.Wrapf(shutdownErr, "ttrpc: client shutting down")
c.err = shutdownErr
for _, waiter := range waiters {
waiter.errs <- shutdownErr
}
c.Close()
return
case <-c.closed: case <-c.closed:
// broadcast the shutdown error to the remaining waiters.
for _, waiter := range waiters {
waiter.errs <- shutdownErr
}
return return
} }
} }
} }
func (c *Client) send(ctx context.Context, streamID uint32, mtype messageType, msg interface{}) error {
p, err := c.codec.Marshal(msg)
if err != nil {
return err
}
return c.channel.send(ctx, streamID, mtype, p)
}
func (c *Client) recv(resp *Response, msg *message) error {
if msg.err != nil {
return msg.err
}
if msg.Type != messageTypeResponse {
return errors.New("unkown message type received")
}
defer c.channel.putmbuf(msg.p)
return proto.Unmarshal(msg.p, resp)
}

View File

@ -2,21 +2,38 @@ package ttrpc
import ( import (
"context" "context"
"math/rand"
"net" "net"
"sync"
"sync/atomic"
"time"
"github.com/containerd/containerd/log" "github.com/containerd/containerd/log"
"github.com/pkg/errors"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
) )
var (
ErrServerClosed = errors.New("ttrpc: server close")
)
type Server struct { type Server struct {
services *serviceSet services *serviceSet
codec codec codec codec
mu sync.Mutex
listeners map[net.Listener]struct{}
connections map[*serverConn]struct{} // all connections to current state
done chan struct{} // marks point at which we stop serving requests
} }
func NewServer() *Server { func NewServer() *Server {
return &Server{ return &Server{
services: newServiceSet(), services: newServiceSet(),
done: make(chan struct{}),
listeners: make(map[net.Listener]struct{}),
connections: make(map[*serverConn]struct{}),
} }
} }
@ -24,28 +41,208 @@ func (s *Server) Register(name string, methods map[string]Method) {
s.services.register(name, methods) s.services.register(name, methods)
} }
func (s *Server) Shutdown(ctx context.Context) error {
// TODO(stevvooe): Wait on connection shutdown.
return nil
}
func (s *Server) Serve(l net.Listener) error { func (s *Server) Serve(l net.Listener) error {
s.addListener(l)
defer s.closeListener(l)
var (
ctx = context.Background()
backoff time.Duration
)
for { for {
conn, err := l.Accept() conn, err := l.Accept()
if err != nil { if err != nil {
log.L.WithError(err).Error("failed accept") select {
continue case <-s.done:
return ErrServerClosed
default:
}
if terr, ok := err.(interface {
Temporary() bool
}); ok && terr.Temporary() {
if backoff == 0 {
backoff = time.Millisecond
} else {
backoff *= 2
}
if max := time.Second; backoff > max {
backoff = max
}
sleep := time.Duration(rand.Int63n(int64(backoff)))
log.L.WithError(err).Errorf("ttrpc: failed accept; backoff %v", sleep)
time.Sleep(sleep)
continue
}
return err
} }
go s.handleConn(conn) backoff = 0
sc := s.newConn(conn)
go sc.run(ctx)
} }
}
func (s *Server) Shutdown(ctx context.Context) error {
s.mu.Lock()
lnerr := s.closeListeners()
select {
case <-s.done:
default:
// protected by mutex
close(s.done)
}
s.mu.Unlock()
ticker := time.NewTicker(200 * time.Millisecond)
defer ticker.Stop()
for {
if s.closeIdleConns() {
return lnerr
}
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
}
}
}
// Close the server without waiting for active connections.
func (s *Server) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
select {
case <-s.done:
default:
// protected by mutex
close(s.done)
}
err := s.closeListeners()
for c := range s.connections {
c.close()
delete(s.connections, c)
}
return err
}
func (s *Server) addListener(l net.Listener) {
s.mu.Lock()
defer s.mu.Unlock()
s.listeners[l] = struct{}{}
}
func (s *Server) closeListener(l net.Listener) error {
s.mu.Lock()
defer s.mu.Unlock()
return s.closeListenerLocked(l)
}
func (s *Server) closeListenerLocked(l net.Listener) error {
defer delete(s.listeners, l)
return l.Close()
}
func (s *Server) closeListeners() error {
var err error
for l := range s.listeners {
if cerr := s.closeListenerLocked(l); cerr != nil && err == nil {
err = cerr
}
}
return err
}
func (s *Server) addConnection(c *serverConn) {
s.mu.Lock()
defer s.mu.Unlock()
s.connections[c] = struct{}{}
}
func (s *Server) closeIdleConns() bool {
s.mu.Lock()
defer s.mu.Unlock()
quiescent := true
for c := range s.connections {
st, ok := c.getState()
if !ok || st != connStateIdle {
quiescent = false
continue
}
c.close()
delete(s.connections, c)
}
return quiescent
}
type connState int
const (
connStateActive = iota + 1 // outstanding requests
connStateIdle // no requests
connStateClosed // closed connection
)
func (cs connState) String() string {
switch cs {
case connStateActive:
return "active"
case connStateIdle:
return "idle"
case connStateClosed:
return "closed"
default:
return "unknown"
}
}
func (s *Server) newConn(conn net.Conn) *serverConn {
c := &serverConn{
server: s,
conn: conn,
shutdown: make(chan struct{}),
}
c.setState(connStateIdle)
s.addConnection(c)
return c
}
type serverConn struct {
server *Server
conn net.Conn
state atomic.Value
shutdownOnce sync.Once
shutdown chan struct{} // forced shutdown, used by close
}
func (c *serverConn) getState() (connState, bool) {
cs, ok := c.state.Load().(connState)
return cs, ok
}
func (c *serverConn) setState(newstate connState) {
c.state.Store(newstate)
}
func (c *serverConn) close() error {
c.shutdownOnce.Do(func() {
close(c.shutdown)
})
return nil return nil
} }
func (s *Server) handleConn(conn net.Conn) { func (c *serverConn) run(sctx context.Context) {
defer conn.Close()
type ( type (
request struct { request struct {
id uint32 id uint32
@ -59,25 +256,66 @@ func (s *Server) handleConn(conn net.Conn) {
) )
var ( var (
ch = newChannel(conn, conn) ch = newChannel(c.conn, c.conn)
ctx, cancel = context.WithCancel(context.Background()) ctx, cancel = context.WithCancel(sctx)
responses = make(chan response) active int
requests = make(chan request) state connState = connStateIdle
recvErr = make(chan error, 1) responses = make(chan response)
done = make(chan struct{}) requests = make(chan request)
recvErr = make(chan error, 1)
shutdown = c.shutdown
done = make(chan struct{})
) )
defer c.conn.Close()
defer cancel() defer cancel()
defer close(done) defer close(done)
go func() { go func(recvErr chan error) {
defer close(recvErr) defer close(recvErr)
var p [messageLengthMax]byte sendImmediate := func(id uint32, st *status.Status) bool {
select {
case responses <- response{
// even though we've had an invalid stream id, we send it
// back on the same stream id so the client knows which
// stream id was bad.
id: id,
resp: &Response{
Status: st.Proto(),
},
}:
return true
case <-c.shutdown:
return false
case <-done:
return false
}
}
for { for {
mh, err := ch.recv(ctx, p[:]) select {
if err != nil { case <-c.shutdown:
recvErr <- err
return return
case <-done:
return
default: // proceed
}
mh, p, err := ch.recv(ctx)
if err != nil {
status, ok := status.FromError(err)
if !ok {
recvErr <- err
return
}
// in this case, we send an error for that particular message
// when the status is defined.
if !sendImmediate(mh.StreamID, status) {
return
}
continue
} }
if mh.Type != messageTypeRequest { if mh.Type != messageTypeRequest {
@ -86,44 +324,57 @@ func (s *Server) handleConn(conn net.Conn) {
} }
var req Request var req Request
if err := s.codec.Unmarshal(p[:mh.Length], &req); err != nil { if err := c.server.codec.Unmarshal(p, &req); err != nil {
recvErr <- err ch.putmbuf(p)
return if !sendImmediate(mh.StreamID, status.Newf(codes.InvalidArgument, "unmarshal request error: %v", err)) {
return
}
continue
} }
ch.putmbuf(p)
if mh.StreamID%2 != 1 { if mh.StreamID%2 != 1 {
// enforce odd client initiated identifiers. // enforce odd client initiated identifiers.
select { if !sendImmediate(mh.StreamID, status.Newf(codes.InvalidArgument, "StreamID must be odd for client initiated streams")) {
case responses <- response{ return
// even though we've had an invalid stream id, we send it
// back on the same stream id so the client knows which
// stream id was bad.
id: mh.StreamID,
resp: &Response{
Status: status.New(codes.InvalidArgument, "StreamID must be odd for client initiated streams").Proto(),
},
}:
case <-done:
} }
continue continue
} }
// Forward the request to the main loop. We don't wait on s.done
// because we have already accepted the client request.
select { select {
case requests <- request{ case requests <- request{
id: mh.StreamID, id: mh.StreamID,
req: &req, req: &req,
}: }:
case <-done: case <-done:
return
} }
} }
}() }(recvErr)
for { for {
newstate := state
switch {
case active > 0:
newstate = connStateActive
shutdown = nil
case active == 0:
newstate = connStateIdle
shutdown = c.shutdown // only enable this branch in idle mode
}
if newstate != state {
c.setState(newstate)
state = newstate
}
select { select {
case request := <-requests: case request := <-requests:
active++
go func(id uint32) { go func(id uint32) {
p, status := s.services.call(ctx, request.req.Service, request.req.Method, request.req.Payload) p, status := c.server.services.call(ctx, request.req.Service, request.req.Method, request.req.Payload)
resp := &Response{ resp := &Response{
Status: status.Proto(), Status: status.Proto(),
Payload: p, Payload: p,
@ -138,17 +389,27 @@ func (s *Server) handleConn(conn net.Conn) {
} }
}(request.id) }(request.id)
case response := <-responses: case response := <-responses:
p, err := s.codec.Marshal(response.resp) p, err := c.server.codec.Marshal(response.resp)
if err != nil { if err != nil {
log.L.WithError(err).Error("failed marshaling response") log.L.WithError(err).Error("failed marshaling response")
return return
} }
if err := ch.send(ctx, response.id, messageTypeResponse, p); err != nil { if err := ch.send(ctx, response.id, messageTypeResponse, p); err != nil {
log.L.WithError(err).Error("failed sending message on channel") log.L.WithError(err).Error("failed sending message on channel")
return return
} }
active--
case err := <-recvErr: case err := <-recvErr:
log.L.WithError(err).Error("error receiving message") // TODO(stevvooe): Not wildly clear what we should do in this
// branch. Basically, it means that we are no longer receiving
// requests due to a terminal error.
recvErr = nil // connection is now "closing"
if err != nil {
log.L.WithError(err).Error("error receiving message")
}
case <-shutdown:
return return
} }
} }