Merge pull request #94 from kevpar/deadlock-new

client: Handle sending/receiving in separate goroutines
This commit is contained in:
Derek McGowan 2021-10-29 17:48:08 -07:00 committed by GitHub
commit d2d6bb6f89
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

197
client.go
View File

@ -194,72 +194,131 @@ type message struct {
err error
}
type receiver struct {
wg *sync.WaitGroup
messages chan *message
err error
// callMap provides access to a map of active calls, guarded by a mutex.
type callMap struct {
m sync.Mutex
activeCalls map[uint32]*callRequest
closeErr error
}
func (r *receiver) run(ctx context.Context, c *channel) {
defer r.wg.Done()
for {
select {
case <-ctx.Done():
r.err = ctx.Err()
return
default:
mh, p, err := c.recv()
if err != nil {
_, ok := status.FromError(err)
if !ok {
// treat all errors that are not an rpc status as terminal.
// all others poison the connection.
r.err = filterCloseErr(err)
return
}
}
select {
case r.messages <- &message{
messageHeader: mh,
p: p[:mh.Length],
err: err,
}:
case <-ctx.Done():
r.err = ctx.Err()
return
}
}
// newCallMap returns a new callMap with an empty set of active calls.
func newCallMap() *callMap {
return &callMap{
activeCalls: make(map[uint32]*callRequest),
}
}
// set adds a call entry to the map with the given streamID key.
func (cm *callMap) set(streamID uint32, cr *callRequest) error {
cm.m.Lock()
defer cm.m.Unlock()
if cm.closeErr != nil {
return cm.closeErr
}
cm.activeCalls[streamID] = cr
return nil
}
// get looks up the call entry for the given streamID key, then removes it
// from the map and returns it.
func (cm *callMap) get(streamID uint32) (cr *callRequest, ok bool, err error) {
cm.m.Lock()
defer cm.m.Unlock()
if cm.closeErr != nil {
return nil, false, cm.closeErr
}
cr, ok = cm.activeCalls[streamID]
if ok {
delete(cm.activeCalls, streamID)
}
return
}
// abort sends the given error to each active call, and clears the map.
// Once abort has been called, any subsequent calls to the callMap will return the error passed to abort.
func (cm *callMap) abort(err error) error {
cm.m.Lock()
defer cm.m.Unlock()
if cm.closeErr != nil {
return cm.closeErr
}
for streamID, call := range cm.activeCalls {
call.errs <- err
delete(cm.activeCalls, streamID)
}
cm.closeErr = err
return nil
}
func (c *Client) run() {
var (
streamID uint32 = 1
waiters = make(map[uint32]*callRequest)
calls = c.calls
incoming = make(chan *message)
receiversDone = make(chan struct{})
wg sync.WaitGroup
waiters = newCallMap()
receiverDone = make(chan struct{})
)
// broadcast the shutdown error to the remaining waiters.
abortWaiters := func(wErr error) {
for _, waiter := range waiters {
waiter.errs <- wErr
}
}
recv := &receiver{
wg: &wg,
messages: incoming,
}
wg.Add(1)
// Sender goroutine
// Receives calls from dispatch, adds them to the set of active calls, and sends them
// to the server.
go func() {
wg.Wait()
close(receiversDone)
var streamID uint32 = 1
for {
select {
case <-c.ctx.Done():
return
case call := <-c.calls:
id := streamID
streamID += 2 // enforce odd client initiated request ids
if err := waiters.set(id, call); err != nil {
call.errs <- err // errs is buffered so should not block.
continue
}
if err := c.send(id, messageTypeRequest, call.req); err != nil {
call.errs <- err // errs is buffered so should not block.
waiters.get(id) // remove from waiters set
}
}
}
}()
// Receiver goroutine
// Receives responses from the server, looks up the call info in the set of active calls,
// and notifies the caller of the response.
go func() {
defer close(receiverDone)
for {
select {
case <-c.ctx.Done():
c.setError(c.ctx.Err())
return
default:
mh, p, err := c.channel.recv()
if err != nil {
_, ok := status.FromError(err)
if !ok {
// treat all errors that are not an rpc status as terminal.
// all others poison the connection.
c.setError(filterCloseErr(err))
return
}
}
msg := &message{
messageHeader: mh,
p: p[:mh.Length],
err: err,
}
call, ok, err := waiters.get(mh.StreamID)
if err != nil {
logrus.Errorf("ttrpc: failed to look up active call: %s", err)
continue
}
if !ok {
logrus.Errorf("ttrpc: received message for unknown channel %v", mh.StreamID)
continue
}
call.errs <- c.recv(call.resp, msg)
}
}
}()
go recv.run(c.ctx, c.channel)
defer func() {
c.conn.Close()
@ -269,32 +328,14 @@ func (c *Client) run() {
for {
select {
case call := <-calls:
if err := c.send(streamID, messageTypeRequest, call.req); err != nil {
call.errs <- err
continue
}
waiters[streamID] = call
streamID += 2 // enforce odd client initiated request ids
case msg := <-incoming:
call, ok := waiters[msg.StreamID]
if !ok {
logrus.Errorf("ttrpc: received message for unknown channel %v", msg.StreamID)
continue
}
call.errs <- c.recv(call.resp, msg)
delete(waiters, msg.StreamID)
case <-receiversDone:
// all the receivers have exited
if recv.err != nil {
c.setError(recv.err)
}
case <-receiverDone:
// The receiver has exited.
// don't return out, let the close of the context trigger the abort of waiters
c.Close()
case <-c.ctx.Done():
abortWaiters(c.error())
// Abort all active calls. This will also prevent any new calls from being added
// to waiters.
waiters.abort(c.error())
return
}
}