vendor: update ttrpc with latest changes
Signed-off-by: Stephen J Day <stephen.day@docker.com>
This commit is contained in:
parent
59bd196711
commit
393cf8e8fc
@ -41,4 +41,4 @@ github.com/boltdb/bolt e9cf4fae01b5a8ff89d0ec6b32f0d9c9f79aefdd
|
|||||||
google.golang.org/genproto d80a6e20e776b0b17a324d0ba1ab50a39c8e8944
|
google.golang.org/genproto d80a6e20e776b0b17a324d0ba1ab50a39c8e8944
|
||||||
golang.org/x/text 19e51611da83d6be54ddafce4a4af510cb3e9ea4
|
golang.org/x/text 19e51611da83d6be54ddafce4a4af510cb3e9ea4
|
||||||
github.com/dmcgowan/go-tar go1.10
|
github.com/dmcgowan/go-tar go1.10
|
||||||
github.com/stevvooe/ttrpc bdb2ab7a8169e485e39421e666e15a505e575fd2
|
github.com/stevvooe/ttrpc 8c92e22ce0c492875ccaac3ab06143a77d8ed0c1
|
||||||
|
2
vendor/github.com/stevvooe/ttrpc/README.md
generated
vendored
2
vendor/github.com/stevvooe/ttrpc/README.md
generated
vendored
@ -1,5 +1,7 @@
|
|||||||
# ttrpc
|
# ttrpc
|
||||||
|
|
||||||
|
[](https://travis-ci.org/stevvooe/ttrpc)
|
||||||
|
|
||||||
GRPC for low-memory environments.
|
GRPC for low-memory environments.
|
||||||
|
|
||||||
The existing grpc-go project requires a lot of memory overhead for importing
|
The existing grpc-go project requires a lot of memory overhead for importing
|
||||||
|
52
vendor/github.com/stevvooe/ttrpc/channel.go
generated
vendored
52
vendor/github.com/stevvooe/ttrpc/channel.go
generated
vendored
@ -5,13 +5,16 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"io"
|
"io"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
messageHeaderLength = 10
|
messageHeaderLength = 10
|
||||||
messageLengthMax = 8 << 10
|
messageLengthMax = 4 << 20
|
||||||
)
|
)
|
||||||
|
|
||||||
type messageType uint8
|
type messageType uint8
|
||||||
@ -54,6 +57,8 @@ func writeMessageHeader(w io.Writer, p []byte, mh messageHeader) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var buffers sync.Pool
|
||||||
|
|
||||||
type channel struct {
|
type channel struct {
|
||||||
bw *bufio.Writer
|
bw *bufio.Writer
|
||||||
br *bufio.Reader
|
br *bufio.Reader
|
||||||
@ -68,21 +73,32 @@ func newChannel(w io.Writer, r io.Reader) *channel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ch *channel) recv(ctx context.Context, p []byte) (messageHeader, error) {
|
// recv a message from the channel. The returned buffer contains the message.
|
||||||
|
//
|
||||||
|
// If a valid grpc status is returned, the message header
|
||||||
|
// returned will be valid and caller should send that along to
|
||||||
|
// the correct consumer. The bytes on the underlying channel
|
||||||
|
// will be discarded.
|
||||||
|
func (ch *channel) recv(ctx context.Context) (messageHeader, []byte, error) {
|
||||||
mh, err := readMessageHeader(ch.hrbuf[:], ch.br)
|
mh, err := readMessageHeader(ch.hrbuf[:], ch.br)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return messageHeader{}, err
|
return messageHeader{}, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if mh.Length > uint32(len(p)) {
|
if mh.Length > uint32(messageLengthMax) {
|
||||||
return messageHeader{}, errors.Wrapf(io.ErrShortBuffer, "message length %v over buffer size %v", mh.Length, len(p))
|
if _, err := ch.br.Discard(int(mh.Length)); err != nil {
|
||||||
|
return mh, nil, errors.Wrapf(err, "failed to discard after receiving oversized message")
|
||||||
|
}
|
||||||
|
|
||||||
|
return mh, nil, status.Errorf(codes.ResourceExhausted, "message length %v exceed maximum message size of %v", mh.Length, messageLengthMax)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := io.ReadFull(ch.br, p[:mh.Length]); err != nil {
|
p := ch.getmbuf(int(mh.Length))
|
||||||
return messageHeader{}, errors.Wrapf(err, "failed reading message")
|
if _, err := io.ReadFull(ch.br, p); err != nil {
|
||||||
|
return messageHeader{}, nil, errors.Wrapf(err, "failed reading message")
|
||||||
}
|
}
|
||||||
|
|
||||||
return mh, nil
|
return mh, p, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ch *channel) send(ctx context.Context, streamID uint32, t messageType, p []byte) error {
|
func (ch *channel) send(ctx context.Context, streamID uint32, t messageType, p []byte) error {
|
||||||
@ -97,3 +113,23 @@ func (ch *channel) send(ctx context.Context, streamID uint32, t messageType, p [
|
|||||||
|
|
||||||
return ch.bw.Flush()
|
return ch.bw.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ch *channel) getmbuf(size int) []byte {
|
||||||
|
// we can't use the standard New method on pool because we want to allocate
|
||||||
|
// based on size.
|
||||||
|
b, ok := buffers.Get().(*[]byte)
|
||||||
|
if !ok || cap(*b) < size {
|
||||||
|
// TODO(stevvooe): It may be better to allocate these in fixed length
|
||||||
|
// buckets to reduce fragmentation but its not clear that would help
|
||||||
|
// with performance. An ilogb approach or similar would work well.
|
||||||
|
bb := make([]byte, size)
|
||||||
|
b = &bb
|
||||||
|
} else {
|
||||||
|
*b = (*b)[:size]
|
||||||
|
}
|
||||||
|
return *b
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ch *channel) putmbuf(p []byte) {
|
||||||
|
buffers.Put(&p)
|
||||||
|
}
|
||||||
|
249
vendor/github.com/stevvooe/ttrpc/client.go
generated
vendored
249
vendor/github.com/stevvooe/ttrpc/client.go
generated
vendored
@ -4,19 +4,18 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
|
||||||
|
|
||||||
|
"github.com/containerd/containerd/log"
|
||||||
"github.com/gogo/protobuf/proto"
|
"github.com/gogo/protobuf/proto"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Client struct {
|
type Client struct {
|
||||||
codec codec
|
codec codec
|
||||||
channel *channel
|
conn net.Conn
|
||||||
requestID uint32
|
channel *channel
|
||||||
sendRequests chan sendRequest
|
calls chan *callRequest
|
||||||
recvRequests chan recvRequest
|
|
||||||
|
|
||||||
closed chan struct{}
|
closed chan struct{}
|
||||||
closeOnce sync.Once
|
closeOnce sync.Once
|
||||||
@ -26,50 +25,76 @@ type Client struct {
|
|||||||
|
|
||||||
func NewClient(conn net.Conn) *Client {
|
func NewClient(conn net.Conn) *Client {
|
||||||
c := &Client{
|
c := &Client{
|
||||||
codec: codec{},
|
codec: codec{},
|
||||||
requestID: 1,
|
conn: conn,
|
||||||
channel: newChannel(conn, conn),
|
channel: newChannel(conn, conn),
|
||||||
sendRequests: make(chan sendRequest),
|
calls: make(chan *callRequest),
|
||||||
recvRequests: make(chan recvRequest),
|
closed: make(chan struct{}),
|
||||||
closed: make(chan struct{}),
|
done: make(chan struct{}),
|
||||||
done: make(chan struct{}),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
go c.run()
|
go c.run()
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type callRequest struct {
|
||||||
|
ctx context.Context
|
||||||
|
req *Request
|
||||||
|
resp *Response // response will be written back here
|
||||||
|
errs chan error // error written here on completion
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Client) Call(ctx context.Context, service, method string, req, resp interface{}) error {
|
func (c *Client) Call(ctx context.Context, service, method string, req, resp interface{}) error {
|
||||||
payload, err := c.codec.Marshal(req)
|
payload, err := c.codec.Marshal(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
requestID := atomic.AddUint32(&c.requestID, 2)
|
var (
|
||||||
request := Request{
|
creq = &Request{
|
||||||
Service: service,
|
Service: service,
|
||||||
Method: method,
|
Method: method,
|
||||||
Payload: payload,
|
Payload: payload,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := c.send(ctx, requestID, &request); err != nil {
|
cresp = &Response{}
|
||||||
|
)
|
||||||
|
|
||||||
|
if err := c.dispatch(ctx, creq, cresp); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
var response Response
|
if err := c.codec.Unmarshal(cresp.Payload, resp); err != nil {
|
||||||
if err := c.recv(ctx, requestID, &response); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := c.codec.Unmarshal(response.Payload, resp); err != nil {
|
if cresp.Status == nil {
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if response.Status == nil {
|
|
||||||
return errors.New("no status provided on response")
|
return errors.New("no status provided on response")
|
||||||
}
|
}
|
||||||
|
|
||||||
return status.ErrorProto(response.Status)
|
return status.ErrorProto(cresp.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) error {
|
||||||
|
errs := make(chan error, 1)
|
||||||
|
call := &callRequest{
|
||||||
|
req: req,
|
||||||
|
resp: resp,
|
||||||
|
errs: errs,
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case c.calls <- call:
|
||||||
|
case <-c.done:
|
||||||
|
return c.err
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-errs:
|
||||||
|
return err
|
||||||
|
case <-c.done:
|
||||||
|
return c.err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Close() error {
|
func (c *Client) Close() error {
|
||||||
@ -80,92 +105,42 @@ func (c *Client) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type sendRequest struct {
|
type message struct {
|
||||||
ctx context.Context
|
messageHeader
|
||||||
id uint32
|
|
||||||
msg interface{}
|
|
||||||
err chan error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) send(ctx context.Context, id uint32, msg interface{}) error {
|
|
||||||
errs := make(chan error, 1)
|
|
||||||
select {
|
|
||||||
case c.sendRequests <- sendRequest{
|
|
||||||
ctx: ctx,
|
|
||||||
id: id,
|
|
||||||
msg: msg,
|
|
||||||
err: errs,
|
|
||||||
}:
|
|
||||||
case <-ctx.Done():
|
|
||||||
return ctx.Err()
|
|
||||||
case <-c.done:
|
|
||||||
return c.err
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case err := <-errs:
|
|
||||||
return err
|
|
||||||
case <-ctx.Done():
|
|
||||||
return ctx.Err()
|
|
||||||
case <-c.done:
|
|
||||||
return c.err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type recvRequest struct {
|
|
||||||
id uint32
|
|
||||||
msg interface{}
|
|
||||||
err chan error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) recv(ctx context.Context, id uint32, msg interface{}) error {
|
|
||||||
errs := make(chan error, 1)
|
|
||||||
select {
|
|
||||||
case c.recvRequests <- recvRequest{
|
|
||||||
id: id,
|
|
||||||
msg: msg,
|
|
||||||
err: errs,
|
|
||||||
}:
|
|
||||||
case <-c.done:
|
|
||||||
return c.err
|
|
||||||
case <-ctx.Done():
|
|
||||||
return ctx.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case err := <-errs:
|
|
||||||
return err
|
|
||||||
case <-c.done:
|
|
||||||
return c.err
|
|
||||||
case <-ctx.Done():
|
|
||||||
return ctx.Err()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type received struct {
|
|
||||||
mh messageHeader
|
|
||||||
p []byte
|
p []byte
|
||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) run() {
|
func (c *Client) run() {
|
||||||
defer close(c.done)
|
|
||||||
var (
|
var (
|
||||||
waiters = map[uint32]recvRequest{}
|
streamID uint32 = 1
|
||||||
queued = map[uint32]received{} // messages unmatched by waiter
|
waiters = make(map[uint32]*callRequest)
|
||||||
incoming = make(chan received)
|
calls = c.calls
|
||||||
|
incoming = make(chan *message)
|
||||||
|
shutdown = make(chan struct{})
|
||||||
|
shutdownErr error
|
||||||
)
|
)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
|
defer close(shutdown)
|
||||||
|
|
||||||
// start one more goroutine to recv messages without blocking.
|
// start one more goroutine to recv messages without blocking.
|
||||||
for {
|
for {
|
||||||
var p [messageLengthMax]byte
|
mh, p, err := c.channel.recv(context.TODO())
|
||||||
mh, err := c.channel.recv(context.TODO(), p[:])
|
if err != nil {
|
||||||
|
_, ok := status.FromError(err)
|
||||||
|
if !ok {
|
||||||
|
// treat all errors that are not an rpc status as terminal.
|
||||||
|
// all others poison the connection.
|
||||||
|
shutdownErr = err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
select {
|
select {
|
||||||
case incoming <- received{
|
case incoming <- &message{
|
||||||
mh: mh,
|
messageHeader: mh,
|
||||||
p: p[:mh.Length],
|
p: p[:mh.Length],
|
||||||
err: err,
|
err: err,
|
||||||
}:
|
}:
|
||||||
case <-c.done:
|
case <-c.done:
|
||||||
return
|
return
|
||||||
@ -173,32 +148,64 @@ func (c *Client) run() {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
defer c.conn.Close()
|
||||||
|
defer close(c.done)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case req := <-c.sendRequests:
|
case call := <-calls:
|
||||||
if p, err := proto.Marshal(req.msg.(proto.Message)); err != nil {
|
if err := c.send(call.ctx, streamID, messageTypeRequest, call.req); err != nil {
|
||||||
req.err <- err
|
call.errs <- err
|
||||||
} else {
|
continue
|
||||||
req.err <- c.channel.send(req.ctx, req.id, messageTypeRequest, p)
|
|
||||||
}
|
|
||||||
case req := <-c.recvRequests:
|
|
||||||
if r, ok := queued[req.id]; ok {
|
|
||||||
req.err <- proto.Unmarshal(r.p, req.msg.(proto.Message))
|
|
||||||
}
|
|
||||||
waiters[req.id] = req
|
|
||||||
case r := <-incoming:
|
|
||||||
if r.err != nil {
|
|
||||||
c.err = r.err
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if waiter, ok := waiters[r.mh.StreamID]; ok {
|
waiters[streamID] = call
|
||||||
waiter.err <- proto.Unmarshal(r.p, waiter.msg.(proto.Message))
|
streamID += 2 // enforce odd client initiated request ids
|
||||||
} else {
|
case msg := <-incoming:
|
||||||
queued[r.mh.StreamID] = r
|
call, ok := waiters[msg.StreamID]
|
||||||
|
if !ok {
|
||||||
|
log.L.Errorf("ttrpc: received message for unknown channel %v", msg.StreamID)
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
call.errs <- c.recv(call.resp, msg)
|
||||||
|
delete(waiters, msg.StreamID)
|
||||||
|
case <-shutdown:
|
||||||
|
shutdownErr = errors.Wrapf(shutdownErr, "ttrpc: client shutting down")
|
||||||
|
c.err = shutdownErr
|
||||||
|
for _, waiter := range waiters {
|
||||||
|
waiter.errs <- shutdownErr
|
||||||
|
}
|
||||||
|
c.Close()
|
||||||
|
return
|
||||||
case <-c.closed:
|
case <-c.closed:
|
||||||
|
// broadcast the shutdown error to the remaining waiters.
|
||||||
|
for _, waiter := range waiters {
|
||||||
|
waiter.errs <- shutdownErr
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) send(ctx context.Context, streamID uint32, mtype messageType, msg interface{}) error {
|
||||||
|
p, err := c.codec.Marshal(msg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.channel.send(ctx, streamID, mtype, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) recv(resp *Response, msg *message) error {
|
||||||
|
if msg.err != nil {
|
||||||
|
return msg.err
|
||||||
|
}
|
||||||
|
|
||||||
|
if msg.Type != messageTypeResponse {
|
||||||
|
return errors.New("unkown message type received")
|
||||||
|
}
|
||||||
|
|
||||||
|
defer c.channel.putmbuf(msg.p)
|
||||||
|
return proto.Unmarshal(msg.p, resp)
|
||||||
|
}
|
||||||
|
345
vendor/github.com/stevvooe/ttrpc/server.go
generated
vendored
345
vendor/github.com/stevvooe/ttrpc/server.go
generated
vendored
@ -2,21 +2,38 @@ package ttrpc
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/containerd/containerd/log"
|
"github.com/containerd/containerd/log"
|
||||||
|
"github.com/pkg/errors"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrServerClosed = errors.New("ttrpc: server close")
|
||||||
|
)
|
||||||
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
services *serviceSet
|
services *serviceSet
|
||||||
codec codec
|
codec codec
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
listeners map[net.Listener]struct{}
|
||||||
|
connections map[*serverConn]struct{} // all connections to current state
|
||||||
|
done chan struct{} // marks point at which we stop serving requests
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServer() *Server {
|
func NewServer() *Server {
|
||||||
return &Server{
|
return &Server{
|
||||||
services: newServiceSet(),
|
services: newServiceSet(),
|
||||||
|
done: make(chan struct{}),
|
||||||
|
listeners: make(map[net.Listener]struct{}),
|
||||||
|
connections: make(map[*serverConn]struct{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -24,28 +41,208 @@ func (s *Server) Register(name string, methods map[string]Method) {
|
|||||||
s.services.register(name, methods)
|
s.services.register(name, methods)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) Shutdown(ctx context.Context) error {
|
|
||||||
// TODO(stevvooe): Wait on connection shutdown.
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) Serve(l net.Listener) error {
|
func (s *Server) Serve(l net.Listener) error {
|
||||||
|
s.addListener(l)
|
||||||
|
defer s.closeListener(l)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ctx = context.Background()
|
||||||
|
backoff time.Duration
|
||||||
|
)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
conn, err := l.Accept()
|
conn, err := l.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.L.WithError(err).Error("failed accept")
|
select {
|
||||||
continue
|
case <-s.done:
|
||||||
|
return ErrServerClosed
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
if terr, ok := err.(interface {
|
||||||
|
Temporary() bool
|
||||||
|
}); ok && terr.Temporary() {
|
||||||
|
if backoff == 0 {
|
||||||
|
backoff = time.Millisecond
|
||||||
|
} else {
|
||||||
|
backoff *= 2
|
||||||
|
}
|
||||||
|
|
||||||
|
if max := time.Second; backoff > max {
|
||||||
|
backoff = max
|
||||||
|
}
|
||||||
|
|
||||||
|
sleep := time.Duration(rand.Int63n(int64(backoff)))
|
||||||
|
log.L.WithError(err).Errorf("ttrpc: failed accept; backoff %v", sleep)
|
||||||
|
time.Sleep(sleep)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
go s.handleConn(conn)
|
backoff = 0
|
||||||
|
sc := s.newConn(conn)
|
||||||
|
go sc.run(ctx)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) Shutdown(ctx context.Context) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
lnerr := s.closeListeners()
|
||||||
|
select {
|
||||||
|
case <-s.done:
|
||||||
|
default:
|
||||||
|
// protected by mutex
|
||||||
|
close(s.done)
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
ticker := time.NewTicker(200 * time.Millisecond)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
if s.closeIdleConns() {
|
||||||
|
return lnerr
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case <-ticker.C:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close the server without waiting for active connections.
|
||||||
|
func (s *Server) Close() error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-s.done:
|
||||||
|
default:
|
||||||
|
// protected by mutex
|
||||||
|
close(s.done)
|
||||||
|
}
|
||||||
|
|
||||||
|
err := s.closeListeners()
|
||||||
|
for c := range s.connections {
|
||||||
|
c.close()
|
||||||
|
delete(s.connections, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) addListener(l net.Listener) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
s.listeners[l] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) closeListener(l net.Listener) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
return s.closeListenerLocked(l)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) closeListenerLocked(l net.Listener) error {
|
||||||
|
defer delete(s.listeners, l)
|
||||||
|
return l.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) closeListeners() error {
|
||||||
|
var err error
|
||||||
|
for l := range s.listeners {
|
||||||
|
if cerr := s.closeListenerLocked(l); cerr != nil && err == nil {
|
||||||
|
err = cerr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) addConnection(c *serverConn) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
s.connections[c] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) closeIdleConns() bool {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
quiescent := true
|
||||||
|
for c := range s.connections {
|
||||||
|
st, ok := c.getState()
|
||||||
|
if !ok || st != connStateIdle {
|
||||||
|
quiescent = false
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
c.close()
|
||||||
|
delete(s.connections, c)
|
||||||
|
}
|
||||||
|
return quiescent
|
||||||
|
}
|
||||||
|
|
||||||
|
type connState int
|
||||||
|
|
||||||
|
const (
|
||||||
|
connStateActive = iota + 1 // outstanding requests
|
||||||
|
connStateIdle // no requests
|
||||||
|
connStateClosed // closed connection
|
||||||
|
)
|
||||||
|
|
||||||
|
func (cs connState) String() string {
|
||||||
|
switch cs {
|
||||||
|
case connStateActive:
|
||||||
|
return "active"
|
||||||
|
case connStateIdle:
|
||||||
|
return "idle"
|
||||||
|
case connStateClosed:
|
||||||
|
return "closed"
|
||||||
|
default:
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) newConn(conn net.Conn) *serverConn {
|
||||||
|
c := &serverConn{
|
||||||
|
server: s,
|
||||||
|
conn: conn,
|
||||||
|
shutdown: make(chan struct{}),
|
||||||
|
}
|
||||||
|
c.setState(connStateIdle)
|
||||||
|
s.addConnection(c)
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
type serverConn struct {
|
||||||
|
server *Server
|
||||||
|
conn net.Conn
|
||||||
|
state atomic.Value
|
||||||
|
|
||||||
|
shutdownOnce sync.Once
|
||||||
|
shutdown chan struct{} // forced shutdown, used by close
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *serverConn) getState() (connState, bool) {
|
||||||
|
cs, ok := c.state.Load().(connState)
|
||||||
|
return cs, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *serverConn) setState(newstate connState) {
|
||||||
|
c.state.Store(newstate)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *serverConn) close() error {
|
||||||
|
c.shutdownOnce.Do(func() {
|
||||||
|
close(c.shutdown)
|
||||||
|
})
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleConn(conn net.Conn) {
|
func (c *serverConn) run(sctx context.Context) {
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
type (
|
type (
|
||||||
request struct {
|
request struct {
|
||||||
id uint32
|
id uint32
|
||||||
@ -59,25 +256,66 @@ func (s *Server) handleConn(conn net.Conn) {
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ch = newChannel(conn, conn)
|
ch = newChannel(c.conn, c.conn)
|
||||||
ctx, cancel = context.WithCancel(context.Background())
|
ctx, cancel = context.WithCancel(sctx)
|
||||||
responses = make(chan response)
|
active int
|
||||||
requests = make(chan request)
|
state connState = connStateIdle
|
||||||
recvErr = make(chan error, 1)
|
responses = make(chan response)
|
||||||
done = make(chan struct{})
|
requests = make(chan request)
|
||||||
|
recvErr = make(chan error, 1)
|
||||||
|
shutdown = c.shutdown
|
||||||
|
done = make(chan struct{})
|
||||||
)
|
)
|
||||||
|
|
||||||
|
defer c.conn.Close()
|
||||||
defer cancel()
|
defer cancel()
|
||||||
defer close(done)
|
defer close(done)
|
||||||
|
|
||||||
go func() {
|
go func(recvErr chan error) {
|
||||||
defer close(recvErr)
|
defer close(recvErr)
|
||||||
var p [messageLengthMax]byte
|
sendImmediate := func(id uint32, st *status.Status) bool {
|
||||||
|
select {
|
||||||
|
case responses <- response{
|
||||||
|
// even though we've had an invalid stream id, we send it
|
||||||
|
// back on the same stream id so the client knows which
|
||||||
|
// stream id was bad.
|
||||||
|
id: id,
|
||||||
|
resp: &Response{
|
||||||
|
Status: st.Proto(),
|
||||||
|
},
|
||||||
|
}:
|
||||||
|
return true
|
||||||
|
case <-c.shutdown:
|
||||||
|
return false
|
||||||
|
case <-done:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
mh, err := ch.recv(ctx, p[:])
|
select {
|
||||||
if err != nil {
|
case <-c.shutdown:
|
||||||
recvErr <- err
|
|
||||||
return
|
return
|
||||||
|
case <-done:
|
||||||
|
return
|
||||||
|
default: // proceed
|
||||||
|
}
|
||||||
|
|
||||||
|
mh, p, err := ch.recv(ctx)
|
||||||
|
if err != nil {
|
||||||
|
status, ok := status.FromError(err)
|
||||||
|
if !ok {
|
||||||
|
recvErr <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// in this case, we send an error for that particular message
|
||||||
|
// when the status is defined.
|
||||||
|
if !sendImmediate(mh.StreamID, status) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if mh.Type != messageTypeRequest {
|
if mh.Type != messageTypeRequest {
|
||||||
@ -86,44 +324,57 @@ func (s *Server) handleConn(conn net.Conn) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var req Request
|
var req Request
|
||||||
if err := s.codec.Unmarshal(p[:mh.Length], &req); err != nil {
|
if err := c.server.codec.Unmarshal(p, &req); err != nil {
|
||||||
recvErr <- err
|
ch.putmbuf(p)
|
||||||
return
|
if !sendImmediate(mh.StreamID, status.Newf(codes.InvalidArgument, "unmarshal request error: %v", err)) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
ch.putmbuf(p)
|
||||||
|
|
||||||
if mh.StreamID%2 != 1 {
|
if mh.StreamID%2 != 1 {
|
||||||
// enforce odd client initiated identifiers.
|
// enforce odd client initiated identifiers.
|
||||||
select {
|
if !sendImmediate(mh.StreamID, status.Newf(codes.InvalidArgument, "StreamID must be odd for client initiated streams")) {
|
||||||
case responses <- response{
|
return
|
||||||
// even though we've had an invalid stream id, we send it
|
|
||||||
// back on the same stream id so the client knows which
|
|
||||||
// stream id was bad.
|
|
||||||
id: mh.StreamID,
|
|
||||||
resp: &Response{
|
|
||||||
Status: status.New(codes.InvalidArgument, "StreamID must be odd for client initiated streams").Proto(),
|
|
||||||
},
|
|
||||||
}:
|
|
||||||
case <-done:
|
|
||||||
}
|
}
|
||||||
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Forward the request to the main loop. We don't wait on s.done
|
||||||
|
// because we have already accepted the client request.
|
||||||
select {
|
select {
|
||||||
case requests <- request{
|
case requests <- request{
|
||||||
id: mh.StreamID,
|
id: mh.StreamID,
|
||||||
req: &req,
|
req: &req,
|
||||||
}:
|
}:
|
||||||
case <-done:
|
case <-done:
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}(recvErr)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
|
newstate := state
|
||||||
|
switch {
|
||||||
|
case active > 0:
|
||||||
|
newstate = connStateActive
|
||||||
|
shutdown = nil
|
||||||
|
case active == 0:
|
||||||
|
newstate = connStateIdle
|
||||||
|
shutdown = c.shutdown // only enable this branch in idle mode
|
||||||
|
}
|
||||||
|
|
||||||
|
if newstate != state {
|
||||||
|
c.setState(newstate)
|
||||||
|
state = newstate
|
||||||
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case request := <-requests:
|
case request := <-requests:
|
||||||
|
active++
|
||||||
go func(id uint32) {
|
go func(id uint32) {
|
||||||
p, status := s.services.call(ctx, request.req.Service, request.req.Method, request.req.Payload)
|
p, status := c.server.services.call(ctx, request.req.Service, request.req.Method, request.req.Payload)
|
||||||
resp := &Response{
|
resp := &Response{
|
||||||
Status: status.Proto(),
|
Status: status.Proto(),
|
||||||
Payload: p,
|
Payload: p,
|
||||||
@ -138,17 +389,27 @@ func (s *Server) handleConn(conn net.Conn) {
|
|||||||
}
|
}
|
||||||
}(request.id)
|
}(request.id)
|
||||||
case response := <-responses:
|
case response := <-responses:
|
||||||
p, err := s.codec.Marshal(response.resp)
|
p, err := c.server.codec.Marshal(response.resp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.L.WithError(err).Error("failed marshaling response")
|
log.L.WithError(err).Error("failed marshaling response")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := ch.send(ctx, response.id, messageTypeResponse, p); err != nil {
|
if err := ch.send(ctx, response.id, messageTypeResponse, p); err != nil {
|
||||||
log.L.WithError(err).Error("failed sending message on channel")
|
log.L.WithError(err).Error("failed sending message on channel")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
active--
|
||||||
case err := <-recvErr:
|
case err := <-recvErr:
|
||||||
log.L.WithError(err).Error("error receiving message")
|
// TODO(stevvooe): Not wildly clear what we should do in this
|
||||||
|
// branch. Basically, it means that we are no longer receiving
|
||||||
|
// requests due to a terminal error.
|
||||||
|
recvErr = nil // connection is now "closing"
|
||||||
|
if err != nil {
|
||||||
|
log.L.WithError(err).Error("error receiving message")
|
||||||
|
}
|
||||||
|
case <-shutdown:
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user