Add recvClose channel to stream
Prevent panic from closing recv channel, which may be written to after close. Use a separate channel to signal recv has closed and check that channel on read and write. Signed-off-by: Derek McGowan <derek@mcg.dev>
This commit is contained in:
parent
98b5f64998
commit
471297eed9
18
client.go
18
client.go
@ -214,13 +214,20 @@ func (cs *clientStream) RecvMsg(m interface{}) error {
|
||||
if cs.remoteClosed {
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
var msg *streamMessage
|
||||
select {
|
||||
case <-cs.ctx.Done():
|
||||
return cs.ctx.Err()
|
||||
case msg, ok := <-cs.s.recv:
|
||||
if !ok {
|
||||
case <-cs.s.recvClose:
|
||||
// If recv has a pending message, process that first
|
||||
select {
|
||||
case msg = <-cs.s.recv:
|
||||
default:
|
||||
return cs.s.recvErr
|
||||
}
|
||||
case msg = <-cs.s.recv:
|
||||
}
|
||||
|
||||
if msg.header.Type == messageTypeResponse {
|
||||
resp := &Response{}
|
||||
@ -267,7 +274,6 @@ func (cs *clientStream) RecvMsg(m interface{}) error {
|
||||
}
|
||||
|
||||
return fmt.Errorf("unexpected %q message received: %w", msg.header.Type, ErrProtocol)
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the ttrpc connection and underlying connection
|
||||
@ -482,11 +488,9 @@ func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) err
|
||||
return ctx.Err()
|
||||
case <-c.ctx.Done():
|
||||
return ErrClosed
|
||||
case msg, ok := <-s.recv:
|
||||
if !ok {
|
||||
case <-s.recvClose:
|
||||
return s.recvErr
|
||||
}
|
||||
|
||||
case msg := <-s.recv:
|
||||
if msg.header.Type == messageTypeResponse {
|
||||
err = proto.Unmarshal(msg.payload[:msg.header.Length], resp)
|
||||
} else {
|
||||
|
13
stream.go
13
stream.go
@ -35,6 +35,7 @@ type stream struct {
|
||||
|
||||
closeOnce sync.Once
|
||||
recvErr error
|
||||
recvClose chan struct{}
|
||||
}
|
||||
|
||||
func newStream(id streamID, send sender) *stream {
|
||||
@ -42,20 +43,18 @@ func newStream(id streamID, send sender) *stream {
|
||||
id: id,
|
||||
sender: send,
|
||||
recv: make(chan *streamMessage, 1),
|
||||
recvClose: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *stream) closeWithError(err error) error {
|
||||
s.closeOnce.Do(func() {
|
||||
if s.recv != nil {
|
||||
close(s.recv)
|
||||
if err != nil {
|
||||
s.recvErr = err
|
||||
} else {
|
||||
s.recvErr = ErrClosed
|
||||
}
|
||||
|
||||
}
|
||||
close(s.recvClose)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
@ -65,10 +64,14 @@ func (s *stream) send(mt messageType, flags uint8, b []byte) error {
|
||||
}
|
||||
|
||||
func (s *stream) receive(ctx context.Context, msg *streamMessage) error {
|
||||
if s.recvErr != nil {
|
||||
select {
|
||||
case <-s.recvClose:
|
||||
return s.recvErr
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case <-s.recvClose:
|
||||
return s.recvErr
|
||||
case s.recv <- msg:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
|
Loading…
Reference in New Issue
Block a user