Merge pull request #130 from austinvazquez/fix-server-shutdown
Fix server shutdown logic
This commit is contained in:
commit
7e006e71c5
46
server.go
46
server.go
@ -121,12 +121,18 @@ func (s *Server) Serve(ctx context.Context, l net.Listener) error {
|
||||
|
||||
approved, handshake, err := handshaker.Handshake(ctx, conn)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Errorf("ttrpc: refusing connection after handshake")
|
||||
logrus.WithError(err).Error("ttrpc: refusing connection after handshake")
|
||||
conn.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
sc, err := s.newConn(approved, handshake)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Error("ttrpc: create connection failed")
|
||||
conn.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
sc := s.newConn(approved, handshake)
|
||||
go sc.run(ctx)
|
||||
}
|
||||
}
|
||||
@ -145,15 +151,20 @@ func (s *Server) Shutdown(ctx context.Context) error {
|
||||
ticker := time.NewTicker(200 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
if s.closeIdleConns() {
|
||||
return lnerr
|
||||
s.closeIdleConns()
|
||||
|
||||
if s.countConnection() == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
}
|
||||
}
|
||||
|
||||
return lnerr
|
||||
}
|
||||
|
||||
// Close the server without waiting for active connections.
|
||||
@ -205,11 +216,18 @@ func (s *Server) closeListeners() error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Server) addConnection(c *serverConn) {
|
||||
func (s *Server) addConnection(c *serverConn) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
select {
|
||||
case <-s.done:
|
||||
return ErrServerClosed
|
||||
default:
|
||||
}
|
||||
|
||||
s.connections[c] = struct{}{}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) delConnection(c *serverConn) {
|
||||
@ -226,20 +244,17 @@ func (s *Server) countConnection() int {
|
||||
return len(s.connections)
|
||||
}
|
||||
|
||||
func (s *Server) closeIdleConns() bool {
|
||||
func (s *Server) closeIdleConns() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
quiescent := true
|
||||
|
||||
for c := range s.connections {
|
||||
st, ok := c.getState()
|
||||
if !ok || st != connStateIdle {
|
||||
quiescent = false
|
||||
if st, ok := c.getState(); !ok || st == connStateActive {
|
||||
continue
|
||||
}
|
||||
c.close()
|
||||
delete(s.connections, c)
|
||||
}
|
||||
return quiescent
|
||||
}
|
||||
|
||||
type connState int
|
||||
@ -263,7 +278,7 @@ func (cs connState) String() string {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) newConn(conn net.Conn, handshake interface{}) *serverConn {
|
||||
func (s *Server) newConn(conn net.Conn, handshake interface{}) (*serverConn, error) {
|
||||
c := &serverConn{
|
||||
server: s,
|
||||
conn: conn,
|
||||
@ -271,8 +286,11 @@ func (s *Server) newConn(conn net.Conn, handshake interface{}) *serverConn {
|
||||
shutdown: make(chan struct{}),
|
||||
}
|
||||
c.setState(connStateIdle)
|
||||
s.addConnection(c)
|
||||
return c
|
||||
if err := s.addConnection(c); err != nil {
|
||||
c.close()
|
||||
return nil, err
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
type serverConn struct {
|
||||
|
@ -201,20 +201,18 @@ func TestServerListenerClosed(t *testing.T) {
|
||||
func TestServerShutdown(t *testing.T) {
|
||||
const ncalls = 5
|
||||
var (
|
||||
ctx = context.Background()
|
||||
server = mustServer(t)(NewServer())
|
||||
addr, listener = newTestListener(t)
|
||||
shutdownStarted = make(chan struct{})
|
||||
shutdownFinished = make(chan struct{})
|
||||
handlersStarted = make(chan struct{})
|
||||
handlersStartedCloseOnce sync.Once
|
||||
proceed = make(chan struct{})
|
||||
serveErrs = make(chan error, 1)
|
||||
callwg sync.WaitGroup
|
||||
callErrs = make(chan error, ncalls)
|
||||
shutdownErrs = make(chan error, 1)
|
||||
client, cleanup = newTestClient(t, addr)
|
||||
_, cleanup2 = newTestClient(t, addr) // secondary connection
|
||||
ctx = context.Background()
|
||||
server = mustServer(t)(NewServer())
|
||||
addr, listener = newTestListener(t)
|
||||
shutdownStarted = make(chan struct{})
|
||||
shutdownFinished = make(chan struct{})
|
||||
handlersStarted sync.WaitGroup
|
||||
proceed = make(chan struct{})
|
||||
serveErrs = make(chan error, 1)
|
||||
callErrs = make(chan error, ncalls)
|
||||
shutdownErrs = make(chan error, 1)
|
||||
client, cleanup = newTestClient(t, addr)
|
||||
_, cleanup2 = newTestClient(t, addr) // secondary connection
|
||||
)
|
||||
defer cleanup()
|
||||
defer cleanup2()
|
||||
@ -227,7 +225,7 @@ func TestServerShutdown(t *testing.T) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
handlersStartedCloseOnce.Do(func() { close(handlersStarted) })
|
||||
handlersStarted.Done()
|
||||
<-proceed
|
||||
return &internal.TestPayload{Foo: "waited"}, nil
|
||||
},
|
||||
@ -238,20 +236,18 @@ func TestServerShutdown(t *testing.T) {
|
||||
}()
|
||||
|
||||
// send a series of requests that will get blocked
|
||||
for i := 0; i < 5; i++ {
|
||||
callwg.Add(1)
|
||||
for i := 0; i < ncalls; i++ {
|
||||
handlersStarted.Add(1)
|
||||
go func(i int) {
|
||||
callwg.Done()
|
||||
tp := internal.TestPayload{Foo: "half" + fmt.Sprint(i)}
|
||||
callErrs <- client.Call(ctx, serviceName, "Test", &tp, &tp)
|
||||
}(i)
|
||||
}
|
||||
|
||||
<-handlersStarted
|
||||
handlersStarted.Wait()
|
||||
go func() {
|
||||
close(shutdownStarted)
|
||||
shutdownErrs <- server.Shutdown(ctx)
|
||||
// server.Close()
|
||||
close(shutdownFinished)
|
||||
}()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user