ttrpc: handle concurrent requests and responses
With this changeset, ttrpc can now handle mutliple outstanding requests and responses on the same connection without blocking. On the server-side, we dispatch a goroutine per outstanding reequest. On the client side, a management goroutine dispatches responses to blocked waiters. The protocol has been changed to support this behavior by including a "stream id" that can used to identify which request a response belongs to on the client-side of the connection. With these changes, we should also be able to support streams in the future. Signed-off-by: Stephen J Day <stephen.day@docker.com>
This commit is contained in:
183
client.go
183
client.go
@@ -3,56 +3,191 @@ package ttrpc
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/gogo/protobuf/proto"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
channel *channel
|
||||
codec codec
|
||||
channel *channel
|
||||
requestID uint32
|
||||
sendRequests chan sendRequest
|
||||
recvRequests chan recvRequest
|
||||
|
||||
closed chan struct{}
|
||||
closeOnce sync.Once
|
||||
done chan struct{}
|
||||
err error
|
||||
}
|
||||
|
||||
func NewClient(conn net.Conn) *Client {
|
||||
return &Client{
|
||||
channel: newChannel(conn),
|
||||
c := &Client{
|
||||
codec: codec{},
|
||||
channel: newChannel(conn, conn),
|
||||
sendRequests: make(chan sendRequest),
|
||||
recvRequests: make(chan recvRequest),
|
||||
closed: make(chan struct{}),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
go c.run()
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Client) Call(ctx context.Context, service, method string, req, resp interface{}) error {
|
||||
var payload []byte
|
||||
switch v := req.(type) {
|
||||
case proto.Message:
|
||||
var err error
|
||||
payload, err = proto.Marshal(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return errors.Errorf("ttrpc: unknown request type: %T", req)
|
||||
payload, err := c.codec.Marshal(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
requestID := atomic.AddUint32(&c.requestID, 1)
|
||||
request := Request{
|
||||
Service: service,
|
||||
Method: method,
|
||||
Payload: payload,
|
||||
}
|
||||
|
||||
if err := c.channel.send(ctx, &request); err != nil {
|
||||
if err := c.send(ctx, requestID, &request); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var response Response
|
||||
if err := c.channel.recv(ctx, &response); err != nil {
|
||||
if err := c.recv(ctx, requestID, &response); err != nil {
|
||||
return err
|
||||
}
|
||||
switch v := resp.(type) {
|
||||
case proto.Message:
|
||||
if err := proto.Unmarshal(response.Payload, v); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return errors.Errorf("ttrpc: unknown response type: %T", resp)
|
||||
}
|
||||
|
||||
return c.codec.Unmarshal(response.Payload, resp)
|
||||
}
|
||||
|
||||
func (c *Client) Close() error {
|
||||
c.closeOnce.Do(func() {
|
||||
close(c.closed)
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type sendRequest struct {
|
||||
ctx context.Context
|
||||
id uint32
|
||||
msg interface{}
|
||||
err chan error
|
||||
}
|
||||
|
||||
func (c *Client) send(ctx context.Context, id uint32, msg interface{}) error {
|
||||
errs := make(chan error, 1)
|
||||
select {
|
||||
case c.sendRequests <- sendRequest{
|
||||
ctx: ctx,
|
||||
id: id,
|
||||
msg: msg,
|
||||
err: errs,
|
||||
}:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-c.done:
|
||||
return c.err
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-errs:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-c.done:
|
||||
return c.err
|
||||
}
|
||||
}
|
||||
|
||||
type recvRequest struct {
|
||||
id uint32
|
||||
msg interface{}
|
||||
err chan error
|
||||
}
|
||||
|
||||
func (c *Client) recv(ctx context.Context, id uint32, msg interface{}) error {
|
||||
errs := make(chan error, 1)
|
||||
select {
|
||||
case c.recvRequests <- recvRequest{
|
||||
id: id,
|
||||
msg: msg,
|
||||
err: errs,
|
||||
}:
|
||||
case <-c.done:
|
||||
return c.err
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-errs:
|
||||
return err
|
||||
case <-c.done:
|
||||
return c.err
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
type received struct {
|
||||
mh messageHeader
|
||||
p []byte
|
||||
err error
|
||||
}
|
||||
|
||||
func (c *Client) run() {
|
||||
defer close(c.done)
|
||||
var (
|
||||
waiters = map[uint32]recvRequest{}
|
||||
queued = map[uint32]received{} // messages unmatched by waiter
|
||||
incoming = make(chan received)
|
||||
)
|
||||
|
||||
go func() {
|
||||
// start one more goroutine to recv messages without blocking.
|
||||
for {
|
||||
var p [messageLengthMax]byte
|
||||
mh, err := c.channel.recv(context.TODO(), p[:])
|
||||
select {
|
||||
case incoming <- received{
|
||||
mh: mh,
|
||||
p: p[:mh.Length],
|
||||
err: err,
|
||||
}:
|
||||
case <-c.done:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case req := <-c.sendRequests:
|
||||
if p, err := proto.Marshal(req.msg.(proto.Message)); err != nil {
|
||||
req.err <- err
|
||||
} else {
|
||||
req.err <- c.channel.send(req.ctx, req.id, messageTypeRequest, p)
|
||||
}
|
||||
case req := <-c.recvRequests:
|
||||
if r, ok := queued[req.id]; ok {
|
||||
req.err <- proto.Unmarshal(r.p, req.msg.(proto.Message))
|
||||
}
|
||||
waiters[req.id] = req
|
||||
case r := <-incoming:
|
||||
if r.err != nil {
|
||||
c.err = r.err
|
||||
return
|
||||
}
|
||||
|
||||
if waiter, ok := waiters[r.mh.StreamID]; ok {
|
||||
waiter.err <- proto.Unmarshal(r.p, waiter.msg.(proto.Message))
|
||||
} else {
|
||||
queued[r.mh.StreamID] = r
|
||||
}
|
||||
case <-c.closed:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user