Add recvClose channel to stream

Prevent panic from closing recv channel, which may be written to after
close. Use a separate channel to signal recv has closed and check that
channel on read and write.

Signed-off-by: Derek McGowan <derek@mcg.dev>
This commit is contained in:
Derek McGowan 2023-04-28 09:19:09 -07:00
parent 98b5f64998
commit 471297eed9
2 changed files with 62 additions and 55 deletions

View File

@ -214,13 +214,20 @@ func (cs *clientStream) RecvMsg(m interface{}) error {
if cs.remoteClosed {
return io.EOF
}
var msg *streamMessage
select {
case <-cs.ctx.Done():
return cs.ctx.Err()
case msg, ok := <-cs.s.recv:
if !ok {
case <-cs.s.recvClose:
// If recv has a pending message, process that first
select {
case msg = <-cs.s.recv:
default:
return cs.s.recvErr
}
case msg = <-cs.s.recv:
}
if msg.header.Type == messageTypeResponse {
resp := &Response{}
@ -268,7 +275,6 @@ func (cs *clientStream) RecvMsg(m interface{}) error {
return fmt.Errorf("unexpected %q message received: %w", msg.header.Type, ErrProtocol)
}
}
// Close closes the ttrpc connection and underlying connection
func (c *Client) Close() error {
@ -482,11 +488,9 @@ func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) err
return ctx.Err()
case <-c.ctx.Done():
return ErrClosed
case msg, ok := <-s.recv:
if !ok {
case <-s.recvClose:
return s.recvErr
}
case msg := <-s.recv:
if msg.header.Type == messageTypeResponse {
err = proto.Unmarshal(msg.payload[:msg.header.Length], resp)
} else {

View File

@ -35,6 +35,7 @@ type stream struct {
closeOnce sync.Once
recvErr error
recvClose chan struct{}
}
func newStream(id streamID, send sender) *stream {
@ -42,20 +43,18 @@ func newStream(id streamID, send sender) *stream {
id: id,
sender: send,
recv: make(chan *streamMessage, 1),
recvClose: make(chan struct{}),
}
}
func (s *stream) closeWithError(err error) error {
s.closeOnce.Do(func() {
if s.recv != nil {
close(s.recv)
if err != nil {
s.recvErr = err
} else {
s.recvErr = ErrClosed
}
}
close(s.recvClose)
})
return nil
}
@ -65,10 +64,14 @@ func (s *stream) send(mt messageType, flags uint8, b []byte) error {
}
func (s *stream) receive(ctx context.Context, msg *streamMessage) error {
if s.recvErr != nil {
select {
case <-s.recvClose:
return s.recvErr
default:
}
select {
case <-s.recvClose:
return s.recvErr
case s.recv <- msg:
return nil
case <-ctx.Done():