ttrpc: implement unix socket credentials

Because ttrpc can be used with abstract sockets, it is critical to
ensure that only certain users can connect to the unix socket. This is
of particular interest in the primary use case of containerd, where a
shim may run as root and any user can connection.

With this, we get a few nice features. The first is the concept of a
`Handshaker` that allows one to intercept each connection and replace it
with one of their own. The enables credential checks and other measures,
such as tls. The second is that servers now support configuration. This
allows one to inject a handshaker for each connection. Other options
will be added in the future.

Signed-off-by: Stephen J Day <stephen.day@docker.com>
This commit is contained in:
Stephen J Day 2017-11-30 15:07:25 -08:00
parent 8c92e22ce0
commit d4983e717b
No known key found for this signature in database
GPG Key ID: 67B3DED84EDC823F
5 changed files with 216 additions and 18 deletions

23
config.go Normal file
View File

@ -0,0 +1,23 @@
package ttrpc
import "github.com/pkg/errors"
type serverConfig struct {
handshaker Handshaker
}
type ServerOpt func(*serverConfig) error
// WithServerHandshaker can be passed to NewServer to ensure that the
// handshaker is called before every connection attempt.
//
// Only one handshaker is allowed per server.
func WithServerHandshaker(handshaker Handshaker) ServerOpt {
return func(c *serverConfig) error {
if c.handshaker != nil {
return errors.New("only one handshaker allowed per server")
}
c.handshaker = handshaker
return nil
}
}

24
handshake.go Normal file
View File

@ -0,0 +1,24 @@
package ttrpc
import (
"context"
"net"
)
// Handshaker defines the interface for connection handshakes performed on the
// server or client when first connecting.
type Handshaker interface {
// Handshake should confirm or decorate a connection that may be incoming
// to a server or outgoing from a client.
//
// If this returns without an error, the caller should use the connection
// in place of the original connection.
//
// The second return value can contain credential specific data, such as
// unix socket credentials or TLS information.
//
// While we currently only have implementations on the server-side, this
// interface should be sufficient to implement similar handshakes on the
// client-side.
Handshake(ctx context.Context, conn net.Conn) (net.Conn, interface{}, error)
}

View File

@ -19,6 +19,7 @@ var (
)
type Server struct {
config *serverConfig
services *serviceSet
codec codec
@ -28,13 +29,21 @@ type Server struct {
done chan struct{} // marks point at which we stop serving requests
}
func NewServer() *Server {
func NewServer(opts ...ServerOpt) (*Server, error) {
config := &serverConfig{}
for _, opt := range opts {
if err := opt(config); err != nil {
return nil, err
}
}
return &Server{
config: config,
services: newServiceSet(),
done: make(chan struct{}),
listeners: make(map[net.Listener]struct{}),
connections: make(map[*serverConn]struct{}),
}
}, nil
}
func (s *Server) Register(name string, methods map[string]Method) {
@ -82,7 +91,15 @@ func (s *Server) Serve(l net.Listener) error {
}
backoff = 0
sc := s.newConn(conn)
approved, handshake, err := s.handshake(ctx, conn)
if err != nil {
log.L.WithError(err).Errorf("ttrpc: refusing connection after handshake")
conn.Close()
continue
}
sc := s.newConn(approved, handshake)
go sc.run(ctx)
}
}
@ -133,6 +150,14 @@ func (s *Server) Close() error {
return err
}
func (s *Server) handshake(ctx context.Context, conn net.Conn) (net.Conn, interface{}, error) {
if s.config.handshaker == nil {
return conn, nil, nil
}
return s.config.handshaker.Handshake(ctx, conn)
}
func (s *Server) addListener(l net.Listener) {
s.mu.Lock()
defer s.mu.Unlock()
@ -205,11 +230,12 @@ func (cs connState) String() string {
}
}
func (s *Server) newConn(conn net.Conn) *serverConn {
func (s *Server) newConn(conn net.Conn, handshake interface{}) *serverConn {
c := &serverConn{
server: s,
conn: conn,
shutdown: make(chan struct{}),
server: s,
conn: conn,
handshake: handshake,
shutdown: make(chan struct{}),
}
c.setState(connStateIdle)
s.addConnection(c)
@ -217,9 +243,10 @@ func (s *Server) newConn(conn net.Conn) *serverConn {
}
type serverConn struct {
server *Server
conn net.Conn
state atomic.Value
server *Server
conn net.Conn
handshake interface{} // data from handshake, not used for now
state atomic.Value
shutdownOnce sync.Once
shutdown chan struct{} // forced shutdown, used by close

View File

@ -78,7 +78,7 @@ func init() {
func TestServer(t *testing.T) {
var (
ctx = context.Background()
server = NewServer()
server = mustServer(t)(NewServer())
testImpl = &testingServer{}
addr, listener = newTestListener(t)
client, cleanup = newTestClient(t, addr)
@ -109,7 +109,7 @@ func TestServer(t *testing.T) {
func BenchmarkRoundTrip(b *testing.B) {
var (
ctx = context.Background()
server = NewServer()
server = mustServer(b)(NewServer())
testImpl = &testingServer{}
addr, listener = newTestListener(b)
client, cleanup = newTestClient(b, addr)
@ -137,7 +137,7 @@ func BenchmarkRoundTrip(b *testing.B) {
func TestServerNotFound(t *testing.T) {
var (
ctx = context.Background()
server = NewServer()
server = mustServer(t)(NewServer())
addr, listener = newTestListener(t)
errs = make(chan error, 1)
client, cleanup = newTestClient(t, addr)
@ -167,7 +167,7 @@ func TestServerNotFound(t *testing.T) {
func TestServerListenerClosed(t *testing.T) {
var (
server = NewServer()
server = mustServer(t)(NewServer())
_, listener = newTestListener(t)
errs = make(chan error, 1)
)
@ -190,7 +190,7 @@ func TestServerShutdown(t *testing.T) {
const ncalls = 5
var (
ctx = context.Background()
server = NewServer()
server = mustServer(t)(NewServer())
addr, listener = newTestListener(t)
shutdownStarted = make(chan struct{})
shutdownFinished = make(chan struct{})
@ -265,7 +265,7 @@ func TestServerShutdown(t *testing.T) {
func TestServerClose(t *testing.T) {
var (
server = NewServer()
server = mustServer(t)(NewServer())
_, listener = newTestListener(t)
startClose = make(chan struct{})
errs = make(chan error, 1)
@ -292,7 +292,7 @@ func TestServerClose(t *testing.T) {
func TestOversizeCall(t *testing.T) {
var (
ctx = context.Background()
server = NewServer()
server = mustServer(t)(NewServer())
addr, listener = newTestListener(t)
errs = make(chan error, 1)
client, cleanup = newTestClient(t, addr)
@ -327,7 +327,7 @@ func TestOversizeCall(t *testing.T) {
func TestClientEOF(t *testing.T) {
var (
ctx = context.Background()
server = NewServer()
server = mustServer(t)(NewServer())
addr, listener = newTestListener(t)
errs = make(chan error, 1)
client, cleanup = newTestClient(t, addr)
@ -360,6 +360,29 @@ func TestClientEOF(t *testing.T) {
}
}
func TestUnixSocketHandshake(t *testing.T) {
var (
ctx = context.Background()
server = mustServer(t)(NewServer(WithServerHandshaker(UnixSocketRequireSameUser)))
addr, listener = newTestListener(t)
errs = make(chan error, 1)
client, cleanup = newTestClient(t, addr)
)
defer cleanup()
defer listener.Close()
go func() {
errs <- server.Serve(listener)
}()
registerTestingService(server, &testingServer{})
var tp testPayload
// server shutdown, but we still make a call.
if err := client.Call(ctx, serviceName, "Test", &tp, &tp); err != nil {
t.Fatalf("unexpected error making call: %v", err)
}
}
func checkServerShutdown(t *testing.T, server *Server) {
t.Helper()
server.mu.Lock()
@ -420,3 +443,14 @@ func newTestListener(t testing.TB) (string, net.Listener) {
return addr, listener
}
func mustServer(t testing.TB) func(server *Server, err error) *Server {
return func(server *Server, err error) *Server {
t.Helper()
if err != nil {
t.Fatal(err)
}
return server
}
}

90
unixcreds.go Normal file
View File

@ -0,0 +1,90 @@
// +build linux freebsd solaris
package ttrpc
import (
"context"
"net"
"os/user"
"strconv"
"syscall"
"github.com/pkg/errors"
"golang.org/x/sys/unix"
)
var (
UnixSocketRequireSameUser = UnixCredentialsFunc(requireSameUser)
UnixSocketRequireRoot = UnixCredentialsFunc(requireRoot)
)
type UnixCredentialsFunc func(*unix.Ucred) error
func (fn UnixCredentialsFunc) Handshake(ctx context.Context, conn net.Conn) (net.Conn, interface{}, error) {
uc, err := requireUnixSocket(conn)
if err != nil {
return nil, nil, errors.Wrap(err, "ttrpc.UnixCredentialsFunc: require unix socket")
}
fp, err := uc.File()
if err != nil {
return nil, nil, errors.Wrap(err, "ttrpc.UnixCredentialsFunc: failed to get unix file")
}
defer fp.Close() // this gets duped and must be closed when this method is complete.
ucred, err := unix.GetsockoptUcred(int(fp.Fd()), unix.SOL_SOCKET, unix.SO_PEERCRED)
if err != nil {
return nil, nil, errors.Wrapf(err, "ttrpc.UnixCredentialsFunc: failed to retrieve socket peer credentials")
}
if err := fn(ucred); err != nil {
return nil, nil, errors.Wrapf(err, "ttrpc.UnixCredentialsFunc: credential check failed")
}
return uc, ucred, nil
}
func UnixSocketRequireUidGid(uid, gid uint32) UnixCredentialsFunc {
return func(ucred *unix.Ucred) error {
return requireUidGid(ucred, uid, gid)
}
}
func requireRoot(ucred *unix.Ucred) error {
return requireUidGid(ucred, 0, 0)
}
func requireSameUser(ucred *unix.Ucred) error {
u, err := user.Current()
if err != nil {
return errors.Wrapf(err, "could not resolve current user")
}
uid, err := strconv.ParseUint(u.Uid, 10, 32)
if err != nil {
return errors.Wrapf(err, "failed to parse current user uid: %v", u.Uid)
}
gid, err := strconv.ParseUint(u.Gid, 10, 32)
if err != nil {
return errors.Wrapf(err, "failed to parse current user gid: %v", u.Gid)
}
return requireUidGid(ucred, uint32(uid), uint32(gid))
}
func requireUidGid(ucred *unix.Ucred, uid, gid uint32) error {
if (uid != ucred.Uid) || (gid != ucred.Gid) {
return errors.Wrap(syscall.EPERM, "ttrpc: invalid credentials")
}
return nil
}
func requireUnixSocket(conn net.Conn) (*net.UnixConn, error) {
uc, ok := conn.(*net.UnixConn)
if !ok {
return nil, errors.New("a unix socket connection is required")
}
return uc, nil
}