Fix server shutdown logic

Simplify close idle connections logic in server shutdown to be more
intuitive. Modify add connection logic to check if server has been
shutdown before adding any new connections. Modify test to make all
calls before server shutdown.

Signed-off-by: Austin Vazquez <macedonv@amazon.com>
This commit is contained in:
Austin Vazquez 2023-02-24 22:42:07 +00:00
parent 32fab23746
commit 19445fddca
2 changed files with 48 additions and 34 deletions

View File

@ -121,12 +121,18 @@ func (s *Server) Serve(ctx context.Context, l net.Listener) error {
approved, handshake, err := handshaker.Handshake(ctx, conn) approved, handshake, err := handshaker.Handshake(ctx, conn)
if err != nil { if err != nil {
logrus.WithError(err).Errorf("ttrpc: refusing connection after handshake") logrus.WithError(err).Error("ttrpc: refusing connection after handshake")
conn.Close()
continue
}
sc, err := s.newConn(approved, handshake)
if err != nil {
logrus.WithError(err).Error("ttrpc: create connection failed")
conn.Close() conn.Close()
continue continue
} }
sc := s.newConn(approved, handshake)
go sc.run(ctx) go sc.run(ctx)
} }
} }
@ -145,15 +151,20 @@ func (s *Server) Shutdown(ctx context.Context) error {
ticker := time.NewTicker(200 * time.Millisecond) ticker := time.NewTicker(200 * time.Millisecond)
defer ticker.Stop() defer ticker.Stop()
for { for {
if s.closeIdleConns() { s.closeIdleConns()
return lnerr
if s.countConnection() == 0 {
break
} }
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return ctx.Err()
case <-ticker.C: case <-ticker.C:
} }
} }
return lnerr
} }
// Close the server without waiting for active connections. // Close the server without waiting for active connections.
@ -205,11 +216,18 @@ func (s *Server) closeListeners() error {
return err return err
} }
func (s *Server) addConnection(c *serverConn) { func (s *Server) addConnection(c *serverConn) error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
select {
case <-s.done:
return ErrServerClosed
default:
}
s.connections[c] = struct{}{} s.connections[c] = struct{}{}
return nil
} }
func (s *Server) delConnection(c *serverConn) { func (s *Server) delConnection(c *serverConn) {
@ -226,20 +244,17 @@ func (s *Server) countConnection() int {
return len(s.connections) return len(s.connections)
} }
func (s *Server) closeIdleConns() bool { func (s *Server) closeIdleConns() {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
quiescent := true
for c := range s.connections { for c := range s.connections {
st, ok := c.getState() if st, ok := c.getState(); !ok || st == connStateActive {
if !ok || st != connStateIdle {
quiescent = false
continue continue
} }
c.close() c.close()
delete(s.connections, c) delete(s.connections, c)
} }
return quiescent
} }
type connState int type connState int
@ -263,7 +278,7 @@ func (cs connState) String() string {
} }
} }
func (s *Server) newConn(conn net.Conn, handshake interface{}) *serverConn { func (s *Server) newConn(conn net.Conn, handshake interface{}) (*serverConn, error) {
c := &serverConn{ c := &serverConn{
server: s, server: s,
conn: conn, conn: conn,
@ -271,8 +286,11 @@ func (s *Server) newConn(conn net.Conn, handshake interface{}) *serverConn {
shutdown: make(chan struct{}), shutdown: make(chan struct{}),
} }
c.setState(connStateIdle) c.setState(connStateIdle)
s.addConnection(c) if err := s.addConnection(c); err != nil {
return c c.close()
return nil, err
}
return c, nil
} }
type serverConn struct { type serverConn struct {

View File

@ -201,20 +201,18 @@ func TestServerListenerClosed(t *testing.T) {
func TestServerShutdown(t *testing.T) { func TestServerShutdown(t *testing.T) {
const ncalls = 5 const ncalls = 5
var ( var (
ctx = context.Background() ctx = context.Background()
server = mustServer(t)(NewServer()) server = mustServer(t)(NewServer())
addr, listener = newTestListener(t) addr, listener = newTestListener(t)
shutdownStarted = make(chan struct{}) shutdownStarted = make(chan struct{})
shutdownFinished = make(chan struct{}) shutdownFinished = make(chan struct{})
handlersStarted = make(chan struct{}) handlersStarted sync.WaitGroup
handlersStartedCloseOnce sync.Once proceed = make(chan struct{})
proceed = make(chan struct{}) serveErrs = make(chan error, 1)
serveErrs = make(chan error, 1) callErrs = make(chan error, ncalls)
callwg sync.WaitGroup shutdownErrs = make(chan error, 1)
callErrs = make(chan error, ncalls) client, cleanup = newTestClient(t, addr)
shutdownErrs = make(chan error, 1) _, cleanup2 = newTestClient(t, addr) // secondary connection
client, cleanup = newTestClient(t, addr)
_, cleanup2 = newTestClient(t, addr) // secondary connection
) )
defer cleanup() defer cleanup()
defer cleanup2() defer cleanup2()
@ -227,7 +225,7 @@ func TestServerShutdown(t *testing.T) {
return nil, err return nil, err
} }
handlersStartedCloseOnce.Do(func() { close(handlersStarted) }) handlersStarted.Done()
<-proceed <-proceed
return &internal.TestPayload{Foo: "waited"}, nil return &internal.TestPayload{Foo: "waited"}, nil
}, },
@ -238,20 +236,18 @@ func TestServerShutdown(t *testing.T) {
}() }()
// send a series of requests that will get blocked // send a series of requests that will get blocked
for i := 0; i < 5; i++ { for i := 0; i < ncalls; i++ {
callwg.Add(1) handlersStarted.Add(1)
go func(i int) { go func(i int) {
callwg.Done()
tp := internal.TestPayload{Foo: "half" + fmt.Sprint(i)} tp := internal.TestPayload{Foo: "half" + fmt.Sprint(i)}
callErrs <- client.Call(ctx, serviceName, "Test", &tp, &tp) callErrs <- client.Call(ctx, serviceName, "Test", &tp, &tp)
}(i) }(i)
} }
<-handlersStarted handlersStarted.Wait()
go func() { go func() {
close(shutdownStarted) close(shutdownStarted)
shutdownErrs <- server.Shutdown(ctx) shutdownErrs <- server.Shutdown(ctx)
// server.Close()
close(shutdownFinished) close(shutdownFinished)
}() }()