Merge pull request #4 from stevvooe/accept-loop-hardening

ttrpc: implement Close and Shutdown
This commit is contained in:
Stephen Day 2017-11-29 11:13:22 -08:00 committed by GitHub
commit 7e38ac9e99
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 460 additions and 61 deletions

View File

@ -40,22 +40,30 @@ func NewClient(conn net.Conn) *Client {
}
func (c *Client) Call(ctx context.Context, service, method string, req, resp interface{}) error {
requestID := atomic.AddUint32(&c.requestID, 2)
if err := c.sendRequest(ctx, requestID, service, method, req); err != nil {
return err
}
return c.recvResponse(ctx, requestID, resp)
}
func (c *Client) sendRequest(ctx context.Context, requestID uint32, service, method string, req interface{}) error {
payload, err := c.codec.Marshal(req)
if err != nil {
return err
}
requestID := atomic.AddUint32(&c.requestID, 2)
request := Request{
Service: service,
Method: method,
Payload: payload,
}
if err := c.send(ctx, requestID, &request); err != nil {
return err
}
return c.send(ctx, requestID, &request)
}
func (c *Client) recvResponse(ctx context.Context, requestID uint32, resp interface{}) error {
var response Response
if err := c.recv(ctx, requestID, &response); err != nil {
return err
@ -160,6 +168,10 @@ func (c *Client) run() {
// start one more goroutine to recv messages without blocking.
for {
var p [messageLengthMax]byte
// TODO(stevvooe): Something still isn't quite right with error
// handling on the client-side, causing EOFs to come through. We
// need other fixes in this changeset, so we'll address this
// correctly later.
mh, err := c.channel.recv(context.TODO(), p[:])
select {
case incoming <- received{
@ -187,13 +199,12 @@ func (c *Client) run() {
}
waiters[req.id] = req
case r := <-incoming:
if r.err != nil {
c.err = r.err
return
}
if waiter, ok := waiters[r.mh.StreamID]; ok {
waiter.err <- proto.Unmarshal(r.p, waiter.msg.(proto.Message))
if r.err != nil {
waiter.err <- r.err
} else {
waiter.err <- proto.Unmarshal(r.p, waiter.msg.(proto.Message))
}
} else {
queued[r.mh.StreamID] = r
}

315
server.go
View File

@ -2,21 +2,38 @@ package ttrpc
import (
"context"
"math/rand"
"net"
"sync"
"sync/atomic"
"time"
"github.com/containerd/containerd/log"
"github.com/pkg/errors"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
var (
ErrServerClosed = errors.New("ttrpc: server close")
)
type Server struct {
services *serviceSet
codec codec
mu sync.Mutex
listeners map[net.Listener]struct{}
connections map[*serverConn]struct{} // all connections to current state
done chan struct{} // marks point at which we stop serving requests
}
func NewServer() *Server {
return &Server{
services: newServiceSet(),
services: newServiceSet(),
done: make(chan struct{}),
listeners: make(map[net.Listener]struct{}),
connections: make(map[*serverConn]struct{}),
}
}
@ -24,28 +41,210 @@ func (s *Server) Register(name string, methods map[string]Method) {
s.services.register(name, methods)
}
func (s *Server) Shutdown(ctx context.Context) error {
// TODO(stevvooe): Wait on connection shutdown.
return nil
}
func (s *Server) Serve(l net.Listener) error {
s.addListener(l)
defer s.closeListener(l)
var (
ctx = context.Background()
backoff time.Duration
)
for {
conn, err := l.Accept()
if err != nil {
log.L.WithError(err).Error("failed accept")
continue
select {
case <-s.done:
return ErrServerClosed
default:
}
if terr, ok := err.(interface {
Temporary() bool
}); ok && terr.Temporary() {
if backoff == 0 {
backoff = time.Millisecond
} else {
backoff *= 2
}
if max := time.Second; backoff > max {
backoff = max
}
sleep := time.Duration(rand.Int63n(int64(backoff)))
log.L.WithError(err).Errorf("ttrpc: failed accept; backoff %v", sleep)
time.Sleep(sleep)
continue
}
return err
}
go s.handleConn(conn)
backoff = 0
sc := s.newConn(conn)
go sc.run(ctx)
}
}
func (s *Server) Shutdown(ctx context.Context) error {
s.mu.Lock()
lnerr := s.closeListeners()
select {
case <-s.done:
default:
// protected by mutex
close(s.done)
}
s.mu.Unlock()
ticker := time.NewTicker(200 * time.Millisecond)
defer ticker.Stop()
for {
if s.closeIdleConns() {
return lnerr
}
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
}
}
return lnerr
}
// Close the server without waiting for active connections.
func (s *Server) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
select {
case <-s.done:
default:
// protected by mutex
close(s.done)
}
err := s.closeListeners()
for c := range s.connections {
c.close()
delete(s.connections, c)
}
return err
}
func (s *Server) addListener(l net.Listener) {
s.mu.Lock()
defer s.mu.Unlock()
s.listeners[l] = struct{}{}
}
func (s *Server) closeListener(l net.Listener) error {
s.mu.Lock()
defer s.mu.Unlock()
return s.closeListenerLocked(l)
}
func (s *Server) closeListenerLocked(l net.Listener) error {
defer delete(s.listeners, l)
return l.Close()
}
func (s *Server) closeListeners() error {
var err error
for l := range s.listeners {
if cerr := s.closeListenerLocked(l); cerr != nil && err == nil {
err = cerr
}
}
return err
}
func (s *Server) addConnection(c *serverConn) {
s.mu.Lock()
defer s.mu.Unlock()
s.connections[c] = struct{}{}
}
func (s *Server) closeIdleConns() bool {
s.mu.Lock()
defer s.mu.Unlock()
quiescent := true
for c := range s.connections {
st, ok := c.getState()
if !ok || st != connStateIdle {
quiescent = false
continue
}
c.close()
delete(s.connections, c)
}
return quiescent
}
type connState int
const (
connStateActive = iota + 1 // outstanding requests
connStateIdle // no requests
connStateClosed // closed connection
)
func (cs connState) String() string {
switch cs {
case connStateActive:
return "active"
case connStateIdle:
return "idle"
case connStateClosed:
return "closed"
default:
return "unknown"
}
}
func (s *Server) newConn(conn net.Conn) *serverConn {
c := &serverConn{
server: s,
conn: conn,
shutdown: make(chan struct{}),
}
c.setState(connStateIdle)
s.addConnection(c)
return c
}
type serverConn struct {
server *Server
conn net.Conn
state atomic.Value
shutdownOnce sync.Once
shutdown chan struct{} // forced shutdown, used by close
}
func (c *serverConn) getState() (connState, bool) {
cs, ok := c.state.Load().(connState)
return cs, ok
}
func (c *serverConn) setState(newstate connState) {
c.state.Store(newstate)
}
func (c *serverConn) close() error {
c.shutdownOnce.Do(func() {
close(c.shutdown)
})
return nil
}
func (s *Server) handleConn(conn net.Conn) {
defer conn.Close()
func (c *serverConn) run(sctx context.Context) {
type (
request struct {
id uint32
@ -59,21 +258,33 @@ func (s *Server) handleConn(conn net.Conn) {
)
var (
ch = newChannel(conn, conn)
ctx, cancel = context.WithCancel(context.Background())
responses = make(chan response)
requests = make(chan request)
recvErr = make(chan error, 1)
done = make(chan struct{})
ch = newChannel(c.conn, c.conn)
ctx, cancel = context.WithCancel(sctx)
active int
state connState = connStateIdle
responses = make(chan response)
requests = make(chan request)
recvErr = make(chan error, 1)
shutdown = c.shutdown
done = make(chan struct{})
)
defer c.conn.Close()
defer cancel()
defer close(done)
go func() {
go func(recvErr chan error) {
defer close(recvErr)
var p [messageLengthMax]byte
for {
select {
case <-c.shutdown:
return
case <-done:
return
default: // proceed
}
mh, err := ch.recv(ctx, p[:])
if err != nil {
recvErr <- err
@ -85,14 +296,7 @@ func (s *Server) handleConn(conn net.Conn) {
continue
}
var req Request
if err := s.codec.Unmarshal(p[:mh.Length], &req); err != nil {
recvErr <- err
return
}
if mh.StreamID%2 != 1 {
// enforce odd client initiated identifiers.
sendImmediate := func(code codes.Code, msg string, args ...interface{}) bool {
select {
case responses <- response{
// even though we've had an invalid stream id, we send it
@ -100,30 +304,68 @@ func (s *Server) handleConn(conn net.Conn) {
// stream id was bad.
id: mh.StreamID,
resp: &Response{
Status: status.New(codes.InvalidArgument, "StreamID must be odd for client initiated streams").Proto(),
Status: status.Newf(code, msg, args...).Proto(),
},
}:
return true
case <-c.shutdown:
return false
case <-done:
return false
}
}
var req Request
if err := c.server.codec.Unmarshal(p[:mh.Length], &req); err != nil {
if !sendImmediate(codes.InvalidArgument, "unmarshal request error: %v", err) {
return
}
continue
}
if mh.StreamID%2 != 1 {
// enforce odd client initiated identifiers.
if !sendImmediate(codes.InvalidArgument, "StreamID must be odd for client initiated streams") {
return
}
continue
}
// Forward the request to the main loop. We don't wait on s.done
// because we have already accepted the client request.
select {
case requests <- request{
id: mh.StreamID,
req: &req,
}:
case <-done:
return
}
}
}()
}(recvErr)
for {
newstate := state
switch {
case active > 0:
newstate = connStateActive
shutdown = nil
case active == 0:
newstate = connStateIdle
shutdown = c.shutdown // only enable this branch in idle mode
}
if newstate != state {
c.setState(newstate)
state = newstate
}
select {
case request := <-requests:
active++
go func(id uint32) {
p, status := s.services.call(ctx, request.req.Service, request.req.Method, request.req.Payload)
p, status := c.server.services.call(ctx, request.req.Service, request.req.Method, request.req.Payload)
resp := &Response{
Status: status.Proto(),
Payload: p,
@ -138,7 +380,7 @@ func (s *Server) handleConn(conn net.Conn) {
}
}(request.id)
case response := <-responses:
p, err := s.codec.Marshal(response.resp)
p, err := c.server.codec.Marshal(response.resp)
if err != nil {
log.L.WithError(err).Error("failed marshaling response")
return
@ -147,8 +389,17 @@ func (s *Server) handleConn(conn net.Conn) {
log.L.WithError(err).Error("failed sending message on channel")
return
}
active--
case err := <-recvErr:
log.L.WithError(err).Error("error receiving message")
// TODO(stevvooe): Not wildly clear what we should do in this
// branch. Basically, it means that we are no longer receiving
// requests due to a terminal error.
recvErr = nil // connection is now "closing"
if err != nil {
log.L.WithError(err).Error("error receiving message")
}
case <-shutdown:
return
}
}

View File

@ -7,6 +7,7 @@ import (
"reflect"
"strings"
"testing"
"time"
"github.com/gogo/protobuf/proto"
)
@ -74,34 +75,26 @@ func init() {
func TestServer(t *testing.T) {
var (
ctx = context.Background()
server = NewServer()
testImpl = &testingServer{}
ctx = context.Background()
server = NewServer()
testImpl = &testingServer{}
addr, listener = newTestListener(t)
client, cleanup = newTestClient(t, addr)
tclient = newTestingClient(client)
)
registerTestingService(server, testImpl)
addr := "\x00" + t.Name()
listener, err := net.Listen("unix", addr)
if err != nil {
t.Fatal(err)
}
defer listener.Close()
defer cleanup()
registerTestingService(server, testImpl)
go server.Serve(listener)
defer server.Shutdown(ctx)
conn, err := net.Dial("unix", addr)
if err != nil {
t.Fatal(err)
}
defer conn.Close()
client := newTestingClient(NewClient(conn))
const calls = 2
results := make(chan callResult, 2)
go roundTrip(ctx, t, client, "bar", results)
go roundTrip(ctx, t, client, "baz", results)
go roundTrip(ctx, t, tclient, "bar", results)
go roundTrip(ctx, t, tclient, "baz", results)
for i := 0; i < calls; i++ {
result := <-results
@ -111,6 +104,140 @@ func TestServer(t *testing.T) {
}
}
func newTestClient(t *testing.T, addr string) (*Client, func()) {
conn, err := net.Dial("unix", addr)
if err != nil {
t.Fatal(err)
}
client := NewClient(conn)
return client, func() {
conn.Close()
client.Close()
}
}
func TestServerListenerClosed(t *testing.T) {
var (
server = NewServer()
_, listener = newTestListener(t)
errs = make(chan error, 1)
)
go func() {
errs <- server.Serve(listener)
}()
if err := listener.Close(); err != nil {
t.Fatal(err)
}
err := <-errs
if err == nil {
t.Fatal(err)
}
}
func TestServerShutdown(t *testing.T) {
var (
ctx = context.Background()
server = NewServer()
addr, listener = newTestListener(t)
shutdownStarted = make(chan struct{})
shutdownFinished = make(chan struct{})
errs = make(chan error, 1)
client, cleanup = newTestClient(t, addr)
_, cleanup2 = newTestClient(t, addr) // secondary connection
)
defer cleanup()
defer cleanup2()
defer server.Close()
// register a service that takes until we tell it to stop
server.Register(serviceName, map[string]Method{
"Test": func(ctx context.Context, unmarshal func(interface{}) error) (interface{}, error) {
var req testPayload
if err := unmarshal(&req); err != nil {
return nil, err
}
return &testPayload{Foo: "waited"}, nil
},
})
go func() {
errs <- server.Serve(listener)
}()
tp := testPayload{Foo: "half"}
// send a few half requests
if err := client.sendRequest(ctx, 1, "testService", "Test", &tp); err != nil {
t.Fatal(err)
}
if err := client.sendRequest(ctx, 3, "testService", "Test", &tp); err != nil {
t.Fatal(err)
}
time.Sleep(1 * time.Millisecond) // ensure that requests make it through before shutdown
go func() {
close(shutdownStarted)
server.Shutdown(ctx)
// server.Close()
close(shutdownFinished)
}()
<-shutdownStarted
// receive the responses
if err := client.recvResponse(ctx, 1, &tp); err != nil {
t.Fatal(err)
}
if err := client.recvResponse(ctx, 3, &tp); err != nil {
t.Fatal(err)
}
<-shutdownFinished
checkServerShutdown(t, server)
}
func TestServerClose(t *testing.T) {
var (
server = NewServer()
_, listener = newTestListener(t)
startClose = make(chan struct{})
errs = make(chan error, 1)
)
go func() {
close(startClose)
errs <- server.Serve(listener)
}()
<-startClose
if err := server.Close(); err != nil {
t.Fatal(err)
}
err := <-errs
if err != ErrServerClosed {
t.Fatal("expected an error from a closed server", err)
}
checkServerShutdown(t, server)
}
func checkServerShutdown(t *testing.T, server *Server) {
t.Helper()
server.mu.Lock()
defer server.mu.Unlock()
if len(server.listeners) > 0 {
t.Fatalf("expected listeners to be empty: %v", server.listeners)
}
if len(server.connections) > 0 {
t.Fatalf("expected connections to be empty: %v", server.connections)
}
}
type callResult struct {
input *testPayload
expected *testPayload
@ -136,3 +263,13 @@ func roundTrip(ctx context.Context, t *testing.T, client *testingClient, value s
received: resp,
}
}
func newTestListener(t *testing.T) (string, net.Listener) {
addr := "\x00" + t.Name()
listener, err := net.Listen("unix", addr)
if err != nil {
t.Fatal(err)
}
return addr, listener
}