From 19445fddca137f3bfb0dc4dd1bd37e6f11b6cbda Mon Sep 17 00:00:00 2001 From: Austin Vazquez Date: Fri, 24 Feb 2023 22:42:07 +0000 Subject: [PATCH] Fix server shutdown logic Simplify close idle connections logic in server shutdown to be more intuitive. Modify add connection logic to check if server has been shutdown before adding any new connections. Modify test to make all calls before server shutdown. Signed-off-by: Austin Vazquez --- server.go | 46 ++++++++++++++++++++++++++++++++-------------- server_test.go | 36 ++++++++++++++++-------------------- 2 files changed, 48 insertions(+), 34 deletions(-) diff --git a/server.go b/server.go index e4dd10b..2efda2b 100644 --- a/server.go +++ b/server.go @@ -121,12 +121,18 @@ func (s *Server) Serve(ctx context.Context, l net.Listener) error { approved, handshake, err := handshaker.Handshake(ctx, conn) if err != nil { - logrus.WithError(err).Errorf("ttrpc: refusing connection after handshake") + logrus.WithError(err).Error("ttrpc: refusing connection after handshake") + conn.Close() + continue + } + + sc, err := s.newConn(approved, handshake) + if err != nil { + logrus.WithError(err).Error("ttrpc: create connection failed") conn.Close() continue } - sc := s.newConn(approved, handshake) go sc.run(ctx) } } @@ -145,15 +151,20 @@ func (s *Server) Shutdown(ctx context.Context) error { ticker := time.NewTicker(200 * time.Millisecond) defer ticker.Stop() for { - if s.closeIdleConns() { - return lnerr + s.closeIdleConns() + + if s.countConnection() == 0 { + break } + select { case <-ctx.Done(): return ctx.Err() case <-ticker.C: } } + + return lnerr } // Close the server without waiting for active connections. @@ -205,11 +216,18 @@ func (s *Server) closeListeners() error { return err } -func (s *Server) addConnection(c *serverConn) { +func (s *Server) addConnection(c *serverConn) error { s.mu.Lock() defer s.mu.Unlock() + select { + case <-s.done: + return ErrServerClosed + default: + } + s.connections[c] = struct{}{} + return nil } func (s *Server) delConnection(c *serverConn) { @@ -226,20 +244,17 @@ func (s *Server) countConnection() int { return len(s.connections) } -func (s *Server) closeIdleConns() bool { +func (s *Server) closeIdleConns() { s.mu.Lock() defer s.mu.Unlock() - quiescent := true + for c := range s.connections { - st, ok := c.getState() - if !ok || st != connStateIdle { - quiescent = false + if st, ok := c.getState(); !ok || st == connStateActive { continue } c.close() delete(s.connections, c) } - return quiescent } type connState int @@ -263,7 +278,7 @@ func (cs connState) String() string { } } -func (s *Server) newConn(conn net.Conn, handshake interface{}) *serverConn { +func (s *Server) newConn(conn net.Conn, handshake interface{}) (*serverConn, error) { c := &serverConn{ server: s, conn: conn, @@ -271,8 +286,11 @@ func (s *Server) newConn(conn net.Conn, handshake interface{}) *serverConn { shutdown: make(chan struct{}), } c.setState(connStateIdle) - s.addConnection(c) - return c + if err := s.addConnection(c); err != nil { + c.close() + return nil, err + } + return c, nil } type serverConn struct { diff --git a/server_test.go b/server_test.go index 77ba4c9..ee8570c 100644 --- a/server_test.go +++ b/server_test.go @@ -201,20 +201,18 @@ func TestServerListenerClosed(t *testing.T) { func TestServerShutdown(t *testing.T) { const ncalls = 5 var ( - ctx = context.Background() - server = mustServer(t)(NewServer()) - addr, listener = newTestListener(t) - shutdownStarted = make(chan struct{}) - shutdownFinished = make(chan struct{}) - handlersStarted = make(chan struct{}) - handlersStartedCloseOnce sync.Once - proceed = make(chan struct{}) - serveErrs = make(chan error, 1) - callwg sync.WaitGroup - callErrs = make(chan error, ncalls) - shutdownErrs = make(chan error, 1) - client, cleanup = newTestClient(t, addr) - _, cleanup2 = newTestClient(t, addr) // secondary connection + ctx = context.Background() + server = mustServer(t)(NewServer()) + addr, listener = newTestListener(t) + shutdownStarted = make(chan struct{}) + shutdownFinished = make(chan struct{}) + handlersStarted sync.WaitGroup + proceed = make(chan struct{}) + serveErrs = make(chan error, 1) + callErrs = make(chan error, ncalls) + shutdownErrs = make(chan error, 1) + client, cleanup = newTestClient(t, addr) + _, cleanup2 = newTestClient(t, addr) // secondary connection ) defer cleanup() defer cleanup2() @@ -227,7 +225,7 @@ func TestServerShutdown(t *testing.T) { return nil, err } - handlersStartedCloseOnce.Do(func() { close(handlersStarted) }) + handlersStarted.Done() <-proceed return &internal.TestPayload{Foo: "waited"}, nil }, @@ -238,20 +236,18 @@ func TestServerShutdown(t *testing.T) { }() // send a series of requests that will get blocked - for i := 0; i < 5; i++ { - callwg.Add(1) + for i := 0; i < ncalls; i++ { + handlersStarted.Add(1) go func(i int) { - callwg.Done() tp := internal.TestPayload{Foo: "half" + fmt.Sprint(i)} callErrs <- client.Call(ctx, serviceName, "Test", &tp, &tp) }(i) } - <-handlersStarted + handlersStarted.Wait() go func() { close(shutdownStarted) shutdownErrs <- server.Shutdown(ctx) - // server.Close() close(shutdownFinished) }()