Merge pull request #130 from austinvazquez/fix-server-shutdown
Fix server shutdown logic
This commit is contained in:
commit
7e006e71c5
46
server.go
46
server.go
@ -121,12 +121,18 @@ func (s *Server) Serve(ctx context.Context, l net.Listener) error {
|
|||||||
|
|
||||||
approved, handshake, err := handshaker.Handshake(ctx, conn)
|
approved, handshake, err := handshaker.Handshake(ctx, conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).Errorf("ttrpc: refusing connection after handshake")
|
logrus.WithError(err).Error("ttrpc: refusing connection after handshake")
|
||||||
|
conn.Close()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
sc, err := s.newConn(approved, handshake)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).Error("ttrpc: create connection failed")
|
||||||
conn.Close()
|
conn.Close()
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
sc := s.newConn(approved, handshake)
|
|
||||||
go sc.run(ctx)
|
go sc.run(ctx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -145,15 +151,20 @@ func (s *Server) Shutdown(ctx context.Context) error {
|
|||||||
ticker := time.NewTicker(200 * time.Millisecond)
|
ticker := time.NewTicker(200 * time.Millisecond)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
for {
|
for {
|
||||||
if s.closeIdleConns() {
|
s.closeIdleConns()
|
||||||
return lnerr
|
|
||||||
|
if s.countConnection() == 0 {
|
||||||
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return lnerr
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close the server without waiting for active connections.
|
// Close the server without waiting for active connections.
|
||||||
@ -205,11 +216,18 @@ func (s *Server) closeListeners() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) addConnection(c *serverConn) {
|
func (s *Server) addConnection(c *serverConn) error {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-s.done:
|
||||||
|
return ErrServerClosed
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
s.connections[c] = struct{}{}
|
s.connections[c] = struct{}{}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) delConnection(c *serverConn) {
|
func (s *Server) delConnection(c *serverConn) {
|
||||||
@ -226,20 +244,17 @@ func (s *Server) countConnection() int {
|
|||||||
return len(s.connections)
|
return len(s.connections)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) closeIdleConns() bool {
|
func (s *Server) closeIdleConns() {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
quiescent := true
|
|
||||||
for c := range s.connections {
|
for c := range s.connections {
|
||||||
st, ok := c.getState()
|
if st, ok := c.getState(); !ok || st == connStateActive {
|
||||||
if !ok || st != connStateIdle {
|
|
||||||
quiescent = false
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
c.close()
|
c.close()
|
||||||
delete(s.connections, c)
|
delete(s.connections, c)
|
||||||
}
|
}
|
||||||
return quiescent
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type connState int
|
type connState int
|
||||||
@ -263,7 +278,7 @@ func (cs connState) String() string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) newConn(conn net.Conn, handshake interface{}) *serverConn {
|
func (s *Server) newConn(conn net.Conn, handshake interface{}) (*serverConn, error) {
|
||||||
c := &serverConn{
|
c := &serverConn{
|
||||||
server: s,
|
server: s,
|
||||||
conn: conn,
|
conn: conn,
|
||||||
@ -271,8 +286,11 @@ func (s *Server) newConn(conn net.Conn, handshake interface{}) *serverConn {
|
|||||||
shutdown: make(chan struct{}),
|
shutdown: make(chan struct{}),
|
||||||
}
|
}
|
||||||
c.setState(connStateIdle)
|
c.setState(connStateIdle)
|
||||||
s.addConnection(c)
|
if err := s.addConnection(c); err != nil {
|
||||||
return c
|
c.close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type serverConn struct {
|
type serverConn struct {
|
||||||
|
@ -201,20 +201,18 @@ func TestServerListenerClosed(t *testing.T) {
|
|||||||
func TestServerShutdown(t *testing.T) {
|
func TestServerShutdown(t *testing.T) {
|
||||||
const ncalls = 5
|
const ncalls = 5
|
||||||
var (
|
var (
|
||||||
ctx = context.Background()
|
ctx = context.Background()
|
||||||
server = mustServer(t)(NewServer())
|
server = mustServer(t)(NewServer())
|
||||||
addr, listener = newTestListener(t)
|
addr, listener = newTestListener(t)
|
||||||
shutdownStarted = make(chan struct{})
|
shutdownStarted = make(chan struct{})
|
||||||
shutdownFinished = make(chan struct{})
|
shutdownFinished = make(chan struct{})
|
||||||
handlersStarted = make(chan struct{})
|
handlersStarted sync.WaitGroup
|
||||||
handlersStartedCloseOnce sync.Once
|
proceed = make(chan struct{})
|
||||||
proceed = make(chan struct{})
|
serveErrs = make(chan error, 1)
|
||||||
serveErrs = make(chan error, 1)
|
callErrs = make(chan error, ncalls)
|
||||||
callwg sync.WaitGroup
|
shutdownErrs = make(chan error, 1)
|
||||||
callErrs = make(chan error, ncalls)
|
client, cleanup = newTestClient(t, addr)
|
||||||
shutdownErrs = make(chan error, 1)
|
_, cleanup2 = newTestClient(t, addr) // secondary connection
|
||||||
client, cleanup = newTestClient(t, addr)
|
|
||||||
_, cleanup2 = newTestClient(t, addr) // secondary connection
|
|
||||||
)
|
)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
defer cleanup2()
|
defer cleanup2()
|
||||||
@ -227,7 +225,7 @@ func TestServerShutdown(t *testing.T) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
handlersStartedCloseOnce.Do(func() { close(handlersStarted) })
|
handlersStarted.Done()
|
||||||
<-proceed
|
<-proceed
|
||||||
return &internal.TestPayload{Foo: "waited"}, nil
|
return &internal.TestPayload{Foo: "waited"}, nil
|
||||||
},
|
},
|
||||||
@ -238,20 +236,18 @@ func TestServerShutdown(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
// send a series of requests that will get blocked
|
// send a series of requests that will get blocked
|
||||||
for i := 0; i < 5; i++ {
|
for i := 0; i < ncalls; i++ {
|
||||||
callwg.Add(1)
|
handlersStarted.Add(1)
|
||||||
go func(i int) {
|
go func(i int) {
|
||||||
callwg.Done()
|
|
||||||
tp := internal.TestPayload{Foo: "half" + fmt.Sprint(i)}
|
tp := internal.TestPayload{Foo: "half" + fmt.Sprint(i)}
|
||||||
callErrs <- client.Call(ctx, serviceName, "Test", &tp, &tp)
|
callErrs <- client.Call(ctx, serviceName, "Test", &tp, &tp)
|
||||||
}(i)
|
}(i)
|
||||||
}
|
}
|
||||||
|
|
||||||
<-handlersStarted
|
handlersStarted.Wait()
|
||||||
go func() {
|
go func() {
|
||||||
close(shutdownStarted)
|
close(shutdownStarted)
|
||||||
shutdownErrs <- server.Shutdown(ctx)
|
shutdownErrs <- server.Shutdown(ctx)
|
||||||
// server.Close()
|
|
||||||
close(shutdownFinished)
|
close(shutdownFinished)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user