Merge pull request #11 from stevvooe/unix-socket-credentials

ttrpc: implement unix socket credentials
This commit is contained in:
Stephen Day 2017-11-30 16:55:30 -08:00 committed by GitHub
commit af6e7491e5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 216 additions and 18 deletions

23
config.go Normal file
View File

@ -0,0 +1,23 @@
package ttrpc
import "github.com/pkg/errors"
type serverConfig struct {
handshaker Handshaker
}
type ServerOpt func(*serverConfig) error
// WithServerHandshaker can be passed to NewServer to ensure that the
// handshaker is called before every connection attempt.
//
// Only one handshaker is allowed per server.
func WithServerHandshaker(handshaker Handshaker) ServerOpt {
return func(c *serverConfig) error {
if c.handshaker != nil {
return errors.New("only one handshaker allowed per server")
}
c.handshaker = handshaker
return nil
}
}

24
handshake.go Normal file
View File

@ -0,0 +1,24 @@
package ttrpc
import (
"context"
"net"
)
// Handshaker defines the interface for connection handshakes performed on the
// server or client when first connecting.
type Handshaker interface {
// Handshake should confirm or decorate a connection that may be incoming
// to a server or outgoing from a client.
//
// If this returns without an error, the caller should use the connection
// in place of the original connection.
//
// The second return value can contain credential specific data, such as
// unix socket credentials or TLS information.
//
// While we currently only have implementations on the server-side, this
// interface should be sufficient to implement similar handshakes on the
// client-side.
Handshake(ctx context.Context, conn net.Conn) (net.Conn, interface{}, error)
}

View File

@ -19,6 +19,7 @@ var (
)
type Server struct {
config *serverConfig
services *serviceSet
codec codec
@ -28,13 +29,21 @@ type Server struct {
done chan struct{} // marks point at which we stop serving requests
}
func NewServer() *Server {
func NewServer(opts ...ServerOpt) (*Server, error) {
config := &serverConfig{}
for _, opt := range opts {
if err := opt(config); err != nil {
return nil, err
}
}
return &Server{
config: config,
services: newServiceSet(),
done: make(chan struct{}),
listeners: make(map[net.Listener]struct{}),
connections: make(map[*serverConn]struct{}),
}
}, nil
}
func (s *Server) Register(name string, methods map[string]Method) {
@ -82,7 +91,15 @@ func (s *Server) Serve(l net.Listener) error {
}
backoff = 0
sc := s.newConn(conn)
approved, handshake, err := s.handshake(ctx, conn)
if err != nil {
log.L.WithError(err).Errorf("ttrpc: refusing connection after handshake")
conn.Close()
continue
}
sc := s.newConn(approved, handshake)
go sc.run(ctx)
}
}
@ -133,6 +150,14 @@ func (s *Server) Close() error {
return err
}
func (s *Server) handshake(ctx context.Context, conn net.Conn) (net.Conn, interface{}, error) {
if s.config.handshaker == nil {
return conn, nil, nil
}
return s.config.handshaker.Handshake(ctx, conn)
}
func (s *Server) addListener(l net.Listener) {
s.mu.Lock()
defer s.mu.Unlock()
@ -205,10 +230,11 @@ func (cs connState) String() string {
}
}
func (s *Server) newConn(conn net.Conn) *serverConn {
func (s *Server) newConn(conn net.Conn, handshake interface{}) *serverConn {
c := &serverConn{
server: s,
conn: conn,
handshake: handshake,
shutdown: make(chan struct{}),
}
c.setState(connStateIdle)
@ -219,6 +245,7 @@ func (s *Server) newConn(conn net.Conn) *serverConn {
type serverConn struct {
server *Server
conn net.Conn
handshake interface{} // data from handshake, not used for now
state atomic.Value
shutdownOnce sync.Once

View File

@ -78,7 +78,7 @@ func init() {
func TestServer(t *testing.T) {
var (
ctx = context.Background()
server = NewServer()
server = mustServer(t)(NewServer())
testImpl = &testingServer{}
addr, listener = newTestListener(t)
client, cleanup = newTestClient(t, addr)
@ -109,7 +109,7 @@ func TestServer(t *testing.T) {
func BenchmarkRoundTrip(b *testing.B) {
var (
ctx = context.Background()
server = NewServer()
server = mustServer(b)(NewServer())
testImpl = &testingServer{}
addr, listener = newTestListener(b)
client, cleanup = newTestClient(b, addr)
@ -137,7 +137,7 @@ func BenchmarkRoundTrip(b *testing.B) {
func TestServerNotFound(t *testing.T) {
var (
ctx = context.Background()
server = NewServer()
server = mustServer(t)(NewServer())
addr, listener = newTestListener(t)
errs = make(chan error, 1)
client, cleanup = newTestClient(t, addr)
@ -167,7 +167,7 @@ func TestServerNotFound(t *testing.T) {
func TestServerListenerClosed(t *testing.T) {
var (
server = NewServer()
server = mustServer(t)(NewServer())
_, listener = newTestListener(t)
errs = make(chan error, 1)
)
@ -190,7 +190,7 @@ func TestServerShutdown(t *testing.T) {
const ncalls = 5
var (
ctx = context.Background()
server = NewServer()
server = mustServer(t)(NewServer())
addr, listener = newTestListener(t)
shutdownStarted = make(chan struct{})
shutdownFinished = make(chan struct{})
@ -265,7 +265,7 @@ func TestServerShutdown(t *testing.T) {
func TestServerClose(t *testing.T) {
var (
server = NewServer()
server = mustServer(t)(NewServer())
_, listener = newTestListener(t)
startClose = make(chan struct{})
errs = make(chan error, 1)
@ -292,7 +292,7 @@ func TestServerClose(t *testing.T) {
func TestOversizeCall(t *testing.T) {
var (
ctx = context.Background()
server = NewServer()
server = mustServer(t)(NewServer())
addr, listener = newTestListener(t)
errs = make(chan error, 1)
client, cleanup = newTestClient(t, addr)
@ -327,7 +327,7 @@ func TestOversizeCall(t *testing.T) {
func TestClientEOF(t *testing.T) {
var (
ctx = context.Background()
server = NewServer()
server = mustServer(t)(NewServer())
addr, listener = newTestListener(t)
errs = make(chan error, 1)
client, cleanup = newTestClient(t, addr)
@ -360,6 +360,29 @@ func TestClientEOF(t *testing.T) {
}
}
func TestUnixSocketHandshake(t *testing.T) {
var (
ctx = context.Background()
server = mustServer(t)(NewServer(WithServerHandshaker(UnixSocketRequireSameUser)))
addr, listener = newTestListener(t)
errs = make(chan error, 1)
client, cleanup = newTestClient(t, addr)
)
defer cleanup()
defer listener.Close()
go func() {
errs <- server.Serve(listener)
}()
registerTestingService(server, &testingServer{})
var tp testPayload
// server shutdown, but we still make a call.
if err := client.Call(ctx, serviceName, "Test", &tp, &tp); err != nil {
t.Fatalf("unexpected error making call: %v", err)
}
}
func checkServerShutdown(t *testing.T, server *Server) {
t.Helper()
server.mu.Lock()
@ -420,3 +443,14 @@ func newTestListener(t testing.TB) (string, net.Listener) {
return addr, listener
}
func mustServer(t testing.TB) func(server *Server, err error) *Server {
return func(server *Server, err error) *Server {
t.Helper()
if err != nil {
t.Fatal(err)
}
return server
}
}

90
unixcreds.go Normal file
View File

@ -0,0 +1,90 @@
// +build linux freebsd solaris
package ttrpc
import (
"context"
"net"
"os/user"
"strconv"
"syscall"
"github.com/pkg/errors"
"golang.org/x/sys/unix"
)
var (
UnixSocketRequireSameUser = UnixCredentialsFunc(requireSameUser)
UnixSocketRequireRoot = UnixCredentialsFunc(requireRoot)
)
type UnixCredentialsFunc func(*unix.Ucred) error
func (fn UnixCredentialsFunc) Handshake(ctx context.Context, conn net.Conn) (net.Conn, interface{}, error) {
uc, err := requireUnixSocket(conn)
if err != nil {
return nil, nil, errors.Wrap(err, "ttrpc.UnixCredentialsFunc: require unix socket")
}
fp, err := uc.File()
if err != nil {
return nil, nil, errors.Wrap(err, "ttrpc.UnixCredentialsFunc: failed to get unix file")
}
defer fp.Close() // this gets duped and must be closed when this method is complete.
ucred, err := unix.GetsockoptUcred(int(fp.Fd()), unix.SOL_SOCKET, unix.SO_PEERCRED)
if err != nil {
return nil, nil, errors.Wrapf(err, "ttrpc.UnixCredentialsFunc: failed to retrieve socket peer credentials")
}
if err := fn(ucred); err != nil {
return nil, nil, errors.Wrapf(err, "ttrpc.UnixCredentialsFunc: credential check failed")
}
return uc, ucred, nil
}
func UnixSocketRequireUidGid(uid, gid uint32) UnixCredentialsFunc {
return func(ucred *unix.Ucred) error {
return requireUidGid(ucred, uid, gid)
}
}
func requireRoot(ucred *unix.Ucred) error {
return requireUidGid(ucred, 0, 0)
}
func requireSameUser(ucred *unix.Ucred) error {
u, err := user.Current()
if err != nil {
return errors.Wrapf(err, "could not resolve current user")
}
uid, err := strconv.ParseUint(u.Uid, 10, 32)
if err != nil {
return errors.Wrapf(err, "failed to parse current user uid: %v", u.Uid)
}
gid, err := strconv.ParseUint(u.Gid, 10, 32)
if err != nil {
return errors.Wrapf(err, "failed to parse current user gid: %v", u.Gid)
}
return requireUidGid(ucred, uint32(uid), uint32(gid))
}
func requireUidGid(ucred *unix.Ucred, uid, gid uint32) error {
if (uid != ucred.Uid) || (gid != ucred.Gid) {
return errors.Wrap(syscall.EPERM, "ttrpc: invalid credentials")
}
return nil
}
func requireUnixSocket(conn net.Conn) (*net.UnixConn, error) {
uc, ok := conn.(*net.UnixConn)
if !ok {
return nil, errors.New("a unix socket connection is required")
}
return uc, nil
}