Merge pull request #130 from austinvazquez/fix-server-shutdown

Fix server shutdown logic
This commit is contained in:
Derek McGowan 2023-03-08 08:17:47 -08:00 committed by GitHub
commit 7e006e71c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 48 additions and 34 deletions

View File

@ -121,12 +121,18 @@ func (s *Server) Serve(ctx context.Context, l net.Listener) error {
approved, handshake, err := handshaker.Handshake(ctx, conn) approved, handshake, err := handshaker.Handshake(ctx, conn)
if err != nil { if err != nil {
logrus.WithError(err).Errorf("ttrpc: refusing connection after handshake") logrus.WithError(err).Error("ttrpc: refusing connection after handshake")
conn.Close()
continue
}
sc, err := s.newConn(approved, handshake)
if err != nil {
logrus.WithError(err).Error("ttrpc: create connection failed")
conn.Close() conn.Close()
continue continue
} }
sc := s.newConn(approved, handshake)
go sc.run(ctx) go sc.run(ctx)
} }
} }
@ -145,15 +151,20 @@ func (s *Server) Shutdown(ctx context.Context) error {
ticker := time.NewTicker(200 * time.Millisecond) ticker := time.NewTicker(200 * time.Millisecond)
defer ticker.Stop() defer ticker.Stop()
for { for {
if s.closeIdleConns() { s.closeIdleConns()
return lnerr
if s.countConnection() == 0 {
break
} }
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return ctx.Err()
case <-ticker.C: case <-ticker.C:
} }
} }
return lnerr
} }
// Close the server without waiting for active connections. // Close the server without waiting for active connections.
@ -205,11 +216,18 @@ func (s *Server) closeListeners() error {
return err return err
} }
func (s *Server) addConnection(c *serverConn) { func (s *Server) addConnection(c *serverConn) error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
select {
case <-s.done:
return ErrServerClosed
default:
}
s.connections[c] = struct{}{} s.connections[c] = struct{}{}
return nil
} }
func (s *Server) delConnection(c *serverConn) { func (s *Server) delConnection(c *serverConn) {
@ -226,20 +244,17 @@ func (s *Server) countConnection() int {
return len(s.connections) return len(s.connections)
} }
func (s *Server) closeIdleConns() bool { func (s *Server) closeIdleConns() {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
quiescent := true
for c := range s.connections { for c := range s.connections {
st, ok := c.getState() if st, ok := c.getState(); !ok || st == connStateActive {
if !ok || st != connStateIdle {
quiescent = false
continue continue
} }
c.close() c.close()
delete(s.connections, c) delete(s.connections, c)
} }
return quiescent
} }
type connState int type connState int
@ -263,7 +278,7 @@ func (cs connState) String() string {
} }
} }
func (s *Server) newConn(conn net.Conn, handshake interface{}) *serverConn { func (s *Server) newConn(conn net.Conn, handshake interface{}) (*serverConn, error) {
c := &serverConn{ c := &serverConn{
server: s, server: s,
conn: conn, conn: conn,
@ -271,8 +286,11 @@ func (s *Server) newConn(conn net.Conn, handshake interface{}) *serverConn {
shutdown: make(chan struct{}), shutdown: make(chan struct{}),
} }
c.setState(connStateIdle) c.setState(connStateIdle)
s.addConnection(c) if err := s.addConnection(c); err != nil {
return c c.close()
return nil, err
}
return c, nil
} }
type serverConn struct { type serverConn struct {

View File

@ -201,20 +201,18 @@ func TestServerListenerClosed(t *testing.T) {
func TestServerShutdown(t *testing.T) { func TestServerShutdown(t *testing.T) {
const ncalls = 5 const ncalls = 5
var ( var (
ctx = context.Background() ctx = context.Background()
server = mustServer(t)(NewServer()) server = mustServer(t)(NewServer())
addr, listener = newTestListener(t) addr, listener = newTestListener(t)
shutdownStarted = make(chan struct{}) shutdownStarted = make(chan struct{})
shutdownFinished = make(chan struct{}) shutdownFinished = make(chan struct{})
handlersStarted = make(chan struct{}) handlersStarted sync.WaitGroup
handlersStartedCloseOnce sync.Once proceed = make(chan struct{})
proceed = make(chan struct{}) serveErrs = make(chan error, 1)
serveErrs = make(chan error, 1) callErrs = make(chan error, ncalls)
callwg sync.WaitGroup shutdownErrs = make(chan error, 1)
callErrs = make(chan error, ncalls) client, cleanup = newTestClient(t, addr)
shutdownErrs = make(chan error, 1) _, cleanup2 = newTestClient(t, addr) // secondary connection
client, cleanup = newTestClient(t, addr)
_, cleanup2 = newTestClient(t, addr) // secondary connection
) )
defer cleanup() defer cleanup()
defer cleanup2() defer cleanup2()
@ -227,7 +225,7 @@ func TestServerShutdown(t *testing.T) {
return nil, err return nil, err
} }
handlersStartedCloseOnce.Do(func() { close(handlersStarted) }) handlersStarted.Done()
<-proceed <-proceed
return &internal.TestPayload{Foo: "waited"}, nil return &internal.TestPayload{Foo: "waited"}, nil
}, },
@ -238,20 +236,18 @@ func TestServerShutdown(t *testing.T) {
}() }()
// send a series of requests that will get blocked // send a series of requests that will get blocked
for i := 0; i < 5; i++ { for i := 0; i < ncalls; i++ {
callwg.Add(1) handlersStarted.Add(1)
go func(i int) { go func(i int) {
callwg.Done()
tp := internal.TestPayload{Foo: "half" + fmt.Sprint(i)} tp := internal.TestPayload{Foo: "half" + fmt.Sprint(i)}
callErrs <- client.Call(ctx, serviceName, "Test", &tp, &tp) callErrs <- client.Call(ctx, serviceName, "Test", &tp, &tp)
}(i) }(i)
} }
<-handlersStarted handlersStarted.Wait()
go func() { go func() {
close(shutdownStarted) close(shutdownStarted)
shutdownErrs <- server.Shutdown(ctx) shutdownErrs <- server.Shutdown(ctx)
// server.Close()
close(shutdownFinished) close(shutdownFinished)
}() }()