Merge pull request #8 from stevvooe/handle-client-eof
ttrpc: refactor client to better handle EOF
This commit is contained in:
commit
8f839f204c
253
client.go
253
client.go
@ -4,19 +4,18 @@ import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/containerd/containerd/log"
|
||||
"github.com/gogo/protobuf/proto"
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
codec codec
|
||||
channel *channel
|
||||
requestID uint32
|
||||
sendRequests chan sendRequest
|
||||
recvRequests chan recvRequest
|
||||
codec codec
|
||||
conn net.Conn
|
||||
channel *channel
|
||||
calls chan *callRequest
|
||||
|
||||
closed chan struct{}
|
||||
closeOnce sync.Once
|
||||
@ -26,58 +25,76 @@ type Client struct {
|
||||
|
||||
func NewClient(conn net.Conn) *Client {
|
||||
c := &Client{
|
||||
codec: codec{},
|
||||
requestID: 1,
|
||||
channel: newChannel(conn, conn),
|
||||
sendRequests: make(chan sendRequest),
|
||||
recvRequests: make(chan recvRequest),
|
||||
closed: make(chan struct{}),
|
||||
done: make(chan struct{}),
|
||||
codec: codec{},
|
||||
conn: conn,
|
||||
channel: newChannel(conn, conn),
|
||||
calls: make(chan *callRequest),
|
||||
closed: make(chan struct{}),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
go c.run()
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Client) Call(ctx context.Context, service, method string, req, resp interface{}) error {
|
||||
requestID := atomic.AddUint32(&c.requestID, 2)
|
||||
if err := c.sendRequest(ctx, requestID, service, method, req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.recvResponse(ctx, requestID, resp)
|
||||
type callRequest struct {
|
||||
ctx context.Context
|
||||
req *Request
|
||||
resp *Response // response will be written back here
|
||||
errs chan error // error written here on completion
|
||||
}
|
||||
|
||||
func (c *Client) sendRequest(ctx context.Context, requestID uint32, service, method string, req interface{}) error {
|
||||
func (c *Client) Call(ctx context.Context, service, method string, req, resp interface{}) error {
|
||||
payload, err := c.codec.Marshal(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
request := Request{
|
||||
Service: service,
|
||||
Method: method,
|
||||
Payload: payload,
|
||||
}
|
||||
var (
|
||||
creq = &Request{
|
||||
Service: service,
|
||||
Method: method,
|
||||
Payload: payload,
|
||||
}
|
||||
|
||||
return c.send(ctx, requestID, &request)
|
||||
}
|
||||
cresp = &Response{}
|
||||
)
|
||||
|
||||
func (c *Client) recvResponse(ctx context.Context, requestID uint32, resp interface{}) error {
|
||||
var response Response
|
||||
if err := c.recv(ctx, requestID, &response); err != nil {
|
||||
if err := c.dispatch(ctx, creq, cresp); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := c.codec.Unmarshal(response.Payload, resp); err != nil {
|
||||
if err := c.codec.Unmarshal(cresp.Payload, resp); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if response.Status == nil {
|
||||
if cresp.Status == nil {
|
||||
return errors.New("no status provided on response")
|
||||
}
|
||||
|
||||
return status.ErrorProto(response.Status)
|
||||
return status.ErrorProto(cresp.Status)
|
||||
}
|
||||
|
||||
func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) error {
|
||||
errs := make(chan error, 1)
|
||||
call := &callRequest{
|
||||
req: req,
|
||||
resp: resp,
|
||||
errs: errs,
|
||||
}
|
||||
|
||||
select {
|
||||
case c.calls <- call:
|
||||
case <-c.done:
|
||||
return c.err
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-errs:
|
||||
return err
|
||||
case <-c.done:
|
||||
return c.err
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) Close() error {
|
||||
@ -88,95 +105,42 @@ func (c *Client) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type sendRequest struct {
|
||||
ctx context.Context
|
||||
id uint32
|
||||
msg interface{}
|
||||
err chan error
|
||||
}
|
||||
|
||||
func (c *Client) send(ctx context.Context, id uint32, msg interface{}) error {
|
||||
errs := make(chan error, 1)
|
||||
select {
|
||||
case c.sendRequests <- sendRequest{
|
||||
ctx: ctx,
|
||||
id: id,
|
||||
msg: msg,
|
||||
err: errs,
|
||||
}:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-c.done:
|
||||
return c.err
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-errs:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-c.done:
|
||||
return c.err
|
||||
}
|
||||
}
|
||||
|
||||
type recvRequest struct {
|
||||
id uint32
|
||||
msg interface{}
|
||||
err chan error
|
||||
}
|
||||
|
||||
func (c *Client) recv(ctx context.Context, id uint32, msg interface{}) error {
|
||||
errs := make(chan error, 1)
|
||||
select {
|
||||
case c.recvRequests <- recvRequest{
|
||||
id: id,
|
||||
msg: msg,
|
||||
err: errs,
|
||||
}:
|
||||
case <-c.done:
|
||||
return c.err
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-errs:
|
||||
return err
|
||||
case <-c.done:
|
||||
return c.err
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
type received struct {
|
||||
mh messageHeader
|
||||
type message struct {
|
||||
messageHeader
|
||||
p []byte
|
||||
err error
|
||||
}
|
||||
|
||||
func (c *Client) run() {
|
||||
defer close(c.done)
|
||||
var (
|
||||
waiters = map[uint32]recvRequest{}
|
||||
queued = map[uint32]received{} // messages unmatched by waiter
|
||||
incoming = make(chan received)
|
||||
streamID uint32 = 1
|
||||
waiters = make(map[uint32]*callRequest)
|
||||
calls = c.calls
|
||||
incoming = make(chan *message)
|
||||
shutdown = make(chan struct{})
|
||||
shutdownErr error
|
||||
)
|
||||
|
||||
go func() {
|
||||
defer close(shutdown)
|
||||
|
||||
// start one more goroutine to recv messages without blocking.
|
||||
for {
|
||||
// TODO(stevvooe): Something still isn't quite right with error
|
||||
// handling on the client-side, causing EOFs to come through. We
|
||||
// need other fixes in this changeset, so we'll address this
|
||||
// correctly later.
|
||||
mh, p, err := c.channel.recv(context.TODO())
|
||||
if err != nil {
|
||||
_, ok := status.FromError(err)
|
||||
if !ok {
|
||||
// treat all errors that are not an rpc status as terminal.
|
||||
// all others poison the connection.
|
||||
shutdownErr = err
|
||||
return
|
||||
}
|
||||
}
|
||||
select {
|
||||
case incoming <- received{
|
||||
mh: mh,
|
||||
p: p[:mh.Length],
|
||||
err: err,
|
||||
case incoming <- &message{
|
||||
messageHeader: mh,
|
||||
p: p[:mh.Length],
|
||||
err: err,
|
||||
}:
|
||||
case <-c.done:
|
||||
return
|
||||
@ -184,32 +148,63 @@ func (c *Client) run() {
|
||||
}
|
||||
}()
|
||||
|
||||
defer c.conn.Close()
|
||||
defer close(c.done)
|
||||
|
||||
for {
|
||||
select {
|
||||
case req := <-c.sendRequests:
|
||||
if p, err := proto.Marshal(req.msg.(proto.Message)); err != nil {
|
||||
req.err <- err
|
||||
} else {
|
||||
req.err <- c.channel.send(req.ctx, req.id, messageTypeRequest, p)
|
||||
case call := <-calls:
|
||||
if err := c.send(call.ctx, streamID, messageTypeRequest, call.req); err != nil {
|
||||
call.errs <- err
|
||||
continue
|
||||
}
|
||||
case req := <-c.recvRequests:
|
||||
if r, ok := queued[req.id]; ok {
|
||||
req.err <- proto.Unmarshal(r.p, req.msg.(proto.Message))
|
||||
|
||||
waiters[streamID] = call
|
||||
streamID += 2 // enforce odd client initiated request ids
|
||||
case msg := <-incoming:
|
||||
call, ok := waiters[msg.StreamID]
|
||||
if !ok {
|
||||
log.L.Errorf("ttrpc: received message for unknown channel %v", msg.StreamID)
|
||||
continue
|
||||
}
|
||||
waiters[req.id] = req
|
||||
case r := <-incoming:
|
||||
if waiter, ok := waiters[r.mh.StreamID]; ok {
|
||||
if r.err != nil {
|
||||
waiter.err <- r.err
|
||||
} else {
|
||||
waiter.err <- proto.Unmarshal(r.p, waiter.msg.(proto.Message))
|
||||
c.channel.putmbuf(r.p)
|
||||
}
|
||||
} else {
|
||||
queued[r.mh.StreamID] = r
|
||||
|
||||
call.errs <- c.recv(call.resp, msg)
|
||||
delete(waiters, msg.StreamID)
|
||||
case <-shutdown:
|
||||
shutdownErr = errors.Wrapf(shutdownErr, "ttrpc: client shutting down")
|
||||
c.err = shutdownErr
|
||||
for _, waiter := range waiters {
|
||||
waiter.errs <- shutdownErr
|
||||
}
|
||||
c.Close()
|
||||
return
|
||||
case <-c.closed:
|
||||
// broadcast the shutdown error to the remaining waiters.
|
||||
for _, waiter := range waiters {
|
||||
waiter.errs <- shutdownErr
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) send(ctx context.Context, streamID uint32, mtype messageType, msg interface{}) error {
|
||||
p, err := c.codec.Marshal(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.channel.send(ctx, streamID, mtype, p)
|
||||
}
|
||||
|
||||
func (c *Client) recv(resp *Response, msg *message) error {
|
||||
if msg.err != nil {
|
||||
return msg.err
|
||||
}
|
||||
|
||||
if msg.Type != messageTypeResponse {
|
||||
return errors.New("unkown message type received")
|
||||
}
|
||||
|
||||
return proto.Unmarshal(msg.p, resp)
|
||||
}
|
||||
|
@ -110,8 +110,6 @@ func (s *Server) Shutdown(ctx context.Context) error {
|
||||
case <-ticker.C:
|
||||
}
|
||||
}
|
||||
|
||||
return lnerr
|
||||
}
|
||||
|
||||
// Close the server without waiting for active connections.
|
||||
@ -373,7 +371,6 @@ func (c *serverConn) run(sctx context.Context) {
|
||||
select {
|
||||
case request := <-requests:
|
||||
active++
|
||||
|
||||
go func(id uint32) {
|
||||
p, status := c.server.services.call(ctx, request.req.Service, request.req.Method, request.req.Payload)
|
||||
resp := &Response{
|
||||
@ -395,6 +392,7 @@ func (c *serverConn) run(sctx context.Context) {
|
||||
log.L.WithError(err).Error("failed marshaling response")
|
||||
return
|
||||
}
|
||||
|
||||
if err := ch.send(ctx, response.id, messageTypeResponse, p); err != nil {
|
||||
log.L.WithError(err).Error("failed sending message on channel")
|
||||
return
|
||||
|
111
server_test.go
111
server_test.go
@ -6,8 +6,8 @@ import (
|
||||
"net"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gogo/protobuf/proto"
|
||||
"google.golang.org/grpc/codes"
|
||||
@ -159,19 +159,25 @@ func TestServerListenerClosed(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestServerShutdown(t *testing.T) {
|
||||
const ncalls = 5
|
||||
var (
|
||||
ctx = context.Background()
|
||||
server = NewServer()
|
||||
addr, listener = newTestListener(t)
|
||||
shutdownStarted = make(chan struct{})
|
||||
shutdownFinished = make(chan struct{})
|
||||
errs = make(chan error, 1)
|
||||
client, cleanup = newTestClient(t, addr)
|
||||
_, cleanup2 = newTestClient(t, addr) // secondary connection
|
||||
ctx = context.Background()
|
||||
server = NewServer()
|
||||
addr, listener = newTestListener(t)
|
||||
shutdownStarted = make(chan struct{})
|
||||
shutdownFinished = make(chan struct{})
|
||||
handlersStarted = make(chan struct{})
|
||||
handlersStartedCloseOnce sync.Once
|
||||
proceed = make(chan struct{})
|
||||
serveErrs = make(chan error, 1)
|
||||
callwg sync.WaitGroup
|
||||
callErrs = make(chan error, ncalls)
|
||||
shutdownErrs = make(chan error, 1)
|
||||
client, cleanup = newTestClient(t, addr)
|
||||
_, cleanup2 = newTestClient(t, addr) // secondary connection
|
||||
)
|
||||
defer cleanup()
|
||||
defer cleanup2()
|
||||
defer server.Close()
|
||||
|
||||
// register a service that takes until we tell it to stop
|
||||
server.Register(serviceName, map[string]Method{
|
||||
@ -180,43 +186,52 @@ func TestServerShutdown(t *testing.T) {
|
||||
if err := unmarshal(&req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
handlersStartedCloseOnce.Do(func() { close(handlersStarted) })
|
||||
<-proceed
|
||||
return &testPayload{Foo: "waited"}, nil
|
||||
},
|
||||
})
|
||||
|
||||
go func() {
|
||||
errs <- server.Serve(listener)
|
||||
serveErrs <- server.Serve(listener)
|
||||
}()
|
||||
|
||||
tp := testPayload{Foo: "half"}
|
||||
// send a few half requests
|
||||
if err := client.sendRequest(ctx, 1, "testService", "Test", &tp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := client.sendRequest(ctx, 3, "testService", "Test", &tp); err != nil {
|
||||
t.Fatal(err)
|
||||
// send a series of requests that will get blocked
|
||||
for i := 0; i < 5; i++ {
|
||||
callwg.Add(1)
|
||||
go func(i int) {
|
||||
callwg.Done()
|
||||
tp := testPayload{Foo: "half" + fmt.Sprint(i)}
|
||||
callErrs <- client.Call(ctx, serviceName, "Test", &tp, &tp)
|
||||
}(i)
|
||||
}
|
||||
|
||||
time.Sleep(1 * time.Millisecond) // ensure that requests make it through before shutdown
|
||||
<-handlersStarted
|
||||
go func() {
|
||||
close(shutdownStarted)
|
||||
server.Shutdown(ctx)
|
||||
shutdownErrs <- server.Shutdown(ctx)
|
||||
// server.Close()
|
||||
close(shutdownFinished)
|
||||
}()
|
||||
|
||||
<-shutdownStarted
|
||||
|
||||
// receive the responses
|
||||
if err := client.recvResponse(ctx, 1, &tp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := client.recvResponse(ctx, 3, &tp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
close(proceed)
|
||||
<-shutdownFinished
|
||||
|
||||
for i := 0; i < ncalls; i++ {
|
||||
if err := <-callErrs; err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := <-shutdownErrs; err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := <-serveErrs; err != ErrServerClosed {
|
||||
t.Fatal(err)
|
||||
}
|
||||
checkServerShutdown(t, server)
|
||||
}
|
||||
|
||||
@ -281,6 +296,42 @@ func TestOversizeCall(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientEOF(t *testing.T) {
|
||||
var (
|
||||
ctx = context.Background()
|
||||
server = NewServer()
|
||||
addr, listener = newTestListener(t)
|
||||
errs = make(chan error, 1)
|
||||
client, cleanup = newTestClient(t, addr)
|
||||
)
|
||||
defer cleanup()
|
||||
defer listener.Close()
|
||||
go func() {
|
||||
errs <- server.Serve(listener)
|
||||
}()
|
||||
|
||||
registerTestingService(server, &testingServer{})
|
||||
|
||||
tp := &testPayload{}
|
||||
// do a regular call
|
||||
if err := client.Call(ctx, serviceName, "Test", tp, tp); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// shutdown the server so the client stops receiving stuff.
|
||||
if err := server.Shutdown(ctx); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := <-errs; err != ErrServerClosed {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// server shutdown, but we still make a call.
|
||||
if err := client.Call(ctx, serviceName, "Test", tp, tp); err == nil {
|
||||
t.Fatalf("expected error when calling against shutdown server")
|
||||
}
|
||||
}
|
||||
|
||||
func checkServerShutdown(t *testing.T, server *Server) {
|
||||
t.Helper()
|
||||
server.mu.Lock()
|
||||
|
Loading…
Reference in New Issue
Block a user