Merge pull request #94 from kevpar/deadlock-new
client: Handle sending/receiving in separate goroutines
This commit is contained in:
commit
d2d6bb6f89
177
client.go
177
client.go
@ -194,72 +194,131 @@ type message struct {
|
|||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
type receiver struct {
|
// callMap provides access to a map of active calls, guarded by a mutex.
|
||||||
wg *sync.WaitGroup
|
type callMap struct {
|
||||||
messages chan *message
|
m sync.Mutex
|
||||||
err error
|
activeCalls map[uint32]*callRequest
|
||||||
|
closeErr error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *receiver) run(ctx context.Context, c *channel) {
|
// newCallMap returns a new callMap with an empty set of active calls.
|
||||||
defer r.wg.Done()
|
func newCallMap() *callMap {
|
||||||
|
return &callMap{
|
||||||
|
activeCalls: make(map[uint32]*callRequest),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// set adds a call entry to the map with the given streamID key.
|
||||||
|
func (cm *callMap) set(streamID uint32, cr *callRequest) error {
|
||||||
|
cm.m.Lock()
|
||||||
|
defer cm.m.Unlock()
|
||||||
|
if cm.closeErr != nil {
|
||||||
|
return cm.closeErr
|
||||||
|
}
|
||||||
|
cm.activeCalls[streamID] = cr
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// get looks up the call entry for the given streamID key, then removes it
|
||||||
|
// from the map and returns it.
|
||||||
|
func (cm *callMap) get(streamID uint32) (cr *callRequest, ok bool, err error) {
|
||||||
|
cm.m.Lock()
|
||||||
|
defer cm.m.Unlock()
|
||||||
|
if cm.closeErr != nil {
|
||||||
|
return nil, false, cm.closeErr
|
||||||
|
}
|
||||||
|
cr, ok = cm.activeCalls[streamID]
|
||||||
|
if ok {
|
||||||
|
delete(cm.activeCalls, streamID)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// abort sends the given error to each active call, and clears the map.
|
||||||
|
// Once abort has been called, any subsequent calls to the callMap will return the error passed to abort.
|
||||||
|
func (cm *callMap) abort(err error) error {
|
||||||
|
cm.m.Lock()
|
||||||
|
defer cm.m.Unlock()
|
||||||
|
if cm.closeErr != nil {
|
||||||
|
return cm.closeErr
|
||||||
|
}
|
||||||
|
for streamID, call := range cm.activeCalls {
|
||||||
|
call.errs <- err
|
||||||
|
delete(cm.activeCalls, streamID)
|
||||||
|
}
|
||||||
|
cm.closeErr = err
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) run() {
|
||||||
|
var (
|
||||||
|
waiters = newCallMap()
|
||||||
|
receiverDone = make(chan struct{})
|
||||||
|
)
|
||||||
|
|
||||||
|
// Sender goroutine
|
||||||
|
// Receives calls from dispatch, adds them to the set of active calls, and sends them
|
||||||
|
// to the server.
|
||||||
|
go func() {
|
||||||
|
var streamID uint32 = 1
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-c.ctx.Done():
|
||||||
r.err = ctx.Err()
|
return
|
||||||
|
case call := <-c.calls:
|
||||||
|
id := streamID
|
||||||
|
streamID += 2 // enforce odd client initiated request ids
|
||||||
|
if err := waiters.set(id, call); err != nil {
|
||||||
|
call.errs <- err // errs is buffered so should not block.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := c.send(id, messageTypeRequest, call.req); err != nil {
|
||||||
|
call.errs <- err // errs is buffered so should not block.
|
||||||
|
waiters.get(id) // remove from waiters set
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Receiver goroutine
|
||||||
|
// Receives responses from the server, looks up the call info in the set of active calls,
|
||||||
|
// and notifies the caller of the response.
|
||||||
|
go func() {
|
||||||
|
defer close(receiverDone)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-c.ctx.Done():
|
||||||
|
c.setError(c.ctx.Err())
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
mh, p, err := c.recv()
|
mh, p, err := c.channel.recv()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_, ok := status.FromError(err)
|
_, ok := status.FromError(err)
|
||||||
if !ok {
|
if !ok {
|
||||||
// treat all errors that are not an rpc status as terminal.
|
// treat all errors that are not an rpc status as terminal.
|
||||||
// all others poison the connection.
|
// all others poison the connection.
|
||||||
r.err = filterCloseErr(err)
|
c.setError(filterCloseErr(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
select {
|
msg := &message{
|
||||||
case r.messages <- &message{
|
|
||||||
messageHeader: mh,
|
messageHeader: mh,
|
||||||
p: p[:mh.Length],
|
p: p[:mh.Length],
|
||||||
err: err,
|
err: err,
|
||||||
}:
|
}
|
||||||
case <-ctx.Done():
|
call, ok, err := waiters.get(mh.StreamID)
|
||||||
r.err = ctx.Err()
|
if err != nil {
|
||||||
return
|
logrus.Errorf("ttrpc: failed to look up active call: %s", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
logrus.Errorf("ttrpc: received message for unknown channel %v", mh.StreamID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
call.errs <- c.recv(call.resp, msg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) run() {
|
|
||||||
var (
|
|
||||||
streamID uint32 = 1
|
|
||||||
waiters = make(map[uint32]*callRequest)
|
|
||||||
calls = c.calls
|
|
||||||
incoming = make(chan *message)
|
|
||||||
receiversDone = make(chan struct{})
|
|
||||||
wg sync.WaitGroup
|
|
||||||
)
|
|
||||||
|
|
||||||
// broadcast the shutdown error to the remaining waiters.
|
|
||||||
abortWaiters := func(wErr error) {
|
|
||||||
for _, waiter := range waiters {
|
|
||||||
waiter.errs <- wErr
|
|
||||||
}
|
|
||||||
}
|
|
||||||
recv := &receiver{
|
|
||||||
wg: &wg,
|
|
||||||
messages: incoming,
|
|
||||||
}
|
|
||||||
wg.Add(1)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
wg.Wait()
|
|
||||||
close(receiversDone)
|
|
||||||
}()
|
}()
|
||||||
go recv.run(c.ctx, c.channel)
|
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
c.conn.Close()
|
c.conn.Close()
|
||||||
@ -269,32 +328,14 @@ func (c *Client) run() {
|
|||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case call := <-calls:
|
case <-receiverDone:
|
||||||
if err := c.send(streamID, messageTypeRequest, call.req); err != nil {
|
// The receiver has exited.
|
||||||
call.errs <- err
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
waiters[streamID] = call
|
|
||||||
streamID += 2 // enforce odd client initiated request ids
|
|
||||||
case msg := <-incoming:
|
|
||||||
call, ok := waiters[msg.StreamID]
|
|
||||||
if !ok {
|
|
||||||
logrus.Errorf("ttrpc: received message for unknown channel %v", msg.StreamID)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
call.errs <- c.recv(call.resp, msg)
|
|
||||||
delete(waiters, msg.StreamID)
|
|
||||||
case <-receiversDone:
|
|
||||||
// all the receivers have exited
|
|
||||||
if recv.err != nil {
|
|
||||||
c.setError(recv.err)
|
|
||||||
}
|
|
||||||
// don't return out, let the close of the context trigger the abort of waiters
|
// don't return out, let the close of the context trigger the abort of waiters
|
||||||
c.Close()
|
c.Close()
|
||||||
case <-c.ctx.Done():
|
case <-c.ctx.Done():
|
||||||
abortWaiters(c.error())
|
// Abort all active calls. This will also prevent any new calls from being added
|
||||||
|
// to waiters.
|
||||||
|
waiters.abort(c.error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user