Merge pull request #4 from stevvooe/accept-loop-hardening
ttrpc: implement Close and Shutdown
This commit is contained in:
commit
7e38ac9e99
29
client.go
29
client.go
@ -40,22 +40,30 @@ func NewClient(conn net.Conn) *Client {
|
||||
}
|
||||
|
||||
func (c *Client) Call(ctx context.Context, service, method string, req, resp interface{}) error {
|
||||
requestID := atomic.AddUint32(&c.requestID, 2)
|
||||
if err := c.sendRequest(ctx, requestID, service, method, req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.recvResponse(ctx, requestID, resp)
|
||||
}
|
||||
|
||||
func (c *Client) sendRequest(ctx context.Context, requestID uint32, service, method string, req interface{}) error {
|
||||
payload, err := c.codec.Marshal(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
requestID := atomic.AddUint32(&c.requestID, 2)
|
||||
request := Request{
|
||||
Service: service,
|
||||
Method: method,
|
||||
Payload: payload,
|
||||
}
|
||||
|
||||
if err := c.send(ctx, requestID, &request); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.send(ctx, requestID, &request)
|
||||
}
|
||||
|
||||
func (c *Client) recvResponse(ctx context.Context, requestID uint32, resp interface{}) error {
|
||||
var response Response
|
||||
if err := c.recv(ctx, requestID, &response); err != nil {
|
||||
return err
|
||||
@ -160,6 +168,10 @@ func (c *Client) run() {
|
||||
// start one more goroutine to recv messages without blocking.
|
||||
for {
|
||||
var p [messageLengthMax]byte
|
||||
// TODO(stevvooe): Something still isn't quite right with error
|
||||
// handling on the client-side, causing EOFs to come through. We
|
||||
// need other fixes in this changeset, so we'll address this
|
||||
// correctly later.
|
||||
mh, err := c.channel.recv(context.TODO(), p[:])
|
||||
select {
|
||||
case incoming <- received{
|
||||
@ -187,13 +199,12 @@ func (c *Client) run() {
|
||||
}
|
||||
waiters[req.id] = req
|
||||
case r := <-incoming:
|
||||
if r.err != nil {
|
||||
c.err = r.err
|
||||
return
|
||||
}
|
||||
|
||||
if waiter, ok := waiters[r.mh.StreamID]; ok {
|
||||
if r.err != nil {
|
||||
waiter.err <- r.err
|
||||
} else {
|
||||
waiter.err <- proto.Unmarshal(r.p, waiter.msg.(proto.Message))
|
||||
}
|
||||
} else {
|
||||
queued[r.mh.StreamID] = r
|
||||
}
|
||||
|
301
server.go
301
server.go
@ -2,21 +2,38 @@ package ttrpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math/rand"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/containerd/containerd/log"
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrServerClosed = errors.New("ttrpc: server close")
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
services *serviceSet
|
||||
codec codec
|
||||
|
||||
mu sync.Mutex
|
||||
listeners map[net.Listener]struct{}
|
||||
connections map[*serverConn]struct{} // all connections to current state
|
||||
done chan struct{} // marks point at which we stop serving requests
|
||||
}
|
||||
|
||||
func NewServer() *Server {
|
||||
return &Server{
|
||||
services: newServiceSet(),
|
||||
done: make(chan struct{}),
|
||||
listeners: make(map[net.Listener]struct{}),
|
||||
connections: make(map[*serverConn]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
@ -24,28 +41,210 @@ func (s *Server) Register(name string, methods map[string]Method) {
|
||||
s.services.register(name, methods)
|
||||
}
|
||||
|
||||
func (s *Server) Shutdown(ctx context.Context) error {
|
||||
// TODO(stevvooe): Wait on connection shutdown.
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) Serve(l net.Listener) error {
|
||||
s.addListener(l)
|
||||
defer s.closeListener(l)
|
||||
|
||||
var (
|
||||
ctx = context.Background()
|
||||
backoff time.Duration
|
||||
)
|
||||
|
||||
for {
|
||||
conn, err := l.Accept()
|
||||
if err != nil {
|
||||
log.L.WithError(err).Error("failed accept")
|
||||
select {
|
||||
case <-s.done:
|
||||
return ErrServerClosed
|
||||
default:
|
||||
}
|
||||
|
||||
if terr, ok := err.(interface {
|
||||
Temporary() bool
|
||||
}); ok && terr.Temporary() {
|
||||
if backoff == 0 {
|
||||
backoff = time.Millisecond
|
||||
} else {
|
||||
backoff *= 2
|
||||
}
|
||||
|
||||
if max := time.Second; backoff > max {
|
||||
backoff = max
|
||||
}
|
||||
|
||||
sleep := time.Duration(rand.Int63n(int64(backoff)))
|
||||
log.L.WithError(err).Errorf("ttrpc: failed accept; backoff %v", sleep)
|
||||
time.Sleep(sleep)
|
||||
continue
|
||||
}
|
||||
|
||||
go s.handleConn(conn)
|
||||
return err
|
||||
}
|
||||
|
||||
backoff = 0
|
||||
sc := s.newConn(conn)
|
||||
go sc.run(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) Shutdown(ctx context.Context) error {
|
||||
s.mu.Lock()
|
||||
lnerr := s.closeListeners()
|
||||
select {
|
||||
case <-s.done:
|
||||
default:
|
||||
// protected by mutex
|
||||
close(s.done)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
ticker := time.NewTicker(200 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
if s.closeIdleConns() {
|
||||
return lnerr
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
}
|
||||
}
|
||||
|
||||
return lnerr
|
||||
}
|
||||
|
||||
// Close the server without waiting for active connections.
|
||||
func (s *Server) Close() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
select {
|
||||
case <-s.done:
|
||||
default:
|
||||
// protected by mutex
|
||||
close(s.done)
|
||||
}
|
||||
|
||||
err := s.closeListeners()
|
||||
for c := range s.connections {
|
||||
c.close()
|
||||
delete(s.connections, c)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Server) addListener(l net.Listener) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.listeners[l] = struct{}{}
|
||||
}
|
||||
|
||||
func (s *Server) closeListener(l net.Listener) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
return s.closeListenerLocked(l)
|
||||
}
|
||||
|
||||
func (s *Server) closeListenerLocked(l net.Listener) error {
|
||||
defer delete(s.listeners, l)
|
||||
return l.Close()
|
||||
}
|
||||
|
||||
func (s *Server) closeListeners() error {
|
||||
var err error
|
||||
for l := range s.listeners {
|
||||
if cerr := s.closeListenerLocked(l); cerr != nil && err == nil {
|
||||
err = cerr
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Server) addConnection(c *serverConn) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.connections[c] = struct{}{}
|
||||
}
|
||||
|
||||
func (s *Server) closeIdleConns() bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
quiescent := true
|
||||
for c := range s.connections {
|
||||
st, ok := c.getState()
|
||||
if !ok || st != connStateIdle {
|
||||
quiescent = false
|
||||
continue
|
||||
}
|
||||
c.close()
|
||||
delete(s.connections, c)
|
||||
}
|
||||
return quiescent
|
||||
}
|
||||
|
||||
type connState int
|
||||
|
||||
const (
|
||||
connStateActive = iota + 1 // outstanding requests
|
||||
connStateIdle // no requests
|
||||
connStateClosed // closed connection
|
||||
)
|
||||
|
||||
func (cs connState) String() string {
|
||||
switch cs {
|
||||
case connStateActive:
|
||||
return "active"
|
||||
case connStateIdle:
|
||||
return "idle"
|
||||
case connStateClosed:
|
||||
return "closed"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) newConn(conn net.Conn) *serverConn {
|
||||
c := &serverConn{
|
||||
server: s,
|
||||
conn: conn,
|
||||
shutdown: make(chan struct{}),
|
||||
}
|
||||
c.setState(connStateIdle)
|
||||
s.addConnection(c)
|
||||
return c
|
||||
}
|
||||
|
||||
type serverConn struct {
|
||||
server *Server
|
||||
conn net.Conn
|
||||
state atomic.Value
|
||||
|
||||
shutdownOnce sync.Once
|
||||
shutdown chan struct{} // forced shutdown, used by close
|
||||
}
|
||||
|
||||
func (c *serverConn) getState() (connState, bool) {
|
||||
cs, ok := c.state.Load().(connState)
|
||||
return cs, ok
|
||||
}
|
||||
|
||||
func (c *serverConn) setState(newstate connState) {
|
||||
c.state.Store(newstate)
|
||||
}
|
||||
|
||||
func (c *serverConn) close() error {
|
||||
c.shutdownOnce.Do(func() {
|
||||
close(c.shutdown)
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) handleConn(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
func (c *serverConn) run(sctx context.Context) {
|
||||
type (
|
||||
request struct {
|
||||
id uint32
|
||||
@ -59,21 +258,33 @@ func (s *Server) handleConn(conn net.Conn) {
|
||||
)
|
||||
|
||||
var (
|
||||
ch = newChannel(conn, conn)
|
||||
ctx, cancel = context.WithCancel(context.Background())
|
||||
ch = newChannel(c.conn, c.conn)
|
||||
ctx, cancel = context.WithCancel(sctx)
|
||||
active int
|
||||
state connState = connStateIdle
|
||||
responses = make(chan response)
|
||||
requests = make(chan request)
|
||||
recvErr = make(chan error, 1)
|
||||
shutdown = c.shutdown
|
||||
done = make(chan struct{})
|
||||
)
|
||||
|
||||
defer c.conn.Close()
|
||||
defer cancel()
|
||||
defer close(done)
|
||||
|
||||
go func() {
|
||||
go func(recvErr chan error) {
|
||||
defer close(recvErr)
|
||||
var p [messageLengthMax]byte
|
||||
for {
|
||||
select {
|
||||
case <-c.shutdown:
|
||||
return
|
||||
case <-done:
|
||||
return
|
||||
default: // proceed
|
||||
}
|
||||
|
||||
mh, err := ch.recv(ctx, p[:])
|
||||
if err != nil {
|
||||
recvErr <- err
|
||||
@ -85,14 +296,7 @@ func (s *Server) handleConn(conn net.Conn) {
|
||||
continue
|
||||
}
|
||||
|
||||
var req Request
|
||||
if err := s.codec.Unmarshal(p[:mh.Length], &req); err != nil {
|
||||
recvErr <- err
|
||||
return
|
||||
}
|
||||
|
||||
if mh.StreamID%2 != 1 {
|
||||
// enforce odd client initiated identifiers.
|
||||
sendImmediate := func(code codes.Code, msg string, args ...interface{}) bool {
|
||||
select {
|
||||
case responses <- response{
|
||||
// even though we've had an invalid stream id, we send it
|
||||
@ -100,30 +304,68 @@ func (s *Server) handleConn(conn net.Conn) {
|
||||
// stream id was bad.
|
||||
id: mh.StreamID,
|
||||
resp: &Response{
|
||||
Status: status.New(codes.InvalidArgument, "StreamID must be odd for client initiated streams").Proto(),
|
||||
Status: status.Newf(code, msg, args...).Proto(),
|
||||
},
|
||||
}:
|
||||
return true
|
||||
case <-c.shutdown:
|
||||
return false
|
||||
case <-done:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
var req Request
|
||||
if err := c.server.codec.Unmarshal(p[:mh.Length], &req); err != nil {
|
||||
if !sendImmediate(codes.InvalidArgument, "unmarshal request error: %v", err) {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if mh.StreamID%2 != 1 {
|
||||
// enforce odd client initiated identifiers.
|
||||
if !sendImmediate(codes.InvalidArgument, "StreamID must be odd for client initiated streams") {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Forward the request to the main loop. We don't wait on s.done
|
||||
// because we have already accepted the client request.
|
||||
select {
|
||||
case requests <- request{
|
||||
id: mh.StreamID,
|
||||
req: &req,
|
||||
}:
|
||||
case <-done:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}(recvErr)
|
||||
|
||||
for {
|
||||
newstate := state
|
||||
switch {
|
||||
case active > 0:
|
||||
newstate = connStateActive
|
||||
shutdown = nil
|
||||
case active == 0:
|
||||
newstate = connStateIdle
|
||||
shutdown = c.shutdown // only enable this branch in idle mode
|
||||
}
|
||||
|
||||
if newstate != state {
|
||||
c.setState(newstate)
|
||||
state = newstate
|
||||
}
|
||||
|
||||
select {
|
||||
case request := <-requests:
|
||||
active++
|
||||
|
||||
go func(id uint32) {
|
||||
p, status := s.services.call(ctx, request.req.Service, request.req.Method, request.req.Payload)
|
||||
p, status := c.server.services.call(ctx, request.req.Service, request.req.Method, request.req.Payload)
|
||||
resp := &Response{
|
||||
Status: status.Proto(),
|
||||
Payload: p,
|
||||
@ -138,7 +380,7 @@ func (s *Server) handleConn(conn net.Conn) {
|
||||
}
|
||||
}(request.id)
|
||||
case response := <-responses:
|
||||
p, err := s.codec.Marshal(response.resp)
|
||||
p, err := c.server.codec.Marshal(response.resp)
|
||||
if err != nil {
|
||||
log.L.WithError(err).Error("failed marshaling response")
|
||||
return
|
||||
@ -147,8 +389,17 @@ func (s *Server) handleConn(conn net.Conn) {
|
||||
log.L.WithError(err).Error("failed sending message on channel")
|
||||
return
|
||||
}
|
||||
|
||||
active--
|
||||
case err := <-recvErr:
|
||||
// TODO(stevvooe): Not wildly clear what we should do in this
|
||||
// branch. Basically, it means that we are no longer receiving
|
||||
// requests due to a terminal error.
|
||||
recvErr = nil // connection is now "closing"
|
||||
if err != nil {
|
||||
log.L.WithError(err).Error("error receiving message")
|
||||
}
|
||||
case <-shutdown:
|
||||
return
|
||||
}
|
||||
}
|
||||
|
169
server_test.go
169
server_test.go
@ -7,6 +7,7 @@ import (
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gogo/protobuf/proto"
|
||||
)
|
||||
@ -77,31 +78,23 @@ func TestServer(t *testing.T) {
|
||||
ctx = context.Background()
|
||||
server = NewServer()
|
||||
testImpl = &testingServer{}
|
||||
addr, listener = newTestListener(t)
|
||||
client, cleanup = newTestClient(t, addr)
|
||||
tclient = newTestingClient(client)
|
||||
)
|
||||
|
||||
registerTestingService(server, testImpl)
|
||||
|
||||
addr := "\x00" + t.Name()
|
||||
listener, err := net.Listen("unix", addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer listener.Close()
|
||||
defer cleanup()
|
||||
|
||||
registerTestingService(server, testImpl)
|
||||
|
||||
go server.Serve(listener)
|
||||
defer server.Shutdown(ctx)
|
||||
|
||||
conn, err := net.Dial("unix", addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
client := newTestingClient(NewClient(conn))
|
||||
|
||||
const calls = 2
|
||||
results := make(chan callResult, 2)
|
||||
go roundTrip(ctx, t, client, "bar", results)
|
||||
go roundTrip(ctx, t, client, "baz", results)
|
||||
go roundTrip(ctx, t, tclient, "bar", results)
|
||||
go roundTrip(ctx, t, tclient, "baz", results)
|
||||
|
||||
for i := 0; i < calls; i++ {
|
||||
result := <-results
|
||||
@ -111,6 +104,140 @@ func TestServer(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func newTestClient(t *testing.T, addr string) (*Client, func()) {
|
||||
conn, err := net.Dial("unix", addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
client := NewClient(conn)
|
||||
return client, func() {
|
||||
conn.Close()
|
||||
client.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerListenerClosed(t *testing.T) {
|
||||
var (
|
||||
server = NewServer()
|
||||
_, listener = newTestListener(t)
|
||||
errs = make(chan error, 1)
|
||||
)
|
||||
|
||||
go func() {
|
||||
errs <- server.Serve(listener)
|
||||
}()
|
||||
|
||||
if err := listener.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err := <-errs
|
||||
if err == nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerShutdown(t *testing.T) {
|
||||
var (
|
||||
ctx = context.Background()
|
||||
server = NewServer()
|
||||
addr, listener = newTestListener(t)
|
||||
shutdownStarted = make(chan struct{})
|
||||
shutdownFinished = make(chan struct{})
|
||||
errs = make(chan error, 1)
|
||||
client, cleanup = newTestClient(t, addr)
|
||||
_, cleanup2 = newTestClient(t, addr) // secondary connection
|
||||
)
|
||||
defer cleanup()
|
||||
defer cleanup2()
|
||||
defer server.Close()
|
||||
|
||||
// register a service that takes until we tell it to stop
|
||||
server.Register(serviceName, map[string]Method{
|
||||
"Test": func(ctx context.Context, unmarshal func(interface{}) error) (interface{}, error) {
|
||||
var req testPayload
|
||||
if err := unmarshal(&req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &testPayload{Foo: "waited"}, nil
|
||||
},
|
||||
})
|
||||
|
||||
go func() {
|
||||
errs <- server.Serve(listener)
|
||||
}()
|
||||
|
||||
tp := testPayload{Foo: "half"}
|
||||
// send a few half requests
|
||||
if err := client.sendRequest(ctx, 1, "testService", "Test", &tp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := client.sendRequest(ctx, 3, "testService", "Test", &tp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
time.Sleep(1 * time.Millisecond) // ensure that requests make it through before shutdown
|
||||
go func() {
|
||||
close(shutdownStarted)
|
||||
server.Shutdown(ctx)
|
||||
// server.Close()
|
||||
close(shutdownFinished)
|
||||
}()
|
||||
|
||||
<-shutdownStarted
|
||||
|
||||
// receive the responses
|
||||
if err := client.recvResponse(ctx, 1, &tp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := client.recvResponse(ctx, 3, &tp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
<-shutdownFinished
|
||||
checkServerShutdown(t, server)
|
||||
}
|
||||
|
||||
func TestServerClose(t *testing.T) {
|
||||
var (
|
||||
server = NewServer()
|
||||
_, listener = newTestListener(t)
|
||||
startClose = make(chan struct{})
|
||||
errs = make(chan error, 1)
|
||||
)
|
||||
|
||||
go func() {
|
||||
close(startClose)
|
||||
errs <- server.Serve(listener)
|
||||
}()
|
||||
|
||||
<-startClose
|
||||
if err := server.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err := <-errs
|
||||
if err != ErrServerClosed {
|
||||
t.Fatal("expected an error from a closed server", err)
|
||||
}
|
||||
|
||||
checkServerShutdown(t, server)
|
||||
}
|
||||
|
||||
func checkServerShutdown(t *testing.T, server *Server) {
|
||||
t.Helper()
|
||||
server.mu.Lock()
|
||||
defer server.mu.Unlock()
|
||||
if len(server.listeners) > 0 {
|
||||
t.Fatalf("expected listeners to be empty: %v", server.listeners)
|
||||
}
|
||||
|
||||
if len(server.connections) > 0 {
|
||||
t.Fatalf("expected connections to be empty: %v", server.connections)
|
||||
}
|
||||
}
|
||||
|
||||
type callResult struct {
|
||||
input *testPayload
|
||||
expected *testPayload
|
||||
@ -136,3 +263,13 @@ func roundTrip(ctx context.Context, t *testing.T, client *testingClient, value s
|
||||
received: resp,
|
||||
}
|
||||
}
|
||||
|
||||
func newTestListener(t *testing.T) (string, net.Listener) {
|
||||
addr := "\x00" + t.Name()
|
||||
listener, err := net.Listen("unix", addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return addr, listener
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user