diff --git a/server_test.go b/server_test.go index b99d5f4..dd4baf6 100644 --- a/server_test.go +++ b/server_test.go @@ -5,6 +5,7 @@ import ( "fmt" "net" "reflect" + "strings" "testing" "github.com/gogo/protobuf/proto" @@ -44,12 +45,25 @@ func (r *testPayload) String() string { return fmt.Sprintf("%+#v", r) } func (r *testPayload) ProtoMessage() {} // testingServer is what would be implemented by the user of this package. -type testingServer struct { - payload *testPayload -} +type testingServer struct{} func (s *testingServer) Test(ctx context.Context, req *testPayload) (*testPayload, error) { - return s.payload, nil + return &testPayload{Foo: strings.Repeat(req.Foo, 2)}, nil +} + +// registerTestingService mocks more of what is generated code. Unlike grpc, we +// register with a closure so that the descriptor is allocated only on +// registration. +func registerTestingService(srv *Server, svc testingService) { + srv.Register(serviceName, map[string]Method{ + "Test": func(ctx context.Context, unmarshal func(interface{}) error) (interface{}, error) { + var req testPayload + if err := unmarshal(&req); err != nil { + return nil, err + } + return svc.Test(ctx, &req) + }, + }) } func init() { @@ -60,29 +74,15 @@ func init() { func TestServer(t *testing.T) { var ( - ctx = context.Background() - server = NewServer() - expectedResponse = &testPayload{Foo: "baz"} - testImpl = &testingServer{payload: expectedResponse} + ctx = context.Background() + server = NewServer() + testImpl = &testingServer{} ) - // more mocking of what is generated code. Unlike grpc, we register with a - // closure so that the descriptor is allocated only on registration. - registerTestingService := func(srv *Server, svc testingService) { - srv.Register(serviceName, map[string]Method{ - "Test": func(ctx context.Context, unmarshal func(interface{}) error) (interface{}, error) { - var req testPayload - if err := unmarshal(&req); err != nil { - return nil, err - } - return svc.Test(ctx, &req) - }, - }) - } - registerTestingService(server, testImpl) - listener, err := net.Listen("tcp", ":0") + addr := "\x00" + t.Name() + listener, err := net.Listen("unix", addr) if err != nil { t.Fatal(err) } @@ -91,24 +91,48 @@ func TestServer(t *testing.T) { go server.Serve(listener) defer server.Shutdown(ctx) - conn, err := net.Dial("tcp", listener.Addr().String()) + conn, err := net.Dial("unix", addr) if err != nil { t.Fatal(err) } defer conn.Close() - client := newTestingClient(NewClient(conn)) - tp := &testPayload{ - Foo: "bar", + const calls = 2 + results := make(chan callResult, 2) + go roundTrip(ctx, t, client, "bar", results) + go roundTrip(ctx, t, client, "baz", results) + + for i := 0; i < calls; i++ { + result := <-results + if !reflect.DeepEqual(result.received, result.expected) { + t.Fatalf("unexpected response: %+#v != %+#v", result.received, result.expected) + } } +} + +type callResult struct { + input *testPayload + expected *testPayload + received *testPayload +} + +func roundTrip(ctx context.Context, t *testing.T, client *testingClient, value string, results chan callResult) { + t.Helper() + var ( + tp = &testPayload{ + Foo: "bar", + } + ) resp, err := client.Test(ctx, tp) if err != nil { t.Fatal(err) } - if !reflect.DeepEqual(resp, expectedResponse) { - t.Fatalf("unexpected response: %+#v != %+#v", resp, expectedResponse) + results <- callResult{ + input: tp, + expected: &testPayload{Foo: strings.Repeat(tp.Foo, 2)}, + received: resp, } }