ttrpc: refactor client to better handle EOF

The request and response requests opened up a nasty race condition where
waiters could find themselves either blocked or receiving errant errors.
The result was low performance and inadvertent busy waits. This
refactors the client to have a single request into the main client loop,
eliminating the race.

The reason for the original design was to allow a sender to control
request and response individually to make unit testing easier. The unit
test has now been refactored to use a channel to ensure that requests
are serviced on graceful shutdown.

Signed-off-by: Stephen J Day <stephen.day@docker.com>
This commit is contained in:
Stephen J Day
2017-11-29 17:22:43 -08:00
parent 78323657aa
commit b774f8872e
3 changed files with 206 additions and 162 deletions

253
client.go
View File

@@ -4,19 +4,18 @@ import (
"context"
"net"
"sync"
"sync/atomic"
"github.com/containerd/containerd/log"
"github.com/gogo/protobuf/proto"
"github.com/pkg/errors"
"google.golang.org/grpc/status"
)
type Client struct {
codec codec
channel *channel
requestID uint32
sendRequests chan sendRequest
recvRequests chan recvRequest
codec codec
conn net.Conn
channel *channel
calls chan *callRequest
closed chan struct{}
closeOnce sync.Once
@@ -26,58 +25,76 @@ type Client struct {
func NewClient(conn net.Conn) *Client {
c := &Client{
codec: codec{},
requestID: 1,
channel: newChannel(conn, conn),
sendRequests: make(chan sendRequest),
recvRequests: make(chan recvRequest),
closed: make(chan struct{}),
done: make(chan struct{}),
codec: codec{},
conn: conn,
channel: newChannel(conn, conn),
calls: make(chan *callRequest),
closed: make(chan struct{}),
done: make(chan struct{}),
}
go c.run()
return c
}
func (c *Client) Call(ctx context.Context, service, method string, req, resp interface{}) error {
requestID := atomic.AddUint32(&c.requestID, 2)
if err := c.sendRequest(ctx, requestID, service, method, req); err != nil {
return err
}
return c.recvResponse(ctx, requestID, resp)
type callRequest struct {
ctx context.Context
req *Request
resp *Response // response will be written back here
errs chan error // error written here on completion
}
func (c *Client) sendRequest(ctx context.Context, requestID uint32, service, method string, req interface{}) error {
func (c *Client) Call(ctx context.Context, service, method string, req, resp interface{}) error {
payload, err := c.codec.Marshal(req)
if err != nil {
return err
}
request := Request{
Service: service,
Method: method,
Payload: payload,
}
var (
creq = &Request{
Service: service,
Method: method,
Payload: payload,
}
return c.send(ctx, requestID, &request)
}
cresp = &Response{}
)
func (c *Client) recvResponse(ctx context.Context, requestID uint32, resp interface{}) error {
var response Response
if err := c.recv(ctx, requestID, &response); err != nil {
if err := c.dispatch(ctx, creq, cresp); err != nil {
return err
}
if err := c.codec.Unmarshal(response.Payload, resp); err != nil {
if err := c.codec.Unmarshal(cresp.Payload, resp); err != nil {
return err
}
if response.Status == nil {
if cresp.Status == nil {
return errors.New("no status provided on response")
}
return status.ErrorProto(response.Status)
return status.ErrorProto(cresp.Status)
}
func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) error {
errs := make(chan error, 1)
call := &callRequest{
req: req,
resp: resp,
errs: errs,
}
select {
case c.calls <- call:
case <-c.done:
return c.err
}
select {
case err := <-errs:
return err
case <-c.done:
return c.err
}
}
func (c *Client) Close() error {
@@ -88,95 +105,42 @@ func (c *Client) Close() error {
return nil
}
type sendRequest struct {
ctx context.Context
id uint32
msg interface{}
err chan error
}
func (c *Client) send(ctx context.Context, id uint32, msg interface{}) error {
errs := make(chan error, 1)
select {
case c.sendRequests <- sendRequest{
ctx: ctx,
id: id,
msg: msg,
err: errs,
}:
case <-ctx.Done():
return ctx.Err()
case <-c.done:
return c.err
}
select {
case err := <-errs:
return err
case <-ctx.Done():
return ctx.Err()
case <-c.done:
return c.err
}
}
type recvRequest struct {
id uint32
msg interface{}
err chan error
}
func (c *Client) recv(ctx context.Context, id uint32, msg interface{}) error {
errs := make(chan error, 1)
select {
case c.recvRequests <- recvRequest{
id: id,
msg: msg,
err: errs,
}:
case <-c.done:
return c.err
case <-ctx.Done():
return ctx.Err()
}
select {
case err := <-errs:
return err
case <-c.done:
return c.err
case <-ctx.Done():
return ctx.Err()
}
}
type received struct {
mh messageHeader
type message struct {
messageHeader
p []byte
err error
}
func (c *Client) run() {
defer close(c.done)
var (
waiters = map[uint32]recvRequest{}
queued = map[uint32]received{} // messages unmatched by waiter
incoming = make(chan received)
streamID uint32 = 1
waiters = make(map[uint32]*callRequest)
calls = c.calls
incoming = make(chan *message)
shutdown = make(chan struct{})
shutdownErr error
)
go func() {
defer close(shutdown)
// start one more goroutine to recv messages without blocking.
for {
// TODO(stevvooe): Something still isn't quite right with error
// handling on the client-side, causing EOFs to come through. We
// need other fixes in this changeset, so we'll address this
// correctly later.
mh, p, err := c.channel.recv(context.TODO())
if err != nil {
_, ok := status.FromError(err)
if !ok {
// treat all errors that are not an rpc status as terminal.
// all others poison the connection.
shutdownErr = err
return
}
}
select {
case incoming <- received{
mh: mh,
p: p[:mh.Length],
err: err,
case incoming <- &message{
messageHeader: mh,
p: p[:mh.Length],
err: err,
}:
case <-c.done:
return
@@ -184,32 +148,63 @@ func (c *Client) run() {
}
}()
defer c.conn.Close()
defer close(c.done)
for {
select {
case req := <-c.sendRequests:
if p, err := proto.Marshal(req.msg.(proto.Message)); err != nil {
req.err <- err
} else {
req.err <- c.channel.send(req.ctx, req.id, messageTypeRequest, p)
case call := <-calls:
if err := c.send(call.ctx, streamID, messageTypeRequest, call.req); err != nil {
call.errs <- err
continue
}
case req := <-c.recvRequests:
if r, ok := queued[req.id]; ok {
req.err <- proto.Unmarshal(r.p, req.msg.(proto.Message))
waiters[streamID] = call
streamID += 2 // enforce odd client initiated request ids
case msg := <-incoming:
call, ok := waiters[msg.StreamID]
if !ok {
log.L.Errorf("ttrpc: received message for unknown channel %v", msg.StreamID)
continue
}
waiters[req.id] = req
case r := <-incoming:
if waiter, ok := waiters[r.mh.StreamID]; ok {
if r.err != nil {
waiter.err <- r.err
} else {
waiter.err <- proto.Unmarshal(r.p, waiter.msg.(proto.Message))
c.channel.putmbuf(r.p)
}
} else {
queued[r.mh.StreamID] = r
call.errs <- c.recv(call.resp, msg)
delete(waiters, msg.StreamID)
case <-shutdown:
shutdownErr = errors.Wrapf(shutdownErr, "ttrpc: client shutting down")
c.err = shutdownErr
for _, waiter := range waiters {
waiter.errs <- shutdownErr
}
c.Close()
return
case <-c.closed:
// broadcast the shutdown error to the remaining waiters.
for _, waiter := range waiters {
waiter.errs <- shutdownErr
}
return
}
}
}
func (c *Client) send(ctx context.Context, streamID uint32, mtype messageType, msg interface{}) error {
p, err := c.codec.Marshal(msg)
if err != nil {
return err
}
return c.channel.send(ctx, streamID, mtype, p)
}
func (c *Client) recv(resp *Response, msg *message) error {
if msg.err != nil {
return msg.err
}
if msg.Type != messageTypeResponse {
return errors.New("unkown message type received")
}
return proto.Unmarshal(msg.p, resp)
}