ttrpc: increase maximum message length
This change increases the maximum message size to 4MB to be inline with the grpc default. The buffer management approach has been changed to use a pool to minimize allocations and keep memory usage low. Signed-off-by: Stephen J Day <stephen.day@docker.com>
This commit is contained in:
parent
ed51a24609
commit
2a1ad5f6c7
52
channel.go
52
channel.go
@ -5,13 +5,16 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"io"
|
"io"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
messageHeaderLength = 10
|
messageHeaderLength = 10
|
||||||
messageLengthMax = 8 << 10
|
messageLengthMax = 4 << 20
|
||||||
)
|
)
|
||||||
|
|
||||||
type messageType uint8
|
type messageType uint8
|
||||||
@ -54,6 +57,8 @@ func writeMessageHeader(w io.Writer, p []byte, mh messageHeader) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var buffers sync.Pool
|
||||||
|
|
||||||
type channel struct {
|
type channel struct {
|
||||||
bw *bufio.Writer
|
bw *bufio.Writer
|
||||||
br *bufio.Reader
|
br *bufio.Reader
|
||||||
@ -68,21 +73,32 @@ func newChannel(w io.Writer, r io.Reader) *channel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ch *channel) recv(ctx context.Context, p []byte) (messageHeader, error) {
|
// recv a message from the channel. The returned buffer contains the message.
|
||||||
|
//
|
||||||
|
// If a valid grpc status is returned, the message header
|
||||||
|
// returned will be valid and caller should send that along to
|
||||||
|
// the correct consumer. The bytes on the underlying channel
|
||||||
|
// will be discarded.
|
||||||
|
func (ch *channel) recv(ctx context.Context) (messageHeader, []byte, error) {
|
||||||
mh, err := readMessageHeader(ch.hrbuf[:], ch.br)
|
mh, err := readMessageHeader(ch.hrbuf[:], ch.br)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return messageHeader{}, err
|
return messageHeader{}, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if mh.Length > uint32(len(p)) {
|
if mh.Length > uint32(messageLengthMax) {
|
||||||
return messageHeader{}, errors.Wrapf(io.ErrShortBuffer, "message length %v over buffer size %v", mh.Length, len(p))
|
if _, err := ch.br.Discard(int(mh.Length)); err != nil {
|
||||||
|
return mh, nil, errors.Wrapf(err, "failed to discard after receiving oversized message")
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := io.ReadFull(ch.br, p[:mh.Length]); err != nil {
|
return mh, nil, status.Errorf(codes.ResourceExhausted, "message length %v exceed maximum message size of %v", mh.Length, messageLengthMax)
|
||||||
return messageHeader{}, errors.Wrapf(err, "failed reading message")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return mh, nil
|
p := ch.getmbuf(int(mh.Length))
|
||||||
|
if _, err := io.ReadFull(ch.br, p); err != nil {
|
||||||
|
return messageHeader{}, nil, errors.Wrapf(err, "failed reading message")
|
||||||
|
}
|
||||||
|
|
||||||
|
return mh, p, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ch *channel) send(ctx context.Context, streamID uint32, t messageType, p []byte) error {
|
func (ch *channel) send(ctx context.Context, streamID uint32, t messageType, p []byte) error {
|
||||||
@ -97,3 +113,23 @@ func (ch *channel) send(ctx context.Context, streamID uint32, t messageType, p [
|
|||||||
|
|
||||||
return ch.bw.Flush()
|
return ch.bw.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ch *channel) getmbuf(size int) []byte {
|
||||||
|
// we can't use the standard New method on pool because we want to allocate
|
||||||
|
// based on size.
|
||||||
|
b, ok := buffers.Get().(*[]byte)
|
||||||
|
if !ok || cap(*b) < size {
|
||||||
|
// TODO(stevvooe): It may be better to allocate these in fixed length
|
||||||
|
// buckets to reduce fragmentation but its not clear that would help
|
||||||
|
// with performance. An ilogb approach or similar would work well.
|
||||||
|
bb := make([]byte, size)
|
||||||
|
b = &bb
|
||||||
|
} else {
|
||||||
|
*b = (*b)[:size]
|
||||||
|
}
|
||||||
|
return *b
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ch *channel) putmbuf(p []byte) {
|
||||||
|
buffers.Put(&p)
|
||||||
|
}
|
||||||
|
@ -9,6 +9,8 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestReadWriteMessage(t *testing.T) {
|
func TestReadWriteMessage(t *testing.T) {
|
||||||
@ -37,8 +39,7 @@ func TestReadWriteMessage(t *testing.T) {
|
|||||||
)
|
)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
var p [4096]byte
|
_, p, err := rch.recv(ctx)
|
||||||
mh, err := rch.recv(ctx, p[:])
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Cause(err) != io.EOF {
|
if errors.Cause(err) != io.EOF {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@ -46,7 +47,7 @@ func TestReadWriteMessage(t *testing.T) {
|
|||||||
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
received = append(received, p[:mh.Length])
|
received = append(received, p)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !reflect.DeepEqual(received, messages) {
|
if !reflect.DeepEqual(received, messages) {
|
||||||
@ -54,13 +55,13 @@ func TestReadWriteMessage(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSmallBuffer(t *testing.T) {
|
func TestMessageOversize(t *testing.T) {
|
||||||
var (
|
var (
|
||||||
ctx = context.Background()
|
ctx = context.Background()
|
||||||
buffer bytes.Buffer
|
buffer bytes.Buffer
|
||||||
w = bufio.NewWriter(&buffer)
|
w = bufio.NewWriter(&buffer)
|
||||||
ch = newChannel(w, nil)
|
ch = newChannel(w, nil)
|
||||||
msg = []byte("a message of massive length")
|
msg = bytes.Repeat([]byte("a message of massive length"), 512<<10)
|
||||||
)
|
)
|
||||||
|
|
||||||
if err := ch.send(ctx, 1, 1, msg); err != nil {
|
if err := ch.send(ctx, 1, 1, msg); err != nil {
|
||||||
@ -69,17 +70,21 @@ func TestSmallBuffer(t *testing.T) {
|
|||||||
|
|
||||||
// now, read it off the channel with a small buffer
|
// now, read it off the channel with a small buffer
|
||||||
var (
|
var (
|
||||||
p = make([]byte, len(msg)-1)
|
|
||||||
r = bufio.NewReader(bytes.NewReader(buffer.Bytes()))
|
r = bufio.NewReader(bytes.NewReader(buffer.Bytes()))
|
||||||
rch = newChannel(nil, r)
|
rch = newChannel(nil, r)
|
||||||
)
|
)
|
||||||
|
|
||||||
_, err := rch.recv(ctx, p[:])
|
_, _, err := rch.recv(ctx)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatalf("error expected reading with small buffer")
|
t.Fatalf("error expected reading with small buffer")
|
||||||
}
|
}
|
||||||
|
|
||||||
if errors.Cause(err) != io.ErrShortBuffer {
|
status, ok := status.FromError(err)
|
||||||
t.Fatalf("errors.Cause(err) should equal io.ErrShortBuffer: %v != %v", err, io.ErrShortBuffer)
|
if !ok {
|
||||||
|
t.Fatalf("expected grpc status error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if status.Code() != codes.ResourceExhausted {
|
||||||
|
t.Fatalf("expected grpc status code: %v != %v", status.Code(), codes.ResourceExhausted)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -167,12 +167,11 @@ func (c *Client) run() {
|
|||||||
go func() {
|
go func() {
|
||||||
// start one more goroutine to recv messages without blocking.
|
// start one more goroutine to recv messages without blocking.
|
||||||
for {
|
for {
|
||||||
var p [messageLengthMax]byte
|
|
||||||
// TODO(stevvooe): Something still isn't quite right with error
|
// TODO(stevvooe): Something still isn't quite right with error
|
||||||
// handling on the client-side, causing EOFs to come through. We
|
// handling on the client-side, causing EOFs to come through. We
|
||||||
// need other fixes in this changeset, so we'll address this
|
// need other fixes in this changeset, so we'll address this
|
||||||
// correctly later.
|
// correctly later.
|
||||||
mh, err := c.channel.recv(context.TODO(), p[:])
|
mh, p, err := c.channel.recv(context.TODO())
|
||||||
select {
|
select {
|
||||||
case incoming <- received{
|
case incoming <- received{
|
||||||
mh: mh,
|
mh: mh,
|
||||||
@ -204,6 +203,7 @@ func (c *Client) run() {
|
|||||||
waiter.err <- r.err
|
waiter.err <- r.err
|
||||||
} else {
|
} else {
|
||||||
waiter.err <- proto.Unmarshal(r.p, waiter.msg.(proto.Message))
|
waiter.err <- proto.Unmarshal(r.p, waiter.msg.(proto.Message))
|
||||||
|
c.channel.putmbuf(r.p)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
queued[r.mh.StreamID] = r
|
queued[r.mh.StreamID] = r
|
||||||
|
64
server.go
64
server.go
@ -275,36 +275,15 @@ func (c *serverConn) run(sctx context.Context) {
|
|||||||
|
|
||||||
go func(recvErr chan error) {
|
go func(recvErr chan error) {
|
||||||
defer close(recvErr)
|
defer close(recvErr)
|
||||||
var p [messageLengthMax]byte
|
sendImmediate := func(id uint32, st *status.Status) bool {
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-c.shutdown:
|
|
||||||
return
|
|
||||||
case <-done:
|
|
||||||
return
|
|
||||||
default: // proceed
|
|
||||||
}
|
|
||||||
|
|
||||||
mh, err := ch.recv(ctx, p[:])
|
|
||||||
if err != nil {
|
|
||||||
recvErr <- err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if mh.Type != messageTypeRequest {
|
|
||||||
// we must ignore this for future compat.
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
sendImmediate := func(code codes.Code, msg string, args ...interface{}) bool {
|
|
||||||
select {
|
select {
|
||||||
case responses <- response{
|
case responses <- response{
|
||||||
// even though we've had an invalid stream id, we send it
|
// even though we've had an invalid stream id, we send it
|
||||||
// back on the same stream id so the client knows which
|
// back on the same stream id so the client knows which
|
||||||
// stream id was bad.
|
// stream id was bad.
|
||||||
id: mh.StreamID,
|
id: id,
|
||||||
resp: &Response{
|
resp: &Response{
|
||||||
Status: status.Newf(code, msg, args...).Proto(),
|
Status: st.Proto(),
|
||||||
},
|
},
|
||||||
}:
|
}:
|
||||||
return true
|
return true
|
||||||
@ -315,9 +294,40 @@ func (c *serverConn) run(sctx context.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-c.shutdown:
|
||||||
|
return
|
||||||
|
case <-done:
|
||||||
|
return
|
||||||
|
default: // proceed
|
||||||
|
}
|
||||||
|
|
||||||
|
mh, p, err := ch.recv(ctx)
|
||||||
|
if err != nil {
|
||||||
|
status, ok := status.FromError(err)
|
||||||
|
if !ok {
|
||||||
|
recvErr <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// in this case, we send an error for that particular message
|
||||||
|
// when the status is defined.
|
||||||
|
if !sendImmediate(mh.StreamID, status) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if mh.Type != messageTypeRequest {
|
||||||
|
// we must ignore this for future compat.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
var req Request
|
var req Request
|
||||||
if err := c.server.codec.Unmarshal(p[:mh.Length], &req); err != nil {
|
if err := c.server.codec.Unmarshal(p, &req); err != nil {
|
||||||
if !sendImmediate(codes.InvalidArgument, "unmarshal request error: %v", err) {
|
if !sendImmediate(mh.StreamID, status.Newf(codes.InvalidArgument, "unmarshal request error: %v", err)) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
@ -325,7 +335,7 @@ func (c *serverConn) run(sctx context.Context) {
|
|||||||
|
|
||||||
if mh.StreamID%2 != 1 {
|
if mh.StreamID%2 != 1 {
|
||||||
// enforce odd client initiated identifiers.
|
// enforce odd client initiated identifiers.
|
||||||
if !sendImmediate(codes.InvalidArgument, "StreamID must be odd for client initiated streams") {
|
if !sendImmediate(mh.StreamID, status.Newf(codes.InvalidArgument, "StreamID must be odd for client initiated streams")) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
|
@ -246,6 +246,41 @@ func TestServerClose(t *testing.T) {
|
|||||||
checkServerShutdown(t, server)
|
checkServerShutdown(t, server)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOversizeCall(t *testing.T) {
|
||||||
|
var (
|
||||||
|
ctx = context.Background()
|
||||||
|
server = NewServer()
|
||||||
|
addr, listener = newTestListener(t)
|
||||||
|
errs = make(chan error, 1)
|
||||||
|
client, cleanup = newTestClient(t, addr)
|
||||||
|
)
|
||||||
|
defer cleanup()
|
||||||
|
defer listener.Close()
|
||||||
|
go func() {
|
||||||
|
errs <- server.Serve(listener)
|
||||||
|
}()
|
||||||
|
|
||||||
|
registerTestingService(server, &testingServer{})
|
||||||
|
|
||||||
|
tp := &testPayload{
|
||||||
|
Foo: strings.Repeat("a", 1+messageLengthMax),
|
||||||
|
}
|
||||||
|
if err := client.Call(ctx, serviceName, "Test", tp, tp); err == nil {
|
||||||
|
t.Fatalf("expected error from non-existent service call")
|
||||||
|
} else if status, ok := status.FromError(err); !ok {
|
||||||
|
t.Fatalf("expected status present in error: %v", err)
|
||||||
|
} else if status.Code() != codes.ResourceExhausted {
|
||||||
|
t.Fatalf("expected code: %v != %v", status.Code(), codes.ResourceExhausted)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := server.Shutdown(ctx); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := <-errs; err != ErrServerClosed {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func checkServerShutdown(t *testing.T, server *Server) {
|
func checkServerShutdown(t *testing.T, server *Server) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
server.mu.Lock()
|
server.mu.Lock()
|
||||||
|
Loading…
Reference in New Issue
Block a user