Merge pull request #171 from klihub/devel/sender-side-oversize-rejection
channel: reject oversized messages on the sender side(, too).
This commit is contained in:
		@@ -143,10 +143,10 @@ func (ch *channel) recv() (messageHeader, []byte, error) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (ch *channel) send(streamID uint32, t messageType, flags uint8, p []byte) error {
 | 
					func (ch *channel) send(streamID uint32, t messageType, flags uint8, p []byte) error {
 | 
				
			||||||
	// TODO: Error on send rather than on recv
 | 
						if len(p) > messageLengthMax {
 | 
				
			||||||
	//if len(p) > messageLengthMax {
 | 
							return OversizedMessageError(len(p))
 | 
				
			||||||
	//	return status.Errorf(codes.InvalidArgument, "refusing to send, message length %v exceed maximum message size of %v", len(p), messageLengthMax)
 | 
						}
 | 
				
			||||||
	//}
 | 
					
 | 
				
			||||||
	if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t, Flags: flags}); err != nil {
 | 
						if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t, Flags: flags}); err != nil {
 | 
				
			||||||
		return err
 | 
							return err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -89,21 +89,19 @@ func TestReadWriteMessage(t *testing.T) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
func TestMessageOversize(t *testing.T) {
 | 
					func TestMessageOversize(t *testing.T) {
 | 
				
			||||||
	var (
 | 
						var (
 | 
				
			||||||
		w, r     = net.Pipe()
 | 
							w, _ = net.Pipe()
 | 
				
			||||||
		wch, rch = newChannel(w), newChannel(r)
 | 
							wch  = newChannel(w)
 | 
				
			||||||
		msg  = bytes.Repeat([]byte("a message of massive length"), 512<<10)
 | 
							msg  = bytes.Repeat([]byte("a message of massive length"), 512<<10)
 | 
				
			||||||
		errs = make(chan error, 1)
 | 
							errs = make(chan error, 1)
 | 
				
			||||||
	)
 | 
						)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	go func() {
 | 
						go func() {
 | 
				
			||||||
		if err := wch.send(1, 1, 0, msg); err != nil {
 | 
							errs <- wch.send(1, 1, 0, msg)
 | 
				
			||||||
			errs <- err
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}()
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	_, _, err := rch.recv()
 | 
						err := <-errs
 | 
				
			||||||
	if err == nil {
 | 
						if err == nil {
 | 
				
			||||||
		t.Fatalf("error expected reading with small buffer")
 | 
							t.Fatalf("sending oversized message expected to fail")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	status, ok := status.FromError(err)
 | 
						status, ok := status.FromError(err)
 | 
				
			||||||
@@ -114,12 +112,4 @@ func TestMessageOversize(t *testing.T) {
 | 
				
			|||||||
	if status.Code() != codes.ResourceExhausted {
 | 
						if status.Code() != codes.ResourceExhausted {
 | 
				
			||||||
		t.Fatalf("expected grpc status code: %v != %v", status.Code(), codes.ResourceExhausted)
 | 
							t.Fatalf("expected grpc status code: %v != %v", status.Code(), codes.ResourceExhausted)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					 | 
				
			||||||
	select {
 | 
					 | 
				
			||||||
	case err := <-errs:
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			t.Fatal(err)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	default:
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										48
									
								
								errors.go
									
									
									
									
									
								
							
							
						
						
									
										48
									
								
								errors.go
									
									
									
									
									
								
							@@ -16,7 +16,12 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package ttrpc
 | 
					package ttrpc
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import "errors"
 | 
					import (
 | 
				
			||||||
 | 
						"errors"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"google.golang.org/grpc/codes"
 | 
				
			||||||
 | 
						"google.golang.org/grpc/status"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var (
 | 
					var (
 | 
				
			||||||
	// ErrProtocol is a general error in the handling the protocol.
 | 
						// ErrProtocol is a general error in the handling the protocol.
 | 
				
			||||||
@@ -32,3 +37,44 @@ var (
 | 
				
			|||||||
	// ErrStreamClosed is when the streaming connection is closed.
 | 
						// ErrStreamClosed is when the streaming connection is closed.
 | 
				
			||||||
	ErrStreamClosed = errors.New("ttrpc: stream closed")
 | 
						ErrStreamClosed = errors.New("ttrpc: stream closed")
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// OversizedMessageErr is used to indicate refusal to send an oversized message.
 | 
				
			||||||
 | 
					// It wraps a ResourceExhausted grpc Status together with the offending message
 | 
				
			||||||
 | 
					// length.
 | 
				
			||||||
 | 
					type OversizedMessageErr struct {
 | 
				
			||||||
 | 
						messageLength int
 | 
				
			||||||
 | 
						err           error
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// OversizedMessageError returns an OversizedMessageErr error for the given message
 | 
				
			||||||
 | 
					// length if it exceeds the allowed maximum. Otherwise a nil error is returned.
 | 
				
			||||||
 | 
					func OversizedMessageError(messageLength int) error {
 | 
				
			||||||
 | 
						if messageLength <= messageLengthMax {
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return &OversizedMessageErr{
 | 
				
			||||||
 | 
							messageLength: messageLength,
 | 
				
			||||||
 | 
							err:           status.Errorf(codes.ResourceExhausted, "message length %v exceed maximum message size of %v", messageLength, messageLengthMax),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Error returns the error message for the corresponding grpc Status for the error.
 | 
				
			||||||
 | 
					func (e *OversizedMessageErr) Error() string {
 | 
				
			||||||
 | 
						return e.err.Error()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Unwrap returns the corresponding error with our grpc status code.
 | 
				
			||||||
 | 
					func (e *OversizedMessageErr) Unwrap() error {
 | 
				
			||||||
 | 
						return e.err
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// RejectedLength retrieves the rejected message length which triggered the error.
 | 
				
			||||||
 | 
					func (e *OversizedMessageErr) RejectedLength() int {
 | 
				
			||||||
 | 
						return e.messageLength
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// MaximumLength retrieves the maximum allowed message length that triggered the error.
 | 
				
			||||||
 | 
					func (*OversizedMessageErr) MaximumLength() int {
 | 
				
			||||||
 | 
						return messageLengthMax
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user