Merge pull request #13 from stevvooe/glibc-not-static
ttrpc: use os.Getuid/os.Getgid directly
This commit is contained in:
10
handshake.go
10
handshake.go
@@ -22,3 +22,13 @@ type Handshaker interface {
|
|||||||
// client-side.
|
// client-side.
|
||||||
Handshake(ctx context.Context, conn net.Conn) (net.Conn, interface{}, error)
|
Handshake(ctx context.Context, conn net.Conn) (net.Conn, interface{}, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type handshakerFunc func(ctx context.Context, conn net.Conn) (net.Conn, interface{}, error)
|
||||||
|
|
||||||
|
func (fn handshakerFunc) Handshake(ctx context.Context, conn net.Conn) (net.Conn, interface{}, error) {
|
||||||
|
return fn(ctx, conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func noopHandshake(ctx context.Context, conn net.Conn) (net.Conn, interface{}, error) {
|
||||||
|
return conn, nil, nil
|
||||||
|
}
|
||||||
|
|||||||
22
server.go
22
server.go
@@ -2,6 +2,7 @@ package ttrpc
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"io"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -55,10 +56,15 @@ func (s *Server) Serve(l net.Listener) error {
|
|||||||
defer s.closeListener(l)
|
defer s.closeListener(l)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ctx = context.Background()
|
ctx = context.Background()
|
||||||
backoff time.Duration
|
backoff time.Duration
|
||||||
|
handshaker = s.config.handshaker
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if handshaker == nil {
|
||||||
|
handshaker = handshakerFunc(noopHandshake)
|
||||||
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
conn, err := l.Accept()
|
conn, err := l.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -92,7 +98,7 @@ func (s *Server) Serve(l net.Listener) error {
|
|||||||
|
|
||||||
backoff = 0
|
backoff = 0
|
||||||
|
|
||||||
approved, handshake, err := s.handshake(ctx, conn)
|
approved, handshake, err := handshaker.Handshake(ctx, conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.L.WithError(err).Errorf("ttrpc: refusing connection after handshake")
|
log.L.WithError(err).Errorf("ttrpc: refusing connection after handshake")
|
||||||
conn.Close()
|
conn.Close()
|
||||||
@@ -150,14 +156,6 @@ func (s *Server) Close() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handshake(ctx context.Context, conn net.Conn) (net.Conn, interface{}, error) {
|
|
||||||
if s.config.handshaker == nil {
|
|
||||||
return conn, nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return s.config.handshaker.Handshake(ctx, conn)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) addListener(l net.Listener) {
|
func (s *Server) addListener(l net.Listener) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
@@ -433,7 +431,7 @@ func (c *serverConn) run(sctx context.Context) {
|
|||||||
// branch. Basically, it means that we are no longer receiving
|
// branch. Basically, it means that we are no longer receiving
|
||||||
// requests due to a terminal error.
|
// requests due to a terminal error.
|
||||||
recvErr = nil // connection is now "closing"
|
recvErr = nil // connection is now "closing"
|
||||||
if err != nil {
|
if err != nil && err != io.EOF {
|
||||||
log.L.WithError(err).Error("error receiving message")
|
log.L.WithError(err).Error("error receiving message")
|
||||||
}
|
}
|
||||||
case <-shutdown:
|
case <-shutdown:
|
||||||
|
|||||||
@@ -106,34 +106,6 @@ func TestServer(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkRoundTrip(b *testing.B) {
|
|
||||||
var (
|
|
||||||
ctx = context.Background()
|
|
||||||
server = mustServer(b)(NewServer())
|
|
||||||
testImpl = &testingServer{}
|
|
||||||
addr, listener = newTestListener(b)
|
|
||||||
client, cleanup = newTestClient(b, addr)
|
|
||||||
tclient = newTestingClient(client)
|
|
||||||
)
|
|
||||||
|
|
||||||
defer listener.Close()
|
|
||||||
defer cleanup()
|
|
||||||
|
|
||||||
registerTestingService(server, testImpl)
|
|
||||||
|
|
||||||
go server.Serve(listener)
|
|
||||||
defer server.Shutdown(ctx)
|
|
||||||
|
|
||||||
var tp testPayload
|
|
||||||
b.ResetTimer()
|
|
||||||
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
if _, err := tclient.Test(ctx, &tp); err != nil {
|
|
||||||
b.Fatal(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestServerNotFound(t *testing.T) {
|
func TestServerNotFound(t *testing.T) {
|
||||||
var (
|
var (
|
||||||
ctx = context.Background()
|
ctx = context.Background()
|
||||||
@@ -363,7 +335,7 @@ func TestClientEOF(t *testing.T) {
|
|||||||
func TestUnixSocketHandshake(t *testing.T) {
|
func TestUnixSocketHandshake(t *testing.T) {
|
||||||
var (
|
var (
|
||||||
ctx = context.Background()
|
ctx = context.Background()
|
||||||
server = mustServer(t)(NewServer(WithServerHandshaker(UnixSocketRequireSameUser)))
|
server = mustServer(t)(NewServer(WithServerHandshaker(UnixSocketRequireSameUser())))
|
||||||
addr, listener = newTestListener(t)
|
addr, listener = newTestListener(t)
|
||||||
errs = make(chan error, 1)
|
errs = make(chan error, 1)
|
||||||
client, cleanup = newTestClient(t, addr)
|
client, cleanup = newTestClient(t, addr)
|
||||||
@@ -383,6 +355,66 @@ func TestUnixSocketHandshake(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func BenchmarkRoundTrip(b *testing.B) {
|
||||||
|
var (
|
||||||
|
ctx = context.Background()
|
||||||
|
server = mustServer(b)(NewServer())
|
||||||
|
testImpl = &testingServer{}
|
||||||
|
addr, listener = newTestListener(b)
|
||||||
|
client, cleanup = newTestClient(b, addr)
|
||||||
|
tclient = newTestingClient(client)
|
||||||
|
)
|
||||||
|
|
||||||
|
defer listener.Close()
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
registerTestingService(server, testImpl)
|
||||||
|
|
||||||
|
go server.Serve(listener)
|
||||||
|
defer server.Shutdown(ctx)
|
||||||
|
|
||||||
|
var tp testPayload
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
if _, err := tclient.Test(ctx, &tp); err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkRoundTripUnixSocketCreds(b *testing.B) {
|
||||||
|
// TODO(stevvooe): Right now, there is a 5x performance decrease when using
|
||||||
|
// unix socket credentials. See (UnixCredentialsFunc).Handshake for
|
||||||
|
// details.
|
||||||
|
|
||||||
|
var (
|
||||||
|
ctx = context.Background()
|
||||||
|
server = mustServer(b)(NewServer(WithServerHandshaker(UnixSocketRequireSameUser())))
|
||||||
|
testImpl = &testingServer{}
|
||||||
|
addr, listener = newTestListener(b)
|
||||||
|
client, cleanup = newTestClient(b, addr)
|
||||||
|
tclient = newTestingClient(client)
|
||||||
|
)
|
||||||
|
|
||||||
|
defer listener.Close()
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
registerTestingService(server, testImpl)
|
||||||
|
|
||||||
|
go server.Serve(listener)
|
||||||
|
defer server.Shutdown(ctx)
|
||||||
|
|
||||||
|
var tp testPayload
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
if _, err := tclient.Test(ctx, &tp); err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func checkServerShutdown(t *testing.T, server *Server) {
|
func checkServerShutdown(t *testing.T, server *Server) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
server.mu.Lock()
|
server.mu.Lock()
|
||||||
|
|||||||
50
unixcreds.go
50
unixcreds.go
@@ -5,19 +5,13 @@ package ttrpc
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
"net"
|
||||||
"os/user"
|
"os"
|
||||||
"strconv"
|
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
UnixSocketRequireSameUser = UnixCredentialsFunc(requireSameUser)
|
|
||||||
UnixSocketRequireRoot = UnixCredentialsFunc(requireRoot)
|
|
||||||
)
|
|
||||||
|
|
||||||
type UnixCredentialsFunc func(*unix.Ucred) error
|
type UnixCredentialsFunc func(*unix.Ucred) error
|
||||||
|
|
||||||
func (fn UnixCredentialsFunc) Handshake(ctx context.Context, conn net.Conn) (net.Conn, interface{}, error) {
|
func (fn UnixCredentialsFunc) Handshake(ctx context.Context, conn net.Conn) (net.Conn, interface{}, error) {
|
||||||
@@ -26,6 +20,9 @@ func (fn UnixCredentialsFunc) Handshake(ctx context.Context, conn net.Conn) (net
|
|||||||
return nil, nil, errors.Wrap(err, "ttrpc.UnixCredentialsFunc: require unix socket")
|
return nil, nil, errors.Wrap(err, "ttrpc.UnixCredentialsFunc: require unix socket")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(stevvooe): Calling (*UnixConn).File causes a 5x performance
|
||||||
|
// decrease vs just accessing the fd directly. Need to do some more
|
||||||
|
// troubleshooting to isolate this to Go runtime or kernel.
|
||||||
fp, err := uc.File()
|
fp, err := uc.File()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, errors.Wrap(err, "ttrpc.UnixCredentialsFunc: failed to get unix file")
|
return nil, nil, errors.Wrap(err, "ttrpc.UnixCredentialsFunc: failed to get unix file")
|
||||||
@@ -44,37 +41,32 @@ func (fn UnixCredentialsFunc) Handshake(ctx context.Context, conn net.Conn) (net
|
|||||||
return uc, ucred, nil
|
return uc, ucred, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func UnixSocketRequireUidGid(uid, gid uint32) UnixCredentialsFunc {
|
func UnixSocketRequireUidGid(uid, gid int) UnixCredentialsFunc {
|
||||||
return func(ucred *unix.Ucred) error {
|
return func(ucred *unix.Ucred) error {
|
||||||
return requireUidGid(ucred, uid, gid)
|
return requireUidGid(ucred, uid, gid)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func UnixSocketRequireRoot() UnixCredentialsFunc {
|
||||||
|
return UnixSocketRequireUidGid(0, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnixSocketRequireSameUser resolves the current unix user and returns a
|
||||||
|
// UnixCredentialsFunc that will validate incoming unix connections against the
|
||||||
|
// current credentials.
|
||||||
|
//
|
||||||
|
// This is useful when using abstract sockets that are accessible by all users.
|
||||||
|
func UnixSocketRequireSameUser() UnixCredentialsFunc {
|
||||||
|
uid, gid := os.Getuid(), os.Getgid()
|
||||||
|
return UnixSocketRequireUidGid(uid, gid)
|
||||||
|
}
|
||||||
|
|
||||||
func requireRoot(ucred *unix.Ucred) error {
|
func requireRoot(ucred *unix.Ucred) error {
|
||||||
return requireUidGid(ucred, 0, 0)
|
return requireUidGid(ucred, 0, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
func requireSameUser(ucred *unix.Ucred) error {
|
func requireUidGid(ucred *unix.Ucred, uid, gid int) error {
|
||||||
u, err := user.Current()
|
if (uid != -1 && uint32(uid) != ucred.Uid) || (gid != -1 && uint32(gid) != ucred.Gid) {
|
||||||
if err != nil {
|
|
||||||
return errors.Wrapf(err, "could not resolve current user")
|
|
||||||
}
|
|
||||||
|
|
||||||
uid, err := strconv.ParseUint(u.Uid, 10, 32)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrapf(err, "failed to parse current user uid: %v", u.Uid)
|
|
||||||
}
|
|
||||||
|
|
||||||
gid, err := strconv.ParseUint(u.Gid, 10, 32)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrapf(err, "failed to parse current user gid: %v", u.Gid)
|
|
||||||
}
|
|
||||||
|
|
||||||
return requireUidGid(ucred, uint32(uid), uint32(gid))
|
|
||||||
}
|
|
||||||
|
|
||||||
func requireUidGid(ucred *unix.Ucred, uid, gid uint32) error {
|
|
||||||
if (uid != ucred.Uid) || (gid != ucred.Gid) {
|
|
||||||
return errors.Wrap(syscall.EPERM, "ttrpc: invalid credentials")
|
return errors.Wrap(syscall.EPERM, "ttrpc: invalid credentials")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
Reference in New Issue
Block a user