Merge pull request #11 from stevvooe/unix-socket-credentials
ttrpc: implement unix socket credentials
This commit is contained in:
commit
af6e7491e5
23
config.go
Normal file
23
config.go
Normal file
@ -0,0 +1,23 @@
|
||||
package ttrpc
|
||||
|
||||
import "github.com/pkg/errors"
|
||||
|
||||
type serverConfig struct {
|
||||
handshaker Handshaker
|
||||
}
|
||||
|
||||
type ServerOpt func(*serverConfig) error
|
||||
|
||||
// WithServerHandshaker can be passed to NewServer to ensure that the
|
||||
// handshaker is called before every connection attempt.
|
||||
//
|
||||
// Only one handshaker is allowed per server.
|
||||
func WithServerHandshaker(handshaker Handshaker) ServerOpt {
|
||||
return func(c *serverConfig) error {
|
||||
if c.handshaker != nil {
|
||||
return errors.New("only one handshaker allowed per server")
|
||||
}
|
||||
c.handshaker = handshaker
|
||||
return nil
|
||||
}
|
||||
}
|
24
handshake.go
Normal file
24
handshake.go
Normal file
@ -0,0 +1,24 @@
|
||||
package ttrpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
)
|
||||
|
||||
// Handshaker defines the interface for connection handshakes performed on the
|
||||
// server or client when first connecting.
|
||||
type Handshaker interface {
|
||||
// Handshake should confirm or decorate a connection that may be incoming
|
||||
// to a server or outgoing from a client.
|
||||
//
|
||||
// If this returns without an error, the caller should use the connection
|
||||
// in place of the original connection.
|
||||
//
|
||||
// The second return value can contain credential specific data, such as
|
||||
// unix socket credentials or TLS information.
|
||||
//
|
||||
// While we currently only have implementations on the server-side, this
|
||||
// interface should be sufficient to implement similar handshakes on the
|
||||
// client-side.
|
||||
Handshake(ctx context.Context, conn net.Conn) (net.Conn, interface{}, error)
|
||||
}
|
47
server.go
47
server.go
@ -19,6 +19,7 @@ var (
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
config *serverConfig
|
||||
services *serviceSet
|
||||
codec codec
|
||||
|
||||
@ -28,13 +29,21 @@ type Server struct {
|
||||
done chan struct{} // marks point at which we stop serving requests
|
||||
}
|
||||
|
||||
func NewServer() *Server {
|
||||
func NewServer(opts ...ServerOpt) (*Server, error) {
|
||||
config := &serverConfig{}
|
||||
for _, opt := range opts {
|
||||
if err := opt(config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &Server{
|
||||
config: config,
|
||||
services: newServiceSet(),
|
||||
done: make(chan struct{}),
|
||||
listeners: make(map[net.Listener]struct{}),
|
||||
connections: make(map[*serverConn]struct{}),
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) Register(name string, methods map[string]Method) {
|
||||
@ -82,7 +91,15 @@ func (s *Server) Serve(l net.Listener) error {
|
||||
}
|
||||
|
||||
backoff = 0
|
||||
sc := s.newConn(conn)
|
||||
|
||||
approved, handshake, err := s.handshake(ctx, conn)
|
||||
if err != nil {
|
||||
log.L.WithError(err).Errorf("ttrpc: refusing connection after handshake")
|
||||
conn.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
sc := s.newConn(approved, handshake)
|
||||
go sc.run(ctx)
|
||||
}
|
||||
}
|
||||
@ -133,6 +150,14 @@ func (s *Server) Close() error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Server) handshake(ctx context.Context, conn net.Conn) (net.Conn, interface{}, error) {
|
||||
if s.config.handshaker == nil {
|
||||
return conn, nil, nil
|
||||
}
|
||||
|
||||
return s.config.handshaker.Handshake(ctx, conn)
|
||||
}
|
||||
|
||||
func (s *Server) addListener(l net.Listener) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
@ -205,11 +230,12 @@ func (cs connState) String() string {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) newConn(conn net.Conn) *serverConn {
|
||||
func (s *Server) newConn(conn net.Conn, handshake interface{}) *serverConn {
|
||||
c := &serverConn{
|
||||
server: s,
|
||||
conn: conn,
|
||||
shutdown: make(chan struct{}),
|
||||
server: s,
|
||||
conn: conn,
|
||||
handshake: handshake,
|
||||
shutdown: make(chan struct{}),
|
||||
}
|
||||
c.setState(connStateIdle)
|
||||
s.addConnection(c)
|
||||
@ -217,9 +243,10 @@ func (s *Server) newConn(conn net.Conn) *serverConn {
|
||||
}
|
||||
|
||||
type serverConn struct {
|
||||
server *Server
|
||||
conn net.Conn
|
||||
state atomic.Value
|
||||
server *Server
|
||||
conn net.Conn
|
||||
handshake interface{} // data from handshake, not used for now
|
||||
state atomic.Value
|
||||
|
||||
shutdownOnce sync.Once
|
||||
shutdown chan struct{} // forced shutdown, used by close
|
||||
|
@ -78,7 +78,7 @@ func init() {
|
||||
func TestServer(t *testing.T) {
|
||||
var (
|
||||
ctx = context.Background()
|
||||
server = NewServer()
|
||||
server = mustServer(t)(NewServer())
|
||||
testImpl = &testingServer{}
|
||||
addr, listener = newTestListener(t)
|
||||
client, cleanup = newTestClient(t, addr)
|
||||
@ -109,7 +109,7 @@ func TestServer(t *testing.T) {
|
||||
func BenchmarkRoundTrip(b *testing.B) {
|
||||
var (
|
||||
ctx = context.Background()
|
||||
server = NewServer()
|
||||
server = mustServer(b)(NewServer())
|
||||
testImpl = &testingServer{}
|
||||
addr, listener = newTestListener(b)
|
||||
client, cleanup = newTestClient(b, addr)
|
||||
@ -137,7 +137,7 @@ func BenchmarkRoundTrip(b *testing.B) {
|
||||
func TestServerNotFound(t *testing.T) {
|
||||
var (
|
||||
ctx = context.Background()
|
||||
server = NewServer()
|
||||
server = mustServer(t)(NewServer())
|
||||
addr, listener = newTestListener(t)
|
||||
errs = make(chan error, 1)
|
||||
client, cleanup = newTestClient(t, addr)
|
||||
@ -167,7 +167,7 @@ func TestServerNotFound(t *testing.T) {
|
||||
|
||||
func TestServerListenerClosed(t *testing.T) {
|
||||
var (
|
||||
server = NewServer()
|
||||
server = mustServer(t)(NewServer())
|
||||
_, listener = newTestListener(t)
|
||||
errs = make(chan error, 1)
|
||||
)
|
||||
@ -190,7 +190,7 @@ func TestServerShutdown(t *testing.T) {
|
||||
const ncalls = 5
|
||||
var (
|
||||
ctx = context.Background()
|
||||
server = NewServer()
|
||||
server = mustServer(t)(NewServer())
|
||||
addr, listener = newTestListener(t)
|
||||
shutdownStarted = make(chan struct{})
|
||||
shutdownFinished = make(chan struct{})
|
||||
@ -265,7 +265,7 @@ func TestServerShutdown(t *testing.T) {
|
||||
|
||||
func TestServerClose(t *testing.T) {
|
||||
var (
|
||||
server = NewServer()
|
||||
server = mustServer(t)(NewServer())
|
||||
_, listener = newTestListener(t)
|
||||
startClose = make(chan struct{})
|
||||
errs = make(chan error, 1)
|
||||
@ -292,7 +292,7 @@ func TestServerClose(t *testing.T) {
|
||||
func TestOversizeCall(t *testing.T) {
|
||||
var (
|
||||
ctx = context.Background()
|
||||
server = NewServer()
|
||||
server = mustServer(t)(NewServer())
|
||||
addr, listener = newTestListener(t)
|
||||
errs = make(chan error, 1)
|
||||
client, cleanup = newTestClient(t, addr)
|
||||
@ -327,7 +327,7 @@ func TestOversizeCall(t *testing.T) {
|
||||
func TestClientEOF(t *testing.T) {
|
||||
var (
|
||||
ctx = context.Background()
|
||||
server = NewServer()
|
||||
server = mustServer(t)(NewServer())
|
||||
addr, listener = newTestListener(t)
|
||||
errs = make(chan error, 1)
|
||||
client, cleanup = newTestClient(t, addr)
|
||||
@ -360,6 +360,29 @@ func TestClientEOF(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnixSocketHandshake(t *testing.T) {
|
||||
var (
|
||||
ctx = context.Background()
|
||||
server = mustServer(t)(NewServer(WithServerHandshaker(UnixSocketRequireSameUser)))
|
||||
addr, listener = newTestListener(t)
|
||||
errs = make(chan error, 1)
|
||||
client, cleanup = newTestClient(t, addr)
|
||||
)
|
||||
defer cleanup()
|
||||
defer listener.Close()
|
||||
go func() {
|
||||
errs <- server.Serve(listener)
|
||||
}()
|
||||
|
||||
registerTestingService(server, &testingServer{})
|
||||
|
||||
var tp testPayload
|
||||
// server shutdown, but we still make a call.
|
||||
if err := client.Call(ctx, serviceName, "Test", &tp, &tp); err != nil {
|
||||
t.Fatalf("unexpected error making call: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func checkServerShutdown(t *testing.T, server *Server) {
|
||||
t.Helper()
|
||||
server.mu.Lock()
|
||||
@ -420,3 +443,14 @@ func newTestListener(t testing.TB) (string, net.Listener) {
|
||||
|
||||
return addr, listener
|
||||
}
|
||||
|
||||
func mustServer(t testing.TB) func(server *Server, err error) *Server {
|
||||
return func(server *Server, err error) *Server {
|
||||
t.Helper()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return server
|
||||
}
|
||||
}
|
||||
|
90
unixcreds.go
Normal file
90
unixcreds.go
Normal file
@ -0,0 +1,90 @@
|
||||
// +build linux freebsd solaris
|
||||
|
||||
package ttrpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"os/user"
|
||||
"strconv"
|
||||
"syscall"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
var (
|
||||
UnixSocketRequireSameUser = UnixCredentialsFunc(requireSameUser)
|
||||
UnixSocketRequireRoot = UnixCredentialsFunc(requireRoot)
|
||||
)
|
||||
|
||||
type UnixCredentialsFunc func(*unix.Ucred) error
|
||||
|
||||
func (fn UnixCredentialsFunc) Handshake(ctx context.Context, conn net.Conn) (net.Conn, interface{}, error) {
|
||||
uc, err := requireUnixSocket(conn)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "ttrpc.UnixCredentialsFunc: require unix socket")
|
||||
}
|
||||
|
||||
fp, err := uc.File()
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "ttrpc.UnixCredentialsFunc: failed to get unix file")
|
||||
}
|
||||
defer fp.Close() // this gets duped and must be closed when this method is complete.
|
||||
|
||||
ucred, err := unix.GetsockoptUcred(int(fp.Fd()), unix.SOL_SOCKET, unix.SO_PEERCRED)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrapf(err, "ttrpc.UnixCredentialsFunc: failed to retrieve socket peer credentials")
|
||||
}
|
||||
|
||||
if err := fn(ucred); err != nil {
|
||||
return nil, nil, errors.Wrapf(err, "ttrpc.UnixCredentialsFunc: credential check failed")
|
||||
}
|
||||
|
||||
return uc, ucred, nil
|
||||
}
|
||||
|
||||
func UnixSocketRequireUidGid(uid, gid uint32) UnixCredentialsFunc {
|
||||
return func(ucred *unix.Ucred) error {
|
||||
return requireUidGid(ucred, uid, gid)
|
||||
}
|
||||
}
|
||||
|
||||
func requireRoot(ucred *unix.Ucred) error {
|
||||
return requireUidGid(ucred, 0, 0)
|
||||
}
|
||||
|
||||
func requireSameUser(ucred *unix.Ucred) error {
|
||||
u, err := user.Current()
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "could not resolve current user")
|
||||
}
|
||||
|
||||
uid, err := strconv.ParseUint(u.Uid, 10, 32)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to parse current user uid: %v", u.Uid)
|
||||
}
|
||||
|
||||
gid, err := strconv.ParseUint(u.Gid, 10, 32)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to parse current user gid: %v", u.Gid)
|
||||
}
|
||||
|
||||
return requireUidGid(ucred, uint32(uid), uint32(gid))
|
||||
}
|
||||
|
||||
func requireUidGid(ucred *unix.Ucred, uid, gid uint32) error {
|
||||
if (uid != ucred.Uid) || (gid != ucred.Gid) {
|
||||
return errors.Wrap(syscall.EPERM, "ttrpc: invalid credentials")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func requireUnixSocket(conn net.Conn) (*net.UnixConn, error) {
|
||||
uc, ok := conn.(*net.UnixConn)
|
||||
if !ok {
|
||||
return nil, errors.New("a unix socket connection is required")
|
||||
}
|
||||
|
||||
return uc, nil
|
||||
}
|
Loading…
Reference in New Issue
Block a user