diff --git a/channel.go b/channel.go index 22f5496..aa8c954 100644 --- a/channel.go +++ b/channel.go @@ -18,7 +18,6 @@ package ttrpc import ( "bufio" - "context" "encoding/binary" "io" "net" @@ -98,7 +97,7 @@ func newChannel(conn net.Conn) *channel { // returned will be valid and caller should send that along to // the correct consumer. The bytes on the underlying channel // will be discarded. -func (ch *channel) recv(ctx context.Context) (messageHeader, []byte, error) { +func (ch *channel) recv() (messageHeader, []byte, error) { mh, err := readMessageHeader(ch.hrbuf[:], ch.br) if err != nil { return messageHeader{}, nil, err @@ -120,7 +119,7 @@ func (ch *channel) recv(ctx context.Context) (messageHeader, []byte, error) { return mh, p, nil } -func (ch *channel) send(ctx context.Context, streamID uint32, t messageType, p []byte) error { +func (ch *channel) send(streamID uint32, t messageType, p []byte) error { if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t}); err != nil { return err } diff --git a/channel_test.go b/channel_test.go index 62030c2..2242336 100644 --- a/channel_test.go +++ b/channel_test.go @@ -18,7 +18,6 @@ package ttrpc import ( "bytes" - "context" "io" "net" "reflect" @@ -31,7 +30,6 @@ import ( func TestReadWriteMessage(t *testing.T) { var ( - ctx = context.Background() w, r = net.Pipe() ch = newChannel(w) rch = newChannel(r) @@ -46,7 +44,7 @@ func TestReadWriteMessage(t *testing.T) { go func() { for i, msg := range messages { - if err := ch.send(ctx, uint32(i), 1, msg); err != nil { + if err := ch.send(uint32(i), 1, msg); err != nil { errs <- err return } @@ -56,7 +54,7 @@ func TestReadWriteMessage(t *testing.T) { }() for { - _, p, err := rch.recv(ctx) + _, p, err := rch.recv() if err != nil { if errors.Cause(err) != io.EOF { t.Fatal(err) @@ -91,7 +89,6 @@ func TestReadWriteMessage(t *testing.T) { func TestMessageOversize(t *testing.T) { var ( - ctx = context.Background() w, r = net.Pipe() wch, rch = newChannel(w), newChannel(r) msg = bytes.Repeat([]byte("a message of massive length"), 512<<10) @@ -99,12 +96,12 @@ func TestMessageOversize(t *testing.T) { ) go func() { - if err := wch.send(ctx, 1, 1, msg); err != nil { + if err := wch.send(1, 1, msg); err != nil { errs <- err } }() - _, _, err := rch.recv(ctx) + _, _, err := rch.recv() if err == nil { t.Fatalf("error expected reading with small buffer") } diff --git a/client.go b/client.go index 804024e..d72fec5 100644 --- a/client.go +++ b/client.go @@ -43,10 +43,13 @@ type Client struct { channel *channel calls chan *callRequest - closed chan struct{} - closeOnce sync.Once - closeFunc func() - done chan struct{} + ctx context.Context + closed func() + + closeOnce sync.Once + userCloseFunc func() + + errOnce sync.Once err error interceptor UnaryClientInterceptor } @@ -57,7 +60,7 @@ type ClientOpts func(c *Client) // WithOnClose sets the close func whenever the client's Close() method is called func WithOnClose(onClose func()) ClientOpts { return func(c *Client) { - c.closeFunc = onClose + c.userCloseFunc = onClose } } @@ -69,15 +72,16 @@ func WithUnaryClientInterceptor(i UnaryClientInterceptor) ClientOpts { } func NewClient(conn net.Conn, opts ...ClientOpts) *Client { + ctx, cancel := context.WithCancel(context.Background()) c := &Client{ - codec: codec{}, - conn: conn, - channel: newChannel(conn), - calls: make(chan *callRequest), - closed: make(chan struct{}), - done: make(chan struct{}), - closeFunc: func() {}, - interceptor: defaultClientInterceptor, + codec: codec{}, + conn: conn, + channel: newChannel(conn), + calls: make(chan *callRequest), + closed: cancel, + ctx: ctx, + userCloseFunc: func() {}, + interceptor: defaultClientInterceptor, } for _, o := range opts { @@ -150,8 +154,8 @@ func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) err case <-ctx.Done(): return ctx.Err() case c.calls <- call: - case <-c.done: - return c.err + case <-c.ctx.Done(): + return c.error() } select { @@ -159,16 +163,15 @@ func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) err return ctx.Err() case err := <-errs: return filterCloseErr(err) - case <-c.done: - return c.err + case <-c.ctx.Done(): + return c.error() } } func (c *Client) Close() error { c.closeOnce.Do(func() { - close(c.closed) + c.closed() }) - return nil } @@ -178,51 +181,82 @@ type message struct { err error } -func (c *Client) run() { - var ( - streamID uint32 = 1 - waiters = make(map[uint32]*callRequest) - calls = c.calls - incoming = make(chan *message) - shutdown = make(chan struct{}) - shutdownErr error - ) +type receiver struct { + wg *sync.WaitGroup + messages chan *message + err error +} - go func() { - defer close(shutdown) +func (r *receiver) run(ctx context.Context, c *channel) { + defer r.wg.Done() - // start one more goroutine to recv messages without blocking. - for { - mh, p, err := c.channel.recv(context.TODO()) + for { + select { + case <-ctx.Done(): + r.err = ctx.Err() + return + default: + mh, p, err := c.recv() if err != nil { _, ok := status.FromError(err) if !ok { // treat all errors that are not an rpc status as terminal. // all others poison the connection. - shutdownErr = err + r.err = err return } } select { - case incoming <- &message{ + case r.messages <- &message{ messageHeader: mh, p: p[:mh.Length], err: err, }: - case <-c.done: + case <-ctx.Done(): + r.err = ctx.Err() return } } - }() + } +} - defer c.conn.Close() - defer close(c.done) - defer c.closeFunc() +func (c *Client) run() { + var ( + streamID uint32 = 1 + waiters = make(map[uint32]*callRequest) + calls = c.calls + incoming = make(chan *message) + receiversDone = make(chan struct{}) + wg sync.WaitGroup + ) + + // broadcast the shutdown error to the remaining waiters. + abortWaiters := func(wErr error) { + for _, waiter := range waiters { + waiter.errs <- wErr + } + } + recv := &receiver{ + wg: &wg, + messages: incoming, + } + wg.Add(1) + + go func() { + wg.Wait() + close(receiversDone) + }() + go recv.run(c.ctx, c.channel) + + defer func() { + c.conn.Close() + c.userCloseFunc() + }() for { select { case call := <-calls: - if err := c.send(call.ctx, streamID, messageTypeRequest, call.req); err != nil { + if err := c.send(streamID, messageTypeRequest, call.req); err != nil { call.errs <- err continue } @@ -238,41 +272,42 @@ func (c *Client) run() { call.errs <- c.recv(call.resp, msg) delete(waiters, msg.StreamID) - case <-shutdown: - if shutdownErr != nil { - shutdownErr = filterCloseErr(shutdownErr) - } else { - shutdownErr = ErrClosed - } - - shutdownErr = errors.Wrapf(shutdownErr, "ttrpc: client shutting down") - - c.err = shutdownErr - for _, waiter := range waiters { - waiter.errs <- shutdownErr + case <-receiversDone: + // all the receivers have exited + if recv.err != nil { + c.setError(recv.err) } + // don't return out, let the close of the context trigger the abort of waiters c.Close() - return - case <-c.closed: - if c.err == nil { - c.err = ErrClosed - } - // broadcast the shutdown error to the remaining waiters. - for _, waiter := range waiters { - waiter.errs <- c.err - } + case <-c.ctx.Done(): + abortWaiters(c.error()) return } } } -func (c *Client) send(ctx context.Context, streamID uint32, mtype messageType, msg interface{}) error { +func (c *Client) error() error { + c.errOnce.Do(func() { + if c.err == nil { + c.err = ErrClosed + } + }) + return c.err +} + +func (c *Client) setError(err error) { + c.errOnce.Do(func() { + c.err = err + }) +} + +func (c *Client) send(streamID uint32, mtype messageType, msg interface{}) error { p, err := c.codec.Marshal(msg) if err != nil { return err } - return c.channel.send(ctx, streamID, mtype, p) + return c.channel.send(streamID, mtype, p) } func (c *Client) recv(resp *Response, msg *message) error { diff --git a/server.go b/server.go index 5c33559..585c9f0 100644 --- a/server.go +++ b/server.go @@ -344,7 +344,7 @@ func (c *serverConn) run(sctx context.Context) { default: // proceed } - mh, p, err := ch.recv(ctx) + mh, p, err := ch.recv() if err != nil { status, ok := status.FromError(err) if !ok { @@ -441,7 +441,7 @@ func (c *serverConn) run(sctx context.Context) { return } - if err := ch.send(ctx, response.id, messageTypeResponse, p); err != nil { + if err := ch.send(response.id, messageTypeResponse, p); err != nil { logrus.WithError(err).Error("failed sending message on channel") return }