Add streaming support to go-ttrpc generator

Signed-off-by: Derek McGowan <derek@mcg.dev>
This commit is contained in:
Derek McGowan 2022-03-04 21:02:14 -08:00
parent 86d3c77a60
commit a323535118

View File

@ -17,6 +17,7 @@
package main
import (
"fmt"
"strings"
"google.golang.org/protobuf/compiler/protogen"
@ -29,10 +30,19 @@ type generator struct {
out *protogen.GeneratedFile
ident struct {
context string
server string
client string
method string
context string
server string
client string
method string
stream string
serviceDesc string
streamDesc string
streamServerIdent protogen.GoIdent
streamClientIdent protogen.GoIdent
streamServer string
streamClient string
}
}
@ -54,6 +64,29 @@ func newGenerator(out *protogen.GeneratedFile) *generator {
GoImportPath: "github.com/containerd/ttrpc",
GoName: "Method",
})
gen.ident.stream = out.QualifiedGoIdent(protogen.GoIdent{
GoImportPath: "github.com/containerd/ttrpc",
GoName: "Stream",
})
gen.ident.serviceDesc = out.QualifiedGoIdent(protogen.GoIdent{
GoImportPath: "github.com/containerd/ttrpc",
GoName: "ServiceDesc",
})
gen.ident.streamDesc = out.QualifiedGoIdent(protogen.GoIdent{
GoImportPath: "github.com/containerd/ttrpc",
GoName: "StreamDesc",
})
gen.ident.streamServerIdent = protogen.GoIdent{
GoImportPath: "github.com/containerd/ttrpc",
GoName: "StreamServer",
}
gen.ident.streamClientIdent = protogen.GoIdent{
GoImportPath: "github.com/containerd/ttrpc",
GoName: "ClientStream",
}
gen.ident.streamServer = out.QualifiedGoIdent(gen.ident.streamServerIdent)
gen.ident.streamClient = out.QualifiedGoIdent(gen.ident.streamClientIdent)
return &gen
}
@ -74,52 +107,277 @@ func (gen *generator) genService(service *protogen.Service) {
fullName := service.Desc.FullName()
p := gen.out
var methods []*protogen.Method
var streams []*protogen.Method
serviceName := service.GoName + "Service"
p.P("type ", serviceName, " interface{")
for _, method := range service.Methods {
p.P(method.GoName,
"(ctx ", gen.ident.context, ",",
"req *", method.Input.GoIdent, ")",
"(*", method.Output.GoIdent, ", error)")
var sendArgs, retArgs string
if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() {
streams = append(streams, method)
sendArgs = fmt.Sprintf("%s_%sServer", service.GoName, method.GoName)
if !method.Desc.IsStreamingClient() {
sendArgs = fmt.Sprintf("*%s, %s", p.QualifiedGoIdent(method.Input.GoIdent), sendArgs)
}
if method.Desc.IsStreamingServer() {
retArgs = "error"
} else {
retArgs = fmt.Sprintf("(*%s, error)", p.QualifiedGoIdent(method.Output.GoIdent))
}
} else {
methods = append(methods, method)
sendArgs = fmt.Sprintf("*%s", p.QualifiedGoIdent(method.Input.GoIdent))
retArgs = fmt.Sprintf("(*%s, error)", p.QualifiedGoIdent(method.Output.GoIdent))
}
p.P(method.GoName, "(", gen.ident.context, ", ", sendArgs, ") ", retArgs)
}
p.P("}")
p.P()
for _, method := range streams {
structName := strings.ToLower(service.GoName) + method.GoName + "Server"
p.P("type ", service.GoName, "_", method.GoName, "Server interface {")
if method.Desc.IsStreamingServer() {
p.P("Send(*", method.Output.GoIdent, ") error")
}
if method.Desc.IsStreamingClient() {
p.P("Recv() (*", method.Input.GoIdent, ", error)")
}
p.P(gen.ident.streamServer)
p.P("}")
p.P()
p.P("type ", structName, " struct {")
p.P(gen.ident.streamServer)
p.P("}")
p.P()
if method.Desc.IsStreamingServer() {
p.P("func (x *", structName, ") Send(m *", method.Output.GoIdent, ") error {")
p.P("return x.StreamServer.SendMsg(m)")
p.P("}")
p.P()
}
if method.Desc.IsStreamingClient() {
p.P("func (x *", structName, ") Recv() (*", method.Input.GoIdent, ", error) {")
p.P("m := new(", method.Input.GoIdent, ")")
p.P("if err := x.StreamServer.RecvMsg(m); err != nil {")
p.P("return nil, err")
p.P("}")
p.P("return m, nil")
p.P("}")
p.P()
}
}
// registration method
p.P("func Register", serviceName, "(srv *", gen.ident.server, ", svc ", serviceName, "){")
p.P(`srv.Register("`, fullName, `", map[string]`, gen.ident.method, "{")
for _, method := range service.Methods {
p.P(`"`, method.GoName, `": func(ctx `, gen.ident.context, ", unmarshal func(interface{}) error)(interface{}, error){")
p.P("var req ", method.Input.GoIdent)
p.P("if err := unmarshal(&req); err != nil {")
p.P("return nil, err")
p.P("}")
p.P("return svc.", method.GoName, "(ctx, &req)")
p.P(`srv.RegisterService("`, fullName, `", &`, gen.ident.serviceDesc, "{")
if len(methods) > 0 {
p.P(`Methods: map[string]`, gen.ident.method, "{")
for _, method := range methods {
p.P(`"`, method.GoName, `": func(ctx `, gen.ident.context, ", unmarshal func(interface{}) error)(interface{}, error){")
p.P("var req ", method.Input.GoIdent)
p.P("if err := unmarshal(&req); err != nil {")
p.P("return nil, err")
p.P("}")
p.P("return svc.", method.GoName, "(ctx, &req)")
p.P("},")
}
p.P("},")
}
if len(streams) > 0 {
p.P(`Streams: map[string]`, gen.ident.stream, "{")
for _, method := range streams {
p.P(`"`, method.GoName, `": {`)
p.P(`Handler: func(ctx `, gen.ident.context, ", stream ", gen.ident.streamServer, ") (interface{}, error) {")
structName := strings.ToLower(service.GoName) + method.GoName + "Server"
var sendArg string
if !method.Desc.IsStreamingClient() {
sendArg = "m, "
p.P("m := new(", method.Input.GoIdent, ")")
p.P("if err := stream.RecvMsg(m); err != nil {")
p.P("return nil, err")
p.P("}")
}
if method.Desc.IsStreamingServer() {
p.P("return nil, svc.", method.GoName, "(ctx, ", sendArg, "&", structName, "{stream})")
} else {
p.P("return svc.", method.GoName, "(ctx, ", sendArg, "&", structName, "{stream})")
}
p.P("},")
if method.Desc.IsStreamingClient() {
p.P("StreamingClient: true,")
} else {
p.P("StreamingClient: false,")
}
if method.Desc.IsStreamingServer() {
p.P("StreamingServer: true,")
} else {
p.P("StreamingServer: false,")
}
p.P("},")
}
p.P("},")
}
p.P("})")
p.P("}")
p.P()
clientType := service.GoName + "Client"
// For consistency with ttrpc 1.0 without streaming, just use
// the service name if no streams are defined
clientInterface := serviceName
if len(streams) > 0 {
clientInterface = clientType
// Stream client interfaces are different than the server interface
p.P("type ", clientInterface, " interface{")
for _, method := range service.Methods {
if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() {
streams = append(streams, method)
var sendArg string
if !method.Desc.IsStreamingClient() {
sendArg = fmt.Sprintf("*%s, ", p.QualifiedGoIdent(method.Input.GoIdent))
}
p.P(method.GoName,
"(", gen.ident.context, ", ", sendArg,
") (", service.GoName, "_", method.GoName, "Client, error)")
} else {
methods = append(methods, method)
p.P(method.GoName,
"(", gen.ident.context, ", ",
"*", method.Input.GoIdent, ")",
"(*", method.Output.GoIdent, ", error)")
}
}
p.P("}")
p.P()
}
clientStructType := strings.ToLower(clientType[:1]) + clientType[1:]
p.P("type ", clientStructType, " struct{")
p.P("client *", gen.ident.client)
p.P("}")
p.P("func New", clientType, "(client *", gen.ident.client, ")", serviceName, "{")
p.P("func New", clientType, "(client *", gen.ident.client, ")", clientInterface, "{")
p.P("return &", clientStructType, "{")
p.P("client:client,")
p.P("}")
p.P("}")
p.P()
for _, method := range service.Methods {
p.P("func (c *", clientStructType, ")", method.GoName, "(",
"ctx ", gen.ident.context, ",",
"req *", method.Input.GoIdent, ")",
"(*", method.Output.GoIdent, ", error){")
p.P("var resp ", method.Output.GoIdent)
p.P(`if err := c.client.Call(ctx, "`, fullName, `", "`, method.Desc.Name(), `", req, &resp); err != nil {`)
p.P("return nil, err")
p.P("}")
p.P("return &resp, nil")
p.P("}")
var sendArg string
if !method.Desc.IsStreamingClient() {
sendArg = ", req *" + gen.out.QualifiedGoIdent(method.Input.GoIdent)
}
intName := service.GoName + "_" + method.GoName + "Client"
var retArg string
if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() {
retArg = intName
} else {
retArg = "*" + gen.out.QualifiedGoIdent(method.Output.GoIdent)
}
p.P("func (c *", clientStructType, ") ", method.GoName,
"(ctx ", gen.ident.context, "", sendArg, ") ",
"(", retArg, ", error) {")
if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() {
var streamingClient, streamingServer, req string
if method.Desc.IsStreamingClient() {
streamingClient = "true"
req = "nil"
} else {
streamingClient = "false"
req = "req"
}
if method.Desc.IsStreamingServer() {
streamingServer = "true"
} else {
streamingServer = "false"
}
p.P("stream, err := c.client.NewStream(ctx, &", gen.ident.streamDesc, "{")
p.P("StreamingClient: ", streamingClient, ",")
p.P("StreamingServer: ", streamingServer, ",")
p.P("}, ", `"`+fullName+`", `, `"`+method.GoName+`", `, req, `)`)
p.P("if err != nil {")
p.P("return nil, err")
p.P("}")
structName := strings.ToLower(service.GoName) + method.GoName + "Client"
p.P("x := &", structName, "{stream}")
p.P("return x, nil")
p.P("}")
p.P()
// Create interface
p.P("type ", intName, " interface {")
if method.Desc.IsStreamingClient() {
p.P("Send(*", method.Input.GoIdent, ") error")
}
if method.Desc.IsStreamingServer() {
p.P("Recv() (*", method.Output.GoIdent, ", error)")
} else {
p.P("CloseAndRecv() (*", method.Output.GoIdent, ", error)")
}
p.P(gen.ident.streamClient)
p.P("}")
p.P()
// Create struct
p.P("type ", structName, " struct {")
p.P(gen.ident.streamClient)
p.P("}")
p.P()
if method.Desc.IsStreamingClient() {
p.P("func (x *", structName, ") Send(m *", method.Input.GoIdent, ") error {")
p.P("return x.", gen.ident.streamClientIdent.GoName, ".SendMsg(m)")
p.P("}")
p.P()
}
if method.Desc.IsStreamingServer() {
p.P("func (x *", structName, ") Recv() (*", method.Output.GoIdent, ", error) {")
p.P("m := new(", method.Output.GoIdent, ")")
p.P("if err := x.ClientStream.RecvMsg(m); err != nil {")
p.P("return nil, err")
p.P("}")
p.P("return m, nil")
p.P("}")
p.P()
} else {
p.P("func (x *", structName, ") CloseAndRecv() (*", method.Output.GoIdent, ", error) {")
p.P("if err := x.ClientStream.CloseSend(); err != nil {")
p.P("return nil, err")
p.P("}")
p.P("m := new(", method.Output.GoIdent, ")")
p.P("if err := x.ClientStream.RecvMsg(m); err != nil {")
p.P("return nil, err")
p.P("}")
p.P("return m, nil")
p.P("}")
p.P()
}
} else {
p.P("var resp ", method.Output.GoIdent)
p.P(`if err := c.client.Call(ctx, "`, fullName, `", "`, method.Desc.Name(), `", req, &resp); err != nil {`)
p.P("return nil, err")
p.P("}")
p.P("return &resp, nil")
p.P("}")
p.P()
}
}
}