Add recvClose channel to stream
Prevent panic from closing recv channel, which may be written to after close. Use a separate channel to signal recv has closed and check that channel on read and write. Signed-off-by: Derek McGowan <derek@mcg.dev>
This commit is contained in:
18
client.go
18
client.go
@@ -214,13 +214,20 @@ func (cs *clientStream) RecvMsg(m interface{}) error {
|
|||||||
if cs.remoteClosed {
|
if cs.remoteClosed {
|
||||||
return io.EOF
|
return io.EOF
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var msg *streamMessage
|
||||||
select {
|
select {
|
||||||
case <-cs.ctx.Done():
|
case <-cs.ctx.Done():
|
||||||
return cs.ctx.Err()
|
return cs.ctx.Err()
|
||||||
case msg, ok := <-cs.s.recv:
|
case <-cs.s.recvClose:
|
||||||
if !ok {
|
// If recv has a pending message, process that first
|
||||||
|
select {
|
||||||
|
case msg = <-cs.s.recv:
|
||||||
|
default:
|
||||||
return cs.s.recvErr
|
return cs.s.recvErr
|
||||||
}
|
}
|
||||||
|
case msg = <-cs.s.recv:
|
||||||
|
}
|
||||||
|
|
||||||
if msg.header.Type == messageTypeResponse {
|
if msg.header.Type == messageTypeResponse {
|
||||||
resp := &Response{}
|
resp := &Response{}
|
||||||
@@ -268,7 +275,6 @@ func (cs *clientStream) RecvMsg(m interface{}) error {
|
|||||||
|
|
||||||
return fmt.Errorf("unexpected %q message received: %w", msg.header.Type, ErrProtocol)
|
return fmt.Errorf("unexpected %q message received: %w", msg.header.Type, ErrProtocol)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Close closes the ttrpc connection and underlying connection
|
// Close closes the ttrpc connection and underlying connection
|
||||||
func (c *Client) Close() error {
|
func (c *Client) Close() error {
|
||||||
@@ -482,11 +488,9 @@ func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) err
|
|||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
case <-c.ctx.Done():
|
case <-c.ctx.Done():
|
||||||
return ErrClosed
|
return ErrClosed
|
||||||
case msg, ok := <-s.recv:
|
case <-s.recvClose:
|
||||||
if !ok {
|
|
||||||
return s.recvErr
|
return s.recvErr
|
||||||
}
|
case msg := <-s.recv:
|
||||||
|
|
||||||
if msg.header.Type == messageTypeResponse {
|
if msg.header.Type == messageTypeResponse {
|
||||||
err = proto.Unmarshal(msg.payload[:msg.header.Length], resp)
|
err = proto.Unmarshal(msg.payload[:msg.header.Length], resp)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
13
stream.go
13
stream.go
@@ -35,6 +35,7 @@ type stream struct {
|
|||||||
|
|
||||||
closeOnce sync.Once
|
closeOnce sync.Once
|
||||||
recvErr error
|
recvErr error
|
||||||
|
recvClose chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newStream(id streamID, send sender) *stream {
|
func newStream(id streamID, send sender) *stream {
|
||||||
@@ -42,20 +43,18 @@ func newStream(id streamID, send sender) *stream {
|
|||||||
id: id,
|
id: id,
|
||||||
sender: send,
|
sender: send,
|
||||||
recv: make(chan *streamMessage, 1),
|
recv: make(chan *streamMessage, 1),
|
||||||
|
recvClose: make(chan struct{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *stream) closeWithError(err error) error {
|
func (s *stream) closeWithError(err error) error {
|
||||||
s.closeOnce.Do(func() {
|
s.closeOnce.Do(func() {
|
||||||
if s.recv != nil {
|
|
||||||
close(s.recv)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.recvErr = err
|
s.recvErr = err
|
||||||
} else {
|
} else {
|
||||||
s.recvErr = ErrClosed
|
s.recvErr = ErrClosed
|
||||||
}
|
}
|
||||||
|
close(s.recvClose)
|
||||||
}
|
|
||||||
})
|
})
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -65,10 +64,14 @@ func (s *stream) send(mt messageType, flags uint8, b []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *stream) receive(ctx context.Context, msg *streamMessage) error {
|
func (s *stream) receive(ctx context.Context, msg *streamMessage) error {
|
||||||
if s.recvErr != nil {
|
select {
|
||||||
|
case <-s.recvClose:
|
||||||
return s.recvErr
|
return s.recvErr
|
||||||
|
default:
|
||||||
}
|
}
|
||||||
select {
|
select {
|
||||||
|
case <-s.recvClose:
|
||||||
|
return s.recvErr
|
||||||
case s.recv <- msg:
|
case s.recv <- msg:
|
||||||
return nil
|
return nil
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
|
|||||||
Reference in New Issue
Block a user