From 256c17bccdb9c3892a78b7ee82c534137672716a Mon Sep 17 00:00:00 2001 From: Stephen J Day Date: Thu, 30 Nov 2017 20:21:50 -0800 Subject: [PATCH] ttrpc: use os.Getuid/os.Getgid directly Because of issues with glibc, using the `os/user` package can cause when calling `user.Current()`. Neither the Go maintainers or glibc developers could be bothered to fix it, so we have to work around it by calling the uid and gid functions directly. This is probably better because we don't actually use much of the data provided in the `user.User` struct. This required some refactoring to have better control over when the uid and gid are resolved. Rather than checking the current user on every connection, we now resolve it once at initialization. To test that this provided an improvement in performance, a benchmark was added. Unfortunately, this exposed a regression in the performance of unix sockets in Go when `(*UnixConn).File` is called. The underlying culprit of this performance regression is still at large. The following open issues describe the underlying problem in more detail: https://github.com/golang/go/issues/13470 https://sourceware.org/bugzilla/show_bug.cgi?id=19341 In better news, I now have an entire herd of shaved yaks. Signed-off-by: Stephen J Day --- handshake.go | 10 ++++++ server.go | 22 ++++++------ server_test.go | 90 ++++++++++++++++++++++++++++++++++---------------- unixcreds.go | 50 ++++++++++++---------------- 4 files changed, 102 insertions(+), 70 deletions(-) diff --git a/handshake.go b/handshake.go index 80e1cc1..a08ae8e 100644 --- a/handshake.go +++ b/handshake.go @@ -22,3 +22,13 @@ type Handshaker interface { // client-side. Handshake(ctx context.Context, conn net.Conn) (net.Conn, interface{}, error) } + +type handshakerFunc func(ctx context.Context, conn net.Conn) (net.Conn, interface{}, error) + +func (fn handshakerFunc) Handshake(ctx context.Context, conn net.Conn) (net.Conn, interface{}, error) { + return fn(ctx, conn) +} + +func noopHandshake(ctx context.Context, conn net.Conn) (net.Conn, interface{}, error) { + return conn, nil, nil +} diff --git a/server.go b/server.go index beae908..edfca0c 100644 --- a/server.go +++ b/server.go @@ -2,6 +2,7 @@ package ttrpc import ( "context" + "io" "math/rand" "net" "sync" @@ -55,10 +56,15 @@ func (s *Server) Serve(l net.Listener) error { defer s.closeListener(l) var ( - ctx = context.Background() - backoff time.Duration + ctx = context.Background() + backoff time.Duration + handshaker = s.config.handshaker ) + if handshaker == nil { + handshaker = handshakerFunc(noopHandshake) + } + for { conn, err := l.Accept() if err != nil { @@ -92,7 +98,7 @@ func (s *Server) Serve(l net.Listener) error { backoff = 0 - approved, handshake, err := s.handshake(ctx, conn) + approved, handshake, err := handshaker.Handshake(ctx, conn) if err != nil { log.L.WithError(err).Errorf("ttrpc: refusing connection after handshake") conn.Close() @@ -150,14 +156,6 @@ func (s *Server) Close() error { return err } -func (s *Server) handshake(ctx context.Context, conn net.Conn) (net.Conn, interface{}, error) { - if s.config.handshaker == nil { - return conn, nil, nil - } - - return s.config.handshaker.Handshake(ctx, conn) -} - func (s *Server) addListener(l net.Listener) { s.mu.Lock() defer s.mu.Unlock() @@ -433,7 +431,7 @@ func (c *serverConn) run(sctx context.Context) { // branch. Basically, it means that we are no longer receiving // requests due to a terminal error. recvErr = nil // connection is now "closing" - if err != nil { + if err != nil && err != io.EOF { log.L.WithError(err).Error("error receiving message") } case <-shutdown: diff --git a/server_test.go b/server_test.go index 6ba8ddb..8be008b 100644 --- a/server_test.go +++ b/server_test.go @@ -106,34 +106,6 @@ func TestServer(t *testing.T) { } } -func BenchmarkRoundTrip(b *testing.B) { - var ( - ctx = context.Background() - server = mustServer(b)(NewServer()) - testImpl = &testingServer{} - addr, listener = newTestListener(b) - client, cleanup = newTestClient(b, addr) - tclient = newTestingClient(client) - ) - - defer listener.Close() - defer cleanup() - - registerTestingService(server, testImpl) - - go server.Serve(listener) - defer server.Shutdown(ctx) - - var tp testPayload - b.ResetTimer() - - for i := 0; i < b.N; i++ { - if _, err := tclient.Test(ctx, &tp); err != nil { - b.Fatal(err) - } - } -} - func TestServerNotFound(t *testing.T) { var ( ctx = context.Background() @@ -363,7 +335,7 @@ func TestClientEOF(t *testing.T) { func TestUnixSocketHandshake(t *testing.T) { var ( ctx = context.Background() - server = mustServer(t)(NewServer(WithServerHandshaker(UnixSocketRequireSameUser))) + server = mustServer(t)(NewServer(WithServerHandshaker(UnixSocketRequireSameUser()))) addr, listener = newTestListener(t) errs = make(chan error, 1) client, cleanup = newTestClient(t, addr) @@ -383,6 +355,66 @@ func TestUnixSocketHandshake(t *testing.T) { } } +func BenchmarkRoundTrip(b *testing.B) { + var ( + ctx = context.Background() + server = mustServer(b)(NewServer()) + testImpl = &testingServer{} + addr, listener = newTestListener(b) + client, cleanup = newTestClient(b, addr) + tclient = newTestingClient(client) + ) + + defer listener.Close() + defer cleanup() + + registerTestingService(server, testImpl) + + go server.Serve(listener) + defer server.Shutdown(ctx) + + var tp testPayload + b.ResetTimer() + + for i := 0; i < b.N; i++ { + if _, err := tclient.Test(ctx, &tp); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkRoundTripUnixSocketCreds(b *testing.B) { + // TODO(stevvooe): Right now, there is a 5x performance decrease when using + // unix socket credentials. See (UnixCredentialsFunc).Handshake for + // details. + + var ( + ctx = context.Background() + server = mustServer(b)(NewServer(WithServerHandshaker(UnixSocketRequireSameUser()))) + testImpl = &testingServer{} + addr, listener = newTestListener(b) + client, cleanup = newTestClient(b, addr) + tclient = newTestingClient(client) + ) + + defer listener.Close() + defer cleanup() + + registerTestingService(server, testImpl) + + go server.Serve(listener) + defer server.Shutdown(ctx) + + var tp testPayload + b.ResetTimer() + + for i := 0; i < b.N; i++ { + if _, err := tclient.Test(ctx, &tp); err != nil { + b.Fatal(err) + } + } +} + func checkServerShutdown(t *testing.T, server *Server) { t.Helper() server.mu.Lock() diff --git a/unixcreds.go b/unixcreds.go index 776761d..5de8b92 100644 --- a/unixcreds.go +++ b/unixcreds.go @@ -5,19 +5,13 @@ package ttrpc import ( "context" "net" - "os/user" - "strconv" + "os" "syscall" "github.com/pkg/errors" "golang.org/x/sys/unix" ) -var ( - UnixSocketRequireSameUser = UnixCredentialsFunc(requireSameUser) - UnixSocketRequireRoot = UnixCredentialsFunc(requireRoot) -) - type UnixCredentialsFunc func(*unix.Ucred) error func (fn UnixCredentialsFunc) Handshake(ctx context.Context, conn net.Conn) (net.Conn, interface{}, error) { @@ -26,6 +20,9 @@ func (fn UnixCredentialsFunc) Handshake(ctx context.Context, conn net.Conn) (net return nil, nil, errors.Wrap(err, "ttrpc.UnixCredentialsFunc: require unix socket") } + // TODO(stevvooe): Calling (*UnixConn).File causes a 5x performance + // decrease vs just accessing the fd directly. Need to do some more + // troubleshooting to isolate this to Go runtime or kernel. fp, err := uc.File() if err != nil { return nil, nil, errors.Wrap(err, "ttrpc.UnixCredentialsFunc: failed to get unix file") @@ -44,37 +41,32 @@ func (fn UnixCredentialsFunc) Handshake(ctx context.Context, conn net.Conn) (net return uc, ucred, nil } -func UnixSocketRequireUidGid(uid, gid uint32) UnixCredentialsFunc { +func UnixSocketRequireUidGid(uid, gid int) UnixCredentialsFunc { return func(ucred *unix.Ucred) error { return requireUidGid(ucred, uid, gid) } } +func UnixSocketRequireRoot() UnixCredentialsFunc { + return UnixSocketRequireUidGid(0, 0) +} + +// UnixSocketRequireSameUser resolves the current unix user and returns a +// UnixCredentialsFunc that will validate incoming unix connections against the +// current credentials. +// +// This is useful when using abstract sockets that are accessible by all users. +func UnixSocketRequireSameUser() UnixCredentialsFunc { + uid, gid := os.Getuid(), os.Getgid() + return UnixSocketRequireUidGid(uid, gid) +} + func requireRoot(ucred *unix.Ucred) error { return requireUidGid(ucred, 0, 0) } -func requireSameUser(ucred *unix.Ucred) error { - u, err := user.Current() - if err != nil { - return errors.Wrapf(err, "could not resolve current user") - } - - uid, err := strconv.ParseUint(u.Uid, 10, 32) - if err != nil { - return errors.Wrapf(err, "failed to parse current user uid: %v", u.Uid) - } - - gid, err := strconv.ParseUint(u.Gid, 10, 32) - if err != nil { - return errors.Wrapf(err, "failed to parse current user gid: %v", u.Gid) - } - - return requireUidGid(ucred, uint32(uid), uint32(gid)) -} - -func requireUidGid(ucred *unix.Ucred, uid, gid uint32) error { - if (uid != ucred.Uid) || (gid != ucred.Gid) { +func requireUidGid(ucred *unix.Ucred, uid, gid int) error { + if (uid != -1 && uint32(uid) != ucred.Uid) || (gid != -1 && uint32(gid) != ucred.Gid) { return errors.Wrap(syscall.EPERM, "ttrpc: invalid credentials") } return nil