Add streaming support to go-ttrpc generator
Signed-off-by: Derek McGowan <derek@mcg.dev>
This commit is contained in:
parent
86d3c77a60
commit
a323535118
@ -17,6 +17,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"google.golang.org/protobuf/compiler/protogen"
|
||||
@ -29,10 +30,19 @@ type generator struct {
|
||||
out *protogen.GeneratedFile
|
||||
|
||||
ident struct {
|
||||
context string
|
||||
server string
|
||||
client string
|
||||
method string
|
||||
context string
|
||||
server string
|
||||
client string
|
||||
method string
|
||||
stream string
|
||||
serviceDesc string
|
||||
streamDesc string
|
||||
|
||||
streamServerIdent protogen.GoIdent
|
||||
streamClientIdent protogen.GoIdent
|
||||
|
||||
streamServer string
|
||||
streamClient string
|
||||
}
|
||||
}
|
||||
|
||||
@ -54,6 +64,29 @@ func newGenerator(out *protogen.GeneratedFile) *generator {
|
||||
GoImportPath: "github.com/containerd/ttrpc",
|
||||
GoName: "Method",
|
||||
})
|
||||
gen.ident.stream = out.QualifiedGoIdent(protogen.GoIdent{
|
||||
GoImportPath: "github.com/containerd/ttrpc",
|
||||
GoName: "Stream",
|
||||
})
|
||||
gen.ident.serviceDesc = out.QualifiedGoIdent(protogen.GoIdent{
|
||||
GoImportPath: "github.com/containerd/ttrpc",
|
||||
GoName: "ServiceDesc",
|
||||
})
|
||||
gen.ident.streamDesc = out.QualifiedGoIdent(protogen.GoIdent{
|
||||
GoImportPath: "github.com/containerd/ttrpc",
|
||||
GoName: "StreamDesc",
|
||||
})
|
||||
|
||||
gen.ident.streamServerIdent = protogen.GoIdent{
|
||||
GoImportPath: "github.com/containerd/ttrpc",
|
||||
GoName: "StreamServer",
|
||||
}
|
||||
gen.ident.streamClientIdent = protogen.GoIdent{
|
||||
GoImportPath: "github.com/containerd/ttrpc",
|
||||
GoName: "ClientStream",
|
||||
}
|
||||
gen.ident.streamServer = out.QualifiedGoIdent(gen.ident.streamServerIdent)
|
||||
gen.ident.streamClient = out.QualifiedGoIdent(gen.ident.streamClientIdent)
|
||||
return &gen
|
||||
}
|
||||
|
||||
@ -74,52 +107,277 @@ func (gen *generator) genService(service *protogen.Service) {
|
||||
fullName := service.Desc.FullName()
|
||||
p := gen.out
|
||||
|
||||
var methods []*protogen.Method
|
||||
var streams []*protogen.Method
|
||||
|
||||
serviceName := service.GoName + "Service"
|
||||
p.P("type ", serviceName, " interface{")
|
||||
for _, method := range service.Methods {
|
||||
p.P(method.GoName,
|
||||
"(ctx ", gen.ident.context, ",",
|
||||
"req *", method.Input.GoIdent, ")",
|
||||
"(*", method.Output.GoIdent, ", error)")
|
||||
var sendArgs, retArgs string
|
||||
if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() {
|
||||
streams = append(streams, method)
|
||||
sendArgs = fmt.Sprintf("%s_%sServer", service.GoName, method.GoName)
|
||||
if !method.Desc.IsStreamingClient() {
|
||||
sendArgs = fmt.Sprintf("*%s, %s", p.QualifiedGoIdent(method.Input.GoIdent), sendArgs)
|
||||
}
|
||||
if method.Desc.IsStreamingServer() {
|
||||
retArgs = "error"
|
||||
} else {
|
||||
retArgs = fmt.Sprintf("(*%s, error)", p.QualifiedGoIdent(method.Output.GoIdent))
|
||||
}
|
||||
} else {
|
||||
methods = append(methods, method)
|
||||
sendArgs = fmt.Sprintf("*%s", p.QualifiedGoIdent(method.Input.GoIdent))
|
||||
retArgs = fmt.Sprintf("(*%s, error)", p.QualifiedGoIdent(method.Output.GoIdent))
|
||||
}
|
||||
p.P(method.GoName, "(", gen.ident.context, ", ", sendArgs, ") ", retArgs)
|
||||
}
|
||||
p.P("}")
|
||||
p.P()
|
||||
|
||||
for _, method := range streams {
|
||||
structName := strings.ToLower(service.GoName) + method.GoName + "Server"
|
||||
|
||||
p.P("type ", service.GoName, "_", method.GoName, "Server interface {")
|
||||
if method.Desc.IsStreamingServer() {
|
||||
p.P("Send(*", method.Output.GoIdent, ") error")
|
||||
}
|
||||
if method.Desc.IsStreamingClient() {
|
||||
p.P("Recv() (*", method.Input.GoIdent, ", error)")
|
||||
|
||||
}
|
||||
p.P(gen.ident.streamServer)
|
||||
p.P("}")
|
||||
p.P()
|
||||
|
||||
p.P("type ", structName, " struct {")
|
||||
p.P(gen.ident.streamServer)
|
||||
p.P("}")
|
||||
p.P()
|
||||
|
||||
if method.Desc.IsStreamingServer() {
|
||||
p.P("func (x *", structName, ") Send(m *", method.Output.GoIdent, ") error {")
|
||||
p.P("return x.StreamServer.SendMsg(m)")
|
||||
p.P("}")
|
||||
p.P()
|
||||
}
|
||||
|
||||
if method.Desc.IsStreamingClient() {
|
||||
p.P("func (x *", structName, ") Recv() (*", method.Input.GoIdent, ", error) {")
|
||||
p.P("m := new(", method.Input.GoIdent, ")")
|
||||
p.P("if err := x.StreamServer.RecvMsg(m); err != nil {")
|
||||
p.P("return nil, err")
|
||||
p.P("}")
|
||||
p.P("return m, nil")
|
||||
p.P("}")
|
||||
p.P()
|
||||
}
|
||||
}
|
||||
|
||||
// registration method
|
||||
p.P("func Register", serviceName, "(srv *", gen.ident.server, ", svc ", serviceName, "){")
|
||||
p.P(`srv.Register("`, fullName, `", map[string]`, gen.ident.method, "{")
|
||||
for _, method := range service.Methods {
|
||||
p.P(`"`, method.GoName, `": func(ctx `, gen.ident.context, ", unmarshal func(interface{}) error)(interface{}, error){")
|
||||
p.P("var req ", method.Input.GoIdent)
|
||||
p.P("if err := unmarshal(&req); err != nil {")
|
||||
p.P("return nil, err")
|
||||
p.P("}")
|
||||
p.P("return svc.", method.GoName, "(ctx, &req)")
|
||||
p.P(`srv.RegisterService("`, fullName, `", &`, gen.ident.serviceDesc, "{")
|
||||
if len(methods) > 0 {
|
||||
p.P(`Methods: map[string]`, gen.ident.method, "{")
|
||||
for _, method := range methods {
|
||||
p.P(`"`, method.GoName, `": func(ctx `, gen.ident.context, ", unmarshal func(interface{}) error)(interface{}, error){")
|
||||
p.P("var req ", method.Input.GoIdent)
|
||||
p.P("if err := unmarshal(&req); err != nil {")
|
||||
p.P("return nil, err")
|
||||
p.P("}")
|
||||
p.P("return svc.", method.GoName, "(ctx, &req)")
|
||||
p.P("},")
|
||||
}
|
||||
p.P("},")
|
||||
}
|
||||
if len(streams) > 0 {
|
||||
p.P(`Streams: map[string]`, gen.ident.stream, "{")
|
||||
for _, method := range streams {
|
||||
p.P(`"`, method.GoName, `": {`)
|
||||
p.P(`Handler: func(ctx `, gen.ident.context, ", stream ", gen.ident.streamServer, ") (interface{}, error) {")
|
||||
|
||||
structName := strings.ToLower(service.GoName) + method.GoName + "Server"
|
||||
var sendArg string
|
||||
if !method.Desc.IsStreamingClient() {
|
||||
sendArg = "m, "
|
||||
p.P("m := new(", method.Input.GoIdent, ")")
|
||||
p.P("if err := stream.RecvMsg(m); err != nil {")
|
||||
p.P("return nil, err")
|
||||
p.P("}")
|
||||
}
|
||||
if method.Desc.IsStreamingServer() {
|
||||
p.P("return nil, svc.", method.GoName, "(ctx, ", sendArg, "&", structName, "{stream})")
|
||||
} else {
|
||||
p.P("return svc.", method.GoName, "(ctx, ", sendArg, "&", structName, "{stream})")
|
||||
|
||||
}
|
||||
p.P("},")
|
||||
if method.Desc.IsStreamingClient() {
|
||||
p.P("StreamingClient: true,")
|
||||
} else {
|
||||
p.P("StreamingClient: false,")
|
||||
}
|
||||
if method.Desc.IsStreamingServer() {
|
||||
p.P("StreamingServer: true,")
|
||||
} else {
|
||||
p.P("StreamingServer: false,")
|
||||
}
|
||||
p.P("},")
|
||||
}
|
||||
p.P("},")
|
||||
}
|
||||
p.P("})")
|
||||
p.P("}")
|
||||
p.P()
|
||||
|
||||
clientType := service.GoName + "Client"
|
||||
|
||||
// For consistency with ttrpc 1.0 without streaming, just use
|
||||
// the service name if no streams are defined
|
||||
clientInterface := serviceName
|
||||
if len(streams) > 0 {
|
||||
clientInterface = clientType
|
||||
// Stream client interfaces are different than the server interface
|
||||
p.P("type ", clientInterface, " interface{")
|
||||
for _, method := range service.Methods {
|
||||
if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() {
|
||||
streams = append(streams, method)
|
||||
var sendArg string
|
||||
if !method.Desc.IsStreamingClient() {
|
||||
sendArg = fmt.Sprintf("*%s, ", p.QualifiedGoIdent(method.Input.GoIdent))
|
||||
}
|
||||
p.P(method.GoName,
|
||||
"(", gen.ident.context, ", ", sendArg,
|
||||
") (", service.GoName, "_", method.GoName, "Client, error)")
|
||||
} else {
|
||||
methods = append(methods, method)
|
||||
p.P(method.GoName,
|
||||
"(", gen.ident.context, ", ",
|
||||
"*", method.Input.GoIdent, ")",
|
||||
"(*", method.Output.GoIdent, ", error)")
|
||||
}
|
||||
}
|
||||
p.P("}")
|
||||
p.P()
|
||||
}
|
||||
|
||||
clientStructType := strings.ToLower(clientType[:1]) + clientType[1:]
|
||||
p.P("type ", clientStructType, " struct{")
|
||||
p.P("client *", gen.ident.client)
|
||||
p.P("}")
|
||||
p.P("func New", clientType, "(client *", gen.ident.client, ")", serviceName, "{")
|
||||
p.P("func New", clientType, "(client *", gen.ident.client, ")", clientInterface, "{")
|
||||
p.P("return &", clientStructType, "{")
|
||||
p.P("client:client,")
|
||||
p.P("}")
|
||||
p.P("}")
|
||||
p.P()
|
||||
|
||||
for _, method := range service.Methods {
|
||||
p.P("func (c *", clientStructType, ")", method.GoName, "(",
|
||||
"ctx ", gen.ident.context, ",",
|
||||
"req *", method.Input.GoIdent, ")",
|
||||
"(*", method.Output.GoIdent, ", error){")
|
||||
p.P("var resp ", method.Output.GoIdent)
|
||||
p.P(`if err := c.client.Call(ctx, "`, fullName, `", "`, method.Desc.Name(), `", req, &resp); err != nil {`)
|
||||
p.P("return nil, err")
|
||||
p.P("}")
|
||||
p.P("return &resp, nil")
|
||||
p.P("}")
|
||||
var sendArg string
|
||||
if !method.Desc.IsStreamingClient() {
|
||||
sendArg = ", req *" + gen.out.QualifiedGoIdent(method.Input.GoIdent)
|
||||
}
|
||||
|
||||
intName := service.GoName + "_" + method.GoName + "Client"
|
||||
var retArg string
|
||||
if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() {
|
||||
retArg = intName
|
||||
} else {
|
||||
retArg = "*" + gen.out.QualifiedGoIdent(method.Output.GoIdent)
|
||||
}
|
||||
|
||||
p.P("func (c *", clientStructType, ") ", method.GoName,
|
||||
"(ctx ", gen.ident.context, "", sendArg, ") ",
|
||||
"(", retArg, ", error) {")
|
||||
|
||||
if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() {
|
||||
var streamingClient, streamingServer, req string
|
||||
if method.Desc.IsStreamingClient() {
|
||||
streamingClient = "true"
|
||||
req = "nil"
|
||||
} else {
|
||||
streamingClient = "false"
|
||||
req = "req"
|
||||
}
|
||||
if method.Desc.IsStreamingServer() {
|
||||
streamingServer = "true"
|
||||
} else {
|
||||
streamingServer = "false"
|
||||
}
|
||||
p.P("stream, err := c.client.NewStream(ctx, &", gen.ident.streamDesc, "{")
|
||||
p.P("StreamingClient: ", streamingClient, ",")
|
||||
p.P("StreamingServer: ", streamingServer, ",")
|
||||
p.P("}, ", `"`+fullName+`", `, `"`+method.GoName+`", `, req, `)`)
|
||||
p.P("if err != nil {")
|
||||
p.P("return nil, err")
|
||||
p.P("}")
|
||||
|
||||
structName := strings.ToLower(service.GoName) + method.GoName + "Client"
|
||||
|
||||
p.P("x := &", structName, "{stream}")
|
||||
|
||||
p.P("return x, nil")
|
||||
p.P("}")
|
||||
p.P()
|
||||
|
||||
// Create interface
|
||||
p.P("type ", intName, " interface {")
|
||||
if method.Desc.IsStreamingClient() {
|
||||
p.P("Send(*", method.Input.GoIdent, ") error")
|
||||
}
|
||||
if method.Desc.IsStreamingServer() {
|
||||
p.P("Recv() (*", method.Output.GoIdent, ", error)")
|
||||
} else {
|
||||
p.P("CloseAndRecv() (*", method.Output.GoIdent, ", error)")
|
||||
}
|
||||
|
||||
p.P(gen.ident.streamClient)
|
||||
p.P("}")
|
||||
p.P()
|
||||
|
||||
// Create struct
|
||||
p.P("type ", structName, " struct {")
|
||||
p.P(gen.ident.streamClient)
|
||||
p.P("}")
|
||||
p.P()
|
||||
|
||||
if method.Desc.IsStreamingClient() {
|
||||
p.P("func (x *", structName, ") Send(m *", method.Input.GoIdent, ") error {")
|
||||
p.P("return x.", gen.ident.streamClientIdent.GoName, ".SendMsg(m)")
|
||||
p.P("}")
|
||||
p.P()
|
||||
}
|
||||
|
||||
if method.Desc.IsStreamingServer() {
|
||||
p.P("func (x *", structName, ") Recv() (*", method.Output.GoIdent, ", error) {")
|
||||
p.P("m := new(", method.Output.GoIdent, ")")
|
||||
p.P("if err := x.ClientStream.RecvMsg(m); err != nil {")
|
||||
p.P("return nil, err")
|
||||
p.P("}")
|
||||
p.P("return m, nil")
|
||||
p.P("}")
|
||||
p.P()
|
||||
} else {
|
||||
p.P("func (x *", structName, ") CloseAndRecv() (*", method.Output.GoIdent, ", error) {")
|
||||
p.P("if err := x.ClientStream.CloseSend(); err != nil {")
|
||||
p.P("return nil, err")
|
||||
p.P("}")
|
||||
p.P("m := new(", method.Output.GoIdent, ")")
|
||||
p.P("if err := x.ClientStream.RecvMsg(m); err != nil {")
|
||||
p.P("return nil, err")
|
||||
p.P("}")
|
||||
p.P("return m, nil")
|
||||
p.P("}")
|
||||
p.P()
|
||||
}
|
||||
} else {
|
||||
p.P("var resp ", method.Output.GoIdent)
|
||||
p.P(`if err := c.client.Call(ctx, "`, fullName, `", "`, method.Desc.Name(), `", req, &resp); err != nil {`)
|
||||
p.P("return nil, err")
|
||||
p.P("}")
|
||||
p.P("return &resp, nil")
|
||||
p.P("}")
|
||||
p.P()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user