Merge pull request #4 from stevvooe/accept-loop-hardening
ttrpc: implement Close and Shutdown
This commit is contained in:
commit
7e38ac9e99
31
client.go
31
client.go
@ -40,22 +40,30 @@ func NewClient(conn net.Conn) *Client {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Call(ctx context.Context, service, method string, req, resp interface{}) error {
|
func (c *Client) Call(ctx context.Context, service, method string, req, resp interface{}) error {
|
||||||
|
requestID := atomic.AddUint32(&c.requestID, 2)
|
||||||
|
if err := c.sendRequest(ctx, requestID, service, method, req); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.recvResponse(ctx, requestID, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) sendRequest(ctx context.Context, requestID uint32, service, method string, req interface{}) error {
|
||||||
payload, err := c.codec.Marshal(req)
|
payload, err := c.codec.Marshal(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
requestID := atomic.AddUint32(&c.requestID, 2)
|
|
||||||
request := Request{
|
request := Request{
|
||||||
Service: service,
|
Service: service,
|
||||||
Method: method,
|
Method: method,
|
||||||
Payload: payload,
|
Payload: payload,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := c.send(ctx, requestID, &request); err != nil {
|
return c.send(ctx, requestID, &request)
|
||||||
return err
|
}
|
||||||
}
|
|
||||||
|
|
||||||
|
func (c *Client) recvResponse(ctx context.Context, requestID uint32, resp interface{}) error {
|
||||||
var response Response
|
var response Response
|
||||||
if err := c.recv(ctx, requestID, &response); err != nil {
|
if err := c.recv(ctx, requestID, &response); err != nil {
|
||||||
return err
|
return err
|
||||||
@ -160,6 +168,10 @@ func (c *Client) run() {
|
|||||||
// start one more goroutine to recv messages without blocking.
|
// start one more goroutine to recv messages without blocking.
|
||||||
for {
|
for {
|
||||||
var p [messageLengthMax]byte
|
var p [messageLengthMax]byte
|
||||||
|
// TODO(stevvooe): Something still isn't quite right with error
|
||||||
|
// handling on the client-side, causing EOFs to come through. We
|
||||||
|
// need other fixes in this changeset, so we'll address this
|
||||||
|
// correctly later.
|
||||||
mh, err := c.channel.recv(context.TODO(), p[:])
|
mh, err := c.channel.recv(context.TODO(), p[:])
|
||||||
select {
|
select {
|
||||||
case incoming <- received{
|
case incoming <- received{
|
||||||
@ -187,13 +199,12 @@ func (c *Client) run() {
|
|||||||
}
|
}
|
||||||
waiters[req.id] = req
|
waiters[req.id] = req
|
||||||
case r := <-incoming:
|
case r := <-incoming:
|
||||||
if r.err != nil {
|
|
||||||
c.err = r.err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if waiter, ok := waiters[r.mh.StreamID]; ok {
|
if waiter, ok := waiters[r.mh.StreamID]; ok {
|
||||||
waiter.err <- proto.Unmarshal(r.p, waiter.msg.(proto.Message))
|
if r.err != nil {
|
||||||
|
waiter.err <- r.err
|
||||||
|
} else {
|
||||||
|
waiter.err <- proto.Unmarshal(r.p, waiter.msg.(proto.Message))
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
queued[r.mh.StreamID] = r
|
queued[r.mh.StreamID] = r
|
||||||
}
|
}
|
||||||
|
315
server.go
315
server.go
@ -2,21 +2,38 @@ package ttrpc
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/containerd/containerd/log"
|
"github.com/containerd/containerd/log"
|
||||||
|
"github.com/pkg/errors"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrServerClosed = errors.New("ttrpc: server close")
|
||||||
|
)
|
||||||
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
services *serviceSet
|
services *serviceSet
|
||||||
codec codec
|
codec codec
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
listeners map[net.Listener]struct{}
|
||||||
|
connections map[*serverConn]struct{} // all connections to current state
|
||||||
|
done chan struct{} // marks point at which we stop serving requests
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServer() *Server {
|
func NewServer() *Server {
|
||||||
return &Server{
|
return &Server{
|
||||||
services: newServiceSet(),
|
services: newServiceSet(),
|
||||||
|
done: make(chan struct{}),
|
||||||
|
listeners: make(map[net.Listener]struct{}),
|
||||||
|
connections: make(map[*serverConn]struct{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -24,28 +41,210 @@ func (s *Server) Register(name string, methods map[string]Method) {
|
|||||||
s.services.register(name, methods)
|
s.services.register(name, methods)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) Shutdown(ctx context.Context) error {
|
|
||||||
// TODO(stevvooe): Wait on connection shutdown.
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) Serve(l net.Listener) error {
|
func (s *Server) Serve(l net.Listener) error {
|
||||||
|
s.addListener(l)
|
||||||
|
defer s.closeListener(l)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ctx = context.Background()
|
||||||
|
backoff time.Duration
|
||||||
|
)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
conn, err := l.Accept()
|
conn, err := l.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.L.WithError(err).Error("failed accept")
|
select {
|
||||||
continue
|
case <-s.done:
|
||||||
|
return ErrServerClosed
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
if terr, ok := err.(interface {
|
||||||
|
Temporary() bool
|
||||||
|
}); ok && terr.Temporary() {
|
||||||
|
if backoff == 0 {
|
||||||
|
backoff = time.Millisecond
|
||||||
|
} else {
|
||||||
|
backoff *= 2
|
||||||
|
}
|
||||||
|
|
||||||
|
if max := time.Second; backoff > max {
|
||||||
|
backoff = max
|
||||||
|
}
|
||||||
|
|
||||||
|
sleep := time.Duration(rand.Int63n(int64(backoff)))
|
||||||
|
log.L.WithError(err).Errorf("ttrpc: failed accept; backoff %v", sleep)
|
||||||
|
time.Sleep(sleep)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
go s.handleConn(conn)
|
backoff = 0
|
||||||
|
sc := s.newConn(conn)
|
||||||
|
go sc.run(ctx)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) Shutdown(ctx context.Context) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
lnerr := s.closeListeners()
|
||||||
|
select {
|
||||||
|
case <-s.done:
|
||||||
|
default:
|
||||||
|
// protected by mutex
|
||||||
|
close(s.done)
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
ticker := time.NewTicker(200 * time.Millisecond)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
if s.closeIdleConns() {
|
||||||
|
return lnerr
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case <-ticker.C:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return lnerr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close the server without waiting for active connections.
|
||||||
|
func (s *Server) Close() error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-s.done:
|
||||||
|
default:
|
||||||
|
// protected by mutex
|
||||||
|
close(s.done)
|
||||||
|
}
|
||||||
|
|
||||||
|
err := s.closeListeners()
|
||||||
|
for c := range s.connections {
|
||||||
|
c.close()
|
||||||
|
delete(s.connections, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) addListener(l net.Listener) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
s.listeners[l] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) closeListener(l net.Listener) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
return s.closeListenerLocked(l)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) closeListenerLocked(l net.Listener) error {
|
||||||
|
defer delete(s.listeners, l)
|
||||||
|
return l.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) closeListeners() error {
|
||||||
|
var err error
|
||||||
|
for l := range s.listeners {
|
||||||
|
if cerr := s.closeListenerLocked(l); cerr != nil && err == nil {
|
||||||
|
err = cerr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) addConnection(c *serverConn) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
s.connections[c] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) closeIdleConns() bool {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
quiescent := true
|
||||||
|
for c := range s.connections {
|
||||||
|
st, ok := c.getState()
|
||||||
|
if !ok || st != connStateIdle {
|
||||||
|
quiescent = false
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
c.close()
|
||||||
|
delete(s.connections, c)
|
||||||
|
}
|
||||||
|
return quiescent
|
||||||
|
}
|
||||||
|
|
||||||
|
type connState int
|
||||||
|
|
||||||
|
const (
|
||||||
|
connStateActive = iota + 1 // outstanding requests
|
||||||
|
connStateIdle // no requests
|
||||||
|
connStateClosed // closed connection
|
||||||
|
)
|
||||||
|
|
||||||
|
func (cs connState) String() string {
|
||||||
|
switch cs {
|
||||||
|
case connStateActive:
|
||||||
|
return "active"
|
||||||
|
case connStateIdle:
|
||||||
|
return "idle"
|
||||||
|
case connStateClosed:
|
||||||
|
return "closed"
|
||||||
|
default:
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) newConn(conn net.Conn) *serverConn {
|
||||||
|
c := &serverConn{
|
||||||
|
server: s,
|
||||||
|
conn: conn,
|
||||||
|
shutdown: make(chan struct{}),
|
||||||
|
}
|
||||||
|
c.setState(connStateIdle)
|
||||||
|
s.addConnection(c)
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
type serverConn struct {
|
||||||
|
server *Server
|
||||||
|
conn net.Conn
|
||||||
|
state atomic.Value
|
||||||
|
|
||||||
|
shutdownOnce sync.Once
|
||||||
|
shutdown chan struct{} // forced shutdown, used by close
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *serverConn) getState() (connState, bool) {
|
||||||
|
cs, ok := c.state.Load().(connState)
|
||||||
|
return cs, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *serverConn) setState(newstate connState) {
|
||||||
|
c.state.Store(newstate)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *serverConn) close() error {
|
||||||
|
c.shutdownOnce.Do(func() {
|
||||||
|
close(c.shutdown)
|
||||||
|
})
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleConn(conn net.Conn) {
|
func (c *serverConn) run(sctx context.Context) {
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
type (
|
type (
|
||||||
request struct {
|
request struct {
|
||||||
id uint32
|
id uint32
|
||||||
@ -59,21 +258,33 @@ func (s *Server) handleConn(conn net.Conn) {
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ch = newChannel(conn, conn)
|
ch = newChannel(c.conn, c.conn)
|
||||||
ctx, cancel = context.WithCancel(context.Background())
|
ctx, cancel = context.WithCancel(sctx)
|
||||||
responses = make(chan response)
|
active int
|
||||||
requests = make(chan request)
|
state connState = connStateIdle
|
||||||
recvErr = make(chan error, 1)
|
responses = make(chan response)
|
||||||
done = make(chan struct{})
|
requests = make(chan request)
|
||||||
|
recvErr = make(chan error, 1)
|
||||||
|
shutdown = c.shutdown
|
||||||
|
done = make(chan struct{})
|
||||||
)
|
)
|
||||||
|
|
||||||
|
defer c.conn.Close()
|
||||||
defer cancel()
|
defer cancel()
|
||||||
defer close(done)
|
defer close(done)
|
||||||
|
|
||||||
go func() {
|
go func(recvErr chan error) {
|
||||||
defer close(recvErr)
|
defer close(recvErr)
|
||||||
var p [messageLengthMax]byte
|
var p [messageLengthMax]byte
|
||||||
for {
|
for {
|
||||||
|
select {
|
||||||
|
case <-c.shutdown:
|
||||||
|
return
|
||||||
|
case <-done:
|
||||||
|
return
|
||||||
|
default: // proceed
|
||||||
|
}
|
||||||
|
|
||||||
mh, err := ch.recv(ctx, p[:])
|
mh, err := ch.recv(ctx, p[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
recvErr <- err
|
recvErr <- err
|
||||||
@ -85,14 +296,7 @@ func (s *Server) handleConn(conn net.Conn) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
var req Request
|
sendImmediate := func(code codes.Code, msg string, args ...interface{}) bool {
|
||||||
if err := s.codec.Unmarshal(p[:mh.Length], &req); err != nil {
|
|
||||||
recvErr <- err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if mh.StreamID%2 != 1 {
|
|
||||||
// enforce odd client initiated identifiers.
|
|
||||||
select {
|
select {
|
||||||
case responses <- response{
|
case responses <- response{
|
||||||
// even though we've had an invalid stream id, we send it
|
// even though we've had an invalid stream id, we send it
|
||||||
@ -100,30 +304,68 @@ func (s *Server) handleConn(conn net.Conn) {
|
|||||||
// stream id was bad.
|
// stream id was bad.
|
||||||
id: mh.StreamID,
|
id: mh.StreamID,
|
||||||
resp: &Response{
|
resp: &Response{
|
||||||
Status: status.New(codes.InvalidArgument, "StreamID must be odd for client initiated streams").Proto(),
|
Status: status.Newf(code, msg, args...).Proto(),
|
||||||
},
|
},
|
||||||
}:
|
}:
|
||||||
|
return true
|
||||||
|
case <-c.shutdown:
|
||||||
|
return false
|
||||||
case <-done:
|
case <-done:
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var req Request
|
||||||
|
if err := c.server.codec.Unmarshal(p[:mh.Length], &req); err != nil {
|
||||||
|
if !sendImmediate(codes.InvalidArgument, "unmarshal request error: %v", err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if mh.StreamID%2 != 1 {
|
||||||
|
// enforce odd client initiated identifiers.
|
||||||
|
if !sendImmediate(codes.InvalidArgument, "StreamID must be odd for client initiated streams") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward the request to the main loop. We don't wait on s.done
|
||||||
|
// because we have already accepted the client request.
|
||||||
select {
|
select {
|
||||||
case requests <- request{
|
case requests <- request{
|
||||||
id: mh.StreamID,
|
id: mh.StreamID,
|
||||||
req: &req,
|
req: &req,
|
||||||
}:
|
}:
|
||||||
case <-done:
|
case <-done:
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}(recvErr)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
|
newstate := state
|
||||||
|
switch {
|
||||||
|
case active > 0:
|
||||||
|
newstate = connStateActive
|
||||||
|
shutdown = nil
|
||||||
|
case active == 0:
|
||||||
|
newstate = connStateIdle
|
||||||
|
shutdown = c.shutdown // only enable this branch in idle mode
|
||||||
|
}
|
||||||
|
|
||||||
|
if newstate != state {
|
||||||
|
c.setState(newstate)
|
||||||
|
state = newstate
|
||||||
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case request := <-requests:
|
case request := <-requests:
|
||||||
|
active++
|
||||||
|
|
||||||
go func(id uint32) {
|
go func(id uint32) {
|
||||||
p, status := s.services.call(ctx, request.req.Service, request.req.Method, request.req.Payload)
|
p, status := c.server.services.call(ctx, request.req.Service, request.req.Method, request.req.Payload)
|
||||||
resp := &Response{
|
resp := &Response{
|
||||||
Status: status.Proto(),
|
Status: status.Proto(),
|
||||||
Payload: p,
|
Payload: p,
|
||||||
@ -138,7 +380,7 @@ func (s *Server) handleConn(conn net.Conn) {
|
|||||||
}
|
}
|
||||||
}(request.id)
|
}(request.id)
|
||||||
case response := <-responses:
|
case response := <-responses:
|
||||||
p, err := s.codec.Marshal(response.resp)
|
p, err := c.server.codec.Marshal(response.resp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.L.WithError(err).Error("failed marshaling response")
|
log.L.WithError(err).Error("failed marshaling response")
|
||||||
return
|
return
|
||||||
@ -147,8 +389,17 @@ func (s *Server) handleConn(conn net.Conn) {
|
|||||||
log.L.WithError(err).Error("failed sending message on channel")
|
log.L.WithError(err).Error("failed sending message on channel")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
active--
|
||||||
case err := <-recvErr:
|
case err := <-recvErr:
|
||||||
log.L.WithError(err).Error("error receiving message")
|
// TODO(stevvooe): Not wildly clear what we should do in this
|
||||||
|
// branch. Basically, it means that we are no longer receiving
|
||||||
|
// requests due to a terminal error.
|
||||||
|
recvErr = nil // connection is now "closing"
|
||||||
|
if err != nil {
|
||||||
|
log.L.WithError(err).Error("error receiving message")
|
||||||
|
}
|
||||||
|
case <-shutdown:
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
175
server_test.go
175
server_test.go
@ -7,6 +7,7 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gogo/protobuf/proto"
|
"github.com/gogo/protobuf/proto"
|
||||||
)
|
)
|
||||||
@ -74,34 +75,26 @@ func init() {
|
|||||||
|
|
||||||
func TestServer(t *testing.T) {
|
func TestServer(t *testing.T) {
|
||||||
var (
|
var (
|
||||||
ctx = context.Background()
|
ctx = context.Background()
|
||||||
server = NewServer()
|
server = NewServer()
|
||||||
testImpl = &testingServer{}
|
testImpl = &testingServer{}
|
||||||
|
addr, listener = newTestListener(t)
|
||||||
|
client, cleanup = newTestClient(t, addr)
|
||||||
|
tclient = newTestingClient(client)
|
||||||
)
|
)
|
||||||
|
|
||||||
registerTestingService(server, testImpl)
|
|
||||||
|
|
||||||
addr := "\x00" + t.Name()
|
|
||||||
listener, err := net.Listen("unix", addr)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer listener.Close()
|
defer listener.Close()
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
registerTestingService(server, testImpl)
|
||||||
|
|
||||||
go server.Serve(listener)
|
go server.Serve(listener)
|
||||||
defer server.Shutdown(ctx)
|
defer server.Shutdown(ctx)
|
||||||
|
|
||||||
conn, err := net.Dial("unix", addr)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
client := newTestingClient(NewClient(conn))
|
|
||||||
|
|
||||||
const calls = 2
|
const calls = 2
|
||||||
results := make(chan callResult, 2)
|
results := make(chan callResult, 2)
|
||||||
go roundTrip(ctx, t, client, "bar", results)
|
go roundTrip(ctx, t, tclient, "bar", results)
|
||||||
go roundTrip(ctx, t, client, "baz", results)
|
go roundTrip(ctx, t, tclient, "baz", results)
|
||||||
|
|
||||||
for i := 0; i < calls; i++ {
|
for i := 0; i < calls; i++ {
|
||||||
result := <-results
|
result := <-results
|
||||||
@ -111,6 +104,140 @@ func TestServer(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newTestClient(t *testing.T, addr string) (*Client, func()) {
|
||||||
|
conn, err := net.Dial("unix", addr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
client := NewClient(conn)
|
||||||
|
return client, func() {
|
||||||
|
conn.Close()
|
||||||
|
client.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerListenerClosed(t *testing.T) {
|
||||||
|
var (
|
||||||
|
server = NewServer()
|
||||||
|
_, listener = newTestListener(t)
|
||||||
|
errs = make(chan error, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
errs <- server.Serve(listener)
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err := listener.Close(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err := <-errs
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerShutdown(t *testing.T) {
|
||||||
|
var (
|
||||||
|
ctx = context.Background()
|
||||||
|
server = NewServer()
|
||||||
|
addr, listener = newTestListener(t)
|
||||||
|
shutdownStarted = make(chan struct{})
|
||||||
|
shutdownFinished = make(chan struct{})
|
||||||
|
errs = make(chan error, 1)
|
||||||
|
client, cleanup = newTestClient(t, addr)
|
||||||
|
_, cleanup2 = newTestClient(t, addr) // secondary connection
|
||||||
|
)
|
||||||
|
defer cleanup()
|
||||||
|
defer cleanup2()
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
// register a service that takes until we tell it to stop
|
||||||
|
server.Register(serviceName, map[string]Method{
|
||||||
|
"Test": func(ctx context.Context, unmarshal func(interface{}) error) (interface{}, error) {
|
||||||
|
var req testPayload
|
||||||
|
if err := unmarshal(&req); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &testPayload{Foo: "waited"}, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
errs <- server.Serve(listener)
|
||||||
|
}()
|
||||||
|
|
||||||
|
tp := testPayload{Foo: "half"}
|
||||||
|
// send a few half requests
|
||||||
|
if err := client.sendRequest(ctx, 1, "testService", "Test", &tp); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := client.sendRequest(ctx, 3, "testService", "Test", &tp); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(1 * time.Millisecond) // ensure that requests make it through before shutdown
|
||||||
|
go func() {
|
||||||
|
close(shutdownStarted)
|
||||||
|
server.Shutdown(ctx)
|
||||||
|
// server.Close()
|
||||||
|
close(shutdownFinished)
|
||||||
|
}()
|
||||||
|
|
||||||
|
<-shutdownStarted
|
||||||
|
|
||||||
|
// receive the responses
|
||||||
|
if err := client.recvResponse(ctx, 1, &tp); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := client.recvResponse(ctx, 3, &tp); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
<-shutdownFinished
|
||||||
|
checkServerShutdown(t, server)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerClose(t *testing.T) {
|
||||||
|
var (
|
||||||
|
server = NewServer()
|
||||||
|
_, listener = newTestListener(t)
|
||||||
|
startClose = make(chan struct{})
|
||||||
|
errs = make(chan error, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
close(startClose)
|
||||||
|
errs <- server.Serve(listener)
|
||||||
|
}()
|
||||||
|
|
||||||
|
<-startClose
|
||||||
|
if err := server.Close(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err := <-errs
|
||||||
|
if err != ErrServerClosed {
|
||||||
|
t.Fatal("expected an error from a closed server", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkServerShutdown(t, server)
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkServerShutdown(t *testing.T, server *Server) {
|
||||||
|
t.Helper()
|
||||||
|
server.mu.Lock()
|
||||||
|
defer server.mu.Unlock()
|
||||||
|
if len(server.listeners) > 0 {
|
||||||
|
t.Fatalf("expected listeners to be empty: %v", server.listeners)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(server.connections) > 0 {
|
||||||
|
t.Fatalf("expected connections to be empty: %v", server.connections)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type callResult struct {
|
type callResult struct {
|
||||||
input *testPayload
|
input *testPayload
|
||||||
expected *testPayload
|
expected *testPayload
|
||||||
@ -136,3 +263,13 @@ func roundTrip(ctx context.Context, t *testing.T, client *testingClient, value s
|
|||||||
received: resp,
|
received: resp,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newTestListener(t *testing.T) (string, net.Listener) {
|
||||||
|
addr := "\x00" + t.Name()
|
||||||
|
listener, err := net.Listen("unix", addr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return addr, listener
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user