ttrpc: use os.Getuid/os.Getgid directly
Because of issues with glibc, using the `os/user` package can cause when calling `user.Current()`. Neither the Go maintainers or glibc developers could be bothered to fix it, so we have to work around it by calling the uid and gid functions directly. This is probably better because we don't actually use much of the data provided in the `user.User` struct. This required some refactoring to have better control over when the uid and gid are resolved. Rather than checking the current user on every connection, we now resolve it once at initialization. To test that this provided an improvement in performance, a benchmark was added. Unfortunately, this exposed a regression in the performance of unix sockets in Go when `(*UnixConn).File` is called. The underlying culprit of this performance regression is still at large. The following open issues describe the underlying problem in more detail: https://github.com/golang/go/issues/13470 https://sourceware.org/bugzilla/show_bug.cgi?id=19341 In better news, I now have an entire herd of shaved yaks. Signed-off-by: Stephen J Day <stephen.day@docker.com>
This commit is contained in:
parent
af6e7491e5
commit
256c17bccd
10
handshake.go
10
handshake.go
@ -22,3 +22,13 @@ type Handshaker interface {
|
||||
// client-side.
|
||||
Handshake(ctx context.Context, conn net.Conn) (net.Conn, interface{}, error)
|
||||
}
|
||||
|
||||
type handshakerFunc func(ctx context.Context, conn net.Conn) (net.Conn, interface{}, error)
|
||||
|
||||
func (fn handshakerFunc) Handshake(ctx context.Context, conn net.Conn) (net.Conn, interface{}, error) {
|
||||
return fn(ctx, conn)
|
||||
}
|
||||
|
||||
func noopHandshake(ctx context.Context, conn net.Conn) (net.Conn, interface{}, error) {
|
||||
return conn, nil, nil
|
||||
}
|
||||
|
22
server.go
22
server.go
@ -2,6 +2,7 @@ package ttrpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"sync"
|
||||
@ -55,10 +56,15 @@ func (s *Server) Serve(l net.Listener) error {
|
||||
defer s.closeListener(l)
|
||||
|
||||
var (
|
||||
ctx = context.Background()
|
||||
backoff time.Duration
|
||||
ctx = context.Background()
|
||||
backoff time.Duration
|
||||
handshaker = s.config.handshaker
|
||||
)
|
||||
|
||||
if handshaker == nil {
|
||||
handshaker = handshakerFunc(noopHandshake)
|
||||
}
|
||||
|
||||
for {
|
||||
conn, err := l.Accept()
|
||||
if err != nil {
|
||||
@ -92,7 +98,7 @@ func (s *Server) Serve(l net.Listener) error {
|
||||
|
||||
backoff = 0
|
||||
|
||||
approved, handshake, err := s.handshake(ctx, conn)
|
||||
approved, handshake, err := handshaker.Handshake(ctx, conn)
|
||||
if err != nil {
|
||||
log.L.WithError(err).Errorf("ttrpc: refusing connection after handshake")
|
||||
conn.Close()
|
||||
@ -150,14 +156,6 @@ func (s *Server) Close() error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Server) handshake(ctx context.Context, conn net.Conn) (net.Conn, interface{}, error) {
|
||||
if s.config.handshaker == nil {
|
||||
return conn, nil, nil
|
||||
}
|
||||
|
||||
return s.config.handshaker.Handshake(ctx, conn)
|
||||
}
|
||||
|
||||
func (s *Server) addListener(l net.Listener) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
@ -433,7 +431,7 @@ func (c *serverConn) run(sctx context.Context) {
|
||||
// branch. Basically, it means that we are no longer receiving
|
||||
// requests due to a terminal error.
|
||||
recvErr = nil // connection is now "closing"
|
||||
if err != nil {
|
||||
if err != nil && err != io.EOF {
|
||||
log.L.WithError(err).Error("error receiving message")
|
||||
}
|
||||
case <-shutdown:
|
||||
|
@ -106,34 +106,6 @@ func TestServer(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkRoundTrip(b *testing.B) {
|
||||
var (
|
||||
ctx = context.Background()
|
||||
server = mustServer(b)(NewServer())
|
||||
testImpl = &testingServer{}
|
||||
addr, listener = newTestListener(b)
|
||||
client, cleanup = newTestClient(b, addr)
|
||||
tclient = newTestingClient(client)
|
||||
)
|
||||
|
||||
defer listener.Close()
|
||||
defer cleanup()
|
||||
|
||||
registerTestingService(server, testImpl)
|
||||
|
||||
go server.Serve(listener)
|
||||
defer server.Shutdown(ctx)
|
||||
|
||||
var tp testPayload
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
if _, err := tclient.Test(ctx, &tp); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerNotFound(t *testing.T) {
|
||||
var (
|
||||
ctx = context.Background()
|
||||
@ -363,7 +335,7 @@ func TestClientEOF(t *testing.T) {
|
||||
func TestUnixSocketHandshake(t *testing.T) {
|
||||
var (
|
||||
ctx = context.Background()
|
||||
server = mustServer(t)(NewServer(WithServerHandshaker(UnixSocketRequireSameUser)))
|
||||
server = mustServer(t)(NewServer(WithServerHandshaker(UnixSocketRequireSameUser())))
|
||||
addr, listener = newTestListener(t)
|
||||
errs = make(chan error, 1)
|
||||
client, cleanup = newTestClient(t, addr)
|
||||
@ -383,6 +355,66 @@ func TestUnixSocketHandshake(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkRoundTrip(b *testing.B) {
|
||||
var (
|
||||
ctx = context.Background()
|
||||
server = mustServer(b)(NewServer())
|
||||
testImpl = &testingServer{}
|
||||
addr, listener = newTestListener(b)
|
||||
client, cleanup = newTestClient(b, addr)
|
||||
tclient = newTestingClient(client)
|
||||
)
|
||||
|
||||
defer listener.Close()
|
||||
defer cleanup()
|
||||
|
||||
registerTestingService(server, testImpl)
|
||||
|
||||
go server.Serve(listener)
|
||||
defer server.Shutdown(ctx)
|
||||
|
||||
var tp testPayload
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
if _, err := tclient.Test(ctx, &tp); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkRoundTripUnixSocketCreds(b *testing.B) {
|
||||
// TODO(stevvooe): Right now, there is a 5x performance decrease when using
|
||||
// unix socket credentials. See (UnixCredentialsFunc).Handshake for
|
||||
// details.
|
||||
|
||||
var (
|
||||
ctx = context.Background()
|
||||
server = mustServer(b)(NewServer(WithServerHandshaker(UnixSocketRequireSameUser())))
|
||||
testImpl = &testingServer{}
|
||||
addr, listener = newTestListener(b)
|
||||
client, cleanup = newTestClient(b, addr)
|
||||
tclient = newTestingClient(client)
|
||||
)
|
||||
|
||||
defer listener.Close()
|
||||
defer cleanup()
|
||||
|
||||
registerTestingService(server, testImpl)
|
||||
|
||||
go server.Serve(listener)
|
||||
defer server.Shutdown(ctx)
|
||||
|
||||
var tp testPayload
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
if _, err := tclient.Test(ctx, &tp); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func checkServerShutdown(t *testing.T, server *Server) {
|
||||
t.Helper()
|
||||
server.mu.Lock()
|
||||
|
50
unixcreds.go
50
unixcreds.go
@ -5,19 +5,13 @@ package ttrpc
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"os/user"
|
||||
"strconv"
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
var (
|
||||
UnixSocketRequireSameUser = UnixCredentialsFunc(requireSameUser)
|
||||
UnixSocketRequireRoot = UnixCredentialsFunc(requireRoot)
|
||||
)
|
||||
|
||||
type UnixCredentialsFunc func(*unix.Ucred) error
|
||||
|
||||
func (fn UnixCredentialsFunc) Handshake(ctx context.Context, conn net.Conn) (net.Conn, interface{}, error) {
|
||||
@ -26,6 +20,9 @@ func (fn UnixCredentialsFunc) Handshake(ctx context.Context, conn net.Conn) (net
|
||||
return nil, nil, errors.Wrap(err, "ttrpc.UnixCredentialsFunc: require unix socket")
|
||||
}
|
||||
|
||||
// TODO(stevvooe): Calling (*UnixConn).File causes a 5x performance
|
||||
// decrease vs just accessing the fd directly. Need to do some more
|
||||
// troubleshooting to isolate this to Go runtime or kernel.
|
||||
fp, err := uc.File()
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "ttrpc.UnixCredentialsFunc: failed to get unix file")
|
||||
@ -44,37 +41,32 @@ func (fn UnixCredentialsFunc) Handshake(ctx context.Context, conn net.Conn) (net
|
||||
return uc, ucred, nil
|
||||
}
|
||||
|
||||
func UnixSocketRequireUidGid(uid, gid uint32) UnixCredentialsFunc {
|
||||
func UnixSocketRequireUidGid(uid, gid int) UnixCredentialsFunc {
|
||||
return func(ucred *unix.Ucred) error {
|
||||
return requireUidGid(ucred, uid, gid)
|
||||
}
|
||||
}
|
||||
|
||||
func UnixSocketRequireRoot() UnixCredentialsFunc {
|
||||
return UnixSocketRequireUidGid(0, 0)
|
||||
}
|
||||
|
||||
// UnixSocketRequireSameUser resolves the current unix user and returns a
|
||||
// UnixCredentialsFunc that will validate incoming unix connections against the
|
||||
// current credentials.
|
||||
//
|
||||
// This is useful when using abstract sockets that are accessible by all users.
|
||||
func UnixSocketRequireSameUser() UnixCredentialsFunc {
|
||||
uid, gid := os.Getuid(), os.Getgid()
|
||||
return UnixSocketRequireUidGid(uid, gid)
|
||||
}
|
||||
|
||||
func requireRoot(ucred *unix.Ucred) error {
|
||||
return requireUidGid(ucred, 0, 0)
|
||||
}
|
||||
|
||||
func requireSameUser(ucred *unix.Ucred) error {
|
||||
u, err := user.Current()
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "could not resolve current user")
|
||||
}
|
||||
|
||||
uid, err := strconv.ParseUint(u.Uid, 10, 32)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to parse current user uid: %v", u.Uid)
|
||||
}
|
||||
|
||||
gid, err := strconv.ParseUint(u.Gid, 10, 32)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to parse current user gid: %v", u.Gid)
|
||||
}
|
||||
|
||||
return requireUidGid(ucred, uint32(uid), uint32(gid))
|
||||
}
|
||||
|
||||
func requireUidGid(ucred *unix.Ucred, uid, gid uint32) error {
|
||||
if (uid != ucred.Uid) || (gid != ucred.Gid) {
|
||||
func requireUidGid(ucred *unix.Ucred, uid, gid int) error {
|
||||
if (uid != -1 && uint32(uid) != ucred.Uid) || (gid != -1 && uint32(gid) != ucred.Gid) {
|
||||
return errors.Wrap(syscall.EPERM, "ttrpc: invalid credentials")
|
||||
}
|
||||
return nil
|
||||
|
Loading…
Reference in New Issue
Block a user