diff --git a/channel.go b/channel.go index a71260b..4a33827 100644 --- a/channel.go +++ b/channel.go @@ -5,13 +5,16 @@ import ( "context" "encoding/binary" "io" + "sync" "github.com/pkg/errors" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) const ( messageHeaderLength = 10 - messageLengthMax = 8 << 10 + messageLengthMax = 4 << 20 ) type messageType uint8 @@ -54,6 +57,8 @@ func writeMessageHeader(w io.Writer, p []byte, mh messageHeader) error { return err } +var buffers sync.Pool + type channel struct { bw *bufio.Writer br *bufio.Reader @@ -68,21 +73,32 @@ func newChannel(w io.Writer, r io.Reader) *channel { } } -func (ch *channel) recv(ctx context.Context, p []byte) (messageHeader, error) { +// recv a message from the channel. The returned buffer contains the message. +// +// If a valid grpc status is returned, the message header +// returned will be valid and caller should send that along to +// the correct consumer. The bytes on the underlying channel +// will be discarded. +func (ch *channel) recv(ctx context.Context) (messageHeader, []byte, error) { mh, err := readMessageHeader(ch.hrbuf[:], ch.br) if err != nil { - return messageHeader{}, err + return messageHeader{}, nil, err } - if mh.Length > uint32(len(p)) { - return messageHeader{}, errors.Wrapf(io.ErrShortBuffer, "message length %v over buffer size %v", mh.Length, len(p)) + if mh.Length > uint32(messageLengthMax) { + if _, err := ch.br.Discard(int(mh.Length)); err != nil { + return mh, nil, errors.Wrapf(err, "failed to discard after receiving oversized message") + } + + return mh, nil, status.Errorf(codes.ResourceExhausted, "message length %v exceed maximum message size of %v", mh.Length, messageLengthMax) } - if _, err := io.ReadFull(ch.br, p[:mh.Length]); err != nil { - return messageHeader{}, errors.Wrapf(err, "failed reading message") + p := ch.getmbuf(int(mh.Length)) + if _, err := io.ReadFull(ch.br, p); err != nil { + return messageHeader{}, nil, errors.Wrapf(err, "failed reading message") } - return mh, nil + return mh, p, nil } func (ch *channel) send(ctx context.Context, streamID uint32, t messageType, p []byte) error { @@ -97,3 +113,23 @@ func (ch *channel) send(ctx context.Context, streamID uint32, t messageType, p [ return ch.bw.Flush() } + +func (ch *channel) getmbuf(size int) []byte { + // we can't use the standard New method on pool because we want to allocate + // based on size. + b, ok := buffers.Get().(*[]byte) + if !ok || cap(*b) < size { + // TODO(stevvooe): It may be better to allocate these in fixed length + // buckets to reduce fragmentation but its not clear that would help + // with performance. An ilogb approach or similar would work well. + bb := make([]byte, size) + b = &bb + } else { + *b = (*b)[:size] + } + return *b +} + +func (ch *channel) putmbuf(p []byte) { + buffers.Put(&p) +} diff --git a/channel_test.go b/channel_test.go index bde01cd..121aafe 100644 --- a/channel_test.go +++ b/channel_test.go @@ -9,6 +9,8 @@ import ( "testing" "github.com/pkg/errors" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) func TestReadWriteMessage(t *testing.T) { @@ -37,8 +39,7 @@ func TestReadWriteMessage(t *testing.T) { ) for { - var p [4096]byte - mh, err := rch.recv(ctx, p[:]) + _, p, err := rch.recv(ctx) if err != nil { if errors.Cause(err) != io.EOF { t.Fatal(err) @@ -46,7 +47,7 @@ func TestReadWriteMessage(t *testing.T) { break } - received = append(received, p[:mh.Length]) + received = append(received, p) } if !reflect.DeepEqual(received, messages) { @@ -54,13 +55,13 @@ func TestReadWriteMessage(t *testing.T) { } } -func TestSmallBuffer(t *testing.T) { +func TestMessageOversize(t *testing.T) { var ( ctx = context.Background() buffer bytes.Buffer w = bufio.NewWriter(&buffer) ch = newChannel(w, nil) - msg = []byte("a message of massive length") + msg = bytes.Repeat([]byte("a message of massive length"), 512<<10) ) if err := ch.send(ctx, 1, 1, msg); err != nil { @@ -69,17 +70,21 @@ func TestSmallBuffer(t *testing.T) { // now, read it off the channel with a small buffer var ( - p = make([]byte, len(msg)-1) r = bufio.NewReader(bytes.NewReader(buffer.Bytes())) rch = newChannel(nil, r) ) - _, err := rch.recv(ctx, p[:]) + _, _, err := rch.recv(ctx) if err == nil { t.Fatalf("error expected reading with small buffer") } - if errors.Cause(err) != io.ErrShortBuffer { - t.Fatalf("errors.Cause(err) should equal io.ErrShortBuffer: %v != %v", err, io.ErrShortBuffer) + status, ok := status.FromError(err) + if !ok { + t.Fatalf("expected grpc status error: %v", err) + } + + if status.Code() != codes.ResourceExhausted { + t.Fatalf("expected grpc status code: %v != %v", status.Code(), codes.ResourceExhausted) } } diff --git a/client.go b/client.go index 342f7bd..f0fb9ec 100644 --- a/client.go +++ b/client.go @@ -167,12 +167,11 @@ func (c *Client) run() { go func() { // start one more goroutine to recv messages without blocking. for { - var p [messageLengthMax]byte // TODO(stevvooe): Something still isn't quite right with error // handling on the client-side, causing EOFs to come through. We // need other fixes in this changeset, so we'll address this // correctly later. - mh, err := c.channel.recv(context.TODO(), p[:]) + mh, p, err := c.channel.recv(context.TODO()) select { case incoming <- received{ mh: mh, @@ -204,6 +203,7 @@ func (c *Client) run() { waiter.err <- r.err } else { waiter.err <- proto.Unmarshal(r.p, waiter.msg.(proto.Message)) + c.channel.putmbuf(r.p) } } else { queued[r.mh.StreamID] = r diff --git a/server.go b/server.go index 845d160..279b5be 100644 --- a/server.go +++ b/server.go @@ -275,7 +275,25 @@ func (c *serverConn) run(sctx context.Context) { go func(recvErr chan error) { defer close(recvErr) - var p [messageLengthMax]byte + sendImmediate := func(id uint32, st *status.Status) bool { + select { + case responses <- response{ + // even though we've had an invalid stream id, we send it + // back on the same stream id so the client knows which + // stream id was bad. + id: id, + resp: &Response{ + Status: st.Proto(), + }, + }: + return true + case <-c.shutdown: + return false + case <-done: + return false + } + } + for { select { case <-c.shutdown: @@ -285,10 +303,21 @@ func (c *serverConn) run(sctx context.Context) { default: // proceed } - mh, err := ch.recv(ctx, p[:]) + mh, p, err := ch.recv(ctx) if err != nil { - recvErr <- err - return + status, ok := status.FromError(err) + if !ok { + recvErr <- err + return + } + + // in this case, we send an error for that particular message + // when the status is defined. + if !sendImmediate(mh.StreamID, status) { + return + } + + continue } if mh.Type != messageTypeRequest { @@ -296,28 +325,9 @@ func (c *serverConn) run(sctx context.Context) { continue } - sendImmediate := func(code codes.Code, msg string, args ...interface{}) bool { - select { - case responses <- response{ - // even though we've had an invalid stream id, we send it - // back on the same stream id so the client knows which - // stream id was bad. - id: mh.StreamID, - resp: &Response{ - Status: status.Newf(code, msg, args...).Proto(), - }, - }: - return true - case <-c.shutdown: - return false - case <-done: - return false - } - } - var req Request - if err := c.server.codec.Unmarshal(p[:mh.Length], &req); err != nil { - if !sendImmediate(codes.InvalidArgument, "unmarshal request error: %v", err) { + if err := c.server.codec.Unmarshal(p, &req); err != nil { + if !sendImmediate(mh.StreamID, status.Newf(codes.InvalidArgument, "unmarshal request error: %v", err)) { return } continue @@ -325,7 +335,7 @@ func (c *serverConn) run(sctx context.Context) { if mh.StreamID%2 != 1 { // enforce odd client initiated identifiers. - if !sendImmediate(codes.InvalidArgument, "StreamID must be odd for client initiated streams") { + if !sendImmediate(mh.StreamID, status.Newf(codes.InvalidArgument, "StreamID must be odd for client initiated streams")) { return } continue diff --git a/server_test.go b/server_test.go index b6a0356..8ce6fd4 100644 --- a/server_test.go +++ b/server_test.go @@ -246,6 +246,41 @@ func TestServerClose(t *testing.T) { checkServerShutdown(t, server) } +func TestOversizeCall(t *testing.T) { + var ( + ctx = context.Background() + server = NewServer() + addr, listener = newTestListener(t) + errs = make(chan error, 1) + client, cleanup = newTestClient(t, addr) + ) + defer cleanup() + defer listener.Close() + go func() { + errs <- server.Serve(listener) + }() + + registerTestingService(server, &testingServer{}) + + tp := &testPayload{ + Foo: strings.Repeat("a", 1+messageLengthMax), + } + if err := client.Call(ctx, serviceName, "Test", tp, tp); err == nil { + t.Fatalf("expected error from non-existent service call") + } else if status, ok := status.FromError(err); !ok { + t.Fatalf("expected status present in error: %v", err) + } else if status.Code() != codes.ResourceExhausted { + t.Fatalf("expected code: %v != %v", status.Code(), codes.ResourceExhausted) + } + + if err := server.Shutdown(ctx); err != nil { + t.Fatal(err) + } + if err := <-errs; err != ErrServerClosed { + t.Fatal(err) + } +} + func checkServerShutdown(t *testing.T, server *Server) { t.Helper() server.mu.Lock()