Merge pull request #42 from crosbymichael/client

Refactor close handling for ttrpc clients
This commit is contained in:
Phil Estes 2019-06-13 14:33:16 -04:00 committed by GitHub
commit 1fb3814edf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 120 additions and 90 deletions

View File

@ -18,7 +18,6 @@ package ttrpc
import ( import (
"bufio" "bufio"
"context"
"encoding/binary" "encoding/binary"
"io" "io"
"net" "net"
@ -98,7 +97,7 @@ func newChannel(conn net.Conn) *channel {
// returned will be valid and caller should send that along to // returned will be valid and caller should send that along to
// the correct consumer. The bytes on the underlying channel // the correct consumer. The bytes on the underlying channel
// will be discarded. // will be discarded.
func (ch *channel) recv(ctx context.Context) (messageHeader, []byte, error) { func (ch *channel) recv() (messageHeader, []byte, error) {
mh, err := readMessageHeader(ch.hrbuf[:], ch.br) mh, err := readMessageHeader(ch.hrbuf[:], ch.br)
if err != nil { if err != nil {
return messageHeader{}, nil, err return messageHeader{}, nil, err
@ -120,7 +119,7 @@ func (ch *channel) recv(ctx context.Context) (messageHeader, []byte, error) {
return mh, p, nil return mh, p, nil
} }
func (ch *channel) send(ctx context.Context, streamID uint32, t messageType, p []byte) error { func (ch *channel) send(streamID uint32, t messageType, p []byte) error {
if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t}); err != nil { if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t}); err != nil {
return err return err
} }

View File

@ -18,7 +18,6 @@ package ttrpc
import ( import (
"bytes" "bytes"
"context"
"io" "io"
"net" "net"
"reflect" "reflect"
@ -31,7 +30,6 @@ import (
func TestReadWriteMessage(t *testing.T) { func TestReadWriteMessage(t *testing.T) {
var ( var (
ctx = context.Background()
w, r = net.Pipe() w, r = net.Pipe()
ch = newChannel(w) ch = newChannel(w)
rch = newChannel(r) rch = newChannel(r)
@ -46,7 +44,7 @@ func TestReadWriteMessage(t *testing.T) {
go func() { go func() {
for i, msg := range messages { for i, msg := range messages {
if err := ch.send(ctx, uint32(i), 1, msg); err != nil { if err := ch.send(uint32(i), 1, msg); err != nil {
errs <- err errs <- err
return return
} }
@ -56,7 +54,7 @@ func TestReadWriteMessage(t *testing.T) {
}() }()
for { for {
_, p, err := rch.recv(ctx) _, p, err := rch.recv()
if err != nil { if err != nil {
if errors.Cause(err) != io.EOF { if errors.Cause(err) != io.EOF {
t.Fatal(err) t.Fatal(err)
@ -91,7 +89,6 @@ func TestReadWriteMessage(t *testing.T) {
func TestMessageOversize(t *testing.T) { func TestMessageOversize(t *testing.T) {
var ( var (
ctx = context.Background()
w, r = net.Pipe() w, r = net.Pipe()
wch, rch = newChannel(w), newChannel(r) wch, rch = newChannel(w), newChannel(r)
msg = bytes.Repeat([]byte("a message of massive length"), 512<<10) msg = bytes.Repeat([]byte("a message of massive length"), 512<<10)
@ -99,12 +96,12 @@ func TestMessageOversize(t *testing.T) {
) )
go func() { go func() {
if err := wch.send(ctx, 1, 1, msg); err != nil { if err := wch.send(1, 1, msg); err != nil {
errs <- err errs <- err
} }
}() }()
_, _, err := rch.recv(ctx) _, _, err := rch.recv()
if err == nil { if err == nil {
t.Fatalf("error expected reading with small buffer") t.Fatalf("error expected reading with small buffer")
} }

188
client.go
View File

@ -43,10 +43,13 @@ type Client struct {
channel *channel channel *channel
calls chan *callRequest calls chan *callRequest
closed chan struct{} ctx context.Context
closeOnce sync.Once closed func()
closeFunc func()
done chan struct{} closeOnce sync.Once
userCloseFunc func()
errOnce sync.Once
err error err error
interceptor UnaryClientInterceptor interceptor UnaryClientInterceptor
} }
@ -57,7 +60,7 @@ type ClientOpts func(c *Client)
// WithOnClose sets the close func whenever the client's Close() method is called // WithOnClose sets the close func whenever the client's Close() method is called
func WithOnClose(onClose func()) ClientOpts { func WithOnClose(onClose func()) ClientOpts {
return func(c *Client) { return func(c *Client) {
c.closeFunc = onClose c.userCloseFunc = onClose
} }
} }
@ -69,15 +72,16 @@ func WithUnaryClientInterceptor(i UnaryClientInterceptor) ClientOpts {
} }
func NewClient(conn net.Conn, opts ...ClientOpts) *Client { func NewClient(conn net.Conn, opts ...ClientOpts) *Client {
ctx, cancel := context.WithCancel(context.Background())
c := &Client{ c := &Client{
codec: codec{}, codec: codec{},
conn: conn, conn: conn,
channel: newChannel(conn), channel: newChannel(conn),
calls: make(chan *callRequest), calls: make(chan *callRequest),
closed: make(chan struct{}), closed: cancel,
done: make(chan struct{}), ctx: ctx,
closeFunc: func() {}, userCloseFunc: func() {},
interceptor: defaultClientInterceptor, interceptor: defaultClientInterceptor,
} }
for _, o := range opts { for _, o := range opts {
@ -150,8 +154,8 @@ func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) err
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return ctx.Err()
case c.calls <- call: case c.calls <- call:
case <-c.done: case <-c.ctx.Done():
return c.err return c.error()
} }
select { select {
@ -159,16 +163,15 @@ func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) err
return ctx.Err() return ctx.Err()
case err := <-errs: case err := <-errs:
return filterCloseErr(err) return filterCloseErr(err)
case <-c.done: case <-c.ctx.Done():
return c.err return c.error()
} }
} }
func (c *Client) Close() error { func (c *Client) Close() error {
c.closeOnce.Do(func() { c.closeOnce.Do(func() {
close(c.closed) c.closed()
}) })
return nil return nil
} }
@ -178,51 +181,82 @@ type message struct {
err error err error
} }
func (c *Client) run() { type receiver struct {
var ( wg *sync.WaitGroup
streamID uint32 = 1 messages chan *message
waiters = make(map[uint32]*callRequest) err error
calls = c.calls }
incoming = make(chan *message)
shutdown = make(chan struct{})
shutdownErr error
)
go func() { func (r *receiver) run(ctx context.Context, c *channel) {
defer close(shutdown) defer r.wg.Done()
// start one more goroutine to recv messages without blocking. for {
for { select {
mh, p, err := c.channel.recv(context.TODO()) case <-ctx.Done():
r.err = ctx.Err()
return
default:
mh, p, err := c.recv()
if err != nil { if err != nil {
_, ok := status.FromError(err) _, ok := status.FromError(err)
if !ok { if !ok {
// treat all errors that are not an rpc status as terminal. // treat all errors that are not an rpc status as terminal.
// all others poison the connection. // all others poison the connection.
shutdownErr = err r.err = filterCloseErr(err)
return return
} }
} }
select { select {
case incoming <- &message{ case r.messages <- &message{
messageHeader: mh, messageHeader: mh,
p: p[:mh.Length], p: p[:mh.Length],
err: err, err: err,
}: }:
case <-c.done: case <-ctx.Done():
r.err = ctx.Err()
return return
} }
} }
}() }
}
defer c.conn.Close() func (c *Client) run() {
defer close(c.done) var (
defer c.closeFunc() streamID uint32 = 1
waiters = make(map[uint32]*callRequest)
calls = c.calls
incoming = make(chan *message)
receiversDone = make(chan struct{})
wg sync.WaitGroup
)
// broadcast the shutdown error to the remaining waiters.
abortWaiters := func(wErr error) {
for _, waiter := range waiters {
waiter.errs <- wErr
}
}
recv := &receiver{
wg: &wg,
messages: incoming,
}
wg.Add(1)
go func() {
wg.Wait()
close(receiversDone)
}()
go recv.run(c.ctx, c.channel)
defer func() {
c.conn.Close()
c.userCloseFunc()
}()
for { for {
select { select {
case call := <-calls: case call := <-calls:
if err := c.send(call.ctx, streamID, messageTypeRequest, call.req); err != nil { if err := c.send(streamID, messageTypeRequest, call.req); err != nil {
call.errs <- err call.errs <- err
continue continue
} }
@ -238,41 +272,42 @@ func (c *Client) run() {
call.errs <- c.recv(call.resp, msg) call.errs <- c.recv(call.resp, msg)
delete(waiters, msg.StreamID) delete(waiters, msg.StreamID)
case <-shutdown: case <-receiversDone:
if shutdownErr != nil { // all the receivers have exited
shutdownErr = filterCloseErr(shutdownErr) if recv.err != nil {
} else { c.setError(recv.err)
shutdownErr = ErrClosed
}
shutdownErr = errors.Wrapf(shutdownErr, "ttrpc: client shutting down")
c.err = shutdownErr
for _, waiter := range waiters {
waiter.errs <- shutdownErr
} }
// don't return out, let the close of the context trigger the abort of waiters
c.Close() c.Close()
return case <-c.ctx.Done():
case <-c.closed: abortWaiters(c.error())
if c.err == nil {
c.err = ErrClosed
}
// broadcast the shutdown error to the remaining waiters.
for _, waiter := range waiters {
waiter.errs <- c.err
}
return return
} }
} }
} }
func (c *Client) send(ctx context.Context, streamID uint32, mtype messageType, msg interface{}) error { func (c *Client) error() error {
c.errOnce.Do(func() {
if c.err == nil {
c.err = ErrClosed
}
})
return c.err
}
func (c *Client) setError(err error) {
c.errOnce.Do(func() {
c.err = err
})
}
func (c *Client) send(streamID uint32, mtype messageType, msg interface{}) error {
p, err := c.codec.Marshal(msg) p, err := c.codec.Marshal(msg)
if err != nil { if err != nil {
return err return err
} }
return c.channel.send(ctx, streamID, mtype, p) return c.channel.send(streamID, mtype, p)
} }
func (c *Client) recv(resp *Response, msg *message) error { func (c *Client) recv(resp *Response, msg *message) error {
@ -293,22 +328,21 @@ func (c *Client) recv(resp *Response, msg *message) error {
// //
// This purposely ignores errors with a wrapped cause. // This purposely ignores errors with a wrapped cause.
func filterCloseErr(err error) error { func filterCloseErr(err error) error {
if err == nil { switch {
case err == nil:
return nil return nil
} case err == io.EOF:
if err == io.EOF {
return ErrClosed return ErrClosed
} case errors.Cause(err) == io.EOF:
if strings.Contains(err.Error(), "use of closed network connection") {
return ErrClosed return ErrClosed
} case strings.Contains(err.Error(), "use of closed network connection"):
return ErrClosed
// if we have an epipe on a write, we cast to errclosed default:
if oerr, ok := err.(*net.OpError); ok && oerr.Op == "write" { // if we have an epipe on a write, we cast to errclosed
if serr, ok := oerr.Err.(*os.SyscallError); ok && serr.Err == syscall.EPIPE { if oerr, ok := err.(*net.OpError); ok && oerr.Op == "write" {
return ErrClosed if serr, ok := oerr.Err.(*os.SyscallError); ok && serr.Err == syscall.EPIPE {
return ErrClosed
}
} }
} }

View File

@ -344,7 +344,7 @@ func (c *serverConn) run(sctx context.Context) {
default: // proceed default: // proceed
} }
mh, p, err := ch.recv(ctx) mh, p, err := ch.recv()
if err != nil { if err != nil {
status, ok := status.FromError(err) status, ok := status.FromError(err)
if !ok { if !ok {
@ -441,7 +441,7 @@ func (c *serverConn) run(sctx context.Context) {
return return
} }
if err := ch.send(ctx, response.id, messageTypeResponse, p); err != nil { if err := ch.send(response.id, messageTypeResponse, p); err != nil {
logrus.WithError(err).Error("failed sending message on channel") logrus.WithError(err).Error("failed sending message on channel")
return return
} }

View File

@ -351,7 +351,7 @@ func TestClientEOF(t *testing.T) {
} }
// shutdown the server so the client stops receiving stuff. // shutdown the server so the client stops receiving stuff.
if err := server.Shutdown(ctx); err != nil { if err := server.Close(); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err := <-errs; err != ErrServerClosed { if err := <-errs; err != ErrServerClosed {