diff --git a/internal/test.pb.go b/internal/test.pb.go index 743afa8..9f94b56 100644 --- a/internal/test.pb.go +++ b/internal/test.pb.go @@ -83,6 +83,61 @@ func (x *TestPayload) GetMetadata() string { return "" } +type EchoPayload struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Seq int64 `protobuf:"varint,1,opt,name=seq,proto3" json:"seq,omitempty"` + Msg string `protobuf:"bytes,2,opt,name=msg,proto3" json:"msg,omitempty"` +} + +func (x *EchoPayload) Reset() { + *x = EchoPayload{} + if protoimpl.UnsafeEnabled { + mi := &file_github_com_containerd_ttrpc_test_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *EchoPayload) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*EchoPayload) ProtoMessage() {} + +func (x *EchoPayload) ProtoReflect() protoreflect.Message { + mi := &file_github_com_containerd_ttrpc_test_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use EchoPayload.ProtoReflect.Descriptor instead. +func (*EchoPayload) Descriptor() ([]byte, []int) { + return file_github_com_containerd_ttrpc_test_proto_rawDescGZIP(), []int{1} +} + +func (x *EchoPayload) GetSeq() int64 { + if x != nil { + return x.Seq + } + return 0 +} + +func (x *EchoPayload) GetMsg() string { + if x != nil { + return x.Msg + } + return "" +} + var File_github_com_containerd_ttrpc_test_proto protoreflect.FileDescriptor var file_github_com_containerd_ttrpc_test_proto_rawDesc = []byte{ @@ -94,10 +149,13 @@ var file_github_com_containerd_ttrpc_test_proto_rawDesc = []byte{ 0x12, 0x1a, 0x0a, 0x08, 0x64, 0x65, 0x61, 0x64, 0x6c, 0x69, 0x6e, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x08, 0x64, 0x65, 0x61, 0x64, 0x6c, 0x69, 0x6e, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, - 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x42, 0x26, 0x5a, 0x24, 0x67, 0x69, 0x74, 0x68, - 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x63, 0x6f, 0x6e, 0x74, 0x61, 0x69, 0x6e, 0x65, 0x72, - 0x64, 0x2f, 0x74, 0x74, 0x72, 0x70, 0x63, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, - 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x22, 0x31, 0x0a, 0x0b, 0x45, 0x63, 0x68, 0x6f, + 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x10, 0x0a, 0x03, 0x73, 0x65, 0x71, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x03, 0x52, 0x03, 0x73, 0x65, 0x71, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x73, 0x67, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6d, 0x73, 0x67, 0x42, 0x26, 0x5a, 0x24, 0x67, + 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x63, 0x6f, 0x6e, 0x74, 0x61, 0x69, + 0x6e, 0x65, 0x72, 0x64, 0x2f, 0x74, 0x74, 0x72, 0x70, 0x63, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, + 0x6e, 0x61, 0x6c, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -112,9 +170,10 @@ func file_github_com_containerd_ttrpc_test_proto_rawDescGZIP() []byte { return file_github_com_containerd_ttrpc_test_proto_rawDescData } -var file_github_com_containerd_ttrpc_test_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_github_com_containerd_ttrpc_test_proto_msgTypes = make([]protoimpl.MessageInfo, 2) var file_github_com_containerd_ttrpc_test_proto_goTypes = []interface{}{ (*TestPayload)(nil), // 0: ttrpc.TestPayload + (*EchoPayload)(nil), // 1: ttrpc.EchoPayload } var file_github_com_containerd_ttrpc_test_proto_depIdxs = []int32{ 0, // [0:0] is the sub-list for method output_type @@ -142,6 +201,18 @@ func file_github_com_containerd_ttrpc_test_proto_init() { return nil } } + file_github_com_containerd_ttrpc_test_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*EchoPayload); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } } type x struct{} out := protoimpl.TypeBuilder{ @@ -149,7 +220,7 @@ func file_github_com_containerd_ttrpc_test_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_github_com_containerd_ttrpc_test_proto_rawDesc, NumEnums: 0, - NumMessages: 1, + NumMessages: 2, NumExtensions: 0, NumServices: 0, }, diff --git a/stream_test.go b/stream_test.go new file mode 100644 index 0000000..1e04f1b --- /dev/null +++ b/stream_test.go @@ -0,0 +1,118 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package ttrpc + +import ( + "context" + "io" + "testing" + + "github.com/containerd/ttrpc/internal" +) + +func TestStreamClient(t *testing.T) { + var ( + ctx = context.Background() + server = mustServer(t)(NewServer()) + addr, listener = newTestListener(t) + client, cleanup = newTestClient(t, addr) + serviceName = "streamService" + ) + + defer listener.Close() + defer cleanup() + + desc := &ServiceDesc{ + Methods: map[string]Method{ + "Echo": func(ctx context.Context, unmarshal func(interface{}) error) (interface{}, error) { + var req internal.EchoPayload + if err := unmarshal(&req); err != nil { + return nil, err + } + req.Seq++ + return &req, nil + }, + }, + Streams: map[string]Stream{ + "EchoStream": { + Handler: func(ctx context.Context, ss StreamServer) (interface{}, error) { + for { + var req internal.EchoPayload + if err := ss.RecvMsg(&req); err != nil { + if err == io.EOF { + err = nil + } + return nil, err + } + req.Seq++ + if err := ss.SendMsg(&req); err != nil { + return nil, err + } + } + + }, + StreamingClient: true, + StreamingServer: true, + }, + }, + } + server.RegisterService(serviceName, desc) + + go server.Serve(ctx, listener) + defer server.Shutdown(ctx) + + //func (c *Client) NewStream(ctx context.Context, desc *StreamDesc, service, method string) (ClientStream, error) { + var req, resp internal.EchoPayload + if err := client.Call(ctx, serviceName, "Echo", &req, &resp); err != nil { + t.Fatal(err) + } + + stream, err := client.NewStream(ctx, &StreamDesc{true, true}, serviceName, "EchoStream", nil) + if err != nil { + t.Fatal(err) + } + for i := 1; i <= 100; i++ { + req := internal.EchoPayload{ + Seq: int64(i), + Msg: "should be returned", + } + if err := stream.SendMsg(&req); err != nil { + t.Fatalf("%d: %v", i, err) + } + var resp internal.EchoPayload + if err := stream.RecvMsg(&resp); err != nil { + t.Fatalf("%d: %v", i, err) + } + if resp.Seq != int64(i)+1 { + t.Fatalf("%d: unexpected sequence value: %d, expected %d", i, resp.Seq, i+1) + } + if resp.Msg != req.Msg { + t.Fatalf("%d: unexpected message: %q, expected %q", i, resp.Msg, req.Msg) + } + } + if err := stream.CloseSend(); err != nil { + t.Fatal(err) + } + + err = stream.RecvMsg(&resp) + if err == nil { + t.Fatal("expected io.EOF after close send") + } + if err != io.EOF { + t.Fatalf("expected io.EOF after close send, got %v", err) + } +} diff --git a/test.proto b/test.proto index deeed94..0e114d5 100644 --- a/test.proto +++ b/test.proto @@ -9,3 +9,8 @@ message TestPayload { int64 deadline = 2; string metadata = 3; } + +message EchoPayload { + int64 seq = 1; + string msg = 2; +}