diff --git a/config.go b/config.go new file mode 100644 index 0000000..23bc603 --- /dev/null +++ b/config.go @@ -0,0 +1,23 @@ +package ttrpc + +import "github.com/pkg/errors" + +type serverConfig struct { + handshaker Handshaker +} + +type ServerOpt func(*serverConfig) error + +// WithServerHandshaker can be passed to NewServer to ensure that the +// handshaker is called before every connection attempt. +// +// Only one handshaker is allowed per server. +func WithServerHandshaker(handshaker Handshaker) ServerOpt { + return func(c *serverConfig) error { + if c.handshaker != nil { + return errors.New("only one handshaker allowed per server") + } + c.handshaker = handshaker + return nil + } +} diff --git a/handshake.go b/handshake.go new file mode 100644 index 0000000..80e1cc1 --- /dev/null +++ b/handshake.go @@ -0,0 +1,24 @@ +package ttrpc + +import ( + "context" + "net" +) + +// Handshaker defines the interface for connection handshakes performed on the +// server or client when first connecting. +type Handshaker interface { + // Handshake should confirm or decorate a connection that may be incoming + // to a server or outgoing from a client. + // + // If this returns without an error, the caller should use the connection + // in place of the original connection. + // + // The second return value can contain credential specific data, such as + // unix socket credentials or TLS information. + // + // While we currently only have implementations on the server-side, this + // interface should be sufficient to implement similar handshakes on the + // client-side. + Handshake(ctx context.Context, conn net.Conn) (net.Conn, interface{}, error) +} diff --git a/server.go b/server.go index ed2d14c..beae908 100644 --- a/server.go +++ b/server.go @@ -19,6 +19,7 @@ var ( ) type Server struct { + config *serverConfig services *serviceSet codec codec @@ -28,13 +29,21 @@ type Server struct { done chan struct{} // marks point at which we stop serving requests } -func NewServer() *Server { +func NewServer(opts ...ServerOpt) (*Server, error) { + config := &serverConfig{} + for _, opt := range opts { + if err := opt(config); err != nil { + return nil, err + } + } + return &Server{ + config: config, services: newServiceSet(), done: make(chan struct{}), listeners: make(map[net.Listener]struct{}), connections: make(map[*serverConn]struct{}), - } + }, nil } func (s *Server) Register(name string, methods map[string]Method) { @@ -82,7 +91,15 @@ func (s *Server) Serve(l net.Listener) error { } backoff = 0 - sc := s.newConn(conn) + + approved, handshake, err := s.handshake(ctx, conn) + if err != nil { + log.L.WithError(err).Errorf("ttrpc: refusing connection after handshake") + conn.Close() + continue + } + + sc := s.newConn(approved, handshake) go sc.run(ctx) } } @@ -133,6 +150,14 @@ func (s *Server) Close() error { return err } +func (s *Server) handshake(ctx context.Context, conn net.Conn) (net.Conn, interface{}, error) { + if s.config.handshaker == nil { + return conn, nil, nil + } + + return s.config.handshaker.Handshake(ctx, conn) +} + func (s *Server) addListener(l net.Listener) { s.mu.Lock() defer s.mu.Unlock() @@ -205,11 +230,12 @@ func (cs connState) String() string { } } -func (s *Server) newConn(conn net.Conn) *serverConn { +func (s *Server) newConn(conn net.Conn, handshake interface{}) *serverConn { c := &serverConn{ - server: s, - conn: conn, - shutdown: make(chan struct{}), + server: s, + conn: conn, + handshake: handshake, + shutdown: make(chan struct{}), } c.setState(connStateIdle) s.addConnection(c) @@ -217,9 +243,10 @@ func (s *Server) newConn(conn net.Conn) *serverConn { } type serverConn struct { - server *Server - conn net.Conn - state atomic.Value + server *Server + conn net.Conn + handshake interface{} // data from handshake, not used for now + state atomic.Value shutdownOnce sync.Once shutdown chan struct{} // forced shutdown, used by close diff --git a/server_test.go b/server_test.go index 9f84e33..6ba8ddb 100644 --- a/server_test.go +++ b/server_test.go @@ -78,7 +78,7 @@ func init() { func TestServer(t *testing.T) { var ( ctx = context.Background() - server = NewServer() + server = mustServer(t)(NewServer()) testImpl = &testingServer{} addr, listener = newTestListener(t) client, cleanup = newTestClient(t, addr) @@ -109,7 +109,7 @@ func TestServer(t *testing.T) { func BenchmarkRoundTrip(b *testing.B) { var ( ctx = context.Background() - server = NewServer() + server = mustServer(b)(NewServer()) testImpl = &testingServer{} addr, listener = newTestListener(b) client, cleanup = newTestClient(b, addr) @@ -137,7 +137,7 @@ func BenchmarkRoundTrip(b *testing.B) { func TestServerNotFound(t *testing.T) { var ( ctx = context.Background() - server = NewServer() + server = mustServer(t)(NewServer()) addr, listener = newTestListener(t) errs = make(chan error, 1) client, cleanup = newTestClient(t, addr) @@ -167,7 +167,7 @@ func TestServerNotFound(t *testing.T) { func TestServerListenerClosed(t *testing.T) { var ( - server = NewServer() + server = mustServer(t)(NewServer()) _, listener = newTestListener(t) errs = make(chan error, 1) ) @@ -190,7 +190,7 @@ func TestServerShutdown(t *testing.T) { const ncalls = 5 var ( ctx = context.Background() - server = NewServer() + server = mustServer(t)(NewServer()) addr, listener = newTestListener(t) shutdownStarted = make(chan struct{}) shutdownFinished = make(chan struct{}) @@ -265,7 +265,7 @@ func TestServerShutdown(t *testing.T) { func TestServerClose(t *testing.T) { var ( - server = NewServer() + server = mustServer(t)(NewServer()) _, listener = newTestListener(t) startClose = make(chan struct{}) errs = make(chan error, 1) @@ -292,7 +292,7 @@ func TestServerClose(t *testing.T) { func TestOversizeCall(t *testing.T) { var ( ctx = context.Background() - server = NewServer() + server = mustServer(t)(NewServer()) addr, listener = newTestListener(t) errs = make(chan error, 1) client, cleanup = newTestClient(t, addr) @@ -327,7 +327,7 @@ func TestOversizeCall(t *testing.T) { func TestClientEOF(t *testing.T) { var ( ctx = context.Background() - server = NewServer() + server = mustServer(t)(NewServer()) addr, listener = newTestListener(t) errs = make(chan error, 1) client, cleanup = newTestClient(t, addr) @@ -360,6 +360,29 @@ func TestClientEOF(t *testing.T) { } } +func TestUnixSocketHandshake(t *testing.T) { + var ( + ctx = context.Background() + server = mustServer(t)(NewServer(WithServerHandshaker(UnixSocketRequireSameUser))) + addr, listener = newTestListener(t) + errs = make(chan error, 1) + client, cleanup = newTestClient(t, addr) + ) + defer cleanup() + defer listener.Close() + go func() { + errs <- server.Serve(listener) + }() + + registerTestingService(server, &testingServer{}) + + var tp testPayload + // server shutdown, but we still make a call. + if err := client.Call(ctx, serviceName, "Test", &tp, &tp); err != nil { + t.Fatalf("unexpected error making call: %v", err) + } +} + func checkServerShutdown(t *testing.T, server *Server) { t.Helper() server.mu.Lock() @@ -420,3 +443,14 @@ func newTestListener(t testing.TB) (string, net.Listener) { return addr, listener } + +func mustServer(t testing.TB) func(server *Server, err error) *Server { + return func(server *Server, err error) *Server { + t.Helper() + if err != nil { + t.Fatal(err) + } + + return server + } +} diff --git a/unixcreds.go b/unixcreds.go new file mode 100644 index 0000000..776761d --- /dev/null +++ b/unixcreds.go @@ -0,0 +1,90 @@ +// +build linux freebsd solaris + +package ttrpc + +import ( + "context" + "net" + "os/user" + "strconv" + "syscall" + + "github.com/pkg/errors" + "golang.org/x/sys/unix" +) + +var ( + UnixSocketRequireSameUser = UnixCredentialsFunc(requireSameUser) + UnixSocketRequireRoot = UnixCredentialsFunc(requireRoot) +) + +type UnixCredentialsFunc func(*unix.Ucred) error + +func (fn UnixCredentialsFunc) Handshake(ctx context.Context, conn net.Conn) (net.Conn, interface{}, error) { + uc, err := requireUnixSocket(conn) + if err != nil { + return nil, nil, errors.Wrap(err, "ttrpc.UnixCredentialsFunc: require unix socket") + } + + fp, err := uc.File() + if err != nil { + return nil, nil, errors.Wrap(err, "ttrpc.UnixCredentialsFunc: failed to get unix file") + } + defer fp.Close() // this gets duped and must be closed when this method is complete. + + ucred, err := unix.GetsockoptUcred(int(fp.Fd()), unix.SOL_SOCKET, unix.SO_PEERCRED) + if err != nil { + return nil, nil, errors.Wrapf(err, "ttrpc.UnixCredentialsFunc: failed to retrieve socket peer credentials") + } + + if err := fn(ucred); err != nil { + return nil, nil, errors.Wrapf(err, "ttrpc.UnixCredentialsFunc: credential check failed") + } + + return uc, ucred, nil +} + +func UnixSocketRequireUidGid(uid, gid uint32) UnixCredentialsFunc { + return func(ucred *unix.Ucred) error { + return requireUidGid(ucred, uid, gid) + } +} + +func requireRoot(ucred *unix.Ucred) error { + return requireUidGid(ucred, 0, 0) +} + +func requireSameUser(ucred *unix.Ucred) error { + u, err := user.Current() + if err != nil { + return errors.Wrapf(err, "could not resolve current user") + } + + uid, err := strconv.ParseUint(u.Uid, 10, 32) + if err != nil { + return errors.Wrapf(err, "failed to parse current user uid: %v", u.Uid) + } + + gid, err := strconv.ParseUint(u.Gid, 10, 32) + if err != nil { + return errors.Wrapf(err, "failed to parse current user gid: %v", u.Gid) + } + + return requireUidGid(ucred, uint32(uid), uint32(gid)) +} + +func requireUidGid(ucred *unix.Ucred, uid, gid uint32) error { + if (uid != ucred.Uid) || (gid != ucred.Gid) { + return errors.Wrap(syscall.EPERM, "ttrpc: invalid credentials") + } + return nil +} + +func requireUnixSocket(conn net.Conn) (*net.UnixConn, error) { + uc, ok := conn.(*net.UnixConn) + if !ok { + return nil, errors.New("a unix socket connection is required") + } + + return uc, nil +}