diff --git a/server.go b/server.go index e4dd10b..2efda2b 100644 --- a/server.go +++ b/server.go @@ -121,12 +121,18 @@ func (s *Server) Serve(ctx context.Context, l net.Listener) error { approved, handshake, err := handshaker.Handshake(ctx, conn) if err != nil { - logrus.WithError(err).Errorf("ttrpc: refusing connection after handshake") + logrus.WithError(err).Error("ttrpc: refusing connection after handshake") + conn.Close() + continue + } + + sc, err := s.newConn(approved, handshake) + if err != nil { + logrus.WithError(err).Error("ttrpc: create connection failed") conn.Close() continue } - sc := s.newConn(approved, handshake) go sc.run(ctx) } } @@ -145,15 +151,20 @@ func (s *Server) Shutdown(ctx context.Context) error { ticker := time.NewTicker(200 * time.Millisecond) defer ticker.Stop() for { - if s.closeIdleConns() { - return lnerr + s.closeIdleConns() + + if s.countConnection() == 0 { + break } + select { case <-ctx.Done(): return ctx.Err() case <-ticker.C: } } + + return lnerr } // Close the server without waiting for active connections. @@ -205,11 +216,18 @@ func (s *Server) closeListeners() error { return err } -func (s *Server) addConnection(c *serverConn) { +func (s *Server) addConnection(c *serverConn) error { s.mu.Lock() defer s.mu.Unlock() + select { + case <-s.done: + return ErrServerClosed + default: + } + s.connections[c] = struct{}{} + return nil } func (s *Server) delConnection(c *serverConn) { @@ -226,20 +244,17 @@ func (s *Server) countConnection() int { return len(s.connections) } -func (s *Server) closeIdleConns() bool { +func (s *Server) closeIdleConns() { s.mu.Lock() defer s.mu.Unlock() - quiescent := true + for c := range s.connections { - st, ok := c.getState() - if !ok || st != connStateIdle { - quiescent = false + if st, ok := c.getState(); !ok || st == connStateActive { continue } c.close() delete(s.connections, c) } - return quiescent } type connState int @@ -263,7 +278,7 @@ func (cs connState) String() string { } } -func (s *Server) newConn(conn net.Conn, handshake interface{}) *serverConn { +func (s *Server) newConn(conn net.Conn, handshake interface{}) (*serverConn, error) { c := &serverConn{ server: s, conn: conn, @@ -271,8 +286,11 @@ func (s *Server) newConn(conn net.Conn, handshake interface{}) *serverConn { shutdown: make(chan struct{}), } c.setState(connStateIdle) - s.addConnection(c) - return c + if err := s.addConnection(c); err != nil { + c.close() + return nil, err + } + return c, nil } type serverConn struct { diff --git a/server_test.go b/server_test.go index 77ba4c9..ee8570c 100644 --- a/server_test.go +++ b/server_test.go @@ -201,20 +201,18 @@ func TestServerListenerClosed(t *testing.T) { func TestServerShutdown(t *testing.T) { const ncalls = 5 var ( - ctx = context.Background() - server = mustServer(t)(NewServer()) - addr, listener = newTestListener(t) - shutdownStarted = make(chan struct{}) - shutdownFinished = make(chan struct{}) - handlersStarted = make(chan struct{}) - handlersStartedCloseOnce sync.Once - proceed = make(chan struct{}) - serveErrs = make(chan error, 1) - callwg sync.WaitGroup - callErrs = make(chan error, ncalls) - shutdownErrs = make(chan error, 1) - client, cleanup = newTestClient(t, addr) - _, cleanup2 = newTestClient(t, addr) // secondary connection + ctx = context.Background() + server = mustServer(t)(NewServer()) + addr, listener = newTestListener(t) + shutdownStarted = make(chan struct{}) + shutdownFinished = make(chan struct{}) + handlersStarted sync.WaitGroup + proceed = make(chan struct{}) + serveErrs = make(chan error, 1) + callErrs = make(chan error, ncalls) + shutdownErrs = make(chan error, 1) + client, cleanup = newTestClient(t, addr) + _, cleanup2 = newTestClient(t, addr) // secondary connection ) defer cleanup() defer cleanup2() @@ -227,7 +225,7 @@ func TestServerShutdown(t *testing.T) { return nil, err } - handlersStartedCloseOnce.Do(func() { close(handlersStarted) }) + handlersStarted.Done() <-proceed return &internal.TestPayload{Foo: "waited"}, nil }, @@ -238,20 +236,18 @@ func TestServerShutdown(t *testing.T) { }() // send a series of requests that will get blocked - for i := 0; i < 5; i++ { - callwg.Add(1) + for i := 0; i < ncalls; i++ { + handlersStarted.Add(1) go func(i int) { - callwg.Done() tp := internal.TestPayload{Foo: "half" + fmt.Sprint(i)} callErrs <- client.Call(ctx, serviceName, "Test", &tp, &tp) }(i) } - <-handlersStarted + handlersStarted.Wait() go func() { close(shutdownStarted) shutdownErrs <- server.Shutdown(ctx) - // server.Close() close(shutdownFinished) }()