ttrpc: handle concurrent requests and responses
With this changeset, ttrpc can now handle mutliple outstanding requests and responses on the same connection without blocking. On the server-side, we dispatch a goroutine per outstanding reequest. On the client side, a management goroutine dispatches responses to blocked waiters. The protocol has been changed to support this behavior by including a "stream id" that can used to identify which request a response belongs to on the client-side of the connection. With these changes, we should also be able to support streams in the future. Signed-off-by: Stephen J Day <stephen.day@docker.com>
This commit is contained in:
parent
2a81659f49
commit
7f752bf263
156
channel.go
156
channel.go
@ -5,101 +5,95 @@ import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/containerd/containerd/log"
|
||||
"github.com/gogo/protobuf/proto"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const maxMessageSize = 8 << 10 // TODO(stevvooe): Cut these down, since they are pre-alloced.
|
||||
const (
|
||||
messageHeaderLength = 10
|
||||
messageLengthMax = 8 << 10
|
||||
)
|
||||
|
||||
type channel struct {
|
||||
conn net.Conn
|
||||
bw *bufio.Writer
|
||||
br *bufio.Reader
|
||||
type messageType uint8
|
||||
|
||||
const (
|
||||
messageTypeRequest messageType = 0x1
|
||||
messageTypeResponse messageType = 0x2
|
||||
)
|
||||
|
||||
// messageHeader represents the fixed-length message header of 10 bytes sent
|
||||
// with every request.
|
||||
type messageHeader struct {
|
||||
Length uint32 // length excluding this header. b[:4]
|
||||
StreamID uint32 // identifies which request stream message is a part of. b[4:8]
|
||||
Type messageType // message type b[8]
|
||||
Flags uint8 // reserved b[9]
|
||||
}
|
||||
|
||||
func newChannel(conn net.Conn) *channel {
|
||||
func readMessageHeader(p []byte, r io.Reader) (messageHeader, error) {
|
||||
_, err := io.ReadFull(r, p[:messageHeaderLength])
|
||||
if err != nil {
|
||||
return messageHeader{}, err
|
||||
}
|
||||
|
||||
return messageHeader{
|
||||
Length: binary.BigEndian.Uint32(p[:4]),
|
||||
StreamID: binary.BigEndian.Uint32(p[4:8]),
|
||||
Type: messageType(p[8]),
|
||||
Flags: p[9],
|
||||
}, nil
|
||||
}
|
||||
|
||||
func writeMessageHeader(w io.Writer, p []byte, mh messageHeader) error {
|
||||
binary.BigEndian.PutUint32(p[:4], mh.Length)
|
||||
binary.BigEndian.PutUint32(p[4:8], mh.StreamID)
|
||||
p[8] = byte(mh.Type)
|
||||
p[9] = mh.Flags
|
||||
|
||||
_, err := w.Write(p[:])
|
||||
return err
|
||||
}
|
||||
|
||||
type channel struct {
|
||||
bw *bufio.Writer
|
||||
br *bufio.Reader
|
||||
hrbuf [messageHeaderLength]byte // avoid alloc when reading header
|
||||
hwbuf [messageHeaderLength]byte
|
||||
}
|
||||
|
||||
func newChannel(w io.Writer, r io.Reader) *channel {
|
||||
return &channel{
|
||||
conn: conn,
|
||||
bw: bufio.NewWriterSize(conn, maxMessageSize),
|
||||
br: bufio.NewReaderSize(conn, maxMessageSize),
|
||||
bw: bufio.NewWriter(w),
|
||||
br: bufio.NewReader(r),
|
||||
}
|
||||
}
|
||||
|
||||
func (ch *channel) recv(ctx context.Context, msg interface{}) error {
|
||||
defer log.G(ctx).WithField("msg", msg).Info("recv")
|
||||
func (ch *channel) recv(ctx context.Context, p []byte) (messageHeader, error) {
|
||||
mh, err := readMessageHeader(ch.hrbuf[:], ch.br)
|
||||
if err != nil {
|
||||
return messageHeader{}, err
|
||||
}
|
||||
|
||||
// TODO(stevvooe): Use `bufio.Reader.Peek` here to remove this allocation.
|
||||
var p [maxMessageSize]byte
|
||||
n, err := readmsg(ch.br, p[:])
|
||||
if mh.Length > uint32(len(p)) {
|
||||
return messageHeader{}, errors.Wrapf(io.ErrShortBuffer, "message length %v over buffer size %v", mh.Length, len(p))
|
||||
}
|
||||
|
||||
if _, err := io.ReadFull(ch.br, p[:mh.Length]); err != nil {
|
||||
return messageHeader{}, errors.Wrapf(err, "failed reading message")
|
||||
}
|
||||
|
||||
return mh, nil
|
||||
}
|
||||
|
||||
func (ch *channel) send(ctx context.Context, streamID uint32, t messageType, p []byte) error {
|
||||
if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err := ch.bw.Write(p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case proto.Message:
|
||||
return proto.Unmarshal(p[:n], msg)
|
||||
default:
|
||||
return errors.Errorf("unnsupported type in channel: %#v", msg)
|
||||
}
|
||||
}
|
||||
|
||||
func (ch *channel) send(ctx context.Context, msg interface{}) error {
|
||||
log.G(ctx).WithField("msg", msg).Info("send")
|
||||
var p []byte
|
||||
switch msg := msg.(type) {
|
||||
case proto.Message:
|
||||
var err error
|
||||
// TODO(stevvooe): trickiest allocation of the bunch. This will be hard
|
||||
// to get rid of without using `MarshalTo` directly.
|
||||
p, err = proto.Marshal(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return errors.Errorf("unsupported type recv from channel: %#v", msg)
|
||||
}
|
||||
|
||||
return writemsg(ch.bw, p)
|
||||
}
|
||||
|
||||
func readmsg(r *bufio.Reader, p []byte) (int, error) {
|
||||
mlen, err := binary.ReadVarint(r)
|
||||
if err != nil {
|
||||
return 0, errors.Wrapf(err, "failed reading message size")
|
||||
}
|
||||
|
||||
if mlen > int64(len(p)) {
|
||||
return 0, errors.Wrapf(io.ErrShortBuffer, "message length %v over buffer size %v", mlen, len(p))
|
||||
}
|
||||
|
||||
nn, err := io.ReadFull(r, p[:mlen])
|
||||
if err != nil {
|
||||
return 0, errors.Wrapf(err, "failed reading message size")
|
||||
}
|
||||
|
||||
if int64(nn) != mlen {
|
||||
return 0, errors.Errorf("mismatched read against message length %v != %v", nn, mlen)
|
||||
}
|
||||
|
||||
return int(mlen), nil
|
||||
}
|
||||
|
||||
func writemsg(w *bufio.Writer, p []byte) error {
|
||||
var (
|
||||
mlenp [binary.MaxVarintLen64]byte
|
||||
n = binary.PutVarint(mlenp[:], int64(len(p)))
|
||||
)
|
||||
|
||||
if _, err := w.Write(mlenp[:n]); err != nil {
|
||||
return errors.Wrapf(err, "failed writing message header")
|
||||
}
|
||||
|
||||
if _, err := w.Write(p); err != nil {
|
||||
return errors.Wrapf(err, "failed writing message")
|
||||
}
|
||||
|
||||
return w.Flush()
|
||||
return ch.bw.Flush()
|
||||
}
|
||||
|
@ -3,6 +3,7 @@ package ttrpc
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"reflect"
|
||||
"testing"
|
||||
@ -12,8 +13,10 @@ import (
|
||||
|
||||
func TestReadWriteMessage(t *testing.T) {
|
||||
var (
|
||||
channel bytes.Buffer
|
||||
w = bufio.NewWriter(&channel)
|
||||
ctx = context.Background()
|
||||
buffer bytes.Buffer
|
||||
w = bufio.NewWriter(&buffer)
|
||||
ch = newChannel(w, nil)
|
||||
messages = [][]byte{
|
||||
[]byte("hello"),
|
||||
[]byte("this is a test"),
|
||||
@ -21,20 +24,21 @@ func TestReadWriteMessage(t *testing.T) {
|
||||
}
|
||||
)
|
||||
|
||||
for _, msg := range messages {
|
||||
if err := writemsg(w, msg); err != nil {
|
||||
for i, msg := range messages {
|
||||
if err := ch.send(ctx, uint32(i), 1, msg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
received [][]byte
|
||||
r = bufio.NewReader(bytes.NewReader(channel.Bytes()))
|
||||
r = bufio.NewReader(bytes.NewReader(buffer.Bytes()))
|
||||
rch = newChannel(nil, r)
|
||||
)
|
||||
|
||||
for {
|
||||
var p [4096]byte
|
||||
n, err := readmsg(r, p[:])
|
||||
mh, err := rch.recv(ctx, p[:])
|
||||
if err != nil {
|
||||
if errors.Cause(err) != io.EOF {
|
||||
t.Fatal(err)
|
||||
@ -42,7 +46,7 @@ func TestReadWriteMessage(t *testing.T) {
|
||||
|
||||
break
|
||||
}
|
||||
received = append(received, p[:n])
|
||||
received = append(received, p[:mh.Length])
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(received, messages) {
|
||||
@ -52,21 +56,25 @@ func TestReadWriteMessage(t *testing.T) {
|
||||
|
||||
func TestSmallBuffer(t *testing.T) {
|
||||
var (
|
||||
channel bytes.Buffer
|
||||
w = bufio.NewWriter(&channel)
|
||||
msg = []byte("a message of massive length")
|
||||
ctx = context.Background()
|
||||
buffer bytes.Buffer
|
||||
w = bufio.NewWriter(&buffer)
|
||||
ch = newChannel(w, nil)
|
||||
msg = []byte("a message of massive length")
|
||||
)
|
||||
|
||||
if err := writemsg(w, msg); err != nil {
|
||||
if err := ch.send(ctx, 1, 1, msg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// now, read it off the channel with a small buffer
|
||||
var (
|
||||
p = make([]byte, len(msg)-1)
|
||||
r = bufio.NewReader(bytes.NewReader(channel.Bytes()))
|
||||
p = make([]byte, len(msg)-1)
|
||||
r = bufio.NewReader(bytes.NewReader(buffer.Bytes()))
|
||||
rch = newChannel(nil, r)
|
||||
)
|
||||
_, err := readmsg(r, p[:])
|
||||
|
||||
_, err := rch.recv(ctx, p[:])
|
||||
if err == nil {
|
||||
t.Fatalf("error expected reading with small buffer")
|
||||
}
|
||||
@ -75,41 +83,3 @@ func TestSmallBuffer(t *testing.T) {
|
||||
t.Fatalf("errors.Cause(err) should equal io.ErrShortBuffer: %v != %v", err, io.ErrShortBuffer)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkReadWrite(b *testing.B) {
|
||||
b.StopTimer()
|
||||
var (
|
||||
messages = [][]byte{
|
||||
[]byte("hello"),
|
||||
[]byte("this is a test"),
|
||||
[]byte("of message framing"),
|
||||
}
|
||||
total int64
|
||||
channel bytes.Buffer
|
||||
w = bufio.NewWriter(&channel)
|
||||
p [4096]byte
|
||||
)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.StartTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
msg := messages[i%len(messages)]
|
||||
if err := writemsg(w, msg); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
total += int64(len(msg))
|
||||
}
|
||||
b.SetBytes(total)
|
||||
|
||||
r := bufio.NewReader(bytes.NewReader(channel.Bytes()))
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := readmsg(r, p[:])
|
||||
if err != nil {
|
||||
if errors.Cause(err) != io.EOF {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
183
client.go
183
client.go
@ -3,56 +3,191 @@ package ttrpc
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/gogo/protobuf/proto"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
channel *channel
|
||||
codec codec
|
||||
channel *channel
|
||||
requestID uint32
|
||||
sendRequests chan sendRequest
|
||||
recvRequests chan recvRequest
|
||||
|
||||
closed chan struct{}
|
||||
closeOnce sync.Once
|
||||
done chan struct{}
|
||||
err error
|
||||
}
|
||||
|
||||
func NewClient(conn net.Conn) *Client {
|
||||
return &Client{
|
||||
channel: newChannel(conn),
|
||||
c := &Client{
|
||||
codec: codec{},
|
||||
channel: newChannel(conn, conn),
|
||||
sendRequests: make(chan sendRequest),
|
||||
recvRequests: make(chan recvRequest),
|
||||
closed: make(chan struct{}),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
go c.run()
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Client) Call(ctx context.Context, service, method string, req, resp interface{}) error {
|
||||
var payload []byte
|
||||
switch v := req.(type) {
|
||||
case proto.Message:
|
||||
var err error
|
||||
payload, err = proto.Marshal(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return errors.Errorf("ttrpc: unknown request type: %T", req)
|
||||
payload, err := c.codec.Marshal(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
requestID := atomic.AddUint32(&c.requestID, 1)
|
||||
request := Request{
|
||||
Service: service,
|
||||
Method: method,
|
||||
Payload: payload,
|
||||
}
|
||||
|
||||
if err := c.channel.send(ctx, &request); err != nil {
|
||||
if err := c.send(ctx, requestID, &request); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var response Response
|
||||
if err := c.channel.recv(ctx, &response); err != nil {
|
||||
if err := c.recv(ctx, requestID, &response); err != nil {
|
||||
return err
|
||||
}
|
||||
switch v := resp.(type) {
|
||||
case proto.Message:
|
||||
if err := proto.Unmarshal(response.Payload, v); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return errors.Errorf("ttrpc: unknown response type: %T", resp)
|
||||
}
|
||||
|
||||
return c.codec.Unmarshal(response.Payload, resp)
|
||||
}
|
||||
|
||||
func (c *Client) Close() error {
|
||||
c.closeOnce.Do(func() {
|
||||
close(c.closed)
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type sendRequest struct {
|
||||
ctx context.Context
|
||||
id uint32
|
||||
msg interface{}
|
||||
err chan error
|
||||
}
|
||||
|
||||
func (c *Client) send(ctx context.Context, id uint32, msg interface{}) error {
|
||||
errs := make(chan error, 1)
|
||||
select {
|
||||
case c.sendRequests <- sendRequest{
|
||||
ctx: ctx,
|
||||
id: id,
|
||||
msg: msg,
|
||||
err: errs,
|
||||
}:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-c.done:
|
||||
return c.err
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-errs:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-c.done:
|
||||
return c.err
|
||||
}
|
||||
}
|
||||
|
||||
type recvRequest struct {
|
||||
id uint32
|
||||
msg interface{}
|
||||
err chan error
|
||||
}
|
||||
|
||||
func (c *Client) recv(ctx context.Context, id uint32, msg interface{}) error {
|
||||
errs := make(chan error, 1)
|
||||
select {
|
||||
case c.recvRequests <- recvRequest{
|
||||
id: id,
|
||||
msg: msg,
|
||||
err: errs,
|
||||
}:
|
||||
case <-c.done:
|
||||
return c.err
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-errs:
|
||||
return err
|
||||
case <-c.done:
|
||||
return c.err
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
type received struct {
|
||||
mh messageHeader
|
||||
p []byte
|
||||
err error
|
||||
}
|
||||
|
||||
func (c *Client) run() {
|
||||
defer close(c.done)
|
||||
var (
|
||||
waiters = map[uint32]recvRequest{}
|
||||
queued = map[uint32]received{} // messages unmatched by waiter
|
||||
incoming = make(chan received)
|
||||
)
|
||||
|
||||
go func() {
|
||||
// start one more goroutine to recv messages without blocking.
|
||||
for {
|
||||
var p [messageLengthMax]byte
|
||||
mh, err := c.channel.recv(context.TODO(), p[:])
|
||||
select {
|
||||
case incoming <- received{
|
||||
mh: mh,
|
||||
p: p[:mh.Length],
|
||||
err: err,
|
||||
}:
|
||||
case <-c.done:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case req := <-c.sendRequests:
|
||||
if p, err := proto.Marshal(req.msg.(proto.Message)); err != nil {
|
||||
req.err <- err
|
||||
} else {
|
||||
req.err <- c.channel.send(req.ctx, req.id, messageTypeRequest, p)
|
||||
}
|
||||
case req := <-c.recvRequests:
|
||||
if r, ok := queued[req.id]; ok {
|
||||
req.err <- proto.Unmarshal(r.p, req.msg.(proto.Message))
|
||||
}
|
||||
waiters[req.id] = req
|
||||
case r := <-incoming:
|
||||
if r.err != nil {
|
||||
c.err = r.err
|
||||
return
|
||||
}
|
||||
|
||||
if waiter, ok := waiters[r.mh.StreamID]; ok {
|
||||
waiter.err <- proto.Unmarshal(r.p, waiter.msg.(proto.Message))
|
||||
} else {
|
||||
queued[r.mh.StreamID] = r
|
||||
}
|
||||
case <-c.closed:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
97
server.go
97
server.go
@ -9,6 +9,7 @@ import (
|
||||
|
||||
type Server struct {
|
||||
services *serviceSet
|
||||
codec codec
|
||||
}
|
||||
|
||||
func NewServer() *Server {
|
||||
@ -43,35 +44,91 @@ func (s *Server) Serve(l net.Listener) error {
|
||||
func (s *Server) handleConn(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
type (
|
||||
request struct {
|
||||
id uint32
|
||||
req *Request
|
||||
}
|
||||
|
||||
response struct {
|
||||
id uint32
|
||||
resp *Response
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
ch = newChannel(conn)
|
||||
req Request
|
||||
ch = newChannel(conn, conn)
|
||||
ctx, cancel = context.WithCancel(context.Background())
|
||||
responses = make(chan response)
|
||||
requests = make(chan request)
|
||||
recvErr = make(chan error, 1)
|
||||
done = make(chan struct{})
|
||||
)
|
||||
|
||||
defer cancel()
|
||||
defer close(done)
|
||||
|
||||
// TODO(stevvooe): Recover here or in dispatch to handle panics in service
|
||||
// methods.
|
||||
go func() {
|
||||
defer close(recvErr)
|
||||
var p [messageLengthMax]byte
|
||||
for {
|
||||
mh, err := ch.recv(ctx, p[:])
|
||||
if err != nil {
|
||||
recvErr <- err
|
||||
return
|
||||
}
|
||||
|
||||
if mh.Type != messageTypeRequest {
|
||||
// we must ignore this for future compat.
|
||||
continue
|
||||
}
|
||||
|
||||
var req Request
|
||||
if err := s.codec.Unmarshal(p[:mh.Length], &req); err != nil {
|
||||
recvErr <- err
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case requests <- request{
|
||||
id: mh.StreamID,
|
||||
req: &req,
|
||||
}:
|
||||
case <-done:
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// every connection is just a simple in/out request loop. No complexity for
|
||||
// multiplexing streams or dealing with head of line blocking, as this
|
||||
// isn't necessary for shim control.
|
||||
for {
|
||||
if err := ch.recv(ctx, &req); err != nil {
|
||||
log.L.WithError(err).Error("failed receiving message on channel")
|
||||
return
|
||||
}
|
||||
select {
|
||||
case request := <-requests:
|
||||
go func(id uint32) {
|
||||
p, status := s.services.call(ctx, request.req.Service, request.req.Method, request.req.Payload)
|
||||
resp := &Response{
|
||||
Status: status.Proto(),
|
||||
Payload: p,
|
||||
}
|
||||
|
||||
p, status := s.services.call(ctx, req.Service, req.Method, req.Payload)
|
||||
|
||||
resp := &Response{
|
||||
Status: status.Proto(),
|
||||
Payload: p,
|
||||
}
|
||||
|
||||
if err := ch.send(ctx, resp); err != nil {
|
||||
log.L.WithError(err).Error("failed sending message on channel")
|
||||
select {
|
||||
case responses <- response{
|
||||
id: id,
|
||||
resp: resp,
|
||||
}:
|
||||
case <-done:
|
||||
}
|
||||
}(request.id)
|
||||
case response := <-responses:
|
||||
p, err := s.codec.Marshal(response.resp)
|
||||
if err != nil {
|
||||
log.L.WithError(err).Error("failed marshaling response")
|
||||
return
|
||||
}
|
||||
if err := ch.send(ctx, response.id, messageTypeResponse, p); err != nil {
|
||||
log.L.WithError(err).Error("failed sending message on channel")
|
||||
return
|
||||
}
|
||||
case err := <-recvErr:
|
||||
log.L.WithError(err).Error("error receiving message")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -6,7 +6,6 @@ import (
|
||||
"os"
|
||||
"path"
|
||||
|
||||
"github.com/containerd/containerd/log"
|
||||
"github.com/gogo/protobuf/proto"
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/grpc/codes"
|
||||
@ -52,7 +51,6 @@ func (s *serviceSet) call(ctx context.Context, serviceName, methodName string, p
|
||||
}
|
||||
|
||||
func (s *serviceSet) dispatch(ctx context.Context, serviceName, methodName string, p []byte) ([]byte, error) {
|
||||
ctx = log.WithLogger(ctx, log.G(ctx).WithField("method", fullPath(serviceName, methodName)))
|
||||
method, err := s.resolve(serviceName, methodName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
Loading…
Reference in New Issue
Block a user