Merge pull request #42 from crosbymichael/client
Refactor close handling for ttrpc clients
This commit is contained in:
commit
1fb3814edf
@ -18,7 +18,6 @@ package ttrpc
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"context"
|
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
@ -98,7 +97,7 @@ func newChannel(conn net.Conn) *channel {
|
|||||||
// returned will be valid and caller should send that along to
|
// returned will be valid and caller should send that along to
|
||||||
// the correct consumer. The bytes on the underlying channel
|
// the correct consumer. The bytes on the underlying channel
|
||||||
// will be discarded.
|
// will be discarded.
|
||||||
func (ch *channel) recv(ctx context.Context) (messageHeader, []byte, error) {
|
func (ch *channel) recv() (messageHeader, []byte, error) {
|
||||||
mh, err := readMessageHeader(ch.hrbuf[:], ch.br)
|
mh, err := readMessageHeader(ch.hrbuf[:], ch.br)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return messageHeader{}, nil, err
|
return messageHeader{}, nil, err
|
||||||
@ -120,7 +119,7 @@ func (ch *channel) recv(ctx context.Context) (messageHeader, []byte, error) {
|
|||||||
return mh, p, nil
|
return mh, p, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ch *channel) send(ctx context.Context, streamID uint32, t messageType, p []byte) error {
|
func (ch *channel) send(streamID uint32, t messageType, p []byte) error {
|
||||||
if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t}); err != nil {
|
if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t}); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -18,7 +18,6 @@ package ttrpc
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"reflect"
|
"reflect"
|
||||||
@ -31,7 +30,6 @@ import (
|
|||||||
|
|
||||||
func TestReadWriteMessage(t *testing.T) {
|
func TestReadWriteMessage(t *testing.T) {
|
||||||
var (
|
var (
|
||||||
ctx = context.Background()
|
|
||||||
w, r = net.Pipe()
|
w, r = net.Pipe()
|
||||||
ch = newChannel(w)
|
ch = newChannel(w)
|
||||||
rch = newChannel(r)
|
rch = newChannel(r)
|
||||||
@ -46,7 +44,7 @@ func TestReadWriteMessage(t *testing.T) {
|
|||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
for i, msg := range messages {
|
for i, msg := range messages {
|
||||||
if err := ch.send(ctx, uint32(i), 1, msg); err != nil {
|
if err := ch.send(uint32(i), 1, msg); err != nil {
|
||||||
errs <- err
|
errs <- err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -56,7 +54,7 @@ func TestReadWriteMessage(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
_, p, err := rch.recv(ctx)
|
_, p, err := rch.recv()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Cause(err) != io.EOF {
|
if errors.Cause(err) != io.EOF {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@ -91,7 +89,6 @@ func TestReadWriteMessage(t *testing.T) {
|
|||||||
|
|
||||||
func TestMessageOversize(t *testing.T) {
|
func TestMessageOversize(t *testing.T) {
|
||||||
var (
|
var (
|
||||||
ctx = context.Background()
|
|
||||||
w, r = net.Pipe()
|
w, r = net.Pipe()
|
||||||
wch, rch = newChannel(w), newChannel(r)
|
wch, rch = newChannel(w), newChannel(r)
|
||||||
msg = bytes.Repeat([]byte("a message of massive length"), 512<<10)
|
msg = bytes.Repeat([]byte("a message of massive length"), 512<<10)
|
||||||
@ -99,12 +96,12 @@ func TestMessageOversize(t *testing.T) {
|
|||||||
)
|
)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
if err := wch.send(ctx, 1, 1, msg); err != nil {
|
if err := wch.send(1, 1, msg); err != nil {
|
||||||
errs <- err
|
errs <- err
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
_, _, err := rch.recv(ctx)
|
_, _, err := rch.recv()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatalf("error expected reading with small buffer")
|
t.Fatalf("error expected reading with small buffer")
|
||||||
}
|
}
|
||||||
|
188
client.go
188
client.go
@ -43,10 +43,13 @@ type Client struct {
|
|||||||
channel *channel
|
channel *channel
|
||||||
calls chan *callRequest
|
calls chan *callRequest
|
||||||
|
|
||||||
closed chan struct{}
|
ctx context.Context
|
||||||
closeOnce sync.Once
|
closed func()
|
||||||
closeFunc func()
|
|
||||||
done chan struct{}
|
closeOnce sync.Once
|
||||||
|
userCloseFunc func()
|
||||||
|
|
||||||
|
errOnce sync.Once
|
||||||
err error
|
err error
|
||||||
interceptor UnaryClientInterceptor
|
interceptor UnaryClientInterceptor
|
||||||
}
|
}
|
||||||
@ -57,7 +60,7 @@ type ClientOpts func(c *Client)
|
|||||||
// WithOnClose sets the close func whenever the client's Close() method is called
|
// WithOnClose sets the close func whenever the client's Close() method is called
|
||||||
func WithOnClose(onClose func()) ClientOpts {
|
func WithOnClose(onClose func()) ClientOpts {
|
||||||
return func(c *Client) {
|
return func(c *Client) {
|
||||||
c.closeFunc = onClose
|
c.userCloseFunc = onClose
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -69,15 +72,16 @@ func WithUnaryClientInterceptor(i UnaryClientInterceptor) ClientOpts {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewClient(conn net.Conn, opts ...ClientOpts) *Client {
|
func NewClient(conn net.Conn, opts ...ClientOpts) *Client {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
c := &Client{
|
c := &Client{
|
||||||
codec: codec{},
|
codec: codec{},
|
||||||
conn: conn,
|
conn: conn,
|
||||||
channel: newChannel(conn),
|
channel: newChannel(conn),
|
||||||
calls: make(chan *callRequest),
|
calls: make(chan *callRequest),
|
||||||
closed: make(chan struct{}),
|
closed: cancel,
|
||||||
done: make(chan struct{}),
|
ctx: ctx,
|
||||||
closeFunc: func() {},
|
userCloseFunc: func() {},
|
||||||
interceptor: defaultClientInterceptor,
|
interceptor: defaultClientInterceptor,
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, o := range opts {
|
for _, o := range opts {
|
||||||
@ -150,8 +154,8 @@ func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) err
|
|||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
case c.calls <- call:
|
case c.calls <- call:
|
||||||
case <-c.done:
|
case <-c.ctx.Done():
|
||||||
return c.err
|
return c.error()
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@ -159,16 +163,15 @@ func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) err
|
|||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
case err := <-errs:
|
case err := <-errs:
|
||||||
return filterCloseErr(err)
|
return filterCloseErr(err)
|
||||||
case <-c.done:
|
case <-c.ctx.Done():
|
||||||
return c.err
|
return c.error()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Close() error {
|
func (c *Client) Close() error {
|
||||||
c.closeOnce.Do(func() {
|
c.closeOnce.Do(func() {
|
||||||
close(c.closed)
|
c.closed()
|
||||||
})
|
})
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -178,51 +181,82 @@ type message struct {
|
|||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) run() {
|
type receiver struct {
|
||||||
var (
|
wg *sync.WaitGroup
|
||||||
streamID uint32 = 1
|
messages chan *message
|
||||||
waiters = make(map[uint32]*callRequest)
|
err error
|
||||||
calls = c.calls
|
}
|
||||||
incoming = make(chan *message)
|
|
||||||
shutdown = make(chan struct{})
|
|
||||||
shutdownErr error
|
|
||||||
)
|
|
||||||
|
|
||||||
go func() {
|
func (r *receiver) run(ctx context.Context, c *channel) {
|
||||||
defer close(shutdown)
|
defer r.wg.Done()
|
||||||
|
|
||||||
// start one more goroutine to recv messages without blocking.
|
for {
|
||||||
for {
|
select {
|
||||||
mh, p, err := c.channel.recv(context.TODO())
|
case <-ctx.Done():
|
||||||
|
r.err = ctx.Err()
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
mh, p, err := c.recv()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_, ok := status.FromError(err)
|
_, ok := status.FromError(err)
|
||||||
if !ok {
|
if !ok {
|
||||||
// treat all errors that are not an rpc status as terminal.
|
// treat all errors that are not an rpc status as terminal.
|
||||||
// all others poison the connection.
|
// all others poison the connection.
|
||||||
shutdownErr = err
|
r.err = filterCloseErr(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
select {
|
select {
|
||||||
case incoming <- &message{
|
case r.messages <- &message{
|
||||||
messageHeader: mh,
|
messageHeader: mh,
|
||||||
p: p[:mh.Length],
|
p: p[:mh.Length],
|
||||||
err: err,
|
err: err,
|
||||||
}:
|
}:
|
||||||
case <-c.done:
|
case <-ctx.Done():
|
||||||
|
r.err = ctx.Err()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}
|
||||||
|
}
|
||||||
|
|
||||||
defer c.conn.Close()
|
func (c *Client) run() {
|
||||||
defer close(c.done)
|
var (
|
||||||
defer c.closeFunc()
|
streamID uint32 = 1
|
||||||
|
waiters = make(map[uint32]*callRequest)
|
||||||
|
calls = c.calls
|
||||||
|
incoming = make(chan *message)
|
||||||
|
receiversDone = make(chan struct{})
|
||||||
|
wg sync.WaitGroup
|
||||||
|
)
|
||||||
|
|
||||||
|
// broadcast the shutdown error to the remaining waiters.
|
||||||
|
abortWaiters := func(wErr error) {
|
||||||
|
for _, waiter := range waiters {
|
||||||
|
waiter.errs <- wErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
recv := &receiver{
|
||||||
|
wg: &wg,
|
||||||
|
messages: incoming,
|
||||||
|
}
|
||||||
|
wg.Add(1)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
wg.Wait()
|
||||||
|
close(receiversDone)
|
||||||
|
}()
|
||||||
|
go recv.run(c.ctx, c.channel)
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
c.conn.Close()
|
||||||
|
c.userCloseFunc()
|
||||||
|
}()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case call := <-calls:
|
case call := <-calls:
|
||||||
if err := c.send(call.ctx, streamID, messageTypeRequest, call.req); err != nil {
|
if err := c.send(streamID, messageTypeRequest, call.req); err != nil {
|
||||||
call.errs <- err
|
call.errs <- err
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -238,41 +272,42 @@ func (c *Client) run() {
|
|||||||
|
|
||||||
call.errs <- c.recv(call.resp, msg)
|
call.errs <- c.recv(call.resp, msg)
|
||||||
delete(waiters, msg.StreamID)
|
delete(waiters, msg.StreamID)
|
||||||
case <-shutdown:
|
case <-receiversDone:
|
||||||
if shutdownErr != nil {
|
// all the receivers have exited
|
||||||
shutdownErr = filterCloseErr(shutdownErr)
|
if recv.err != nil {
|
||||||
} else {
|
c.setError(recv.err)
|
||||||
shutdownErr = ErrClosed
|
|
||||||
}
|
|
||||||
|
|
||||||
shutdownErr = errors.Wrapf(shutdownErr, "ttrpc: client shutting down")
|
|
||||||
|
|
||||||
c.err = shutdownErr
|
|
||||||
for _, waiter := range waiters {
|
|
||||||
waiter.errs <- shutdownErr
|
|
||||||
}
|
}
|
||||||
|
// don't return out, let the close of the context trigger the abort of waiters
|
||||||
c.Close()
|
c.Close()
|
||||||
return
|
case <-c.ctx.Done():
|
||||||
case <-c.closed:
|
abortWaiters(c.error())
|
||||||
if c.err == nil {
|
|
||||||
c.err = ErrClosed
|
|
||||||
}
|
|
||||||
// broadcast the shutdown error to the remaining waiters.
|
|
||||||
for _, waiter := range waiters {
|
|
||||||
waiter.errs <- c.err
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) send(ctx context.Context, streamID uint32, mtype messageType, msg interface{}) error {
|
func (c *Client) error() error {
|
||||||
|
c.errOnce.Do(func() {
|
||||||
|
if c.err == nil {
|
||||||
|
c.err = ErrClosed
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return c.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) setError(err error) {
|
||||||
|
c.errOnce.Do(func() {
|
||||||
|
c.err = err
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) send(streamID uint32, mtype messageType, msg interface{}) error {
|
||||||
p, err := c.codec.Marshal(msg)
|
p, err := c.codec.Marshal(msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.channel.send(ctx, streamID, mtype, p)
|
return c.channel.send(streamID, mtype, p)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) recv(resp *Response, msg *message) error {
|
func (c *Client) recv(resp *Response, msg *message) error {
|
||||||
@ -293,22 +328,21 @@ func (c *Client) recv(resp *Response, msg *message) error {
|
|||||||
//
|
//
|
||||||
// This purposely ignores errors with a wrapped cause.
|
// This purposely ignores errors with a wrapped cause.
|
||||||
func filterCloseErr(err error) error {
|
func filterCloseErr(err error) error {
|
||||||
if err == nil {
|
switch {
|
||||||
|
case err == nil:
|
||||||
return nil
|
return nil
|
||||||
}
|
case err == io.EOF:
|
||||||
|
|
||||||
if err == io.EOF {
|
|
||||||
return ErrClosed
|
return ErrClosed
|
||||||
}
|
case errors.Cause(err) == io.EOF:
|
||||||
|
|
||||||
if strings.Contains(err.Error(), "use of closed network connection") {
|
|
||||||
return ErrClosed
|
return ErrClosed
|
||||||
}
|
case strings.Contains(err.Error(), "use of closed network connection"):
|
||||||
|
return ErrClosed
|
||||||
// if we have an epipe on a write, we cast to errclosed
|
default:
|
||||||
if oerr, ok := err.(*net.OpError); ok && oerr.Op == "write" {
|
// if we have an epipe on a write, we cast to errclosed
|
||||||
if serr, ok := oerr.Err.(*os.SyscallError); ok && serr.Err == syscall.EPIPE {
|
if oerr, ok := err.(*net.OpError); ok && oerr.Op == "write" {
|
||||||
return ErrClosed
|
if serr, ok := oerr.Err.(*os.SyscallError); ok && serr.Err == syscall.EPIPE {
|
||||||
|
return ErrClosed
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -344,7 +344,7 @@ func (c *serverConn) run(sctx context.Context) {
|
|||||||
default: // proceed
|
default: // proceed
|
||||||
}
|
}
|
||||||
|
|
||||||
mh, p, err := ch.recv(ctx)
|
mh, p, err := ch.recv()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
status, ok := status.FromError(err)
|
status, ok := status.FromError(err)
|
||||||
if !ok {
|
if !ok {
|
||||||
@ -441,7 +441,7 @@ func (c *serverConn) run(sctx context.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := ch.send(ctx, response.id, messageTypeResponse, p); err != nil {
|
if err := ch.send(response.id, messageTypeResponse, p); err != nil {
|
||||||
logrus.WithError(err).Error("failed sending message on channel")
|
logrus.WithError(err).Error("failed sending message on channel")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -351,7 +351,7 @@ func TestClientEOF(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// shutdown the server so the client stops receiving stuff.
|
// shutdown the server so the client stops receiving stuff.
|
||||||
if err := server.Shutdown(ctx); err != nil {
|
if err := server.Close(); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if err := <-errs; err != ErrServerClosed {
|
if err := <-errs; err != ErrServerClosed {
|
||||||
|
Loading…
Reference in New Issue
Block a user