diff --git a/client.go b/client.go index 9dbb3d3..342f7bd 100644 --- a/client.go +++ b/client.go @@ -40,22 +40,30 @@ func NewClient(conn net.Conn) *Client { } func (c *Client) Call(ctx context.Context, service, method string, req, resp interface{}) error { + requestID := atomic.AddUint32(&c.requestID, 2) + if err := c.sendRequest(ctx, requestID, service, method, req); err != nil { + return err + } + + return c.recvResponse(ctx, requestID, resp) +} + +func (c *Client) sendRequest(ctx context.Context, requestID uint32, service, method string, req interface{}) error { payload, err := c.codec.Marshal(req) if err != nil { return err } - requestID := atomic.AddUint32(&c.requestID, 2) request := Request{ Service: service, Method: method, Payload: payload, } - if err := c.send(ctx, requestID, &request); err != nil { - return err - } + return c.send(ctx, requestID, &request) +} +func (c *Client) recvResponse(ctx context.Context, requestID uint32, resp interface{}) error { var response Response if err := c.recv(ctx, requestID, &response); err != nil { return err @@ -160,6 +168,10 @@ func (c *Client) run() { // start one more goroutine to recv messages without blocking. for { var p [messageLengthMax]byte + // TODO(stevvooe): Something still isn't quite right with error + // handling on the client-side, causing EOFs to come through. We + // need other fixes in this changeset, so we'll address this + // correctly later. mh, err := c.channel.recv(context.TODO(), p[:]) select { case incoming <- received{ @@ -187,13 +199,12 @@ func (c *Client) run() { } waiters[req.id] = req case r := <-incoming: - if r.err != nil { - c.err = r.err - return - } - if waiter, ok := waiters[r.mh.StreamID]; ok { - waiter.err <- proto.Unmarshal(r.p, waiter.msg.(proto.Message)) + if r.err != nil { + waiter.err <- r.err + } else { + waiter.err <- proto.Unmarshal(r.p, waiter.msg.(proto.Message)) + } } else { queued[r.mh.StreamID] = r } diff --git a/server.go b/server.go index 407068f..845d160 100644 --- a/server.go +++ b/server.go @@ -2,21 +2,38 @@ package ttrpc import ( "context" + "math/rand" "net" + "sync" + "sync/atomic" + "time" "github.com/containerd/containerd/log" + "github.com/pkg/errors" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) +var ( + ErrServerClosed = errors.New("ttrpc: server close") +) + type Server struct { services *serviceSet codec codec + + mu sync.Mutex + listeners map[net.Listener]struct{} + connections map[*serverConn]struct{} // all connections to current state + done chan struct{} // marks point at which we stop serving requests } func NewServer() *Server { return &Server{ - services: newServiceSet(), + services: newServiceSet(), + done: make(chan struct{}), + listeners: make(map[net.Listener]struct{}), + connections: make(map[*serverConn]struct{}), } } @@ -24,28 +41,210 @@ func (s *Server) Register(name string, methods map[string]Method) { s.services.register(name, methods) } -func (s *Server) Shutdown(ctx context.Context) error { - // TODO(stevvooe): Wait on connection shutdown. - return nil -} - func (s *Server) Serve(l net.Listener) error { + s.addListener(l) + defer s.closeListener(l) + + var ( + ctx = context.Background() + backoff time.Duration + ) + for { conn, err := l.Accept() if err != nil { - log.L.WithError(err).Error("failed accept") - continue + select { + case <-s.done: + return ErrServerClosed + default: + } + + if terr, ok := err.(interface { + Temporary() bool + }); ok && terr.Temporary() { + if backoff == 0 { + backoff = time.Millisecond + } else { + backoff *= 2 + } + + if max := time.Second; backoff > max { + backoff = max + } + + sleep := time.Duration(rand.Int63n(int64(backoff))) + log.L.WithError(err).Errorf("ttrpc: failed accept; backoff %v", sleep) + time.Sleep(sleep) + continue + } + + return err } - go s.handleConn(conn) + backoff = 0 + sc := s.newConn(conn) + go sc.run(ctx) } +} + +func (s *Server) Shutdown(ctx context.Context) error { + s.mu.Lock() + lnerr := s.closeListeners() + select { + case <-s.done: + default: + // protected by mutex + close(s.done) + } + s.mu.Unlock() + + ticker := time.NewTicker(200 * time.Millisecond) + defer ticker.Stop() + for { + if s.closeIdleConns() { + return lnerr + } + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + } + } + + return lnerr +} + +// Close the server without waiting for active connections. +func (s *Server) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + + select { + case <-s.done: + default: + // protected by mutex + close(s.done) + } + + err := s.closeListeners() + for c := range s.connections { + c.close() + delete(s.connections, c) + } + + return err +} + +func (s *Server) addListener(l net.Listener) { + s.mu.Lock() + defer s.mu.Unlock() + s.listeners[l] = struct{}{} +} + +func (s *Server) closeListener(l net.Listener) error { + s.mu.Lock() + defer s.mu.Unlock() + + return s.closeListenerLocked(l) +} + +func (s *Server) closeListenerLocked(l net.Listener) error { + defer delete(s.listeners, l) + return l.Close() +} + +func (s *Server) closeListeners() error { + var err error + for l := range s.listeners { + if cerr := s.closeListenerLocked(l); cerr != nil && err == nil { + err = cerr + } + } + return err +} + +func (s *Server) addConnection(c *serverConn) { + s.mu.Lock() + defer s.mu.Unlock() + + s.connections[c] = struct{}{} +} + +func (s *Server) closeIdleConns() bool { + s.mu.Lock() + defer s.mu.Unlock() + quiescent := true + for c := range s.connections { + st, ok := c.getState() + if !ok || st != connStateIdle { + quiescent = false + continue + } + c.close() + delete(s.connections, c) + } + return quiescent +} + +type connState int + +const ( + connStateActive = iota + 1 // outstanding requests + connStateIdle // no requests + connStateClosed // closed connection +) + +func (cs connState) String() string { + switch cs { + case connStateActive: + return "active" + case connStateIdle: + return "idle" + case connStateClosed: + return "closed" + default: + return "unknown" + } +} + +func (s *Server) newConn(conn net.Conn) *serverConn { + c := &serverConn{ + server: s, + conn: conn, + shutdown: make(chan struct{}), + } + c.setState(connStateIdle) + s.addConnection(c) + return c +} + +type serverConn struct { + server *Server + conn net.Conn + state atomic.Value + + shutdownOnce sync.Once + shutdown chan struct{} // forced shutdown, used by close +} + +func (c *serverConn) getState() (connState, bool) { + cs, ok := c.state.Load().(connState) + return cs, ok +} + +func (c *serverConn) setState(newstate connState) { + c.state.Store(newstate) +} + +func (c *serverConn) close() error { + c.shutdownOnce.Do(func() { + close(c.shutdown) + }) return nil } -func (s *Server) handleConn(conn net.Conn) { - defer conn.Close() - +func (c *serverConn) run(sctx context.Context) { type ( request struct { id uint32 @@ -59,21 +258,33 @@ func (s *Server) handleConn(conn net.Conn) { ) var ( - ch = newChannel(conn, conn) - ctx, cancel = context.WithCancel(context.Background()) - responses = make(chan response) - requests = make(chan request) - recvErr = make(chan error, 1) - done = make(chan struct{}) + ch = newChannel(c.conn, c.conn) + ctx, cancel = context.WithCancel(sctx) + active int + state connState = connStateIdle + responses = make(chan response) + requests = make(chan request) + recvErr = make(chan error, 1) + shutdown = c.shutdown + done = make(chan struct{}) ) + defer c.conn.Close() defer cancel() defer close(done) - go func() { + go func(recvErr chan error) { defer close(recvErr) var p [messageLengthMax]byte for { + select { + case <-c.shutdown: + return + case <-done: + return + default: // proceed + } + mh, err := ch.recv(ctx, p[:]) if err != nil { recvErr <- err @@ -85,14 +296,7 @@ func (s *Server) handleConn(conn net.Conn) { continue } - var req Request - if err := s.codec.Unmarshal(p[:mh.Length], &req); err != nil { - recvErr <- err - return - } - - if mh.StreamID%2 != 1 { - // enforce odd client initiated identifiers. + sendImmediate := func(code codes.Code, msg string, args ...interface{}) bool { select { case responses <- response{ // even though we've had an invalid stream id, we send it @@ -100,30 +304,68 @@ func (s *Server) handleConn(conn net.Conn) { // stream id was bad. id: mh.StreamID, resp: &Response{ - Status: status.New(codes.InvalidArgument, "StreamID must be odd for client initiated streams").Proto(), + Status: status.Newf(code, msg, args...).Proto(), }, }: + return true + case <-c.shutdown: + return false case <-done: + return false } + } + var req Request + if err := c.server.codec.Unmarshal(p[:mh.Length], &req); err != nil { + if !sendImmediate(codes.InvalidArgument, "unmarshal request error: %v", err) { + return + } continue } + if mh.StreamID%2 != 1 { + // enforce odd client initiated identifiers. + if !sendImmediate(codes.InvalidArgument, "StreamID must be odd for client initiated streams") { + return + } + continue + } + + // Forward the request to the main loop. We don't wait on s.done + // because we have already accepted the client request. select { case requests <- request{ id: mh.StreamID, req: &req, }: case <-done: + return } } - }() + }(recvErr) for { + newstate := state + switch { + case active > 0: + newstate = connStateActive + shutdown = nil + case active == 0: + newstate = connStateIdle + shutdown = c.shutdown // only enable this branch in idle mode + } + + if newstate != state { + c.setState(newstate) + state = newstate + } + select { case request := <-requests: + active++ + go func(id uint32) { - p, status := s.services.call(ctx, request.req.Service, request.req.Method, request.req.Payload) + p, status := c.server.services.call(ctx, request.req.Service, request.req.Method, request.req.Payload) resp := &Response{ Status: status.Proto(), Payload: p, @@ -138,7 +380,7 @@ func (s *Server) handleConn(conn net.Conn) { } }(request.id) case response := <-responses: - p, err := s.codec.Marshal(response.resp) + p, err := c.server.codec.Marshal(response.resp) if err != nil { log.L.WithError(err).Error("failed marshaling response") return @@ -147,8 +389,17 @@ func (s *Server) handleConn(conn net.Conn) { log.L.WithError(err).Error("failed sending message on channel") return } + + active-- case err := <-recvErr: - log.L.WithError(err).Error("error receiving message") + // TODO(stevvooe): Not wildly clear what we should do in this + // branch. Basically, it means that we are no longer receiving + // requests due to a terminal error. + recvErr = nil // connection is now "closing" + if err != nil { + log.L.WithError(err).Error("error receiving message") + } + case <-shutdown: return } } diff --git a/server_test.go b/server_test.go index dd4baf6..bbbe95d 100644 --- a/server_test.go +++ b/server_test.go @@ -7,6 +7,7 @@ import ( "reflect" "strings" "testing" + "time" "github.com/gogo/protobuf/proto" ) @@ -74,34 +75,26 @@ func init() { func TestServer(t *testing.T) { var ( - ctx = context.Background() - server = NewServer() - testImpl = &testingServer{} + ctx = context.Background() + server = NewServer() + testImpl = &testingServer{} + addr, listener = newTestListener(t) + client, cleanup = newTestClient(t, addr) + tclient = newTestingClient(client) ) - registerTestingService(server, testImpl) - - addr := "\x00" + t.Name() - listener, err := net.Listen("unix", addr) - if err != nil { - t.Fatal(err) - } defer listener.Close() + defer cleanup() + + registerTestingService(server, testImpl) go server.Serve(listener) defer server.Shutdown(ctx) - conn, err := net.Dial("unix", addr) - if err != nil { - t.Fatal(err) - } - defer conn.Close() - client := newTestingClient(NewClient(conn)) - const calls = 2 results := make(chan callResult, 2) - go roundTrip(ctx, t, client, "bar", results) - go roundTrip(ctx, t, client, "baz", results) + go roundTrip(ctx, t, tclient, "bar", results) + go roundTrip(ctx, t, tclient, "baz", results) for i := 0; i < calls; i++ { result := <-results @@ -111,6 +104,140 @@ func TestServer(t *testing.T) { } } +func newTestClient(t *testing.T, addr string) (*Client, func()) { + conn, err := net.Dial("unix", addr) + if err != nil { + t.Fatal(err) + } + client := NewClient(conn) + return client, func() { + conn.Close() + client.Close() + } +} + +func TestServerListenerClosed(t *testing.T) { + var ( + server = NewServer() + _, listener = newTestListener(t) + errs = make(chan error, 1) + ) + + go func() { + errs <- server.Serve(listener) + }() + + if err := listener.Close(); err != nil { + t.Fatal(err) + } + + err := <-errs + if err == nil { + t.Fatal(err) + } +} + +func TestServerShutdown(t *testing.T) { + var ( + ctx = context.Background() + server = NewServer() + addr, listener = newTestListener(t) + shutdownStarted = make(chan struct{}) + shutdownFinished = make(chan struct{}) + errs = make(chan error, 1) + client, cleanup = newTestClient(t, addr) + _, cleanup2 = newTestClient(t, addr) // secondary connection + ) + defer cleanup() + defer cleanup2() + defer server.Close() + + // register a service that takes until we tell it to stop + server.Register(serviceName, map[string]Method{ + "Test": func(ctx context.Context, unmarshal func(interface{}) error) (interface{}, error) { + var req testPayload + if err := unmarshal(&req); err != nil { + return nil, err + } + return &testPayload{Foo: "waited"}, nil + }, + }) + + go func() { + errs <- server.Serve(listener) + }() + + tp := testPayload{Foo: "half"} + // send a few half requests + if err := client.sendRequest(ctx, 1, "testService", "Test", &tp); err != nil { + t.Fatal(err) + } + if err := client.sendRequest(ctx, 3, "testService", "Test", &tp); err != nil { + t.Fatal(err) + } + + time.Sleep(1 * time.Millisecond) // ensure that requests make it through before shutdown + go func() { + close(shutdownStarted) + server.Shutdown(ctx) + // server.Close() + close(shutdownFinished) + }() + + <-shutdownStarted + + // receive the responses + if err := client.recvResponse(ctx, 1, &tp); err != nil { + t.Fatal(err) + } + + if err := client.recvResponse(ctx, 3, &tp); err != nil { + t.Fatal(err) + } + + <-shutdownFinished + checkServerShutdown(t, server) +} + +func TestServerClose(t *testing.T) { + var ( + server = NewServer() + _, listener = newTestListener(t) + startClose = make(chan struct{}) + errs = make(chan error, 1) + ) + + go func() { + close(startClose) + errs <- server.Serve(listener) + }() + + <-startClose + if err := server.Close(); err != nil { + t.Fatal(err) + } + + err := <-errs + if err != ErrServerClosed { + t.Fatal("expected an error from a closed server", err) + } + + checkServerShutdown(t, server) +} + +func checkServerShutdown(t *testing.T, server *Server) { + t.Helper() + server.mu.Lock() + defer server.mu.Unlock() + if len(server.listeners) > 0 { + t.Fatalf("expected listeners to be empty: %v", server.listeners) + } + + if len(server.connections) > 0 { + t.Fatalf("expected connections to be empty: %v", server.connections) + } +} + type callResult struct { input *testPayload expected *testPayload @@ -136,3 +263,13 @@ func roundTrip(ctx context.Context, t *testing.T, client *testingClient, value s received: resp, } } + +func newTestListener(t *testing.T) (string, net.Listener) { + addr := "\x00" + t.Name() + listener, err := net.Listen("unix", addr) + if err != nil { + t.Fatal(err) + } + + return addr, listener +}