Refactor close handling for ttrpc clients
Signed-off-by: Michael Crosby <crosbymichael@gmail.com>
This commit is contained in:
		@@ -18,7 +18,6 @@ package ttrpc
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"bufio"
 | 
						"bufio"
 | 
				
			||||||
	"context"
 | 
					 | 
				
			||||||
	"encoding/binary"
 | 
						"encoding/binary"
 | 
				
			||||||
	"io"
 | 
						"io"
 | 
				
			||||||
	"net"
 | 
						"net"
 | 
				
			||||||
@@ -98,7 +97,7 @@ func newChannel(conn net.Conn) *channel {
 | 
				
			|||||||
// returned will be valid and caller should send that along to
 | 
					// returned will be valid and caller should send that along to
 | 
				
			||||||
// the correct consumer. The bytes on the underlying channel
 | 
					// the correct consumer. The bytes on the underlying channel
 | 
				
			||||||
// will be discarded.
 | 
					// will be discarded.
 | 
				
			||||||
func (ch *channel) recv(ctx context.Context) (messageHeader, []byte, error) {
 | 
					func (ch *channel) recv() (messageHeader, []byte, error) {
 | 
				
			||||||
	mh, err := readMessageHeader(ch.hrbuf[:], ch.br)
 | 
						mh, err := readMessageHeader(ch.hrbuf[:], ch.br)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return messageHeader{}, nil, err
 | 
							return messageHeader{}, nil, err
 | 
				
			||||||
@@ -120,7 +119,7 @@ func (ch *channel) recv(ctx context.Context) (messageHeader, []byte, error) {
 | 
				
			|||||||
	return mh, p, nil
 | 
						return mh, p, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (ch *channel) send(ctx context.Context, streamID uint32, t messageType, p []byte) error {
 | 
					func (ch *channel) send(streamID uint32, t messageType, p []byte) error {
 | 
				
			||||||
	if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t}); err != nil {
 | 
						if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t}); err != nil {
 | 
				
			||||||
		return err
 | 
							return err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -18,7 +18,6 @@ package ttrpc
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"bytes"
 | 
						"bytes"
 | 
				
			||||||
	"context"
 | 
					 | 
				
			||||||
	"io"
 | 
						"io"
 | 
				
			||||||
	"net"
 | 
						"net"
 | 
				
			||||||
	"reflect"
 | 
						"reflect"
 | 
				
			||||||
@@ -31,7 +30,6 @@ import (
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
func TestReadWriteMessage(t *testing.T) {
 | 
					func TestReadWriteMessage(t *testing.T) {
 | 
				
			||||||
	var (
 | 
						var (
 | 
				
			||||||
		ctx      = context.Background()
 | 
					 | 
				
			||||||
		w, r     = net.Pipe()
 | 
							w, r     = net.Pipe()
 | 
				
			||||||
		ch       = newChannel(w)
 | 
							ch       = newChannel(w)
 | 
				
			||||||
		rch      = newChannel(r)
 | 
							rch      = newChannel(r)
 | 
				
			||||||
@@ -46,7 +44,7 @@ func TestReadWriteMessage(t *testing.T) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	go func() {
 | 
						go func() {
 | 
				
			||||||
		for i, msg := range messages {
 | 
							for i, msg := range messages {
 | 
				
			||||||
			if err := ch.send(ctx, uint32(i), 1, msg); err != nil {
 | 
								if err := ch.send(uint32(i), 1, msg); err != nil {
 | 
				
			||||||
				errs <- err
 | 
									errs <- err
 | 
				
			||||||
				return
 | 
									return
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
@@ -56,7 +54,7 @@ func TestReadWriteMessage(t *testing.T) {
 | 
				
			|||||||
	}()
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for {
 | 
						for {
 | 
				
			||||||
		_, p, err := rch.recv(ctx)
 | 
							_, p, err := rch.recv()
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			if errors.Cause(err) != io.EOF {
 | 
								if errors.Cause(err) != io.EOF {
 | 
				
			||||||
				t.Fatal(err)
 | 
									t.Fatal(err)
 | 
				
			||||||
@@ -91,7 +89,6 @@ func TestReadWriteMessage(t *testing.T) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
func TestMessageOversize(t *testing.T) {
 | 
					func TestMessageOversize(t *testing.T) {
 | 
				
			||||||
	var (
 | 
						var (
 | 
				
			||||||
		ctx      = context.Background()
 | 
					 | 
				
			||||||
		w, r     = net.Pipe()
 | 
							w, r     = net.Pipe()
 | 
				
			||||||
		wch, rch = newChannel(w), newChannel(r)
 | 
							wch, rch = newChannel(w), newChannel(r)
 | 
				
			||||||
		msg      = bytes.Repeat([]byte("a message of massive length"), 512<<10)
 | 
							msg      = bytes.Repeat([]byte("a message of massive length"), 512<<10)
 | 
				
			||||||
@@ -99,12 +96,12 @@ func TestMessageOversize(t *testing.T) {
 | 
				
			|||||||
	)
 | 
						)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	go func() {
 | 
						go func() {
 | 
				
			||||||
		if err := wch.send(ctx, 1, 1, msg); err != nil {
 | 
							if err := wch.send(1, 1, msg); err != nil {
 | 
				
			||||||
			errs <- err
 | 
								errs <- err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}()
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	_, _, err := rch.recv(ctx)
 | 
						_, _, err := rch.recv()
 | 
				
			||||||
	if err == nil {
 | 
						if err == nil {
 | 
				
			||||||
		t.Fatalf("error expected reading with small buffer")
 | 
							t.Fatalf("error expected reading with small buffer")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										163
									
								
								client.go
									
									
									
									
									
								
							
							
						
						
									
										163
									
								
								client.go
									
									
									
									
									
								
							@@ -43,10 +43,13 @@ type Client struct {
 | 
				
			|||||||
	channel *channel
 | 
						channel *channel
 | 
				
			||||||
	calls   chan *callRequest
 | 
						calls   chan *callRequest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	closed      chan struct{}
 | 
						ctx    context.Context
 | 
				
			||||||
	closeOnce   sync.Once
 | 
						closed func()
 | 
				
			||||||
	closeFunc   func()
 | 
					
 | 
				
			||||||
	done        chan struct{}
 | 
						closeOnce     sync.Once
 | 
				
			||||||
 | 
						userCloseFunc func()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						errOnce     sync.Once
 | 
				
			||||||
	err         error
 | 
						err         error
 | 
				
			||||||
	interceptor UnaryClientInterceptor
 | 
						interceptor UnaryClientInterceptor
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@@ -57,7 +60,7 @@ type ClientOpts func(c *Client)
 | 
				
			|||||||
// WithOnClose sets the close func whenever the client's Close() method is called
 | 
					// WithOnClose sets the close func whenever the client's Close() method is called
 | 
				
			||||||
func WithOnClose(onClose func()) ClientOpts {
 | 
					func WithOnClose(onClose func()) ClientOpts {
 | 
				
			||||||
	return func(c *Client) {
 | 
						return func(c *Client) {
 | 
				
			||||||
		c.closeFunc = onClose
 | 
							c.userCloseFunc = onClose
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -69,15 +72,16 @@ func WithUnaryClientInterceptor(i UnaryClientInterceptor) ClientOpts {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func NewClient(conn net.Conn, opts ...ClientOpts) *Client {
 | 
					func NewClient(conn net.Conn, opts ...ClientOpts) *Client {
 | 
				
			||||||
 | 
						ctx, cancel := context.WithCancel(context.Background())
 | 
				
			||||||
	c := &Client{
 | 
						c := &Client{
 | 
				
			||||||
		codec:       codec{},
 | 
							codec:         codec{},
 | 
				
			||||||
		conn:        conn,
 | 
							conn:          conn,
 | 
				
			||||||
		channel:     newChannel(conn),
 | 
							channel:       newChannel(conn),
 | 
				
			||||||
		calls:       make(chan *callRequest),
 | 
							calls:         make(chan *callRequest),
 | 
				
			||||||
		closed:      make(chan struct{}),
 | 
							closed:        cancel,
 | 
				
			||||||
		done:        make(chan struct{}),
 | 
							ctx:           ctx,
 | 
				
			||||||
		closeFunc:   func() {},
 | 
							userCloseFunc: func() {},
 | 
				
			||||||
		interceptor: defaultClientInterceptor,
 | 
							interceptor:   defaultClientInterceptor,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for _, o := range opts {
 | 
						for _, o := range opts {
 | 
				
			||||||
@@ -150,8 +154,8 @@ func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) err
 | 
				
			|||||||
	case <-ctx.Done():
 | 
						case <-ctx.Done():
 | 
				
			||||||
		return ctx.Err()
 | 
							return ctx.Err()
 | 
				
			||||||
	case c.calls <- call:
 | 
						case c.calls <- call:
 | 
				
			||||||
	case <-c.done:
 | 
						case <-c.ctx.Done():
 | 
				
			||||||
		return c.err
 | 
							return c.error()
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	select {
 | 
						select {
 | 
				
			||||||
@@ -159,16 +163,15 @@ func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) err
 | 
				
			|||||||
		return ctx.Err()
 | 
							return ctx.Err()
 | 
				
			||||||
	case err := <-errs:
 | 
						case err := <-errs:
 | 
				
			||||||
		return filterCloseErr(err)
 | 
							return filterCloseErr(err)
 | 
				
			||||||
	case <-c.done:
 | 
						case <-c.ctx.Done():
 | 
				
			||||||
		return c.err
 | 
							return c.error()
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (c *Client) Close() error {
 | 
					func (c *Client) Close() error {
 | 
				
			||||||
	c.closeOnce.Do(func() {
 | 
						c.closeOnce.Do(func() {
 | 
				
			||||||
		close(c.closed)
 | 
							c.closed()
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
 | 
					 | 
				
			||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -178,51 +181,82 @@ type message struct {
 | 
				
			|||||||
	err error
 | 
						err error
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (c *Client) run() {
 | 
					type receiver struct {
 | 
				
			||||||
	var (
 | 
						wg       *sync.WaitGroup
 | 
				
			||||||
		streamID    uint32 = 1
 | 
						messages chan *message
 | 
				
			||||||
		waiters            = make(map[uint32]*callRequest)
 | 
						err      error
 | 
				
			||||||
		calls              = c.calls
 | 
					}
 | 
				
			||||||
		incoming           = make(chan *message)
 | 
					 | 
				
			||||||
		shutdown           = make(chan struct{})
 | 
					 | 
				
			||||||
		shutdownErr error
 | 
					 | 
				
			||||||
	)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	go func() {
 | 
					func (r *receiver) run(ctx context.Context, c *channel) {
 | 
				
			||||||
		defer close(shutdown)
 | 
						defer r.wg.Done()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// start one more goroutine to recv messages without blocking.
 | 
						for {
 | 
				
			||||||
		for {
 | 
							select {
 | 
				
			||||||
			mh, p, err := c.channel.recv(context.TODO())
 | 
							case <-ctx.Done():
 | 
				
			||||||
 | 
								r.err = ctx.Err()
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							default:
 | 
				
			||||||
 | 
								mh, p, err := c.recv()
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				_, ok := status.FromError(err)
 | 
									_, ok := status.FromError(err)
 | 
				
			||||||
				if !ok {
 | 
									if !ok {
 | 
				
			||||||
					// treat all errors that are not an rpc status as terminal.
 | 
										// treat all errors that are not an rpc status as terminal.
 | 
				
			||||||
					// all others poison the connection.
 | 
										// all others poison the connection.
 | 
				
			||||||
					shutdownErr = err
 | 
										r.err = err
 | 
				
			||||||
					return
 | 
										return
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			select {
 | 
								select {
 | 
				
			||||||
			case incoming <- &message{
 | 
								case r.messages <- &message{
 | 
				
			||||||
				messageHeader: mh,
 | 
									messageHeader: mh,
 | 
				
			||||||
				p:             p[:mh.Length],
 | 
									p:             p[:mh.Length],
 | 
				
			||||||
				err:           err,
 | 
									err:           err,
 | 
				
			||||||
			}:
 | 
								}:
 | 
				
			||||||
			case <-c.done:
 | 
								case <-ctx.Done():
 | 
				
			||||||
 | 
									r.err = ctx.Err()
 | 
				
			||||||
				return
 | 
									return
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}()
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	defer c.conn.Close()
 | 
					func (c *Client) run() {
 | 
				
			||||||
	defer close(c.done)
 | 
						var (
 | 
				
			||||||
	defer c.closeFunc()
 | 
							streamID      uint32 = 1
 | 
				
			||||||
 | 
							waiters              = make(map[uint32]*callRequest)
 | 
				
			||||||
 | 
							calls                = c.calls
 | 
				
			||||||
 | 
							incoming             = make(chan *message)
 | 
				
			||||||
 | 
							receiversDone        = make(chan struct{})
 | 
				
			||||||
 | 
							wg            sync.WaitGroup
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// broadcast the shutdown error to the remaining waiters.
 | 
				
			||||||
 | 
						abortWaiters := func(wErr error) {
 | 
				
			||||||
 | 
							for _, waiter := range waiters {
 | 
				
			||||||
 | 
								waiter.errs <- wErr
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						recv := &receiver{
 | 
				
			||||||
 | 
							wg:       &wg,
 | 
				
			||||||
 | 
							messages: incoming,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						wg.Add(1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						go func() {
 | 
				
			||||||
 | 
							wg.Wait()
 | 
				
			||||||
 | 
							close(receiversDone)
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
						go recv.run(c.ctx, c.channel)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						defer func() {
 | 
				
			||||||
 | 
							c.conn.Close()
 | 
				
			||||||
 | 
							c.userCloseFunc()
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for {
 | 
						for {
 | 
				
			||||||
		select {
 | 
							select {
 | 
				
			||||||
		case call := <-calls:
 | 
							case call := <-calls:
 | 
				
			||||||
			if err := c.send(call.ctx, streamID, messageTypeRequest, call.req); err != nil {
 | 
								if err := c.send(streamID, messageTypeRequest, call.req); err != nil {
 | 
				
			||||||
				call.errs <- err
 | 
									call.errs <- err
 | 
				
			||||||
				continue
 | 
									continue
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
@@ -238,41 +272,42 @@ func (c *Client) run() {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
			call.errs <- c.recv(call.resp, msg)
 | 
								call.errs <- c.recv(call.resp, msg)
 | 
				
			||||||
			delete(waiters, msg.StreamID)
 | 
								delete(waiters, msg.StreamID)
 | 
				
			||||||
		case <-shutdown:
 | 
							case <-receiversDone:
 | 
				
			||||||
			if shutdownErr != nil {
 | 
								// all the receivers have exited
 | 
				
			||||||
				shutdownErr = filterCloseErr(shutdownErr)
 | 
								if recv.err != nil {
 | 
				
			||||||
			} else {
 | 
									c.setError(recv.err)
 | 
				
			||||||
				shutdownErr = ErrClosed
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			shutdownErr = errors.Wrapf(shutdownErr, "ttrpc: client shutting down")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			c.err = shutdownErr
 | 
					 | 
				
			||||||
			for _, waiter := range waiters {
 | 
					 | 
				
			||||||
				waiter.errs <- shutdownErr
 | 
					 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
								// don't return out, let the close of the context trigger the abort of waiters
 | 
				
			||||||
			c.Close()
 | 
								c.Close()
 | 
				
			||||||
			return
 | 
							case <-c.ctx.Done():
 | 
				
			||||||
		case <-c.closed:
 | 
								abortWaiters(c.error())
 | 
				
			||||||
			if c.err == nil {
 | 
					 | 
				
			||||||
				c.err = ErrClosed
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
			// broadcast the shutdown error to the remaining waiters.
 | 
					 | 
				
			||||||
			for _, waiter := range waiters {
 | 
					 | 
				
			||||||
				waiter.errs <- c.err
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
			return
 | 
								return
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (c *Client) send(ctx context.Context, streamID uint32, mtype messageType, msg interface{}) error {
 | 
					func (c *Client) error() error {
 | 
				
			||||||
 | 
						c.errOnce.Do(func() {
 | 
				
			||||||
 | 
							if c.err == nil {
 | 
				
			||||||
 | 
								c.err = ErrClosed
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
						return c.err
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (c *Client) setError(err error) {
 | 
				
			||||||
 | 
						c.errOnce.Do(func() {
 | 
				
			||||||
 | 
							c.err = err
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (c *Client) send(streamID uint32, mtype messageType, msg interface{}) error {
 | 
				
			||||||
	p, err := c.codec.Marshal(msg)
 | 
						p, err := c.codec.Marshal(msg)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return err
 | 
							return err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return c.channel.send(ctx, streamID, mtype, p)
 | 
						return c.channel.send(streamID, mtype, p)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (c *Client) recv(resp *Response, msg *message) error {
 | 
					func (c *Client) recv(resp *Response, msg *message) error {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -344,7 +344,7 @@ func (c *serverConn) run(sctx context.Context) {
 | 
				
			|||||||
			default: // proceed
 | 
								default: // proceed
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			mh, p, err := ch.recv(ctx)
 | 
								mh, p, err := ch.recv()
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				status, ok := status.FromError(err)
 | 
									status, ok := status.FromError(err)
 | 
				
			||||||
				if !ok {
 | 
									if !ok {
 | 
				
			||||||
@@ -441,7 +441,7 @@ func (c *serverConn) run(sctx context.Context) {
 | 
				
			|||||||
				return
 | 
									return
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			if err := ch.send(ctx, response.id, messageTypeResponse, p); err != nil {
 | 
								if err := ch.send(response.id, messageTypeResponse, p); err != nil {
 | 
				
			||||||
				logrus.WithError(err).Error("failed sending message on channel")
 | 
									logrus.WithError(err).Error("failed sending message on channel")
 | 
				
			||||||
				return
 | 
									return
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user