diff --git a/server.go b/server.go index 40804ea..1431251 100644 --- a/server.go +++ b/server.go @@ -452,6 +452,11 @@ func (c *serverConn) run(sctx context.Context) { if err != nil && err != io.EOF { logrus.WithError(err).Error("error receiving message") } + if err == io.EOF || err == io.ErrUnexpectedEOF { + // The client went away and we should stop processing + // requests, so that the client connection is closed + return + } case <-shutdown: return } diff --git a/server_test.go b/server_test.go index 4f31feb..9ee7e0d 100644 --- a/server_test.go +++ b/server_test.go @@ -28,6 +28,7 @@ import ( "github.com/gogo/protobuf/proto" "github.com/pkg/errors" + "github.com/prometheus/procfs" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -359,6 +360,48 @@ func TestClientEOF(t *testing.T) { } } +func TestServerEOF(t *testing.T) { + var ( + ctx = context.Background() + server = mustServer(t)(NewServer()) + addr, listener = newTestListener(t) + client, cleanup = newTestClient(t, addr) + ) + defer cleanup() + defer listener.Close() + + socketCountBefore := socketCount(t) + + go server.Serve(ctx, listener) + + registerTestingService(server, &testingServer{}) + + tp := &testPayload{} + // do a regular call + if err := client.Call(ctx, serviceName, "Test", tp, tp); err != nil { + t.Fatalf("unexpected error during test call: %v", err) + } + + // close the client, so that server gets EOF + if err := client.Close(); err != nil { + t.Fatalf("unexpected error while closing client: %v", err) + } + + // server should eventually close the client connection + maxAttempts := 20 + for i := 1; i <= maxAttempts; i++ { + socketCountAfter := socketCount(t) + if socketCountAfter < socketCountBefore { + break + } + if i == maxAttempts { + t.Fatalf("expected number of open sockets to be less than %d after client close, got %d open sockets", + socketCountBefore, socketCountAfter) + } + time.Sleep(100 * time.Millisecond) + } +} + func TestUnixSocketHandshake(t *testing.T) { var ( ctx = context.Background() @@ -541,3 +584,22 @@ func mustServer(t testing.TB) func(server *Server, err error) *Server { return server } } + +func socketCount(t *testing.T) int { + proc, err := procfs.Self() + if err != nil { + t.Fatalf("unexpected error while reading procfs: %v", err) + } + fds, err := proc.FileDescriptorTargets() + if err != nil { + t.Fatalf("unexpected error while listing open file descriptors: %v", err) + } + + sockets := 0 + for _, fd := range fds { + if strings.Contains(fd, "socket") { + sockets++ + } + } + return sockets +}