Add streaming support to go-ttrpc generator
Signed-off-by: Derek McGowan <derek@mcg.dev>
This commit is contained in:
		| @@ -17,6 +17,7 @@ | |||||||
| package main | package main | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"fmt" | ||||||
| 	"strings" | 	"strings" | ||||||
|  |  | ||||||
| 	"google.golang.org/protobuf/compiler/protogen" | 	"google.golang.org/protobuf/compiler/protogen" | ||||||
| @@ -29,10 +30,19 @@ type generator struct { | |||||||
| 	out *protogen.GeneratedFile | 	out *protogen.GeneratedFile | ||||||
|  |  | ||||||
| 	ident struct { | 	ident struct { | ||||||
| 		context string | 		context     string | ||||||
| 		server  string | 		server      string | ||||||
| 		client  string | 		client      string | ||||||
| 		method  string | 		method      string | ||||||
|  | 		stream      string | ||||||
|  | 		serviceDesc string | ||||||
|  | 		streamDesc  string | ||||||
|  |  | ||||||
|  | 		streamServerIdent protogen.GoIdent | ||||||
|  | 		streamClientIdent protogen.GoIdent | ||||||
|  |  | ||||||
|  | 		streamServer string | ||||||
|  | 		streamClient string | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -54,6 +64,29 @@ func newGenerator(out *protogen.GeneratedFile) *generator { | |||||||
| 		GoImportPath: "github.com/containerd/ttrpc", | 		GoImportPath: "github.com/containerd/ttrpc", | ||||||
| 		GoName:       "Method", | 		GoName:       "Method", | ||||||
| 	}) | 	}) | ||||||
|  | 	gen.ident.stream = out.QualifiedGoIdent(protogen.GoIdent{ | ||||||
|  | 		GoImportPath: "github.com/containerd/ttrpc", | ||||||
|  | 		GoName:       "Stream", | ||||||
|  | 	}) | ||||||
|  | 	gen.ident.serviceDesc = out.QualifiedGoIdent(protogen.GoIdent{ | ||||||
|  | 		GoImportPath: "github.com/containerd/ttrpc", | ||||||
|  | 		GoName:       "ServiceDesc", | ||||||
|  | 	}) | ||||||
|  | 	gen.ident.streamDesc = out.QualifiedGoIdent(protogen.GoIdent{ | ||||||
|  | 		GoImportPath: "github.com/containerd/ttrpc", | ||||||
|  | 		GoName:       "StreamDesc", | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	gen.ident.streamServerIdent = protogen.GoIdent{ | ||||||
|  | 		GoImportPath: "github.com/containerd/ttrpc", | ||||||
|  | 		GoName:       "StreamServer", | ||||||
|  | 	} | ||||||
|  | 	gen.ident.streamClientIdent = protogen.GoIdent{ | ||||||
|  | 		GoImportPath: "github.com/containerd/ttrpc", | ||||||
|  | 		GoName:       "ClientStream", | ||||||
|  | 	} | ||||||
|  | 	gen.ident.streamServer = out.QualifiedGoIdent(gen.ident.streamServerIdent) | ||||||
|  | 	gen.ident.streamClient = out.QualifiedGoIdent(gen.ident.streamClientIdent) | ||||||
| 	return &gen | 	return &gen | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -74,52 +107,277 @@ func (gen *generator) genService(service *protogen.Service) { | |||||||
| 	fullName := service.Desc.FullName() | 	fullName := service.Desc.FullName() | ||||||
| 	p := gen.out | 	p := gen.out | ||||||
|  |  | ||||||
|  | 	var methods []*protogen.Method | ||||||
|  | 	var streams []*protogen.Method | ||||||
|  |  | ||||||
| 	serviceName := service.GoName + "Service" | 	serviceName := service.GoName + "Service" | ||||||
| 	p.P("type ", serviceName, " interface{") | 	p.P("type ", serviceName, " interface{") | ||||||
| 	for _, method := range service.Methods { | 	for _, method := range service.Methods { | ||||||
| 		p.P(method.GoName, | 		var sendArgs, retArgs string | ||||||
| 			"(ctx ", gen.ident.context, ",", | 		if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() { | ||||||
| 			"req *", method.Input.GoIdent, ")", | 			streams = append(streams, method) | ||||||
| 			"(*", method.Output.GoIdent, ", error)") | 			sendArgs = fmt.Sprintf("%s_%sServer", service.GoName, method.GoName) | ||||||
|  | 			if !method.Desc.IsStreamingClient() { | ||||||
|  | 				sendArgs = fmt.Sprintf("*%s, %s", p.QualifiedGoIdent(method.Input.GoIdent), sendArgs) | ||||||
|  | 			} | ||||||
|  | 			if method.Desc.IsStreamingServer() { | ||||||
|  | 				retArgs = "error" | ||||||
|  | 			} else { | ||||||
|  | 				retArgs = fmt.Sprintf("(*%s, error)", p.QualifiedGoIdent(method.Output.GoIdent)) | ||||||
|  | 			} | ||||||
|  | 		} else { | ||||||
|  | 			methods = append(methods, method) | ||||||
|  | 			sendArgs = fmt.Sprintf("*%s", p.QualifiedGoIdent(method.Input.GoIdent)) | ||||||
|  | 			retArgs = fmt.Sprintf("(*%s, error)", p.QualifiedGoIdent(method.Output.GoIdent)) | ||||||
|  | 		} | ||||||
|  | 		p.P(method.GoName, "(", gen.ident.context, ", ", sendArgs, ") ", retArgs) | ||||||
| 	} | 	} | ||||||
| 	p.P("}") | 	p.P("}") | ||||||
|  | 	p.P() | ||||||
|  |  | ||||||
|  | 	for _, method := range streams { | ||||||
|  | 		structName := strings.ToLower(service.GoName) + method.GoName + "Server" | ||||||
|  |  | ||||||
|  | 		p.P("type ", service.GoName, "_", method.GoName, "Server interface {") | ||||||
|  | 		if method.Desc.IsStreamingServer() { | ||||||
|  | 			p.P("Send(*", method.Output.GoIdent, ") error") | ||||||
|  | 		} | ||||||
|  | 		if method.Desc.IsStreamingClient() { | ||||||
|  | 			p.P("Recv() (*", method.Input.GoIdent, ", error)") | ||||||
|  |  | ||||||
|  | 		} | ||||||
|  | 		p.P(gen.ident.streamServer) | ||||||
|  | 		p.P("}") | ||||||
|  | 		p.P() | ||||||
|  |  | ||||||
|  | 		p.P("type ", structName, " struct {") | ||||||
|  | 		p.P(gen.ident.streamServer) | ||||||
|  | 		p.P("}") | ||||||
|  | 		p.P() | ||||||
|  |  | ||||||
|  | 		if method.Desc.IsStreamingServer() { | ||||||
|  | 			p.P("func (x *", structName, ") Send(m *", method.Output.GoIdent, ") error {") | ||||||
|  | 			p.P("return x.StreamServer.SendMsg(m)") | ||||||
|  | 			p.P("}") | ||||||
|  | 			p.P() | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		if method.Desc.IsStreamingClient() { | ||||||
|  | 			p.P("func (x *", structName, ") Recv() (*", method.Input.GoIdent, ", error) {") | ||||||
|  | 			p.P("m := new(", method.Input.GoIdent, ")") | ||||||
|  | 			p.P("if err := x.StreamServer.RecvMsg(m); err != nil {") | ||||||
|  | 			p.P("return nil, err") | ||||||
|  | 			p.P("}") | ||||||
|  | 			p.P("return m, nil") | ||||||
|  | 			p.P("}") | ||||||
|  | 			p.P() | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	// registration method | 	// registration method | ||||||
| 	p.P("func Register", serviceName, "(srv *", gen.ident.server, ", svc ", serviceName, "){") | 	p.P("func Register", serviceName, "(srv *", gen.ident.server, ", svc ", serviceName, "){") | ||||||
| 	p.P(`srv.Register("`, fullName, `", map[string]`, gen.ident.method, "{") | 	p.P(`srv.RegisterService("`, fullName, `", &`, gen.ident.serviceDesc, "{") | ||||||
| 	for _, method := range service.Methods { | 	if len(methods) > 0 { | ||||||
| 		p.P(`"`, method.GoName, `": func(ctx `, gen.ident.context, ", unmarshal func(interface{}) error)(interface{}, error){") | 		p.P(`Methods: map[string]`, gen.ident.method, "{") | ||||||
| 		p.P("var req ", method.Input.GoIdent) | 		for _, method := range methods { | ||||||
| 		p.P("if err := unmarshal(&req); err != nil {") | 			p.P(`"`, method.GoName, `": func(ctx `, gen.ident.context, ", unmarshal func(interface{}) error)(interface{}, error){") | ||||||
| 		p.P("return nil, err") | 			p.P("var req ", method.Input.GoIdent) | ||||||
| 		p.P("}") | 			p.P("if err := unmarshal(&req); err != nil {") | ||||||
| 		p.P("return svc.", method.GoName, "(ctx, &req)") | 			p.P("return nil, err") | ||||||
|  | 			p.P("}") | ||||||
|  | 			p.P("return svc.", method.GoName, "(ctx, &req)") | ||||||
|  | 			p.P("},") | ||||||
|  | 		} | ||||||
|  | 		p.P("},") | ||||||
|  | 	} | ||||||
|  | 	if len(streams) > 0 { | ||||||
|  | 		p.P(`Streams: map[string]`, gen.ident.stream, "{") | ||||||
|  | 		for _, method := range streams { | ||||||
|  | 			p.P(`"`, method.GoName, `": {`) | ||||||
|  | 			p.P(`Handler: func(ctx `, gen.ident.context, ", stream ", gen.ident.streamServer, ") (interface{}, error) {") | ||||||
|  |  | ||||||
|  | 			structName := strings.ToLower(service.GoName) + method.GoName + "Server" | ||||||
|  | 			var sendArg string | ||||||
|  | 			if !method.Desc.IsStreamingClient() { | ||||||
|  | 				sendArg = "m, " | ||||||
|  | 				p.P("m := new(", method.Input.GoIdent, ")") | ||||||
|  | 				p.P("if err := stream.RecvMsg(m); err != nil {") | ||||||
|  | 				p.P("return nil, err") | ||||||
|  | 				p.P("}") | ||||||
|  | 			} | ||||||
|  | 			if method.Desc.IsStreamingServer() { | ||||||
|  | 				p.P("return nil, svc.", method.GoName, "(ctx, ", sendArg, "&", structName, "{stream})") | ||||||
|  | 			} else { | ||||||
|  | 				p.P("return svc.", method.GoName, "(ctx, ", sendArg, "&", structName, "{stream})") | ||||||
|  |  | ||||||
|  | 			} | ||||||
|  | 			p.P("},") | ||||||
|  | 			if method.Desc.IsStreamingClient() { | ||||||
|  | 				p.P("StreamingClient: true,") | ||||||
|  | 			} else { | ||||||
|  | 				p.P("StreamingClient: false,") | ||||||
|  | 			} | ||||||
|  | 			if method.Desc.IsStreamingServer() { | ||||||
|  | 				p.P("StreamingServer: true,") | ||||||
|  | 			} else { | ||||||
|  | 				p.P("StreamingServer: false,") | ||||||
|  | 			} | ||||||
|  | 			p.P("},") | ||||||
|  | 		} | ||||||
| 		p.P("},") | 		p.P("},") | ||||||
| 	} | 	} | ||||||
| 	p.P("})") | 	p.P("})") | ||||||
| 	p.P("}") | 	p.P("}") | ||||||
|  | 	p.P() | ||||||
|  |  | ||||||
| 	clientType := service.GoName + "Client" | 	clientType := service.GoName + "Client" | ||||||
|  |  | ||||||
|  | 	// For consistency with ttrpc 1.0 without streaming, just use | ||||||
|  | 	// the service name if no streams are defined | ||||||
|  | 	clientInterface := serviceName | ||||||
|  | 	if len(streams) > 0 { | ||||||
|  | 		clientInterface = clientType | ||||||
|  | 		// Stream client interfaces are different than the server interface | ||||||
|  | 		p.P("type ", clientInterface, " interface{") | ||||||
|  | 		for _, method := range service.Methods { | ||||||
|  | 			if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() { | ||||||
|  | 				streams = append(streams, method) | ||||||
|  | 				var sendArg string | ||||||
|  | 				if !method.Desc.IsStreamingClient() { | ||||||
|  | 					sendArg = fmt.Sprintf("*%s, ", p.QualifiedGoIdent(method.Input.GoIdent)) | ||||||
|  | 				} | ||||||
|  | 				p.P(method.GoName, | ||||||
|  | 					"(", gen.ident.context, ", ", sendArg, | ||||||
|  | 					") (", service.GoName, "_", method.GoName, "Client, error)") | ||||||
|  | 			} else { | ||||||
|  | 				methods = append(methods, method) | ||||||
|  | 				p.P(method.GoName, | ||||||
|  | 					"(", gen.ident.context, ", ", | ||||||
|  | 					"*", method.Input.GoIdent, ")", | ||||||
|  | 					"(*", method.Output.GoIdent, ", error)") | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		p.P("}") | ||||||
|  | 		p.P() | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	clientStructType := strings.ToLower(clientType[:1]) + clientType[1:] | 	clientStructType := strings.ToLower(clientType[:1]) + clientType[1:] | ||||||
| 	p.P("type ", clientStructType, " struct{") | 	p.P("type ", clientStructType, " struct{") | ||||||
| 	p.P("client *", gen.ident.client) | 	p.P("client *", gen.ident.client) | ||||||
| 	p.P("}") | 	p.P("}") | ||||||
| 	p.P("func New", clientType, "(client *", gen.ident.client, ")", serviceName, "{") | 	p.P("func New", clientType, "(client *", gen.ident.client, ")", clientInterface, "{") | ||||||
| 	p.P("return &", clientStructType, "{") | 	p.P("return &", clientStructType, "{") | ||||||
| 	p.P("client:client,") | 	p.P("client:client,") | ||||||
| 	p.P("}") | 	p.P("}") | ||||||
| 	p.P("}") | 	p.P("}") | ||||||
|  | 	p.P() | ||||||
|  |  | ||||||
| 	for _, method := range service.Methods { | 	for _, method := range service.Methods { | ||||||
| 		p.P("func (c *", clientStructType, ")", method.GoName, "(", | 		var sendArg string | ||||||
| 			"ctx ", gen.ident.context, ",", | 		if !method.Desc.IsStreamingClient() { | ||||||
| 			"req *", method.Input.GoIdent, ")", | 			sendArg = ", req *" + gen.out.QualifiedGoIdent(method.Input.GoIdent) | ||||||
| 			"(*", method.Output.GoIdent, ", error){") | 		} | ||||||
| 		p.P("var resp ", method.Output.GoIdent) |  | ||||||
| 		p.P(`if err := c.client.Call(ctx, "`, fullName, `", "`, method.Desc.Name(), `", req, &resp); err != nil {`) | 		intName := service.GoName + "_" + method.GoName + "Client" | ||||||
| 		p.P("return nil, err") | 		var retArg string | ||||||
| 		p.P("}") | 		if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() { | ||||||
| 		p.P("return &resp, nil") | 			retArg = intName | ||||||
| 		p.P("}") | 		} else { | ||||||
|  | 			retArg = "*" + gen.out.QualifiedGoIdent(method.Output.GoIdent) | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		p.P("func (c *", clientStructType, ") ", method.GoName, | ||||||
|  | 			"(ctx ", gen.ident.context, "", sendArg, ") ", | ||||||
|  | 			"(", retArg, ", error) {") | ||||||
|  |  | ||||||
|  | 		if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() { | ||||||
|  | 			var streamingClient, streamingServer, req string | ||||||
|  | 			if method.Desc.IsStreamingClient() { | ||||||
|  | 				streamingClient = "true" | ||||||
|  | 				req = "nil" | ||||||
|  | 			} else { | ||||||
|  | 				streamingClient = "false" | ||||||
|  | 				req = "req" | ||||||
|  | 			} | ||||||
|  | 			if method.Desc.IsStreamingServer() { | ||||||
|  | 				streamingServer = "true" | ||||||
|  | 			} else { | ||||||
|  | 				streamingServer = "false" | ||||||
|  | 			} | ||||||
|  | 			p.P("stream, err := c.client.NewStream(ctx, &", gen.ident.streamDesc, "{") | ||||||
|  | 			p.P("StreamingClient: ", streamingClient, ",") | ||||||
|  | 			p.P("StreamingServer: ", streamingServer, ",") | ||||||
|  | 			p.P("}, ", `"`+fullName+`", `, `"`+method.GoName+`", `, req, `)`) | ||||||
|  | 			p.P("if err != nil {") | ||||||
|  | 			p.P("return nil, err") | ||||||
|  | 			p.P("}") | ||||||
|  |  | ||||||
|  | 			structName := strings.ToLower(service.GoName) + method.GoName + "Client" | ||||||
|  |  | ||||||
|  | 			p.P("x := &", structName, "{stream}") | ||||||
|  |  | ||||||
|  | 			p.P("return x, nil") | ||||||
|  | 			p.P("}") | ||||||
|  | 			p.P() | ||||||
|  |  | ||||||
|  | 			// Create interface | ||||||
|  | 			p.P("type ", intName, " interface {") | ||||||
|  | 			if method.Desc.IsStreamingClient() { | ||||||
|  | 				p.P("Send(*", method.Input.GoIdent, ") error") | ||||||
|  | 			} | ||||||
|  | 			if method.Desc.IsStreamingServer() { | ||||||
|  | 				p.P("Recv() (*", method.Output.GoIdent, ", error)") | ||||||
|  | 			} else { | ||||||
|  | 				p.P("CloseAndRecv() (*", method.Output.GoIdent, ", error)") | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			p.P(gen.ident.streamClient) | ||||||
|  | 			p.P("}") | ||||||
|  | 			p.P() | ||||||
|  |  | ||||||
|  | 			// Create struct | ||||||
|  | 			p.P("type ", structName, " struct {") | ||||||
|  | 			p.P(gen.ident.streamClient) | ||||||
|  | 			p.P("}") | ||||||
|  | 			p.P() | ||||||
|  |  | ||||||
|  | 			if method.Desc.IsStreamingClient() { | ||||||
|  | 				p.P("func (x *", structName, ") Send(m *", method.Input.GoIdent, ") error {") | ||||||
|  | 				p.P("return x.", gen.ident.streamClientIdent.GoName, ".SendMsg(m)") | ||||||
|  | 				p.P("}") | ||||||
|  | 				p.P() | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			if method.Desc.IsStreamingServer() { | ||||||
|  | 				p.P("func (x *", structName, ") Recv() (*", method.Output.GoIdent, ", error) {") | ||||||
|  | 				p.P("m := new(", method.Output.GoIdent, ")") | ||||||
|  | 				p.P("if err := x.ClientStream.RecvMsg(m); err != nil {") | ||||||
|  | 				p.P("return nil, err") | ||||||
|  | 				p.P("}") | ||||||
|  | 				p.P("return m, nil") | ||||||
|  | 				p.P("}") | ||||||
|  | 				p.P() | ||||||
|  | 			} else { | ||||||
|  | 				p.P("func (x *", structName, ") CloseAndRecv() (*", method.Output.GoIdent, ", error) {") | ||||||
|  | 				p.P("if err := x.ClientStream.CloseSend(); err != nil {") | ||||||
|  | 				p.P("return nil, err") | ||||||
|  | 				p.P("}") | ||||||
|  | 				p.P("m := new(", method.Output.GoIdent, ")") | ||||||
|  | 				p.P("if err := x.ClientStream.RecvMsg(m); err != nil {") | ||||||
|  | 				p.P("return nil, err") | ||||||
|  | 				p.P("}") | ||||||
|  | 				p.P("return m, nil") | ||||||
|  | 				p.P("}") | ||||||
|  | 				p.P() | ||||||
|  | 			} | ||||||
|  | 		} else { | ||||||
|  | 			p.P("var resp ", method.Output.GoIdent) | ||||||
|  | 			p.P(`if err := c.client.Call(ctx, "`, fullName, `", "`, method.Desc.Name(), `", req, &resp); err != nil {`) | ||||||
|  | 			p.P("return nil, err") | ||||||
|  | 			p.P("}") | ||||||
|  | 			p.P("return &resp, nil") | ||||||
|  | 			p.P("}") | ||||||
|  | 			p.P() | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Derek McGowan
					Derek McGowan