ttrpc: increase maximum message length

This change increases the maximum message size to 4MB to be inline
with the grpc default. The buffer management approach has been changed
to use a pool to minimize allocations and keep memory usage low.

Signed-off-by: Stephen J Day <stephen.day@docker.com>
This commit is contained in:
Stephen J Day 2017-11-29 13:30:41 -08:00
parent ed51a24609
commit 2a1ad5f6c7
No known key found for this signature in database
GPG Key ID: 67B3DED84EDC823F
5 changed files with 131 additions and 45 deletions

View File

@ -5,13 +5,16 @@ import (
"context" "context"
"encoding/binary" "encoding/binary"
"io" "io"
"sync"
"github.com/pkg/errors" "github.com/pkg/errors"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
) )
const ( const (
messageHeaderLength = 10 messageHeaderLength = 10
messageLengthMax = 8 << 10 messageLengthMax = 4 << 20
) )
type messageType uint8 type messageType uint8
@ -54,6 +57,8 @@ func writeMessageHeader(w io.Writer, p []byte, mh messageHeader) error {
return err return err
} }
var buffers sync.Pool
type channel struct { type channel struct {
bw *bufio.Writer bw *bufio.Writer
br *bufio.Reader br *bufio.Reader
@ -68,21 +73,32 @@ func newChannel(w io.Writer, r io.Reader) *channel {
} }
} }
func (ch *channel) recv(ctx context.Context, p []byte) (messageHeader, error) { // recv a message from the channel. The returned buffer contains the message.
//
// If a valid grpc status is returned, the message header
// returned will be valid and caller should send that along to
// the correct consumer. The bytes on the underlying channel
// will be discarded.
func (ch *channel) recv(ctx context.Context) (messageHeader, []byte, error) {
mh, err := readMessageHeader(ch.hrbuf[:], ch.br) mh, err := readMessageHeader(ch.hrbuf[:], ch.br)
if err != nil { if err != nil {
return messageHeader{}, err return messageHeader{}, nil, err
} }
if mh.Length > uint32(len(p)) { if mh.Length > uint32(messageLengthMax) {
return messageHeader{}, errors.Wrapf(io.ErrShortBuffer, "message length %v over buffer size %v", mh.Length, len(p)) if _, err := ch.br.Discard(int(mh.Length)); err != nil {
return mh, nil, errors.Wrapf(err, "failed to discard after receiving oversized message")
} }
if _, err := io.ReadFull(ch.br, p[:mh.Length]); err != nil { return mh, nil, status.Errorf(codes.ResourceExhausted, "message length %v exceed maximum message size of %v", mh.Length, messageLengthMax)
return messageHeader{}, errors.Wrapf(err, "failed reading message")
} }
return mh, nil p := ch.getmbuf(int(mh.Length))
if _, err := io.ReadFull(ch.br, p); err != nil {
return messageHeader{}, nil, errors.Wrapf(err, "failed reading message")
}
return mh, p, nil
} }
func (ch *channel) send(ctx context.Context, streamID uint32, t messageType, p []byte) error { func (ch *channel) send(ctx context.Context, streamID uint32, t messageType, p []byte) error {
@ -97,3 +113,23 @@ func (ch *channel) send(ctx context.Context, streamID uint32, t messageType, p [
return ch.bw.Flush() return ch.bw.Flush()
} }
func (ch *channel) getmbuf(size int) []byte {
// we can't use the standard New method on pool because we want to allocate
// based on size.
b, ok := buffers.Get().(*[]byte)
if !ok || cap(*b) < size {
// TODO(stevvooe): It may be better to allocate these in fixed length
// buckets to reduce fragmentation but its not clear that would help
// with performance. An ilogb approach or similar would work well.
bb := make([]byte, size)
b = &bb
} else {
*b = (*b)[:size]
}
return *b
}
func (ch *channel) putmbuf(p []byte) {
buffers.Put(&p)
}

View File

@ -9,6 +9,8 @@ import (
"testing" "testing"
"github.com/pkg/errors" "github.com/pkg/errors"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
) )
func TestReadWriteMessage(t *testing.T) { func TestReadWriteMessage(t *testing.T) {
@ -37,8 +39,7 @@ func TestReadWriteMessage(t *testing.T) {
) )
for { for {
var p [4096]byte _, p, err := rch.recv(ctx)
mh, err := rch.recv(ctx, p[:])
if err != nil { if err != nil {
if errors.Cause(err) != io.EOF { if errors.Cause(err) != io.EOF {
t.Fatal(err) t.Fatal(err)
@ -46,7 +47,7 @@ func TestReadWriteMessage(t *testing.T) {
break break
} }
received = append(received, p[:mh.Length]) received = append(received, p)
} }
if !reflect.DeepEqual(received, messages) { if !reflect.DeepEqual(received, messages) {
@ -54,13 +55,13 @@ func TestReadWriteMessage(t *testing.T) {
} }
} }
func TestSmallBuffer(t *testing.T) { func TestMessageOversize(t *testing.T) {
var ( var (
ctx = context.Background() ctx = context.Background()
buffer bytes.Buffer buffer bytes.Buffer
w = bufio.NewWriter(&buffer) w = bufio.NewWriter(&buffer)
ch = newChannel(w, nil) ch = newChannel(w, nil)
msg = []byte("a message of massive length") msg = bytes.Repeat([]byte("a message of massive length"), 512<<10)
) )
if err := ch.send(ctx, 1, 1, msg); err != nil { if err := ch.send(ctx, 1, 1, msg); err != nil {
@ -69,17 +70,21 @@ func TestSmallBuffer(t *testing.T) {
// now, read it off the channel with a small buffer // now, read it off the channel with a small buffer
var ( var (
p = make([]byte, len(msg)-1)
r = bufio.NewReader(bytes.NewReader(buffer.Bytes())) r = bufio.NewReader(bytes.NewReader(buffer.Bytes()))
rch = newChannel(nil, r) rch = newChannel(nil, r)
) )
_, err := rch.recv(ctx, p[:]) _, _, err := rch.recv(ctx)
if err == nil { if err == nil {
t.Fatalf("error expected reading with small buffer") t.Fatalf("error expected reading with small buffer")
} }
if errors.Cause(err) != io.ErrShortBuffer { status, ok := status.FromError(err)
t.Fatalf("errors.Cause(err) should equal io.ErrShortBuffer: %v != %v", err, io.ErrShortBuffer) if !ok {
t.Fatalf("expected grpc status error: %v", err)
}
if status.Code() != codes.ResourceExhausted {
t.Fatalf("expected grpc status code: %v != %v", status.Code(), codes.ResourceExhausted)
} }
} }

View File

@ -167,12 +167,11 @@ func (c *Client) run() {
go func() { go func() {
// start one more goroutine to recv messages without blocking. // start one more goroutine to recv messages without blocking.
for { for {
var p [messageLengthMax]byte
// TODO(stevvooe): Something still isn't quite right with error // TODO(stevvooe): Something still isn't quite right with error
// handling on the client-side, causing EOFs to come through. We // handling on the client-side, causing EOFs to come through. We
// need other fixes in this changeset, so we'll address this // need other fixes in this changeset, so we'll address this
// correctly later. // correctly later.
mh, err := c.channel.recv(context.TODO(), p[:]) mh, p, err := c.channel.recv(context.TODO())
select { select {
case incoming <- received{ case incoming <- received{
mh: mh, mh: mh,
@ -204,6 +203,7 @@ func (c *Client) run() {
waiter.err <- r.err waiter.err <- r.err
} else { } else {
waiter.err <- proto.Unmarshal(r.p, waiter.msg.(proto.Message)) waiter.err <- proto.Unmarshal(r.p, waiter.msg.(proto.Message))
c.channel.putmbuf(r.p)
} }
} else { } else {
queued[r.mh.StreamID] = r queued[r.mh.StreamID] = r

View File

@ -275,36 +275,15 @@ func (c *serverConn) run(sctx context.Context) {
go func(recvErr chan error) { go func(recvErr chan error) {
defer close(recvErr) defer close(recvErr)
var p [messageLengthMax]byte sendImmediate := func(id uint32, st *status.Status) bool {
for {
select {
case <-c.shutdown:
return
case <-done:
return
default: // proceed
}
mh, err := ch.recv(ctx, p[:])
if err != nil {
recvErr <- err
return
}
if mh.Type != messageTypeRequest {
// we must ignore this for future compat.
continue
}
sendImmediate := func(code codes.Code, msg string, args ...interface{}) bool {
select { select {
case responses <- response{ case responses <- response{
// even though we've had an invalid stream id, we send it // even though we've had an invalid stream id, we send it
// back on the same stream id so the client knows which // back on the same stream id so the client knows which
// stream id was bad. // stream id was bad.
id: mh.StreamID, id: id,
resp: &Response{ resp: &Response{
Status: status.Newf(code, msg, args...).Proto(), Status: st.Proto(),
}, },
}: }:
return true return true
@ -315,9 +294,40 @@ func (c *serverConn) run(sctx context.Context) {
} }
} }
for {
select {
case <-c.shutdown:
return
case <-done:
return
default: // proceed
}
mh, p, err := ch.recv(ctx)
if err != nil {
status, ok := status.FromError(err)
if !ok {
recvErr <- err
return
}
// in this case, we send an error for that particular message
// when the status is defined.
if !sendImmediate(mh.StreamID, status) {
return
}
continue
}
if mh.Type != messageTypeRequest {
// we must ignore this for future compat.
continue
}
var req Request var req Request
if err := c.server.codec.Unmarshal(p[:mh.Length], &req); err != nil { if err := c.server.codec.Unmarshal(p, &req); err != nil {
if !sendImmediate(codes.InvalidArgument, "unmarshal request error: %v", err) { if !sendImmediate(mh.StreamID, status.Newf(codes.InvalidArgument, "unmarshal request error: %v", err)) {
return return
} }
continue continue
@ -325,7 +335,7 @@ func (c *serverConn) run(sctx context.Context) {
if mh.StreamID%2 != 1 { if mh.StreamID%2 != 1 {
// enforce odd client initiated identifiers. // enforce odd client initiated identifiers.
if !sendImmediate(codes.InvalidArgument, "StreamID must be odd for client initiated streams") { if !sendImmediate(mh.StreamID, status.Newf(codes.InvalidArgument, "StreamID must be odd for client initiated streams")) {
return return
} }
continue continue

View File

@ -246,6 +246,41 @@ func TestServerClose(t *testing.T) {
checkServerShutdown(t, server) checkServerShutdown(t, server)
} }
func TestOversizeCall(t *testing.T) {
var (
ctx = context.Background()
server = NewServer()
addr, listener = newTestListener(t)
errs = make(chan error, 1)
client, cleanup = newTestClient(t, addr)
)
defer cleanup()
defer listener.Close()
go func() {
errs <- server.Serve(listener)
}()
registerTestingService(server, &testingServer{})
tp := &testPayload{
Foo: strings.Repeat("a", 1+messageLengthMax),
}
if err := client.Call(ctx, serviceName, "Test", tp, tp); err == nil {
t.Fatalf("expected error from non-existent service call")
} else if status, ok := status.FromError(err); !ok {
t.Fatalf("expected status present in error: %v", err)
} else if status.Code() != codes.ResourceExhausted {
t.Fatalf("expected code: %v != %v", status.Code(), codes.ResourceExhausted)
}
if err := server.Shutdown(ctx); err != nil {
t.Fatal(err)
}
if err := <-errs; err != ErrServerClosed {
t.Fatal(err)
}
}
func checkServerShutdown(t *testing.T, server *Server) { func checkServerShutdown(t *testing.T, server *Server) {
t.Helper() t.Helper()
server.mu.Lock() server.mu.Lock()