diff --git a/handshake.go b/handshake.go index 80e1cc1..a08ae8e 100644 --- a/handshake.go +++ b/handshake.go @@ -22,3 +22,13 @@ type Handshaker interface { // client-side. Handshake(ctx context.Context, conn net.Conn) (net.Conn, interface{}, error) } + +type handshakerFunc func(ctx context.Context, conn net.Conn) (net.Conn, interface{}, error) + +func (fn handshakerFunc) Handshake(ctx context.Context, conn net.Conn) (net.Conn, interface{}, error) { + return fn(ctx, conn) +} + +func noopHandshake(ctx context.Context, conn net.Conn) (net.Conn, interface{}, error) { + return conn, nil, nil +} diff --git a/server.go b/server.go index beae908..edfca0c 100644 --- a/server.go +++ b/server.go @@ -2,6 +2,7 @@ package ttrpc import ( "context" + "io" "math/rand" "net" "sync" @@ -55,10 +56,15 @@ func (s *Server) Serve(l net.Listener) error { defer s.closeListener(l) var ( - ctx = context.Background() - backoff time.Duration + ctx = context.Background() + backoff time.Duration + handshaker = s.config.handshaker ) + if handshaker == nil { + handshaker = handshakerFunc(noopHandshake) + } + for { conn, err := l.Accept() if err != nil { @@ -92,7 +98,7 @@ func (s *Server) Serve(l net.Listener) error { backoff = 0 - approved, handshake, err := s.handshake(ctx, conn) + approved, handshake, err := handshaker.Handshake(ctx, conn) if err != nil { log.L.WithError(err).Errorf("ttrpc: refusing connection after handshake") conn.Close() @@ -150,14 +156,6 @@ func (s *Server) Close() error { return err } -func (s *Server) handshake(ctx context.Context, conn net.Conn) (net.Conn, interface{}, error) { - if s.config.handshaker == nil { - return conn, nil, nil - } - - return s.config.handshaker.Handshake(ctx, conn) -} - func (s *Server) addListener(l net.Listener) { s.mu.Lock() defer s.mu.Unlock() @@ -433,7 +431,7 @@ func (c *serverConn) run(sctx context.Context) { // branch. Basically, it means that we are no longer receiving // requests due to a terminal error. recvErr = nil // connection is now "closing" - if err != nil { + if err != nil && err != io.EOF { log.L.WithError(err).Error("error receiving message") } case <-shutdown: diff --git a/server_test.go b/server_test.go index 6ba8ddb..8be008b 100644 --- a/server_test.go +++ b/server_test.go @@ -106,34 +106,6 @@ func TestServer(t *testing.T) { } } -func BenchmarkRoundTrip(b *testing.B) { - var ( - ctx = context.Background() - server = mustServer(b)(NewServer()) - testImpl = &testingServer{} - addr, listener = newTestListener(b) - client, cleanup = newTestClient(b, addr) - tclient = newTestingClient(client) - ) - - defer listener.Close() - defer cleanup() - - registerTestingService(server, testImpl) - - go server.Serve(listener) - defer server.Shutdown(ctx) - - var tp testPayload - b.ResetTimer() - - for i := 0; i < b.N; i++ { - if _, err := tclient.Test(ctx, &tp); err != nil { - b.Fatal(err) - } - } -} - func TestServerNotFound(t *testing.T) { var ( ctx = context.Background() @@ -363,7 +335,7 @@ func TestClientEOF(t *testing.T) { func TestUnixSocketHandshake(t *testing.T) { var ( ctx = context.Background() - server = mustServer(t)(NewServer(WithServerHandshaker(UnixSocketRequireSameUser))) + server = mustServer(t)(NewServer(WithServerHandshaker(UnixSocketRequireSameUser()))) addr, listener = newTestListener(t) errs = make(chan error, 1) client, cleanup = newTestClient(t, addr) @@ -383,6 +355,66 @@ func TestUnixSocketHandshake(t *testing.T) { } } +func BenchmarkRoundTrip(b *testing.B) { + var ( + ctx = context.Background() + server = mustServer(b)(NewServer()) + testImpl = &testingServer{} + addr, listener = newTestListener(b) + client, cleanup = newTestClient(b, addr) + tclient = newTestingClient(client) + ) + + defer listener.Close() + defer cleanup() + + registerTestingService(server, testImpl) + + go server.Serve(listener) + defer server.Shutdown(ctx) + + var tp testPayload + b.ResetTimer() + + for i := 0; i < b.N; i++ { + if _, err := tclient.Test(ctx, &tp); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkRoundTripUnixSocketCreds(b *testing.B) { + // TODO(stevvooe): Right now, there is a 5x performance decrease when using + // unix socket credentials. See (UnixCredentialsFunc).Handshake for + // details. + + var ( + ctx = context.Background() + server = mustServer(b)(NewServer(WithServerHandshaker(UnixSocketRequireSameUser()))) + testImpl = &testingServer{} + addr, listener = newTestListener(b) + client, cleanup = newTestClient(b, addr) + tclient = newTestingClient(client) + ) + + defer listener.Close() + defer cleanup() + + registerTestingService(server, testImpl) + + go server.Serve(listener) + defer server.Shutdown(ctx) + + var tp testPayload + b.ResetTimer() + + for i := 0; i < b.N; i++ { + if _, err := tclient.Test(ctx, &tp); err != nil { + b.Fatal(err) + } + } +} + func checkServerShutdown(t *testing.T, server *Server) { t.Helper() server.mu.Lock() diff --git a/unixcreds.go b/unixcreds.go index 776761d..5de8b92 100644 --- a/unixcreds.go +++ b/unixcreds.go @@ -5,19 +5,13 @@ package ttrpc import ( "context" "net" - "os/user" - "strconv" + "os" "syscall" "github.com/pkg/errors" "golang.org/x/sys/unix" ) -var ( - UnixSocketRequireSameUser = UnixCredentialsFunc(requireSameUser) - UnixSocketRequireRoot = UnixCredentialsFunc(requireRoot) -) - type UnixCredentialsFunc func(*unix.Ucred) error func (fn UnixCredentialsFunc) Handshake(ctx context.Context, conn net.Conn) (net.Conn, interface{}, error) { @@ -26,6 +20,9 @@ func (fn UnixCredentialsFunc) Handshake(ctx context.Context, conn net.Conn) (net return nil, nil, errors.Wrap(err, "ttrpc.UnixCredentialsFunc: require unix socket") } + // TODO(stevvooe): Calling (*UnixConn).File causes a 5x performance + // decrease vs just accessing the fd directly. Need to do some more + // troubleshooting to isolate this to Go runtime or kernel. fp, err := uc.File() if err != nil { return nil, nil, errors.Wrap(err, "ttrpc.UnixCredentialsFunc: failed to get unix file") @@ -44,37 +41,32 @@ func (fn UnixCredentialsFunc) Handshake(ctx context.Context, conn net.Conn) (net return uc, ucred, nil } -func UnixSocketRequireUidGid(uid, gid uint32) UnixCredentialsFunc { +func UnixSocketRequireUidGid(uid, gid int) UnixCredentialsFunc { return func(ucred *unix.Ucred) error { return requireUidGid(ucred, uid, gid) } } +func UnixSocketRequireRoot() UnixCredentialsFunc { + return UnixSocketRequireUidGid(0, 0) +} + +// UnixSocketRequireSameUser resolves the current unix user and returns a +// UnixCredentialsFunc that will validate incoming unix connections against the +// current credentials. +// +// This is useful when using abstract sockets that are accessible by all users. +func UnixSocketRequireSameUser() UnixCredentialsFunc { + uid, gid := os.Getuid(), os.Getgid() + return UnixSocketRequireUidGid(uid, gid) +} + func requireRoot(ucred *unix.Ucred) error { return requireUidGid(ucred, 0, 0) } -func requireSameUser(ucred *unix.Ucred) error { - u, err := user.Current() - if err != nil { - return errors.Wrapf(err, "could not resolve current user") - } - - uid, err := strconv.ParseUint(u.Uid, 10, 32) - if err != nil { - return errors.Wrapf(err, "failed to parse current user uid: %v", u.Uid) - } - - gid, err := strconv.ParseUint(u.Gid, 10, 32) - if err != nil { - return errors.Wrapf(err, "failed to parse current user gid: %v", u.Gid) - } - - return requireUidGid(ucred, uint32(uid), uint32(gid)) -} - -func requireUidGid(ucred *unix.Ucred, uid, gid uint32) error { - if (uid != ucred.Uid) || (gid != ucred.Gid) { +func requireUidGid(ucred *unix.Ucred, uid, gid int) error { + if (uid != -1 && uint32(uid) != ucred.Uid) || (gid != -1 && uint32(gid) != ucred.Gid) { return errors.Wrap(syscall.EPERM, "ttrpc: invalid credentials") } return nil