ttrpc: implement unix socket credentials
Because ttrpc can be used with abstract sockets, it is critical to ensure that only certain users can connect to the unix socket. This is of particular interest in the primary use case of containerd, where a shim may run as root and any user can connection. With this, we get a few nice features. The first is the concept of a `Handshaker` that allows one to intercept each connection and replace it with one of their own. The enables credential checks and other measures, such as tls. The second is that servers now support configuration. This allows one to inject a handshaker for each connection. Other options will be added in the future. Signed-off-by: Stephen J Day <stephen.day@docker.com>
This commit is contained in:
parent
8c92e22ce0
commit
d4983e717b
23
config.go
Normal file
23
config.go
Normal file
@ -0,0 +1,23 @@
|
||||
package ttrpc
|
||||
|
||||
import "github.com/pkg/errors"
|
||||
|
||||
type serverConfig struct {
|
||||
handshaker Handshaker
|
||||
}
|
||||
|
||||
type ServerOpt func(*serverConfig) error
|
||||
|
||||
// WithServerHandshaker can be passed to NewServer to ensure that the
|
||||
// handshaker is called before every connection attempt.
|
||||
//
|
||||
// Only one handshaker is allowed per server.
|
||||
func WithServerHandshaker(handshaker Handshaker) ServerOpt {
|
||||
return func(c *serverConfig) error {
|
||||
if c.handshaker != nil {
|
||||
return errors.New("only one handshaker allowed per server")
|
||||
}
|
||||
c.handshaker = handshaker
|
||||
return nil
|
||||
}
|
||||
}
|
24
handshake.go
Normal file
24
handshake.go
Normal file
@ -0,0 +1,24 @@
|
||||
package ttrpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
)
|
||||
|
||||
// Handshaker defines the interface for connection handshakes performed on the
|
||||
// server or client when first connecting.
|
||||
type Handshaker interface {
|
||||
// Handshake should confirm or decorate a connection that may be incoming
|
||||
// to a server or outgoing from a client.
|
||||
//
|
||||
// If this returns without an error, the caller should use the connection
|
||||
// in place of the original connection.
|
||||
//
|
||||
// The second return value can contain credential specific data, such as
|
||||
// unix socket credentials or TLS information.
|
||||
//
|
||||
// While we currently only have implementations on the server-side, this
|
||||
// interface should be sufficient to implement similar handshakes on the
|
||||
// client-side.
|
||||
Handshake(ctx context.Context, conn net.Conn) (net.Conn, interface{}, error)
|
||||
}
|
47
server.go
47
server.go
@ -19,6 +19,7 @@ var (
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
config *serverConfig
|
||||
services *serviceSet
|
||||
codec codec
|
||||
|
||||
@ -28,13 +29,21 @@ type Server struct {
|
||||
done chan struct{} // marks point at which we stop serving requests
|
||||
}
|
||||
|
||||
func NewServer() *Server {
|
||||
func NewServer(opts ...ServerOpt) (*Server, error) {
|
||||
config := &serverConfig{}
|
||||
for _, opt := range opts {
|
||||
if err := opt(config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &Server{
|
||||
config: config,
|
||||
services: newServiceSet(),
|
||||
done: make(chan struct{}),
|
||||
listeners: make(map[net.Listener]struct{}),
|
||||
connections: make(map[*serverConn]struct{}),
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) Register(name string, methods map[string]Method) {
|
||||
@ -82,7 +91,15 @@ func (s *Server) Serve(l net.Listener) error {
|
||||
}
|
||||
|
||||
backoff = 0
|
||||
sc := s.newConn(conn)
|
||||
|
||||
approved, handshake, err := s.handshake(ctx, conn)
|
||||
if err != nil {
|
||||
log.L.WithError(err).Errorf("ttrpc: refusing connection after handshake")
|
||||
conn.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
sc := s.newConn(approved, handshake)
|
||||
go sc.run(ctx)
|
||||
}
|
||||
}
|
||||
@ -133,6 +150,14 @@ func (s *Server) Close() error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Server) handshake(ctx context.Context, conn net.Conn) (net.Conn, interface{}, error) {
|
||||
if s.config.handshaker == nil {
|
||||
return conn, nil, nil
|
||||
}
|
||||
|
||||
return s.config.handshaker.Handshake(ctx, conn)
|
||||
}
|
||||
|
||||
func (s *Server) addListener(l net.Listener) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
@ -205,11 +230,12 @@ func (cs connState) String() string {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) newConn(conn net.Conn) *serverConn {
|
||||
func (s *Server) newConn(conn net.Conn, handshake interface{}) *serverConn {
|
||||
c := &serverConn{
|
||||
server: s,
|
||||
conn: conn,
|
||||
shutdown: make(chan struct{}),
|
||||
server: s,
|
||||
conn: conn,
|
||||
handshake: handshake,
|
||||
shutdown: make(chan struct{}),
|
||||
}
|
||||
c.setState(connStateIdle)
|
||||
s.addConnection(c)
|
||||
@ -217,9 +243,10 @@ func (s *Server) newConn(conn net.Conn) *serverConn {
|
||||
}
|
||||
|
||||
type serverConn struct {
|
||||
server *Server
|
||||
conn net.Conn
|
||||
state atomic.Value
|
||||
server *Server
|
||||
conn net.Conn
|
||||
handshake interface{} // data from handshake, not used for now
|
||||
state atomic.Value
|
||||
|
||||
shutdownOnce sync.Once
|
||||
shutdown chan struct{} // forced shutdown, used by close
|
||||
|
@ -78,7 +78,7 @@ func init() {
|
||||
func TestServer(t *testing.T) {
|
||||
var (
|
||||
ctx = context.Background()
|
||||
server = NewServer()
|
||||
server = mustServer(t)(NewServer())
|
||||
testImpl = &testingServer{}
|
||||
addr, listener = newTestListener(t)
|
||||
client, cleanup = newTestClient(t, addr)
|
||||
@ -109,7 +109,7 @@ func TestServer(t *testing.T) {
|
||||
func BenchmarkRoundTrip(b *testing.B) {
|
||||
var (
|
||||
ctx = context.Background()
|
||||
server = NewServer()
|
||||
server = mustServer(b)(NewServer())
|
||||
testImpl = &testingServer{}
|
||||
addr, listener = newTestListener(b)
|
||||
client, cleanup = newTestClient(b, addr)
|
||||
@ -137,7 +137,7 @@ func BenchmarkRoundTrip(b *testing.B) {
|
||||
func TestServerNotFound(t *testing.T) {
|
||||
var (
|
||||
ctx = context.Background()
|
||||
server = NewServer()
|
||||
server = mustServer(t)(NewServer())
|
||||
addr, listener = newTestListener(t)
|
||||
errs = make(chan error, 1)
|
||||
client, cleanup = newTestClient(t, addr)
|
||||
@ -167,7 +167,7 @@ func TestServerNotFound(t *testing.T) {
|
||||
|
||||
func TestServerListenerClosed(t *testing.T) {
|
||||
var (
|
||||
server = NewServer()
|
||||
server = mustServer(t)(NewServer())
|
||||
_, listener = newTestListener(t)
|
||||
errs = make(chan error, 1)
|
||||
)
|
||||
@ -190,7 +190,7 @@ func TestServerShutdown(t *testing.T) {
|
||||
const ncalls = 5
|
||||
var (
|
||||
ctx = context.Background()
|
||||
server = NewServer()
|
||||
server = mustServer(t)(NewServer())
|
||||
addr, listener = newTestListener(t)
|
||||
shutdownStarted = make(chan struct{})
|
||||
shutdownFinished = make(chan struct{})
|
||||
@ -265,7 +265,7 @@ func TestServerShutdown(t *testing.T) {
|
||||
|
||||
func TestServerClose(t *testing.T) {
|
||||
var (
|
||||
server = NewServer()
|
||||
server = mustServer(t)(NewServer())
|
||||
_, listener = newTestListener(t)
|
||||
startClose = make(chan struct{})
|
||||
errs = make(chan error, 1)
|
||||
@ -292,7 +292,7 @@ func TestServerClose(t *testing.T) {
|
||||
func TestOversizeCall(t *testing.T) {
|
||||
var (
|
||||
ctx = context.Background()
|
||||
server = NewServer()
|
||||
server = mustServer(t)(NewServer())
|
||||
addr, listener = newTestListener(t)
|
||||
errs = make(chan error, 1)
|
||||
client, cleanup = newTestClient(t, addr)
|
||||
@ -327,7 +327,7 @@ func TestOversizeCall(t *testing.T) {
|
||||
func TestClientEOF(t *testing.T) {
|
||||
var (
|
||||
ctx = context.Background()
|
||||
server = NewServer()
|
||||
server = mustServer(t)(NewServer())
|
||||
addr, listener = newTestListener(t)
|
||||
errs = make(chan error, 1)
|
||||
client, cleanup = newTestClient(t, addr)
|
||||
@ -360,6 +360,29 @@ func TestClientEOF(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnixSocketHandshake(t *testing.T) {
|
||||
var (
|
||||
ctx = context.Background()
|
||||
server = mustServer(t)(NewServer(WithServerHandshaker(UnixSocketRequireSameUser)))
|
||||
addr, listener = newTestListener(t)
|
||||
errs = make(chan error, 1)
|
||||
client, cleanup = newTestClient(t, addr)
|
||||
)
|
||||
defer cleanup()
|
||||
defer listener.Close()
|
||||
go func() {
|
||||
errs <- server.Serve(listener)
|
||||
}()
|
||||
|
||||
registerTestingService(server, &testingServer{})
|
||||
|
||||
var tp testPayload
|
||||
// server shutdown, but we still make a call.
|
||||
if err := client.Call(ctx, serviceName, "Test", &tp, &tp); err != nil {
|
||||
t.Fatalf("unexpected error making call: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func checkServerShutdown(t *testing.T, server *Server) {
|
||||
t.Helper()
|
||||
server.mu.Lock()
|
||||
@ -420,3 +443,14 @@ func newTestListener(t testing.TB) (string, net.Listener) {
|
||||
|
||||
return addr, listener
|
||||
}
|
||||
|
||||
func mustServer(t testing.TB) func(server *Server, err error) *Server {
|
||||
return func(server *Server, err error) *Server {
|
||||
t.Helper()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return server
|
||||
}
|
||||
}
|
||||
|
90
unixcreds.go
Normal file
90
unixcreds.go
Normal file
@ -0,0 +1,90 @@
|
||||
// +build linux freebsd solaris
|
||||
|
||||
package ttrpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"os/user"
|
||||
"strconv"
|
||||
"syscall"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
var (
|
||||
UnixSocketRequireSameUser = UnixCredentialsFunc(requireSameUser)
|
||||
UnixSocketRequireRoot = UnixCredentialsFunc(requireRoot)
|
||||
)
|
||||
|
||||
type UnixCredentialsFunc func(*unix.Ucred) error
|
||||
|
||||
func (fn UnixCredentialsFunc) Handshake(ctx context.Context, conn net.Conn) (net.Conn, interface{}, error) {
|
||||
uc, err := requireUnixSocket(conn)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "ttrpc.UnixCredentialsFunc: require unix socket")
|
||||
}
|
||||
|
||||
fp, err := uc.File()
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "ttrpc.UnixCredentialsFunc: failed to get unix file")
|
||||
}
|
||||
defer fp.Close() // this gets duped and must be closed when this method is complete.
|
||||
|
||||
ucred, err := unix.GetsockoptUcred(int(fp.Fd()), unix.SOL_SOCKET, unix.SO_PEERCRED)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrapf(err, "ttrpc.UnixCredentialsFunc: failed to retrieve socket peer credentials")
|
||||
}
|
||||
|
||||
if err := fn(ucred); err != nil {
|
||||
return nil, nil, errors.Wrapf(err, "ttrpc.UnixCredentialsFunc: credential check failed")
|
||||
}
|
||||
|
||||
return uc, ucred, nil
|
||||
}
|
||||
|
||||
func UnixSocketRequireUidGid(uid, gid uint32) UnixCredentialsFunc {
|
||||
return func(ucred *unix.Ucred) error {
|
||||
return requireUidGid(ucred, uid, gid)
|
||||
}
|
||||
}
|
||||
|
||||
func requireRoot(ucred *unix.Ucred) error {
|
||||
return requireUidGid(ucred, 0, 0)
|
||||
}
|
||||
|
||||
func requireSameUser(ucred *unix.Ucred) error {
|
||||
u, err := user.Current()
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "could not resolve current user")
|
||||
}
|
||||
|
||||
uid, err := strconv.ParseUint(u.Uid, 10, 32)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to parse current user uid: %v", u.Uid)
|
||||
}
|
||||
|
||||
gid, err := strconv.ParseUint(u.Gid, 10, 32)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to parse current user gid: %v", u.Gid)
|
||||
}
|
||||
|
||||
return requireUidGid(ucred, uint32(uid), uint32(gid))
|
||||
}
|
||||
|
||||
func requireUidGid(ucred *unix.Ucred, uid, gid uint32) error {
|
||||
if (uid != ucred.Uid) || (gid != ucred.Gid) {
|
||||
return errors.Wrap(syscall.EPERM, "ttrpc: invalid credentials")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func requireUnixSocket(conn net.Conn) (*net.UnixConn, error) {
|
||||
uc, ok := conn.(*net.UnixConn)
|
||||
if !ok {
|
||||
return nil, errors.New("a unix socket connection is required")
|
||||
}
|
||||
|
||||
return uc, nil
|
||||
}
|
Loading…
Reference in New Issue
Block a user