diff --git a/client.go b/client.go index f0fb9ec..0325e60 100644 --- a/client.go +++ b/client.go @@ -4,19 +4,18 @@ import ( "context" "net" "sync" - "sync/atomic" + "github.com/containerd/containerd/log" "github.com/gogo/protobuf/proto" "github.com/pkg/errors" "google.golang.org/grpc/status" ) type Client struct { - codec codec - channel *channel - requestID uint32 - sendRequests chan sendRequest - recvRequests chan recvRequest + codec codec + conn net.Conn + channel *channel + calls chan *callRequest closed chan struct{} closeOnce sync.Once @@ -26,58 +25,76 @@ type Client struct { func NewClient(conn net.Conn) *Client { c := &Client{ - codec: codec{}, - requestID: 1, - channel: newChannel(conn, conn), - sendRequests: make(chan sendRequest), - recvRequests: make(chan recvRequest), - closed: make(chan struct{}), - done: make(chan struct{}), + codec: codec{}, + conn: conn, + channel: newChannel(conn, conn), + calls: make(chan *callRequest), + closed: make(chan struct{}), + done: make(chan struct{}), } go c.run() return c } -func (c *Client) Call(ctx context.Context, service, method string, req, resp interface{}) error { - requestID := atomic.AddUint32(&c.requestID, 2) - if err := c.sendRequest(ctx, requestID, service, method, req); err != nil { - return err - } - - return c.recvResponse(ctx, requestID, resp) +type callRequest struct { + ctx context.Context + req *Request + resp *Response // response will be written back here + errs chan error // error written here on completion } -func (c *Client) sendRequest(ctx context.Context, requestID uint32, service, method string, req interface{}) error { +func (c *Client) Call(ctx context.Context, service, method string, req, resp interface{}) error { payload, err := c.codec.Marshal(req) if err != nil { return err } - request := Request{ - Service: service, - Method: method, - Payload: payload, - } + var ( + creq = &Request{ + Service: service, + Method: method, + Payload: payload, + } - return c.send(ctx, requestID, &request) -} + cresp = &Response{} + ) -func (c *Client) recvResponse(ctx context.Context, requestID uint32, resp interface{}) error { - var response Response - if err := c.recv(ctx, requestID, &response); err != nil { + if err := c.dispatch(ctx, creq, cresp); err != nil { return err } - if err := c.codec.Unmarshal(response.Payload, resp); err != nil { + if err := c.codec.Unmarshal(cresp.Payload, resp); err != nil { return err } - if response.Status == nil { + if cresp.Status == nil { return errors.New("no status provided on response") } - return status.ErrorProto(response.Status) + return status.ErrorProto(cresp.Status) +} + +func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) error { + errs := make(chan error, 1) + call := &callRequest{ + req: req, + resp: resp, + errs: errs, + } + + select { + case c.calls <- call: + case <-c.done: + return c.err + } + + select { + case err := <-errs: + return err + case <-c.done: + return c.err + } } func (c *Client) Close() error { @@ -88,95 +105,42 @@ func (c *Client) Close() error { return nil } -type sendRequest struct { - ctx context.Context - id uint32 - msg interface{} - err chan error -} - -func (c *Client) send(ctx context.Context, id uint32, msg interface{}) error { - errs := make(chan error, 1) - select { - case c.sendRequests <- sendRequest{ - ctx: ctx, - id: id, - msg: msg, - err: errs, - }: - case <-ctx.Done(): - return ctx.Err() - case <-c.done: - return c.err - } - - select { - case err := <-errs: - return err - case <-ctx.Done(): - return ctx.Err() - case <-c.done: - return c.err - } -} - -type recvRequest struct { - id uint32 - msg interface{} - err chan error -} - -func (c *Client) recv(ctx context.Context, id uint32, msg interface{}) error { - errs := make(chan error, 1) - select { - case c.recvRequests <- recvRequest{ - id: id, - msg: msg, - err: errs, - }: - case <-c.done: - return c.err - case <-ctx.Done(): - return ctx.Err() - } - - select { - case err := <-errs: - return err - case <-c.done: - return c.err - case <-ctx.Done(): - return ctx.Err() - } -} - -type received struct { - mh messageHeader +type message struct { + messageHeader p []byte err error } func (c *Client) run() { - defer close(c.done) var ( - waiters = map[uint32]recvRequest{} - queued = map[uint32]received{} // messages unmatched by waiter - incoming = make(chan received) + streamID uint32 = 1 + waiters = make(map[uint32]*callRequest) + calls = c.calls + incoming = make(chan *message) + shutdown = make(chan struct{}) + shutdownErr error ) go func() { + defer close(shutdown) + // start one more goroutine to recv messages without blocking. for { - // TODO(stevvooe): Something still isn't quite right with error - // handling on the client-side, causing EOFs to come through. We - // need other fixes in this changeset, so we'll address this - // correctly later. mh, p, err := c.channel.recv(context.TODO()) + if err != nil { + _, ok := status.FromError(err) + if !ok { + // treat all errors that are not an rpc status as terminal. + // all others poison the connection. + shutdownErr = err + return + } + } select { - case incoming <- received{ - mh: mh, - p: p[:mh.Length], - err: err, + case incoming <- &message{ + messageHeader: mh, + p: p[:mh.Length], + err: err, }: case <-c.done: return @@ -184,32 +148,63 @@ func (c *Client) run() { } }() + defer c.conn.Close() + defer close(c.done) + for { select { - case req := <-c.sendRequests: - if p, err := proto.Marshal(req.msg.(proto.Message)); err != nil { - req.err <- err - } else { - req.err <- c.channel.send(req.ctx, req.id, messageTypeRequest, p) + case call := <-calls: + if err := c.send(call.ctx, streamID, messageTypeRequest, call.req); err != nil { + call.errs <- err + continue } - case req := <-c.recvRequests: - if r, ok := queued[req.id]; ok { - req.err <- proto.Unmarshal(r.p, req.msg.(proto.Message)) + + waiters[streamID] = call + streamID += 2 // enforce odd client initiated request ids + case msg := <-incoming: + call, ok := waiters[msg.StreamID] + if !ok { + log.L.Errorf("ttrpc: received message for unknown channel %v", msg.StreamID) + continue } - waiters[req.id] = req - case r := <-incoming: - if waiter, ok := waiters[r.mh.StreamID]; ok { - if r.err != nil { - waiter.err <- r.err - } else { - waiter.err <- proto.Unmarshal(r.p, waiter.msg.(proto.Message)) - c.channel.putmbuf(r.p) - } - } else { - queued[r.mh.StreamID] = r + + call.errs <- c.recv(call.resp, msg) + delete(waiters, msg.StreamID) + case <-shutdown: + shutdownErr = errors.Wrapf(shutdownErr, "ttrpc: client shutting down") + c.err = shutdownErr + for _, waiter := range waiters { + waiter.errs <- shutdownErr } + c.Close() + return case <-c.closed: + // broadcast the shutdown error to the remaining waiters. + for _, waiter := range waiters { + waiter.errs <- shutdownErr + } return } } } + +func (c *Client) send(ctx context.Context, streamID uint32, mtype messageType, msg interface{}) error { + p, err := c.codec.Marshal(msg) + if err != nil { + return err + } + + return c.channel.send(ctx, streamID, mtype, p) +} + +func (c *Client) recv(resp *Response, msg *message) error { + if msg.err != nil { + return msg.err + } + + if msg.Type != messageTypeResponse { + return errors.New("unkown message type received") + } + + return proto.Unmarshal(msg.p, resp) +} diff --git a/server.go b/server.go index 279b5be..7830378 100644 --- a/server.go +++ b/server.go @@ -110,8 +110,6 @@ func (s *Server) Shutdown(ctx context.Context) error { case <-ticker.C: } } - - return lnerr } // Close the server without waiting for active connections. @@ -373,7 +371,6 @@ func (c *serverConn) run(sctx context.Context) { select { case request := <-requests: active++ - go func(id uint32) { p, status := c.server.services.call(ctx, request.req.Service, request.req.Method, request.req.Payload) resp := &Response{ @@ -395,6 +392,7 @@ func (c *serverConn) run(sctx context.Context) { log.L.WithError(err).Error("failed marshaling response") return } + if err := ch.send(ctx, response.id, messageTypeResponse, p); err != nil { log.L.WithError(err).Error("failed sending message on channel") return diff --git a/server_test.go b/server_test.go index 8ce6fd4..d01fc07 100644 --- a/server_test.go +++ b/server_test.go @@ -6,8 +6,8 @@ import ( "net" "reflect" "strings" + "sync" "testing" - "time" "github.com/gogo/protobuf/proto" "google.golang.org/grpc/codes" @@ -159,19 +159,25 @@ func TestServerListenerClosed(t *testing.T) { } func TestServerShutdown(t *testing.T) { + const ncalls = 5 var ( - ctx = context.Background() - server = NewServer() - addr, listener = newTestListener(t) - shutdownStarted = make(chan struct{}) - shutdownFinished = make(chan struct{}) - errs = make(chan error, 1) - client, cleanup = newTestClient(t, addr) - _, cleanup2 = newTestClient(t, addr) // secondary connection + ctx = context.Background() + server = NewServer() + addr, listener = newTestListener(t) + shutdownStarted = make(chan struct{}) + shutdownFinished = make(chan struct{}) + handlersStarted = make(chan struct{}) + handlersStartedCloseOnce sync.Once + proceed = make(chan struct{}) + serveErrs = make(chan error, 1) + callwg sync.WaitGroup + callErrs = make(chan error, ncalls) + shutdownErrs = make(chan error, 1) + client, cleanup = newTestClient(t, addr) + _, cleanup2 = newTestClient(t, addr) // secondary connection ) defer cleanup() defer cleanup2() - defer server.Close() // register a service that takes until we tell it to stop server.Register(serviceName, map[string]Method{ @@ -180,43 +186,52 @@ func TestServerShutdown(t *testing.T) { if err := unmarshal(&req); err != nil { return nil, err } + + handlersStartedCloseOnce.Do(func() { close(handlersStarted) }) + <-proceed return &testPayload{Foo: "waited"}, nil }, }) go func() { - errs <- server.Serve(listener) + serveErrs <- server.Serve(listener) }() - tp := testPayload{Foo: "half"} - // send a few half requests - if err := client.sendRequest(ctx, 1, "testService", "Test", &tp); err != nil { - t.Fatal(err) - } - if err := client.sendRequest(ctx, 3, "testService", "Test", &tp); err != nil { - t.Fatal(err) + // send a series of requests that will get blocked + for i := 0; i < 5; i++ { + callwg.Add(1) + go func(i int) { + callwg.Done() + tp := testPayload{Foo: "half" + fmt.Sprint(i)} + callErrs <- client.Call(ctx, serviceName, "Test", &tp, &tp) + }(i) } - time.Sleep(1 * time.Millisecond) // ensure that requests make it through before shutdown + <-handlersStarted go func() { close(shutdownStarted) - server.Shutdown(ctx) + shutdownErrs <- server.Shutdown(ctx) // server.Close() close(shutdownFinished) }() <-shutdownStarted - - // receive the responses - if err := client.recvResponse(ctx, 1, &tp); err != nil { - t.Fatal(err) - } - - if err := client.recvResponse(ctx, 3, &tp); err != nil { - t.Fatal(err) - } - + close(proceed) <-shutdownFinished + + for i := 0; i < ncalls; i++ { + if err := <-callErrs; err != nil { + t.Fatal(err) + } + } + + if err := <-shutdownErrs; err != nil { + t.Fatal(err) + } + + if err := <-serveErrs; err != ErrServerClosed { + t.Fatal(err) + } checkServerShutdown(t, server) } @@ -281,6 +296,42 @@ func TestOversizeCall(t *testing.T) { } } +func TestClientEOF(t *testing.T) { + var ( + ctx = context.Background() + server = NewServer() + addr, listener = newTestListener(t) + errs = make(chan error, 1) + client, cleanup = newTestClient(t, addr) + ) + defer cleanup() + defer listener.Close() + go func() { + errs <- server.Serve(listener) + }() + + registerTestingService(server, &testingServer{}) + + tp := &testPayload{} + // do a regular call + if err := client.Call(ctx, serviceName, "Test", tp, tp); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // shutdown the server so the client stops receiving stuff. + if err := server.Shutdown(ctx); err != nil { + t.Fatal(err) + } + if err := <-errs; err != ErrServerClosed { + t.Fatal(err) + } + + // server shutdown, but we still make a call. + if err := client.Call(ctx, serviceName, "Test", tp, tp); err == nil { + t.Fatalf("expected error when calling against shutdown server") + } +} + func checkServerShutdown(t *testing.T, server *Server) { t.Helper() server.mu.Lock()