diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 505c649..f6530f9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -86,6 +86,11 @@ jobs: run: | make coverage TESTFLAGS_RACE=-race + - name: Integration Tests + working-directory: src/github.com/containerd/ttrpc + run: | + make integration + - name: Code Coverage uses: codecov/codecov-action@v2 with: diff --git a/PROTOCOL.md b/PROTOCOL.md index de0835d..5475e0b 100644 --- a/PROTOCOL.md +++ b/PROTOCOL.md @@ -95,11 +95,20 @@ client or server may send data. A data message is not allowed on a unary stream. A data message should not be sent after indicating `remote closed` to the peer. The last data message on a stream must set the `remote closed` flag. +The `no data` flag is used to indicate that the data message does not include +any data. This is normally used with the `remote closed` flag to indicate the +stream is now closed without transmitting any data. Since ttrpc normally +transmits a single object per message, a zero length data message may be +interpreted as an empty object. For example, transmitting the number zero as a +protobuf message ends up with a data length of zero, but the message is still +considered data and should be processed. + #### Data | Flag | Name | Description | |------|-----------------|-----------------------------------| | 0x01 | `remote closed` | No more data expected from remote | +| 0x04 | `no data` | This message does not have data | ## Streaming diff --git a/Protobuild.toml b/Protobuild.toml index d2aba1f..34e61aa 100644 --- a/Protobuild.toml +++ b/Protobuild.toml @@ -16,4 +16,10 @@ generators = ["go"] # This section maps protobuf imports to Go packages. These will become # `-M` directives in the call to the go protobuf generator. [packages] + "google/protobuf/any.proto" = "github.com/gogo/protobuf/types" "proto/status.proto" = "google.golang.org/genproto/googleapis/rpc/status" + +[[overrides]] +# enable ttrpc and disable fieldpath and grpc for the shim +prefixes = ["github.com/containerd/ttrpc/integration/streaming"] +generators = ["go", "go-ttrpc"] diff --git a/channel.go b/channel.go index 81116a5..1e06274 100644 --- a/channel.go +++ b/channel.go @@ -19,6 +19,7 @@ package ttrpc import ( "bufio" "encoding/binary" + "errors" "fmt" "io" "net" @@ -38,15 +39,37 @@ type messageType uint8 const ( messageTypeRequest messageType = 0x1 messageTypeResponse messageType = 0x2 + messageTypeData messageType = 0x3 ) +func (mt messageType) String() string { + switch mt { + case messageTypeRequest: + return "request" + case messageTypeResponse: + return "response" + case messageTypeData: + return "data" + default: + return "unknown" + } +} + +const ( + flagRemoteClosed uint8 = 0x1 + flagRemoteOpen uint8 = 0x2 + flagNoData uint8 = 0x4 +) + +var ErrProtocol = errors.New("protocol error") + // messageHeader represents the fixed-length message header of 10 bytes sent // with every request. type messageHeader struct { Length uint32 // length excluding this header. b[:4] StreamID uint32 // identifies which request stream message is a part of. b[4:8] Type messageType // message type b[8] - Flags uint8 // reserved b[9] + Flags uint8 // type specific flags b[9] } func readMessageHeader(p []byte, r io.Reader) (messageHeader, error) { @@ -111,22 +134,31 @@ func (ch *channel) recv() (messageHeader, []byte, error) { return mh, nil, status.Errorf(codes.ResourceExhausted, "message length %v exceed maximum message size of %v", mh.Length, messageLengthMax) } - p := ch.getmbuf(int(mh.Length)) - if _, err := io.ReadFull(ch.br, p); err != nil { - return messageHeader{}, nil, fmt.Errorf("failed reading message: %w", err) + var p []byte + if mh.Length > 0 { + p = ch.getmbuf(int(mh.Length)) + if _, err := io.ReadFull(ch.br, p); err != nil { + return messageHeader{}, nil, fmt.Errorf("failed reading message: %w", err) + } } return mh, p, nil } -func (ch *channel) send(streamID uint32, t messageType, p []byte) error { - if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t}); err != nil { +func (ch *channel) send(streamID uint32, t messageType, flags uint8, p []byte) error { + // TODO: Error on send rather than on recv + //if len(p) > messageLengthMax { + // return status.Errorf(codes.InvalidArgument, "refusing to send, message length %v exceed maximum message size of %v", len(p), messageLengthMax) + //} + if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t, Flags: flags}); err != nil { return err } - _, err := ch.bw.Write(p) - if err != nil { - return err + if len(p) > 0 { + _, err := ch.bw.Write(p) + if err != nil { + return err + } } return ch.bw.Flush() diff --git a/channel_test.go b/channel_test.go index 0ee1508..de8b66d 100644 --- a/channel_test.go +++ b/channel_test.go @@ -44,7 +44,7 @@ func TestReadWriteMessage(t *testing.T) { go func() { for i, msg := range messages { - if err := ch.send(uint32(i), 1, msg); err != nil { + if err := ch.send(uint32(i), 1, 0, msg); err != nil { errs <- err return } @@ -96,7 +96,7 @@ func TestMessageOversize(t *testing.T) { ) go func() { - if err := wch.send(1, 1, msg); err != nil { + if err := wch.send(1, 1, 0, msg); err != nil { errs <- err } }() diff --git a/client.go b/client.go index 3e98847..384eb02 100644 --- a/client.go +++ b/client.go @@ -19,9 +19,9 @@ package ttrpc import ( "context" "errors" + "fmt" "io" "net" - "os" "strings" "sync" "syscall" @@ -42,7 +42,11 @@ type Client struct { codec codec conn net.Conn channel *channel - calls chan *callRequest + + streamLock sync.RWMutex + streams map[streamID]*stream + nextStreamID streamID + sendLock sync.Mutex ctx context.Context closed func() @@ -51,8 +55,6 @@ type Client struct { userCloseFunc func() userCloseWaitCh chan struct{} - errOnce sync.Once - err error interceptor UnaryClientInterceptor } @@ -73,13 +75,16 @@ func WithUnaryClientInterceptor(i UnaryClientInterceptor) ClientOpts { } } +// NewClient creates a new ttrpc client using the given connection func NewClient(conn net.Conn, opts ...ClientOpts) *Client { ctx, cancel := context.WithCancel(context.Background()) + channel := newChannel(conn) c := &Client{ codec: codec{}, conn: conn, - channel: newChannel(conn), - calls: make(chan *callRequest), + channel: channel, + streams: make(map[streamID]*stream), + nextStreamID: 1, closed: cancel, ctx: ctx, userCloseFunc: func() {}, @@ -95,13 +100,13 @@ func NewClient(conn net.Conn, opts ...ClientOpts) *Client { return c } -type callRequest struct { - ctx context.Context - req *Request - resp *Response // response will be written back here - errs chan error // error written here on completion +func (c *Client) send(sid uint32, mt messageType, flags uint8, b []byte) error { + c.sendLock.Lock() + defer c.sendLock.Unlock() + return c.channel.send(sid, mt, flags, b) } +// Call makes a unary request and returns with response func (c *Client) Call(ctx context.Context, service, method string, req, resp interface{}) error { payload, err := c.codec.Marshal(req) if err != nil { @@ -113,6 +118,7 @@ func (c *Client) Call(ctx context.Context, service, method string, req, resp int Service: service, Method: method, Payload: payload, + // TODO: metadata from context } cresp = &Response{} @@ -143,36 +149,137 @@ func (c *Client) Call(ctx context.Context, service, method string, req, resp int return nil } -func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) error { - errs := make(chan error, 1) - call := &callRequest{ - ctx: ctx, - req: req, - resp: resp, - errs: errs, - } +// StreamDesc describes the stream properties, whether the stream has +// a streaming client, a streaming server, or both +type StreamDesc struct { + StreamingClient bool + StreamingServer bool +} - select { - case <-ctx.Done(): - return ctx.Err() - case c.calls <- call: - case <-c.ctx.Done(): - return c.error() - } +// ClientStream is used to send or recv messages on the underlying stream +type ClientStream interface { + CloseSend() error + SendMsg(m interface{}) error + RecvMsg(m interface{}) error +} - select { - case <-ctx.Done(): - return ctx.Err() - case err := <-errs: +type clientStream struct { + ctx context.Context + s *stream + c *Client + desc *StreamDesc + localClosed bool + remoteClosed bool +} + +func (cs *clientStream) CloseSend() error { + if !cs.desc.StreamingClient { + return fmt.Errorf("%w: cannot close non-streaming client", ErrProtocol) + } + if cs.localClosed { + return ErrStreamClosed + } + err := cs.s.send(messageTypeData, flagRemoteClosed|flagNoData, nil) + if err != nil { return filterCloseErr(err) - case <-c.ctx.Done(): - return c.error() + } + cs.localClosed = true + return nil +} + +func (cs *clientStream) SendMsg(m interface{}) error { + if !cs.desc.StreamingClient { + return fmt.Errorf("%w: cannot send data from non-streaming client", ErrProtocol) + } + if cs.localClosed { + return ErrStreamClosed + } + + var ( + payload []byte + err error + ) + if m != nil { + payload, err = cs.c.codec.Marshal(m) + if err != nil { + return err + } + } + + err = cs.s.send(messageTypeData, 0, payload) + if err != nil { + return filterCloseErr(err) + } + + return nil +} + +func (cs *clientStream) RecvMsg(m interface{}) error { + if cs.remoteClosed { + return io.EOF + } + select { + case <-cs.ctx.Done(): + return cs.ctx.Err() + case msg, ok := <-cs.s.recv: + if !ok { + return cs.s.recvErr + } + + if msg.header.Type == messageTypeResponse { + resp := &Response{} + err := proto.Unmarshal(msg.payload[:msg.header.Length], resp) + // return the payload buffer for reuse + cs.c.channel.putmbuf(msg.payload) + if err != nil { + return err + } + + if err := cs.c.codec.Unmarshal(resp.Payload, m); err != nil { + return err + } + + if resp.Status != nil && resp.Status.Code != int32(codes.OK) { + return status.ErrorProto(resp.Status) + } + + cs.c.deleteStream(cs.s) + cs.remoteClosed = true + + return nil + } else if msg.header.Type == messageTypeData { + if !cs.desc.StreamingServer { + cs.c.deleteStream(cs.s) + cs.remoteClosed = true + return fmt.Errorf("received data from non-streaming server: %w", ErrProtocol) + } + if msg.header.Flags&flagRemoteClosed == flagRemoteClosed { + cs.c.deleteStream(cs.s) + cs.remoteClosed = true + + if msg.header.Flags&flagNoData == flagNoData { + return io.EOF + } + } + + err := cs.c.codec.Unmarshal(msg.payload[:msg.header.Length], m) + cs.c.channel.putmbuf(msg.payload) + if err != nil { + return err + } + return nil + } + + return fmt.Errorf("unexpected %q message received: %w", msg.header.Type, ErrProtocol) } } +// Close closes the ttrpc connection and underlying connection func (c *Client) Close() error { c.closeOnce.Do(func() { c.closed() + + c.conn.Close() }) return nil } @@ -188,194 +295,105 @@ func (c *Client) UserOnCloseWait(ctx context.Context) error { } } -type message struct { - messageHeader - p []byte - err error -} - -// callMap provides access to a map of active calls, guarded by a mutex. -type callMap struct { - m sync.Mutex - activeCalls map[uint32]*callRequest - closeErr error -} - -// newCallMap returns a new callMap with an empty set of active calls. -func newCallMap() *callMap { - return &callMap{ - activeCalls: make(map[uint32]*callRequest), - } -} - -// set adds a call entry to the map with the given streamID key. -func (cm *callMap) set(streamID uint32, cr *callRequest) error { - cm.m.Lock() - defer cm.m.Unlock() - if cm.closeErr != nil { - return cm.closeErr - } - cm.activeCalls[streamID] = cr - return nil -} - -// get looks up the call entry for the given streamID key, then removes it -// from the map and returns it. -func (cm *callMap) get(streamID uint32) (cr *callRequest, ok bool, err error) { - cm.m.Lock() - defer cm.m.Unlock() - if cm.closeErr != nil { - return nil, false, cm.closeErr - } - cr, ok = cm.activeCalls[streamID] - if ok { - delete(cm.activeCalls, streamID) - } - return -} - -// abort sends the given error to each active call, and clears the map. -// Once abort has been called, any subsequent calls to the callMap will return the error passed to abort. -func (cm *callMap) abort(err error) error { - cm.m.Lock() - defer cm.m.Unlock() - if cm.closeErr != nil { - return cm.closeErr - } - for streamID, call := range cm.activeCalls { - call.errs <- err - delete(cm.activeCalls, streamID) - } - cm.closeErr = err - return nil -} - func (c *Client) run() { - var ( - waiters = newCallMap() - receiverDone = make(chan struct{}) - ) + err := c.receiveLoop() + c.Close() + c.cleanupStreams(err) - // Sender goroutine - // Receives calls from dispatch, adds them to the set of active calls, and sends them - // to the server. - go func() { - var streamID uint32 = 1 - for { - select { - case <-c.ctx.Done(): - return - case call := <-c.calls: - id := streamID - streamID += 2 // enforce odd client initiated request ids - if err := waiters.set(id, call); err != nil { - call.errs <- err // errs is buffered so should not block. - continue - } - if err := c.send(id, messageTypeRequest, call.req); err != nil { - call.errs <- err // errs is buffered so should not block. - waiters.get(id) // remove from waiters set - } - } - } - }() - - // Receiver goroutine - // Receives responses from the server, looks up the call info in the set of active calls, - // and notifies the caller of the response. - go func() { - defer close(receiverDone) - for { - select { - case <-c.ctx.Done(): - c.setError(c.ctx.Err()) - return - default: - mh, p, err := c.channel.recv() - if err != nil { - _, ok := status.FromError(err) - if !ok { - // treat all errors that are not an rpc status as terminal. - // all others poison the connection. - c.setError(filterCloseErr(err)) - return - } - } - msg := &message{ - messageHeader: mh, - p: p[:mh.Length], - err: err, - } - call, ok, err := waiters.get(mh.StreamID) - if err != nil { - logrus.Errorf("ttrpc: failed to look up active call: %s", err) - continue - } - if !ok { - logrus.Errorf("ttrpc: received message for unknown channel %v", mh.StreamID) - continue - } - call.errs <- c.recv(call.resp, msg) - } - } - }() - - defer func() { - c.conn.Close() - c.userCloseFunc() - close(c.userCloseWaitCh) - }() + c.userCloseFunc() + close(c.userCloseWaitCh) +} +func (c *Client) receiveLoop() error { for { select { - case <-receiverDone: - // The receiver has exited. - // don't return out, let the close of the context trigger the abort of waiters - c.Close() case <-c.ctx.Done(): - // Abort all active calls. This will also prevent any new calls from being added - // to waiters. - waiters.abort(c.error()) - return + return ErrClosed + default: + var ( + msg = &streamMessage{} + err error + ) + + msg.header, msg.payload, err = c.channel.recv() + if err != nil { + _, ok := status.FromError(err) + if !ok { + // treat all errors that are not an rpc status as terminal. + // all others poison the connection. + return filterCloseErr(err) + } + } + sid := streamID(msg.header.StreamID) + s := c.getStream(sid) + if s == nil { + logrus.WithField("stream", sid).Errorf("ttrpc: received message on inactive stream") + continue + } + + if err != nil { + s.closeWithError(err) + } else { + if err := s.receive(c.ctx, msg); err != nil { + logrus.WithError(err).WithField("stream", sid).Errorf("ttrpc: failed to handle message") + } + } } } } -func (c *Client) error() error { - c.errOnce.Do(func() { - if c.err == nil { - c.err = ErrClosed - } - }) - return c.err -} +// createStream creates a new stream and registers it with the client +// Introduce stream types for multiple or single response +func (c *Client) createStream(flags uint8, b []byte) (*stream, error) { + c.streamLock.Lock() -func (c *Client) setError(err error) { - c.errOnce.Do(func() { - c.err = err - }) -} - -func (c *Client) send(streamID uint32, mtype messageType, msg interface{}) error { - p, err := c.codec.Marshal(msg) - if err != nil { - return err + // Check if closed since lock acquired to prevent adding + // anything after cleanup completes + select { + case <-c.ctx.Done(): + c.streamLock.Unlock() + return nil, ErrClosed + default: } - return c.channel.send(streamID, mtype, p) + // Stream ID should be allocated at same time + s := newStream(c.nextStreamID, c) + c.streams[s.id] = s + c.nextStreamID = c.nextStreamID + 2 + + c.sendLock.Lock() + defer c.sendLock.Unlock() + c.streamLock.Unlock() + + if err := c.channel.send(uint32(s.id), messageTypeRequest, flags, b); err != nil { + return s, filterCloseErr(err) + } + + return s, nil } -func (c *Client) recv(resp *Response, msg *message) error { - if msg.err != nil { - return msg.err - } +func (c *Client) deleteStream(s *stream) { + c.streamLock.Lock() + delete(c.streams, s.id) + c.streamLock.Unlock() + s.closeWithError(nil) +} - if msg.Type != messageTypeResponse { - return errors.New("unknown message type received") - } +func (c *Client) getStream(sid streamID) *stream { + c.streamLock.RLock() + s := c.streams[sid] + c.streamLock.RUnlock() + return s +} - defer c.channel.putmbuf(msg.p) - return proto.Unmarshal(msg.p, resp) +func (c *Client) cleanupStreams(err error) { + c.streamLock.Lock() + defer c.streamLock.Unlock() + + for sid, s := range c.streams { + s.closeWithError(err) + delete(c.streams, sid) + } } // filterCloseErr rewrites EOF and EPIPE errors to ErrClosed. Use when @@ -388,6 +406,8 @@ func filterCloseErr(err error) error { return nil case err == io.EOF: return ErrClosed + case errors.Is(err, io.ErrClosedPipe): + return ErrClosed case errors.Is(err, io.EOF): return ErrClosed case strings.Contains(err.Error(), "use of closed network connection"): @@ -395,11 +415,9 @@ func filterCloseErr(err error) error { default: // if we have an epipe on a write or econnreset on a read , we cast to errclosed var oerr *net.OpError - if errors.As(err, &oerr) && (oerr.Op == "write" || oerr.Op == "read") { - serr, sok := oerr.Err.(*os.SyscallError) - if sok && ((serr.Err == syscall.EPIPE && oerr.Op == "write") || - (serr.Err == syscall.ECONNRESET && oerr.Op == "read")) { - + if errors.As(err, &oerr) { + if (oerr.Op == "write" && errors.Is(err, syscall.EPIPE)) || + (oerr.Op == "read" && errors.Is(err, syscall.ECONNRESET)) { return ErrClosed } } @@ -407,3 +425,81 @@ func filterCloseErr(err error) error { return err } + +// NewStream creates a new stream with the given stream descriptor to the +// specified service and method. If not a streaming client, the request object +// may be provided. +func (c *Client) NewStream(ctx context.Context, desc *StreamDesc, service, method string, req interface{}) (ClientStream, error) { + var payload []byte + if req != nil { + var err error + payload, err = c.codec.Marshal(req) + if err != nil { + return nil, err + } + } + + request := &Request{ + Service: service, + Method: method, + Payload: payload, + // TODO: metadata from context + } + p, err := c.codec.Marshal(request) + if err != nil { + return nil, err + } + + var flags uint8 + if desc.StreamingClient { + flags = flagRemoteOpen + } else { + flags = flagRemoteClosed + } + s, err := c.createStream(flags, p) + if err != nil { + return nil, err + } + + return &clientStream{ + ctx: ctx, + s: s, + c: c, + desc: desc, + }, nil +} + +func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) error { + p, err := c.codec.Marshal(req) + if err != nil { + return err + } + + s, err := c.createStream(0, p) + if err != nil { + return err + } + defer c.deleteStream(s) + + select { + case <-ctx.Done(): + return ctx.Err() + case <-c.ctx.Done(): + return ErrClosed + case msg, ok := <-s.recv: + if !ok { + return s.recvErr + } + + if msg.header.Type == messageTypeResponse { + err = proto.Unmarshal(msg.payload[:msg.header.Length], resp) + } else { + err = fmt.Errorf("unexpected %q message received: %w", msg.header.Type, ErrProtocol) + } + + // return the payload buffer for reuse + c.channel.putmbuf(msg.payload) + + return err + } +} diff --git a/cmd/protoc-gen-go-ttrpc/generator.go b/cmd/protoc-gen-go-ttrpc/generator.go index 0886beb..6c0a54c 100644 --- a/cmd/protoc-gen-go-ttrpc/generator.go +++ b/cmd/protoc-gen-go-ttrpc/generator.go @@ -17,6 +17,7 @@ package main import ( + "fmt" "strings" "google.golang.org/protobuf/compiler/protogen" @@ -29,10 +30,19 @@ type generator struct { out *protogen.GeneratedFile ident struct { - context string - server string - client string - method string + context string + server string + client string + method string + stream string + serviceDesc string + streamDesc string + + streamServerIdent protogen.GoIdent + streamClientIdent protogen.GoIdent + + streamServer string + streamClient string } } @@ -54,6 +64,29 @@ func newGenerator(out *protogen.GeneratedFile) *generator { GoImportPath: "github.com/containerd/ttrpc", GoName: "Method", }) + gen.ident.stream = out.QualifiedGoIdent(protogen.GoIdent{ + GoImportPath: "github.com/containerd/ttrpc", + GoName: "Stream", + }) + gen.ident.serviceDesc = out.QualifiedGoIdent(protogen.GoIdent{ + GoImportPath: "github.com/containerd/ttrpc", + GoName: "ServiceDesc", + }) + gen.ident.streamDesc = out.QualifiedGoIdent(protogen.GoIdent{ + GoImportPath: "github.com/containerd/ttrpc", + GoName: "StreamDesc", + }) + + gen.ident.streamServerIdent = protogen.GoIdent{ + GoImportPath: "github.com/containerd/ttrpc", + GoName: "StreamServer", + } + gen.ident.streamClientIdent = protogen.GoIdent{ + GoImportPath: "github.com/containerd/ttrpc", + GoName: "ClientStream", + } + gen.ident.streamServer = out.QualifiedGoIdent(gen.ident.streamServerIdent) + gen.ident.streamClient = out.QualifiedGoIdent(gen.ident.streamClientIdent) return &gen } @@ -74,52 +107,277 @@ func (gen *generator) genService(service *protogen.Service) { fullName := service.Desc.FullName() p := gen.out + var methods []*protogen.Method + var streams []*protogen.Method + serviceName := service.GoName + "Service" p.P("type ", serviceName, " interface{") for _, method := range service.Methods { - p.P(method.GoName, - "(ctx ", gen.ident.context, ",", - "req *", method.Input.GoIdent, ")", - "(*", method.Output.GoIdent, ", error)") + var sendArgs, retArgs string + if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() { + streams = append(streams, method) + sendArgs = fmt.Sprintf("%s_%sServer", service.GoName, method.GoName) + if !method.Desc.IsStreamingClient() { + sendArgs = fmt.Sprintf("*%s, %s", p.QualifiedGoIdent(method.Input.GoIdent), sendArgs) + } + if method.Desc.IsStreamingServer() { + retArgs = "error" + } else { + retArgs = fmt.Sprintf("(*%s, error)", p.QualifiedGoIdent(method.Output.GoIdent)) + } + } else { + methods = append(methods, method) + sendArgs = fmt.Sprintf("*%s", p.QualifiedGoIdent(method.Input.GoIdent)) + retArgs = fmt.Sprintf("(*%s, error)", p.QualifiedGoIdent(method.Output.GoIdent)) + } + p.P(method.GoName, "(", gen.ident.context, ", ", sendArgs, ") ", retArgs) } p.P("}") + p.P() + + for _, method := range streams { + structName := strings.ToLower(service.GoName) + method.GoName + "Server" + + p.P("type ", service.GoName, "_", method.GoName, "Server interface {") + if method.Desc.IsStreamingServer() { + p.P("Send(*", method.Output.GoIdent, ") error") + } + if method.Desc.IsStreamingClient() { + p.P("Recv() (*", method.Input.GoIdent, ", error)") + + } + p.P(gen.ident.streamServer) + p.P("}") + p.P() + + p.P("type ", structName, " struct {") + p.P(gen.ident.streamServer) + p.P("}") + p.P() + + if method.Desc.IsStreamingServer() { + p.P("func (x *", structName, ") Send(m *", method.Output.GoIdent, ") error {") + p.P("return x.StreamServer.SendMsg(m)") + p.P("}") + p.P() + } + + if method.Desc.IsStreamingClient() { + p.P("func (x *", structName, ") Recv() (*", method.Input.GoIdent, ", error) {") + p.P("m := new(", method.Input.GoIdent, ")") + p.P("if err := x.StreamServer.RecvMsg(m); err != nil {") + p.P("return nil, err") + p.P("}") + p.P("return m, nil") + p.P("}") + p.P() + } + } // registration method p.P("func Register", serviceName, "(srv *", gen.ident.server, ", svc ", serviceName, "){") - p.P(`srv.Register("`, fullName, `", map[string]`, gen.ident.method, "{") - for _, method := range service.Methods { - p.P(`"`, method.GoName, `": func(ctx `, gen.ident.context, ", unmarshal func(interface{}) error)(interface{}, error){") - p.P("var req ", method.Input.GoIdent) - p.P("if err := unmarshal(&req); err != nil {") - p.P("return nil, err") - p.P("}") - p.P("return svc.", method.GoName, "(ctx, &req)") + p.P(`srv.RegisterService("`, fullName, `", &`, gen.ident.serviceDesc, "{") + if len(methods) > 0 { + p.P(`Methods: map[string]`, gen.ident.method, "{") + for _, method := range methods { + p.P(`"`, method.GoName, `": func(ctx `, gen.ident.context, ", unmarshal func(interface{}) error)(interface{}, error){") + p.P("var req ", method.Input.GoIdent) + p.P("if err := unmarshal(&req); err != nil {") + p.P("return nil, err") + p.P("}") + p.P("return svc.", method.GoName, "(ctx, &req)") + p.P("},") + } + p.P("},") + } + if len(streams) > 0 { + p.P(`Streams: map[string]`, gen.ident.stream, "{") + for _, method := range streams { + p.P(`"`, method.GoName, `": {`) + p.P(`Handler: func(ctx `, gen.ident.context, ", stream ", gen.ident.streamServer, ") (interface{}, error) {") + + structName := strings.ToLower(service.GoName) + method.GoName + "Server" + var sendArg string + if !method.Desc.IsStreamingClient() { + sendArg = "m, " + p.P("m := new(", method.Input.GoIdent, ")") + p.P("if err := stream.RecvMsg(m); err != nil {") + p.P("return nil, err") + p.P("}") + } + if method.Desc.IsStreamingServer() { + p.P("return nil, svc.", method.GoName, "(ctx, ", sendArg, "&", structName, "{stream})") + } else { + p.P("return svc.", method.GoName, "(ctx, ", sendArg, "&", structName, "{stream})") + + } + p.P("},") + if method.Desc.IsStreamingClient() { + p.P("StreamingClient: true,") + } else { + p.P("StreamingClient: false,") + } + if method.Desc.IsStreamingServer() { + p.P("StreamingServer: true,") + } else { + p.P("StreamingServer: false,") + } + p.P("},") + } p.P("},") } p.P("})") p.P("}") + p.P() clientType := service.GoName + "Client" + + // For consistency with ttrpc 1.0 without streaming, just use + // the service name if no streams are defined + clientInterface := serviceName + if len(streams) > 0 { + clientInterface = clientType + // Stream client interfaces are different than the server interface + p.P("type ", clientInterface, " interface{") + for _, method := range service.Methods { + if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() { + streams = append(streams, method) + var sendArg string + if !method.Desc.IsStreamingClient() { + sendArg = fmt.Sprintf("*%s, ", p.QualifiedGoIdent(method.Input.GoIdent)) + } + p.P(method.GoName, + "(", gen.ident.context, ", ", sendArg, + ") (", service.GoName, "_", method.GoName, "Client, error)") + } else { + methods = append(methods, method) + p.P(method.GoName, + "(", gen.ident.context, ", ", + "*", method.Input.GoIdent, ")", + "(*", method.Output.GoIdent, ", error)") + } + } + p.P("}") + p.P() + } + clientStructType := strings.ToLower(clientType[:1]) + clientType[1:] p.P("type ", clientStructType, " struct{") p.P("client *", gen.ident.client) p.P("}") - p.P("func New", clientType, "(client *", gen.ident.client, ")", serviceName, "{") + p.P("func New", clientType, "(client *", gen.ident.client, ")", clientInterface, "{") p.P("return &", clientStructType, "{") p.P("client:client,") p.P("}") p.P("}") + p.P() for _, method := range service.Methods { - p.P("func (c *", clientStructType, ")", method.GoName, "(", - "ctx ", gen.ident.context, ",", - "req *", method.Input.GoIdent, ")", - "(*", method.Output.GoIdent, ", error){") - p.P("var resp ", method.Output.GoIdent) - p.P(`if err := c.client.Call(ctx, "`, fullName, `", "`, method.Desc.Name(), `", req, &resp); err != nil {`) - p.P("return nil, err") - p.P("}") - p.P("return &resp, nil") - p.P("}") + var sendArg string + if !method.Desc.IsStreamingClient() { + sendArg = ", req *" + gen.out.QualifiedGoIdent(method.Input.GoIdent) + } + + intName := service.GoName + "_" + method.GoName + "Client" + var retArg string + if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() { + retArg = intName + } else { + retArg = "*" + gen.out.QualifiedGoIdent(method.Output.GoIdent) + } + + p.P("func (c *", clientStructType, ") ", method.GoName, + "(ctx ", gen.ident.context, "", sendArg, ") ", + "(", retArg, ", error) {") + + if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() { + var streamingClient, streamingServer, req string + if method.Desc.IsStreamingClient() { + streamingClient = "true" + req = "nil" + } else { + streamingClient = "false" + req = "req" + } + if method.Desc.IsStreamingServer() { + streamingServer = "true" + } else { + streamingServer = "false" + } + p.P("stream, err := c.client.NewStream(ctx, &", gen.ident.streamDesc, "{") + p.P("StreamingClient: ", streamingClient, ",") + p.P("StreamingServer: ", streamingServer, ",") + p.P("}, ", `"`+fullName+`", `, `"`+method.GoName+`", `, req, `)`) + p.P("if err != nil {") + p.P("return nil, err") + p.P("}") + + structName := strings.ToLower(service.GoName) + method.GoName + "Client" + + p.P("x := &", structName, "{stream}") + + p.P("return x, nil") + p.P("}") + p.P() + + // Create interface + p.P("type ", intName, " interface {") + if method.Desc.IsStreamingClient() { + p.P("Send(*", method.Input.GoIdent, ") error") + } + if method.Desc.IsStreamingServer() { + p.P("Recv() (*", method.Output.GoIdent, ", error)") + } else { + p.P("CloseAndRecv() (*", method.Output.GoIdent, ", error)") + } + + p.P(gen.ident.streamClient) + p.P("}") + p.P() + + // Create struct + p.P("type ", structName, " struct {") + p.P(gen.ident.streamClient) + p.P("}") + p.P() + + if method.Desc.IsStreamingClient() { + p.P("func (x *", structName, ") Send(m *", method.Input.GoIdent, ") error {") + p.P("return x.", gen.ident.streamClientIdent.GoName, ".SendMsg(m)") + p.P("}") + p.P() + } + + if method.Desc.IsStreamingServer() { + p.P("func (x *", structName, ") Recv() (*", method.Output.GoIdent, ", error) {") + p.P("m := new(", method.Output.GoIdent, ")") + p.P("if err := x.ClientStream.RecvMsg(m); err != nil {") + p.P("return nil, err") + p.P("}") + p.P("return m, nil") + p.P("}") + p.P() + } else { + p.P("func (x *", structName, ") CloseAndRecv() (*", method.Output.GoIdent, ", error) {") + p.P("if err := x.ClientStream.CloseSend(); err != nil {") + p.P("return nil, err") + p.P("}") + p.P("m := new(", method.Output.GoIdent, ")") + p.P("if err := x.ClientStream.RecvMsg(m); err != nil {") + p.P("return nil, err") + p.P("}") + p.P("return m, nil") + p.P("}") + p.P() + } + } else { + p.P("var resp ", method.Output.GoIdent) + p.P(`if err := c.client.Call(ctx, "`, fullName, `", "`, method.Desc.Name(), `", req, &resp); err != nil {`) + p.P("return nil, err") + p.P("}") + p.P("return &resp, nil") + p.P("}") + p.P() + } } } diff --git a/integration/streaming/doc.go b/integration/streaming/doc.go new file mode 100644 index 0000000..04c4362 --- /dev/null +++ b/integration/streaming/doc.go @@ -0,0 +1,17 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package streaming diff --git a/integration/streaming/test.pb.go b/integration/streaming/test.pb.go new file mode 100644 index 0000000..3155dd3 --- /dev/null +++ b/integration/streaming/test.pb.go @@ -0,0 +1,355 @@ +// +//Copyright The containerd Authors. +// +//Licensed under the Apache License, Version 2.0 (the "License"); +//you may not use this file except in compliance with the License. +//You may obtain a copy of the License at +// +//http://www.apache.org/licenses/LICENSE-2.0 +// +//Unless required by applicable law or agreed to in writing, software +//distributed under the License is distributed on an "AS IS" BASIS, +//WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +//See the License for the specific language governing permissions and +//limitations under the License. + +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.27.1 +// protoc v3.11.4 +// source: github.com/containerd/ttrpc/integration/streaming/test.proto + +package streaming + +import ( + empty "github.com/golang/protobuf/ptypes/empty" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type EchoPayload struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Seq uint32 `protobuf:"varint,1,opt,name=seq,proto3" json:"seq,omitempty"` + Msg string `protobuf:"bytes,2,opt,name=msg,proto3" json:"msg,omitempty"` +} + +func (x *EchoPayload) Reset() { + *x = EchoPayload{} + if protoimpl.UnsafeEnabled { + mi := &file_github_com_containerd_ttrpc_integration_streaming_test_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *EchoPayload) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*EchoPayload) ProtoMessage() {} + +func (x *EchoPayload) ProtoReflect() protoreflect.Message { + mi := &file_github_com_containerd_ttrpc_integration_streaming_test_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use EchoPayload.ProtoReflect.Descriptor instead. +func (*EchoPayload) Descriptor() ([]byte, []int) { + return file_github_com_containerd_ttrpc_integration_streaming_test_proto_rawDescGZIP(), []int{0} +} + +func (x *EchoPayload) GetSeq() uint32 { + if x != nil { + return x.Seq + } + return 0 +} + +func (x *EchoPayload) GetMsg() string { + if x != nil { + return x.Msg + } + return "" +} + +type Part struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Add int32 `protobuf:"varint,1,opt,name=add,proto3" json:"add,omitempty"` +} + +func (x *Part) Reset() { + *x = Part{} + if protoimpl.UnsafeEnabled { + mi := &file_github_com_containerd_ttrpc_integration_streaming_test_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Part) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Part) ProtoMessage() {} + +func (x *Part) ProtoReflect() protoreflect.Message { + mi := &file_github_com_containerd_ttrpc_integration_streaming_test_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Part.ProtoReflect.Descriptor instead. +func (*Part) Descriptor() ([]byte, []int) { + return file_github_com_containerd_ttrpc_integration_streaming_test_proto_rawDescGZIP(), []int{1} +} + +func (x *Part) GetAdd() int32 { + if x != nil { + return x.Add + } + return 0 +} + +type Sum struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Sum int32 `protobuf:"varint,1,opt,name=sum,proto3" json:"sum,omitempty"` + Num int32 `protobuf:"varint,2,opt,name=num,proto3" json:"num,omitempty"` +} + +func (x *Sum) Reset() { + *x = Sum{} + if protoimpl.UnsafeEnabled { + mi := &file_github_com_containerd_ttrpc_integration_streaming_test_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Sum) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Sum) ProtoMessage() {} + +func (x *Sum) ProtoReflect() protoreflect.Message { + mi := &file_github_com_containerd_ttrpc_integration_streaming_test_proto_msgTypes[2] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Sum.ProtoReflect.Descriptor instead. +func (*Sum) Descriptor() ([]byte, []int) { + return file_github_com_containerd_ttrpc_integration_streaming_test_proto_rawDescGZIP(), []int{2} +} + +func (x *Sum) GetSum() int32 { + if x != nil { + return x.Sum + } + return 0 +} + +func (x *Sum) GetNum() int32 { + if x != nil { + return x.Num + } + return 0 +} + +var File_github_com_containerd_ttrpc_integration_streaming_test_proto protoreflect.FileDescriptor + +var file_github_com_containerd_ttrpc_integration_streaming_test_proto_rawDesc = []byte{ + 0x0a, 0x3c, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x63, 0x6f, 0x6e, + 0x74, 0x61, 0x69, 0x6e, 0x65, 0x72, 0x64, 0x2f, 0x74, 0x74, 0x72, 0x70, 0x63, 0x2f, 0x69, 0x6e, + 0x74, 0x65, 0x67, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2f, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, + 0x69, 0x6e, 0x67, 0x2f, 0x74, 0x65, 0x73, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x1b, + 0x74, 0x74, 0x72, 0x70, 0x63, 0x2e, 0x69, 0x6e, 0x74, 0x65, 0x67, 0x72, 0x61, 0x74, 0x69, 0x6f, + 0x6e, 0x2e, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x69, 0x6e, 0x67, 0x1a, 0x1b, 0x67, 0x6f, 0x6f, + 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x65, 0x6d, 0x70, + 0x74, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x31, 0x0a, 0x0b, 0x45, 0x63, 0x68, 0x6f, + 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x10, 0x0a, 0x03, 0x73, 0x65, 0x71, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x0d, 0x52, 0x03, 0x73, 0x65, 0x71, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x73, 0x67, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6d, 0x73, 0x67, 0x22, 0x18, 0x0a, 0x04, 0x50, + 0x61, 0x72, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x64, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, + 0x52, 0x03, 0x61, 0x64, 0x64, 0x22, 0x29, 0x0a, 0x03, 0x53, 0x75, 0x6d, 0x12, 0x10, 0x0a, 0x03, + 0x73, 0x75, 0x6d, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x03, 0x73, 0x75, 0x6d, 0x12, 0x10, + 0x0a, 0x03, 0x6e, 0x75, 0x6d, 0x18, 0x02, 0x20, 0x01, 0x28, 0x05, 0x52, 0x03, 0x6e, 0x75, 0x6d, + 0x32, 0xa0, 0x04, 0x0a, 0x09, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x69, 0x6e, 0x67, 0x12, 0x5a, + 0x0a, 0x04, 0x45, 0x63, 0x68, 0x6f, 0x12, 0x28, 0x2e, 0x74, 0x74, 0x72, 0x70, 0x63, 0x2e, 0x69, + 0x6e, 0x74, 0x65, 0x67, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x73, 0x74, 0x72, 0x65, 0x61, + 0x6d, 0x69, 0x6e, 0x67, 0x2e, 0x45, 0x63, 0x68, 0x6f, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, + 0x1a, 0x28, 0x2e, 0x74, 0x74, 0x72, 0x70, 0x63, 0x2e, 0x69, 0x6e, 0x74, 0x65, 0x67, 0x72, 0x61, + 0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x69, 0x6e, 0x67, 0x2e, 0x45, + 0x63, 0x68, 0x6f, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x64, 0x0a, 0x0a, 0x45, 0x63, + 0x68, 0x6f, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x28, 0x2e, 0x74, 0x74, 0x72, 0x70, 0x63, + 0x2e, 0x69, 0x6e, 0x74, 0x65, 0x67, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x73, 0x74, 0x72, + 0x65, 0x61, 0x6d, 0x69, 0x6e, 0x67, 0x2e, 0x45, 0x63, 0x68, 0x6f, 0x50, 0x61, 0x79, 0x6c, 0x6f, + 0x61, 0x64, 0x1a, 0x28, 0x2e, 0x74, 0x74, 0x72, 0x70, 0x63, 0x2e, 0x69, 0x6e, 0x74, 0x65, 0x67, + 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x69, 0x6e, 0x67, + 0x2e, 0x45, 0x63, 0x68, 0x6f, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x28, 0x01, 0x30, 0x01, + 0x12, 0x52, 0x0a, 0x09, 0x53, 0x75, 0x6d, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x21, 0x2e, + 0x74, 0x74, 0x72, 0x70, 0x63, 0x2e, 0x69, 0x6e, 0x74, 0x65, 0x67, 0x72, 0x61, 0x74, 0x69, 0x6f, + 0x6e, 0x2e, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x69, 0x6e, 0x67, 0x2e, 0x50, 0x61, 0x72, 0x74, + 0x1a, 0x20, 0x2e, 0x74, 0x74, 0x72, 0x70, 0x63, 0x2e, 0x69, 0x6e, 0x74, 0x65, 0x67, 0x72, 0x61, + 0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x69, 0x6e, 0x67, 0x2e, 0x53, + 0x75, 0x6d, 0x28, 0x01, 0x12, 0x55, 0x0a, 0x0c, 0x44, 0x69, 0x76, 0x69, 0x64, 0x65, 0x53, 0x74, + 0x72, 0x65, 0x61, 0x6d, 0x12, 0x20, 0x2e, 0x74, 0x74, 0x72, 0x70, 0x63, 0x2e, 0x69, 0x6e, 0x74, + 0x65, 0x67, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x69, + 0x6e, 0x67, 0x2e, 0x53, 0x75, 0x6d, 0x1a, 0x21, 0x2e, 0x74, 0x74, 0x72, 0x70, 0x63, 0x2e, 0x69, + 0x6e, 0x74, 0x65, 0x67, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x73, 0x74, 0x72, 0x65, 0x61, + 0x6d, 0x69, 0x6e, 0x67, 0x2e, 0x50, 0x61, 0x72, 0x74, 0x30, 0x01, 0x12, 0x4e, 0x0a, 0x08, 0x45, + 0x63, 0x68, 0x6f, 0x4e, 0x75, 0x6c, 0x6c, 0x12, 0x28, 0x2e, 0x74, 0x74, 0x72, 0x70, 0x63, 0x2e, + 0x69, 0x6e, 0x74, 0x65, 0x67, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x73, 0x74, 0x72, 0x65, + 0x61, 0x6d, 0x69, 0x6e, 0x67, 0x2e, 0x45, 0x63, 0x68, 0x6f, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, + 0x64, 0x1a, 0x16, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x28, 0x01, 0x12, 0x56, 0x0a, 0x0e, 0x45, + 0x63, 0x68, 0x6f, 0x4e, 0x75, 0x6c, 0x6c, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x28, 0x2e, + 0x74, 0x74, 0x72, 0x70, 0x63, 0x2e, 0x69, 0x6e, 0x74, 0x65, 0x67, 0x72, 0x61, 0x74, 0x69, 0x6f, + 0x6e, 0x2e, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x69, 0x6e, 0x67, 0x2e, 0x45, 0x63, 0x68, 0x6f, + 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x1a, 0x16, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, + 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x28, + 0x01, 0x30, 0x01, 0x42, 0x3d, 0x5a, 0x3b, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, + 0x6d, 0x2f, 0x63, 0x6f, 0x6e, 0x74, 0x61, 0x69, 0x6e, 0x65, 0x72, 0x64, 0x2f, 0x74, 0x74, 0x72, + 0x70, 0x63, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x67, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2f, 0x73, + 0x74, 0x72, 0x65, 0x61, 0x6d, 0x69, 0x6e, 0x67, 0x3b, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x69, + 0x6e, 0x67, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_github_com_containerd_ttrpc_integration_streaming_test_proto_rawDescOnce sync.Once + file_github_com_containerd_ttrpc_integration_streaming_test_proto_rawDescData = file_github_com_containerd_ttrpc_integration_streaming_test_proto_rawDesc +) + +func file_github_com_containerd_ttrpc_integration_streaming_test_proto_rawDescGZIP() []byte { + file_github_com_containerd_ttrpc_integration_streaming_test_proto_rawDescOnce.Do(func() { + file_github_com_containerd_ttrpc_integration_streaming_test_proto_rawDescData = protoimpl.X.CompressGZIP(file_github_com_containerd_ttrpc_integration_streaming_test_proto_rawDescData) + }) + return file_github_com_containerd_ttrpc_integration_streaming_test_proto_rawDescData +} + +var file_github_com_containerd_ttrpc_integration_streaming_test_proto_msgTypes = make([]protoimpl.MessageInfo, 3) +var file_github_com_containerd_ttrpc_integration_streaming_test_proto_goTypes = []interface{}{ + (*EchoPayload)(nil), // 0: ttrpc.integration.streaming.EchoPayload + (*Part)(nil), // 1: ttrpc.integration.streaming.Part + (*Sum)(nil), // 2: ttrpc.integration.streaming.Sum + (*empty.Empty)(nil), // 3: google.protobuf.Empty +} +var file_github_com_containerd_ttrpc_integration_streaming_test_proto_depIdxs = []int32{ + 0, // 0: ttrpc.integration.streaming.Streaming.Echo:input_type -> ttrpc.integration.streaming.EchoPayload + 0, // 1: ttrpc.integration.streaming.Streaming.EchoStream:input_type -> ttrpc.integration.streaming.EchoPayload + 1, // 2: ttrpc.integration.streaming.Streaming.SumStream:input_type -> ttrpc.integration.streaming.Part + 2, // 3: ttrpc.integration.streaming.Streaming.DivideStream:input_type -> ttrpc.integration.streaming.Sum + 0, // 4: ttrpc.integration.streaming.Streaming.EchoNull:input_type -> ttrpc.integration.streaming.EchoPayload + 0, // 5: ttrpc.integration.streaming.Streaming.EchoNullStream:input_type -> ttrpc.integration.streaming.EchoPayload + 0, // 6: ttrpc.integration.streaming.Streaming.Echo:output_type -> ttrpc.integration.streaming.EchoPayload + 0, // 7: ttrpc.integration.streaming.Streaming.EchoStream:output_type -> ttrpc.integration.streaming.EchoPayload + 2, // 8: ttrpc.integration.streaming.Streaming.SumStream:output_type -> ttrpc.integration.streaming.Sum + 1, // 9: ttrpc.integration.streaming.Streaming.DivideStream:output_type -> ttrpc.integration.streaming.Part + 3, // 10: ttrpc.integration.streaming.Streaming.EchoNull:output_type -> google.protobuf.Empty + 3, // 11: ttrpc.integration.streaming.Streaming.EchoNullStream:output_type -> google.protobuf.Empty + 6, // [6:12] is the sub-list for method output_type + 0, // [0:6] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_github_com_containerd_ttrpc_integration_streaming_test_proto_init() } +func file_github_com_containerd_ttrpc_integration_streaming_test_proto_init() { + if File_github_com_containerd_ttrpc_integration_streaming_test_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_github_com_containerd_ttrpc_integration_streaming_test_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*EchoPayload); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_github_com_containerd_ttrpc_integration_streaming_test_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Part); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_github_com_containerd_ttrpc_integration_streaming_test_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Sum); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_github_com_containerd_ttrpc_integration_streaming_test_proto_rawDesc, + NumEnums: 0, + NumMessages: 3, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_github_com_containerd_ttrpc_integration_streaming_test_proto_goTypes, + DependencyIndexes: file_github_com_containerd_ttrpc_integration_streaming_test_proto_depIdxs, + MessageInfos: file_github_com_containerd_ttrpc_integration_streaming_test_proto_msgTypes, + }.Build() + File_github_com_containerd_ttrpc_integration_streaming_test_proto = out.File + file_github_com_containerd_ttrpc_integration_streaming_test_proto_rawDesc = nil + file_github_com_containerd_ttrpc_integration_streaming_test_proto_goTypes = nil + file_github_com_containerd_ttrpc_integration_streaming_test_proto_depIdxs = nil +} diff --git a/integration/streaming/test.proto b/integration/streaming/test.proto new file mode 100644 index 0000000..3f46ccf --- /dev/null +++ b/integration/streaming/test.proto @@ -0,0 +1,51 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +syntax = "proto3"; + +package ttrpc.integration.streaming; + +import "google/protobuf/empty.proto"; + +option go_package = "github.com/containerd/ttrpc/integration/streaming;streaming"; + +// Shim service is launched for each container and is responsible for owning the IO +// for the container and its additional processes. The shim is also the parent of +// each container and allows reattaching to the IO and receiving the exit status +// for the container processes. + +service Streaming { + rpc Echo(EchoPayload) returns (EchoPayload); + rpc EchoStream(stream EchoPayload) returns (stream EchoPayload); + rpc SumStream(stream Part) returns (Sum); + rpc DivideStream(Sum) returns (stream Part); + rpc EchoNull(stream EchoPayload) returns (google.protobuf.Empty); + rpc EchoNullStream(stream EchoPayload) returns (stream google.protobuf.Empty); +} + +message EchoPayload { + uint32 seq = 1; + string msg = 2; +} + +message Part { + int32 add = 1; +} + +message Sum { + int32 sum = 1; + int32 num = 2; +} diff --git a/integration/streaming/test_ttrpc.pb.go b/integration/streaming/test_ttrpc.pb.go new file mode 100644 index 0000000..40dd05c --- /dev/null +++ b/integration/streaming/test_ttrpc.pb.go @@ -0,0 +1,362 @@ +// Code generated by protoc-gen-go-ttrpc. DO NOT EDIT. +// source: github.com/containerd/ttrpc/integration/streaming/test.proto +package streaming + +import ( + context "context" + ttrpc "github.com/containerd/ttrpc" + empty "github.com/golang/protobuf/ptypes/empty" +) + +type StreamingService interface { + Echo(context.Context, *EchoPayload) (*EchoPayload, error) + EchoStream(context.Context, Streaming_EchoStreamServer) error + SumStream(context.Context, Streaming_SumStreamServer) (*Sum, error) + DivideStream(context.Context, *Sum, Streaming_DivideStreamServer) error + EchoNull(context.Context, Streaming_EchoNullServer) (*empty.Empty, error) + EchoNullStream(context.Context, Streaming_EchoNullStreamServer) error +} + +type Streaming_EchoStreamServer interface { + Send(*EchoPayload) error + Recv() (*EchoPayload, error) + ttrpc.StreamServer +} + +type streamingEchoStreamServer struct { + ttrpc.StreamServer +} + +func (x *streamingEchoStreamServer) Send(m *EchoPayload) error { + return x.StreamServer.SendMsg(m) +} + +func (x *streamingEchoStreamServer) Recv() (*EchoPayload, error) { + m := new(EchoPayload) + if err := x.StreamServer.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +type Streaming_SumStreamServer interface { + Recv() (*Part, error) + ttrpc.StreamServer +} + +type streamingSumStreamServer struct { + ttrpc.StreamServer +} + +func (x *streamingSumStreamServer) Recv() (*Part, error) { + m := new(Part) + if err := x.StreamServer.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +type Streaming_DivideStreamServer interface { + Send(*Part) error + ttrpc.StreamServer +} + +type streamingDivideStreamServer struct { + ttrpc.StreamServer +} + +func (x *streamingDivideStreamServer) Send(m *Part) error { + return x.StreamServer.SendMsg(m) +} + +type Streaming_EchoNullServer interface { + Recv() (*EchoPayload, error) + ttrpc.StreamServer +} + +type streamingEchoNullServer struct { + ttrpc.StreamServer +} + +func (x *streamingEchoNullServer) Recv() (*EchoPayload, error) { + m := new(EchoPayload) + if err := x.StreamServer.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +type Streaming_EchoNullStreamServer interface { + Send(*empty.Empty) error + Recv() (*EchoPayload, error) + ttrpc.StreamServer +} + +type streamingEchoNullStreamServer struct { + ttrpc.StreamServer +} + +func (x *streamingEchoNullStreamServer) Send(m *empty.Empty) error { + return x.StreamServer.SendMsg(m) +} + +func (x *streamingEchoNullStreamServer) Recv() (*EchoPayload, error) { + m := new(EchoPayload) + if err := x.StreamServer.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +func RegisterStreamingService(srv *ttrpc.Server, svc StreamingService) { + srv.RegisterService("ttrpc.integration.streaming.Streaming", &ttrpc.ServiceDesc{ + Methods: map[string]ttrpc.Method{ + "Echo": func(ctx context.Context, unmarshal func(interface{}) error) (interface{}, error) { + var req EchoPayload + if err := unmarshal(&req); err != nil { + return nil, err + } + return svc.Echo(ctx, &req) + }, + }, + Streams: map[string]ttrpc.Stream{ + "EchoStream": { + Handler: func(ctx context.Context, stream ttrpc.StreamServer) (interface{}, error) { + return nil, svc.EchoStream(ctx, &streamingEchoStreamServer{stream}) + }, + StreamingClient: true, + StreamingServer: true, + }, + "SumStream": { + Handler: func(ctx context.Context, stream ttrpc.StreamServer) (interface{}, error) { + return svc.SumStream(ctx, &streamingSumStreamServer{stream}) + }, + StreamingClient: true, + StreamingServer: false, + }, + "DivideStream": { + Handler: func(ctx context.Context, stream ttrpc.StreamServer) (interface{}, error) { + m := new(Sum) + if err := stream.RecvMsg(m); err != nil { + return nil, err + } + return nil, svc.DivideStream(ctx, m, &streamingDivideStreamServer{stream}) + }, + StreamingClient: false, + StreamingServer: true, + }, + "EchoNull": { + Handler: func(ctx context.Context, stream ttrpc.StreamServer) (interface{}, error) { + return svc.EchoNull(ctx, &streamingEchoNullServer{stream}) + }, + StreamingClient: true, + StreamingServer: false, + }, + "EchoNullStream": { + Handler: func(ctx context.Context, stream ttrpc.StreamServer) (interface{}, error) { + return nil, svc.EchoNullStream(ctx, &streamingEchoNullStreamServer{stream}) + }, + StreamingClient: true, + StreamingServer: true, + }, + }, + }) +} + +type StreamingClient interface { + Echo(context.Context, *EchoPayload) (*EchoPayload, error) + EchoStream(context.Context) (Streaming_EchoStreamClient, error) + SumStream(context.Context) (Streaming_SumStreamClient, error) + DivideStream(context.Context, *Sum) (Streaming_DivideStreamClient, error) + EchoNull(context.Context) (Streaming_EchoNullClient, error) + EchoNullStream(context.Context) (Streaming_EchoNullStreamClient, error) +} + +type streamingClient struct { + client *ttrpc.Client +} + +func NewStreamingClient(client *ttrpc.Client) StreamingClient { + return &streamingClient{ + client: client, + } +} + +func (c *streamingClient) Echo(ctx context.Context, req *EchoPayload) (*EchoPayload, error) { + var resp EchoPayload + if err := c.client.Call(ctx, "ttrpc.integration.streaming.Streaming", "Echo", req, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +func (c *streamingClient) EchoStream(ctx context.Context) (Streaming_EchoStreamClient, error) { + stream, err := c.client.NewStream(ctx, &ttrpc.StreamDesc{ + StreamingClient: true, + StreamingServer: true, + }, "ttrpc.integration.streaming.Streaming", "EchoStream", nil) + if err != nil { + return nil, err + } + x := &streamingEchoStreamClient{stream} + return x, nil +} + +type Streaming_EchoStreamClient interface { + Send(*EchoPayload) error + Recv() (*EchoPayload, error) + ttrpc.ClientStream +} + +type streamingEchoStreamClient struct { + ttrpc.ClientStream +} + +func (x *streamingEchoStreamClient) Send(m *EchoPayload) error { + return x.ClientStream.SendMsg(m) +} + +func (x *streamingEchoStreamClient) Recv() (*EchoPayload, error) { + m := new(EchoPayload) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +func (c *streamingClient) SumStream(ctx context.Context) (Streaming_SumStreamClient, error) { + stream, err := c.client.NewStream(ctx, &ttrpc.StreamDesc{ + StreamingClient: true, + StreamingServer: false, + }, "ttrpc.integration.streaming.Streaming", "SumStream", nil) + if err != nil { + return nil, err + } + x := &streamingSumStreamClient{stream} + return x, nil +} + +type Streaming_SumStreamClient interface { + Send(*Part) error + CloseAndRecv() (*Sum, error) + ttrpc.ClientStream +} + +type streamingSumStreamClient struct { + ttrpc.ClientStream +} + +func (x *streamingSumStreamClient) Send(m *Part) error { + return x.ClientStream.SendMsg(m) +} + +func (x *streamingSumStreamClient) CloseAndRecv() (*Sum, error) { + if err := x.ClientStream.CloseSend(); err != nil { + return nil, err + } + m := new(Sum) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +func (c *streamingClient) DivideStream(ctx context.Context, req *Sum) (Streaming_DivideStreamClient, error) { + stream, err := c.client.NewStream(ctx, &ttrpc.StreamDesc{ + StreamingClient: false, + StreamingServer: true, + }, "ttrpc.integration.streaming.Streaming", "DivideStream", req) + if err != nil { + return nil, err + } + x := &streamingDivideStreamClient{stream} + return x, nil +} + +type Streaming_DivideStreamClient interface { + Recv() (*Part, error) + ttrpc.ClientStream +} + +type streamingDivideStreamClient struct { + ttrpc.ClientStream +} + +func (x *streamingDivideStreamClient) Recv() (*Part, error) { + m := new(Part) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +func (c *streamingClient) EchoNull(ctx context.Context) (Streaming_EchoNullClient, error) { + stream, err := c.client.NewStream(ctx, &ttrpc.StreamDesc{ + StreamingClient: true, + StreamingServer: false, + }, "ttrpc.integration.streaming.Streaming", "EchoNull", nil) + if err != nil { + return nil, err + } + x := &streamingEchoNullClient{stream} + return x, nil +} + +type Streaming_EchoNullClient interface { + Send(*EchoPayload) error + CloseAndRecv() (*empty.Empty, error) + ttrpc.ClientStream +} + +type streamingEchoNullClient struct { + ttrpc.ClientStream +} + +func (x *streamingEchoNullClient) Send(m *EchoPayload) error { + return x.ClientStream.SendMsg(m) +} + +func (x *streamingEchoNullClient) CloseAndRecv() (*empty.Empty, error) { + if err := x.ClientStream.CloseSend(); err != nil { + return nil, err + } + m := new(empty.Empty) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +func (c *streamingClient) EchoNullStream(ctx context.Context) (Streaming_EchoNullStreamClient, error) { + stream, err := c.client.NewStream(ctx, &ttrpc.StreamDesc{ + StreamingClient: true, + StreamingServer: true, + }, "ttrpc.integration.streaming.Streaming", "EchoNullStream", nil) + if err != nil { + return nil, err + } + x := &streamingEchoNullStreamClient{stream} + return x, nil +} + +type Streaming_EchoNullStreamClient interface { + Send(*EchoPayload) error + Recv() (*empty.Empty, error) + ttrpc.ClientStream +} + +type streamingEchoNullStreamClient struct { + ttrpc.ClientStream +} + +func (x *streamingEchoNullStreamClient) Send(m *EchoPayload) error { + return x.ClientStream.SendMsg(m) +} + +func (x *streamingEchoNullStreamClient) Recv() (*empty.Empty, error) { + m := new(empty.Empty) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} diff --git a/integration/streaming_test.go b/integration/streaming_test.go new file mode 100644 index 0000000..5dcc717 --- /dev/null +++ b/integration/streaming_test.go @@ -0,0 +1,425 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package integration + +import ( + "context" + "errors" + "fmt" + "io" + "math/rand" + "net" + "os" + "sync" + "testing" + "time" + + "github.com/containerd/ttrpc" + "github.com/containerd/ttrpc/integration/streaming" + "github.com/golang/protobuf/ptypes/empty" +) + +func runService(ctx context.Context, t testing.TB, service streaming.StreamingService) (streaming.StreamingClient, func()) { + server, err := ttrpc.NewServer() + if err != nil { + t.Fatal(err) + } + + streaming.RegisterStreamingService(server, service) + + addr := t.Name() + ".sock" + if err := os.RemoveAll(addr); err != nil { + t.Fatal(err) + } + listener, err := net.Listen("unix", addr) + if err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithCancel(ctx) + defer func() { + if t.Failed() { + cancel() + server.Close() + } + }() + + go func() { + err := server.Serve(ctx, listener) + if err != nil && !errors.Is(err, ttrpc.ErrServerClosed) { + t.Error(err) + } + }() + + conn, err := net.Dial("unix", addr) + if err != nil { + t.Fatal(err) + } + + client := ttrpc.NewClient(conn) + return streaming.NewStreamingClient(client), func() { + client.Close() + server.Close() + conn.Close() + cancel() + } +} + +type testStreamingService struct { + t testing.TB +} + +func (tss *testStreamingService) Echo(_ context.Context, e *streaming.EchoPayload) (*streaming.EchoPayload, error) { + e.Seq++ + return e, nil +} + +func (tss *testStreamingService) EchoStream(_ context.Context, es streaming.Streaming_EchoStreamServer) error { + for { + var e streaming.EchoPayload + if err := es.RecvMsg(&e); err != nil { + if err == io.EOF { + return nil + } + return err + } + e.Seq++ + if err := es.SendMsg(&e); err != nil { + return err + } + + } +} + +func (tss *testStreamingService) SumStream(_ context.Context, ss streaming.Streaming_SumStreamServer) (*streaming.Sum, error) { + var sum streaming.Sum + for { + var part streaming.Part + if err := ss.RecvMsg(&part); err != nil { + if err == io.EOF { + break + } + return nil, err + } + sum.Sum = sum.Sum + part.Add + sum.Num++ + } + + return &sum, nil +} + +func (tss *testStreamingService) DivideStream(_ context.Context, sum *streaming.Sum, ss streaming.Streaming_DivideStreamServer) error { + parts := divideSum(sum) + for _, part := range parts { + if err := ss.Send(part); err != nil { + return err + } + } + return nil +} +func (tss *testStreamingService) EchoNull(_ context.Context, es streaming.Streaming_EchoNullServer) (*empty.Empty, error) { + msg := "non-empty empty" + for seq := uint32(0); ; seq++ { + var e streaming.EchoPayload + if err := es.RecvMsg(&e); err != nil { + if err == io.EOF { + break + } + return nil, err + } + if e.Seq != seq { + return nil, fmt.Errorf("unexpected sequence %d, expected %d", e.Seq, seq) + } + if e.Msg != msg { + return nil, fmt.Errorf("unexpected message %q, expected %q", e.Msg, msg) + } + } + + return &empty.Empty{}, nil +} + +func (tss *testStreamingService) EchoNullStream(_ context.Context, es streaming.Streaming_EchoNullStreamServer) error { + msg := "non-empty empty" + empty := &empty.Empty{} + var wg sync.WaitGroup + var sendErr error + var errOnce sync.Once + for seq := uint32(0); ; seq++ { + var e streaming.EchoPayload + if err := es.RecvMsg(&e); err != nil { + if err == io.EOF { + break + } + return err + } + if e.Seq != seq { + return fmt.Errorf("unexpected sequence %d, expected %d", e.Seq, seq) + } + if e.Msg != msg { + return fmt.Errorf("unexpected message %q, expected %q", e.Msg, msg) + } + + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + if err := es.SendMsg(empty); err != nil { + errOnce.Do(func() { + sendErr = err + }) + } + }() + } + } + wg.Wait() + + return sendErr +} + +func TestStreamingService(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + client, cleanup := runService(ctx, t, &testStreamingService{t}) + defer cleanup() + + t.Run("Echo", echoTest(ctx, client)) + t.Run("EchoStream", echoStreamTest(ctx, client)) + t.Run("SumStream", sumStreamTest(ctx, client)) + t.Run("DivideStream", divideStreamTest(ctx, client)) + t.Run("EchoNull", echoNullTest(ctx, client)) + t.Run("EchoNullStream", echoNullStreamTest(ctx, client)) +} + +func echoTest(ctx context.Context, client streaming.StreamingClient) func(t *testing.T) { + return func(t *testing.T) { + echo1 := &streaming.EchoPayload{ + Seq: 1, + Msg: "Echo Me", + } + resp, err := client.Echo(ctx, echo1) + if err != nil { + t.Fatal(err) + } + assertNextEcho(t, echo1, resp) + } + +} + +func echoStreamTest(ctx context.Context, client streaming.StreamingClient) func(t *testing.T) { + return func(t *testing.T) { + stream, err := client.EchoStream(ctx) + if err != nil { + t.Fatal(err) + } + for i := 0; i < 100; i = i + 2 { + echoi := &streaming.EchoPayload{ + Seq: uint32(i), + Msg: fmt.Sprintf("%d: Echo in a stream", i), + } + if err := stream.Send(echoi); err != nil { + t.Fatal(err) + } + + resp, err := stream.Recv() + if err != nil { + t.Fatal(err) + } + assertNextEcho(t, echoi, resp) + } + + if err := stream.CloseSend(); err != nil { + t.Fatal(err) + } + if _, err := stream.Recv(); err != io.EOF { + t.Fatalf("Expected io.EOF, got %v", err) + } + } +} + +func sumStreamTest(ctx context.Context, client streaming.StreamingClient) func(t *testing.T) { + return func(t *testing.T) { + stream, err := client.SumStream(ctx) + if err != nil { + t.Fatal(err) + } + var sum streaming.Sum + if err := stream.Send(&streaming.Part{}); err != nil { + t.Fatal(err) + } + sum.Num++ + for i := -99; i <= 100; i++ { + addi := &streaming.Part{ + Add: int32(i), + } + if err := stream.Send(addi); err != nil { + t.Fatal(err) + } + sum.Sum = sum.Sum + int32(i) + sum.Num++ + } + if err := stream.Send(&streaming.Part{}); err != nil { + t.Fatal(err) + } + sum.Num++ + + ssum, err := stream.CloseAndRecv() + if err != nil { + t.Fatal(err) + } + assertSum(t, ssum, &sum) + } +} + +func divideStreamTest(ctx context.Context, client streaming.StreamingClient) func(t *testing.T) { + return func(t *testing.T) { + expected := &streaming.Sum{ + Sum: 392, + Num: 30, + } + + stream, err := client.DivideStream(ctx, expected) + if err != nil { + t.Fatal(err) + } + + var actual streaming.Sum + for { + part, err := stream.Recv() + if err != nil { + if err == io.EOF { + break + } + t.Fatal(err) + } + actual.Sum = actual.Sum + part.Add + actual.Num++ + } + assertSum(t, &actual, expected) + } +} +func echoNullTest(ctx context.Context, client streaming.StreamingClient) func(t *testing.T) { + return func(t *testing.T) { + stream, err := client.EchoNull(ctx) + if err != nil { + t.Fatal(err) + } + for i := 0; i < 100; i++ { + echoi := &streaming.EchoPayload{ + Seq: uint32(i), + Msg: "non-empty empty", + } + if err := stream.Send(echoi); err != nil { + t.Fatal(err) + } + } + + if _, err := stream.CloseAndRecv(); err != nil { + t.Fatal(err) + } + + } +} +func echoNullStreamTest(ctx context.Context, client streaming.StreamingClient) func(t *testing.T) { + return func(t *testing.T) { + stream, err := client.EchoNullStream(ctx) + if err != nil { + t.Fatal(err) + } + var c int + wait := make(chan error) + go func() { + defer close(wait) + for { + _, err := stream.Recv() + if err != nil { + if err != io.EOF { + wait <- err + } + return + } + c++ + } + + }() + + for i := 0; i < 100; i++ { + echoi := &streaming.EchoPayload{ + Seq: uint32(i), + Msg: "non-empty empty", + } + if err := stream.Send(echoi); err != nil { + t.Fatal(err) + } + + } + + if err := stream.CloseSend(); err != nil { + t.Fatal(err) + } + + select { + + case err := <-wait: + if err != nil { + t.Fatal(err) + } + case <-time.After(time.Second * 10): + t.Fatal("did not receive EOF within 10 seconds") + } + + } +} + +func assertNextEcho(t testing.TB, a, b *streaming.EchoPayload) { + t.Helper() + if a.Msg != b.Msg { + t.Fatalf("Mismatched messages: %q != %q", a.Msg, b.Msg) + } + if b.Seq != a.Seq+1 { + t.Fatalf("Wrong sequence ID: got %d, expected %d", b.Seq, a.Seq+1) + } +} + +func assertSum(t testing.TB, a, b *streaming.Sum) { + t.Helper() + if a.Sum != b.Sum { + t.Fatalf("Wrong sum %d, expected %d", a.Sum, b.Sum) + } + if a.Num != b.Num { + t.Fatalf("Wrong num %d, expected %d", a.Num, b.Num) + } +} + +func divideSum(sum *streaming.Sum) []*streaming.Part { + r := rand.New(rand.NewSource(14)) + var total int32 + parts := make([]*streaming.Part, sum.Num) + for i := int32(1); i < sum.Num-2; i++ { + add := r.Int31()%1000 - 500 + parts[i] = &streaming.Part{ + Add: add, + } + total = total + add + } + parts[0] = &streaming.Part{} + parts[sum.Num-2] = &streaming.Part{ + Add: sum.Sum - total, + } + parts[sum.Num-1] = &streaming.Part{} + return parts +} diff --git a/interceptor.go b/interceptor.go index c6a4a8d..7ff5e9d 100644 --- a/interceptor.go +++ b/interceptor.go @@ -28,6 +28,13 @@ type UnaryClientInfo struct { FullMethod string } +// StreamServerInfo provides information about the server request +type StreamServerInfo struct { + FullMethod string + StreamingClient bool + StreamingServer bool +} + // Unmarshaler contains the server request data and allows it to be unmarshaled // into a concrete type type Unmarshaler func(interface{}) error @@ -48,3 +55,11 @@ func defaultServerInterceptor(ctx context.Context, unmarshal Unmarshaler, _ *Una func defaultClientInterceptor(ctx context.Context, req *Request, resp *Response, _ *UnaryClientInfo, invoker Invoker) error { return invoker(ctx, req, resp) } + +type StreamServerInterceptor func(context.Context, StreamServer, *StreamServerInfo, StreamHandler) (interface{}, error) + +func defaultStreamServerInterceptor(ctx context.Context, ss StreamServer, _ *StreamServerInfo, stream StreamHandler) (interface{}, error) { + return stream(ctx, ss) +} + +type StreamClientInterceptor func(context.Context) diff --git a/internal/test.pb.go b/internal/test.pb.go index 743afa8..9f94b56 100644 --- a/internal/test.pb.go +++ b/internal/test.pb.go @@ -83,6 +83,61 @@ func (x *TestPayload) GetMetadata() string { return "" } +type EchoPayload struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Seq int64 `protobuf:"varint,1,opt,name=seq,proto3" json:"seq,omitempty"` + Msg string `protobuf:"bytes,2,opt,name=msg,proto3" json:"msg,omitempty"` +} + +func (x *EchoPayload) Reset() { + *x = EchoPayload{} + if protoimpl.UnsafeEnabled { + mi := &file_github_com_containerd_ttrpc_test_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *EchoPayload) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*EchoPayload) ProtoMessage() {} + +func (x *EchoPayload) ProtoReflect() protoreflect.Message { + mi := &file_github_com_containerd_ttrpc_test_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use EchoPayload.ProtoReflect.Descriptor instead. +func (*EchoPayload) Descriptor() ([]byte, []int) { + return file_github_com_containerd_ttrpc_test_proto_rawDescGZIP(), []int{1} +} + +func (x *EchoPayload) GetSeq() int64 { + if x != nil { + return x.Seq + } + return 0 +} + +func (x *EchoPayload) GetMsg() string { + if x != nil { + return x.Msg + } + return "" +} + var File_github_com_containerd_ttrpc_test_proto protoreflect.FileDescriptor var file_github_com_containerd_ttrpc_test_proto_rawDesc = []byte{ @@ -94,10 +149,13 @@ var file_github_com_containerd_ttrpc_test_proto_rawDesc = []byte{ 0x12, 0x1a, 0x0a, 0x08, 0x64, 0x65, 0x61, 0x64, 0x6c, 0x69, 0x6e, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x08, 0x64, 0x65, 0x61, 0x64, 0x6c, 0x69, 0x6e, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, - 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x42, 0x26, 0x5a, 0x24, 0x67, 0x69, 0x74, 0x68, - 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x63, 0x6f, 0x6e, 0x74, 0x61, 0x69, 0x6e, 0x65, 0x72, - 0x64, 0x2f, 0x74, 0x74, 0x72, 0x70, 0x63, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, - 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x22, 0x31, 0x0a, 0x0b, 0x45, 0x63, 0x68, 0x6f, + 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x10, 0x0a, 0x03, 0x73, 0x65, 0x71, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x03, 0x52, 0x03, 0x73, 0x65, 0x71, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x73, 0x67, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6d, 0x73, 0x67, 0x42, 0x26, 0x5a, 0x24, 0x67, + 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x63, 0x6f, 0x6e, 0x74, 0x61, 0x69, + 0x6e, 0x65, 0x72, 0x64, 0x2f, 0x74, 0x74, 0x72, 0x70, 0x63, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, + 0x6e, 0x61, 0x6c, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -112,9 +170,10 @@ func file_github_com_containerd_ttrpc_test_proto_rawDescGZIP() []byte { return file_github_com_containerd_ttrpc_test_proto_rawDescData } -var file_github_com_containerd_ttrpc_test_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_github_com_containerd_ttrpc_test_proto_msgTypes = make([]protoimpl.MessageInfo, 2) var file_github_com_containerd_ttrpc_test_proto_goTypes = []interface{}{ (*TestPayload)(nil), // 0: ttrpc.TestPayload + (*EchoPayload)(nil), // 1: ttrpc.EchoPayload } var file_github_com_containerd_ttrpc_test_proto_depIdxs = []int32{ 0, // [0:0] is the sub-list for method output_type @@ -142,6 +201,18 @@ func file_github_com_containerd_ttrpc_test_proto_init() { return nil } } + file_github_com_containerd_ttrpc_test_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*EchoPayload); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } } type x struct{} out := protoimpl.TypeBuilder{ @@ -149,7 +220,7 @@ func file_github_com_containerd_ttrpc_test_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_github_com_containerd_ttrpc_test_proto_rawDesc, NumEnums: 0, - NumMessages: 1, + NumMessages: 2, NumExtensions: 0, NumServices: 0, }, diff --git a/server.go b/server.go index b0e4807..17c6bcf 100644 --- a/server.go +++ b/server.go @@ -66,8 +66,14 @@ func NewServer(opts ...ServerOpt) (*Server, error) { }, nil } +// Register registers a map of methods to method handlers +// TODO: Remove in 2.0, does not support streams func (s *Server) Register(name string, methods map[string]Method) { - s.services.register(name, methods) + s.services.register(name, &ServiceDesc{Methods: methods}) +} + +func (s *Server) RegisterService(name string, desc *ServiceDesc) { + s.services.register(name, desc) } func (s *Server) Serve(ctx context.Context, l net.Listener) error { @@ -301,27 +307,24 @@ func (c *serverConn) close() error { func (c *serverConn) run(sctx context.Context) { type ( - request struct { - id uint32 - req *Request - } - response struct { - id uint32 - resp *Response + id uint32 + status *status.Status + data []byte + closeStream bool + streaming bool } ) var ( - ch = newChannel(c.conn) - ctx, cancel = context.WithCancel(sctx) - active int - state connState = connStateIdle - responses = make(chan response) - requests = make(chan request) - recvErr = make(chan error, 1) - shutdown = c.shutdown - done = make(chan struct{}) + ch = newChannel(c.conn) + ctx, cancel = context.WithCancel(sctx) + state connState = connStateIdle + responses = make(chan response) + recvErr = make(chan error, 1) + done = make(chan struct{}) + active int32 + lastStreamID uint32 ) defer c.conn.Close() @@ -329,27 +332,27 @@ func (c *serverConn) run(sctx context.Context) { defer close(done) defer c.server.delConnection(c) + sendStatus := func(id uint32, st *status.Status) bool { + select { + case responses <- response{ + // even though we've had an invalid stream id, we send it + // back on the same stream id so the client knows which + // stream id was bad. + id: id, + status: st, + closeStream: true, + }: + return true + case <-c.shutdown: + return false + case <-done: + return false + } + } + go func(recvErr chan error) { defer close(recvErr) - sendImmediate := func(id uint32, st *status.Status) bool { - select { - case responses <- response{ - // even though we've had an invalid stream id, we send it - // back on the same stream id so the client knows which - // stream id was bad. - id: id, - resp: &Response{ - Status: st.Proto(), - }, - }: - return true - case <-c.shutdown: - return false - case <-done: - return false - } - } - + streams := map[uint32]*streamHandler{} for { select { case <-c.shutdown: @@ -369,99 +372,159 @@ func (c *serverConn) run(sctx context.Context) { // in this case, we send an error for that particular message // when the status is defined. - if !sendImmediate(mh.StreamID, status) { + if !sendStatus(mh.StreamID, status) { return } continue } - if mh.Type != messageTypeRequest { - // we must ignore this for future compat. - continue - } - - var req Request - if err := c.server.codec.Unmarshal(p, &req); err != nil { - ch.putmbuf(p) - if !sendImmediate(mh.StreamID, status.Newf(codes.InvalidArgument, "unmarshal request error: %v", err)) { - return - } - continue - } - ch.putmbuf(p) - if mh.StreamID%2 != 1 { // enforce odd client initiated identifiers. - if !sendImmediate(mh.StreamID, status.Newf(codes.InvalidArgument, "StreamID must be odd for client initiated streams")) { + if !sendStatus(mh.StreamID, status.Newf(codes.InvalidArgument, "StreamID must be odd for client initiated streams")) { return } continue } - // Forward the request to the main loop. We don't wait on s.done - // because we have already accepted the client request. - select { - case requests <- request{ - id: mh.StreamID, - req: &req, - }: - case <-done: - return + if mh.Type == messageTypeData { + sh, ok := streams[mh.StreamID] + if !ok { + if !sendStatus(mh.StreamID, status.Newf(codes.InvalidArgument, "StreamID is no longer active")) { + return + } + } + if mh.Flags&flagNoData != flagNoData { + unmarshal := func(obj interface{}) error { + err := protoUnmarshal(p, obj) + ch.putmbuf(p) + return err + } + + if err := sh.data(unmarshal); err != nil { + if !sendStatus(mh.StreamID, status.Newf(codes.InvalidArgument, "data handling error: %v", err)) { + return + } + } + } + + if mh.Flags&flagRemoteClosed == flagRemoteClosed { + sh.closeSend() + if len(p) > 0 { + if !sendStatus(mh.StreamID, status.Newf(codes.InvalidArgument, "data close message cannot include data")) { + return + } + } + } + } else if mh.Type == messageTypeRequest { + if mh.StreamID <= lastStreamID { + // enforce odd client initiated identifiers. + if !sendStatus(mh.StreamID, status.Newf(codes.InvalidArgument, "StreamID cannot be re-used and must increment")) { + return + } + continue + + } + lastStreamID = mh.StreamID + + // TODO: Make request type configurable + // Unmarshaller which takes in a byte array and returns an interface? + var req Request + if err := c.server.codec.Unmarshal(p, &req); err != nil { + ch.putmbuf(p) + if !sendStatus(mh.StreamID, status.Newf(codes.InvalidArgument, "unmarshal request error: %v", err)) { + return + } + continue + } + ch.putmbuf(p) + + id := mh.StreamID + respond := func(status *status.Status, data []byte, streaming, closeStream bool) error { + select { + case responses <- response{ + id: id, + status: status, + data: data, + closeStream: closeStream, + streaming: streaming, + }: + case <-done: + return ErrClosed + } + return nil + } + sh, err := c.server.services.handle(ctx, &req, respond) + if err != nil { + status, _ := status.FromError(err) + if !sendStatus(mh.StreamID, status) { + return + } + continue + } + + streams[id] = sh + atomic.AddInt32(&active, 1) } + // TODO: else we must ignore this for future compat. log this? } }(recvErr) for { - newstate := state - switch { - case active > 0: + var ( + newstate connState + shutdown chan struct{} + ) + + activeN := atomic.LoadInt32(&active) + if activeN > 0 { newstate = connStateActive shutdown = nil - case active == 0: + } else { newstate = connStateIdle shutdown = c.shutdown // only enable this branch in idle mode } - if newstate != state { c.setState(newstate) state = newstate } select { - case request := <-requests: - active++ - go func(id uint32) { - ctx, cancel := getRequestContext(ctx, request.req) - defer cancel() - - p, status := c.server.services.call(ctx, request.req.Service, request.req.Method, request.req.Payload) - resp := &Response{ - Status: status.Proto(), - Payload: p, - } - - select { - case responses <- response{ - id: id, - resp: resp, - }: - case <-done: - } - }(request.id) case response := <-responses: - p, err := c.server.codec.Marshal(response.resp) - if err != nil { - logrus.WithError(err).Error("failed marshaling response") - return + if !response.streaming || response.status.Code() != codes.OK { + p, err := c.server.codec.Marshal(&Response{ + Status: response.status.Proto(), + Payload: response.data, + }) + if err != nil { + logrus.WithError(err).Error("failed marshaling response") + return + } + + if err := ch.send(response.id, messageTypeResponse, 0, p); err != nil { + logrus.WithError(err).Error("failed sending message on channel") + return + } + } else { + var flags uint8 + if response.closeStream { + flags = flagRemoteClosed + } + if response.data == nil { + flags = flags | flagNoData + } + if err := ch.send(response.id, messageTypeData, flags, response.data); err != nil { + logrus.WithError(err).Error("failed sending message on channel") + return + } } - if err := ch.send(response.id, messageTypeResponse, p); err != nil { - logrus.WithError(err).Error("failed sending message on channel") - return + if response.closeStream { + // The ttrpc protocol currently does not support the case where + // the server is localClosed but not remoteClosed. Once the server + // is closing, the whole stream may be considered finished + atomic.AddInt32(&active, -1) } - - active-- case err := <-recvErr: // TODO(stevvooe): Not wildly clear what we should do in this // branch. Basically, it means that we are no longer receiving @@ -475,6 +538,7 @@ func (c *serverConn) run(sctx context.Context) { if err != nil { logrus.WithError(err).Error("error receiving message") } + // else, initiate shutdown case <-shutdown: return } diff --git a/server_test.go b/server_test.go index 791c835..7cb9d34 100644 --- a/server_test.go +++ b/server_test.go @@ -22,7 +22,6 @@ import ( "errors" "fmt" "net" - "os" "runtime" "strings" "sync" @@ -372,18 +371,9 @@ func TestClientEOF(t *testing.T) { if err := client.Call(ctx, serviceName, "Test", tp, tp); err == nil { t.Fatalf("expected error when calling against shutdown server") } else if !errors.Is(err, ErrClosed) { - errno, ok := err.(syscall.Errno) - if ok { + var errno syscall.Errno + if errors.As(err, &errno) { t.Logf("errno=%d", errno) - } else { - var oerr *net.OpError - if errors.As(err, &oerr) { - serr, sok := oerr.Err.(*os.SyscallError) - if sok { - t.Logf("Op=%q, syscall=%s, Err=%v", oerr.Op, serr.Syscall, serr.Err) - } - } - t.Logf("error %q doesn't match syscall.Errno", err) } t.Fatalf("expected to have a cause of ErrClosed, got %v", err) diff --git a/services.go b/services.go index 57c8f80..6aabfbb 100644 --- a/services.go +++ b/services.go @@ -32,36 +32,55 @@ import ( type Method func(ctx context.Context, unmarshal func(interface{}) error) (interface{}, error) +type StreamHandler func(context.Context, StreamServer) (interface{}, error) + +type Stream struct { + Handler StreamHandler + StreamingClient bool + StreamingServer bool +} + type ServiceDesc struct { Methods map[string]Method - - // TODO(stevvooe): Add stream support. + Streams map[string]Stream } type serviceSet struct { - services map[string]ServiceDesc - interceptor UnaryServerInterceptor + services map[string]*ServiceDesc + unaryInterceptor UnaryServerInterceptor + streamInterceptor StreamServerInterceptor } func newServiceSet(interceptor UnaryServerInterceptor) *serviceSet { return &serviceSet{ - services: make(map[string]ServiceDesc), - interceptor: interceptor, + services: make(map[string]*ServiceDesc), + unaryInterceptor: interceptor, + streamInterceptor: defaultStreamServerInterceptor, } } -func (s *serviceSet) register(name string, methods map[string]Method) { +func (s *serviceSet) register(name string, desc *ServiceDesc) { if _, ok := s.services[name]; ok { panic(fmt.Errorf("duplicate service %v registered", name)) } - s.services[name] = ServiceDesc{ - Methods: methods, - } + s.services[name] = desc } -func (s *serviceSet) call(ctx context.Context, serviceName, methodName string, p []byte) ([]byte, *status.Status) { - p, err := s.dispatch(ctx, serviceName, methodName, p) +func (s *serviceSet) unaryCall(ctx context.Context, method Method, info *UnaryServerInfo, data []byte) (p []byte, st *status.Status) { + unmarshal := func(obj interface{}) error { + return protoUnmarshal(data, obj) + } + + resp, err := s.unaryInterceptor(ctx, unmarshal, info, method) + if err == nil { + if isNil(resp) { + err = errors.New("ttrpc: marshal called with nil") + } else { + p, err = protoMarshal(resp) + } + } + st, ok := status.FromError(err) if !ok { st = status.New(convertCode(err), err.Error()) @@ -70,38 +89,142 @@ func (s *serviceSet) call(ctx context.Context, serviceName, methodName string, p return p, st } -func (s *serviceSet) dispatch(ctx context.Context, serviceName, methodName string, p []byte) ([]byte, error) { - method, err := s.resolve(serviceName, methodName) - if err != nil { - return nil, err +func (s *serviceSet) streamCall(ctx context.Context, stream StreamHandler, info *StreamServerInfo, ss StreamServer) (p []byte, st *status.Status) { + resp, err := s.streamInterceptor(ctx, ss, info, stream) + if err == nil { + p, err = protoMarshal(resp) + } + st, ok := status.FromError(err) + if !ok { + st = status.New(convertCode(err), err.Error()) + } + return +} + +func (s *serviceSet) handle(ctx context.Context, req *Request, respond func(*status.Status, []byte, bool, bool) error) (*streamHandler, error) { + srv, ok := s.services[req.Service] + if !ok { + return nil, status.Errorf(codes.Unimplemented, "service %v", req.Service) } - unmarshal := func(obj interface{}) error { - switch v := obj.(type) { - case proto.Message: - if err := proto.Unmarshal(p, v); err != nil { - return status.Errorf(codes.Internal, "ttrpc: error unmarshalling payload: %v", err.Error()) + if method, ok := srv.Methods[req.Method]; ok { + go func() { + ctx, cancel := getRequestContext(ctx, req) + defer cancel() + + info := &UnaryServerInfo{ + FullMethod: fullPath(req.Service, req.Method), } - default: - return status.Errorf(codes.Internal, "ttrpc: error unsupported request type: %T", v) + p, st := s.unaryCall(ctx, method, info, req.Payload) + + respond(st, p, false, true) + }() + return nil, nil + } + if stream, ok := srv.Streams[req.Method]; ok { + ctx, cancel := getRequestContext(ctx, req) + info := &StreamServerInfo{ + FullMethod: fullPath(req.Service, req.Method), + StreamingClient: stream.StreamingClient, + StreamingServer: stream.StreamingServer, } + sh := &streamHandler{ + ctx: ctx, + respond: respond, + recv: make(chan Unmarshaler, 5), + info: info, + } + go func() { + defer cancel() + p, st := s.streamCall(ctx, stream.Handler, info, sh) + respond(st, p, stream.StreamingServer, true) + }() + + if req.Payload != nil { + unmarshal := func(obj interface{}) error { + return protoUnmarshal(req.Payload, obj) + } + if err := sh.data(unmarshal); err != nil { + return nil, err + } + } + + return sh, nil + } + return nil, status.Errorf(codes.Unimplemented, "method %v", req.Method) +} + +type streamHandler struct { + ctx context.Context + respond func(*status.Status, []byte, bool, bool) error + recv chan Unmarshaler + info *StreamServerInfo + + remoteClosed bool + localClosed bool +} + +func (s *streamHandler) closeSend() { + if !s.remoteClosed { + s.remoteClosed = true + close(s.recv) + } +} + +func (s *streamHandler) data(unmarshal Unmarshaler) error { + if s.remoteClosed { + return ErrStreamClosed + } + select { + case s.recv <- unmarshal: return nil + case <-s.ctx.Done(): + return s.ctx.Err() } +} - info := &UnaryServerInfo{ - FullMethod: fullPath(serviceName, methodName), +func (s *streamHandler) SendMsg(m interface{}) error { + if s.localClosed { + return ErrStreamClosed } - - resp, err := s.interceptor(ctx, unmarshal, info, method) + p, err := protoMarshal(m) if err != nil { - return nil, err + return err + } + return s.respond(nil, p, true, false) +} + +func (s *streamHandler) RecvMsg(m interface{}) error { + select { + case unmarshal, ok := <-s.recv: + if !ok { + return io.EOF + } + return unmarshal(m) + case <-s.ctx.Done(): + return s.ctx.Err() + + } +} + +func protoUnmarshal(p []byte, obj interface{}) error { + switch v := obj.(type) { + case proto.Message: + if err := proto.Unmarshal(p, v); err != nil { + return status.Errorf(codes.Internal, "ttrpc: error unmarshalling payload: %v", err.Error()) + } + default: + return status.Errorf(codes.Internal, "ttrpc: error unsupported request type: %T", v) + } + return nil +} + +func protoMarshal(obj interface{}) ([]byte, error) { + if obj == nil { + return nil, nil } - if isNil(resp) { - return nil, errors.New("ttrpc: marshal called with nil") - } - - switch v := resp.(type) { + switch v := obj.(type) { case proto.Message: r, err := proto.Marshal(v) if err != nil { @@ -114,20 +237,6 @@ func (s *serviceSet) dispatch(ctx context.Context, serviceName, methodName strin } } -func (s *serviceSet) resolve(service, method string) (Method, error) { - srv, ok := s.services[service] - if !ok { - return nil, status.Errorf(codes.Unimplemented, "service %v", service) - } - - mthd, ok := srv.Methods[method] - if !ok { - return nil, status.Errorf(codes.Unimplemented, "method %v", method) - } - - return mthd, nil -} - // convertCode maps stdlib go errors into grpc space. // // This is ripped from the grpc-go code base. diff --git a/stream.go b/stream.go new file mode 100644 index 0000000..5609e14 --- /dev/null +++ b/stream.go @@ -0,0 +1,84 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package ttrpc + +import ( + "context" + "errors" + "sync" +) + +var ErrStreamClosed = errors.New("ttrpc: stream closed") + +type streamID uint32 + +type streamMessage struct { + header messageHeader + payload []byte +} + +type stream struct { + id streamID + sender sender + recv chan *streamMessage + + closeOnce sync.Once + recvErr error +} + +func newStream(id streamID, send sender) *stream { + return &stream{ + id: id, + sender: send, + recv: make(chan *streamMessage, 1), + } +} + +func (s *stream) closeWithError(err error) error { + s.closeOnce.Do(func() { + if s.recv != nil { + close(s.recv) + if err != nil { + s.recvErr = err + } else { + s.recvErr = ErrClosed + } + + } + }) + return nil +} + +func (s *stream) send(mt messageType, flags uint8, b []byte) error { + return s.sender.send(uint32(s.id), mt, flags, b) +} + +func (s *stream) receive(ctx context.Context, msg *streamMessage) error { + if s.recvErr != nil { + return s.recvErr + } + select { + case s.recv <- msg: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +type sender interface { + send(uint32, messageType, uint8, []byte) error +} diff --git a/stream_server.go b/stream_server.go new file mode 100644 index 0000000..b6d1ba7 --- /dev/null +++ b/stream_server.go @@ -0,0 +1,22 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package ttrpc + +type StreamServer interface { + SendMsg(m interface{}) error + RecvMsg(m interface{}) error +} diff --git a/stream_test.go b/stream_test.go new file mode 100644 index 0000000..1e04f1b --- /dev/null +++ b/stream_test.go @@ -0,0 +1,118 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package ttrpc + +import ( + "context" + "io" + "testing" + + "github.com/containerd/ttrpc/internal" +) + +func TestStreamClient(t *testing.T) { + var ( + ctx = context.Background() + server = mustServer(t)(NewServer()) + addr, listener = newTestListener(t) + client, cleanup = newTestClient(t, addr) + serviceName = "streamService" + ) + + defer listener.Close() + defer cleanup() + + desc := &ServiceDesc{ + Methods: map[string]Method{ + "Echo": func(ctx context.Context, unmarshal func(interface{}) error) (interface{}, error) { + var req internal.EchoPayload + if err := unmarshal(&req); err != nil { + return nil, err + } + req.Seq++ + return &req, nil + }, + }, + Streams: map[string]Stream{ + "EchoStream": { + Handler: func(ctx context.Context, ss StreamServer) (interface{}, error) { + for { + var req internal.EchoPayload + if err := ss.RecvMsg(&req); err != nil { + if err == io.EOF { + err = nil + } + return nil, err + } + req.Seq++ + if err := ss.SendMsg(&req); err != nil { + return nil, err + } + } + + }, + StreamingClient: true, + StreamingServer: true, + }, + }, + } + server.RegisterService(serviceName, desc) + + go server.Serve(ctx, listener) + defer server.Shutdown(ctx) + + //func (c *Client) NewStream(ctx context.Context, desc *StreamDesc, service, method string) (ClientStream, error) { + var req, resp internal.EchoPayload + if err := client.Call(ctx, serviceName, "Echo", &req, &resp); err != nil { + t.Fatal(err) + } + + stream, err := client.NewStream(ctx, &StreamDesc{true, true}, serviceName, "EchoStream", nil) + if err != nil { + t.Fatal(err) + } + for i := 1; i <= 100; i++ { + req := internal.EchoPayload{ + Seq: int64(i), + Msg: "should be returned", + } + if err := stream.SendMsg(&req); err != nil { + t.Fatalf("%d: %v", i, err) + } + var resp internal.EchoPayload + if err := stream.RecvMsg(&resp); err != nil { + t.Fatalf("%d: %v", i, err) + } + if resp.Seq != int64(i)+1 { + t.Fatalf("%d: unexpected sequence value: %d, expected %d", i, resp.Seq, i+1) + } + if resp.Msg != req.Msg { + t.Fatalf("%d: unexpected message: %q, expected %q", i, resp.Msg, req.Msg) + } + } + if err := stream.CloseSend(); err != nil { + t.Fatal(err) + } + + err = stream.RecvMsg(&resp) + if err == nil { + t.Fatal("expected io.EOF after close send") + } + if err != io.EOF { + t.Fatalf("expected io.EOF after close send, got %v", err) + } +} diff --git a/test.proto b/test.proto index deeed94..0e114d5 100644 --- a/test.proto +++ b/test.proto @@ -9,3 +9,8 @@ message TestPayload { int64 deadline = 2; string metadata = 3; } + +message EchoPayload { + int64 seq = 1; + string msg = 2; +}