Fix server shutdown logic
Simplify close idle connections logic in server shutdown to be more intuitive. Modify add connection logic to check if server has been shutdown before adding any new connections. Modify test to make all calls before server shutdown. Signed-off-by: Austin Vazquez <macedonv@amazon.com>
This commit is contained in:
parent
32fab23746
commit
19445fddca
46
server.go
46
server.go
@ -121,12 +121,18 @@ func (s *Server) Serve(ctx context.Context, l net.Listener) error {
|
||||
|
||||
approved, handshake, err := handshaker.Handshake(ctx, conn)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Errorf("ttrpc: refusing connection after handshake")
|
||||
logrus.WithError(err).Error("ttrpc: refusing connection after handshake")
|
||||
conn.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
sc, err := s.newConn(approved, handshake)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Error("ttrpc: create connection failed")
|
||||
conn.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
sc := s.newConn(approved, handshake)
|
||||
go sc.run(ctx)
|
||||
}
|
||||
}
|
||||
@ -145,15 +151,20 @@ func (s *Server) Shutdown(ctx context.Context) error {
|
||||
ticker := time.NewTicker(200 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
if s.closeIdleConns() {
|
||||
return lnerr
|
||||
s.closeIdleConns()
|
||||
|
||||
if s.countConnection() == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
}
|
||||
}
|
||||
|
||||
return lnerr
|
||||
}
|
||||
|
||||
// Close the server without waiting for active connections.
|
||||
@ -205,11 +216,18 @@ func (s *Server) closeListeners() error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Server) addConnection(c *serverConn) {
|
||||
func (s *Server) addConnection(c *serverConn) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
select {
|
||||
case <-s.done:
|
||||
return ErrServerClosed
|
||||
default:
|
||||
}
|
||||
|
||||
s.connections[c] = struct{}{}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) delConnection(c *serverConn) {
|
||||
@ -226,20 +244,17 @@ func (s *Server) countConnection() int {
|
||||
return len(s.connections)
|
||||
}
|
||||
|
||||
func (s *Server) closeIdleConns() bool {
|
||||
func (s *Server) closeIdleConns() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
quiescent := true
|
||||
|
||||
for c := range s.connections {
|
||||
st, ok := c.getState()
|
||||
if !ok || st != connStateIdle {
|
||||
quiescent = false
|
||||
if st, ok := c.getState(); !ok || st == connStateActive {
|
||||
continue
|
||||
}
|
||||
c.close()
|
||||
delete(s.connections, c)
|
||||
}
|
||||
return quiescent
|
||||
}
|
||||
|
||||
type connState int
|
||||
@ -263,7 +278,7 @@ func (cs connState) String() string {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) newConn(conn net.Conn, handshake interface{}) *serverConn {
|
||||
func (s *Server) newConn(conn net.Conn, handshake interface{}) (*serverConn, error) {
|
||||
c := &serverConn{
|
||||
server: s,
|
||||
conn: conn,
|
||||
@ -271,8 +286,11 @@ func (s *Server) newConn(conn net.Conn, handshake interface{}) *serverConn {
|
||||
shutdown: make(chan struct{}),
|
||||
}
|
||||
c.setState(connStateIdle)
|
||||
s.addConnection(c)
|
||||
return c
|
||||
if err := s.addConnection(c); err != nil {
|
||||
c.close()
|
||||
return nil, err
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
type serverConn struct {
|
||||
|
@ -201,20 +201,18 @@ func TestServerListenerClosed(t *testing.T) {
|
||||
func TestServerShutdown(t *testing.T) {
|
||||
const ncalls = 5
|
||||
var (
|
||||
ctx = context.Background()
|
||||
server = mustServer(t)(NewServer())
|
||||
addr, listener = newTestListener(t)
|
||||
shutdownStarted = make(chan struct{})
|
||||
shutdownFinished = make(chan struct{})
|
||||
handlersStarted = make(chan struct{})
|
||||
handlersStartedCloseOnce sync.Once
|
||||
proceed = make(chan struct{})
|
||||
serveErrs = make(chan error, 1)
|
||||
callwg sync.WaitGroup
|
||||
callErrs = make(chan error, ncalls)
|
||||
shutdownErrs = make(chan error, 1)
|
||||
client, cleanup = newTestClient(t, addr)
|
||||
_, cleanup2 = newTestClient(t, addr) // secondary connection
|
||||
ctx = context.Background()
|
||||
server = mustServer(t)(NewServer())
|
||||
addr, listener = newTestListener(t)
|
||||
shutdownStarted = make(chan struct{})
|
||||
shutdownFinished = make(chan struct{})
|
||||
handlersStarted sync.WaitGroup
|
||||
proceed = make(chan struct{})
|
||||
serveErrs = make(chan error, 1)
|
||||
callErrs = make(chan error, ncalls)
|
||||
shutdownErrs = make(chan error, 1)
|
||||
client, cleanup = newTestClient(t, addr)
|
||||
_, cleanup2 = newTestClient(t, addr) // secondary connection
|
||||
)
|
||||
defer cleanup()
|
||||
defer cleanup2()
|
||||
@ -227,7 +225,7 @@ func TestServerShutdown(t *testing.T) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
handlersStartedCloseOnce.Do(func() { close(handlersStarted) })
|
||||
handlersStarted.Done()
|
||||
<-proceed
|
||||
return &internal.TestPayload{Foo: "waited"}, nil
|
||||
},
|
||||
@ -238,20 +236,18 @@ func TestServerShutdown(t *testing.T) {
|
||||
}()
|
||||
|
||||
// send a series of requests that will get blocked
|
||||
for i := 0; i < 5; i++ {
|
||||
callwg.Add(1)
|
||||
for i := 0; i < ncalls; i++ {
|
||||
handlersStarted.Add(1)
|
||||
go func(i int) {
|
||||
callwg.Done()
|
||||
tp := internal.TestPayload{Foo: "half" + fmt.Sprint(i)}
|
||||
callErrs <- client.Call(ctx, serviceName, "Test", &tp, &tp)
|
||||
}(i)
|
||||
}
|
||||
|
||||
<-handlersStarted
|
||||
handlersStarted.Wait()
|
||||
go func() {
|
||||
close(shutdownStarted)
|
||||
shutdownErrs <- server.Shutdown(ctx)
|
||||
// server.Close()
|
||||
close(shutdownFinished)
|
||||
}()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user