Add recvClose channel to stream

Prevent panic from closing recv channel, which may be written to after
close. Use a separate channel to signal recv has closed and check that
channel on read and write.

Signed-off-by: Derek McGowan <derek@mcg.dev>
This commit is contained in:
Derek McGowan 2023-04-28 09:19:09 -07:00
parent 98b5f64998
commit 471297eed9
2 changed files with 62 additions and 55 deletions

View File

@ -214,60 +214,66 @@ func (cs *clientStream) RecvMsg(m interface{}) error {
if cs.remoteClosed { if cs.remoteClosed {
return io.EOF return io.EOF
} }
var msg *streamMessage
select { select {
case <-cs.ctx.Done(): case <-cs.ctx.Done():
return cs.ctx.Err() return cs.ctx.Err()
case msg, ok := <-cs.s.recv: case <-cs.s.recvClose:
if !ok { // If recv has a pending message, process that first
select {
case msg = <-cs.s.recv:
default:
return cs.s.recvErr return cs.s.recvErr
} }
case msg = <-cs.s.recv:
}
if msg.header.Type == messageTypeResponse { if msg.header.Type == messageTypeResponse {
resp := &Response{} resp := &Response{}
err := proto.Unmarshal(msg.payload[:msg.header.Length], resp) err := proto.Unmarshal(msg.payload[:msg.header.Length], resp)
// return the payload buffer for reuse // return the payload buffer for reuse
cs.c.channel.putmbuf(msg.payload) cs.c.channel.putmbuf(msg.payload)
if err != nil { if err != nil {
return err return err
} }
if err := cs.c.codec.Unmarshal(resp.Payload, m); err != nil { if err := cs.c.codec.Unmarshal(resp.Payload, m); err != nil {
return err return err
} }
if resp.Status != nil && resp.Status.Code != int32(codes.OK) { if resp.Status != nil && resp.Status.Code != int32(codes.OK) {
return status.ErrorProto(resp.Status) return status.ErrorProto(resp.Status)
} }
cs.c.deleteStream(cs.s)
cs.remoteClosed = true
return nil
} else if msg.header.Type == messageTypeData {
if !cs.desc.StreamingServer {
cs.c.deleteStream(cs.s)
cs.remoteClosed = true
return fmt.Errorf("received data from non-streaming server: %w", ErrProtocol)
}
if msg.header.Flags&flagRemoteClosed == flagRemoteClosed {
cs.c.deleteStream(cs.s) cs.c.deleteStream(cs.s)
cs.remoteClosed = true cs.remoteClosed = true
return nil if msg.header.Flags&flagNoData == flagNoData {
} else if msg.header.Type == messageTypeData { return io.EOF
if !cs.desc.StreamingServer {
cs.c.deleteStream(cs.s)
cs.remoteClosed = true
return fmt.Errorf("received data from non-streaming server: %w", ErrProtocol)
} }
if msg.header.Flags&flagRemoteClosed == flagRemoteClosed {
cs.c.deleteStream(cs.s)
cs.remoteClosed = true
if msg.header.Flags&flagNoData == flagNoData {
return io.EOF
}
}
err := cs.c.codec.Unmarshal(msg.payload[:msg.header.Length], m)
cs.c.channel.putmbuf(msg.payload)
if err != nil {
return err
}
return nil
} }
return fmt.Errorf("unexpected %q message received: %w", msg.header.Type, ErrProtocol) err := cs.c.codec.Unmarshal(msg.payload[:msg.header.Length], m)
cs.c.channel.putmbuf(msg.payload)
if err != nil {
return err
}
return nil
} }
return fmt.Errorf("unexpected %q message received: %w", msg.header.Type, ErrProtocol)
} }
// Close closes the ttrpc connection and underlying connection // Close closes the ttrpc connection and underlying connection
@ -482,11 +488,9 @@ func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) err
return ctx.Err() return ctx.Err()
case <-c.ctx.Done(): case <-c.ctx.Done():
return ErrClosed return ErrClosed
case msg, ok := <-s.recv: case <-s.recvClose:
if !ok { return s.recvErr
return s.recvErr case msg := <-s.recv:
}
if msg.header.Type == messageTypeResponse { if msg.header.Type == messageTypeResponse {
err = proto.Unmarshal(msg.payload[:msg.header.Length], resp) err = proto.Unmarshal(msg.payload[:msg.header.Length], resp)
} else { } else {

View File

@ -35,27 +35,26 @@ type stream struct {
closeOnce sync.Once closeOnce sync.Once
recvErr error recvErr error
recvClose chan struct{}
} }
func newStream(id streamID, send sender) *stream { func newStream(id streamID, send sender) *stream {
return &stream{ return &stream{
id: id, id: id,
sender: send, sender: send,
recv: make(chan *streamMessage, 1), recv: make(chan *streamMessage, 1),
recvClose: make(chan struct{}),
} }
} }
func (s *stream) closeWithError(err error) error { func (s *stream) closeWithError(err error) error {
s.closeOnce.Do(func() { s.closeOnce.Do(func() {
if s.recv != nil { if err != nil {
close(s.recv) s.recvErr = err
if err != nil { } else {
s.recvErr = err s.recvErr = ErrClosed
} else {
s.recvErr = ErrClosed
}
} }
close(s.recvClose)
}) })
return nil return nil
} }
@ -65,10 +64,14 @@ func (s *stream) send(mt messageType, flags uint8, b []byte) error {
} }
func (s *stream) receive(ctx context.Context, msg *streamMessage) error { func (s *stream) receive(ctx context.Context, msg *streamMessage) error {
if s.recvErr != nil { select {
case <-s.recvClose:
return s.recvErr return s.recvErr
default:
} }
select { select {
case <-s.recvClose:
return s.recvErr
case s.recv <- msg: case s.recv <- msg:
return nil return nil
case <-ctx.Done(): case <-ctx.Done():