ttrpc: increase maximum message length
This change increases the maximum message size to 4MB to be inline with the grpc default. The buffer management approach has been changed to use a pool to minimize allocations and keep memory usage low. Signed-off-by: Stephen J Day <stephen.day@docker.com>
This commit is contained in:
		
							
								
								
									
										52
									
								
								channel.go
									
									
									
									
									
								
							
							
						
						
									
										52
									
								
								channel.go
									
									
									
									
									
								
							| @@ -5,13 +5,16 @@ import ( | |||||||
| 	"context" | 	"context" | ||||||
| 	"encoding/binary" | 	"encoding/binary" | ||||||
| 	"io" | 	"io" | ||||||
|  | 	"sync" | ||||||
|  |  | ||||||
| 	"github.com/pkg/errors" | 	"github.com/pkg/errors" | ||||||
|  | 	"google.golang.org/grpc/codes" | ||||||
|  | 	"google.golang.org/grpc/status" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
| 	messageHeaderLength = 10 | 	messageHeaderLength = 10 | ||||||
| 	messageLengthMax    = 8 << 10 | 	messageLengthMax    = 4 << 20 | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type messageType uint8 | type messageType uint8 | ||||||
| @@ -54,6 +57,8 @@ func writeMessageHeader(w io.Writer, p []byte, mh messageHeader) error { | |||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
|  |  | ||||||
|  | var buffers sync.Pool | ||||||
|  |  | ||||||
| type channel struct { | type channel struct { | ||||||
| 	bw    *bufio.Writer | 	bw    *bufio.Writer | ||||||
| 	br    *bufio.Reader | 	br    *bufio.Reader | ||||||
| @@ -68,21 +73,32 @@ func newChannel(w io.Writer, r io.Reader) *channel { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func (ch *channel) recv(ctx context.Context, p []byte) (messageHeader, error) { | // recv a message from the channel. The returned buffer contains the message. | ||||||
|  | // | ||||||
|  | // If a valid grpc status is returned, the message header | ||||||
|  | // returned will be valid and caller should send that along to | ||||||
|  | // the correct consumer. The bytes on the underlying channel | ||||||
|  | // will be discarded. | ||||||
|  | func (ch *channel) recv(ctx context.Context) (messageHeader, []byte, error) { | ||||||
| 	mh, err := readMessageHeader(ch.hrbuf[:], ch.br) | 	mh, err := readMessageHeader(ch.hrbuf[:], ch.br) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return messageHeader{}, err | 		return messageHeader{}, nil, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if mh.Length > uint32(len(p)) { | 	if mh.Length > uint32(messageLengthMax) { | ||||||
| 		return messageHeader{}, errors.Wrapf(io.ErrShortBuffer, "message length %v over buffer size %v", mh.Length, len(p)) | 		if _, err := ch.br.Discard(int(mh.Length)); err != nil { | ||||||
|  | 			return mh, nil, errors.Wrapf(err, "failed to discard after receiving oversized message") | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		return mh, nil, status.Errorf(codes.ResourceExhausted, "message length %v exceed maximum message size of %v", mh.Length, messageLengthMax) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if _, err := io.ReadFull(ch.br, p[:mh.Length]); err != nil { | 	p := ch.getmbuf(int(mh.Length)) | ||||||
| 		return messageHeader{}, errors.Wrapf(err, "failed reading message") | 	if _, err := io.ReadFull(ch.br, p); err != nil { | ||||||
|  | 		return messageHeader{}, nil, errors.Wrapf(err, "failed reading message") | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return mh, nil | 	return mh, p, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (ch *channel) send(ctx context.Context, streamID uint32, t messageType, p []byte) error { | func (ch *channel) send(ctx context.Context, streamID uint32, t messageType, p []byte) error { | ||||||
| @@ -97,3 +113,23 @@ func (ch *channel) send(ctx context.Context, streamID uint32, t messageType, p [ | |||||||
|  |  | ||||||
| 	return ch.bw.Flush() | 	return ch.bw.Flush() | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (ch *channel) getmbuf(size int) []byte { | ||||||
|  | 	// we can't use the standard New method on pool because we want to allocate | ||||||
|  | 	// based on size. | ||||||
|  | 	b, ok := buffers.Get().(*[]byte) | ||||||
|  | 	if !ok || cap(*b) < size { | ||||||
|  | 		// TODO(stevvooe): It may be better to allocate these in fixed length | ||||||
|  | 		// buckets to reduce fragmentation but its not clear that would help | ||||||
|  | 		// with performance. An ilogb approach or similar would work well. | ||||||
|  | 		bb := make([]byte, size) | ||||||
|  | 		b = &bb | ||||||
|  | 	} else { | ||||||
|  | 		*b = (*b)[:size] | ||||||
|  | 	} | ||||||
|  | 	return *b | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (ch *channel) putmbuf(p []byte) { | ||||||
|  | 	buffers.Put(&p) | ||||||
|  | } | ||||||
|   | |||||||
| @@ -9,6 +9,8 @@ import ( | |||||||
| 	"testing" | 	"testing" | ||||||
|  |  | ||||||
| 	"github.com/pkg/errors" | 	"github.com/pkg/errors" | ||||||
|  | 	"google.golang.org/grpc/codes" | ||||||
|  | 	"google.golang.org/grpc/status" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func TestReadWriteMessage(t *testing.T) { | func TestReadWriteMessage(t *testing.T) { | ||||||
| @@ -37,8 +39,7 @@ func TestReadWriteMessage(t *testing.T) { | |||||||
| 	) | 	) | ||||||
|  |  | ||||||
| 	for { | 	for { | ||||||
| 		var p [4096]byte | 		_, p, err := rch.recv(ctx) | ||||||
| 		mh, err := rch.recv(ctx, p[:]) |  | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			if errors.Cause(err) != io.EOF { | 			if errors.Cause(err) != io.EOF { | ||||||
| 				t.Fatal(err) | 				t.Fatal(err) | ||||||
| @@ -46,7 +47,7 @@ func TestReadWriteMessage(t *testing.T) { | |||||||
|  |  | ||||||
| 			break | 			break | ||||||
| 		} | 		} | ||||||
| 		received = append(received, p[:mh.Length]) | 		received = append(received, p) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if !reflect.DeepEqual(received, messages) { | 	if !reflect.DeepEqual(received, messages) { | ||||||
| @@ -54,13 +55,13 @@ func TestReadWriteMessage(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func TestSmallBuffer(t *testing.T) { | func TestMessageOversize(t *testing.T) { | ||||||
| 	var ( | 	var ( | ||||||
| 		ctx    = context.Background() | 		ctx    = context.Background() | ||||||
| 		buffer bytes.Buffer | 		buffer bytes.Buffer | ||||||
| 		w      = bufio.NewWriter(&buffer) | 		w      = bufio.NewWriter(&buffer) | ||||||
| 		ch     = newChannel(w, nil) | 		ch     = newChannel(w, nil) | ||||||
| 		msg    = []byte("a message of massive length") | 		msg    = bytes.Repeat([]byte("a message of massive length"), 512<<10) | ||||||
| 	) | 	) | ||||||
|  |  | ||||||
| 	if err := ch.send(ctx, 1, 1, msg); err != nil { | 	if err := ch.send(ctx, 1, 1, msg); err != nil { | ||||||
| @@ -69,17 +70,21 @@ func TestSmallBuffer(t *testing.T) { | |||||||
|  |  | ||||||
| 	// now, read it off the channel with a small buffer | 	// now, read it off the channel with a small buffer | ||||||
| 	var ( | 	var ( | ||||||
| 		p   = make([]byte, len(msg)-1) |  | ||||||
| 		r   = bufio.NewReader(bytes.NewReader(buffer.Bytes())) | 		r   = bufio.NewReader(bytes.NewReader(buffer.Bytes())) | ||||||
| 		rch = newChannel(nil, r) | 		rch = newChannel(nil, r) | ||||||
| 	) | 	) | ||||||
|  |  | ||||||
| 	_, err := rch.recv(ctx, p[:]) | 	_, _, err := rch.recv(ctx) | ||||||
| 	if err == nil { | 	if err == nil { | ||||||
| 		t.Fatalf("error expected reading with small buffer") | 		t.Fatalf("error expected reading with small buffer") | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if errors.Cause(err) != io.ErrShortBuffer { | 	status, ok := status.FromError(err) | ||||||
| 		t.Fatalf("errors.Cause(err) should equal io.ErrShortBuffer: %v != %v", err, io.ErrShortBuffer) | 	if !ok { | ||||||
|  | 		t.Fatalf("expected grpc status error: %v", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if status.Code() != codes.ResourceExhausted { | ||||||
|  | 		t.Fatalf("expected grpc status code: %v != %v", status.Code(), codes.ResourceExhausted) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|   | |||||||
| @@ -167,12 +167,11 @@ func (c *Client) run() { | |||||||
| 	go func() { | 	go func() { | ||||||
| 		// start one more goroutine to recv messages without blocking. | 		// start one more goroutine to recv messages without blocking. | ||||||
| 		for { | 		for { | ||||||
| 			var p [messageLengthMax]byte |  | ||||||
| 			// TODO(stevvooe): Something still isn't quite right with error | 			// TODO(stevvooe): Something still isn't quite right with error | ||||||
| 			// handling on the client-side, causing EOFs to come through. We | 			// handling on the client-side, causing EOFs to come through. We | ||||||
| 			// need other fixes in this changeset, so we'll address this | 			// need other fixes in this changeset, so we'll address this | ||||||
| 			// correctly later. | 			// correctly later. | ||||||
| 			mh, err := c.channel.recv(context.TODO(), p[:]) | 			mh, p, err := c.channel.recv(context.TODO()) | ||||||
| 			select { | 			select { | ||||||
| 			case incoming <- received{ | 			case incoming <- received{ | ||||||
| 				mh:  mh, | 				mh:  mh, | ||||||
| @@ -204,6 +203,7 @@ func (c *Client) run() { | |||||||
| 					waiter.err <- r.err | 					waiter.err <- r.err | ||||||
| 				} else { | 				} else { | ||||||
| 					waiter.err <- proto.Unmarshal(r.p, waiter.msg.(proto.Message)) | 					waiter.err <- proto.Unmarshal(r.p, waiter.msg.(proto.Message)) | ||||||
|  | 					c.channel.putmbuf(r.p) | ||||||
| 				} | 				} | ||||||
| 			} else { | 			} else { | ||||||
| 				queued[r.mh.StreamID] = r | 				queued[r.mh.StreamID] = r | ||||||
|   | |||||||
							
								
								
									
										62
									
								
								server.go
									
									
									
									
									
								
							
							
						
						
									
										62
									
								
								server.go
									
									
									
									
									
								
							| @@ -275,7 +275,25 @@ func (c *serverConn) run(sctx context.Context) { | |||||||
|  |  | ||||||
| 	go func(recvErr chan error) { | 	go func(recvErr chan error) { | ||||||
| 		defer close(recvErr) | 		defer close(recvErr) | ||||||
| 		var p [messageLengthMax]byte | 		sendImmediate := func(id uint32, st *status.Status) bool { | ||||||
|  | 			select { | ||||||
|  | 			case responses <- response{ | ||||||
|  | 				// even though we've had an invalid stream id, we send it | ||||||
|  | 				// back on the same stream id so the client knows which | ||||||
|  | 				// stream id was bad. | ||||||
|  | 				id: id, | ||||||
|  | 				resp: &Response{ | ||||||
|  | 					Status: st.Proto(), | ||||||
|  | 				}, | ||||||
|  | 			}: | ||||||
|  | 				return true | ||||||
|  | 			case <-c.shutdown: | ||||||
|  | 				return false | ||||||
|  | 			case <-done: | ||||||
|  | 				return false | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  |  | ||||||
| 		for { | 		for { | ||||||
| 			select { | 			select { | ||||||
| 			case <-c.shutdown: | 			case <-c.shutdown: | ||||||
| @@ -285,10 +303,21 @@ func (c *serverConn) run(sctx context.Context) { | |||||||
| 			default: // proceed | 			default: // proceed | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
| 			mh, err := ch.recv(ctx, p[:]) | 			mh, p, err := ch.recv(ctx) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				recvErr <- err | 				status, ok := status.FromError(err) | ||||||
| 				return | 				if !ok { | ||||||
|  | 					recvErr <- err | ||||||
|  | 					return | ||||||
|  | 				} | ||||||
|  |  | ||||||
|  | 				// in this case, we send an error for that particular message | ||||||
|  | 				// when the status is defined. | ||||||
|  | 				if !sendImmediate(mh.StreamID, status) { | ||||||
|  | 					return | ||||||
|  | 				} | ||||||
|  |  | ||||||
|  | 				continue | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
| 			if mh.Type != messageTypeRequest { | 			if mh.Type != messageTypeRequest { | ||||||
| @@ -296,28 +325,9 @@ func (c *serverConn) run(sctx context.Context) { | |||||||
| 				continue | 				continue | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
| 			sendImmediate := func(code codes.Code, msg string, args ...interface{}) bool { |  | ||||||
| 				select { |  | ||||||
| 				case responses <- response{ |  | ||||||
| 					// even though we've had an invalid stream id, we send it |  | ||||||
| 					// back on the same stream id so the client knows which |  | ||||||
| 					// stream id was bad. |  | ||||||
| 					id: mh.StreamID, |  | ||||||
| 					resp: &Response{ |  | ||||||
| 						Status: status.Newf(code, msg, args...).Proto(), |  | ||||||
| 					}, |  | ||||||
| 				}: |  | ||||||
| 					return true |  | ||||||
| 				case <-c.shutdown: |  | ||||||
| 					return false |  | ||||||
| 				case <-done: |  | ||||||
| 					return false |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			var req Request | 			var req Request | ||||||
| 			if err := c.server.codec.Unmarshal(p[:mh.Length], &req); err != nil { | 			if err := c.server.codec.Unmarshal(p, &req); err != nil { | ||||||
| 				if !sendImmediate(codes.InvalidArgument, "unmarshal request error: %v", err) { | 				if !sendImmediate(mh.StreamID, status.Newf(codes.InvalidArgument, "unmarshal request error: %v", err)) { | ||||||
| 					return | 					return | ||||||
| 				} | 				} | ||||||
| 				continue | 				continue | ||||||
| @@ -325,7 +335,7 @@ func (c *serverConn) run(sctx context.Context) { | |||||||
|  |  | ||||||
| 			if mh.StreamID%2 != 1 { | 			if mh.StreamID%2 != 1 { | ||||||
| 				// enforce odd client initiated identifiers. | 				// enforce odd client initiated identifiers. | ||||||
| 				if !sendImmediate(codes.InvalidArgument, "StreamID must be odd for client initiated streams") { | 				if !sendImmediate(mh.StreamID, status.Newf(codes.InvalidArgument, "StreamID must be odd for client initiated streams")) { | ||||||
| 					return | 					return | ||||||
| 				} | 				} | ||||||
| 				continue | 				continue | ||||||
|   | |||||||
| @@ -246,6 +246,41 @@ func TestServerClose(t *testing.T) { | |||||||
| 	checkServerShutdown(t, server) | 	checkServerShutdown(t, server) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func TestOversizeCall(t *testing.T) { | ||||||
|  | 	var ( | ||||||
|  | 		ctx             = context.Background() | ||||||
|  | 		server          = NewServer() | ||||||
|  | 		addr, listener  = newTestListener(t) | ||||||
|  | 		errs            = make(chan error, 1) | ||||||
|  | 		client, cleanup = newTestClient(t, addr) | ||||||
|  | 	) | ||||||
|  | 	defer cleanup() | ||||||
|  | 	defer listener.Close() | ||||||
|  | 	go func() { | ||||||
|  | 		errs <- server.Serve(listener) | ||||||
|  | 	}() | ||||||
|  |  | ||||||
|  | 	registerTestingService(server, &testingServer{}) | ||||||
|  |  | ||||||
|  | 	tp := &testPayload{ | ||||||
|  | 		Foo: strings.Repeat("a", 1+messageLengthMax), | ||||||
|  | 	} | ||||||
|  | 	if err := client.Call(ctx, serviceName, "Test", tp, tp); err == nil { | ||||||
|  | 		t.Fatalf("expected error from non-existent service call") | ||||||
|  | 	} else if status, ok := status.FromError(err); !ok { | ||||||
|  | 		t.Fatalf("expected status present in error: %v", err) | ||||||
|  | 	} else if status.Code() != codes.ResourceExhausted { | ||||||
|  | 		t.Fatalf("expected code: %v != %v", status.Code(), codes.ResourceExhausted) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if err := server.Shutdown(ctx); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	if err := <-errs; err != ErrServerClosed { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
| func checkServerShutdown(t *testing.T, server *Server) { | func checkServerShutdown(t *testing.T, server *Server) { | ||||||
| 	t.Helper() | 	t.Helper() | ||||||
| 	server.mu.Lock() | 	server.mu.Lock() | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Stephen J Day
					Stephen J Day