ttrpc/server.go
Stephen J Day 256c17bccd
ttrpc: use os.Getuid/os.Getgid directly
Because of issues with glibc, using the `os/user` package can cause when
calling `user.Current()`. Neither the Go maintainers or glibc developers
could be bothered to fix it, so we have to work around it by calling the
uid and gid functions directly. This is probably better because we don't
actually use much of the data provided in the `user.User` struct.

This required some refactoring to have better control over when the uid
and gid are resolved. Rather than checking the current user on every
connection, we now resolve it once at initialization. To test that this
provided an improvement in performance, a benchmark was added.
Unfortunately, this exposed a regression in the performance of unix
sockets in Go when `(*UnixConn).File` is called. The underlying culprit
of this performance regression is still at large.

The following open issues describe the underlying problem in more
detail:

https://github.com/golang/go/issues/13470
https://sourceware.org/bugzilla/show_bug.cgi?id=19341

In better news, I now have an entire herd of shaved yaks.

Signed-off-by: Stephen J Day <stephen.day@docker.com>
2017-11-30 20:21:50 -08:00

442 lines
8.6 KiB
Go

package ttrpc
import (
"context"
"io"
"math/rand"
"net"
"sync"
"sync/atomic"
"time"
"github.com/containerd/containerd/log"
"github.com/pkg/errors"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
var (
ErrServerClosed = errors.New("ttrpc: server close")
)
type Server struct {
config *serverConfig
services *serviceSet
codec codec
mu sync.Mutex
listeners map[net.Listener]struct{}
connections map[*serverConn]struct{} // all connections to current state
done chan struct{} // marks point at which we stop serving requests
}
func NewServer(opts ...ServerOpt) (*Server, error) {
config := &serverConfig{}
for _, opt := range opts {
if err := opt(config); err != nil {
return nil, err
}
}
return &Server{
config: config,
services: newServiceSet(),
done: make(chan struct{}),
listeners: make(map[net.Listener]struct{}),
connections: make(map[*serverConn]struct{}),
}, nil
}
func (s *Server) Register(name string, methods map[string]Method) {
s.services.register(name, methods)
}
func (s *Server) Serve(l net.Listener) error {
s.addListener(l)
defer s.closeListener(l)
var (
ctx = context.Background()
backoff time.Duration
handshaker = s.config.handshaker
)
if handshaker == nil {
handshaker = handshakerFunc(noopHandshake)
}
for {
conn, err := l.Accept()
if err != nil {
select {
case <-s.done:
return ErrServerClosed
default:
}
if terr, ok := err.(interface {
Temporary() bool
}); ok && terr.Temporary() {
if backoff == 0 {
backoff = time.Millisecond
} else {
backoff *= 2
}
if max := time.Second; backoff > max {
backoff = max
}
sleep := time.Duration(rand.Int63n(int64(backoff)))
log.L.WithError(err).Errorf("ttrpc: failed accept; backoff %v", sleep)
time.Sleep(sleep)
continue
}
return err
}
backoff = 0
approved, handshake, err := handshaker.Handshake(ctx, conn)
if err != nil {
log.L.WithError(err).Errorf("ttrpc: refusing connection after handshake")
conn.Close()
continue
}
sc := s.newConn(approved, handshake)
go sc.run(ctx)
}
}
func (s *Server) Shutdown(ctx context.Context) error {
s.mu.Lock()
lnerr := s.closeListeners()
select {
case <-s.done:
default:
// protected by mutex
close(s.done)
}
s.mu.Unlock()
ticker := time.NewTicker(200 * time.Millisecond)
defer ticker.Stop()
for {
if s.closeIdleConns() {
return lnerr
}
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
}
}
}
// Close the server without waiting for active connections.
func (s *Server) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
select {
case <-s.done:
default:
// protected by mutex
close(s.done)
}
err := s.closeListeners()
for c := range s.connections {
c.close()
delete(s.connections, c)
}
return err
}
func (s *Server) addListener(l net.Listener) {
s.mu.Lock()
defer s.mu.Unlock()
s.listeners[l] = struct{}{}
}
func (s *Server) closeListener(l net.Listener) error {
s.mu.Lock()
defer s.mu.Unlock()
return s.closeListenerLocked(l)
}
func (s *Server) closeListenerLocked(l net.Listener) error {
defer delete(s.listeners, l)
return l.Close()
}
func (s *Server) closeListeners() error {
var err error
for l := range s.listeners {
if cerr := s.closeListenerLocked(l); cerr != nil && err == nil {
err = cerr
}
}
return err
}
func (s *Server) addConnection(c *serverConn) {
s.mu.Lock()
defer s.mu.Unlock()
s.connections[c] = struct{}{}
}
func (s *Server) closeIdleConns() bool {
s.mu.Lock()
defer s.mu.Unlock()
quiescent := true
for c := range s.connections {
st, ok := c.getState()
if !ok || st != connStateIdle {
quiescent = false
continue
}
c.close()
delete(s.connections, c)
}
return quiescent
}
type connState int
const (
connStateActive = iota + 1 // outstanding requests
connStateIdle // no requests
connStateClosed // closed connection
)
func (cs connState) String() string {
switch cs {
case connStateActive:
return "active"
case connStateIdle:
return "idle"
case connStateClosed:
return "closed"
default:
return "unknown"
}
}
func (s *Server) newConn(conn net.Conn, handshake interface{}) *serverConn {
c := &serverConn{
server: s,
conn: conn,
handshake: handshake,
shutdown: make(chan struct{}),
}
c.setState(connStateIdle)
s.addConnection(c)
return c
}
type serverConn struct {
server *Server
conn net.Conn
handshake interface{} // data from handshake, not used for now
state atomic.Value
shutdownOnce sync.Once
shutdown chan struct{} // forced shutdown, used by close
}
func (c *serverConn) getState() (connState, bool) {
cs, ok := c.state.Load().(connState)
return cs, ok
}
func (c *serverConn) setState(newstate connState) {
c.state.Store(newstate)
}
func (c *serverConn) close() error {
c.shutdownOnce.Do(func() {
close(c.shutdown)
})
return nil
}
func (c *serverConn) run(sctx context.Context) {
type (
request struct {
id uint32
req *Request
}
response struct {
id uint32
resp *Response
}
)
var (
ch = newChannel(c.conn, c.conn)
ctx, cancel = context.WithCancel(sctx)
active int
state connState = connStateIdle
responses = make(chan response)
requests = make(chan request)
recvErr = make(chan error, 1)
shutdown = c.shutdown
done = make(chan struct{})
)
defer c.conn.Close()
defer cancel()
defer close(done)
go func(recvErr chan error) {
defer close(recvErr)
sendImmediate := func(id uint32, st *status.Status) bool {
select {
case responses <- response{
// even though we've had an invalid stream id, we send it
// back on the same stream id so the client knows which
// stream id was bad.
id: id,
resp: &Response{
Status: st.Proto(),
},
}:
return true
case <-c.shutdown:
return false
case <-done:
return false
}
}
for {
select {
case <-c.shutdown:
return
case <-done:
return
default: // proceed
}
mh, p, err := ch.recv(ctx)
if err != nil {
status, ok := status.FromError(err)
if !ok {
recvErr <- err
return
}
// in this case, we send an error for that particular message
// when the status is defined.
if !sendImmediate(mh.StreamID, status) {
return
}
continue
}
if mh.Type != messageTypeRequest {
// we must ignore this for future compat.
continue
}
var req Request
if err := c.server.codec.Unmarshal(p, &req); err != nil {
ch.putmbuf(p)
if !sendImmediate(mh.StreamID, status.Newf(codes.InvalidArgument, "unmarshal request error: %v", err)) {
return
}
continue
}
ch.putmbuf(p)
if mh.StreamID%2 != 1 {
// enforce odd client initiated identifiers.
if !sendImmediate(mh.StreamID, status.Newf(codes.InvalidArgument, "StreamID must be odd for client initiated streams")) {
return
}
continue
}
// Forward the request to the main loop. We don't wait on s.done
// because we have already accepted the client request.
select {
case requests <- request{
id: mh.StreamID,
req: &req,
}:
case <-done:
return
}
}
}(recvErr)
for {
newstate := state
switch {
case active > 0:
newstate = connStateActive
shutdown = nil
case active == 0:
newstate = connStateIdle
shutdown = c.shutdown // only enable this branch in idle mode
}
if newstate != state {
c.setState(newstate)
state = newstate
}
select {
case request := <-requests:
active++
go func(id uint32) {
p, status := c.server.services.call(ctx, request.req.Service, request.req.Method, request.req.Payload)
resp := &Response{
Status: status.Proto(),
Payload: p,
}
select {
case responses <- response{
id: id,
resp: resp,
}:
case <-done:
}
}(request.id)
case response := <-responses:
p, err := c.server.codec.Marshal(response.resp)
if err != nil {
log.L.WithError(err).Error("failed marshaling response")
return
}
if err := ch.send(ctx, response.id, messageTypeResponse, p); err != nil {
log.L.WithError(err).Error("failed sending message on channel")
return
}
active--
case err := <-recvErr:
// TODO(stevvooe): Not wildly clear what we should do in this
// branch. Basically, it means that we are no longer receiving
// requests due to a terminal error.
recvErr = nil // connection is now "closing"
if err != nil && err != io.EOF {
log.L.WithError(err).Error("error receiving message")
}
case <-shutdown:
return
}
}
}