diff --git a/client.go b/client.go index 0abc702..b73116b 100644 --- a/client.go +++ b/client.go @@ -214,60 +214,66 @@ func (cs *clientStream) RecvMsg(m interface{}) error { if cs.remoteClosed { return io.EOF } + + var msg *streamMessage select { case <-cs.ctx.Done(): return cs.ctx.Err() - case msg, ok := <-cs.s.recv: - if !ok { + case <-cs.s.recvClose: + // If recv has a pending message, process that first + select { + case msg = <-cs.s.recv: + default: return cs.s.recvErr } + case msg = <-cs.s.recv: + } - if msg.header.Type == messageTypeResponse { - resp := &Response{} - err := proto.Unmarshal(msg.payload[:msg.header.Length], resp) - // return the payload buffer for reuse - cs.c.channel.putmbuf(msg.payload) - if err != nil { - return err - } + if msg.header.Type == messageTypeResponse { + resp := &Response{} + err := proto.Unmarshal(msg.payload[:msg.header.Length], resp) + // return the payload buffer for reuse + cs.c.channel.putmbuf(msg.payload) + if err != nil { + return err + } - if err := cs.c.codec.Unmarshal(resp.Payload, m); err != nil { - return err - } + if err := cs.c.codec.Unmarshal(resp.Payload, m); err != nil { + return err + } - if resp.Status != nil && resp.Status.Code != int32(codes.OK) { - return status.ErrorProto(resp.Status) - } + if resp.Status != nil && resp.Status.Code != int32(codes.OK) { + return status.ErrorProto(resp.Status) + } + cs.c.deleteStream(cs.s) + cs.remoteClosed = true + + return nil + } else if msg.header.Type == messageTypeData { + if !cs.desc.StreamingServer { + cs.c.deleteStream(cs.s) + cs.remoteClosed = true + return fmt.Errorf("received data from non-streaming server: %w", ErrProtocol) + } + if msg.header.Flags&flagRemoteClosed == flagRemoteClosed { cs.c.deleteStream(cs.s) cs.remoteClosed = true - return nil - } else if msg.header.Type == messageTypeData { - if !cs.desc.StreamingServer { - cs.c.deleteStream(cs.s) - cs.remoteClosed = true - return fmt.Errorf("received data from non-streaming server: %w", ErrProtocol) + if msg.header.Flags&flagNoData == flagNoData { + return io.EOF } - if msg.header.Flags&flagRemoteClosed == flagRemoteClosed { - cs.c.deleteStream(cs.s) - cs.remoteClosed = true - - if msg.header.Flags&flagNoData == flagNoData { - return io.EOF - } - } - - err := cs.c.codec.Unmarshal(msg.payload[:msg.header.Length], m) - cs.c.channel.putmbuf(msg.payload) - if err != nil { - return err - } - return nil } - return fmt.Errorf("unexpected %q message received: %w", msg.header.Type, ErrProtocol) + err := cs.c.codec.Unmarshal(msg.payload[:msg.header.Length], m) + cs.c.channel.putmbuf(msg.payload) + if err != nil { + return err + } + return nil } + + return fmt.Errorf("unexpected %q message received: %w", msg.header.Type, ErrProtocol) } // Close closes the ttrpc connection and underlying connection @@ -482,11 +488,9 @@ func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) err return ctx.Err() case <-c.ctx.Done(): return ErrClosed - case msg, ok := <-s.recv: - if !ok { - return s.recvErr - } - + case <-s.recvClose: + return s.recvErr + case msg := <-s.recv: if msg.header.Type == messageTypeResponse { err = proto.Unmarshal(msg.payload[:msg.header.Length], resp) } else { diff --git a/stream.go b/stream.go index 5f264fe..739a4c9 100644 --- a/stream.go +++ b/stream.go @@ -35,27 +35,26 @@ type stream struct { closeOnce sync.Once recvErr error + recvClose chan struct{} } func newStream(id streamID, send sender) *stream { return &stream{ - id: id, - sender: send, - recv: make(chan *streamMessage, 1), + id: id, + sender: send, + recv: make(chan *streamMessage, 1), + recvClose: make(chan struct{}), } } func (s *stream) closeWithError(err error) error { s.closeOnce.Do(func() { - if s.recv != nil { - close(s.recv) - if err != nil { - s.recvErr = err - } else { - s.recvErr = ErrClosed - } - + if err != nil { + s.recvErr = err + } else { + s.recvErr = ErrClosed } + close(s.recvClose) }) return nil } @@ -65,10 +64,14 @@ func (s *stream) send(mt messageType, flags uint8, b []byte) error { } func (s *stream) receive(ctx context.Context, msg *streamMessage) error { - if s.recvErr != nil { + select { + case <-s.recvClose: return s.recvErr + default: } select { + case <-s.recvClose: + return s.recvErr case s.recv <- msg: return nil case <-ctx.Done():