Merge pull request #171 from klihub/devel/sender-side-oversize-rejection
channel: reject oversized messages on the sender side(, too).
This commit is contained in:
commit
bcc40a4d69
@ -143,10 +143,10 @@ func (ch *channel) recv() (messageHeader, []byte, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (ch *channel) send(streamID uint32, t messageType, flags uint8, p []byte) error {
|
func (ch *channel) send(streamID uint32, t messageType, flags uint8, p []byte) error {
|
||||||
// TODO: Error on send rather than on recv
|
if len(p) > messageLengthMax {
|
||||||
//if len(p) > messageLengthMax {
|
return OversizedMessageError(len(p))
|
||||||
// return status.Errorf(codes.InvalidArgument, "refusing to send, message length %v exceed maximum message size of %v", len(p), messageLengthMax)
|
}
|
||||||
//}
|
|
||||||
if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t, Flags: flags}); err != nil {
|
if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t, Flags: flags}); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -89,21 +89,19 @@ func TestReadWriteMessage(t *testing.T) {
|
|||||||
|
|
||||||
func TestMessageOversize(t *testing.T) {
|
func TestMessageOversize(t *testing.T) {
|
||||||
var (
|
var (
|
||||||
w, r = net.Pipe()
|
w, _ = net.Pipe()
|
||||||
wch, rch = newChannel(w), newChannel(r)
|
wch = newChannel(w)
|
||||||
msg = bytes.Repeat([]byte("a message of massive length"), 512<<10)
|
msg = bytes.Repeat([]byte("a message of massive length"), 512<<10)
|
||||||
errs = make(chan error, 1)
|
errs = make(chan error, 1)
|
||||||
)
|
)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
if err := wch.send(1, 1, 0, msg); err != nil {
|
errs <- wch.send(1, 1, 0, msg)
|
||||||
errs <- err
|
|
||||||
}
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
_, _, err := rch.recv()
|
err := <-errs
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatalf("error expected reading with small buffer")
|
t.Fatalf("sending oversized message expected to fail")
|
||||||
}
|
}
|
||||||
|
|
||||||
status, ok := status.FromError(err)
|
status, ok := status.FromError(err)
|
||||||
@ -114,12 +112,4 @@ func TestMessageOversize(t *testing.T) {
|
|||||||
if status.Code() != codes.ResourceExhausted {
|
if status.Code() != codes.ResourceExhausted {
|
||||||
t.Fatalf("expected grpc status code: %v != %v", status.Code(), codes.ResourceExhausted)
|
t.Fatalf("expected grpc status code: %v != %v", status.Code(), codes.ResourceExhausted)
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
|
||||||
case err := <-errs:
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
48
errors.go
48
errors.go
@ -16,7 +16,12 @@
|
|||||||
|
|
||||||
package ttrpc
|
package ttrpc
|
||||||
|
|
||||||
import "errors"
|
import (
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// ErrProtocol is a general error in the handling the protocol.
|
// ErrProtocol is a general error in the handling the protocol.
|
||||||
@ -32,3 +37,44 @@ var (
|
|||||||
// ErrStreamClosed is when the streaming connection is closed.
|
// ErrStreamClosed is when the streaming connection is closed.
|
||||||
ErrStreamClosed = errors.New("ttrpc: stream closed")
|
ErrStreamClosed = errors.New("ttrpc: stream closed")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// OversizedMessageErr is used to indicate refusal to send an oversized message.
|
||||||
|
// It wraps a ResourceExhausted grpc Status together with the offending message
|
||||||
|
// length.
|
||||||
|
type OversizedMessageErr struct {
|
||||||
|
messageLength int
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
// OversizedMessageError returns an OversizedMessageErr error for the given message
|
||||||
|
// length if it exceeds the allowed maximum. Otherwise a nil error is returned.
|
||||||
|
func OversizedMessageError(messageLength int) error {
|
||||||
|
if messageLength <= messageLengthMax {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &OversizedMessageErr{
|
||||||
|
messageLength: messageLength,
|
||||||
|
err: status.Errorf(codes.ResourceExhausted, "message length %v exceed maximum message size of %v", messageLength, messageLengthMax),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error returns the error message for the corresponding grpc Status for the error.
|
||||||
|
func (e *OversizedMessageErr) Error() string {
|
||||||
|
return e.err.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unwrap returns the corresponding error with our grpc status code.
|
||||||
|
func (e *OversizedMessageErr) Unwrap() error {
|
||||||
|
return e.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// RejectedLength retrieves the rejected message length which triggered the error.
|
||||||
|
func (e *OversizedMessageErr) RejectedLength() int {
|
||||||
|
return e.messageLength
|
||||||
|
}
|
||||||
|
|
||||||
|
// MaximumLength retrieves the maximum allowed message length that triggered the error.
|
||||||
|
func (*OversizedMessageErr) MaximumLength() int {
|
||||||
|
return messageLengthMax
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user