diff --git a/cmd/protoc-gen-go-fieldpath/generator.go b/cmd/protoc-gen-go-fieldpath/generator.go index 2aa6a8cb3..229f9c48e 100644 --- a/cmd/protoc-gen-go-fieldpath/generator.go +++ b/cmd/protoc-gen-go-fieldpath/generator.go @@ -130,7 +130,7 @@ func (gen *generator) genFieldMethod(m *protogen.Message) { } func isMessageField(f *protogen.Field) bool { - return f.Desc.Kind() == protoreflect.MessageKind && f.GoIdent.GoName != "Timestamp" + return f.Desc.Kind() == protoreflect.MessageKind && f.Desc.Cardinality() != protoreflect.Repeated && f.Message.GoIdent.GoName != "Timestamp" } func isLabelsField(f *protogen.Field) bool { @@ -138,7 +138,24 @@ func isLabelsField(f *protogen.Field) bool { } func isAnyField(f *protogen.Field) bool { - return f.Desc.Kind() == protoreflect.MessageKind && f.GoIdent.GoName == "Any" + return f.Desc.Kind() == protoreflect.MessageKind && f.Message.GoIdent.GoName == "Any" +} + +func collectChildlen(parent *protogen.Message) ([]*protogen.Message, error) { + var children []*protogen.Message + for _, child := range parent.Messages { + if child.Desc.IsMapEntry() { + continue + } + children = append(children, child) + + xs, err := collectChildlen(child) + if err != nil { + return nil, err + } + children = append(children, xs...) + } + return children, nil } func generate(plugin *protogen.Plugin, input *protogen.File) error { @@ -148,7 +165,18 @@ func generate(plugin *protogen.Plugin, input *protogen.File) error { file.P("package ", input.GoPackageName) gen := newGenerator(file) + + var messages []*protogen.Message for _, m := range input.Messages { + messages = append(messages, m) + children, err := collectChildlen(m) + if err != nil { + return err + } + messages = append(messages, children...) + } + + for _, m := range messages { gen.genFieldMethod(m) } return nil