diff --git a/client.go b/client.go new file mode 100644 index 0000000..6efd7e9 --- /dev/null +++ b/client.go @@ -0,0 +1,51 @@ +package mgrpc + +import ( + "context" + "net" + + "github.com/containerd/typeurl" +) + +type Client struct { + channel *channel +} + +func NewClient(conn net.Conn) *Client { + return &Client{ + channel: newChannel(conn), + } +} + +func (c *Client) Call(ctx context.Context, service, method string, req interface{}) (interface{}, error) { + payload, err := typeurl.MarshalAny(req) + if err != nil { + return nil, err + } + + request := Request{ + Service: service, + Method: method, + Payload: payload, + } + + if err := c.channel.send(ctx, &request); err != nil { + return nil, err + } + + var response Response + + if err := c.channel.recv(ctx, &response); err != nil { + return nil, err + } + + // TODO(stevvooe): Reliance on the typeurl isn't great for bootstrapping + // and ease of use. Let's consider a request header frame and body frame as + // a better solution. This will allow the caller to set the exact type. + rpayload, err := typeurl.UnmarshalAny(response.Payload) + if err != nil { + return nil, err + } + + return rpayload, nil +} diff --git a/handlers.go b/handlers.go new file mode 100644 index 0000000..c34fec4 --- /dev/null +++ b/handlers.go @@ -0,0 +1,13 @@ +package mgrpc + +import "context" + +type Handler interface { + Handle(ctx context.Context, req interface{}) (interface{}, error) +} + +type HandlerFunc func(ctx context.Context, req interface{}) (interface{}, error) + +func (fn HandlerFunc) Handle(ctx context.Context, req interface{}) (interface{}, error) { + return fn(ctx, req) +} diff --git a/server.go b/server.go new file mode 100644 index 0000000..3691a36 --- /dev/null +++ b/server.go @@ -0,0 +1,131 @@ +package mgrpc + +import ( + "context" + "net" + "path" + + "github.com/containerd/containerd/errdefs" + "github.com/containerd/containerd/log" + "github.com/containerd/typeurl" + "github.com/pkg/errors" +) + +type Server struct { + handlers map[string]map[string]Handler +} + +func NewServer() *Server { + return &Server{handlers: make(map[string]map[string]Handler)} +} + +func (s *Server) Register(name string, methods map[string]Handler) error { + if _, ok := s.handlers[name]; ok { + return errors.Errorf("duplicate service %v registered", name) + } + + s.handlers[name] = methods + return nil +} + +func (s *Server) Shutdown(ctx context.Context) error { + // TODO(stevvooe): Wait on connection shutdown. + return nil +} + +func (s *Server) Serve(l net.Listener) error { + for { + conn, err := l.Accept() + if err != nil { + log.L.WithError(err).Error("failed accept") + } + + go s.handleConn(conn) + } + + return nil +} + +const maxMessageSize = 1 << 20 // TODO(stevvooe): Cut these down, since they are pre-alloced. + +func (s *Server) handleConn(conn net.Conn) { + defer conn.Close() + + var ( + ch = newChannel(conn) + req Request + ctx, cancel = context.WithCancel(context.Background()) + ) + + defer cancel() + + // TODO(stevvooe): Recover here or in dispatch to handle panics in service + // methods. + + // every connection is just a simple in/out request loop. No complexity for + // multiplexing streams or dealing with head of line blocking, as this + // isn't necessary for shim control. + for { + if err := ch.recv(ctx, &req); err != nil { + log.L.WithError(err).Error("failed receiving message on channel") + return + } + + resp, err := s.dispatch(ctx, &req) + if err != nil { + log.L.WithError(err).Error("failed to dispatch request") + return + } + + if err := ch.send(ctx, resp); err != nil { + log.L.WithError(err).Error("failed sending message on channel") + return + } + } +} + +func (s *Server) dispatch(ctx context.Context, req *Request) (*Response, error) { + ctx = log.WithLogger(ctx, log.G(ctx).WithField("method", path.Join("/", req.Service, req.Method))) + handler, err := s.resolve(req.Service, req.Method) + if err != nil { + log.L.WithError(err).Error("failed to resolve handler") + return nil, err + } + + payload, err := typeurl.UnmarshalAny(req.Payload) + if err != nil { + return nil, err + } + + resp, err := handler.Handle(ctx, payload) + if err != nil { + log.L.WithError(err).Error("handler returned an error") + return nil, err + } + + apayload, err := typeurl.MarshalAny(resp) + if err != nil { + return nil, err + } + + rresp := &Response{ + // Status: *st, + Payload: apayload, + } + + return rresp, nil +} + +func (s *Server) resolve(service, method string) (Handler, error) { + srv, ok := s.handlers[service] + if !ok { + return nil, errors.Wrapf(errdefs.ErrNotFound, "could not resolve service %v", service) + } + + handler, ok := srv[method] + if !ok { + return nil, errors.Wrapf(errdefs.ErrNotFound, "could not resolve method %v", method) + } + + return handler, nil +} diff --git a/server_test.go b/server_test.go new file mode 100644 index 0000000..3380f2c --- /dev/null +++ b/server_test.go @@ -0,0 +1,71 @@ +package mgrpc + +import ( + "context" + "fmt" + "net" + "testing" + + "github.com/gogo/protobuf/proto" +) + +// var serverMethods = map[string]Handler{ +// "Create": HandlerFunc(func(ctx context.Context, req interface{}) (interface{}, error) { + +// }, +// } + +type testPayload struct { + Foo string `protobuf:"bytes,1,opt,name=foo,proto3"` +} + +func (r *testPayload) Reset() { *r = testPayload{} } +func (r *testPayload) String() string { return fmt.Sprintf("%+#v", r) } +func (r *testPayload) ProtoMessage() {} + +func init() { + proto.RegisterType((*testPayload)(nil), "testpayload") + proto.RegisterType((*Request)(nil), "Request") + proto.RegisterType((*Response)(nil), "Response") +} + +func TestServer(t *testing.T) { + server := NewServer() + ctx := context.Background() + + if err := server.Register("test-service", map[string]Handler{ + "Test": HandlerFunc(func(ctx context.Context, req interface{}) (interface{}, error) { + fmt.Println(req) + + return &testPayload{Foo: "baz"}, nil + }), + }); err != nil { + t.Fatal(err) + } + + listener, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatal(err) + } + defer listener.Close() + + go server.Serve(listener) + defer server.Shutdown(ctx) + + conn, err := net.Dial("tcp", listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + client := NewClient(conn) + + tp := &testPayload{ + Foo: "bar", + } + resp, err := client.Call(ctx, "test-service", "Test", tp) + if err != nil { + t.Fatal(err) + } + fmt.Println(resp) +} diff --git a/types.go b/types.go new file mode 100644 index 0000000..fc78a3c --- /dev/null +++ b/types.go @@ -0,0 +1,27 @@ +package mgrpc + +import ( + "fmt" + + "github.com/containerd/containerd/protobuf/google/rpc" + "github.com/gogo/protobuf/types" +) + +type Request struct { + Service string `protobuf:"bytes,1,opt,name=service,proto3"` + Method string `protobuf:"bytes,2,opt,name=method,proto3"` + Payload *types.Any `protobuf:"bytes,3,opt,name=payload,proto3"` +} + +func (r *Request) Reset() { *r = Request{} } +func (r *Request) String() string { return fmt.Sprintf("%+#v", r) } +func (r *Request) ProtoMessage() {} + +type Response struct { + Status *rpc.Status `protobuf:"bytes,1,opt,name=status,proto3"` + Payload *types.Any `protobuf:"bytes,2,opt,name=payload,proto3"` +} + +func (r *Response) Reset() { *r = Response{} } +func (r *Response) String() string { return fmt.Sprintf("%+#v", r) } +func (r *Response) ProtoMessage() {}