ttrpc: handle concurrent requests and responses

With this changeset, ttrpc can now handle mutliple outstanding requests
and responses on the same connection without blocking. On the
server-side, we dispatch a goroutine per outstanding reequest. On the
client side, a management goroutine dispatches responses to blocked
waiters.

The protocol has been changed to support this behavior by including a
"stream id" that can used to identify which request a response belongs
to on the client-side of the connection. With these changes, we should
also be able to support streams in the future.

Signed-off-by: Stephen J Day <stephen.day@docker.com>
This commit is contained in:
Stephen J Day 2017-11-21 21:38:38 -08:00
parent 2a81659f49
commit 7f752bf263
No known key found for this signature in database
GPG Key ID: 67B3DED84EDC823F
5 changed files with 333 additions and 179 deletions

View File

@ -5,101 +5,95 @@ import (
"context"
"encoding/binary"
"io"
"net"
"github.com/containerd/containerd/log"
"github.com/gogo/protobuf/proto"
"github.com/pkg/errors"
)
const maxMessageSize = 8 << 10 // TODO(stevvooe): Cut these down, since they are pre-alloced.
const (
messageHeaderLength = 10
messageLengthMax = 8 << 10
)
type channel struct {
conn net.Conn
bw *bufio.Writer
br *bufio.Reader
type messageType uint8
const (
messageTypeRequest messageType = 0x1
messageTypeResponse messageType = 0x2
)
// messageHeader represents the fixed-length message header of 10 bytes sent
// with every request.
type messageHeader struct {
Length uint32 // length excluding this header. b[:4]
StreamID uint32 // identifies which request stream message is a part of. b[4:8]
Type messageType // message type b[8]
Flags uint8 // reserved b[9]
}
func newChannel(conn net.Conn) *channel {
func readMessageHeader(p []byte, r io.Reader) (messageHeader, error) {
_, err := io.ReadFull(r, p[:messageHeaderLength])
if err != nil {
return messageHeader{}, err
}
return messageHeader{
Length: binary.BigEndian.Uint32(p[:4]),
StreamID: binary.BigEndian.Uint32(p[4:8]),
Type: messageType(p[8]),
Flags: p[9],
}, nil
}
func writeMessageHeader(w io.Writer, p []byte, mh messageHeader) error {
binary.BigEndian.PutUint32(p[:4], mh.Length)
binary.BigEndian.PutUint32(p[4:8], mh.StreamID)
p[8] = byte(mh.Type)
p[9] = mh.Flags
_, err := w.Write(p[:])
return err
}
type channel struct {
bw *bufio.Writer
br *bufio.Reader
hrbuf [messageHeaderLength]byte // avoid alloc when reading header
hwbuf [messageHeaderLength]byte
}
func newChannel(w io.Writer, r io.Reader) *channel {
return &channel{
conn: conn,
bw: bufio.NewWriterSize(conn, maxMessageSize),
br: bufio.NewReaderSize(conn, maxMessageSize),
bw: bufio.NewWriter(w),
br: bufio.NewReader(r),
}
}
func (ch *channel) recv(ctx context.Context, msg interface{}) error {
defer log.G(ctx).WithField("msg", msg).Info("recv")
func (ch *channel) recv(ctx context.Context, p []byte) (messageHeader, error) {
mh, err := readMessageHeader(ch.hrbuf[:], ch.br)
if err != nil {
return messageHeader{}, err
}
// TODO(stevvooe): Use `bufio.Reader.Peek` here to remove this allocation.
var p [maxMessageSize]byte
n, err := readmsg(ch.br, p[:])
if mh.Length > uint32(len(p)) {
return messageHeader{}, errors.Wrapf(io.ErrShortBuffer, "message length %v over buffer size %v", mh.Length, len(p))
}
if _, err := io.ReadFull(ch.br, p[:mh.Length]); err != nil {
return messageHeader{}, errors.Wrapf(err, "failed reading message")
}
return mh, nil
}
func (ch *channel) send(ctx context.Context, streamID uint32, t messageType, p []byte) error {
if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t}); err != nil {
return err
}
_, err := ch.bw.Write(p)
if err != nil {
return err
}
switch msg := msg.(type) {
case proto.Message:
return proto.Unmarshal(p[:n], msg)
default:
return errors.Errorf("unnsupported type in channel: %#v", msg)
}
}
func (ch *channel) send(ctx context.Context, msg interface{}) error {
log.G(ctx).WithField("msg", msg).Info("send")
var p []byte
switch msg := msg.(type) {
case proto.Message:
var err error
// TODO(stevvooe): trickiest allocation of the bunch. This will be hard
// to get rid of without using `MarshalTo` directly.
p, err = proto.Marshal(msg)
if err != nil {
return err
}
default:
return errors.Errorf("unsupported type recv from channel: %#v", msg)
}
return writemsg(ch.bw, p)
}
func readmsg(r *bufio.Reader, p []byte) (int, error) {
mlen, err := binary.ReadVarint(r)
if err != nil {
return 0, errors.Wrapf(err, "failed reading message size")
}
if mlen > int64(len(p)) {
return 0, errors.Wrapf(io.ErrShortBuffer, "message length %v over buffer size %v", mlen, len(p))
}
nn, err := io.ReadFull(r, p[:mlen])
if err != nil {
return 0, errors.Wrapf(err, "failed reading message size")
}
if int64(nn) != mlen {
return 0, errors.Errorf("mismatched read against message length %v != %v", nn, mlen)
}
return int(mlen), nil
}
func writemsg(w *bufio.Writer, p []byte) error {
var (
mlenp [binary.MaxVarintLen64]byte
n = binary.PutVarint(mlenp[:], int64(len(p)))
)
if _, err := w.Write(mlenp[:n]); err != nil {
return errors.Wrapf(err, "failed writing message header")
}
if _, err := w.Write(p); err != nil {
return errors.Wrapf(err, "failed writing message")
}
return w.Flush()
return ch.bw.Flush()
}

View File

@ -3,6 +3,7 @@ package ttrpc
import (
"bufio"
"bytes"
"context"
"io"
"reflect"
"testing"
@ -12,8 +13,10 @@ import (
func TestReadWriteMessage(t *testing.T) {
var (
channel bytes.Buffer
w = bufio.NewWriter(&channel)
ctx = context.Background()
buffer bytes.Buffer
w = bufio.NewWriter(&buffer)
ch = newChannel(w, nil)
messages = [][]byte{
[]byte("hello"),
[]byte("this is a test"),
@ -21,20 +24,21 @@ func TestReadWriteMessage(t *testing.T) {
}
)
for _, msg := range messages {
if err := writemsg(w, msg); err != nil {
for i, msg := range messages {
if err := ch.send(ctx, uint32(i), 1, msg); err != nil {
t.Fatal(err)
}
}
var (
received [][]byte
r = bufio.NewReader(bytes.NewReader(channel.Bytes()))
r = bufio.NewReader(bytes.NewReader(buffer.Bytes()))
rch = newChannel(nil, r)
)
for {
var p [4096]byte
n, err := readmsg(r, p[:])
mh, err := rch.recv(ctx, p[:])
if err != nil {
if errors.Cause(err) != io.EOF {
t.Fatal(err)
@ -42,7 +46,7 @@ func TestReadWriteMessage(t *testing.T) {
break
}
received = append(received, p[:n])
received = append(received, p[:mh.Length])
}
if !reflect.DeepEqual(received, messages) {
@ -52,21 +56,25 @@ func TestReadWriteMessage(t *testing.T) {
func TestSmallBuffer(t *testing.T) {
var (
channel bytes.Buffer
w = bufio.NewWriter(&channel)
msg = []byte("a message of massive length")
ctx = context.Background()
buffer bytes.Buffer
w = bufio.NewWriter(&buffer)
ch = newChannel(w, nil)
msg = []byte("a message of massive length")
)
if err := writemsg(w, msg); err != nil {
if err := ch.send(ctx, 1, 1, msg); err != nil {
t.Fatal(err)
}
// now, read it off the channel with a small buffer
var (
p = make([]byte, len(msg)-1)
r = bufio.NewReader(bytes.NewReader(channel.Bytes()))
p = make([]byte, len(msg)-1)
r = bufio.NewReader(bytes.NewReader(buffer.Bytes()))
rch = newChannel(nil, r)
)
_, err := readmsg(r, p[:])
_, err := rch.recv(ctx, p[:])
if err == nil {
t.Fatalf("error expected reading with small buffer")
}
@ -75,41 +83,3 @@ func TestSmallBuffer(t *testing.T) {
t.Fatalf("errors.Cause(err) should equal io.ErrShortBuffer: %v != %v", err, io.ErrShortBuffer)
}
}
func BenchmarkReadWrite(b *testing.B) {
b.StopTimer()
var (
messages = [][]byte{
[]byte("hello"),
[]byte("this is a test"),
[]byte("of message framing"),
}
total int64
channel bytes.Buffer
w = bufio.NewWriter(&channel)
p [4096]byte
)
b.ReportAllocs()
b.StartTimer()
for i := 0; i < b.N; i++ {
msg := messages[i%len(messages)]
if err := writemsg(w, msg); err != nil {
b.Fatal(err)
}
total += int64(len(msg))
}
b.SetBytes(total)
r := bufio.NewReader(bytes.NewReader(channel.Bytes()))
for i := 0; i < b.N; i++ {
_, err := readmsg(r, p[:])
if err != nil {
if errors.Cause(err) != io.EOF {
b.Fatal(err)
}
break
}
}
}

183
client.go
View File

@ -3,56 +3,191 @@ package ttrpc
import (
"context"
"net"
"sync"
"sync/atomic"
"github.com/gogo/protobuf/proto"
"github.com/pkg/errors"
)
type Client struct {
channel *channel
codec codec
channel *channel
requestID uint32
sendRequests chan sendRequest
recvRequests chan recvRequest
closed chan struct{}
closeOnce sync.Once
done chan struct{}
err error
}
func NewClient(conn net.Conn) *Client {
return &Client{
channel: newChannel(conn),
c := &Client{
codec: codec{},
channel: newChannel(conn, conn),
sendRequests: make(chan sendRequest),
recvRequests: make(chan recvRequest),
closed: make(chan struct{}),
done: make(chan struct{}),
}
go c.run()
return c
}
func (c *Client) Call(ctx context.Context, service, method string, req, resp interface{}) error {
var payload []byte
switch v := req.(type) {
case proto.Message:
var err error
payload, err = proto.Marshal(v)
if err != nil {
return err
}
default:
return errors.Errorf("ttrpc: unknown request type: %T", req)
payload, err := c.codec.Marshal(req)
if err != nil {
return err
}
requestID := atomic.AddUint32(&c.requestID, 1)
request := Request{
Service: service,
Method: method,
Payload: payload,
}
if err := c.channel.send(ctx, &request); err != nil {
if err := c.send(ctx, requestID, &request); err != nil {
return err
}
var response Response
if err := c.channel.recv(ctx, &response); err != nil {
if err := c.recv(ctx, requestID, &response); err != nil {
return err
}
switch v := resp.(type) {
case proto.Message:
if err := proto.Unmarshal(response.Payload, v); err != nil {
return err
}
default:
return errors.Errorf("ttrpc: unknown response type: %T", resp)
}
return c.codec.Unmarshal(response.Payload, resp)
}
func (c *Client) Close() error {
c.closeOnce.Do(func() {
close(c.closed)
})
return nil
}
type sendRequest struct {
ctx context.Context
id uint32
msg interface{}
err chan error
}
func (c *Client) send(ctx context.Context, id uint32, msg interface{}) error {
errs := make(chan error, 1)
select {
case c.sendRequests <- sendRequest{
ctx: ctx,
id: id,
msg: msg,
err: errs,
}:
case <-ctx.Done():
return ctx.Err()
case <-c.done:
return c.err
}
select {
case err := <-errs:
return err
case <-ctx.Done():
return ctx.Err()
case <-c.done:
return c.err
}
}
type recvRequest struct {
id uint32
msg interface{}
err chan error
}
func (c *Client) recv(ctx context.Context, id uint32, msg interface{}) error {
errs := make(chan error, 1)
select {
case c.recvRequests <- recvRequest{
id: id,
msg: msg,
err: errs,
}:
case <-c.done:
return c.err
case <-ctx.Done():
return ctx.Err()
}
select {
case err := <-errs:
return err
case <-c.done:
return c.err
case <-ctx.Done():
return ctx.Err()
}
}
type received struct {
mh messageHeader
p []byte
err error
}
func (c *Client) run() {
defer close(c.done)
var (
waiters = map[uint32]recvRequest{}
queued = map[uint32]received{} // messages unmatched by waiter
incoming = make(chan received)
)
go func() {
// start one more goroutine to recv messages without blocking.
for {
var p [messageLengthMax]byte
mh, err := c.channel.recv(context.TODO(), p[:])
select {
case incoming <- received{
mh: mh,
p: p[:mh.Length],
err: err,
}:
case <-c.done:
return
}
}
}()
for {
select {
case req := <-c.sendRequests:
if p, err := proto.Marshal(req.msg.(proto.Message)); err != nil {
req.err <- err
} else {
req.err <- c.channel.send(req.ctx, req.id, messageTypeRequest, p)
}
case req := <-c.recvRequests:
if r, ok := queued[req.id]; ok {
req.err <- proto.Unmarshal(r.p, req.msg.(proto.Message))
}
waiters[req.id] = req
case r := <-incoming:
if r.err != nil {
c.err = r.err
return
}
if waiter, ok := waiters[r.mh.StreamID]; ok {
waiter.err <- proto.Unmarshal(r.p, waiter.msg.(proto.Message))
} else {
queued[r.mh.StreamID] = r
}
case <-c.closed:
return
}
}
}

View File

@ -9,6 +9,7 @@ import (
type Server struct {
services *serviceSet
codec codec
}
func NewServer() *Server {
@ -43,35 +44,91 @@ func (s *Server) Serve(l net.Listener) error {
func (s *Server) handleConn(conn net.Conn) {
defer conn.Close()
type (
request struct {
id uint32
req *Request
}
response struct {
id uint32
resp *Response
}
)
var (
ch = newChannel(conn)
req Request
ch = newChannel(conn, conn)
ctx, cancel = context.WithCancel(context.Background())
responses = make(chan response)
requests = make(chan request)
recvErr = make(chan error, 1)
done = make(chan struct{})
)
defer cancel()
defer close(done)
// TODO(stevvooe): Recover here or in dispatch to handle panics in service
// methods.
go func() {
defer close(recvErr)
var p [messageLengthMax]byte
for {
mh, err := ch.recv(ctx, p[:])
if err != nil {
recvErr <- err
return
}
if mh.Type != messageTypeRequest {
// we must ignore this for future compat.
continue
}
var req Request
if err := s.codec.Unmarshal(p[:mh.Length], &req); err != nil {
recvErr <- err
return
}
select {
case requests <- request{
id: mh.StreamID,
req: &req,
}:
case <-done:
}
}
}()
// every connection is just a simple in/out request loop. No complexity for
// multiplexing streams or dealing with head of line blocking, as this
// isn't necessary for shim control.
for {
if err := ch.recv(ctx, &req); err != nil {
log.L.WithError(err).Error("failed receiving message on channel")
return
}
select {
case request := <-requests:
go func(id uint32) {
p, status := s.services.call(ctx, request.req.Service, request.req.Method, request.req.Payload)
resp := &Response{
Status: status.Proto(),
Payload: p,
}
p, status := s.services.call(ctx, req.Service, req.Method, req.Payload)
resp := &Response{
Status: status.Proto(),
Payload: p,
}
if err := ch.send(ctx, resp); err != nil {
log.L.WithError(err).Error("failed sending message on channel")
select {
case responses <- response{
id: id,
resp: resp,
}:
case <-done:
}
}(request.id)
case response := <-responses:
p, err := s.codec.Marshal(response.resp)
if err != nil {
log.L.WithError(err).Error("failed marshaling response")
return
}
if err := ch.send(ctx, response.id, messageTypeResponse, p); err != nil {
log.L.WithError(err).Error("failed sending message on channel")
return
}
case err := <-recvErr:
log.L.WithError(err).Error("error receiving message")
return
}
}

View File

@ -6,7 +6,6 @@ import (
"os"
"path"
"github.com/containerd/containerd/log"
"github.com/gogo/protobuf/proto"
"github.com/pkg/errors"
"google.golang.org/grpc/codes"
@ -52,7 +51,6 @@ func (s *serviceSet) call(ctx context.Context, serviceName, methodName string, p
}
func (s *serviceSet) dispatch(ctx context.Context, serviceName, methodName string, p []byte) ([]byte, error) {
ctx = log.WithLogger(ctx, log.G(ctx).WithField("method", fullPath(serviceName, methodName)))
method, err := s.resolve(serviceName, methodName)
if err != nil {
return nil, err