ttrpc: better test for concurrency
Improve the test to conduct two concurrent calls. The test server now just doubles the input and we use a unix socket to better represent what will be used in production. Signed-off-by: Stephen J Day <stephen.day@docker.com>
This commit is contained in:
parent
bdb2ab7a81
commit
5e1096a4c2
@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gogo/protobuf/proto"
|
||||
@ -44,12 +45,25 @@ func (r *testPayload) String() string { return fmt.Sprintf("%+#v", r) }
|
||||
func (r *testPayload) ProtoMessage() {}
|
||||
|
||||
// testingServer is what would be implemented by the user of this package.
|
||||
type testingServer struct {
|
||||
payload *testPayload
|
||||
}
|
||||
type testingServer struct{}
|
||||
|
||||
func (s *testingServer) Test(ctx context.Context, req *testPayload) (*testPayload, error) {
|
||||
return s.payload, nil
|
||||
return &testPayload{Foo: strings.Repeat(req.Foo, 2)}, nil
|
||||
}
|
||||
|
||||
// registerTestingService mocks more of what is generated code. Unlike grpc, we
|
||||
// register with a closure so that the descriptor is allocated only on
|
||||
// registration.
|
||||
func registerTestingService(srv *Server, svc testingService) {
|
||||
srv.Register(serviceName, map[string]Method{
|
||||
"Test": func(ctx context.Context, unmarshal func(interface{}) error) (interface{}, error) {
|
||||
var req testPayload
|
||||
if err := unmarshal(&req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return svc.Test(ctx, &req)
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func init() {
|
||||
@ -60,29 +74,15 @@ func init() {
|
||||
|
||||
func TestServer(t *testing.T) {
|
||||
var (
|
||||
ctx = context.Background()
|
||||
server = NewServer()
|
||||
expectedResponse = &testPayload{Foo: "baz"}
|
||||
testImpl = &testingServer{payload: expectedResponse}
|
||||
ctx = context.Background()
|
||||
server = NewServer()
|
||||
testImpl = &testingServer{}
|
||||
)
|
||||
|
||||
// more mocking of what is generated code. Unlike grpc, we register with a
|
||||
// closure so that the descriptor is allocated only on registration.
|
||||
registerTestingService := func(srv *Server, svc testingService) {
|
||||
srv.Register(serviceName, map[string]Method{
|
||||
"Test": func(ctx context.Context, unmarshal func(interface{}) error) (interface{}, error) {
|
||||
var req testPayload
|
||||
if err := unmarshal(&req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return svc.Test(ctx, &req)
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
registerTestingService(server, testImpl)
|
||||
|
||||
listener, err := net.Listen("tcp", ":0")
|
||||
addr := "\x00" + t.Name()
|
||||
listener, err := net.Listen("unix", addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -91,24 +91,48 @@ func TestServer(t *testing.T) {
|
||||
go server.Serve(listener)
|
||||
defer server.Shutdown(ctx)
|
||||
|
||||
conn, err := net.Dial("tcp", listener.Addr().String())
|
||||
conn, err := net.Dial("unix", addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := newTestingClient(NewClient(conn))
|
||||
|
||||
tp := &testPayload{
|
||||
Foo: "bar",
|
||||
const calls = 2
|
||||
results := make(chan callResult, 2)
|
||||
go roundTrip(ctx, t, client, "bar", results)
|
||||
go roundTrip(ctx, t, client, "baz", results)
|
||||
|
||||
for i := 0; i < calls; i++ {
|
||||
result := <-results
|
||||
if !reflect.DeepEqual(result.received, result.expected) {
|
||||
t.Fatalf("unexpected response: %+#v != %+#v", result.received, result.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type callResult struct {
|
||||
input *testPayload
|
||||
expected *testPayload
|
||||
received *testPayload
|
||||
}
|
||||
|
||||
func roundTrip(ctx context.Context, t *testing.T, client *testingClient, value string, results chan callResult) {
|
||||
t.Helper()
|
||||
var (
|
||||
tp = &testPayload{
|
||||
Foo: "bar",
|
||||
}
|
||||
)
|
||||
|
||||
resp, err := client.Test(ctx, tp)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(resp, expectedResponse) {
|
||||
t.Fatalf("unexpected response: %+#v != %+#v", resp, expectedResponse)
|
||||
results <- callResult{
|
||||
input: tp,
|
||||
expected: &testPayload{Foo: strings.Repeat(tp.Foo, 2)},
|
||||
received: resp,
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user