Add streaming support to go-ttrpc generator
Signed-off-by: Derek McGowan <derek@mcg.dev>
This commit is contained in:
@@ -17,6 +17,7 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"google.golang.org/protobuf/compiler/protogen"
|
"google.golang.org/protobuf/compiler/protogen"
|
||||||
@@ -33,6 +34,15 @@ type generator struct {
|
|||||||
server string
|
server string
|
||||||
client string
|
client string
|
||||||
method string
|
method string
|
||||||
|
stream string
|
||||||
|
serviceDesc string
|
||||||
|
streamDesc string
|
||||||
|
|
||||||
|
streamServerIdent protogen.GoIdent
|
||||||
|
streamClientIdent protogen.GoIdent
|
||||||
|
|
||||||
|
streamServer string
|
||||||
|
streamClient string
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -54,6 +64,29 @@ func newGenerator(out *protogen.GeneratedFile) *generator {
|
|||||||
GoImportPath: "github.com/containerd/ttrpc",
|
GoImportPath: "github.com/containerd/ttrpc",
|
||||||
GoName: "Method",
|
GoName: "Method",
|
||||||
})
|
})
|
||||||
|
gen.ident.stream = out.QualifiedGoIdent(protogen.GoIdent{
|
||||||
|
GoImportPath: "github.com/containerd/ttrpc",
|
||||||
|
GoName: "Stream",
|
||||||
|
})
|
||||||
|
gen.ident.serviceDesc = out.QualifiedGoIdent(protogen.GoIdent{
|
||||||
|
GoImportPath: "github.com/containerd/ttrpc",
|
||||||
|
GoName: "ServiceDesc",
|
||||||
|
})
|
||||||
|
gen.ident.streamDesc = out.QualifiedGoIdent(protogen.GoIdent{
|
||||||
|
GoImportPath: "github.com/containerd/ttrpc",
|
||||||
|
GoName: "StreamDesc",
|
||||||
|
})
|
||||||
|
|
||||||
|
gen.ident.streamServerIdent = protogen.GoIdent{
|
||||||
|
GoImportPath: "github.com/containerd/ttrpc",
|
||||||
|
GoName: "StreamServer",
|
||||||
|
}
|
||||||
|
gen.ident.streamClientIdent = protogen.GoIdent{
|
||||||
|
GoImportPath: "github.com/containerd/ttrpc",
|
||||||
|
GoName: "ClientStream",
|
||||||
|
}
|
||||||
|
gen.ident.streamServer = out.QualifiedGoIdent(gen.ident.streamServerIdent)
|
||||||
|
gen.ident.streamClient = out.QualifiedGoIdent(gen.ident.streamClientIdent)
|
||||||
return &gen
|
return &gen
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -74,20 +107,79 @@ func (gen *generator) genService(service *protogen.Service) {
|
|||||||
fullName := service.Desc.FullName()
|
fullName := service.Desc.FullName()
|
||||||
p := gen.out
|
p := gen.out
|
||||||
|
|
||||||
|
var methods []*protogen.Method
|
||||||
|
var streams []*protogen.Method
|
||||||
|
|
||||||
serviceName := service.GoName + "Service"
|
serviceName := service.GoName + "Service"
|
||||||
p.P("type ", serviceName, " interface{")
|
p.P("type ", serviceName, " interface{")
|
||||||
for _, method := range service.Methods {
|
for _, method := range service.Methods {
|
||||||
p.P(method.GoName,
|
var sendArgs, retArgs string
|
||||||
"(ctx ", gen.ident.context, ",",
|
if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() {
|
||||||
"req *", method.Input.GoIdent, ")",
|
streams = append(streams, method)
|
||||||
"(*", method.Output.GoIdent, ", error)")
|
sendArgs = fmt.Sprintf("%s_%sServer", service.GoName, method.GoName)
|
||||||
|
if !method.Desc.IsStreamingClient() {
|
||||||
|
sendArgs = fmt.Sprintf("*%s, %s", p.QualifiedGoIdent(method.Input.GoIdent), sendArgs)
|
||||||
|
}
|
||||||
|
if method.Desc.IsStreamingServer() {
|
||||||
|
retArgs = "error"
|
||||||
|
} else {
|
||||||
|
retArgs = fmt.Sprintf("(*%s, error)", p.QualifiedGoIdent(method.Output.GoIdent))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
methods = append(methods, method)
|
||||||
|
sendArgs = fmt.Sprintf("*%s", p.QualifiedGoIdent(method.Input.GoIdent))
|
||||||
|
retArgs = fmt.Sprintf("(*%s, error)", p.QualifiedGoIdent(method.Output.GoIdent))
|
||||||
|
}
|
||||||
|
p.P(method.GoName, "(", gen.ident.context, ", ", sendArgs, ") ", retArgs)
|
||||||
}
|
}
|
||||||
p.P("}")
|
p.P("}")
|
||||||
|
p.P()
|
||||||
|
|
||||||
|
for _, method := range streams {
|
||||||
|
structName := strings.ToLower(service.GoName) + method.GoName + "Server"
|
||||||
|
|
||||||
|
p.P("type ", service.GoName, "_", method.GoName, "Server interface {")
|
||||||
|
if method.Desc.IsStreamingServer() {
|
||||||
|
p.P("Send(*", method.Output.GoIdent, ") error")
|
||||||
|
}
|
||||||
|
if method.Desc.IsStreamingClient() {
|
||||||
|
p.P("Recv() (*", method.Input.GoIdent, ", error)")
|
||||||
|
|
||||||
|
}
|
||||||
|
p.P(gen.ident.streamServer)
|
||||||
|
p.P("}")
|
||||||
|
p.P()
|
||||||
|
|
||||||
|
p.P("type ", structName, " struct {")
|
||||||
|
p.P(gen.ident.streamServer)
|
||||||
|
p.P("}")
|
||||||
|
p.P()
|
||||||
|
|
||||||
|
if method.Desc.IsStreamingServer() {
|
||||||
|
p.P("func (x *", structName, ") Send(m *", method.Output.GoIdent, ") error {")
|
||||||
|
p.P("return x.StreamServer.SendMsg(m)")
|
||||||
|
p.P("}")
|
||||||
|
p.P()
|
||||||
|
}
|
||||||
|
|
||||||
|
if method.Desc.IsStreamingClient() {
|
||||||
|
p.P("func (x *", structName, ") Recv() (*", method.Input.GoIdent, ", error) {")
|
||||||
|
p.P("m := new(", method.Input.GoIdent, ")")
|
||||||
|
p.P("if err := x.StreamServer.RecvMsg(m); err != nil {")
|
||||||
|
p.P("return nil, err")
|
||||||
|
p.P("}")
|
||||||
|
p.P("return m, nil")
|
||||||
|
p.P("}")
|
||||||
|
p.P()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// registration method
|
// registration method
|
||||||
p.P("func Register", serviceName, "(srv *", gen.ident.server, ", svc ", serviceName, "){")
|
p.P("func Register", serviceName, "(srv *", gen.ident.server, ", svc ", serviceName, "){")
|
||||||
p.P(`srv.Register("`, fullName, `", map[string]`, gen.ident.method, "{")
|
p.P(`srv.RegisterService("`, fullName, `", &`, gen.ident.serviceDesc, "{")
|
||||||
for _, method := range service.Methods {
|
if len(methods) > 0 {
|
||||||
|
p.P(`Methods: map[string]`, gen.ident.method, "{")
|
||||||
|
for _, method := range methods {
|
||||||
p.P(`"`, method.GoName, `": func(ctx `, gen.ident.context, ", unmarshal func(interface{}) error)(interface{}, error){")
|
p.P(`"`, method.GoName, `": func(ctx `, gen.ident.context, ", unmarshal func(interface{}) error)(interface{}, error){")
|
||||||
p.P("var req ", method.Input.GoIdent)
|
p.P("var req ", method.Input.GoIdent)
|
||||||
p.P("if err := unmarshal(&req); err != nil {")
|
p.P("if err := unmarshal(&req); err != nil {")
|
||||||
@@ -96,30 +188,196 @@ func (gen *generator) genService(service *protogen.Service) {
|
|||||||
p.P("return svc.", method.GoName, "(ctx, &req)")
|
p.P("return svc.", method.GoName, "(ctx, &req)")
|
||||||
p.P("},")
|
p.P("},")
|
||||||
}
|
}
|
||||||
|
p.P("},")
|
||||||
|
}
|
||||||
|
if len(streams) > 0 {
|
||||||
|
p.P(`Streams: map[string]`, gen.ident.stream, "{")
|
||||||
|
for _, method := range streams {
|
||||||
|
p.P(`"`, method.GoName, `": {`)
|
||||||
|
p.P(`Handler: func(ctx `, gen.ident.context, ", stream ", gen.ident.streamServer, ") (interface{}, error) {")
|
||||||
|
|
||||||
|
structName := strings.ToLower(service.GoName) + method.GoName + "Server"
|
||||||
|
var sendArg string
|
||||||
|
if !method.Desc.IsStreamingClient() {
|
||||||
|
sendArg = "m, "
|
||||||
|
p.P("m := new(", method.Input.GoIdent, ")")
|
||||||
|
p.P("if err := stream.RecvMsg(m); err != nil {")
|
||||||
|
p.P("return nil, err")
|
||||||
|
p.P("}")
|
||||||
|
}
|
||||||
|
if method.Desc.IsStreamingServer() {
|
||||||
|
p.P("return nil, svc.", method.GoName, "(ctx, ", sendArg, "&", structName, "{stream})")
|
||||||
|
} else {
|
||||||
|
p.P("return svc.", method.GoName, "(ctx, ", sendArg, "&", structName, "{stream})")
|
||||||
|
|
||||||
|
}
|
||||||
|
p.P("},")
|
||||||
|
if method.Desc.IsStreamingClient() {
|
||||||
|
p.P("StreamingClient: true,")
|
||||||
|
} else {
|
||||||
|
p.P("StreamingClient: false,")
|
||||||
|
}
|
||||||
|
if method.Desc.IsStreamingServer() {
|
||||||
|
p.P("StreamingServer: true,")
|
||||||
|
} else {
|
||||||
|
p.P("StreamingServer: false,")
|
||||||
|
}
|
||||||
|
p.P("},")
|
||||||
|
}
|
||||||
|
p.P("},")
|
||||||
|
}
|
||||||
p.P("})")
|
p.P("})")
|
||||||
p.P("}")
|
p.P("}")
|
||||||
|
p.P()
|
||||||
|
|
||||||
clientType := service.GoName + "Client"
|
clientType := service.GoName + "Client"
|
||||||
|
|
||||||
|
// For consistency with ttrpc 1.0 without streaming, just use
|
||||||
|
// the service name if no streams are defined
|
||||||
|
clientInterface := serviceName
|
||||||
|
if len(streams) > 0 {
|
||||||
|
clientInterface = clientType
|
||||||
|
// Stream client interfaces are different than the server interface
|
||||||
|
p.P("type ", clientInterface, " interface{")
|
||||||
|
for _, method := range service.Methods {
|
||||||
|
if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() {
|
||||||
|
streams = append(streams, method)
|
||||||
|
var sendArg string
|
||||||
|
if !method.Desc.IsStreamingClient() {
|
||||||
|
sendArg = fmt.Sprintf("*%s, ", p.QualifiedGoIdent(method.Input.GoIdent))
|
||||||
|
}
|
||||||
|
p.P(method.GoName,
|
||||||
|
"(", gen.ident.context, ", ", sendArg,
|
||||||
|
") (", service.GoName, "_", method.GoName, "Client, error)")
|
||||||
|
} else {
|
||||||
|
methods = append(methods, method)
|
||||||
|
p.P(method.GoName,
|
||||||
|
"(", gen.ident.context, ", ",
|
||||||
|
"*", method.Input.GoIdent, ")",
|
||||||
|
"(*", method.Output.GoIdent, ", error)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
p.P("}")
|
||||||
|
p.P()
|
||||||
|
}
|
||||||
|
|
||||||
clientStructType := strings.ToLower(clientType[:1]) + clientType[1:]
|
clientStructType := strings.ToLower(clientType[:1]) + clientType[1:]
|
||||||
p.P("type ", clientStructType, " struct{")
|
p.P("type ", clientStructType, " struct{")
|
||||||
p.P("client *", gen.ident.client)
|
p.P("client *", gen.ident.client)
|
||||||
p.P("}")
|
p.P("}")
|
||||||
p.P("func New", clientType, "(client *", gen.ident.client, ")", serviceName, "{")
|
p.P("func New", clientType, "(client *", gen.ident.client, ")", clientInterface, "{")
|
||||||
p.P("return &", clientStructType, "{")
|
p.P("return &", clientStructType, "{")
|
||||||
p.P("client:client,")
|
p.P("client:client,")
|
||||||
p.P("}")
|
p.P("}")
|
||||||
p.P("}")
|
p.P("}")
|
||||||
|
p.P()
|
||||||
|
|
||||||
for _, method := range service.Methods {
|
for _, method := range service.Methods {
|
||||||
p.P("func (c *", clientStructType, ")", method.GoName, "(",
|
var sendArg string
|
||||||
"ctx ", gen.ident.context, ",",
|
if !method.Desc.IsStreamingClient() {
|
||||||
"req *", method.Input.GoIdent, ")",
|
sendArg = ", req *" + gen.out.QualifiedGoIdent(method.Input.GoIdent)
|
||||||
"(*", method.Output.GoIdent, ", error){")
|
}
|
||||||
|
|
||||||
|
intName := service.GoName + "_" + method.GoName + "Client"
|
||||||
|
var retArg string
|
||||||
|
if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() {
|
||||||
|
retArg = intName
|
||||||
|
} else {
|
||||||
|
retArg = "*" + gen.out.QualifiedGoIdent(method.Output.GoIdent)
|
||||||
|
}
|
||||||
|
|
||||||
|
p.P("func (c *", clientStructType, ") ", method.GoName,
|
||||||
|
"(ctx ", gen.ident.context, "", sendArg, ") ",
|
||||||
|
"(", retArg, ", error) {")
|
||||||
|
|
||||||
|
if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() {
|
||||||
|
var streamingClient, streamingServer, req string
|
||||||
|
if method.Desc.IsStreamingClient() {
|
||||||
|
streamingClient = "true"
|
||||||
|
req = "nil"
|
||||||
|
} else {
|
||||||
|
streamingClient = "false"
|
||||||
|
req = "req"
|
||||||
|
}
|
||||||
|
if method.Desc.IsStreamingServer() {
|
||||||
|
streamingServer = "true"
|
||||||
|
} else {
|
||||||
|
streamingServer = "false"
|
||||||
|
}
|
||||||
|
p.P("stream, err := c.client.NewStream(ctx, &", gen.ident.streamDesc, "{")
|
||||||
|
p.P("StreamingClient: ", streamingClient, ",")
|
||||||
|
p.P("StreamingServer: ", streamingServer, ",")
|
||||||
|
p.P("}, ", `"`+fullName+`", `, `"`+method.GoName+`", `, req, `)`)
|
||||||
|
p.P("if err != nil {")
|
||||||
|
p.P("return nil, err")
|
||||||
|
p.P("}")
|
||||||
|
|
||||||
|
structName := strings.ToLower(service.GoName) + method.GoName + "Client"
|
||||||
|
|
||||||
|
p.P("x := &", structName, "{stream}")
|
||||||
|
|
||||||
|
p.P("return x, nil")
|
||||||
|
p.P("}")
|
||||||
|
p.P()
|
||||||
|
|
||||||
|
// Create interface
|
||||||
|
p.P("type ", intName, " interface {")
|
||||||
|
if method.Desc.IsStreamingClient() {
|
||||||
|
p.P("Send(*", method.Input.GoIdent, ") error")
|
||||||
|
}
|
||||||
|
if method.Desc.IsStreamingServer() {
|
||||||
|
p.P("Recv() (*", method.Output.GoIdent, ", error)")
|
||||||
|
} else {
|
||||||
|
p.P("CloseAndRecv() (*", method.Output.GoIdent, ", error)")
|
||||||
|
}
|
||||||
|
|
||||||
|
p.P(gen.ident.streamClient)
|
||||||
|
p.P("}")
|
||||||
|
p.P()
|
||||||
|
|
||||||
|
// Create struct
|
||||||
|
p.P("type ", structName, " struct {")
|
||||||
|
p.P(gen.ident.streamClient)
|
||||||
|
p.P("}")
|
||||||
|
p.P()
|
||||||
|
|
||||||
|
if method.Desc.IsStreamingClient() {
|
||||||
|
p.P("func (x *", structName, ") Send(m *", method.Input.GoIdent, ") error {")
|
||||||
|
p.P("return x.", gen.ident.streamClientIdent.GoName, ".SendMsg(m)")
|
||||||
|
p.P("}")
|
||||||
|
p.P()
|
||||||
|
}
|
||||||
|
|
||||||
|
if method.Desc.IsStreamingServer() {
|
||||||
|
p.P("func (x *", structName, ") Recv() (*", method.Output.GoIdent, ", error) {")
|
||||||
|
p.P("m := new(", method.Output.GoIdent, ")")
|
||||||
|
p.P("if err := x.ClientStream.RecvMsg(m); err != nil {")
|
||||||
|
p.P("return nil, err")
|
||||||
|
p.P("}")
|
||||||
|
p.P("return m, nil")
|
||||||
|
p.P("}")
|
||||||
|
p.P()
|
||||||
|
} else {
|
||||||
|
p.P("func (x *", structName, ") CloseAndRecv() (*", method.Output.GoIdent, ", error) {")
|
||||||
|
p.P("if err := x.ClientStream.CloseSend(); err != nil {")
|
||||||
|
p.P("return nil, err")
|
||||||
|
p.P("}")
|
||||||
|
p.P("m := new(", method.Output.GoIdent, ")")
|
||||||
|
p.P("if err := x.ClientStream.RecvMsg(m); err != nil {")
|
||||||
|
p.P("return nil, err")
|
||||||
|
p.P("}")
|
||||||
|
p.P("return m, nil")
|
||||||
|
p.P("}")
|
||||||
|
p.P()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
p.P("var resp ", method.Output.GoIdent)
|
p.P("var resp ", method.Output.GoIdent)
|
||||||
p.P(`if err := c.client.Call(ctx, "`, fullName, `", "`, method.Desc.Name(), `", req, &resp); err != nil {`)
|
p.P(`if err := c.client.Call(ctx, "`, fullName, `", "`, method.Desc.Name(), `", req, &resp); err != nil {`)
|
||||||
p.P("return nil, err")
|
p.P("return nil, err")
|
||||||
p.P("}")
|
p.P("}")
|
||||||
p.P("return &resp, nil")
|
p.P("return &resp, nil")
|
||||||
p.P("}")
|
p.P("}")
|
||||||
|
p.P()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user