go.mod: runc v1.0.0

Signed-off-by: Akihiro Suda <akihiro.suda.cz@hco.ntt.co.jp>
This commit is contained in:
Akihiro Suda 2021-06-22 15:35:10 +09:00 committed by Davanum Srinivas
parent 28bb59c080
commit f913a42755
107 changed files with 9760 additions and 1798 deletions

10
go.mod
View File

@ -19,7 +19,7 @@ require (
github.com/containerd/typeurl v1.0.2 github.com/containerd/typeurl v1.0.2
github.com/containerd/zfs v1.0.0 github.com/containerd/zfs v1.0.0
github.com/containernetworking/plugins v0.9.1 github.com/containernetworking/plugins v0.9.1
github.com/coreos/go-systemd/v22 v22.3.1 github.com/coreos/go-systemd/v22 v22.3.2
github.com/davecgh/go-spew v1.1.1 github.com/davecgh/go-spew v1.1.1
github.com/docker/go-events v0.0.0-20190806004212-e31b211e4f1c github.com/docker/go-events v0.0.0-20190806004212-e31b211e4f1c
github.com/docker/go-metrics v0.0.1 github.com/docker/go-metrics v0.0.1
@ -28,7 +28,7 @@ require (
github.com/fsnotify/fsnotify v1.4.9 github.com/fsnotify/fsnotify v1.4.9
github.com/gogo/googleapis v1.4.0 github.com/gogo/googleapis v1.4.0
github.com/gogo/protobuf v1.3.2 github.com/gogo/protobuf v1.3.2
github.com/google/go-cmp v0.5.4 github.com/google/go-cmp v0.5.5
github.com/google/uuid v1.2.0 github.com/google/uuid v1.2.0
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0
github.com/hashicorp/go-multierror v1.0.0 github.com/hashicorp/go-multierror v1.0.0
@ -39,7 +39,7 @@ require (
github.com/moby/sys/symlink v0.1.0 github.com/moby/sys/symlink v0.1.0
github.com/opencontainers/go-digest v1.0.0 github.com/opencontainers/go-digest v1.0.0
github.com/opencontainers/image-spec v1.0.1 github.com/opencontainers/image-spec v1.0.1
github.com/opencontainers/runc v1.0.0-rc95 github.com/opencontainers/runc v1.0.0
github.com/opencontainers/runtime-spec v1.0.3-0.20210326190908-1c3f411f0417 github.com/opencontainers/runtime-spec v1.0.3-0.20210326190908-1c3f411f0417
github.com/opencontainers/selinux v1.8.2 github.com/opencontainers/selinux v1.8.2
github.com/pelletier/go-toml v1.8.1 github.com/pelletier/go-toml v1.8.1
@ -47,7 +47,7 @@ require (
github.com/prometheus/client_golang v1.7.1 github.com/prometheus/client_golang v1.7.1
github.com/prometheus/procfs v0.6.0 // indirect; temporarily force v0.6.0, which was previously defined in imgcrypt as explicit version github.com/prometheus/procfs v0.6.0 // indirect; temporarily force v0.6.0, which was previously defined in imgcrypt as explicit version
github.com/satori/go.uuid v1.2.0 // indirect github.com/satori/go.uuid v1.2.0 // indirect
github.com/sirupsen/logrus v1.7.0 github.com/sirupsen/logrus v1.8.1
github.com/stretchr/testify v1.6.1 github.com/stretchr/testify v1.6.1
github.com/tchap/go-patricia v2.2.6+incompatible github.com/tchap/go-patricia v2.2.6+incompatible
github.com/urfave/cli v1.22.2 github.com/urfave/cli v1.22.2
@ -56,7 +56,7 @@ require (
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a golang.org/x/sync v0.0.0-20201207232520-09787c993a3a
golang.org/x/sys v0.0.0-20210426230700-d19ff857e887 golang.org/x/sys v0.0.0-20210426230700-d19ff857e887
google.golang.org/grpc v1.38.0 google.golang.org/grpc v1.38.0
google.golang.org/protobuf v1.25.0 google.golang.org/protobuf v1.26.0
gotest.tools/v3 v3.0.3 gotest.tools/v3 v3.0.3
k8s.io/api v0.20.6 k8s.io/api v0.20.6
k8s.io/apimachinery v0.20.6 k8s.io/apimachinery v0.20.6

27
go.sum
View File

@ -75,8 +75,8 @@ github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5P
github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
github.com/cilium/ebpf v0.2.0/go.mod h1:To2CFviqOWL/M0gIMsvSMlqe7em/l1ALkX1PyjrX2Qs= github.com/cilium/ebpf v0.2.0/go.mod h1:To2CFviqOWL/M0gIMsvSMlqe7em/l1ALkX1PyjrX2Qs=
github.com/cilium/ebpf v0.4.0/go.mod h1:4tRaxcgiL706VnOzHOdBlY8IEAIdxINsQBcU4xJJXRs= github.com/cilium/ebpf v0.4.0/go.mod h1:4tRaxcgiL706VnOzHOdBlY8IEAIdxINsQBcU4xJJXRs=
github.com/cilium/ebpf v0.5.0 h1:E1KshmrMEtkMP2UjlWzfmUV1owWY+BnbL5FxxuatnrU= github.com/cilium/ebpf v0.6.1 h1:n6ZUOkSFi6OwcMeTCFaDQx2Onx2rEikQo69315MNbdc=
github.com/cilium/ebpf v0.5.0/go.mod h1:4tRaxcgiL706VnOzHOdBlY8IEAIdxINsQBcU4xJJXRs= github.com/cilium/ebpf v0.6.1/go.mod h1:4tRaxcgiL706VnOzHOdBlY8IEAIdxINsQBcU4xJJXRs=
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
@ -129,8 +129,8 @@ github.com/coreos/go-systemd v0.0.0-20180511133405-39ca1b05acc7/go.mod h1:F5haX7
github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e h1:Wf6HqHfScWJN9/ZjdUKyjop4mf3Qdd+1TvvltAvM3m8= github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e h1:Wf6HqHfScWJN9/ZjdUKyjop4mf3Qdd+1TvvltAvM3m8=
github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
github.com/coreos/go-systemd/v22 v22.1.0/go.mod h1:xO0FLkIi5MaZafQlIrOotqXZ90ih+1atmu1JpKERPPk= github.com/coreos/go-systemd/v22 v22.1.0/go.mod h1:xO0FLkIi5MaZafQlIrOotqXZ90ih+1atmu1JpKERPPk=
github.com/coreos/go-systemd/v22 v22.3.1 h1:7OO2CXWMYNDdaAzP51t4lCCZWwpQHmvPbm9sxWjm3So= github.com/coreos/go-systemd/v22 v22.3.2 h1:D9/bQk5vlXQFZ6Kwuu6zaiXJ9oTPe68++AzAJc1DzSI=
github.com/coreos/go-systemd/v22 v22.3.1/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/coreos/pkg v0.0.0-20160727233714-3ac0863d7acf/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA= github.com/coreos/pkg v0.0.0-20160727233714-3ac0863d7acf/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA=
github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA= github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA=
github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU=
@ -233,8 +233,9 @@ github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:W
github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0=
github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8=
github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
github.com/golang/protobuf v1.4.3 h1:JjCZWpVbqXDqFVmTfYWEVTMIYrL/NPdPSCHPJ0T/raM=
github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
github.com/golang/protobuf v1.5.0 h1:LUVKkCeviFUMKqHa4tXIIij/lbhnMbP7Fn5wKdKkRh4=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
@ -243,8 +244,9 @@ github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.4 h1:L8R9j+yAqZuZjsqh/z+F1NCffTKKLShY6zXTItVIZ8M=
github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/gofuzz v1.1.0 h1:Hsa8mG0dQ46ij8Sl2AYJDUv1oA9/d6Vk+3LG99Oe02g= github.com/google/gofuzz v1.1.0 h1:Hsa8mG0dQ46ij8Sl2AYJDUv1oA9/d6Vk+3LG99Oe02g=
github.com/google/gofuzz v1.1.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.1.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
@ -367,13 +369,12 @@ github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.0.1 h1:JMemWkRwHx4Zj+fVxWoMCFm/8sYGGrUVojFA6h/TRcI= github.com/opencontainers/image-spec v1.0.1 h1:JMemWkRwHx4Zj+fVxWoMCFm/8sYGGrUVojFA6h/TRcI=
github.com/opencontainers/image-spec v1.0.1/go.mod h1:BtxoFyWECRxE4U/7sNtV5W15zMzWCbyJoFRP3s7yZA0= github.com/opencontainers/image-spec v1.0.1/go.mod h1:BtxoFyWECRxE4U/7sNtV5W15zMzWCbyJoFRP3s7yZA0=
github.com/opencontainers/runc v1.0.0-rc95 h1:RMuWVfY3E1ILlVsC3RhIq38n4sJtlOFwU9gfFZSqrd0= github.com/opencontainers/runc v1.0.0 h1:QOhAQAYUlKeofuyeKdR6ITvOnXLPbEAjPMjz9wCUXcU=
github.com/opencontainers/runc v1.0.0-rc95/go.mod h1:z+bZxa/+Tz/FmYVWkhUajJdzFeOqjc5vrqskhVyHGUM= github.com/opencontainers/runc v1.0.0/go.mod h1:MU2S3KEB2ZExnhnAQYbwjdYV6HwKtDlNbA2Z2OeNDeA=
github.com/opencontainers/runtime-spec v1.0.2/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0= github.com/opencontainers/runtime-spec v1.0.2/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0=
github.com/opencontainers/runtime-spec v1.0.3-0.20200929063507-e6143ca7d51d/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0= github.com/opencontainers/runtime-spec v1.0.3-0.20200929063507-e6143ca7d51d/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0=
github.com/opencontainers/runtime-spec v1.0.3-0.20210326190908-1c3f411f0417 h1:3snG66yBm59tKhhSPQrQ/0bCrv1LQbKt40LnUPiUxdc= github.com/opencontainers/runtime-spec v1.0.3-0.20210326190908-1c3f411f0417 h1:3snG66yBm59tKhhSPQrQ/0bCrv1LQbKt40LnUPiUxdc=
github.com/opencontainers/runtime-spec v1.0.3-0.20210326190908-1c3f411f0417/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0= github.com/opencontainers/runtime-spec v1.0.3-0.20210326190908-1c3f411f0417/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0=
github.com/opencontainers/selinux v1.8.0/go.mod h1:RScLhm78qiWa2gbVCcGkC7tCGdgk3ogry1nUQF8Evvo=
github.com/opencontainers/selinux v1.8.2 h1:c4ca10UMgRcvZ6h0K4HtS15UaVSBEaE+iln2LVpAuGc= github.com/opencontainers/selinux v1.8.2 h1:c4ca10UMgRcvZ6h0K4HtS15UaVSBEaE+iln2LVpAuGc=
github.com/opencontainers/selinux v1.8.2/go.mod h1:MUIHuUEvKB1wtJjQdOyYRgOnLD2xAPP8dBsCoU0KuF8= github.com/opencontainers/selinux v1.8.2/go.mod h1:MUIHuUEvKB1wtJjQdOyYRgOnLD2xAPP8dBsCoU0KuF8=
github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic=
@ -428,8 +429,9 @@ github.com/sirupsen/logrus v1.0.6/go.mod h1:pMByvHTf9Beacp5x1UXfOR9xyW/9antXMhjM
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
github.com/sirupsen/logrus v1.7.0 h1:ShrD1U9pZB12TX0cVy0DtePoCH97K8EtX+mg7ZARUtM=
github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE=
github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM= github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM=
github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ= github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ=
@ -470,7 +472,6 @@ github.com/vishvananda/netlink v1.1.0/go.mod h1:cTgwzPIzzgDAYoQrMm0EdrjRUBkTqKYp
github.com/vishvananda/netlink v1.1.1-0.20201029203352-d40f9887b852/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhgX83tXhKS2B/PRMpOho= github.com/vishvananda/netlink v1.1.1-0.20201029203352-d40f9887b852/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhgX83tXhKS2B/PRMpOho=
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17pCcGlemwknint6hfoeCVQrEMVwxRLRjXpq+BU= github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17pCcGlemwknint6hfoeCVQrEMVwxRLRjXpq+BU=
github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
github.com/willf/bitset v1.1.11/go.mod h1:83CECat5yLh5zVOf4P1ErAgKA5UDvKtgyUABdr3+MjI=
github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU=
github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
@ -718,8 +719,10 @@ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miE
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
google.golang.org/protobuf v1.25.0 h1:Ejskq+SyPohKW+1uil0JJMtmHCgJPJ/qWTxr8qp+R4c=
google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.26.0 h1:bxAC2xTBsZGibn2RTntX0oH50xLsqy1OxA9tTL3p/lk=
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
gopkg.in/airbrake/gobrake.v2 v2.0.9/go.mod h1:/h5ZAUhDkGaJfjzjKLSjv6zCL6O0LLBxU4K+aSYdM/U= gopkg.in/airbrake/gobrake.v2 v2.0.9/go.mod h1:/h5ZAUhDkGaJfjzjKLSjv6zCL6O0LLBxU4K+aSYdM/U=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

View File

@ -18,6 +18,23 @@ reason about the proposed changes.
## Running the tests ## Running the tests
Many of the tests require privileges to set resource limits and load eBPF code. Many of the tests require privileges to set resource limits and load eBPF code.
The easiest way to obtain these is to run the tests with `sudo`: The easiest way to obtain these is to run the tests with `sudo`.
To test the current package with your local kernel you can simply run:
```
go test -exec sudo ./...
```
To test the current package with a different kernel version you can use the [run-tests.sh](run-tests.sh) script.
It requires [virtme](https://github.com/amluto/virtme) and qemu to be installed.
Examples:
```bash
# Run all tests on a 5.4 kernel
./run-tests.sh 5.4
# Run a subset of tests:
./run-tests.sh 5.4 go test ./link
```
sudo go test ./...

View File

@ -1,7 +1,7 @@
# The development version of clang is distributed as the 'clang' binary, # The development version of clang is distributed as the 'clang' binary,
# while stable/released versions have a version number attached. # while stable/released versions have a version number attached.
# Pin the default clang to a stable version. # Pin the default clang to a stable version.
CLANG ?= clang-11 CLANG ?= clang-12
CFLAGS := -target bpf -O2 -g -Wall -Werror $(CFLAGS) CFLAGS := -target bpf -O2 -g -Wall -Werror $(CFLAGS)
# Obtain an absolute path to the directory of the Makefile. # Obtain an absolute path to the directory of the Makefile.
@ -17,7 +17,7 @@ VERSION := $(shell cat ${REPODIR}/testdata/docker/VERSION)
TARGETS := \ TARGETS := \
testdata/loader-clang-7 \ testdata/loader-clang-7 \
testdata/loader-clang-9 \ testdata/loader-clang-9 \
testdata/loader-clang-11 \ testdata/loader-$(CLANG) \
testdata/invalid_map \ testdata/invalid_map \
testdata/raw_tracepoint \ testdata/raw_tracepoint \
testdata/invalid_map_static \ testdata/invalid_map_static \
@ -33,6 +33,7 @@ TARGETS := \
docker-all: docker-all:
docker run --rm --user "${UIDGID}" \ docker run --rm --user "${UIDGID}" \
-v "${REPODIR}":/ebpf -w /ebpf --env MAKEFLAGS \ -v "${REPODIR}":/ebpf -w /ebpf --env MAKEFLAGS \
--env CFLAGS="-fdebug-prefix-map=/ebpf=." \
"${IMAGE}:${VERSION}" \ "${IMAGE}:${VERSION}" \
make all make all
@ -47,6 +48,8 @@ clean:
-$(RM) internal/btf/testdata/*.elf -$(RM) internal/btf/testdata/*.elf
all: $(addsuffix -el.elf,$(TARGETS)) $(addsuffix -eb.elf,$(TARGETS)) all: $(addsuffix -el.elf,$(TARGETS)) $(addsuffix -eb.elf,$(TARGETS))
ln -srf testdata/loader-$(CLANG)-el.elf testdata/loader-el.elf
ln -srf testdata/loader-$(CLANG)-eb.elf testdata/loader-eb.elf
testdata/loader-%-el.elf: testdata/loader.c testdata/loader-%-el.elf: testdata/loader.c
$* $(CFLAGS) -mlittle-endian -c $< -o $@ $* $(CFLAGS) -mlittle-endian -c $< -o $@

View File

@ -57,7 +57,7 @@ func (ins *Instruction) Unmarshal(r io.Reader, bo binary.ByteOrder) (uint64, err
return 0, fmt.Errorf("can't unmarshal registers: %s", err) return 0, fmt.Errorf("can't unmarshal registers: %s", err)
} }
if !bi.OpCode.isDWordLoad() { if !bi.OpCode.IsDWordLoad() {
return InstructionSize, nil return InstructionSize, nil
} }
@ -80,7 +80,7 @@ func (ins Instruction) Marshal(w io.Writer, bo binary.ByteOrder) (uint64, error)
return 0, errors.New("invalid opcode") return 0, errors.New("invalid opcode")
} }
isDWordLoad := ins.OpCode.isDWordLoad() isDWordLoad := ins.OpCode.IsDWordLoad()
cons := int32(ins.Constant) cons := int32(ins.Constant)
if isDWordLoad { if isDWordLoad {
@ -123,7 +123,7 @@ func (ins Instruction) Marshal(w io.Writer, bo binary.ByteOrder) (uint64, error)
// //
// Returns an error if the instruction doesn't load a map. // Returns an error if the instruction doesn't load a map.
func (ins *Instruction) RewriteMapPtr(fd int) error { func (ins *Instruction) RewriteMapPtr(fd int) error {
if !ins.OpCode.isDWordLoad() { if !ins.OpCode.IsDWordLoad() {
return fmt.Errorf("%s is not a 64 bit load", ins.OpCode) return fmt.Errorf("%s is not a 64 bit load", ins.OpCode)
} }
@ -138,15 +138,19 @@ func (ins *Instruction) RewriteMapPtr(fd int) error {
return nil return nil
} }
func (ins *Instruction) mapPtr() uint32 { // MapPtr returns the map fd for this instruction.
return uint32(uint64(ins.Constant) & math.MaxUint32) //
// The result is undefined if the instruction is not a load from a map,
// see IsLoadFromMap.
func (ins *Instruction) MapPtr() int {
return int(int32(uint64(ins.Constant) & math.MaxUint32))
} }
// RewriteMapOffset changes the offset of a direct load from a map. // RewriteMapOffset changes the offset of a direct load from a map.
// //
// Returns an error if the instruction is not a direct load. // Returns an error if the instruction is not a direct load.
func (ins *Instruction) RewriteMapOffset(offset uint32) error { func (ins *Instruction) RewriteMapOffset(offset uint32) error {
if !ins.OpCode.isDWordLoad() { if !ins.OpCode.IsDWordLoad() {
return fmt.Errorf("%s is not a 64 bit load", ins.OpCode) return fmt.Errorf("%s is not a 64 bit load", ins.OpCode)
} }
@ -163,10 +167,10 @@ func (ins *Instruction) mapOffset() uint32 {
return uint32(uint64(ins.Constant) >> 32) return uint32(uint64(ins.Constant) >> 32)
} }
// isLoadFromMap returns true if the instruction loads from a map. // IsLoadFromMap returns true if the instruction loads from a map.
// //
// This covers both loading the map pointer and direct map value loads. // This covers both loading the map pointer and direct map value loads.
func (ins *Instruction) isLoadFromMap() bool { func (ins *Instruction) IsLoadFromMap() bool {
return ins.OpCode == LoadImmOp(DWord) && (ins.Src == PseudoMapFD || ins.Src == PseudoMapValue) return ins.OpCode == LoadImmOp(DWord) && (ins.Src == PseudoMapFD || ins.Src == PseudoMapValue)
} }
@ -177,6 +181,12 @@ func (ins *Instruction) IsFunctionCall() bool {
return ins.OpCode.JumpOp() == Call && ins.Src == PseudoCall return ins.OpCode.JumpOp() == Call && ins.Src == PseudoCall
} }
// IsConstantLoad returns true if the instruction loads a constant of the
// given size.
func (ins *Instruction) IsConstantLoad(size Size) bool {
return ins.OpCode == LoadImmOp(size) && ins.Src == R0 && ins.Offset == 0
}
// Format implements fmt.Formatter. // Format implements fmt.Formatter.
func (ins Instruction) Format(f fmt.State, c rune) { func (ins Instruction) Format(f fmt.State, c rune) {
if c != 'v' { if c != 'v' {
@ -197,8 +207,8 @@ func (ins Instruction) Format(f fmt.State, c rune) {
return return
} }
if ins.isLoadFromMap() { if ins.IsLoadFromMap() {
fd := int32(ins.mapPtr()) fd := ins.MapPtr()
switch ins.Src { switch ins.Src {
case PseudoMapFD: case PseudoMapFD:
fmt.Fprintf(f, "LoadMapPtr dst: %s fd: %d", ins.Dst, fd) fmt.Fprintf(f, "LoadMapPtr dst: %s fd: %d", ins.Dst, fd)
@ -403,7 +413,7 @@ func (insns Instructions) Marshal(w io.Writer, bo binary.ByteOrder) error {
func (insns Instructions) Tag(bo binary.ByteOrder) (string, error) { func (insns Instructions) Tag(bo binary.ByteOrder) (string, error) {
h := sha1.New() h := sha1.New()
for i, ins := range insns { for i, ins := range insns {
if ins.isLoadFromMap() { if ins.IsLoadFromMap() {
ins.Constant = 0 ins.Constant = 0
} }
_, err := ins.Marshal(h, bo) _, err := ins.Marshal(h, bo)

View File

@ -111,7 +111,7 @@ func LoadMapPtr(dst Register, fd int) Instruction {
OpCode: LoadImmOp(DWord), OpCode: LoadImmOp(DWord),
Dst: dst, Dst: dst,
Src: PseudoMapFD, Src: PseudoMapFD,
Constant: int64(fd), Constant: int64(uint32(fd)),
} }
} }

View File

@ -69,13 +69,13 @@ const InvalidOpCode OpCode = 0xff
// rawInstructions returns the number of BPF instructions required // rawInstructions returns the number of BPF instructions required
// to encode this opcode. // to encode this opcode.
func (op OpCode) rawInstructions() int { func (op OpCode) rawInstructions() int {
if op.isDWordLoad() { if op.IsDWordLoad() {
return 2 return 2
} }
return 1 return 1
} }
func (op OpCode) isDWordLoad() bool { func (op OpCode) IsDWordLoad() bool {
return op == LoadImmOp(DWord) return op == LoadImmOp(DWord)
} }

View File

@ -3,6 +3,7 @@ package ebpf
import ( import (
"errors" "errors"
"fmt" "fmt"
"io"
"math" "math"
"reflect" "reflect"
"strings" "strings"
@ -89,8 +90,8 @@ func (cs *CollectionSpec) RewriteMaps(maps map[string]*Map) error {
// //
// The constant must be defined like so in the C program: // The constant must be defined like so in the C program:
// //
// static volatile const type foobar; // volatile const type foobar;
// static volatile const type foobar = default; // volatile const type foobar = default;
// //
// Replacement values must be of the same length as the C sizeof(type). // Replacement values must be of the same length as the C sizeof(type).
// If necessary, they are marshalled according to the same rules as // If necessary, they are marshalled according to the same rules as
@ -269,11 +270,21 @@ func NewCollectionWithOptions(spec *CollectionSpec, opts CollectionOptions) (*Co
}, nil }, nil
} }
type btfHandleCache map[*btf.Spec]*btf.Handle type handleCache struct {
btfHandles map[*btf.Spec]*btf.Handle
btfSpecs map[io.ReaderAt]*btf.Spec
}
func (btfs btfHandleCache) load(spec *btf.Spec) (*btf.Handle, error) { func newHandleCache() *handleCache {
if btfs[spec] != nil { return &handleCache{
return btfs[spec], nil btfHandles: make(map[*btf.Spec]*btf.Handle),
btfSpecs: make(map[io.ReaderAt]*btf.Spec),
}
}
func (hc handleCache) btfHandle(spec *btf.Spec) (*btf.Handle, error) {
if hc.btfHandles[spec] != nil {
return hc.btfHandles[spec], nil
} }
handle, err := btf.NewHandle(spec) handle, err := btf.NewHandle(spec)
@ -281,14 +292,30 @@ func (btfs btfHandleCache) load(spec *btf.Spec) (*btf.Handle, error) {
return nil, err return nil, err
} }
btfs[spec] = handle hc.btfHandles[spec] = handle
return handle, nil return handle, nil
} }
func (btfs btfHandleCache) close() { func (hc handleCache) btfSpec(rd io.ReaderAt) (*btf.Spec, error) {
for _, handle := range btfs { if hc.btfSpecs[rd] != nil {
return hc.btfSpecs[rd], nil
}
spec, err := btf.LoadSpecFromReader(rd)
if err != nil {
return nil, err
}
hc.btfSpecs[rd] = spec
return spec, nil
}
func (hc handleCache) close() {
for _, handle := range hc.btfHandles {
handle.Close() handle.Close()
} }
hc.btfHandles = nil
hc.btfSpecs = nil
} }
func lazyLoadCollection(coll *CollectionSpec, opts *CollectionOptions) ( func lazyLoadCollection(coll *CollectionSpec, opts *CollectionOptions) (
@ -300,12 +327,12 @@ func lazyLoadCollection(coll *CollectionSpec, opts *CollectionOptions) (
var ( var (
maps = make(map[string]*Map) maps = make(map[string]*Map)
progs = make(map[string]*Program) progs = make(map[string]*Program)
btfs = make(btfHandleCache) handles = newHandleCache()
skipMapsAndProgs = false skipMapsAndProgs = false
) )
cleanup = func() { cleanup = func() {
btfs.close() handles.close()
if skipMapsAndProgs { if skipMapsAndProgs {
return return
@ -335,7 +362,7 @@ func lazyLoadCollection(coll *CollectionSpec, opts *CollectionOptions) (
return nil, fmt.Errorf("missing map %s", mapName) return nil, fmt.Errorf("missing map %s", mapName)
} }
m, err := newMapWithOptions(mapSpec, opts.Maps, btfs) m, err := newMapWithOptions(mapSpec, opts.Maps, handles)
if err != nil { if err != nil {
return nil, fmt.Errorf("map %s: %w", mapName, err) return nil, fmt.Errorf("map %s: %w", mapName, err)
} }
@ -360,7 +387,7 @@ func lazyLoadCollection(coll *CollectionSpec, opts *CollectionOptions) (
for i := range progSpec.Instructions { for i := range progSpec.Instructions {
ins := &progSpec.Instructions[i] ins := &progSpec.Instructions[i]
if ins.OpCode != asm.LoadImmOp(asm.DWord) || ins.Reference == "" { if !ins.IsLoadFromMap() || ins.Reference == "" {
continue continue
} }
@ -372,7 +399,7 @@ func lazyLoadCollection(coll *CollectionSpec, opts *CollectionOptions) (
m, err := loadMap(ins.Reference) m, err := loadMap(ins.Reference)
if err != nil { if err != nil {
return nil, fmt.Errorf("program %s: %s", progName, err) return nil, fmt.Errorf("program %s: %w", progName, err)
} }
fd := m.FD() fd := m.FD()
@ -384,7 +411,7 @@ func lazyLoadCollection(coll *CollectionSpec, opts *CollectionOptions) (
} }
} }
prog, err := newProgramWithOptions(progSpec, opts.Programs, btfs) prog, err := newProgramWithOptions(progSpec, opts.Programs, handles)
if err != nil { if err != nil {
return nil, fmt.Errorf("program %s: %w", progName, err) return nil, fmt.Errorf("program %s: %w", progName, err)
} }
@ -534,7 +561,7 @@ func assignValues(to interface{}, valueOf func(reflect.Type, string) (reflect.Va
} }
if err != nil { if err != nil {
return fmt.Errorf("field %s: %s", field.Name, err) return fmt.Errorf("field %s: %w", field.Name, err)
} }
} }

View File

@ -96,7 +96,7 @@ func LoadCollectionSpecFromReader(rd io.ReaderAt) (*CollectionSpec, error) {
} }
btfSpec, err := btf.LoadSpecFromReader(rd) btfSpec, err := btf.LoadSpecFromReader(rd)
if err != nil { if err != nil && !errors.Is(err, btf.ErrNotFound) {
return nil, fmt.Errorf("load BTF: %w", err) return nil, fmt.Errorf("load BTF: %w", err)
} }
@ -159,7 +159,7 @@ func LoadCollectionSpecFromReader(rd io.ReaderAt) (*CollectionSpec, error) {
} }
if target.Flags&elf.SHF_STRINGS > 0 { if target.Flags&elf.SHF_STRINGS > 0 {
return nil, fmt.Errorf("section %q: string %q is not stack allocated: %w", section.Name, rel.Name, ErrNotSupported) return nil, fmt.Errorf("section %q: string is not stack allocated: %w", section.Name, ErrNotSupported)
} }
target.references++ target.references++
@ -374,17 +374,25 @@ func (ec *elfCode) relocateInstruction(ins *asm.Instruction, rel elf.Symbol) err
} }
case dataSection: case dataSection:
var offset uint32
switch typ { switch typ {
case elf.STT_SECTION: case elf.STT_SECTION:
if bind != elf.STB_LOCAL { if bind != elf.STB_LOCAL {
return fmt.Errorf("direct load: %s: unsupported relocation %s", name, bind) return fmt.Errorf("direct load: %s: unsupported relocation %s", name, bind)
} }
// This is really a reference to a static symbol, which clang doesn't
// emit a symbol table entry for. Instead it encodes the offset in
// the instruction itself.
offset = uint32(uint64(ins.Constant))
case elf.STT_OBJECT: case elf.STT_OBJECT:
if bind != elf.STB_GLOBAL { if bind != elf.STB_GLOBAL {
return fmt.Errorf("direct load: %s: unsupported relocation %s", name, bind) return fmt.Errorf("direct load: %s: unsupported relocation %s", name, bind)
} }
offset = uint32(rel.Value)
default: default:
return fmt.Errorf("incorrect relocation type %v for direct map load", typ) return fmt.Errorf("incorrect relocation type %v for direct map load", typ)
} }
@ -394,10 +402,8 @@ func (ec *elfCode) relocateInstruction(ins *asm.Instruction, rel elf.Symbol) err
// it's not clear how to encode that into Instruction. // it's not clear how to encode that into Instruction.
name = target.Name name = target.Name
// For some reason, clang encodes the offset of the symbol its // The kernel expects the offset in the second basic BPF instruction.
// section in the first basic BPF instruction, while the kernel ins.Constant = int64(uint64(offset) << 32)
// expects it in the second one.
ins.Constant <<= 32
ins.Src = asm.PseudoMapValue ins.Src = asm.PseudoMapValue
// Mark the instruction as needing an update when creating the // Mark the instruction as needing an update when creating the
@ -491,33 +497,38 @@ func (ec *elfCode) loadMaps(maps map[string]*MapSpec) error {
return fmt.Errorf("section %s: missing symbol for map at offset %d", sec.Name, offset) return fmt.Errorf("section %s: missing symbol for map at offset %d", sec.Name, offset)
} }
if maps[mapSym.Name] != nil { mapName := mapSym.Name
if maps[mapName] != nil {
return fmt.Errorf("section %v: map %v already exists", sec.Name, mapSym) return fmt.Errorf("section %v: map %v already exists", sec.Name, mapSym)
} }
lr := io.LimitReader(r, int64(size)) lr := io.LimitReader(r, int64(size))
spec := MapSpec{ spec := MapSpec{
Name: SanitizeName(mapSym.Name, -1), Name: SanitizeName(mapName, -1),
} }
switch { switch {
case binary.Read(lr, ec.ByteOrder, &spec.Type) != nil: case binary.Read(lr, ec.ByteOrder, &spec.Type) != nil:
return fmt.Errorf("map %v: missing type", mapSym) return fmt.Errorf("map %s: missing type", mapName)
case binary.Read(lr, ec.ByteOrder, &spec.KeySize) != nil: case binary.Read(lr, ec.ByteOrder, &spec.KeySize) != nil:
return fmt.Errorf("map %v: missing key size", mapSym) return fmt.Errorf("map %s: missing key size", mapName)
case binary.Read(lr, ec.ByteOrder, &spec.ValueSize) != nil: case binary.Read(lr, ec.ByteOrder, &spec.ValueSize) != nil:
return fmt.Errorf("map %v: missing value size", mapSym) return fmt.Errorf("map %s: missing value size", mapName)
case binary.Read(lr, ec.ByteOrder, &spec.MaxEntries) != nil: case binary.Read(lr, ec.ByteOrder, &spec.MaxEntries) != nil:
return fmt.Errorf("map %v: missing max entries", mapSym) return fmt.Errorf("map %s: missing max entries", mapName)
case binary.Read(lr, ec.ByteOrder, &spec.Flags) != nil: case binary.Read(lr, ec.ByteOrder, &spec.Flags) != nil:
return fmt.Errorf("map %v: missing flags", mapSym) return fmt.Errorf("map %s: missing flags", mapName)
} }
if _, err := io.Copy(internal.DiscardZeroes{}, lr); err != nil { if _, err := io.Copy(internal.DiscardZeroes{}, lr); err != nil {
return fmt.Errorf("map %v: unknown and non-zero fields in definition", mapSym) return fmt.Errorf("map %s: unknown and non-zero fields in definition", mapName)
} }
maps[mapSym.Name] = &spec if err := spec.clampPerfEventArraySize(); err != nil {
return fmt.Errorf("map %s: %w", mapName, err)
}
maps[mapName] = &spec
} }
} }
@ -565,6 +576,10 @@ func (ec *elfCode) loadBTFMaps(maps map[string]*MapSpec) error {
return fmt.Errorf("map %v: %w", name, err) return fmt.Errorf("map %v: %w", name, err)
} }
if err := mapSpec.clampPerfEventArraySize(); err != nil {
return fmt.Errorf("map %v: %w", name, err)
}
maps[name] = mapSpec maps[name] = mapSpec
} }
} }
@ -847,6 +862,8 @@ func getProgType(sectionName string) (ProgramType, AttachType, uint32, string) {
"uretprobe/": {Kprobe, AttachNone, 0}, "uretprobe/": {Kprobe, AttachNone, 0},
"tracepoint/": {TracePoint, AttachNone, 0}, "tracepoint/": {TracePoint, AttachNone, 0},
"raw_tracepoint/": {RawTracepoint, AttachNone, 0}, "raw_tracepoint/": {RawTracepoint, AttachNone, 0},
"raw_tp/": {RawTracepoint, AttachNone, 0},
"tp_btf/": {Tracing, AttachTraceRawTp, 0},
"xdp": {XDP, AttachNone, 0}, "xdp": {XDP, AttachNone, 0},
"perf_event": {PerfEvent, AttachNone, 0}, "perf_event": {PerfEvent, AttachNone, 0},
"lwt_in": {LWTIn, AttachNone, 0}, "lwt_in": {LWTIn, AttachNone, 0},

View File

@ -35,7 +35,7 @@ type Spec struct {
namedTypes map[string][]namedType namedTypes map[string][]namedType
funcInfos map[string]extInfo funcInfos map[string]extInfo
lineInfos map[string]extInfo lineInfos map[string]extInfo
coreRelos map[string]bpfCoreRelos coreRelos map[string]coreRelos
byteOrder binary.ByteOrder byteOrder binary.ByteOrder
} }
@ -53,7 +53,7 @@ type btfHeader struct {
// LoadSpecFromReader reads BTF sections from an ELF. // LoadSpecFromReader reads BTF sections from an ELF.
// //
// Returns a nil Spec and no error if no BTF was present. // Returns ErrNotFound if the reader contains no BTF.
func LoadSpecFromReader(rd io.ReaderAt) (*Spec, error) { func LoadSpecFromReader(rd io.ReaderAt) (*Spec, error) {
file, err := internal.NewSafeELFFile(rd) file, err := internal.NewSafeELFFile(rd)
if err != nil { if err != nil {
@ -67,7 +67,7 @@ func LoadSpecFromReader(rd io.ReaderAt) (*Spec, error) {
} }
if btfSection == nil { if btfSection == nil {
return nil, nil return nil, fmt.Errorf("btf: %w", ErrNotFound)
} }
symbols, err := file.Symbols() symbols, err := file.Symbols()
@ -438,13 +438,13 @@ func (s *Spec) Program(name string, length uint64) (*Program, error) {
funcInfos, funcOK := s.funcInfos[name] funcInfos, funcOK := s.funcInfos[name]
lineInfos, lineOK := s.lineInfos[name] lineInfos, lineOK := s.lineInfos[name]
coreRelos, coreOK := s.coreRelos[name] relos, coreOK := s.coreRelos[name]
if !funcOK && !lineOK && !coreOK { if !funcOK && !lineOK && !coreOK {
return nil, fmt.Errorf("no extended BTF info for section %s", name) return nil, fmt.Errorf("no extended BTF info for section %s", name)
} }
return &Program{s, length, funcInfos, lineInfos, coreRelos}, nil return &Program{s, length, funcInfos, lineInfos, relos}, nil
} }
// Datasec returns the BTF required to create maps which represent data sections. // Datasec returns the BTF required to create maps which represent data sections.
@ -491,7 +491,8 @@ func (s *Spec) FindType(name string, typ Type) error {
return fmt.Errorf("type %s: %w", name, ErrNotFound) return fmt.Errorf("type %s: %w", name, ErrNotFound)
} }
value := reflect.Indirect(reflect.ValueOf(copyType(candidate))) cpy, _ := copyType(candidate, nil)
value := reflect.Indirect(reflect.ValueOf(cpy))
reflect.Indirect(reflect.ValueOf(typ)).Set(value) reflect.Indirect(reflect.ValueOf(typ)).Set(value)
return nil return nil
} }
@ -606,7 +607,7 @@ type Program struct {
spec *Spec spec *Spec
length uint64 length uint64
funcInfos, lineInfos extInfo funcInfos, lineInfos extInfo
coreRelos bpfCoreRelos coreRelos coreRelos
} }
// ProgramSpec returns the Spec needed for loading function and line infos into the kernel. // ProgramSpec returns the Spec needed for loading function and line infos into the kernel.
@ -665,16 +666,23 @@ func ProgramLineInfos(s *Program) (recordSize uint32, bytes []byte, err error) {
return s.lineInfos.recordSize, bytes, nil return s.lineInfos.recordSize, bytes, nil
} }
// ProgramRelocations returns the CO-RE relocations required to adjust the // ProgramFixups returns the changes required to adjust the program to the target.
// program to the target.
// //
// This is a free function instead of a method to hide it from users // This is a free function instead of a method to hide it from users
// of package ebpf. // of package ebpf.
func ProgramRelocations(s *Program, target *Spec) (map[uint64]Relocation, error) { func ProgramFixups(s *Program, target *Spec) (COREFixups, error) {
if len(s.coreRelos) == 0 { if len(s.coreRelos) == 0 {
return nil, nil return nil, nil
} }
if target == nil {
var err error
target, err = LoadKernelSpec()
if err != nil {
return nil, err
}
}
return coreRelocate(s.spec, target, s.coreRelos) return coreRelocate(s.spec, target, s.coreRelos)
} }

View File

@ -3,29 +3,146 @@ package btf
import ( import (
"errors" "errors"
"fmt" "fmt"
"math"
"reflect" "reflect"
"sort"
"strconv" "strconv"
"strings" "strings"
"github.com/cilium/ebpf/asm"
) )
// Code in this file is derived from libbpf, which is available under a BSD // Code in this file is derived from libbpf, which is available under a BSD
// 2-Clause license. // 2-Clause license.
// Relocation describes a CO-RE relocation. // COREFixup is the result of computing a CO-RE relocation for a target.
type Relocation struct { type COREFixup struct {
Current uint32 Kind COREKind
New uint32 Local uint32
Target uint32
Poison bool
} }
func (r Relocation) equal(other Relocation) bool { func (f COREFixup) equal(other COREFixup) bool {
return r.Current == other.Current && r.New == other.New return f.Local == other.Local && f.Target == other.Target
} }
// coreReloKind is the type of CO-RE relocation func (f COREFixup) String() string {
type coreReloKind uint32 if f.Poison {
return fmt.Sprintf("%s=poison", f.Kind)
}
return fmt.Sprintf("%s=%d->%d", f.Kind, f.Local, f.Target)
}
func (f COREFixup) apply(ins *asm.Instruction) error {
if f.Poison {
return errors.New("can't poison individual instruction")
}
switch class := ins.OpCode.Class(); class {
case asm.LdXClass, asm.StClass, asm.StXClass:
if want := int16(f.Local); want != ins.Offset {
return fmt.Errorf("invalid offset %d, expected %d", ins.Offset, want)
}
if f.Target > math.MaxInt16 {
return fmt.Errorf("offset %d exceeds MaxInt16", f.Target)
}
ins.Offset = int16(f.Target)
case asm.LdClass:
if !ins.IsConstantLoad(asm.DWord) {
return fmt.Errorf("not a dword-sized immediate load")
}
if want := int64(f.Local); want != ins.Constant {
return fmt.Errorf("invalid immediate %d, expected %d", ins.Constant, want)
}
ins.Constant = int64(f.Target)
case asm.ALUClass:
if ins.OpCode.ALUOp() == asm.Swap {
return fmt.Errorf("relocation against swap")
}
fallthrough
case asm.ALU64Class:
if src := ins.OpCode.Source(); src != asm.ImmSource {
return fmt.Errorf("invalid source %s", src)
}
if want := int64(f.Local); want != ins.Constant {
return fmt.Errorf("invalid immediate %d, expected %d", ins.Constant, want)
}
if f.Target > math.MaxInt32 {
return fmt.Errorf("immediate %d exceeds MaxInt32", f.Target)
}
ins.Constant = int64(f.Target)
default:
return fmt.Errorf("invalid class %s", class)
}
return nil
}
func (f COREFixup) isNonExistant() bool {
return f.Kind.checksForExistence() && f.Target == 0
}
type COREFixups map[uint64]COREFixup
// Apply a set of CO-RE relocations to a BPF program.
func (fs COREFixups) Apply(insns asm.Instructions) (asm.Instructions, error) {
if len(fs) == 0 {
cpy := make(asm.Instructions, len(insns))
copy(cpy, insns)
return insns, nil
}
cpy := make(asm.Instructions, 0, len(insns))
iter := insns.Iterate()
for iter.Next() {
fixup, ok := fs[iter.Offset.Bytes()]
if !ok {
cpy = append(cpy, *iter.Ins)
continue
}
ins := *iter.Ins
if fixup.Poison {
const badRelo = asm.BuiltinFunc(0xbad2310)
cpy = append(cpy, badRelo.Call())
if ins.OpCode.IsDWordLoad() {
// 64 bit constant loads occupy two raw bpf instructions, so
// we need to add another instruction as padding.
cpy = append(cpy, badRelo.Call())
}
continue
}
if err := fixup.apply(&ins); err != nil {
return nil, fmt.Errorf("instruction %d, offset %d: %s: %w", iter.Index, iter.Offset.Bytes(), fixup.Kind, err)
}
cpy = append(cpy, ins)
}
return cpy, nil
}
// COREKind is the type of CO-RE relocation
type COREKind uint32
const ( const (
reloFieldByteOffset coreReloKind = iota /* field byte offset */ reloFieldByteOffset COREKind = iota /* field byte offset */
reloFieldByteSize /* field size in bytes */ reloFieldByteSize /* field size in bytes */
reloFieldExists /* field existence in target kernel */ reloFieldExists /* field existence in target kernel */
reloFieldSigned /* field signedness (0 - unsigned, 1 - signed) */ reloFieldSigned /* field signedness (0 - unsigned, 1 - signed) */
@ -39,7 +156,7 @@ const (
reloEnumvalValue /* enum value integer value */ reloEnumvalValue /* enum value integer value */
) )
func (k coreReloKind) String() string { func (k COREKind) String() string {
switch k { switch k {
case reloFieldByteOffset: case reloFieldByteOffset:
return "byte_off" return "byte_off"
@ -70,103 +187,249 @@ func (k coreReloKind) String() string {
} }
} }
func coreRelocate(local, target *Spec, coreRelos bpfCoreRelos) (map[uint64]Relocation, error) { func (k COREKind) checksForExistence() bool {
if target == nil { return k == reloEnumvalExists || k == reloTypeExists || k == reloFieldExists
var err error }
target, err = loadKernelSpec()
if err != nil {
return nil, err
}
}
func coreRelocate(local, target *Spec, relos coreRelos) (COREFixups, error) {
if local.byteOrder != target.byteOrder { if local.byteOrder != target.byteOrder {
return nil, fmt.Errorf("can't relocate %s against %s", local.byteOrder, target.byteOrder) return nil, fmt.Errorf("can't relocate %s against %s", local.byteOrder, target.byteOrder)
} }
relocations := make(map[uint64]Relocation, len(coreRelos)) var ids []TypeID
for _, relo := range coreRelos { relosByID := make(map[TypeID]coreRelos)
accessorStr, err := local.strings.Lookup(relo.AccessStrOff) result := make(COREFixups, len(relos))
for _, relo := range relos {
if relo.kind == reloTypeIDLocal {
// Filtering out reloTypeIDLocal here makes our lives a lot easier
// down the line, since it doesn't have a target at all.
if len(relo.accessor) > 1 || relo.accessor[0] != 0 {
return nil, fmt.Errorf("%s: unexpected accessor %v", relo.kind, relo.accessor)
}
result[uint64(relo.insnOff)] = COREFixup{
relo.kind,
uint32(relo.typeID),
uint32(relo.typeID),
false,
}
continue
}
relos, ok := relosByID[relo.typeID]
if !ok {
ids = append(ids, relo.typeID)
}
relosByID[relo.typeID] = append(relos, relo)
}
// Ensure we work on relocations in a deterministic order.
sort.Slice(ids, func(i, j int) bool {
return ids[i] < ids[j]
})
for _, id := range ids {
if int(id) >= len(local.types) {
return nil, fmt.Errorf("invalid type id %d", id)
}
localType := local.types[id]
named, ok := localType.(namedType)
if !ok || named.name() == "" {
return nil, fmt.Errorf("relocate unnamed or anonymous type %s: %w", localType, ErrNotSupported)
}
relos := relosByID[id]
targets := target.namedTypes[named.essentialName()]
fixups, err := coreCalculateFixups(localType, targets, relos)
if err != nil {
return nil, fmt.Errorf("relocate %s: %w", localType, err)
}
for i, relo := range relos {
result[uint64(relo.insnOff)] = fixups[i]
}
}
return result, nil
}
var errAmbiguousRelocation = errors.New("ambiguous relocation")
var errImpossibleRelocation = errors.New("impossible relocation")
// coreCalculateFixups calculates the fixups for the given relocations using
// the "best" target.
//
// The best target is determined by scoring: the less poisoning we have to do
// the better the target is.
func coreCalculateFixups(local Type, targets []namedType, relos coreRelos) ([]COREFixup, error) {
localID := local.ID()
local, err := copyType(local, skipQualifierAndTypedef)
if err != nil { if err != nil {
return nil, err return nil, err
} }
accessor, err := parseCoreAccessor(accessorStr) bestScore := len(relos)
var bestFixups []COREFixup
for i := range targets {
targetID := targets[i].ID()
target, err := copyType(targets[i], skipQualifierAndTypedef)
if err != nil { if err != nil {
return nil, fmt.Errorf("accessor %q: %s", accessorStr, err) return nil, err
} }
if int(relo.TypeID) >= len(local.types) { score := 0 // lower is better
return nil, fmt.Errorf("invalid type id %d", relo.TypeID) fixups := make([]COREFixup, 0, len(relos))
for _, relo := range relos {
fixup, err := coreCalculateFixup(local, localID, target, targetID, relo)
if err != nil {
return nil, fmt.Errorf("target %s: %w", target, err)
}
if fixup.Poison || fixup.isNonExistant() {
score++
}
fixups = append(fixups, fixup)
} }
typ := local.types[relo.TypeID] if score > bestScore {
// We have a better target already, ignore this one.
if relo.ReloKind == reloTypeIDLocal {
relocations[uint64(relo.InsnOff)] = Relocation{
uint32(typ.ID()),
uint32(typ.ID()),
}
continue continue
} }
named, ok := typ.(namedType) if score < bestScore {
if !ok || named.name() == "" { // This is the best target yet, use it.
return nil, fmt.Errorf("relocate anonymous type %s: %w", typ.String(), ErrNotSupported) bestScore = score
bestFixups = fixups
continue
} }
name := essentialName(named.name()) // Some other target has the same score as the current one. Make sure
res, err := coreCalculateRelocation(typ, target.namedTypes[name], relo.ReloKind, accessor) // the fixups agree with each other.
if err != nil { for i, fixup := range bestFixups {
return nil, fmt.Errorf("relocate %s: %w", name, err) if !fixup.equal(fixups[i]) {
return nil, fmt.Errorf("%s: multiple types match: %w", fixup.Kind, errAmbiguousRelocation)
}
}
} }
relocations[uint64(relo.InsnOff)] = res if bestFixups == nil {
// Nothing at all matched, probably because there are no suitable
// targets at all. Poison everything!
bestFixups = make([]COREFixup, len(relos))
for i, relo := range relos {
bestFixups[i] = COREFixup{Kind: relo.kind, Poison: true}
}
} }
return relocations, nil return bestFixups, nil
} }
var errAmbiguousRelocation = errors.New("ambiguous relocation") // coreCalculateFixup calculates the fixup for a single local type, target type
// and relocation.
func coreCalculateFixup(local Type, localID TypeID, target Type, targetID TypeID, relo coreRelo) (COREFixup, error) {
fixup := func(local, target uint32) (COREFixup, error) {
return COREFixup{relo.kind, local, target, false}, nil
}
poison := func() (COREFixup, error) {
if relo.kind.checksForExistence() {
return fixup(1, 0)
}
return COREFixup{relo.kind, 0, 0, true}, nil
}
zero := COREFixup{}
switch relo.kind {
case reloTypeIDTarget, reloTypeSize, reloTypeExists:
if len(relo.accessor) > 1 || relo.accessor[0] != 0 {
return zero, fmt.Errorf("%s: unexpected accessor %v", relo.kind, relo.accessor)
}
err := coreAreTypesCompatible(local, target)
if errors.Is(err, errImpossibleRelocation) {
return poison()
}
if err != nil {
return zero, fmt.Errorf("relocation %s: %w", relo.kind, err)
}
switch relo.kind {
case reloTypeExists:
return fixup(1, 1)
func coreCalculateRelocation(local Type, targets []namedType, kind coreReloKind, localAccessor coreAccessor) (Relocation, error) {
var relos []Relocation
var matches []Type
for _, target := range targets {
switch kind {
case reloTypeIDTarget: case reloTypeIDTarget:
if localAccessor[0] != 0 { return fixup(uint32(localID), uint32(targetID))
return Relocation{}, fmt.Errorf("%s: unexpected non-zero accessor", kind)
case reloTypeSize:
localSize, err := Sizeof(local)
if err != nil {
return zero, err
} }
if compat, err := coreAreTypesCompatible(local, target); err != nil { targetSize, err := Sizeof(target)
return Relocation{}, fmt.Errorf("%s: %s", kind, err) if err != nil {
} else if !compat { return zero, err
continue
} }
relos = append(relos, Relocation{uint32(target.ID()), uint32(target.ID())}) return fixup(uint32(localSize), uint32(targetSize))
default:
return Relocation{}, fmt.Errorf("relocation %s: %w", kind, ErrNotSupported)
}
matches = append(matches, target)
} }
if len(relos) == 0 { case reloEnumvalValue, reloEnumvalExists:
// TODO: Add switch for existence checks like reloEnumvalExists here. localValue, targetValue, err := coreFindEnumValue(local, relo.accessor, target)
if errors.Is(err, errImpossibleRelocation) {
// TODO: This might have to be poisoned. return poison()
return Relocation{}, fmt.Errorf("no relocation found, tried %v", targets) }
if err != nil {
return zero, fmt.Errorf("relocation %s: %w", relo.kind, err)
} }
relo := relos[0] switch relo.kind {
for _, altRelo := range relos[1:] { case reloEnumvalExists:
if !altRelo.equal(relo) { return fixup(1, 1)
return Relocation{}, fmt.Errorf("multiple types %v match: %w", matches, errAmbiguousRelocation)
case reloEnumvalValue:
return fixup(uint32(localValue.Value), uint32(targetValue.Value))
}
case reloFieldByteOffset, reloFieldByteSize, reloFieldExists:
if _, ok := target.(*Fwd); ok {
// We can't relocate fields using a forward declaration, so
// skip it. If a non-forward declaration is present in the BTF
// we'll find it in one of the other iterations.
return poison()
}
localField, targetField, err := coreFindField(local, relo.accessor, target)
if errors.Is(err, errImpossibleRelocation) {
return poison()
}
if err != nil {
return zero, fmt.Errorf("target %s: %w", target, err)
}
switch relo.kind {
case reloFieldExists:
return fixup(1, 1)
case reloFieldByteOffset:
return fixup(localField.offset/8, targetField.offset/8)
case reloFieldByteSize:
localSize, err := Sizeof(localField.Type)
if err != nil {
return zero, err
}
targetSize, err := Sizeof(targetField.Type)
if err != nil {
return zero, err
}
return fixup(uint32(localSize), uint32(targetSize))
} }
} }
return relo, nil return zero, fmt.Errorf("relocation %s: %w", relo.kind, ErrNotSupported)
} }
/* coreAccessor contains a path through a struct. It contains at least one index. /* coreAccessor contains a path through a struct. It contains at least one index.
@ -219,6 +482,240 @@ func parseCoreAccessor(accessor string) (coreAccessor, error) {
return result, nil return result, nil
} }
func (ca coreAccessor) String() string {
strs := make([]string, 0, len(ca))
for _, i := range ca {
strs = append(strs, strconv.Itoa(i))
}
return strings.Join(strs, ":")
}
func (ca coreAccessor) enumValue(t Type) (*EnumValue, error) {
e, ok := t.(*Enum)
if !ok {
return nil, fmt.Errorf("not an enum: %s", t)
}
if len(ca) > 1 {
return nil, fmt.Errorf("invalid accessor %s for enum", ca)
}
i := ca[0]
if i >= len(e.Values) {
return nil, fmt.Errorf("invalid index %d for %s", i, e)
}
return &e.Values[i], nil
}
type coreField struct {
Type Type
offset uint32
}
func adjustOffset(base uint32, t Type, n int) (uint32, error) {
size, err := Sizeof(t)
if err != nil {
return 0, err
}
return base + (uint32(n) * uint32(size) * 8), nil
}
// coreFindField descends into the local type using the accessor and tries to
// find an equivalent field in target at each step.
//
// Returns the field and the offset of the field from the start of
// target in bits.
func coreFindField(local Type, localAcc coreAccessor, target Type) (_, _ coreField, _ error) {
// The first index is used to offset a pointer of the base type like
// when accessing an array.
localOffset, err := adjustOffset(0, local, localAcc[0])
if err != nil {
return coreField{}, coreField{}, err
}
targetOffset, err := adjustOffset(0, target, localAcc[0])
if err != nil {
return coreField{}, coreField{}, err
}
if err := coreAreMembersCompatible(local, target); err != nil {
return coreField{}, coreField{}, fmt.Errorf("fields: %w", err)
}
var localMaybeFlex, targetMaybeFlex bool
for _, acc := range localAcc[1:] {
switch localType := local.(type) {
case composite:
// For composite types acc is used to find the field in the local type,
// and then we try to find a field in target with the same name.
localMembers := localType.members()
if acc >= len(localMembers) {
return coreField{}, coreField{}, fmt.Errorf("invalid accessor %d for %s", acc, local)
}
localMember := localMembers[acc]
if localMember.Name == "" {
_, ok := localMember.Type.(composite)
if !ok {
return coreField{}, coreField{}, fmt.Errorf("unnamed field with type %s: %s", localMember.Type, ErrNotSupported)
}
// This is an anonymous struct or union, ignore it.
local = localMember.Type
localOffset += localMember.Offset
localMaybeFlex = false
continue
}
targetType, ok := target.(composite)
if !ok {
return coreField{}, coreField{}, fmt.Errorf("target not composite: %w", errImpossibleRelocation)
}
targetMember, last, err := coreFindMember(targetType, localMember.Name)
if err != nil {
return coreField{}, coreField{}, err
}
if targetMember.BitfieldSize > 0 {
return coreField{}, coreField{}, fmt.Errorf("field %q is a bitfield: %w", targetMember.Name, ErrNotSupported)
}
local = localMember.Type
localMaybeFlex = acc == len(localMembers)-1
localOffset += localMember.Offset
target = targetMember.Type
targetMaybeFlex = last
targetOffset += targetMember.Offset
case *Array:
// For arrays, acc is the index in the target.
targetType, ok := target.(*Array)
if !ok {
return coreField{}, coreField{}, fmt.Errorf("target not array: %w", errImpossibleRelocation)
}
if localType.Nelems == 0 && !localMaybeFlex {
return coreField{}, coreField{}, fmt.Errorf("local type has invalid flexible array")
}
if targetType.Nelems == 0 && !targetMaybeFlex {
return coreField{}, coreField{}, fmt.Errorf("target type has invalid flexible array")
}
if localType.Nelems > 0 && acc >= int(localType.Nelems) {
return coreField{}, coreField{}, fmt.Errorf("invalid access of %s at index %d", localType, acc)
}
if targetType.Nelems > 0 && acc >= int(targetType.Nelems) {
return coreField{}, coreField{}, fmt.Errorf("out of bounds access of target: %w", errImpossibleRelocation)
}
local = localType.Type
localMaybeFlex = false
localOffset, err = adjustOffset(localOffset, local, acc)
if err != nil {
return coreField{}, coreField{}, err
}
target = targetType.Type
targetMaybeFlex = false
targetOffset, err = adjustOffset(targetOffset, target, acc)
if err != nil {
return coreField{}, coreField{}, err
}
default:
return coreField{}, coreField{}, fmt.Errorf("relocate field of %T: %w", localType, ErrNotSupported)
}
if err := coreAreMembersCompatible(local, target); err != nil {
return coreField{}, coreField{}, err
}
}
return coreField{local, localOffset}, coreField{target, targetOffset}, nil
}
// coreFindMember finds a member in a composite type while handling anonymous
// structs and unions.
func coreFindMember(typ composite, name Name) (Member, bool, error) {
if name == "" {
return Member{}, false, errors.New("can't search for anonymous member")
}
type offsetTarget struct {
composite
offset uint32
}
targets := []offsetTarget{{typ, 0}}
visited := make(map[composite]bool)
for i := 0; i < len(targets); i++ {
target := targets[i]
// Only visit targets once to prevent infinite recursion.
if visited[target] {
continue
}
if len(visited) >= maxTypeDepth {
// This check is different than libbpf, which restricts the entire
// path to BPF_CORE_SPEC_MAX_LEN items.
return Member{}, false, fmt.Errorf("type is nested too deep")
}
visited[target] = true
members := target.members()
for j, member := range members {
if member.Name == name {
// NB: This is safe because member is a copy.
member.Offset += target.offset
return member, j == len(members)-1, nil
}
// The names don't match, but this member could be an anonymous struct
// or union.
if member.Name != "" {
continue
}
comp, ok := member.Type.(composite)
if !ok {
return Member{}, false, fmt.Errorf("anonymous non-composite type %T not allowed", member.Type)
}
targets = append(targets, offsetTarget{comp, target.offset + member.Offset})
}
}
return Member{}, false, fmt.Errorf("no matching member: %w", errImpossibleRelocation)
}
// coreFindEnumValue follows localAcc to find the equivalent enum value in target.
func coreFindEnumValue(local Type, localAcc coreAccessor, target Type) (localValue, targetValue *EnumValue, _ error) {
localValue, err := localAcc.enumValue(local)
if err != nil {
return nil, nil, err
}
targetEnum, ok := target.(*Enum)
if !ok {
return nil, nil, errImpossibleRelocation
}
localName := localValue.Name.essentialName()
for i, targetValue := range targetEnum.Values {
if targetValue.Name.essentialName() != localName {
continue
}
return localValue, &targetEnum.Values[i], nil
}
return nil, nil, errImpossibleRelocation
}
/* The comment below is from bpf_core_types_are_compat in libbpf.c: /* The comment below is from bpf_core_types_are_compat in libbpf.c:
* *
* Check local and target types for compatibility. This check is used for * Check local and target types for compatibility. This check is used for
@ -239,8 +736,10 @@ func parseCoreAccessor(accessor string) (coreAccessor, error) {
* number of input args and compatible return and argument types. * number of input args and compatible return and argument types.
* These rules are not set in stone and probably will be adjusted as we get * These rules are not set in stone and probably will be adjusted as we get
* more experience with using BPF CO-RE relocations. * more experience with using BPF CO-RE relocations.
*
* Returns errImpossibleRelocation if types are not compatible.
*/ */
func coreAreTypesCompatible(localType Type, targetType Type) (bool, error) { func coreAreTypesCompatible(localType Type, targetType Type) error {
var ( var (
localTs, targetTs typeDeque localTs, targetTs typeDeque
l, t = &localType, &targetType l, t = &localType, &targetType
@ -249,14 +748,14 @@ func coreAreTypesCompatible(localType Type, targetType Type) (bool, error) {
for ; l != nil && t != nil; l, t = localTs.shift(), targetTs.shift() { for ; l != nil && t != nil; l, t = localTs.shift(), targetTs.shift() {
if depth >= maxTypeDepth { if depth >= maxTypeDepth {
return false, errors.New("types are nested too deep") return errors.New("types are nested too deep")
} }
localType = skipQualifierAndTypedef(*l) localType = *l
targetType = skipQualifierAndTypedef(*t) targetType = *t
if reflect.TypeOf(localType) != reflect.TypeOf(targetType) { if reflect.TypeOf(localType) != reflect.TypeOf(targetType) {
return false, nil return fmt.Errorf("type mismatch: %w", errImpossibleRelocation)
} }
switch lv := (localType).(type) { switch lv := (localType).(type) {
@ -266,7 +765,7 @@ func coreAreTypesCompatible(localType Type, targetType Type) (bool, error) {
case *Int: case *Int:
tv := targetType.(*Int) tv := targetType.(*Int)
if lv.isBitfield() || tv.isBitfield() { if lv.isBitfield() || tv.isBitfield() {
return false, nil return fmt.Errorf("bitfield: %w", errImpossibleRelocation)
} }
case *Pointer, *Array: case *Pointer, *Array:
@ -277,7 +776,7 @@ func coreAreTypesCompatible(localType Type, targetType Type) (bool, error) {
case *FuncProto: case *FuncProto:
tv := targetType.(*FuncProto) tv := targetType.(*FuncProto)
if len(lv.Params) != len(tv.Params) { if len(lv.Params) != len(tv.Params) {
return false, nil return fmt.Errorf("function param mismatch: %w", errImpossibleRelocation)
} }
depth++ depth++
@ -285,22 +784,24 @@ func coreAreTypesCompatible(localType Type, targetType Type) (bool, error) {
targetType.walk(&targetTs) targetType.walk(&targetTs)
default: default:
return false, fmt.Errorf("unsupported type %T", localType) return fmt.Errorf("unsupported type %T", localType)
} }
} }
if l != nil { if l != nil {
return false, fmt.Errorf("dangling local type %T", *l) return fmt.Errorf("dangling local type %T", *l)
} }
if t != nil { if t != nil {
return false, fmt.Errorf("dangling target type %T", *t) return fmt.Errorf("dangling target type %T", *t)
} }
return true, nil return nil
} }
/* The comment below is from bpf_core_fields_are_compat in libbpf.c: /* coreAreMembersCompatible checks two types for field-based relocation compatibility.
*
* The comment below is from bpf_core_fields_are_compat in libbpf.c:
* *
* Check two types for compatibility for the purpose of field access * Check two types for compatibility for the purpose of field access
* relocation. const/volatile/restrict and typedefs are skipped to ensure we * relocation. const/volatile/restrict and typedefs are skipped to ensure we
@ -314,65 +815,63 @@ func coreAreTypesCompatible(localType Type, targetType Type) (bool, error) {
* - for INT, size and signedness are ignored; * - for INT, size and signedness are ignored;
* - for ARRAY, dimensionality is ignored, element types are checked for * - for ARRAY, dimensionality is ignored, element types are checked for
* compatibility recursively; * compatibility recursively;
* [ NB: coreAreMembersCompatible doesn't recurse, this check is done
* by coreFindField. ]
* - everything else shouldn't be ever a target of relocation. * - everything else shouldn't be ever a target of relocation.
* These rules are not set in stone and probably will be adjusted as we get * These rules are not set in stone and probably will be adjusted as we get
* more experience with using BPF CO-RE relocations. * more experience with using BPF CO-RE relocations.
*
* Returns errImpossibleRelocation if the members are not compatible.
*/ */
func coreAreMembersCompatible(localType Type, targetType Type) (bool, error) { func coreAreMembersCompatible(localType Type, targetType Type) error {
doNamesMatch := func(a, b string) bool { doNamesMatch := func(a, b string) error {
if a == "" || b == "" { if a == "" || b == "" {
// allow anonymous and named type to match // allow anonymous and named type to match
return true return nil
} }
return essentialName(a) == essentialName(b) if essentialName(a) == essentialName(b) {
return nil
} }
for depth := 0; depth <= maxTypeDepth; depth++ { return fmt.Errorf("names don't match: %w", errImpossibleRelocation)
localType = skipQualifierAndTypedef(localType) }
targetType = skipQualifierAndTypedef(targetType)
_, lok := localType.(composite) _, lok := localType.(composite)
_, tok := targetType.(composite) _, tok := targetType.(composite)
if lok && tok { if lok && tok {
return true, nil return nil
} }
if reflect.TypeOf(localType) != reflect.TypeOf(targetType) { if reflect.TypeOf(localType) != reflect.TypeOf(targetType) {
return false, nil return fmt.Errorf("type mismatch: %w", errImpossibleRelocation)
} }
switch lv := localType.(type) { switch lv := localType.(type) {
case *Pointer: case *Array, *Pointer:
return true, nil return nil
case *Enum: case *Enum:
tv := targetType.(*Enum) tv := targetType.(*Enum)
return doNamesMatch(lv.name(), tv.name()), nil return doNamesMatch(lv.name(), tv.name())
case *Fwd: case *Fwd:
tv := targetType.(*Fwd) tv := targetType.(*Fwd)
return doNamesMatch(lv.name(), tv.name()), nil return doNamesMatch(lv.name(), tv.name())
case *Int: case *Int:
tv := targetType.(*Int) tv := targetType.(*Int)
return !lv.isBitfield() && !tv.isBitfield(), nil if lv.isBitfield() || tv.isBitfield() {
return fmt.Errorf("bitfield: %w", errImpossibleRelocation)
case *Array: }
tv := targetType.(*Array) return nil
localType = lv.Type
targetType = tv.Type
default: default:
return false, fmt.Errorf("unsupported type %T", localType) return fmt.Errorf("type %s: %w", localType, ErrNotSupported)
} }
}
return false, errors.New("types are nested too deep")
} }
func skipQualifierAndTypedef(typ Type) Type { func skipQualifierAndTypedef(typ Type) (Type, error) {
result := typ result := typ
for depth := 0; depth <= maxTypeDepth; depth++ { for depth := 0; depth <= maxTypeDepth; depth++ {
switch v := (result).(type) { switch v := (result).(type) {
@ -381,8 +880,8 @@ func skipQualifierAndTypedef(typ Type) Type {
case *Typedef: case *Typedef:
result = v.Type result = v.Type
default: default:
return result return result, nil
} }
} }
return typ return nil, errors.New("exceeded type depth")
} }

View File

@ -30,7 +30,7 @@ type btfExtCoreHeader struct {
CoreReloLen uint32 CoreReloLen uint32
} }
func parseExtInfos(r io.ReadSeeker, bo binary.ByteOrder, strings stringTable) (funcInfo, lineInfo map[string]extInfo, coreRelos map[string]bpfCoreRelos, err error) { func parseExtInfos(r io.ReadSeeker, bo binary.ByteOrder, strings stringTable) (funcInfo, lineInfo map[string]extInfo, relos map[string]coreRelos, err error) {
var header btfExtHeader var header btfExtHeader
var coreHeader btfExtCoreHeader var coreHeader btfExtCoreHeader
if err := binary.Read(r, bo, &header); err != nil { if err := binary.Read(r, bo, &header); err != nil {
@ -94,13 +94,13 @@ func parseExtInfos(r io.ReadSeeker, bo binary.ByteOrder, strings stringTable) (f
return nil, nil, nil, fmt.Errorf("can't seek to CO-RE relocation section: %v", err) return nil, nil, nil, fmt.Errorf("can't seek to CO-RE relocation section: %v", err)
} }
coreRelos, err = parseExtInfoRelos(io.LimitReader(r, int64(coreHeader.CoreReloLen)), bo, strings) relos, err = parseExtInfoRelos(io.LimitReader(r, int64(coreHeader.CoreReloLen)), bo, strings)
if err != nil { if err != nil {
return nil, nil, nil, fmt.Errorf("CO-RE relocation info: %w", err) return nil, nil, nil, fmt.Errorf("CO-RE relocation info: %w", err)
} }
} }
return funcInfo, lineInfo, coreRelos, nil return funcInfo, lineInfo, relos, nil
} }
type btfExtInfoSec struct { type btfExtInfoSec struct {
@ -208,18 +208,25 @@ type bpfCoreRelo struct {
InsnOff uint32 InsnOff uint32
TypeID TypeID TypeID TypeID
AccessStrOff uint32 AccessStrOff uint32
ReloKind coreReloKind Kind COREKind
} }
type bpfCoreRelos []bpfCoreRelo type coreRelo struct {
insnOff uint32
typeID TypeID
accessor coreAccessor
kind COREKind
}
type coreRelos []coreRelo
// append two slices of extInfoRelo to each other. The InsnOff of b are adjusted // append two slices of extInfoRelo to each other. The InsnOff of b are adjusted
// by offset. // by offset.
func (r bpfCoreRelos) append(other bpfCoreRelos, offset uint64) bpfCoreRelos { func (r coreRelos) append(other coreRelos, offset uint64) coreRelos {
result := make([]bpfCoreRelo, 0, len(r)+len(other)) result := make([]coreRelo, 0, len(r)+len(other))
result = append(result, r...) result = append(result, r...)
for _, relo := range other { for _, relo := range other {
relo.InsnOff += uint32(offset) relo.insnOff += uint32(offset)
result = append(result, relo) result = append(result, relo)
} }
return result return result
@ -227,7 +234,7 @@ func (r bpfCoreRelos) append(other bpfCoreRelos, offset uint64) bpfCoreRelos {
var extInfoReloSize = binary.Size(bpfCoreRelo{}) var extInfoReloSize = binary.Size(bpfCoreRelo{})
func parseExtInfoRelos(r io.Reader, bo binary.ByteOrder, strings stringTable) (map[string]bpfCoreRelos, error) { func parseExtInfoRelos(r io.Reader, bo binary.ByteOrder, strings stringTable) (map[string]coreRelos, error) {
var recordSize uint32 var recordSize uint32
if err := binary.Read(r, bo, &recordSize); err != nil { if err := binary.Read(r, bo, &recordSize); err != nil {
return nil, fmt.Errorf("read record size: %v", err) return nil, fmt.Errorf("read record size: %v", err)
@ -237,14 +244,14 @@ func parseExtInfoRelos(r io.Reader, bo binary.ByteOrder, strings stringTable) (m
return nil, fmt.Errorf("expected record size %d, got %d", extInfoReloSize, recordSize) return nil, fmt.Errorf("expected record size %d, got %d", extInfoReloSize, recordSize)
} }
result := make(map[string]bpfCoreRelos) result := make(map[string]coreRelos)
for { for {
secName, infoHeader, err := parseExtInfoHeader(r, bo, strings) secName, infoHeader, err := parseExtInfoHeader(r, bo, strings)
if errors.Is(err, io.EOF) { if errors.Is(err, io.EOF) {
return result, nil return result, nil
} }
var relos []bpfCoreRelo var relos coreRelos
for i := uint32(0); i < infoHeader.NumInfo; i++ { for i := uint32(0); i < infoHeader.NumInfo; i++ {
var relo bpfCoreRelo var relo bpfCoreRelo
if err := binary.Read(r, bo, &relo); err != nil { if err := binary.Read(r, bo, &relo); err != nil {
@ -255,7 +262,22 @@ func parseExtInfoRelos(r io.Reader, bo binary.ByteOrder, strings stringTable) (m
return nil, fmt.Errorf("section %v: offset %v is not aligned with instruction size", secName, relo.InsnOff) return nil, fmt.Errorf("section %v: offset %v is not aligned with instruction size", secName, relo.InsnOff)
} }
relos = append(relos, relo) accessorStr, err := strings.Lookup(relo.AccessStrOff)
if err != nil {
return nil, err
}
accessor, err := parseCoreAccessor(accessorStr)
if err != nil {
return nil, fmt.Errorf("accessor %q: %s", accessorStr, err)
}
relos = append(relos, coreRelo{
relo.InsnOff,
relo.TypeID,
accessor,
relo.Kind,
})
} }
result[secName] = relos result[secName] = relos

View File

@ -1,7 +1,6 @@
package btf package btf
import ( import (
"errors"
"fmt" "fmt"
"math" "math"
"strings" "strings"
@ -37,6 +36,7 @@ type Type interface {
type namedType interface { type namedType interface {
Type Type
name() string name() string
essentialName() string
} }
// Name identifies a type. // Name identifies a type.
@ -48,6 +48,10 @@ func (n Name) name() string {
return string(n) return string(n)
} }
func (n Name) essentialName() string {
return essentialName(string(n))
}
// Void is the unit type of BTF. // Void is the unit type of BTF.
type Void struct{} type Void struct{}
@ -174,8 +178,7 @@ func (s *Struct) walk(tdq *typeDeque) {
func (s *Struct) copy() Type { func (s *Struct) copy() Type {
cpy := *s cpy := *s
cpy.Members = make([]Member, len(s.Members)) cpy.Members = copyMembers(s.Members)
copy(cpy.Members, s.Members)
return &cpy return &cpy
} }
@ -206,8 +209,7 @@ func (u *Union) walk(tdq *typeDeque) {
func (u *Union) copy() Type { func (u *Union) copy() Type {
cpy := *u cpy := *u
cpy.Members = make([]Member, len(u.Members)) cpy.Members = copyMembers(u.Members)
copy(cpy.Members, u.Members)
return &cpy return &cpy
} }
@ -215,6 +217,12 @@ func (u *Union) members() []Member {
return u.Members return u.Members
} }
func copyMembers(orig []Member) []Member {
cpy := make([]Member, len(orig))
copy(cpy, orig)
return cpy
}
type composite interface { type composite interface {
members() []Member members() []Member
} }
@ -511,7 +519,7 @@ func Sizeof(typ Type) (int, error) {
switch v := typ.(type) { switch v := typ.(type) {
case *Array: case *Array:
if n > 0 && int64(v.Nelems) > math.MaxInt64/n { if n > 0 && int64(v.Nelems) > math.MaxInt64/n {
return 0, errors.New("overflow") return 0, fmt.Errorf("type %s: overflow", typ)
} }
// Arrays may be of zero length, which allows // Arrays may be of zero length, which allows
@ -532,28 +540,30 @@ func Sizeof(typ Type) (int, error) {
continue continue
default: default:
return 0, fmt.Errorf("unrecognized type %T", typ) return 0, fmt.Errorf("unsized type %T", typ)
} }
if n > 0 && elem > math.MaxInt64/n { if n > 0 && elem > math.MaxInt64/n {
return 0, errors.New("overflow") return 0, fmt.Errorf("type %s: overflow", typ)
} }
size := n * elem size := n * elem
if int64(int(size)) != size { if int64(int(size)) != size {
return 0, errors.New("overflow") return 0, fmt.Errorf("type %s: overflow", typ)
} }
return int(size), nil return int(size), nil
} }
return 0, errors.New("exceeded type depth") return 0, fmt.Errorf("type %s: exceeded type depth", typ)
} }
// copy a Type recursively. // copy a Type recursively.
// //
// typ may form a cycle. // typ may form a cycle.
func copyType(typ Type) Type { //
// Returns any errors from transform verbatim.
func copyType(typ Type, transform func(Type) (Type, error)) (Type, error) {
var ( var (
copies = make(map[Type]Type) copies = make(map[Type]Type)
work typeDeque work typeDeque
@ -566,7 +576,17 @@ func copyType(typ Type) Type {
continue continue
} }
cpy := (*t).copy() var cpy Type
if transform != nil {
tf, err := transform(*t)
if err != nil {
return nil, fmt.Errorf("copy %s: %w", typ, err)
}
cpy = tf.copy()
} else {
cpy = (*t).copy()
}
copies[*t] = cpy copies[*t] = cpy
*t = cpy *t = cpy
@ -574,7 +594,7 @@ func copyType(typ Type) Type {
cpy.walk(&work) cpy.walk(&work)
} }
return typ return typ, nil
} }
// typeDeque keeps track of pointers to types which still // typeDeque keeps track of pointers to types which still

View File

@ -50,3 +50,19 @@ func (se *SafeELFFile) Symbols() (syms []elf.Symbol, err error) {
syms, err = se.File.Symbols() syms, err = se.File.Symbols()
return return
} }
// DynamicSymbols is the safe version of elf.File.DynamicSymbols.
func (se *SafeELFFile) DynamicSymbols() (syms []elf.Symbol, err error) {
defer func() {
r := recover()
if r == nil {
return
}
syms = nil
err = fmt.Errorf("reading ELF dynamic symbols panicked: %s", r)
}()
syms, err = se.File.DynamicSymbols()
return
}

View File

@ -9,11 +9,16 @@ import (
// depending on the host's endianness. // depending on the host's endianness.
var NativeEndian binary.ByteOrder var NativeEndian binary.ByteOrder
// Clang is set to either "el" or "eb" depending on the host's endianness.
var ClangEndian string
func init() { func init() {
if isBigEndian() { if isBigEndian() {
NativeEndian = binary.BigEndian NativeEndian = binary.BigEndian
ClangEndian = "eb"
} else { } else {
NativeEndian = binary.LittleEndian NativeEndian = binary.LittleEndian
ClangEndian = "el"
} }
} }

View File

@ -29,6 +29,10 @@ type VerifierError struct {
log string log string
} }
func (le *VerifierError) Unwrap() error {
return le.cause
}
func (le *VerifierError) Error() string { func (le *VerifierError) Error() string {
if le.log == "" { if le.log == "" {
return le.cause.Error() return le.cause.Error()

View File

@ -22,10 +22,6 @@ func NewSlicePointer(buf []byte) Pointer {
// NewStringPointer creates a 64-bit pointer from a string. // NewStringPointer creates a 64-bit pointer from a string.
func NewStringPointer(str string) Pointer { func NewStringPointer(str string) Pointer {
if str == "" {
return Pointer{}
}
p, err := unix.BytePtrFromString(str) p, err := unix.BytePtrFromString(str)
if err != nil { if err != nil {
return Pointer{} return Pointer{}

View File

@ -42,6 +42,7 @@ const (
PROT_READ = linux.PROT_READ PROT_READ = linux.PROT_READ
PROT_WRITE = linux.PROT_WRITE PROT_WRITE = linux.PROT_WRITE
MAP_SHARED = linux.MAP_SHARED MAP_SHARED = linux.MAP_SHARED
PERF_ATTR_SIZE_VER1 = linux.PERF_ATTR_SIZE_VER1
PERF_TYPE_SOFTWARE = linux.PERF_TYPE_SOFTWARE PERF_TYPE_SOFTWARE = linux.PERF_TYPE_SOFTWARE
PERF_TYPE_TRACEPOINT = linux.PERF_TYPE_TRACEPOINT PERF_TYPE_TRACEPOINT = linux.PERF_TYPE_TRACEPOINT
PERF_COUNT_SW_BPF_OUTPUT = linux.PERF_COUNT_SW_BPF_OUTPUT PERF_COUNT_SW_BPF_OUTPUT = linux.PERF_COUNT_SW_BPF_OUTPUT

View File

@ -43,6 +43,7 @@ const (
PROT_READ = 0x1 PROT_READ = 0x1
PROT_WRITE = 0x2 PROT_WRITE = 0x2
MAP_SHARED = 0x1 MAP_SHARED = 0x1
PERF_ATTR_SIZE_VER1 = 0
PERF_TYPE_SOFTWARE = 0x1 PERF_TYPE_SOFTWARE = 0x1
PERF_TYPE_TRACEPOINT = 0 PERF_TYPE_TRACEPOINT = 0
PERF_COUNT_SW_BPF_OUTPUT = 0xa PERF_COUNT_SW_BPF_OUTPUT = 0xa

View File

@ -1,12 +1,16 @@
package link package link
import ( import (
"bytes"
"crypto/rand" "crypto/rand"
"errors" "errors"
"fmt" "fmt"
"io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
"sync"
"unsafe"
"github.com/cilium/ebpf" "github.com/cilium/ebpf"
"github.com/cilium/ebpf/internal" "github.com/cilium/ebpf/internal"
@ -15,13 +19,60 @@ import (
var ( var (
kprobeEventsPath = filepath.Join(tracefsPath, "kprobe_events") kprobeEventsPath = filepath.Join(tracefsPath, "kprobe_events")
kprobeRetprobeBit = struct {
once sync.Once
value uint64
err error
}{}
) )
type probeType uint8
const (
kprobeType probeType = iota
uprobeType
)
func (pt probeType) String() string {
if pt == kprobeType {
return "kprobe"
}
return "uprobe"
}
func (pt probeType) EventsPath() string {
if pt == kprobeType {
return kprobeEventsPath
}
return uprobeEventsPath
}
func (pt probeType) PerfEventType(ret bool) perfEventType {
if pt == kprobeType {
if ret {
return kretprobeEvent
}
return kprobeEvent
}
if ret {
return uretprobeEvent
}
return uprobeEvent
}
func (pt probeType) RetprobeBit() (uint64, error) {
if pt == kprobeType {
return kretprobeBit()
}
return uretprobeBit()
}
// Kprobe attaches the given eBPF program to a perf event that fires when the // Kprobe attaches the given eBPF program to a perf event that fires when the
// given kernel symbol starts executing. See /proc/kallsyms for available // given kernel symbol starts executing. See /proc/kallsyms for available
// symbols. For example, printk(): // symbols. For example, printk():
// //
// Kprobe("printk") // Kprobe("printk", prog)
// //
// The resulting Link must be Closed during program shutdown to avoid leaking // The resulting Link must be Closed during program shutdown to avoid leaking
// system resources. // system resources.
@ -44,7 +95,7 @@ func Kprobe(symbol string, prog *ebpf.Program) (Link, error) {
// before the given kernel symbol exits, with the function stack left intact. // before the given kernel symbol exits, with the function stack left intact.
// See /proc/kallsyms for available symbols. For example, printk(): // See /proc/kallsyms for available symbols. For example, printk():
// //
// Kretprobe("printk") // Kretprobe("printk", prog)
// //
// The resulting Link must be Closed during program shutdown to avoid leaking // The resulting Link must be Closed during program shutdown to avoid leaking
// system resources. // system resources.
@ -80,7 +131,10 @@ func kprobe(symbol string, prog *ebpf.Program, ret bool) (*perfEvent, error) {
} }
// Use kprobe PMU if the kernel has it available. // Use kprobe PMU if the kernel has it available.
tp, err := pmuKprobe(symbol, ret) tp, err := pmuKprobe(platformPrefix(symbol), ret)
if errors.Is(err, os.ErrNotExist) {
tp, err = pmuKprobe(symbol, ret)
}
if err == nil { if err == nil {
return tp, nil return tp, nil
} }
@ -89,7 +143,10 @@ func kprobe(symbol string, prog *ebpf.Program, ret bool) (*perfEvent, error) {
} }
// Use tracefs if kprobe PMU is missing. // Use tracefs if kprobe PMU is missing.
tp, err = tracefsKprobe(platformPrefix(symbol), ret)
if errors.Is(err, os.ErrNotExist) {
tp, err = tracefsKprobe(symbol, ret) tp, err = tracefsKprobe(symbol, ret)
}
if err != nil { if err != nil {
return nil, fmt.Errorf("creating trace event '%s' in tracefs: %w", symbol, err) return nil, fmt.Errorf("creating trace event '%s' in tracefs: %w", symbol, err)
} }
@ -97,36 +154,70 @@ func kprobe(symbol string, prog *ebpf.Program, ret bool) (*perfEvent, error) {
return tp, nil return tp, nil
} }
// pmuKprobe opens a perf event based on a Performance Monitoring Unit. // pmuKprobe opens a perf event based on the kprobe PMU.
// Requires at least 4.17 (e12f03d7031a "perf/core: Implement the // Returns os.ErrNotExist if the given symbol does not exist in the kernel.
// 'perf_kprobe' PMU").
// Returns ErrNotSupported if the kernel doesn't support perf_kprobe PMU,
// or os.ErrNotExist if the given symbol does not exist in the kernel.
func pmuKprobe(symbol string, ret bool) (*perfEvent, error) { func pmuKprobe(symbol string, ret bool) (*perfEvent, error) {
return pmuProbe(kprobeType, symbol, "", 0, ret)
}
// pmuProbe opens a perf event based on a Performance Monitoring Unit.
//
// Requires at least a 4.17 kernel.
// e12f03d7031a "perf/core: Implement the 'perf_kprobe' PMU"
// 33ea4b24277b "perf/core: Implement the 'perf_uprobe' PMU"
//
// Returns ErrNotSupported if the kernel doesn't support perf_[k,u]probe PMU
func pmuProbe(typ probeType, symbol, path string, offset uint64, ret bool) (*perfEvent, error) {
// Getting the PMU type will fail if the kernel doesn't support // Getting the PMU type will fail if the kernel doesn't support
// the perf_kprobe PMU. // the perf_[k,u]probe PMU.
et, err := getPMUEventType("kprobe") et, err := getPMUEventType(typ)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var config uint64
if ret {
bit, err := typ.RetprobeBit()
if err != nil {
return nil, err
}
config |= 1 << bit
}
var (
attr unix.PerfEventAttr
sp unsafe.Pointer
)
switch typ {
case kprobeType:
// Create a pointer to a NUL-terminated string for the kernel. // Create a pointer to a NUL-terminated string for the kernel.
sp, err := unsafeStringPtr(symbol) sp, err := unsafeStringPtr(symbol)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// TODO: Parse the position of the bit from /sys/bus/event_source/devices/%s/format/retprobe. attr = unix.PerfEventAttr{
config := 0
if ret {
config = 1
}
attr := unix.PerfEventAttr{
Type: uint32(et), // PMU event type read from sysfs Type: uint32(et), // PMU event type read from sysfs
Ext1: uint64(uintptr(sp)), // Kernel symbol to trace Ext1: uint64(uintptr(sp)), // Kernel symbol to trace
Config: uint64(config), // perf_kprobe PMU treats config as flags Config: config, // Retprobe flag
}
case uprobeType:
sp, err := unsafeStringPtr(path)
if err != nil {
return nil, err
}
attr = unix.PerfEventAttr{
// The minimum size required for PMU uprobes is PERF_ATTR_SIZE_VER1,
// since it added the config2 (Ext2) field. The Size field controls the
// size of the internal buffer the kernel allocates for reading the
// perf_event_attr argument from userspace.
Size: unix.PERF_ATTR_SIZE_VER1,
Type: uint32(et), // PMU event type read from sysfs
Ext1: uint64(uintptr(sp)), // Uprobe path
Ext2: offset, // Uprobe offset
Config: config, // Retprobe flag
}
} }
fd, err := unix.PerfEventOpen(&attr, perfAllThreads, 0, -1, unix.PERF_FLAG_FD_CLOEXEC) fd, err := unix.PerfEventOpen(&attr, perfAllThreads, 0, -1, unix.PERF_FLAG_FD_CLOEXEC)
@ -144,22 +235,27 @@ func pmuKprobe(symbol string, ret bool) (*perfEvent, error) {
// Ensure the string pointer is not collected before PerfEventOpen returns. // Ensure the string pointer is not collected before PerfEventOpen returns.
runtime.KeepAlive(sp) runtime.KeepAlive(sp)
// Kernel has perf_kprobe PMU available, initialize perf event. // Kernel has perf_[k,u]probe PMU available, initialize perf event.
return &perfEvent{ return &perfEvent{
fd: internal.NewFD(uint32(fd)), fd: internal.NewFD(uint32(fd)),
pmuID: et, pmuID: et,
name: symbol, name: symbol,
ret: ret, typ: typ.PerfEventType(ret),
progType: ebpf.Kprobe,
}, nil }, nil
} }
// tracefsKprobe creates a trace event by writing an entry to <tracefs>/kprobe_events. // tracefsKprobe creates a Kprobe tracefs entry.
// A new trace event group name is generated on every call to support creating
// multiple trace events for the same kernel symbol. A perf event is then opened
// on the newly-created trace event and returned to the caller.
func tracefsKprobe(symbol string, ret bool) (*perfEvent, error) { func tracefsKprobe(symbol string, ret bool) (*perfEvent, error) {
return tracefsProbe(kprobeType, symbol, "", 0, ret)
}
// tracefsProbe creates a trace event by writing an entry to <tracefs>/[k,u]probe_events.
// A new trace event group name is generated on every call to support creating
// multiple trace events for the same kernel or userspace symbol.
// Path and offset are only set in the case of uprobe(s) and are used to set
// the executable/library path on the filesystem and the offset where the probe is inserted.
// A perf event is then opened on the newly-created trace event and returned to the caller.
func tracefsProbe(typ probeType, symbol, path string, offset uint64, ret bool) (*perfEvent, error) {
// Generate a random string for each trace event we attempt to create. // Generate a random string for each trace event we attempt to create.
// This value is used as the 'group' token in tracefs to allow creating // This value is used as the 'group' token in tracefs to allow creating
// multiple kprobe trace events with the same name. // multiple kprobe trace events with the same name.
@ -176,14 +272,13 @@ func tracefsKprobe(symbol string, ret bool) (*perfEvent, error) {
if err == nil { if err == nil {
return nil, fmt.Errorf("trace event already exists: %s/%s", group, symbol) return nil, fmt.Errorf("trace event already exists: %s/%s", group, symbol)
} }
// The read is expected to fail with ErrNotSupported due to a non-existing event. if err != nil && !errors.Is(err, os.ErrNotExist) {
if err != nil && !errors.Is(err, ErrNotSupported) {
return nil, fmt.Errorf("checking trace event %s/%s: %w", group, symbol, err) return nil, fmt.Errorf("checking trace event %s/%s: %w", group, symbol, err)
} }
// Create the kprobe trace event using tracefs. // Create the [k,u]probe trace event using tracefs.
if err := createTraceFSKprobeEvent(group, symbol, ret); err != nil { if err := createTraceFSProbeEvent(typ, group, symbol, path, offset, ret); err != nil {
return nil, fmt.Errorf("creating kprobe event on tracefs: %w", err) return nil, fmt.Errorf("creating probe entry on tracefs: %w", err)
} }
// Get the newly-created trace event's id. // Get the newly-created trace event's id.
@ -202,23 +297,26 @@ func tracefsKprobe(symbol string, ret bool) (*perfEvent, error) {
fd: fd, fd: fd,
group: group, group: group,
name: symbol, name: symbol,
ret: ret,
tracefsID: tid, tracefsID: tid,
progType: ebpf.Kprobe, // kernel only allows attaching kprobe programs to kprobe events typ: typ.PerfEventType(ret),
}, nil }, nil
} }
// createTraceFSKprobeEvent creates a new ephemeral trace event by writing to // createTraceFSProbeEvent creates a new ephemeral trace event by writing to
// <tracefs>/kprobe_events. Returns ErrNotSupported if symbol is not a valid // <tracefs>/[k,u]probe_events. Returns os.ErrNotExist if symbol is not a valid
// kernel symbol, or if it is not traceable with kprobes. // kernel symbol, or if it is not traceable with kprobes. Returns os.ErrExist
func createTraceFSKprobeEvent(group, symbol string, ret bool) error { // if a probe with the same group and symbol already exists.
func createTraceFSProbeEvent(typ probeType, group, symbol, path string, offset uint64, ret bool) error {
// Open the kprobe_events file in tracefs. // Open the kprobe_events file in tracefs.
f, err := os.OpenFile(kprobeEventsPath, os.O_APPEND|os.O_WRONLY, 0666) f, err := os.OpenFile(typ.EventsPath(), os.O_APPEND|os.O_WRONLY, 0666)
if err != nil { if err != nil {
return fmt.Errorf("error opening kprobe_events: %w", err) return fmt.Errorf("error opening '%s': %w", typ.EventsPath(), err)
} }
defer f.Close() defer f.Close()
var pe string
switch typ {
case kprobeType:
// The kprobe_events syntax is as follows (see Documentation/trace/kprobetrace.txt): // The kprobe_events syntax is as follows (see Documentation/trace/kprobetrace.txt):
// p[:[GRP/]EVENT] [MOD:]SYM[+offs]|MEMADDR [FETCHARGS] : Set a probe // p[:[GRP/]EVENT] [MOD:]SYM[+offs]|MEMADDR [FETCHARGS] : Set a probe
// r[MAXACTIVE][:[GRP/]EVENT] [MOD:]SYM[+0] [FETCHARGS] : Set a return probe // r[MAXACTIVE][:[GRP/]EVENT] [MOD:]SYM[+0] [FETCHARGS] : Set a return probe
@ -231,36 +329,51 @@ func createTraceFSKprobeEvent(group, symbol string, ret bool) error {
// Leaving the kretprobe's MAXACTIVE set to 0 (or absent) will make the // Leaving the kretprobe's MAXACTIVE set to 0 (or absent) will make the
// kernel default to NR_CPUS. This is desired in most eBPF cases since // kernel default to NR_CPUS. This is desired in most eBPF cases since
// subsampling or rate limiting logic can be more accurately implemented in // subsampling or rate limiting logic can be more accurately implemented in
// the eBPF program itself. See Documentation/kprobes.txt for more details. // the eBPF program itself.
pe := fmt.Sprintf("%s:%s/%s %s", kprobePrefix(ret), group, symbol, symbol) // See Documentation/kprobes.txt for more details.
pe = fmt.Sprintf("%s:%s/%s %s", probePrefix(ret), group, symbol, symbol)
case uprobeType:
// The uprobe_events syntax is as follows:
// p[:[GRP/]EVENT] PATH:OFFSET [FETCHARGS] : Set a probe
// r[:[GRP/]EVENT] PATH:OFFSET [FETCHARGS] : Set a return probe
// -:[GRP/]EVENT : Clear a probe
//
// Some examples:
// r:ebpf_1234/readline /bin/bash:0x12345
// p:ebpf_5678/main_mySymbol /bin/mybin:0x12345
//
// See Documentation/trace/uprobetracer.txt for more details.
pathOffset := uprobePathOffset(path, offset)
pe = fmt.Sprintf("%s:%s/%s %s", probePrefix(ret), group, symbol, pathOffset)
}
_, err = f.WriteString(pe) _, err = f.WriteString(pe)
// Since commit 97c753e62e6c, ENOENT is correctly returned instead of EINVAL // Since commit 97c753e62e6c, ENOENT is correctly returned instead of EINVAL
// when trying to create a kretprobe for a missing symbol. Make sure ENOENT // when trying to create a kretprobe for a missing symbol. Make sure ENOENT
// is returned to the caller. // is returned to the caller.
if errors.Is(err, os.ErrNotExist) || errors.Is(err, unix.EINVAL) { if errors.Is(err, os.ErrNotExist) || errors.Is(err, unix.EINVAL) {
return fmt.Errorf("kernel symbol %s not found: %w", symbol, os.ErrNotExist) return fmt.Errorf("symbol %s not found: %w", symbol, os.ErrNotExist)
} }
if err != nil { if err != nil {
return fmt.Errorf("writing '%s' to kprobe_events: %w", pe, err) return fmt.Errorf("writing '%s' to '%s': %w", pe, typ.EventsPath(), err)
} }
return nil return nil
} }
// closeTraceFSKprobeEvent removes the kprobe with the given group, symbol and kind // closeTraceFSProbeEvent removes the [k,u]probe with the given type, group and symbol
// from <tracefs>/kprobe_events. // from <tracefs>/[k,u]probe_events.
func closeTraceFSKprobeEvent(group, symbol string) error { func closeTraceFSProbeEvent(typ probeType, group, symbol string) error {
f, err := os.OpenFile(kprobeEventsPath, os.O_APPEND|os.O_WRONLY, 0666) f, err := os.OpenFile(typ.EventsPath(), os.O_APPEND|os.O_WRONLY, 0666)
if err != nil { if err != nil {
return fmt.Errorf("error opening kprobe_events: %w", err) return fmt.Errorf("error opening %s: %w", typ.EventsPath(), err)
} }
defer f.Close() defer f.Close()
// See kprobe_events syntax above. Kprobe type does not need to be specified // See [k,u]probe_events syntax above. The probe type does not need to be specified
// for removals. // for removals.
pe := fmt.Sprintf("-:%s/%s", group, symbol) pe := fmt.Sprintf("-:%s/%s", group, symbol)
if _, err = f.WriteString(pe); err != nil { if _, err = f.WriteString(pe); err != nil {
return fmt.Errorf("writing '%s' to kprobe_events: %w", pe, err) return fmt.Errorf("writing '%s' to '%s': %w", pe, typ.EventsPath(), err)
} }
return nil return nil
@ -288,9 +401,38 @@ func randomGroup(prefix string) (string, error) {
return group, nil return group, nil
} }
func kprobePrefix(ret bool) string { func probePrefix(ret bool) string {
if ret { if ret {
return "r" return "r"
} }
return "p" return "p"
} }
// determineRetprobeBit reads a Performance Monitoring Unit's retprobe bit
// from /sys/bus/event_source/devices/<pmu>/format/retprobe.
func determineRetprobeBit(typ probeType) (uint64, error) {
p := filepath.Join("/sys/bus/event_source/devices/", typ.String(), "/format/retprobe")
data, err := ioutil.ReadFile(p)
if err != nil {
return 0, err
}
var rp uint64
n, err := fmt.Sscanf(string(bytes.TrimSpace(data)), "config:%d", &rp)
if err != nil {
return 0, fmt.Errorf("parse retprobe bit: %w", err)
}
if n != 1 {
return 0, fmt.Errorf("parse retprobe bit: expected 1 item, got %d", n)
}
return rp, nil
}
func kretprobeBit() (uint64, error) {
kprobeRetprobeBit.once.Do(func() {
kprobeRetprobeBit.value, kprobeRetprobeBit.err = determineRetprobeBit(kprobeType)
})
return kprobeRetprobeBit.value, kprobeRetprobeBit.err
}

View File

@ -31,6 +31,10 @@ import (
// exported kernel symbols. kprobe-based (tracefs) trace events can be // exported kernel symbols. kprobe-based (tracefs) trace events can be
// created system-wide by writing to the <tracefs>/kprobe_events file, or // created system-wide by writing to the <tracefs>/kprobe_events file, or
// they can be scoped to the current process by creating PMU perf events. // they can be scoped to the current process by creating PMU perf events.
// - u(ret)probe: Ephemeral trace events based on user provides ELF binaries
// and offsets. uprobe-based (tracefs) trace events can be
// created system-wide by writing to the <tracefs>/uprobe_events file, or
// they can be scoped to the current process by creating PMU perf events.
// - perf event: An object instantiated based on an existing trace event or // - perf event: An object instantiated based on an existing trace event or
// kernel symbol. Referred to by fd in userspace. // kernel symbol. Referred to by fd in userspace.
// Exactly one eBPF program can be attached to a perf event. Multiple perf // Exactly one eBPF program can be attached to a perf event. Multiple perf
@ -52,6 +56,16 @@ const (
perfAllThreads = -1 perfAllThreads = -1
) )
type perfEventType uint8
const (
tracepointEvent perfEventType = iota
kprobeEvent
kretprobeEvent
uprobeEvent
uretprobeEvent
)
// A perfEvent represents a perf event kernel object. Exactly one eBPF program // A perfEvent represents a perf event kernel object. Exactly one eBPF program
// can be attached to it. It is created based on a tracefs trace event or a // can be attached to it. It is created based on a tracefs trace event or a
// Performance Monitoring Unit (PMU). // Performance Monitoring Unit (PMU).
@ -66,11 +80,10 @@ type perfEvent struct {
// ID of the trace event read from tracefs. Valid IDs are non-zero. // ID of the trace event read from tracefs. Valid IDs are non-zero.
tracefsID uint64 tracefsID uint64
// True for kretprobes/uretprobes. // The event type determines the types of programs that can be attached.
ret bool typ perfEventType
fd *internal.FD fd *internal.FD
progType ebpf.ProgramType
} }
func (pe *perfEvent) isLink() {} func (pe *perfEvent) isLink() {}
@ -117,13 +130,18 @@ func (pe *perfEvent) Close() error {
return fmt.Errorf("closing perf event fd: %w", err) return fmt.Errorf("closing perf event fd: %w", err)
} }
switch t := pe.progType; t { switch pe.typ {
case ebpf.Kprobe: case kprobeEvent, kretprobeEvent:
// For kprobes created using tracefs, clean up the <tracefs>/kprobe_events entry. // Clean up kprobe tracefs entry.
if pe.tracefsID != 0 { if pe.tracefsID != 0 {
return closeTraceFSKprobeEvent(pe.group, pe.name) return closeTraceFSProbeEvent(kprobeType, pe.group, pe.name)
} }
case ebpf.TracePoint: case uprobeEvent, uretprobeEvent:
// Clean up uprobe tracefs entry.
if pe.tracefsID != 0 {
return closeTraceFSProbeEvent(uprobeType, pe.group, pe.name)
}
case tracepointEvent:
// Tracepoint trace events don't hold any extra resources. // Tracepoint trace events don't hold any extra resources.
return nil return nil
} }
@ -141,12 +159,21 @@ func (pe *perfEvent) attach(prog *ebpf.Program) error {
if pe.fd == nil { if pe.fd == nil {
return errors.New("cannot attach to nil perf event") return errors.New("cannot attach to nil perf event")
} }
if t := prog.Type(); t != pe.progType {
return fmt.Errorf("invalid program type (expected %s): %s", pe.progType, t)
}
if prog.FD() < 0 { if prog.FD() < 0 {
return fmt.Errorf("invalid program: %w", internal.ErrClosedFd) return fmt.Errorf("invalid program: %w", internal.ErrClosedFd)
} }
switch pe.typ {
case kprobeEvent, kretprobeEvent, uprobeEvent, uretprobeEvent:
if t := prog.Type(); t != ebpf.Kprobe {
return fmt.Errorf("invalid program type (expected %s): %s", ebpf.Kprobe, t)
}
case tracepointEvent:
if t := prog.Type(); t != ebpf.TracePoint {
return fmt.Errorf("invalid program type (expected %s): %s", ebpf.TracePoint, t)
}
default:
return fmt.Errorf("unknown perf event type: %d", pe.typ)
}
// The ioctl below will fail when the fd is invalid. // The ioctl below will fail when the fd is invalid.
kfd, _ := pe.fd.Value() kfd, _ := pe.fd.Value()
@ -180,8 +207,8 @@ func unsafeStringPtr(str string) (unsafe.Pointer, error) {
// group and name must be alphanumeric or underscore, as required by the kernel. // group and name must be alphanumeric or underscore, as required by the kernel.
func getTraceEventID(group, name string) (uint64, error) { func getTraceEventID(group, name string) (uint64, error) {
tid, err := uint64FromFile(tracefsPath, "events", group, name, "id") tid, err := uint64FromFile(tracefsPath, "events", group, name, "id")
if errors.Is(err, ErrNotSupported) { if errors.Is(err, os.ErrNotExist) {
return 0, fmt.Errorf("trace event %s/%s: %w", group, name, ErrNotSupported) return 0, fmt.Errorf("trace event %s/%s: %w", group, name, os.ErrNotExist)
} }
if err != nil { if err != nil {
return 0, fmt.Errorf("reading trace event ID of %s/%s: %w", group, name, err) return 0, fmt.Errorf("reading trace event ID of %s/%s: %w", group, name, err)
@ -192,20 +219,22 @@ func getTraceEventID(group, name string) (uint64, error) {
// getPMUEventType reads a Performance Monitoring Unit's type (numeric identifier) // getPMUEventType reads a Performance Monitoring Unit's type (numeric identifier)
// from /sys/bus/event_source/devices/<pmu>/type. // from /sys/bus/event_source/devices/<pmu>/type.
func getPMUEventType(pmu string) (uint64, error) { //
et, err := uint64FromFile("/sys/bus/event_source/devices", pmu, "type") // Returns ErrNotSupported if the pmu type is not supported.
if errors.Is(err, ErrNotSupported) { func getPMUEventType(typ probeType) (uint64, error) {
return 0, fmt.Errorf("pmu type %s: %w", pmu, ErrNotSupported) et, err := uint64FromFile("/sys/bus/event_source/devices", typ.String(), "type")
if errors.Is(err, os.ErrNotExist) {
return 0, fmt.Errorf("pmu type %s: %w", typ, ErrNotSupported)
} }
if err != nil { if err != nil {
return 0, fmt.Errorf("reading pmu type %s: %w", pmu, err) return 0, fmt.Errorf("reading pmu type %s: %w", typ, err)
} }
return et, nil return et, nil
} }
// openTracepointPerfEvent opens a tracepoint-type perf event. System-wide // openTracepointPerfEvent opens a tracepoint-type perf event. System-wide
// kprobes created by writing to <tracefs>/kprobe_events are tracepoints // [k,u]probes created by writing to <tracefs>/[k,u]probe_events are tracepoints
// behind the scenes, and can be attached to using these perf events. // behind the scenes, and can be attached to using these perf events.
func openTracepointPerfEvent(tid uint64) (*internal.FD, error) { func openTracepointPerfEvent(tid uint64) (*internal.FD, error) {
attr := unix.PerfEventAttr{ attr := unix.PerfEventAttr{
@ -228,22 +257,13 @@ func openTracepointPerfEvent(tid uint64) (*internal.FD, error) {
// and joined onto base. Returns error if base no longer prefixes the path after // and joined onto base. Returns error if base no longer prefixes the path after
// joining all components. // joining all components.
func uint64FromFile(base string, path ...string) (uint64, error) { func uint64FromFile(base string, path ...string) (uint64, error) {
// Resolve leaf path separately for error feedback. Makes the join onto
// base more readable (can't mix with variadic args).
l := filepath.Join(path...) l := filepath.Join(path...)
p := filepath.Join(base, l) p := filepath.Join(base, l)
if !strings.HasPrefix(p, base) { if !strings.HasPrefix(p, base) {
return 0, fmt.Errorf("path '%s' attempts to escape base path '%s': %w", l, base, errInvalidInput) return 0, fmt.Errorf("path '%s' attempts to escape base path '%s': %w", l, base, errInvalidInput)
} }
data, err := ioutil.ReadFile(p) data, err := ioutil.ReadFile(p)
if os.IsNotExist(err) {
// Only echo leaf path, the base path can be prepended at the call site
// if more verbosity is required.
return 0, fmt.Errorf("symbol %s: %w", l, ErrNotSupported)
}
if err != nil { if err != nil {
return 0, fmt.Errorf("reading file %s: %w", p, err) return 0, fmt.Errorf("reading file %s: %w", p, err)
} }

25
vendor/github.com/cilium/ebpf/link/platform.go generated vendored Normal file
View File

@ -0,0 +1,25 @@
package link
import (
"fmt"
"runtime"
)
func platformPrefix(symbol string) string {
prefix := runtime.GOARCH
// per https://github.com/golang/go/blob/master/src/go/build/syslist.go
switch prefix {
case "386":
prefix = "ia32"
case "amd64", "amd64p32":
prefix = "x64"
case "arm64", "arm64be":
prefix = "arm64"
default:
return symbol
}
return fmt.Sprintf("__%s_%s", prefix, symbol)
}

View File

@ -43,7 +43,7 @@ func RawAttachProgram(opts RawAttachProgramOptions) error {
} }
if err := internal.BPFProgAttach(&attr); err != nil { if err := internal.BPFProgAttach(&attr); err != nil {
return fmt.Errorf("can't attach program: %s", err) return fmt.Errorf("can't attach program: %w", err)
} }
return nil return nil
} }
@ -69,7 +69,7 @@ func RawDetachProgram(opts RawDetachProgramOptions) error {
AttachType: uint32(opts.Attach), AttachType: uint32(opts.Attach),
} }
if err := internal.BPFProgDetach(&attr); err != nil { if err := internal.BPFProgDetach(&attr); err != nil {
return fmt.Errorf("can't detach program: %s", err) return fmt.Errorf("can't detach program: %w", err)
} }
return nil return nil

View File

@ -11,7 +11,7 @@ import (
// tracepoints. The top-level directory is the group, the event's subdirectory // tracepoints. The top-level directory is the group, the event's subdirectory
// is the name. Example: // is the name. Example:
// //
// Tracepoint("syscalls", "sys_enter_fork") // Tracepoint("syscalls", "sys_enter_fork", prog)
// //
// Note that attaching eBPF programs to syscalls (sys_enter_*/sys_exit_*) is // Note that attaching eBPF programs to syscalls (sys_enter_*/sys_exit_*) is
// only possible as of kernel 4.14 (commit cf5f5ce). // only possible as of kernel 4.14 (commit cf5f5ce).
@ -44,7 +44,7 @@ func Tracepoint(group, name string, prog *ebpf.Program) (Link, error) {
tracefsID: tid, tracefsID: tid,
group: group, group: group,
name: name, name: name,
progType: ebpf.TracePoint, typ: tracepointEvent,
} }
if err := pe.attach(prog); err != nil { if err := pe.attach(prog); err != nil {

207
vendor/github.com/cilium/ebpf/link/uprobe.go generated vendored Normal file
View File

@ -0,0 +1,207 @@
package link
import (
"debug/elf"
"errors"
"fmt"
"os"
"path/filepath"
"regexp"
"sync"
"github.com/cilium/ebpf"
"github.com/cilium/ebpf/internal"
)
var (
uprobeEventsPath = filepath.Join(tracefsPath, "uprobe_events")
// rgxUprobeSymbol is used to strip invalid characters from the uprobe symbol
// as they are not allowed to be used as the EVENT token in tracefs.
rgxUprobeSymbol = regexp.MustCompile("[^a-zA-Z0-9]+")
uprobeRetprobeBit = struct {
once sync.Once
value uint64
err error
}{}
)
// Executable defines an executable program on the filesystem.
type Executable struct {
// Path of the executable on the filesystem.
path string
// Parsed ELF symbols and dynamic symbols.
symbols map[string]elf.Symbol
}
// To open a new Executable, use:
//
// OpenExecutable("/bin/bash")
//
// The returned value can then be used to open Uprobe(s).
func OpenExecutable(path string) (*Executable, error) {
if path == "" {
return nil, fmt.Errorf("path cannot be empty")
}
f, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("open file '%s': %w", path, err)
}
defer f.Close()
se, err := internal.NewSafeELFFile(f)
if err != nil {
return nil, fmt.Errorf("parse ELF file: %w", err)
}
var ex = Executable{
path: path,
symbols: make(map[string]elf.Symbol),
}
if err := ex.addSymbols(se.Symbols); err != nil {
return nil, err
}
if err := ex.addSymbols(se.DynamicSymbols); err != nil {
return nil, err
}
return &ex, nil
}
func (ex *Executable) addSymbols(f func() ([]elf.Symbol, error)) error {
// elf.Symbols and elf.DynamicSymbols return ErrNoSymbols if the section is not found.
syms, err := f()
if err != nil && !errors.Is(err, elf.ErrNoSymbols) {
return err
}
for _, s := range syms {
ex.symbols[s.Name] = s
}
return nil
}
func (ex *Executable) symbol(symbol string) (*elf.Symbol, error) {
if s, ok := ex.symbols[symbol]; ok {
return &s, nil
}
return nil, fmt.Errorf("symbol %s not found", symbol)
}
// Uprobe attaches the given eBPF program to a perf event that fires when the
// given symbol starts executing in the given Executable.
// For example, /bin/bash::main():
//
// ex, _ = OpenExecutable("/bin/bash")
// ex.Uprobe("main", prog)
//
// The resulting Link must be Closed during program shutdown to avoid leaking
// system resources. Functions provided by shared libraries can currently not
// be traced and will result in an ErrNotSupported.
func (ex *Executable) Uprobe(symbol string, prog *ebpf.Program) (Link, error) {
u, err := ex.uprobe(symbol, prog, false)
if err != nil {
return nil, err
}
err = u.attach(prog)
if err != nil {
u.Close()
return nil, err
}
return u, nil
}
// Uretprobe attaches the given eBPF program to a perf event that fires right
// before the given symbol exits. For example, /bin/bash::main():
//
// ex, _ = OpenExecutable("/bin/bash")
// ex.Uretprobe("main", prog)
//
// The resulting Link must be Closed during program shutdown to avoid leaking
// system resources. Functions provided by shared libraries can currently not
// be traced and will result in an ErrNotSupported.
func (ex *Executable) Uretprobe(symbol string, prog *ebpf.Program) (Link, error) {
u, err := ex.uprobe(symbol, prog, true)
if err != nil {
return nil, err
}
err = u.attach(prog)
if err != nil {
u.Close()
return nil, err
}
return u, nil
}
// uprobe opens a perf event for the given binary/symbol and attaches prog to it.
// If ret is true, create a uretprobe.
func (ex *Executable) uprobe(symbol string, prog *ebpf.Program, ret bool) (*perfEvent, error) {
if prog == nil {
return nil, fmt.Errorf("prog cannot be nil: %w", errInvalidInput)
}
if prog.Type() != ebpf.Kprobe {
return nil, fmt.Errorf("eBPF program type %s is not Kprobe: %w", prog.Type(), errInvalidInput)
}
sym, err := ex.symbol(symbol)
if err != nil {
return nil, fmt.Errorf("symbol '%s' not found in '%s': %w", symbol, ex.path, err)
}
// Symbols with location 0 from section undef are shared library calls and
// are relocated before the binary is executed. Dynamic linking is not
// implemented by the library, so mark this as unsupported for now.
if sym.Section == elf.SHN_UNDEF && sym.Value == 0 {
return nil, fmt.Errorf("cannot resolve %s library call '%s': %w", ex.path, symbol, ErrNotSupported)
}
// Use uprobe PMU if the kernel has it available.
tp, err := pmuUprobe(sym.Name, ex.path, sym.Value, ret)
if err == nil {
return tp, nil
}
if err != nil && !errors.Is(err, ErrNotSupported) {
return nil, fmt.Errorf("creating perf_uprobe PMU: %w", err)
}
// Use tracefs if uprobe PMU is missing.
tp, err = tracefsUprobe(uprobeSanitizedSymbol(sym.Name), ex.path, sym.Value, ret)
if err != nil {
return nil, fmt.Errorf("creating trace event '%s:%s' in tracefs: %w", ex.path, symbol, err)
}
return tp, nil
}
// pmuUprobe opens a perf event based on the uprobe PMU.
func pmuUprobe(symbol, path string, offset uint64, ret bool) (*perfEvent, error) {
return pmuProbe(uprobeType, symbol, path, offset, ret)
}
// tracefsUprobe creates a Uprobe tracefs entry.
func tracefsUprobe(symbol, path string, offset uint64, ret bool) (*perfEvent, error) {
return tracefsProbe(uprobeType, symbol, path, offset, ret)
}
// uprobeSanitizedSymbol replaces every invalid characted for the tracefs api with an underscore.
func uprobeSanitizedSymbol(symbol string) string {
return rgxUprobeSymbol.ReplaceAllString(symbol, "_")
}
// uprobePathOffset creates the PATH:OFFSET token for the tracefs api.
func uprobePathOffset(path string, offset uint64) string {
return fmt.Sprintf("%s:%#x", path, offset)
}
func uretprobeBit() (uint64, error) {
uprobeRetprobeBit.once.Do(func() {
uprobeRetprobeBit.value, uprobeRetprobeBit.err = determineRetprobeBit(uprobeType)
})
return uprobeRetprobeBit.value, uprobeRetprobeBit.err
}

View File

@ -108,12 +108,16 @@ func fixupJumpsAndCalls(insns asm.Instructions) error {
offset := iter.Offset offset := iter.Offset
ins := iter.Ins ins := iter.Ins
if ins.Reference == "" {
continue
}
switch { switch {
case ins.IsFunctionCall() && ins.Constant == -1: case ins.IsFunctionCall() && ins.Constant == -1:
// Rewrite bpf to bpf call // Rewrite bpf to bpf call
callOffset, ok := symbolOffsets[ins.Reference] callOffset, ok := symbolOffsets[ins.Reference]
if !ok { if !ok {
return fmt.Errorf("instruction %d: reference to missing symbol %q", i, ins.Reference) return fmt.Errorf("call at %d: reference to missing symbol %q", i, ins.Reference)
} }
ins.Constant = int64(callOffset - offset - 1) ins.Constant = int64(callOffset - offset - 1)
@ -122,10 +126,13 @@ func fixupJumpsAndCalls(insns asm.Instructions) error {
// Rewrite jump to label // Rewrite jump to label
jumpOffset, ok := symbolOffsets[ins.Reference] jumpOffset, ok := symbolOffsets[ins.Reference]
if !ok { if !ok {
return fmt.Errorf("instruction %d: reference to missing symbol %q", i, ins.Reference) return fmt.Errorf("jump at %d: reference to missing symbol %q", i, ins.Reference)
} }
ins.Offset = int16(jumpOffset - offset - 1) ins.Offset = int16(jumpOffset - offset - 1)
case ins.IsLoadFromMap() && ins.MapPtr() == -1:
return fmt.Errorf("map %s: %w", ins.Reference, errUnsatisfiedReference)
} }
} }

50
vendor/github.com/cilium/ebpf/map.go generated vendored
View File

@ -18,6 +18,7 @@ var (
ErrKeyNotExist = errors.New("key does not exist") ErrKeyNotExist = errors.New("key does not exist")
ErrKeyExist = errors.New("key already exists") ErrKeyExist = errors.New("key already exists")
ErrIterationAborted = errors.New("iteration aborted") ErrIterationAborted = errors.New("iteration aborted")
ErrMapIncompatible = errors.New("map's spec is incompatible with pinned map")
) )
// MapOptions control loading a map into the kernel. // MapOptions control loading a map into the kernel.
@ -87,6 +88,23 @@ func (ms *MapSpec) Copy() *MapSpec {
return &cpy return &cpy
} }
func (ms *MapSpec) clampPerfEventArraySize() error {
if ms.Type != PerfEventArray {
return nil
}
n, err := internal.PossibleCPUs()
if err != nil {
return fmt.Errorf("perf event array: %w", err)
}
if n := uint32(n); ms.MaxEntries > n {
ms.MaxEntries = n
}
return nil
}
// MapKV is used to initialize the contents of a Map. // MapKV is used to initialize the contents of a Map.
type MapKV struct { type MapKV struct {
Key interface{} Key interface{}
@ -96,19 +114,19 @@ type MapKV struct {
func (ms *MapSpec) checkCompatibility(m *Map) error { func (ms *MapSpec) checkCompatibility(m *Map) error {
switch { switch {
case m.typ != ms.Type: case m.typ != ms.Type:
return fmt.Errorf("expected type %v, got %v", ms.Type, m.typ) return fmt.Errorf("expected type %v, got %v: %w", ms.Type, m.typ, ErrMapIncompatible)
case m.keySize != ms.KeySize: case m.keySize != ms.KeySize:
return fmt.Errorf("expected key size %v, got %v", ms.KeySize, m.keySize) return fmt.Errorf("expected key size %v, got %v: %w", ms.KeySize, m.keySize, ErrMapIncompatible)
case m.valueSize != ms.ValueSize: case m.valueSize != ms.ValueSize:
return fmt.Errorf("expected value size %v, got %v", ms.ValueSize, m.valueSize) return fmt.Errorf("expected value size %v, got %v: %w", ms.ValueSize, m.valueSize, ErrMapIncompatible)
case m.maxEntries != ms.MaxEntries: case m.maxEntries != ms.MaxEntries:
return fmt.Errorf("expected max entries %v, got %v", ms.MaxEntries, m.maxEntries) return fmt.Errorf("expected max entries %v, got %v: %w", ms.MaxEntries, m.maxEntries, ErrMapIncompatible)
case m.flags != ms.Flags: case m.flags != ms.Flags:
return fmt.Errorf("expected flags %v, got %v", ms.Flags, m.flags) return fmt.Errorf("expected flags %v, got %v: %w", ms.Flags, m.flags, ErrMapIncompatible)
} }
return nil return nil
} }
@ -171,14 +189,16 @@ func NewMap(spec *MapSpec) (*Map, error) {
// The caller is responsible for ensuring the process' rlimit is set // The caller is responsible for ensuring the process' rlimit is set
// sufficiently high for locking memory during map creation. This can be done // sufficiently high for locking memory during map creation. This can be done
// by calling unix.Setrlimit with unix.RLIMIT_MEMLOCK prior to calling NewMapWithOptions. // by calling unix.Setrlimit with unix.RLIMIT_MEMLOCK prior to calling NewMapWithOptions.
//
// May return an error wrapping ErrMapIncompatible.
func NewMapWithOptions(spec *MapSpec, opts MapOptions) (*Map, error) { func NewMapWithOptions(spec *MapSpec, opts MapOptions) (*Map, error) {
btfs := make(btfHandleCache) handles := newHandleCache()
defer btfs.close() defer handles.close()
return newMapWithOptions(spec, opts, btfs) return newMapWithOptions(spec, opts, handles)
} }
func newMapWithOptions(spec *MapSpec, opts MapOptions, btfs btfHandleCache) (_ *Map, err error) { func newMapWithOptions(spec *MapSpec, opts MapOptions, handles *handleCache) (_ *Map, err error) {
closeOnError := func(c io.Closer) { closeOnError := func(c io.Closer) {
if err != nil { if err != nil {
c.Close() c.Close()
@ -202,7 +222,7 @@ func newMapWithOptions(spec *MapSpec, opts MapOptions, btfs btfHandleCache) (_ *
defer closeOnError(m) defer closeOnError(m)
if err := spec.checkCompatibility(m); err != nil { if err := spec.checkCompatibility(m); err != nil {
return nil, fmt.Errorf("use pinned map %s: %s", spec.Name, err) return nil, fmt.Errorf("use pinned map %s: %w", spec.Name, err)
} }
return m, nil return m, nil
@ -211,7 +231,7 @@ func newMapWithOptions(spec *MapSpec, opts MapOptions, btfs btfHandleCache) (_ *
// Nothing to do here // Nothing to do here
default: default:
return nil, fmt.Errorf("unsupported pin type %d", int(spec.Pinning)) return nil, fmt.Errorf("pin type %d: %w", int(spec.Pinning), ErrNotSupported)
} }
var innerFd *internal.FD var innerFd *internal.FD
@ -224,7 +244,7 @@ func newMapWithOptions(spec *MapSpec, opts MapOptions, btfs btfHandleCache) (_ *
return nil, errors.New("inner maps cannot be pinned") return nil, errors.New("inner maps cannot be pinned")
} }
template, err := createMap(spec.InnerMap, nil, opts, btfs) template, err := createMap(spec.InnerMap, nil, opts, handles)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -233,7 +253,7 @@ func newMapWithOptions(spec *MapSpec, opts MapOptions, btfs btfHandleCache) (_ *
innerFd = template.fd innerFd = template.fd
} }
m, err := createMap(spec, innerFd, opts, btfs) m, err := createMap(spec, innerFd, opts, handles)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -249,7 +269,7 @@ func newMapWithOptions(spec *MapSpec, opts MapOptions, btfs btfHandleCache) (_ *
return m, nil return m, nil
} }
func createMap(spec *MapSpec, inner *internal.FD, opts MapOptions, btfs btfHandleCache) (_ *Map, err error) { func createMap(spec *MapSpec, inner *internal.FD, opts MapOptions, handles *handleCache) (_ *Map, err error) {
closeOnError := func(closer io.Closer) { closeOnError := func(closer io.Closer) {
if err != nil { if err != nil {
closer.Close() closer.Close()
@ -320,7 +340,7 @@ func createMap(spec *MapSpec, inner *internal.FD, opts MapOptions, btfs btfHandl
var btfDisabled bool var btfDisabled bool
if spec.BTF != nil { if spec.BTF != nil {
handle, err := btfs.load(btf.MapSpec(spec.BTF)) handle, err := handles.btfHandle(btf.MapSpec(spec.BTF))
btfDisabled = errors.Is(err, btf.ErrNotSupported) btfDisabled = errors.Is(err, btf.ErrNotSupported)
if err != nil && !btfDisabled { if err != nil && !btfDisabled {
return nil, fmt.Errorf("load BTF: %w", err) return nil, fmt.Errorf("load BTF: %w", err)

153
vendor/github.com/cilium/ebpf/prog.go generated vendored
View File

@ -5,6 +5,7 @@ import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
"io"
"math" "math"
"path/filepath" "path/filepath"
"strings" "strings"
@ -19,6 +20,8 @@ import (
// ErrNotSupported is returned whenever the kernel doesn't support a feature. // ErrNotSupported is returned whenever the kernel doesn't support a feature.
var ErrNotSupported = internal.ErrNotSupported var ErrNotSupported = internal.ErrNotSupported
var errUnsatisfiedReference = errors.New("unsatisfied reference")
// ProgramID represents the unique ID of an eBPF program. // ProgramID represents the unique ID of an eBPF program.
type ProgramID uint32 type ProgramID uint32
@ -41,6 +44,12 @@ type ProgramOptions struct {
// Controls the output buffer size for the verifier. Defaults to // Controls the output buffer size for the verifier. Defaults to
// DefaultVerifierLogSize. // DefaultVerifierLogSize.
LogSize int LogSize int
// An ELF containing the target BTF for this program. It is used both to
// find the correct function to trace and to apply CO-RE relocations.
// This is useful in environments where the kernel BTF is not available
// (containers) or where it is in a non-standard location. Defaults to
// use the kernel BTF from a well-known location.
TargetBTF io.ReaderAt
} }
// ProgramSpec defines a Program. // ProgramSpec defines a Program.
@ -125,21 +134,21 @@ func NewProgram(spec *ProgramSpec) (*Program, error) {
// Loading a program for the first time will perform // Loading a program for the first time will perform
// feature detection by loading small, temporary programs. // feature detection by loading small, temporary programs.
func NewProgramWithOptions(spec *ProgramSpec, opts ProgramOptions) (*Program, error) { func NewProgramWithOptions(spec *ProgramSpec, opts ProgramOptions) (*Program, error) {
btfs := make(btfHandleCache) handles := newHandleCache()
defer btfs.close() defer handles.close()
return newProgramWithOptions(spec, opts, btfs) prog, err := newProgramWithOptions(spec, opts, handles)
if errors.Is(err, errUnsatisfiedReference) {
return nil, fmt.Errorf("cannot load program without loading its whole collection: %w", err)
}
return prog, err
} }
func newProgramWithOptions(spec *ProgramSpec, opts ProgramOptions, btfs btfHandleCache) (*Program, error) { func newProgramWithOptions(spec *ProgramSpec, opts ProgramOptions, handles *handleCache) (*Program, error) {
if len(spec.Instructions) == 0 { if len(spec.Instructions) == 0 {
return nil, errors.New("Instructions cannot be empty") return nil, errors.New("Instructions cannot be empty")
} }
if len(spec.License) == 0 {
return nil, errors.New("License cannot be empty")
}
if spec.ByteOrder != nil && spec.ByteOrder != internal.NativeEndian { if spec.ByteOrder != nil && spec.ByteOrder != internal.NativeEndian {
return nil, fmt.Errorf("can't load %s program on %s", spec.ByteOrder, internal.NativeEndian) return nil, fmt.Errorf("can't load %s program on %s", spec.ByteOrder, internal.NativeEndian)
} }
@ -157,27 +166,10 @@ func newProgramWithOptions(spec *ProgramSpec, opts ProgramOptions, btfs btfHandl
kv = v.Kernel() kv = v.Kernel()
} }
insns := make(asm.Instructions, len(spec.Instructions))
copy(insns, spec.Instructions)
if err := fixupJumpsAndCalls(insns); err != nil {
return nil, err
}
buf := bytes.NewBuffer(make([]byte, 0, len(spec.Instructions)*asm.InstructionSize))
err := insns.Marshal(buf, internal.NativeEndian)
if err != nil {
return nil, err
}
bytecode := buf.Bytes()
insCount := uint32(len(bytecode) / asm.InstructionSize)
attr := &bpfProgLoadAttr{ attr := &bpfProgLoadAttr{
progType: spec.Type, progType: spec.Type,
progFlags: spec.Flags, progFlags: spec.Flags,
expectedAttachType: spec.AttachType, expectedAttachType: spec.AttachType,
insCount: insCount,
instructions: internal.NewSlicePointer(bytecode),
license: internal.NewStringPointer(spec.License), license: internal.NewStringPointer(spec.License),
kernelVersion: kv, kernelVersion: kv,
} }
@ -186,15 +178,24 @@ func newProgramWithOptions(spec *ProgramSpec, opts ProgramOptions, btfs btfHandl
attr.progName = newBPFObjName(spec.Name) attr.progName = newBPFObjName(spec.Name)
} }
var btfDisabled bool var err error
if spec.BTF != nil { var targetBTF *btf.Spec
if relos, err := btf.ProgramRelocations(spec.BTF, nil); err != nil { if opts.TargetBTF != nil {
return nil, fmt.Errorf("CO-RE relocations: %s", err) targetBTF, err = handles.btfSpec(opts.TargetBTF)
} else if len(relos) > 0 { if err != nil {
return nil, fmt.Errorf("applying CO-RE relocations: %w", ErrNotSupported) return nil, fmt.Errorf("load target BTF: %w", err)
}
} }
handle, err := btfs.load(btf.ProgramSpec(spec.BTF)) var btfDisabled bool
var core btf.COREFixups
if spec.BTF != nil {
core, err = btf.ProgramFixups(spec.BTF, targetBTF)
if err != nil {
return nil, fmt.Errorf("CO-RE relocations: %w", err)
}
handle, err := handles.btfHandle(btf.ProgramSpec(spec.BTF))
btfDisabled = errors.Is(err, btf.ErrNotSupported) btfDisabled = errors.Is(err, btf.ErrNotSupported)
if err != nil && !btfDisabled { if err != nil && !btfDisabled {
return nil, fmt.Errorf("load BTF: %w", err) return nil, fmt.Errorf("load BTF: %w", err)
@ -221,8 +222,27 @@ func newProgramWithOptions(spec *ProgramSpec, opts ProgramOptions, btfs btfHandl
} }
} }
insns, err := core.Apply(spec.Instructions)
if err != nil {
return nil, fmt.Errorf("CO-RE fixup: %w", err)
}
if err := fixupJumpsAndCalls(insns); err != nil {
return nil, err
}
buf := bytes.NewBuffer(make([]byte, 0, len(spec.Instructions)*asm.InstructionSize))
err = insns.Marshal(buf, internal.NativeEndian)
if err != nil {
return nil, err
}
bytecode := buf.Bytes()
attr.instructions = internal.NewSlicePointer(bytecode)
attr.insCount = uint32(len(bytecode) / asm.InstructionSize)
if spec.AttachTo != "" { if spec.AttachTo != "" {
target, err := resolveBTFType(spec.AttachTo, spec.Type, spec.AttachType) target, err := resolveBTFType(targetBTF, spec.AttachTo, spec.Type, spec.AttachType)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -250,7 +270,7 @@ func newProgramWithOptions(spec *ProgramSpec, opts ProgramOptions, btfs btfHandl
} }
logErr := err logErr := err
if opts.LogLevel == 0 { if opts.LogLevel == 0 && opts.LogSize >= 0 {
// Re-run with the verifier enabled to get better error messages. // Re-run with the verifier enabled to get better error messages.
logBuf = make([]byte, logSize) logBuf = make([]byte, logSize)
attr.logLevel = 1 attr.logLevel = 1
@ -664,52 +684,45 @@ func (p *Program) ID() (ProgramID, error) {
return ProgramID(info.id), nil return ProgramID(info.id), nil
} }
func findKernelType(name string, typ btf.Type) error { func resolveBTFType(kernel *btf.Spec, name string, progType ProgramType, attachType AttachType) (btf.Type, error) {
kernel, err := btf.LoadKernelSpec()
if err != nil {
return fmt.Errorf("can't load kernel spec: %w", err)
}
return kernel.FindType(name, typ)
}
func resolveBTFType(name string, progType ProgramType, attachType AttachType) (btf.Type, error) {
type match struct { type match struct {
p ProgramType p ProgramType
a AttachType a AttachType
} }
target := match{progType, attachType} var target btf.Type
switch target { var typeName, featureName string
switch (match{progType, attachType}) {
case match{LSM, AttachLSMMac}: case match{LSM, AttachLSMMac}:
var target btf.Func target = new(btf.Func)
err := findKernelType("bpf_lsm_"+name, &target) typeName = "bpf_lsm_" + name
if errors.Is(err, btf.ErrNotFound) { featureName = name + " LSM hook"
return nil, &internal.UnsupportedFeatureError{
Name: name + " LSM hook",
}
}
if err != nil {
return nil, fmt.Errorf("resolve BTF for LSM hook %s: %w", name, err)
}
return &target, nil
case match{Tracing, AttachTraceIter}: case match{Tracing, AttachTraceIter}:
var target btf.Func target = new(btf.Func)
err := findKernelType("bpf_iter_"+name, &target) typeName = "bpf_iter_" + name
if errors.Is(err, btf.ErrNotFound) { featureName = name + " iterator"
return nil, &internal.UnsupportedFeatureError{
Name: name + " iterator",
}
}
if err != nil {
return nil, fmt.Errorf("resolve BTF for iterator %s: %w", name, err)
}
return &target, nil
default: default:
return nil, nil return nil, nil
} }
if kernel == nil {
var err error
kernel, err = btf.LoadKernelSpec()
if err != nil {
return nil, fmt.Errorf("load kernel spec: %w", err)
}
}
err := kernel.FindType(typeName, target)
if errors.Is(err, btf.ErrNotFound) {
return nil, &internal.UnsupportedFeatureError{
Name: featureName,
}
}
if err != nil {
return nil, fmt.Errorf("resolve BTF for %s: %w", featureName, err)
}
return target, nil
} }

View File

@ -1,56 +1,95 @@
#!/bin/bash #!/bin/bash
# Test the current package under a different kernel. # Test the current package under a different kernel.
# Requires virtme and qemu to be installed. # Requires virtme and qemu to be installed.
# Examples:
# Run all tests on a 5.4 kernel
# $ ./run-tests.sh 5.4
# Run a subset of tests:
# $ ./run-tests.sh 5.4 go test ./link
set -eu set -euo pipefail
set -o pipefail
if [[ "${1:-}" = "--in-vm" ]]; then script="$(realpath "$0")"
readonly script
# This script is a bit like a Matryoshka doll since it keeps re-executing itself
# in various different contexts:
#
# 1. invoked by the user like run-tests.sh 5.4
# 2. invoked by go test like run-tests.sh --exec-vm
# 3. invoked by init in the vm like run-tests.sh --exec-test
#
# This allows us to use all available CPU on the host machine to compile our
# code, and then only use the VM to execute the test. This is because the VM
# is usually slower at compiling than the host.
if [[ "${1:-}" = "--exec-vm" ]]; then
shift
input="$1"
shift
# Use sudo if /dev/kvm isn't accessible by the current user.
sudo=""
if [[ ! -r /dev/kvm || ! -w /dev/kvm ]]; then
sudo="sudo"
fi
readonly sudo
testdir="$(dirname "$1")"
output="$(mktemp -d)"
printf -v cmd "%q " "$@"
if [[ "$(stat -c '%t:%T' -L /proc/$$/fd/0)" == "1:3" ]]; then
# stdin is /dev/null, which doesn't play well with qemu. Use a fifo as a
# blocking substitute.
mkfifo "${output}/fake-stdin"
# Open for reading and writing to avoid blocking.
exec 0<> "${output}/fake-stdin"
rm "${output}/fake-stdin"
fi
$sudo virtme-run --kimg "${input}/bzImage" --memory 768M --pwd \
--rwdir="${testdir}=${testdir}" \
--rodir=/run/input="${input}" \
--rwdir=/run/output="${output}" \
--script-sh "PATH=\"$PATH\" \"$script\" --exec-test $cmd" \
--qemu-opts -smp 2 # need at least two CPUs for some tests
if [[ ! -e "${output}/success" ]]; then
exit 1
fi
$sudo rm -r "$output"
exit 0
elif [[ "${1:-}" = "--exec-test" ]]; then
shift shift
mount -t bpf bpf /sys/fs/bpf mount -t bpf bpf /sys/fs/bpf
mount -t tracefs tracefs /sys/kernel/debug/tracing mount -t tracefs tracefs /sys/kernel/debug/tracing
export CGO_ENABLED=0
export GOFLAGS=-mod=readonly
export GOPATH=/run/go-path
export GOPROXY=file:///run/go-path/pkg/mod/cache/download
export GOSUMDB=off
export GOCACHE=/run/go-cache
if [[ -d "/run/input/bpf" ]]; then if [[ -d "/run/input/bpf" ]]; then
export KERNEL_SELFTESTS="/run/input/bpf" export KERNEL_SELFTESTS="/run/input/bpf"
fi fi
readonly output="${1}" dmesg -C
shift if ! "$@"; then
dmesg
echo Running tests... exit 1
go test -v -coverpkg=./... -coverprofile="$output/coverage.txt" -count 1 ./... fi
touch "$output/success" touch "/run/output/success"
exit 0 exit 0
fi fi
# Pull all dependencies, so that we can run tests without the
# vm having network access.
go mod download
# Use sudo if /dev/kvm isn't accessible by the current user.
sudo=""
if [[ ! -r /dev/kvm || ! -w /dev/kvm ]]; then
sudo="sudo"
fi
readonly sudo
readonly kernel_version="${1:-}" readonly kernel_version="${1:-}"
if [[ -z "${kernel_version}" ]]; then if [[ -z "${kernel_version}" ]]; then
echo "Expecting kernel version as first argument" echo "Expecting kernel version as first argument"
exit 1 exit 1
fi fi
shift
readonly kernel="linux-${kernel_version}.bz" readonly kernel="linux-${kernel_version}.bz"
readonly selftests="linux-${kernel_version}-selftests-bpf.bz" readonly selftests="linux-${kernel_version}-selftests-bpf.bz"
readonly input="$(mktemp -d)" readonly input="$(mktemp -d)"
readonly output="$(mktemp -d)"
readonly tmp_dir="${TMPDIR:-/tmp}" readonly tmp_dir="${TMPDIR:-/tmp}"
readonly branch="${BRANCH:-master}" readonly branch="${BRANCH:-master}"
@ -60,6 +99,7 @@ fetch() {
} }
fetch "${kernel}" fetch "${kernel}"
cp "${tmp_dir}/${kernel}" "${input}/bzImage"
if fetch "${selftests}"; then if fetch "${selftests}"; then
mkdir "${input}/bpf" mkdir "${input}/bpf"
@ -68,25 +108,16 @@ else
echo "No selftests found, disabling" echo "No selftests found, disabling"
fi fi
echo Testing on "${kernel_version}" args=(-v -short -coverpkg=./... -coverprofile=coverage.out -count 1 ./...)
$sudo virtme-run --kimg "${tmp_dir}/${kernel}" --memory 512M --pwd \ if (( $# > 0 )); then
--rw \ args=("$@")
--rwdir=/run/input="${input}" \
--rwdir=/run/output="${output}" \
--rodir=/run/go-path="$(go env GOPATH)" \
--rwdir=/run/go-cache="$(go env GOCACHE)" \
--script-sh "PATH=\"$PATH\" $(realpath "$0") --in-vm /run/output" \
--qemu-opts -smp 2 # need at least two CPUs for some tests
if [[ ! -e "${output}/success" ]]; then
echo "Test failed on ${kernel_version}"
exit 1
else
echo "Test successful on ${kernel_version}"
if [[ -v COVERALLS_TOKEN ]]; then
goveralls -coverprofile="${output}/coverage.txt" -service=semaphore -repotoken "$COVERALLS_TOKEN"
fi
fi fi
$sudo rm -r "${input}" export GOFLAGS=-mod=readonly
$sudo rm -r "${output}" export CGO_ENABLED=0
echo Testing on "${kernel_version}"
go test -exec "$script --exec-vm $input" "${args[@]}"
echo "Test successful on ${kernel_version}"
rm -r "${input}"

View File

@ -111,14 +111,13 @@ type Conn struct {
} }
} }
// New establishes a connection to any available bus and authenticates. // Deprecated: use NewWithContext instead.
// Callers should call Close() when done with the connection.
// Deprecated: use NewWithContext instead
func New() (*Conn, error) { func New() (*Conn, error) {
return NewWithContext(context.Background()) return NewWithContext(context.Background())
} }
// NewWithContext same as New with context // NewWithContext establishes a connection to any available bus and authenticates.
// Callers should call Close() when done with the connection.
func NewWithContext(ctx context.Context) (*Conn, error) { func NewWithContext(ctx context.Context) (*Conn, error) {
conn, err := NewSystemConnectionContext(ctx) conn, err := NewSystemConnectionContext(ctx)
if err != nil && os.Geteuid() == 0 { if err != nil && os.Geteuid() == 0 {
@ -127,44 +126,41 @@ func NewWithContext(ctx context.Context) (*Conn, error) {
return conn, err return conn, err
} }
// NewSystemConnection establishes a connection to the system bus and authenticates. // Deprecated: use NewSystemConnectionContext instead.
// Callers should call Close() when done with the connection
// Deprecated: use NewSystemConnectionContext instead
func NewSystemConnection() (*Conn, error) { func NewSystemConnection() (*Conn, error) {
return NewSystemConnectionContext(context.Background()) return NewSystemConnectionContext(context.Background())
} }
// NewSystemConnectionContext same as NewSystemConnection with context // NewSystemConnectionContext establishes a connection to the system bus and authenticates.
// Callers should call Close() when done with the connection.
func NewSystemConnectionContext(ctx context.Context) (*Conn, error) { func NewSystemConnectionContext(ctx context.Context) (*Conn, error) {
return NewConnection(func() (*dbus.Conn, error) { return NewConnection(func() (*dbus.Conn, error) {
return dbusAuthHelloConnection(ctx, dbus.SystemBusPrivate) return dbusAuthHelloConnection(ctx, dbus.SystemBusPrivate)
}) })
} }
// NewUserConnection establishes a connection to the session bus and // Deprecated: use NewUserConnectionContext instead.
// authenticates. This can be used to connect to systemd user instances.
// Callers should call Close() when done with the connection.
// Deprecated: use NewUserConnectionContext instead
func NewUserConnection() (*Conn, error) { func NewUserConnection() (*Conn, error) {
return NewUserConnectionContext(context.Background()) return NewUserConnectionContext(context.Background())
} }
// NewUserConnectionContext same as NewUserConnection with context // NewUserConnectionContext establishes a connection to the session bus and
// authenticates. This can be used to connect to systemd user instances.
// Callers should call Close() when done with the connection.
func NewUserConnectionContext(ctx context.Context) (*Conn, error) { func NewUserConnectionContext(ctx context.Context) (*Conn, error) {
return NewConnection(func() (*dbus.Conn, error) { return NewConnection(func() (*dbus.Conn, error) {
return dbusAuthHelloConnection(ctx, dbus.SessionBusPrivate) return dbusAuthHelloConnection(ctx, dbus.SessionBusPrivate)
}) })
} }
// NewSystemdConnection establishes a private, direct connection to systemd. // Deprecated: use NewSystemdConnectionContext instead.
// This can be used for communicating with systemd without a dbus daemon.
// Callers should call Close() when done with the connection.
// Deprecated: use NewSystemdConnectionContext instead
func NewSystemdConnection() (*Conn, error) { func NewSystemdConnection() (*Conn, error) {
return NewSystemdConnectionContext(context.Background()) return NewSystemdConnectionContext(context.Background())
} }
// NewSystemdConnectionContext same as NewSystemdConnection with context // NewSystemdConnectionContext establishes a private, direct connection to systemd.
// This can be used for communicating with systemd without a dbus daemon.
// Callers should call Close() when done with the connection.
func NewSystemdConnectionContext(ctx context.Context) (*Conn, error) { func NewSystemdConnectionContext(ctx context.Context) (*Conn, error) {
return NewConnection(func() (*dbus.Conn, error) { return NewConnection(func() (*dbus.Conn, error) {
// We skip Hello when talking directly to systemd. // We skip Hello when talking directly to systemd.
@ -174,7 +170,7 @@ func NewSystemdConnectionContext(ctx context.Context) (*Conn, error) {
}) })
} }
// Close closes an established connection // Close closes an established connection.
func (c *Conn) Close() { func (c *Conn) Close() {
c.sysconn.Close() c.sysconn.Close()
c.sigconn.Close() c.sigconn.Close()
@ -217,7 +213,7 @@ func NewConnection(dialBus func() (*dbus.Conn, error)) (*Conn, error) {
// GetManagerProperty returns the value of a property on the org.freedesktop.systemd1.Manager // GetManagerProperty returns the value of a property on the org.freedesktop.systemd1.Manager
// interface. The value is returned in its string representation, as defined at // interface. The value is returned in its string representation, as defined at
// https://developer.gnome.org/glib/unstable/gvariant-text.html // https://developer.gnome.org/glib/unstable/gvariant-text.html.
func (c *Conn) GetManagerProperty(prop string) (string, error) { func (c *Conn) GetManagerProperty(prop string) (string, error) {
variant, err := c.sysobj.GetProperty("org.freedesktop.systemd1.Manager." + prop) variant, err := c.sysobj.GetProperty("org.freedesktop.systemd1.Manager." + prop)
if err != nil { if err != nil {

View File

@ -73,7 +73,12 @@ func (c *Conn) startJob(ctx context.Context, ch chan<- string, job string, args
return jobID, nil return jobID, nil
} }
// StartUnit enqueues a start job and depending jobs, if any (unless otherwise // Deprecated: use StartUnitContext instead.
func (c *Conn) StartUnit(name string, mode string, ch chan<- string) (int, error) {
return c.StartUnitContext(context.Background(), name, mode, ch)
}
// StartUnitContext enqueues a start job and depending jobs, if any (unless otherwise
// specified by the mode string). // specified by the mode string).
// //
// Takes the unit to activate, plus a mode string. The mode needs to be one of // Takes the unit to activate, plus a mode string. The mode needs to be one of
@ -103,137 +108,124 @@ func (c *Conn) startJob(ctx context.Context, ch chan<- string, job string, args
// should not be considered authoritative. // should not be considered authoritative.
// //
// If an error does occur, it will be returned to the user alongside a job ID of 0. // If an error does occur, it will be returned to the user alongside a job ID of 0.
// Deprecated: use StartUnitContext instead
func (c *Conn) StartUnit(name string, mode string, ch chan<- string) (int, error) {
return c.StartUnitContext(context.Background(), name, mode, ch)
}
// StartUnitContext same as StartUnit with context
func (c *Conn) StartUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error) { func (c *Conn) StartUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error) {
return c.startJob(ctx, ch, "org.freedesktop.systemd1.Manager.StartUnit", name, mode) return c.startJob(ctx, ch, "org.freedesktop.systemd1.Manager.StartUnit", name, mode)
} }
// StopUnit is similar to StartUnit but stops the specified unit rather // Deprecated: use StopUnitContext instead.
// than starting it.
// Deprecated: use StopUnitContext instead
func (c *Conn) StopUnit(name string, mode string, ch chan<- string) (int, error) { func (c *Conn) StopUnit(name string, mode string, ch chan<- string) (int, error) {
return c.StopUnitContext(context.Background(), name, mode, ch) return c.StopUnitContext(context.Background(), name, mode, ch)
} }
// StopUnitContext same as StopUnit with context // StopUnitContext is similar to StartUnitContext, but stops the specified unit
// rather than starting it.
func (c *Conn) StopUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error) { func (c *Conn) StopUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error) {
return c.startJob(ctx, ch, "org.freedesktop.systemd1.Manager.StopUnit", name, mode) return c.startJob(ctx, ch, "org.freedesktop.systemd1.Manager.StopUnit", name, mode)
} }
// ReloadUnit reloads a unit. Reloading is done only if the unit is already running and fails otherwise. // Deprecated: use ReloadUnitContext instead.
// Deprecated: use ReloadUnitContext instead
func (c *Conn) ReloadUnit(name string, mode string, ch chan<- string) (int, error) { func (c *Conn) ReloadUnit(name string, mode string, ch chan<- string) (int, error) {
return c.ReloadUnitContext(context.Background(), name, mode, ch) return c.ReloadUnitContext(context.Background(), name, mode, ch)
} }
// ReloadUnitContext same as ReloadUnit with context // ReloadUnitContext reloads a unit. Reloading is done only if the unit
// is already running, and fails otherwise.
func (c *Conn) ReloadUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error) { func (c *Conn) ReloadUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error) {
return c.startJob(ctx, ch, "org.freedesktop.systemd1.Manager.ReloadUnit", name, mode) return c.startJob(ctx, ch, "org.freedesktop.systemd1.Manager.ReloadUnit", name, mode)
} }
// RestartUnit restarts a service. If a service is restarted that isn't // Deprecated: use RestartUnitContext instead.
// running it will be started.
// Deprecated: use RestartUnitContext instead
func (c *Conn) RestartUnit(name string, mode string, ch chan<- string) (int, error) { func (c *Conn) RestartUnit(name string, mode string, ch chan<- string) (int, error) {
return c.RestartUnitContext(context.Background(), name, mode, ch) return c.RestartUnitContext(context.Background(), name, mode, ch)
} }
// RestartUnitContext same as RestartUnit with context // RestartUnitContext restarts a service. If a service is restarted that isn't
// running it will be started.
func (c *Conn) RestartUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error) { func (c *Conn) RestartUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error) {
return c.startJob(ctx, ch, "org.freedesktop.systemd1.Manager.RestartUnit", name, mode) return c.startJob(ctx, ch, "org.freedesktop.systemd1.Manager.RestartUnit", name, mode)
} }
// TryRestartUnit is like RestartUnit, except that a service that isn't running // Deprecated: use TryRestartUnitContext instead.
// is not affected by the restart.
// Deprecated: use TryRestartUnitContext instead
func (c *Conn) TryRestartUnit(name string, mode string, ch chan<- string) (int, error) { func (c *Conn) TryRestartUnit(name string, mode string, ch chan<- string) (int, error) {
return c.TryRestartUnitContext(context.Background(), name, mode, ch) return c.TryRestartUnitContext(context.Background(), name, mode, ch)
} }
// TryRestartUnitContext same as TryRestartUnit with context // TryRestartUnitContext is like RestartUnitContext, except that a service that
// isn't running is not affected by the restart.
func (c *Conn) TryRestartUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error) { func (c *Conn) TryRestartUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error) {
return c.startJob(ctx, ch, "org.freedesktop.systemd1.Manager.TryRestartUnit", name, mode) return c.startJob(ctx, ch, "org.freedesktop.systemd1.Manager.TryRestartUnit", name, mode)
} }
// ReloadOrRestartUnit attempts a reload if the unit supports it and use a restart // Deprecated: use ReloadOrRestartUnitContext instead.
// otherwise.
// Deprecated: use ReloadOrRestartUnitContext instead
func (c *Conn) ReloadOrRestartUnit(name string, mode string, ch chan<- string) (int, error) { func (c *Conn) ReloadOrRestartUnit(name string, mode string, ch chan<- string) (int, error) {
return c.ReloadOrRestartUnitContext(context.Background(), name, mode, ch) return c.ReloadOrRestartUnitContext(context.Background(), name, mode, ch)
} }
// ReloadOrRestartUnitContext same as ReloadOrRestartUnit with context // ReloadOrRestartUnitContext attempts a reload if the unit supports it and use
// a restart otherwise.
func (c *Conn) ReloadOrRestartUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error) { func (c *Conn) ReloadOrRestartUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error) {
return c.startJob(ctx, ch, "org.freedesktop.systemd1.Manager.ReloadOrRestartUnit", name, mode) return c.startJob(ctx, ch, "org.freedesktop.systemd1.Manager.ReloadOrRestartUnit", name, mode)
} }
// ReloadOrTryRestartUnit attempts a reload if the unit supports it and use a "Try" // Deprecated: use ReloadOrTryRestartUnitContext instead.
// flavored restart otherwise.
// Deprecated: use ReloadOrTryRestartUnitContext instead
func (c *Conn) ReloadOrTryRestartUnit(name string, mode string, ch chan<- string) (int, error) { func (c *Conn) ReloadOrTryRestartUnit(name string, mode string, ch chan<- string) (int, error) {
return c.ReloadOrTryRestartUnitContext(context.Background(), name, mode, ch) return c.ReloadOrTryRestartUnitContext(context.Background(), name, mode, ch)
} }
// ReloadOrTryRestartUnitContext same as ReloadOrTryRestartUnit with context // ReloadOrTryRestartUnitContext attempts a reload if the unit supports it,
// and use a "Try" flavored restart otherwise.
func (c *Conn) ReloadOrTryRestartUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error) { func (c *Conn) ReloadOrTryRestartUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error) {
return c.startJob(ctx, ch, "org.freedesktop.systemd1.Manager.ReloadOrTryRestartUnit", name, mode) return c.startJob(ctx, ch, "org.freedesktop.systemd1.Manager.ReloadOrTryRestartUnit", name, mode)
} }
// StartTransientUnit() may be used to create and start a transient unit, which // Deprecated: use StartTransientUnitContext instead.
// will be released as soon as it is not running or referenced anymore or the
// system is rebooted. name is the unit name including suffix, and must be
// unique. mode is the same as in StartUnit(), properties contains properties
// of the unit.
// Deprecated: use StartTransientUnitContext instead
func (c *Conn) StartTransientUnit(name string, mode string, properties []Property, ch chan<- string) (int, error) { func (c *Conn) StartTransientUnit(name string, mode string, properties []Property, ch chan<- string) (int, error) {
return c.StartTransientUnitContext(context.Background(), name, mode, properties, ch) return c.StartTransientUnitContext(context.Background(), name, mode, properties, ch)
} }
// StartTransientUnitContext same as StartTransientUnit with context // StartTransientUnitContext may be used to create and start a transient unit, which
// will be released as soon as it is not running or referenced anymore or the
// system is rebooted. name is the unit name including suffix, and must be
// unique. mode is the same as in StartUnitContext, properties contains properties
// of the unit.
func (c *Conn) StartTransientUnitContext(ctx context.Context, name string, mode string, properties []Property, ch chan<- string) (int, error) { func (c *Conn) StartTransientUnitContext(ctx context.Context, name string, mode string, properties []Property, ch chan<- string) (int, error) {
return c.startJob(ctx, ch, "org.freedesktop.systemd1.Manager.StartTransientUnit", name, mode, properties, make([]PropertyCollection, 0)) return c.startJob(ctx, ch, "org.freedesktop.systemd1.Manager.StartTransientUnit", name, mode, properties, make([]PropertyCollection, 0))
} }
// KillUnit takes the unit name and a UNIX signal number to send. All of the unit's // Deprecated: use KillUnitContext instead.
// processes are killed.
// Deprecated: use KillUnitContext instead
func (c *Conn) KillUnit(name string, signal int32) { func (c *Conn) KillUnit(name string, signal int32) {
c.KillUnitContext(context.Background(), name, signal) c.KillUnitContext(context.Background(), name, signal)
} }
// KillUnitContext same as KillUnit with context // KillUnitContext takes the unit name and a UNIX signal number to send.
// All of the unit's processes are killed.
func (c *Conn) KillUnitContext(ctx context.Context, name string, signal int32) { func (c *Conn) KillUnitContext(ctx context.Context, name string, signal int32) {
c.KillUnitWithTarget(ctx, name, All, signal) c.KillUnitWithTarget(ctx, name, All, signal)
} }
// KillUnitWithTarget is like KillUnitContext, but allows you to specify which process in the unit to send the signal to // KillUnitWithTarget is like KillUnitContext, but allows you to specify which
// process in the unit to send the signal to.
func (c *Conn) KillUnitWithTarget(ctx context.Context, name string, target Who, signal int32) error { func (c *Conn) KillUnitWithTarget(ctx context.Context, name string, target Who, signal int32) error {
return c.sysobj.CallWithContext(ctx, "org.freedesktop.systemd1.Manager.KillUnit", 0, name, string(target), signal).Store() return c.sysobj.CallWithContext(ctx, "org.freedesktop.systemd1.Manager.KillUnit", 0, name, string(target), signal).Store()
} }
// ResetFailedUnit resets the "failed" state of a specific unit. // Deprecated: use ResetFailedUnitContext instead.
// Deprecated: use ResetFailedUnitContext instead
func (c *Conn) ResetFailedUnit(name string) error { func (c *Conn) ResetFailedUnit(name string) error {
return c.ResetFailedUnitContext(context.Background(), name) return c.ResetFailedUnitContext(context.Background(), name)
} }
// ResetFailedUnitContext same as ResetFailedUnit with context // ResetFailedUnitContext resets the "failed" state of a specific unit.
func (c *Conn) ResetFailedUnitContext(ctx context.Context, name string) error { func (c *Conn) ResetFailedUnitContext(ctx context.Context, name string) error {
return c.sysobj.CallWithContext(ctx, "org.freedesktop.systemd1.Manager.ResetFailedUnit", 0, name).Store() return c.sysobj.CallWithContext(ctx, "org.freedesktop.systemd1.Manager.ResetFailedUnit", 0, name).Store()
} }
// SystemState returns the systemd state. Equivalent to `systemctl is-system-running`. // Deprecated: use SystemStateContext instead.
// Deprecated: use SystemStateContext instead
func (c *Conn) SystemState() (*Property, error) { func (c *Conn) SystemState() (*Property, error) {
return c.SystemStateContext(context.Background()) return c.SystemStateContext(context.Background())
} }
// SystemStateContext same as SystemState with context // SystemStateContext returns the systemd state. Equivalent to
// systemctl is-system-running.
func (c *Conn) SystemStateContext(ctx context.Context) (*Property, error) { func (c *Conn) SystemStateContext(ctx context.Context) (*Property, error) {
var err error var err error
var prop dbus.Variant var prop dbus.Variant
@ -247,7 +239,7 @@ func (c *Conn) SystemStateContext(ctx context.Context) (*Property, error) {
return &Property{Name: "SystemState", Value: prop}, nil return &Property{Name: "SystemState", Value: prop}, nil
} }
// getProperties takes the unit path and returns all of its dbus object properties, for the given dbus interface // getProperties takes the unit path and returns all of its dbus object properties, for the given dbus interface.
func (c *Conn) getProperties(ctx context.Context, path dbus.ObjectPath, dbusInterface string) (map[string]interface{}, error) { func (c *Conn) getProperties(ctx context.Context, path dbus.ObjectPath, dbusInterface string) (map[string]interface{}, error) {
var err error var err error
var props map[string]dbus.Variant var props map[string]dbus.Variant
@ -270,36 +262,36 @@ func (c *Conn) getProperties(ctx context.Context, path dbus.ObjectPath, dbusInte
return out, nil return out, nil
} }
// GetUnitProperties takes the (unescaped) unit name and returns all of its dbus object properties. // Deprecated: use GetUnitPropertiesContext instead.
// Deprecated: use GetUnitPropertiesContext instead
func (c *Conn) GetUnitProperties(unit string) (map[string]interface{}, error) { func (c *Conn) GetUnitProperties(unit string) (map[string]interface{}, error) {
return c.GetUnitPropertiesContext(context.Background(), unit) return c.GetUnitPropertiesContext(context.Background(), unit)
} }
// GetUnitPropertiesContext same as GetUnitPropertiesContext with context // GetUnitPropertiesContext takes the (unescaped) unit name and returns all of
// its dbus object properties.
func (c *Conn) GetUnitPropertiesContext(ctx context.Context, unit string) (map[string]interface{}, error) { func (c *Conn) GetUnitPropertiesContext(ctx context.Context, unit string) (map[string]interface{}, error) {
path := unitPath(unit) path := unitPath(unit)
return c.getProperties(ctx, path, "org.freedesktop.systemd1.Unit") return c.getProperties(ctx, path, "org.freedesktop.systemd1.Unit")
} }
// GetUnitPathProperties takes the (escaped) unit path and returns all of its dbus object properties. // Deprecated: use GetUnitPathPropertiesContext instead.
// Deprecated: use GetUnitPathPropertiesContext instead
func (c *Conn) GetUnitPathProperties(path dbus.ObjectPath) (map[string]interface{}, error) { func (c *Conn) GetUnitPathProperties(path dbus.ObjectPath) (map[string]interface{}, error) {
return c.GetUnitPathPropertiesContext(context.Background(), path) return c.GetUnitPathPropertiesContext(context.Background(), path)
} }
// GetUnitPathPropertiesContext same as GetUnitPathProperties with context // GetUnitPathPropertiesContext takes the (escaped) unit path and returns all
// of its dbus object properties.
func (c *Conn) GetUnitPathPropertiesContext(ctx context.Context, path dbus.ObjectPath) (map[string]interface{}, error) { func (c *Conn) GetUnitPathPropertiesContext(ctx context.Context, path dbus.ObjectPath) (map[string]interface{}, error) {
return c.getProperties(ctx, path, "org.freedesktop.systemd1.Unit") return c.getProperties(ctx, path, "org.freedesktop.systemd1.Unit")
} }
// GetAllProperties takes the (unescaped) unit name and returns all of its dbus object properties. // Deprecated: use GetAllPropertiesContext instead.
// Deprecated: use GetAllPropertiesContext instead
func (c *Conn) GetAllProperties(unit string) (map[string]interface{}, error) { func (c *Conn) GetAllProperties(unit string) (map[string]interface{}, error) {
return c.GetAllPropertiesContext(context.Background(), unit) return c.GetAllPropertiesContext(context.Background(), unit)
} }
// GetAllPropertiesContext same as GetAllProperties with context // GetAllPropertiesContext takes the (unescaped) unit name and returns all of
// its dbus object properties.
func (c *Conn) GetAllPropertiesContext(ctx context.Context, unit string) (map[string]interface{}, error) { func (c *Conn) GetAllPropertiesContext(ctx context.Context, unit string) (map[string]interface{}, error) {
path := unitPath(unit) path := unitPath(unit)
return c.getProperties(ctx, path, "") return c.getProperties(ctx, path, "")
@ -323,64 +315,63 @@ func (c *Conn) getProperty(ctx context.Context, unit string, dbusInterface strin
return &Property{Name: propertyName, Value: prop}, nil return &Property{Name: propertyName, Value: prop}, nil
} }
// Deprecated: use GetUnitPropertyContext instead // Deprecated: use GetUnitPropertyContext instead.
func (c *Conn) GetUnitProperty(unit string, propertyName string) (*Property, error) { func (c *Conn) GetUnitProperty(unit string, propertyName string) (*Property, error) {
return c.GetUnitPropertyContext(context.Background(), unit, propertyName) return c.GetUnitPropertyContext(context.Background(), unit, propertyName)
} }
// GetUnitPropertyContext same as GetUnitProperty with context // GetUnitPropertyContext takes an (unescaped) unit name, and a property name,
// and returns the property value.
func (c *Conn) GetUnitPropertyContext(ctx context.Context, unit string, propertyName string) (*Property, error) { func (c *Conn) GetUnitPropertyContext(ctx context.Context, unit string, propertyName string) (*Property, error) {
return c.getProperty(ctx, unit, "org.freedesktop.systemd1.Unit", propertyName) return c.getProperty(ctx, unit, "org.freedesktop.systemd1.Unit", propertyName)
} }
// GetServiceProperty returns property for given service name and property name // Deprecated: use GetServicePropertyContext instead.
// Deprecated: use GetServicePropertyContext instead
func (c *Conn) GetServiceProperty(service string, propertyName string) (*Property, error) { func (c *Conn) GetServiceProperty(service string, propertyName string) (*Property, error) {
return c.GetServicePropertyContext(context.Background(), service, propertyName) return c.GetServicePropertyContext(context.Background(), service, propertyName)
} }
// GetServicePropertyContext same as GetServiceProperty with context // GetServiceProperty returns property for given service name and property name.
func (c *Conn) GetServicePropertyContext(ctx context.Context, service string, propertyName string) (*Property, error) { func (c *Conn) GetServicePropertyContext(ctx context.Context, service string, propertyName string) (*Property, error) {
return c.getProperty(ctx, service, "org.freedesktop.systemd1.Service", propertyName) return c.getProperty(ctx, service, "org.freedesktop.systemd1.Service", propertyName)
} }
// GetUnitTypeProperties returns the extra properties for a unit, specific to the unit type. // Deprecated: use GetUnitTypePropertiesContext instead.
// Valid values for unitType: Service, Socket, Target, Device, Mount, Automount, Snapshot, Timer, Swap, Path, Slice, Scope
// return "dbus.Error: Unknown interface" if the unitType is not the correct type of the unit
// Deprecated: use GetUnitTypePropertiesContext instead
func (c *Conn) GetUnitTypeProperties(unit string, unitType string) (map[string]interface{}, error) { func (c *Conn) GetUnitTypeProperties(unit string, unitType string) (map[string]interface{}, error) {
return c.GetUnitTypePropertiesContext(context.Background(), unit, unitType) return c.GetUnitTypePropertiesContext(context.Background(), unit, unitType)
} }
// GetUnitTypePropertiesContext same as GetUnitTypeProperties with context // GetUnitTypePropertiesContext returns the extra properties for a unit, specific to the unit type.
// Valid values for unitType: Service, Socket, Target, Device, Mount, Automount, Snapshot, Timer, Swap, Path, Slice, Scope.
// Returns "dbus.Error: Unknown interface" error if the unitType is not the correct type of the unit.
func (c *Conn) GetUnitTypePropertiesContext(ctx context.Context, unit string, unitType string) (map[string]interface{}, error) { func (c *Conn) GetUnitTypePropertiesContext(ctx context.Context, unit string, unitType string) (map[string]interface{}, error) {
path := unitPath(unit) path := unitPath(unit)
return c.getProperties(ctx, path, "org.freedesktop.systemd1."+unitType) return c.getProperties(ctx, path, "org.freedesktop.systemd1."+unitType)
} }
// SetUnitProperties() may be used to modify certain unit properties at runtime. // Deprecated: use SetUnitPropertiesContext instead.
func (c *Conn) SetUnitProperties(name string, runtime bool, properties ...Property) error {
return c.SetUnitPropertiesContext(context.Background(), name, runtime, properties...)
}
// SetUnitPropertiesContext may be used to modify certain unit properties at runtime.
// Not all properties may be changed at runtime, but many resource management // Not all properties may be changed at runtime, but many resource management
// settings (primarily those in systemd.cgroup(5)) may. The changes are applied // settings (primarily those in systemd.cgroup(5)) may. The changes are applied
// instantly, and stored on disk for future boots, unless runtime is true, in which // instantly, and stored on disk for future boots, unless runtime is true, in which
// case the settings only apply until the next reboot. name is the name of the unit // case the settings only apply until the next reboot. name is the name of the unit
// to modify. properties are the settings to set, encoded as an array of property // to modify. properties are the settings to set, encoded as an array of property
// name and value pairs. // name and value pairs.
// Deprecated: use SetUnitPropertiesContext instead
func (c *Conn) SetUnitProperties(name string, runtime bool, properties ...Property) error {
return c.SetUnitPropertiesContext(context.Background(), name, runtime, properties...)
}
// SetUnitPropertiesContext same as SetUnitProperties with context
func (c *Conn) SetUnitPropertiesContext(ctx context.Context, name string, runtime bool, properties ...Property) error { func (c *Conn) SetUnitPropertiesContext(ctx context.Context, name string, runtime bool, properties ...Property) error {
return c.sysobj.CallWithContext(ctx, "org.freedesktop.systemd1.Manager.SetUnitProperties", 0, name, runtime, properties).Store() return c.sysobj.CallWithContext(ctx, "org.freedesktop.systemd1.Manager.SetUnitProperties", 0, name, runtime, properties).Store()
} }
// Deprecated: use GetUnitTypePropertyContext instead // Deprecated: use GetUnitTypePropertyContext instead.
func (c *Conn) GetUnitTypeProperty(unit string, unitType string, propertyName string) (*Property, error) { func (c *Conn) GetUnitTypeProperty(unit string, unitType string, propertyName string) (*Property, error) {
return c.GetUnitTypePropertyContext(context.Background(), unit, unitType, propertyName) return c.GetUnitTypePropertyContext(context.Background(), unit, unitType, propertyName)
} }
// GetUnitTypePropertyContext same as GetUnitTypeProperty with context // GetUnitTypePropertyContext takes a property name, a unit name, and a unit type,
// and returns a property value. For valid values of unitType, see GetUnitTypePropertiesContext.
func (c *Conn) GetUnitTypePropertyContext(ctx context.Context, unit string, unitType string, propertyName string) (*Property, error) { func (c *Conn) GetUnitTypePropertyContext(ctx context.Context, unit string, unitType string, propertyName string) (*Property, error) {
return c.getProperty(ctx, unit, "org.freedesktop.systemd1."+unitType, propertyName) return c.getProperty(ctx, unit, "org.freedesktop.systemd1."+unitType, propertyName)
} }
@ -426,58 +417,55 @@ func (c *Conn) listUnitsInternal(f storeFunc) ([]UnitStatus, error) {
return status, nil return status, nil
} }
// ListUnits returns an array with all currently loaded units. Note that // Deprecated: use ListUnitsContext instead.
// units may be known by multiple names at the same time, and hence there might
// be more unit names loaded than actual units behind them.
// Also note that a unit is only loaded if it is active and/or enabled.
// Units that are both disabled and inactive will thus not be returned.
// Deprecated: use ListUnitsContext instead
func (c *Conn) ListUnits() ([]UnitStatus, error) { func (c *Conn) ListUnits() ([]UnitStatus, error) {
return c.ListUnitsContext(context.Background()) return c.ListUnitsContext(context.Background())
} }
// ListUnitsContext same as ListUnits with context // ListUnitsContext returns an array with all currently loaded units. Note that
// units may be known by multiple names at the same time, and hence there might
// be more unit names loaded than actual units behind them.
// Also note that a unit is only loaded if it is active and/or enabled.
// Units that are both disabled and inactive will thus not be returned.
func (c *Conn) ListUnitsContext(ctx context.Context) ([]UnitStatus, error) { func (c *Conn) ListUnitsContext(ctx context.Context) ([]UnitStatus, error) {
return c.listUnitsInternal(c.sysobj.CallWithContext(ctx, "org.freedesktop.systemd1.Manager.ListUnits", 0).Store) return c.listUnitsInternal(c.sysobj.CallWithContext(ctx, "org.freedesktop.systemd1.Manager.ListUnits", 0).Store)
} }
// ListUnitsFiltered returns an array with units filtered by state. // Deprecated: use ListUnitsFilteredContext instead.
// It takes a list of units' statuses to filter.
// Deprecated: use ListUnitsFilteredContext instead
func (c *Conn) ListUnitsFiltered(states []string) ([]UnitStatus, error) { func (c *Conn) ListUnitsFiltered(states []string) ([]UnitStatus, error) {
return c.ListUnitsFilteredContext(context.Background(), states) return c.ListUnitsFilteredContext(context.Background(), states)
} }
// ListUnitsFilteredContext same as ListUnitsFiltered with context // ListUnitsFilteredContext returns an array with units filtered by state.
// It takes a list of units' statuses to filter.
func (c *Conn) ListUnitsFilteredContext(ctx context.Context, states []string) ([]UnitStatus, error) { func (c *Conn) ListUnitsFilteredContext(ctx context.Context, states []string) ([]UnitStatus, error) {
return c.listUnitsInternal(c.sysobj.CallWithContext(ctx, "org.freedesktop.systemd1.Manager.ListUnitsFiltered", 0, states).Store) return c.listUnitsInternal(c.sysobj.CallWithContext(ctx, "org.freedesktop.systemd1.Manager.ListUnitsFiltered", 0, states).Store)
} }
// ListUnitsByPatterns returns an array with units. // Deprecated: use ListUnitsByPatternsContext instead.
// It takes a list of units' statuses and names to filter.
// Note that units may be known by multiple names at the same time,
// and hence there might be more unit names loaded than actual units behind them.
// Deprecated: use ListUnitsByPatternsContext instead
func (c *Conn) ListUnitsByPatterns(states []string, patterns []string) ([]UnitStatus, error) { func (c *Conn) ListUnitsByPatterns(states []string, patterns []string) ([]UnitStatus, error) {
return c.ListUnitsByPatternsContext(context.Background(), states, patterns) return c.ListUnitsByPatternsContext(context.Background(), states, patterns)
} }
// ListUnitsByPatternsContext same as ListUnitsByPatterns with context // ListUnitsByPatternsContext returns an array with units.
// It takes a list of units' statuses and names to filter.
// Note that units may be known by multiple names at the same time,
// and hence there might be more unit names loaded than actual units behind them.
func (c *Conn) ListUnitsByPatternsContext(ctx context.Context, states []string, patterns []string) ([]UnitStatus, error) { func (c *Conn) ListUnitsByPatternsContext(ctx context.Context, states []string, patterns []string) ([]UnitStatus, error) {
return c.listUnitsInternal(c.sysobj.CallWithContext(ctx, "org.freedesktop.systemd1.Manager.ListUnitsByPatterns", 0, states, patterns).Store) return c.listUnitsInternal(c.sysobj.CallWithContext(ctx, "org.freedesktop.systemd1.Manager.ListUnitsByPatterns", 0, states, patterns).Store)
} }
// ListUnitsByNames returns an array with units. It takes a list of units' // Deprecated: use ListUnitsByNamesContext instead.
// names and returns an UnitStatus array. Comparing to ListUnitsByPatterns
// method, this method returns statuses even for inactive or non-existing
// units. Input array should contain exact unit names, but not patterns.
// Note: Requires systemd v230 or higher
// Deprecated: use ListUnitsByNamesContext instead
func (c *Conn) ListUnitsByNames(units []string) ([]UnitStatus, error) { func (c *Conn) ListUnitsByNames(units []string) ([]UnitStatus, error) {
return c.ListUnitsByNamesContext(context.Background(), units) return c.ListUnitsByNamesContext(context.Background(), units)
} }
// ListUnitsByNamesContext same as ListUnitsByNames with context // ListUnitsByNamesContext returns an array with units. It takes a list of units'
// names and returns an UnitStatus array. Comparing to ListUnitsByPatternsContext
// method, this method returns statuses even for inactive or non-existing
// units. Input array should contain exact unit names, but not patterns.
//
// Requires systemd v230 or higher.
func (c *Conn) ListUnitsByNamesContext(ctx context.Context, units []string) ([]UnitStatus, error) { func (c *Conn) ListUnitsByNamesContext(ctx context.Context, units []string) ([]UnitStatus, error) {
return c.listUnitsInternal(c.sysobj.CallWithContext(ctx, "org.freedesktop.systemd1.Manager.ListUnitsByNames", 0, units).Store) return c.listUnitsInternal(c.sysobj.CallWithContext(ctx, "org.freedesktop.systemd1.Manager.ListUnitsByNames", 0, units).Store)
} }
@ -513,37 +501,43 @@ func (c *Conn) listUnitFilesInternal(f storeFunc) ([]UnitFile, error) {
return files, nil return files, nil
} }
// ListUnitFiles returns an array of all available units on disk. // Deprecated: use ListUnitFilesContext instead.
// Deprecated: use ListUnitFilesContext instead
func (c *Conn) ListUnitFiles() ([]UnitFile, error) { func (c *Conn) ListUnitFiles() ([]UnitFile, error) {
return c.ListUnitFilesContext(context.Background()) return c.ListUnitFilesContext(context.Background())
} }
// ListUnitFilesContext same as ListUnitFiles with context // ListUnitFiles returns an array of all available units on disk.
func (c *Conn) ListUnitFilesContext(ctx context.Context) ([]UnitFile, error) { func (c *Conn) ListUnitFilesContext(ctx context.Context) ([]UnitFile, error) {
return c.listUnitFilesInternal(c.sysobj.CallWithContext(ctx, "org.freedesktop.systemd1.Manager.ListUnitFiles", 0).Store) return c.listUnitFilesInternal(c.sysobj.CallWithContext(ctx, "org.freedesktop.systemd1.Manager.ListUnitFiles", 0).Store)
} }
// ListUnitFilesByPatterns returns an array of all available units on disk matched the patterns. // Deprecated: use ListUnitFilesByPatternsContext instead.
// Deprecated: use ListUnitFilesByPatternsContext instead
func (c *Conn) ListUnitFilesByPatterns(states []string, patterns []string) ([]UnitFile, error) { func (c *Conn) ListUnitFilesByPatterns(states []string, patterns []string) ([]UnitFile, error) {
return c.ListUnitFilesByPatternsContext(context.Background(), states, patterns) return c.ListUnitFilesByPatternsContext(context.Background(), states, patterns)
} }
// ListUnitFilesByPatternsContext same as ListUnitFilesByPatterns with context // ListUnitFilesByPatternsContext returns an array of all available units on disk matched the patterns.
func (c *Conn) ListUnitFilesByPatternsContext(ctx context.Context, states []string, patterns []string) ([]UnitFile, error) { func (c *Conn) ListUnitFilesByPatternsContext(ctx context.Context, states []string, patterns []string) ([]UnitFile, error) {
return c.listUnitFilesInternal(c.sysobj.CallWithContext(ctx, "org.freedesktop.systemd1.Manager.ListUnitFilesByPatterns", 0, states, patterns).Store) return c.listUnitFilesInternal(c.sysobj.CallWithContext(ctx, "org.freedesktop.systemd1.Manager.ListUnitFilesByPatterns", 0, states, patterns).Store)
} }
type LinkUnitFileChange EnableUnitFileChange type LinkUnitFileChange EnableUnitFileChange
// LinkUnitFiles() links unit files (that are located outside of the // Deprecated: use LinkUnitFilesContext instead.
func (c *Conn) LinkUnitFiles(files []string, runtime bool, force bool) ([]LinkUnitFileChange, error) {
return c.LinkUnitFilesContext(context.Background(), files, runtime, force)
}
// LinkUnitFilesContext links unit files (that are located outside of the
// usual unit search paths) into the unit search path. // usual unit search paths) into the unit search path.
// //
// It takes a list of absolute paths to unit files to link and two // It takes a list of absolute paths to unit files to link and two
// booleans. The first boolean controls whether the unit shall be // booleans.
//
// The first boolean controls whether the unit shall be
// enabled for runtime only (true, /run), or persistently (false, // enabled for runtime only (true, /run), or persistently (false,
// /etc). // /etc).
//
// The second controls whether symlinks pointing to other units shall // The second controls whether symlinks pointing to other units shall
// be replaced if necessary. // be replaced if necessary.
// //
@ -551,12 +545,6 @@ type LinkUnitFileChange EnableUnitFileChange
// structures with three strings: the type of the change (one of symlink // structures with three strings: the type of the change (one of symlink
// or unlink), the file name of the symlink and the destination of the // or unlink), the file name of the symlink and the destination of the
// symlink. // symlink.
// Deprecated: use LinkUnitFilesContext instead
func (c *Conn) LinkUnitFiles(files []string, runtime bool, force bool) ([]LinkUnitFileChange, error) {
return c.LinkUnitFilesContext(context.Background(), files, runtime, force)
}
// LinkUnitFilesContext same as LinkUnitFiles with context
func (c *Conn) LinkUnitFilesContext(ctx context.Context, files []string, runtime bool, force bool) ([]LinkUnitFileChange, error) { func (c *Conn) LinkUnitFilesContext(ctx context.Context, files []string, runtime bool, force bool) ([]LinkUnitFileChange, error) {
result := make([][]interface{}, 0) result := make([][]interface{}, 0)
err := c.sysobj.CallWithContext(ctx, "org.freedesktop.systemd1.Manager.LinkUnitFiles", 0, files, runtime, force).Store(&result) err := c.sysobj.CallWithContext(ctx, "org.freedesktop.systemd1.Manager.LinkUnitFiles", 0, files, runtime, force).Store(&result)
@ -583,8 +571,13 @@ func (c *Conn) LinkUnitFilesContext(ctx context.Context, files []string, runtime
return changes, nil return changes, nil
} }
// EnableUnitFiles() may be used to enable one or more units in the system (by // Deprecated: use EnableUnitFilesContext instead.
// creating symlinks to them in /etc or /run). func (c *Conn) EnableUnitFiles(files []string, runtime bool, force bool) (bool, []EnableUnitFileChange, error) {
return c.EnableUnitFilesContext(context.Background(), files, runtime, force)
}
// EnableUnitFilesContext may be used to enable one or more units in the system
// (by creating symlinks to them in /etc or /run).
// //
// It takes a list of unit files to enable (either just file names or full // It takes a list of unit files to enable (either just file names or full
// absolute paths if the unit files are residing outside the usual unit // absolute paths if the unit files are residing outside the usual unit
@ -599,12 +592,6 @@ func (c *Conn) LinkUnitFilesContext(ctx context.Context, files []string, runtime
// structures with three strings: the type of the change (one of symlink // structures with three strings: the type of the change (one of symlink
// or unlink), the file name of the symlink and the destination of the // or unlink), the file name of the symlink and the destination of the
// symlink. // symlink.
// Deprecated: use EnableUnitFilesContext instead
func (c *Conn) EnableUnitFiles(files []string, runtime bool, force bool) (bool, []EnableUnitFileChange, error) {
return c.EnableUnitFilesContext(context.Background(), files, runtime, force)
}
// EnableUnitFilesContext same as EnableUnitFiles with context
func (c *Conn) EnableUnitFilesContext(ctx context.Context, files []string, runtime bool, force bool) (bool, []EnableUnitFileChange, error) { func (c *Conn) EnableUnitFilesContext(ctx context.Context, files []string, runtime bool, force bool) (bool, []EnableUnitFileChange, error) {
var carries_install_info bool var carries_install_info bool
@ -639,8 +626,13 @@ type EnableUnitFileChange struct {
Destination string // Destination of the symlink Destination string // Destination of the symlink
} }
// DisableUnitFiles() may be used to disable one or more units in the system (by // Deprecated: use DisableUnitFilesContext instead.
// removing symlinks to them from /etc or /run). func (c *Conn) DisableUnitFiles(files []string, runtime bool) ([]DisableUnitFileChange, error) {
return c.DisableUnitFilesContext(context.Background(), files, runtime)
}
// DisableUnitFilesContext may be used to disable one or more units in the
// system (by removing symlinks to them from /etc or /run).
// //
// It takes a list of unit files to disable (either just file names or full // It takes a list of unit files to disable (either just file names or full
// absolute paths if the unit files are residing outside the usual unit // absolute paths if the unit files are residing outside the usual unit
@ -651,12 +643,6 @@ type EnableUnitFileChange struct {
// consists of structures with three strings: the type of the change (one of // consists of structures with three strings: the type of the change (one of
// symlink or unlink), the file name of the symlink and the destination of the // symlink or unlink), the file name of the symlink and the destination of the
// symlink. // symlink.
// Deprecated: use DisableUnitFilesContext instead
func (c *Conn) DisableUnitFiles(files []string, runtime bool) ([]DisableUnitFileChange, error) {
return c.DisableUnitFilesContext(context.Background(), files, runtime)
}
// DisableUnitFilesContext same as DisableUnitFiles with context
func (c *Conn) DisableUnitFilesContext(ctx context.Context, files []string, runtime bool) ([]DisableUnitFileChange, error) { func (c *Conn) DisableUnitFilesContext(ctx context.Context, files []string, runtime bool) ([]DisableUnitFileChange, error) {
result := make([][]interface{}, 0) result := make([][]interface{}, 0)
err := c.sysobj.CallWithContext(ctx, "org.freedesktop.systemd1.Manager.DisableUnitFiles", 0, files, runtime).Store(&result) err := c.sysobj.CallWithContext(ctx, "org.freedesktop.systemd1.Manager.DisableUnitFiles", 0, files, runtime).Store(&result)
@ -689,21 +675,20 @@ type DisableUnitFileChange struct {
Destination string // Destination of the symlink Destination string // Destination of the symlink
} }
// MaskUnitFiles masks one or more units in the system // Deprecated: use MaskUnitFilesContext instead.
//
// It takes three arguments:
// * list of units to mask (either just file names or full
// absolute paths if the unit files are residing outside
// the usual unit search paths)
// * runtime to specify whether the unit was enabled for runtime
// only (true, /run/systemd/..), or persistently (false, /etc/systemd/..)
// * force flag
// Deprecated: use MaskUnitFilesContext instead
func (c *Conn) MaskUnitFiles(files []string, runtime bool, force bool) ([]MaskUnitFileChange, error) { func (c *Conn) MaskUnitFiles(files []string, runtime bool, force bool) ([]MaskUnitFileChange, error) {
return c.MaskUnitFilesContext(context.Background(), files, runtime, force) return c.MaskUnitFilesContext(context.Background(), files, runtime, force)
} }
// MaskUnitFilesContext same as MaskUnitFiles with context // MaskUnitFilesContext masks one or more units in the system.
//
// The files argument contains a list of units to mask (either just file names
// or full absolute paths if the unit files are residing outside the usual unit
// search paths).
//
// The runtime argument is used to specify whether the unit was enabled for
// runtime only (true, /run/systemd/..), or persistently (false,
// /etc/systemd/..).
func (c *Conn) MaskUnitFilesContext(ctx context.Context, files []string, runtime bool, force bool) ([]MaskUnitFileChange, error) { func (c *Conn) MaskUnitFilesContext(ctx context.Context, files []string, runtime bool, force bool) ([]MaskUnitFileChange, error) {
result := make([][]interface{}, 0) result := make([][]interface{}, 0)
err := c.sysobj.CallWithContext(ctx, "org.freedesktop.systemd1.Manager.MaskUnitFiles", 0, files, runtime, force).Store(&result) err := c.sysobj.CallWithContext(ctx, "org.freedesktop.systemd1.Manager.MaskUnitFiles", 0, files, runtime, force).Store(&result)
@ -736,20 +721,18 @@ type MaskUnitFileChange struct {
Destination string // Destination of the symlink Destination string // Destination of the symlink
} }
// UnmaskUnitFiles unmasks one or more units in the system // Deprecated: use UnmaskUnitFilesContext instead.
//
// It takes two arguments:
// * list of unit files to mask (either just file names or full
// absolute paths if the unit files are residing outside
// the usual unit search paths)
// * runtime to specify whether the unit was enabled for runtime
// only (true, /run/systemd/..), or persistently (false, /etc/systemd/..)
// Deprecated: use UnmaskUnitFilesContext instead
func (c *Conn) UnmaskUnitFiles(files []string, runtime bool) ([]UnmaskUnitFileChange, error) { func (c *Conn) UnmaskUnitFiles(files []string, runtime bool) ([]UnmaskUnitFileChange, error) {
return c.UnmaskUnitFilesContext(context.Background(), files, runtime) return c.UnmaskUnitFilesContext(context.Background(), files, runtime)
} }
// UnmaskUnitFilesContext same as UnmaskUnitFiles with context // UnmaskUnitFilesContext unmasks one or more units in the system.
//
// It takes the list of unit files to mask (either just file names or full
// absolute paths if the unit files are residing outside the usual unit search
// paths), and a boolean runtime flag to specify whether the unit was enabled
// for runtime only (true, /run/systemd/..), or persistently (false,
// /etc/systemd/..).
func (c *Conn) UnmaskUnitFilesContext(ctx context.Context, files []string, runtime bool) ([]UnmaskUnitFileChange, error) { func (c *Conn) UnmaskUnitFilesContext(ctx context.Context, files []string, runtime bool) ([]UnmaskUnitFileChange, error) {
result := make([][]interface{}, 0) result := make([][]interface{}, 0)
err := c.sysobj.CallWithContext(ctx, "org.freedesktop.systemd1.Manager.UnmaskUnitFiles", 0, files, runtime).Store(&result) err := c.sysobj.CallWithContext(ctx, "org.freedesktop.systemd1.Manager.UnmaskUnitFiles", 0, files, runtime).Store(&result)
@ -782,14 +765,13 @@ type UnmaskUnitFileChange struct {
Destination string // Destination of the symlink Destination string // Destination of the symlink
} }
// Reload instructs systemd to scan for and reload unit files. This is // Deprecated: use ReloadContext instead.
// equivalent to a 'systemctl daemon-reload'.
// Deprecated: use ReloadContext instead
func (c *Conn) Reload() error { func (c *Conn) Reload() error {
return c.ReloadContext(context.Background()) return c.ReloadContext(context.Background())
} }
// ReloadContext same as Reload with context // ReloadContext instructs systemd to scan for and reload unit files. This is
// an equivalent to systemctl daemon-reload.
func (c *Conn) ReloadContext(ctx context.Context) error { func (c *Conn) ReloadContext(ctx context.Context) error {
return c.sysobj.CallWithContext(ctx, "org.freedesktop.systemd1.Manager.Reload", 0).Store() return c.sysobj.CallWithContext(ctx, "org.freedesktop.systemd1.Manager.Reload", 0).Store()
} }
@ -798,12 +780,12 @@ func unitPath(name string) dbus.ObjectPath {
return dbus.ObjectPath("/org/freedesktop/systemd1/unit/" + PathBusEscape(name)) return dbus.ObjectPath("/org/freedesktop/systemd1/unit/" + PathBusEscape(name))
} }
// unitName returns the unescaped base element of the supplied escaped path // unitName returns the unescaped base element of the supplied escaped path.
func unitName(dpath dbus.ObjectPath) string { func unitName(dpath dbus.ObjectPath) string {
return pathBusUnescape(path.Base(string(dpath))) return pathBusUnescape(path.Base(string(dpath)))
} }
// Currently queued job definition // JobStatus holds a currently queued job definition.
type JobStatus struct { type JobStatus struct {
Id uint32 // The numeric job id Id uint32 // The numeric job id
Unit string // The primary unit name for this job Unit string // The primary unit name for this job
@ -813,13 +795,12 @@ type JobStatus struct {
UnitPath dbus.ObjectPath // The unit object path UnitPath dbus.ObjectPath // The unit object path
} }
// ListJobs returns an array with all currently queued jobs // Deprecated: use ListJobsContext instead.
// Deprecated: use ListJobsContext instead
func (c *Conn) ListJobs() ([]JobStatus, error) { func (c *Conn) ListJobs() ([]JobStatus, error) {
return c.ListJobsContext(context.Background()) return c.ListJobsContext(context.Background())
} }
// ListJobsContext same as ListJobs with context // ListJobsContext returns an array with all currently queued jobs.
func (c *Conn) ListJobsContext(ctx context.Context) ([]JobStatus, error) { func (c *Conn) ListJobsContext(ctx context.Context) ([]JobStatus, error) {
return c.listJobsInternal(ctx) return c.listJobsInternal(ctx)
} }

View File

@ -13,6 +13,7 @@ import (
"strings" "strings"
"sync" "sync"
"google.golang.org/protobuf/reflect/protodesc"
"google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry" "google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/runtime/protoimpl" "google.golang.org/protobuf/runtime/protoimpl"
@ -62,14 +63,7 @@ func FileDescriptor(s filePath) fileDescGZIP {
// Find the descriptor in the v2 registry. // Find the descriptor in the v2 registry.
var b []byte var b []byte
if fd, _ := protoregistry.GlobalFiles.FindFileByPath(s); fd != nil { if fd, _ := protoregistry.GlobalFiles.FindFileByPath(s); fd != nil {
if fd, ok := fd.(interface{ ProtoLegacyRawDesc() []byte }); ok { b, _ = Marshal(protodesc.ToFileDescriptorProto(fd))
b = fd.ProtoLegacyRawDesc()
} else {
// TODO: Use protodesc.ToFileDescriptorProto to construct
// a descriptorpb.FileDescriptorProto and marshal it.
// However, doing so causes the proto package to have a dependency
// on descriptorpb, leading to cyclic dependency issues.
}
} }
// Locally cache the raw descriptor form for the file. // Locally cache the raw descriptor form for the file.

View File

@ -19,6 +19,8 @@ const urlPrefix = "type.googleapis.com/"
// AnyMessageName returns the message name contained in an anypb.Any message. // AnyMessageName returns the message name contained in an anypb.Any message.
// Most type assertions should use the Is function instead. // Most type assertions should use the Is function instead.
//
// Deprecated: Call the any.MessageName method instead.
func AnyMessageName(any *anypb.Any) (string, error) { func AnyMessageName(any *anypb.Any) (string, error) {
name, err := anyMessageName(any) name, err := anyMessageName(any)
return string(name), err return string(name), err
@ -38,6 +40,8 @@ func anyMessageName(any *anypb.Any) (protoreflect.FullName, error) {
} }
// MarshalAny marshals the given message m into an anypb.Any message. // MarshalAny marshals the given message m into an anypb.Any message.
//
// Deprecated: Call the anypb.New function instead.
func MarshalAny(m proto.Message) (*anypb.Any, error) { func MarshalAny(m proto.Message) (*anypb.Any, error) {
switch dm := m.(type) { switch dm := m.(type) {
case DynamicAny: case DynamicAny:
@ -58,6 +62,9 @@ func MarshalAny(m proto.Message) (*anypb.Any, error) {
// Empty returns a new message of the type specified in an anypb.Any message. // Empty returns a new message of the type specified in an anypb.Any message.
// It returns protoregistry.NotFound if the corresponding message type could not // It returns protoregistry.NotFound if the corresponding message type could not
// be resolved in the global registry. // be resolved in the global registry.
//
// Deprecated: Use protoregistry.GlobalTypes.FindMessageByName instead
// to resolve the message name and create a new instance of it.
func Empty(any *anypb.Any) (proto.Message, error) { func Empty(any *anypb.Any) (proto.Message, error) {
name, err := anyMessageName(any) name, err := anyMessageName(any)
if err != nil { if err != nil {
@ -76,6 +83,8 @@ func Empty(any *anypb.Any) (proto.Message, error) {
// //
// The target message m may be a *DynamicAny message. If the underlying message // The target message m may be a *DynamicAny message. If the underlying message
// type could not be resolved, then this returns protoregistry.NotFound. // type could not be resolved, then this returns protoregistry.NotFound.
//
// Deprecated: Call the any.UnmarshalTo method instead.
func UnmarshalAny(any *anypb.Any, m proto.Message) error { func UnmarshalAny(any *anypb.Any, m proto.Message) error {
if dm, ok := m.(*DynamicAny); ok { if dm, ok := m.(*DynamicAny); ok {
if dm.Message == nil { if dm.Message == nil {
@ -100,6 +109,8 @@ func UnmarshalAny(any *anypb.Any, m proto.Message) error {
} }
// Is reports whether the Any message contains a message of the specified type. // Is reports whether the Any message contains a message of the specified type.
//
// Deprecated: Call the any.MessageIs method instead.
func Is(any *anypb.Any, m proto.Message) bool { func Is(any *anypb.Any, m proto.Message) bool {
if any == nil || m == nil { if any == nil || m == nil {
return false return false
@ -119,6 +130,9 @@ func Is(any *anypb.Any, m proto.Message) bool {
// var x ptypes.DynamicAny // var x ptypes.DynamicAny
// if err := ptypes.UnmarshalAny(a, &x); err != nil { ... } // if err := ptypes.UnmarshalAny(a, &x); err != nil { ... }
// fmt.Printf("unmarshaled message: %v", x.Message) // fmt.Printf("unmarshaled message: %v", x.Message)
//
// Deprecated: Use the any.UnmarshalNew method instead to unmarshal
// the any message contents into a new instance of the underlying message.
type DynamicAny struct{ proto.Message } type DynamicAny struct{ proto.Message }
func (m DynamicAny) String() string { func (m DynamicAny) String() string {

View File

@ -3,4 +3,8 @@
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// Package ptypes provides functionality for interacting with well-known types. // Package ptypes provides functionality for interacting with well-known types.
//
// Deprecated: Well-known types have specialized functionality directly
// injected into the generated packages for each message type.
// See the deprecation notice for each function for the suggested alternative.
package ptypes package ptypes

View File

@ -21,6 +21,8 @@ const (
// Duration converts a durationpb.Duration to a time.Duration. // Duration converts a durationpb.Duration to a time.Duration.
// Duration returns an error if dur is invalid or overflows a time.Duration. // Duration returns an error if dur is invalid or overflows a time.Duration.
//
// Deprecated: Call the dur.AsDuration and dur.CheckValid methods instead.
func Duration(dur *durationpb.Duration) (time.Duration, error) { func Duration(dur *durationpb.Duration) (time.Duration, error) {
if err := validateDuration(dur); err != nil { if err := validateDuration(dur); err != nil {
return 0, err return 0, err
@ -39,6 +41,8 @@ func Duration(dur *durationpb.Duration) (time.Duration, error) {
} }
// DurationProto converts a time.Duration to a durationpb.Duration. // DurationProto converts a time.Duration to a durationpb.Duration.
//
// Deprecated: Call the durationpb.New function instead.
func DurationProto(d time.Duration) *durationpb.Duration { func DurationProto(d time.Duration) *durationpb.Duration {
nanos := d.Nanoseconds() nanos := d.Nanoseconds()
secs := nanos / 1e9 secs := nanos / 1e9

View File

@ -33,6 +33,8 @@ const (
// //
// A nil Timestamp returns an error. The first return value in that case is // A nil Timestamp returns an error. The first return value in that case is
// undefined. // undefined.
//
// Deprecated: Call the ts.AsTime and ts.CheckValid methods instead.
func Timestamp(ts *timestamppb.Timestamp) (time.Time, error) { func Timestamp(ts *timestamppb.Timestamp) (time.Time, error) {
// Don't return the zero value on error, because corresponds to a valid // Don't return the zero value on error, because corresponds to a valid
// timestamp. Instead return whatever time.Unix gives us. // timestamp. Instead return whatever time.Unix gives us.
@ -46,6 +48,8 @@ func Timestamp(ts *timestamppb.Timestamp) (time.Time, error) {
} }
// TimestampNow returns a google.protobuf.Timestamp for the current time. // TimestampNow returns a google.protobuf.Timestamp for the current time.
//
// Deprecated: Call the timestamppb.Now function instead.
func TimestampNow() *timestamppb.Timestamp { func TimestampNow() *timestamppb.Timestamp {
ts, err := TimestampProto(time.Now()) ts, err := TimestampProto(time.Now())
if err != nil { if err != nil {
@ -56,6 +60,8 @@ func TimestampNow() *timestamppb.Timestamp {
// TimestampProto converts the time.Time to a google.protobuf.Timestamp proto. // TimestampProto converts the time.Time to a google.protobuf.Timestamp proto.
// It returns an error if the resulting Timestamp is invalid. // It returns an error if the resulting Timestamp is invalid.
//
// Deprecated: Call the timestamppb.New function instead.
func TimestampProto(t time.Time) (*timestamppb.Timestamp, error) { func TimestampProto(t time.Time) (*timestamppb.Timestamp, error) {
ts := &timestamppb.Timestamp{ ts := &timestamppb.Timestamp{
Seconds: t.Unix(), Seconds: t.Unix(),
@ -69,6 +75,9 @@ func TimestampProto(t time.Time) (*timestamppb.Timestamp, error) {
// TimestampString returns the RFC 3339 string for valid Timestamps. // TimestampString returns the RFC 3339 string for valid Timestamps.
// For invalid Timestamps, it returns an error message in parentheses. // For invalid Timestamps, it returns an error message in parentheses.
//
// Deprecated: Call the ts.AsTime method instead,
// followed by a call to the Format method on the time.Time value.
func TimestampString(ts *timestamppb.Timestamp) string { func TimestampString(ts *timestamppb.Timestamp) string {
t, err := Timestamp(ts) t, err := Timestamp(ts)
if err != nil { if err != nil {

View File

@ -11,7 +11,6 @@ import (
"time" "time"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"golang.org/x/xerrors"
) )
func equateAlways(_, _ interface{}) bool { return true } func equateAlways(_, _ interface{}) bool { return true }
@ -147,10 +146,3 @@ func areConcreteErrors(x, y interface{}) bool {
_, ok2 := y.(error) _, ok2 := y.(error)
return ok1 && ok2 return ok1 && ok2
} }
func compareErrors(x, y interface{}) bool {
xe := x.(error)
ye := y.(error)
// TODO(≥go1.13): Use standard definition of errors.Is.
return xerrors.Is(xe, ye) || xerrors.Is(ye, xe)
}

View File

@ -0,0 +1,15 @@
// Copyright 2021, The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.13
package cmpopts
import "errors"
func compareErrors(x, y interface{}) bool {
xe := x.(error)
ye := y.(error)
return errors.Is(xe, ye) || errors.Is(ye, xe)
}

View File

@ -0,0 +1,18 @@
// Copyright 2021, The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !go1.13
// TODO(≥go1.13): For support on <go1.13, we use the xerrors package.
// Drop this file when we no longer support older Go versions.
package cmpopts
import "golang.org/x/xerrors"
func compareErrors(x, y interface{}) bool {
xe := x.(error)
ye := y.(error)
return xerrors.Is(xe, ye) || xerrors.Is(ye, xe)
}

View File

@ -79,7 +79,7 @@ func (opts formatOptions) verbosity() uint {
} }
} }
const maxVerbosityPreset = 3 const maxVerbosityPreset = 6
// verbosityPreset modifies the verbosity settings given an index // verbosityPreset modifies the verbosity settings given an index
// between 0 and maxVerbosityPreset, inclusive. // between 0 and maxVerbosityPreset, inclusive.
@ -100,7 +100,7 @@ func verbosityPreset(opts formatOptions, i int) formatOptions {
func (opts formatOptions) FormatDiff(v *valueNode, ptrs *pointerReferences) (out textNode) { func (opts formatOptions) FormatDiff(v *valueNode, ptrs *pointerReferences) (out textNode) {
if opts.DiffMode == diffIdentical { if opts.DiffMode == diffIdentical {
opts = opts.WithVerbosity(1) opts = opts.WithVerbosity(1)
} else { } else if opts.verbosity() < 3 {
opts = opts.WithVerbosity(3) opts = opts.WithVerbosity(3)
} }

View File

@ -26,8 +26,6 @@ func (opts formatOptions) CanFormatDiffSlice(v *valueNode) bool {
return false // No differences detected return false // No differences detected
case !v.ValueX.IsValid() || !v.ValueY.IsValid(): case !v.ValueX.IsValid() || !v.ValueY.IsValid():
return false // Both values must be valid return false // Both values must be valid
case v.Type.Kind() == reflect.Slice && (v.ValueX.Len() == 0 || v.ValueY.Len() == 0):
return false // Both slice values have to be non-empty
case v.NumIgnored > 0: case v.NumIgnored > 0:
return false // Some ignore option was used return false // Some ignore option was used
case v.NumTransformed > 0: case v.NumTransformed > 0:
@ -45,7 +43,16 @@ func (opts formatOptions) CanFormatDiffSlice(v *valueNode) bool {
return false return false
} }
switch t := v.Type; t.Kind() { // Check whether this is an interface with the same concrete types.
t := v.Type
vx, vy := v.ValueX, v.ValueY
if t.Kind() == reflect.Interface && !vx.IsNil() && !vy.IsNil() && vx.Elem().Type() == vy.Elem().Type() {
vx, vy = vx.Elem(), vy.Elem()
t = vx.Type()
}
// Check whether we provide specialized diffing for this type.
switch t.Kind() {
case reflect.String: case reflect.String:
case reflect.Array, reflect.Slice: case reflect.Array, reflect.Slice:
// Only slices of primitive types have specialized handling. // Only slices of primitive types have specialized handling.
@ -57,6 +64,11 @@ func (opts formatOptions) CanFormatDiffSlice(v *valueNode) bool {
return false return false
} }
// Both slice values have to be non-empty.
if t.Kind() == reflect.Slice && (vx.Len() == 0 || vy.Len() == 0) {
return false
}
// If a sufficient number of elements already differ, // If a sufficient number of elements already differ,
// use specialized formatting even if length requirement is not met. // use specialized formatting even if length requirement is not met.
if v.NumDiff > v.NumSame { if v.NumDiff > v.NumSame {
@ -68,7 +80,7 @@ func (opts formatOptions) CanFormatDiffSlice(v *valueNode) bool {
// Use specialized string diffing for longer slices or strings. // Use specialized string diffing for longer slices or strings.
const minLength = 64 const minLength = 64
return v.ValueX.Len() >= minLength && v.ValueY.Len() >= minLength return vx.Len() >= minLength && vy.Len() >= minLength
} }
// FormatDiffSlice prints a diff for the slices (or strings) represented by v. // FormatDiffSlice prints a diff for the slices (or strings) represented by v.
@ -77,6 +89,11 @@ func (opts formatOptions) CanFormatDiffSlice(v *valueNode) bool {
func (opts formatOptions) FormatDiffSlice(v *valueNode) textNode { func (opts formatOptions) FormatDiffSlice(v *valueNode) textNode {
assert(opts.DiffMode == diffUnknown) assert(opts.DiffMode == diffUnknown)
t, vx, vy := v.Type, v.ValueX, v.ValueY t, vx, vy := v.Type, v.ValueX, v.ValueY
if t.Kind() == reflect.Interface {
vx, vy = vx.Elem(), vy.Elem()
t = vx.Type()
opts = opts.WithTypeMode(emitType)
}
// Auto-detect the type of the data. // Auto-detect the type of the data.
var isLinedText, isText, isBinary bool var isLinedText, isText, isBinary bool

View File

@ -11,19 +11,17 @@ import (
) )
const ( const (
minId = 0 minID = 0
maxId = 1<<31 - 1 //for 32-bit systems compatibility maxID = 1<<31 - 1 // for 32-bit systems compatibility
) )
var ( var (
// The current operating system does not provide the required data for user lookups. // ErrNoPasswdEntries is returned if no matching entries were found in /etc/group.
ErrUnsupported = errors.New("user lookup: operating system does not provide passwd-formatted data")
// No matching entries found in file.
ErrNoPasswdEntries = errors.New("no matching entries in passwd file") ErrNoPasswdEntries = errors.New("no matching entries in passwd file")
// ErrNoGroupEntries is returned if no matching entries were found in /etc/passwd.
ErrNoGroupEntries = errors.New("no matching entries in group file") ErrNoGroupEntries = errors.New("no matching entries in group file")
// ErrRange is returned if a UID or GID is outside of the valid range.
ErrRange = fmt.Errorf("uids and gids must be in range %d-%d", minId, maxId) ErrRange = fmt.Errorf("uids and gids must be in range %d-%d", minID, maxID)
) )
type User struct { type User struct {
@ -328,7 +326,7 @@ func GetExecUser(userSpec string, defaults *ExecUser, passwd, group io.Reader) (
user.Uid = uidArg user.Uid = uidArg
// Must be inside valid uid range. // Must be inside valid uid range.
if user.Uid < minId || user.Uid > maxId { if user.Uid < minID || user.Uid > maxID {
return nil, ErrRange return nil, ErrRange
} }
@ -377,7 +375,7 @@ func GetExecUser(userSpec string, defaults *ExecUser, passwd, group io.Reader) (
user.Gid = gidArg user.Gid = gidArg
// Must be inside valid gid range. // Must be inside valid gid range.
if user.Gid < minId || user.Gid > maxId { if user.Gid < minID || user.Gid > maxID {
return nil, ErrRange return nil, ErrRange
} }
@ -401,7 +399,7 @@ func GetExecUser(userSpec string, defaults *ExecUser, passwd, group io.Reader) (
// or the given group data is nil, the id will be returned as-is // or the given group data is nil, the id will be returned as-is
// provided it is in the legal range. // provided it is in the legal range.
func GetAdditionalGroups(additionalGroups []string, group io.Reader) ([]int, error) { func GetAdditionalGroups(additionalGroups []string, group io.Reader) ([]int, error) {
var groups = []Group{} groups := []Group{}
if group != nil { if group != nil {
var err error var err error
groups, err = ParseGroupFilter(group, func(g Group) bool { groups, err = ParseGroupFilter(group, func(g Group) bool {
@ -439,7 +437,7 @@ func GetAdditionalGroups(additionalGroups []string, group io.Reader) ([]int, err
return nil, fmt.Errorf("Unable to find group %s", ag) return nil, fmt.Errorf("Unable to find group %s", ag)
} }
// Ensure gid is inside gid range. // Ensure gid is inside gid range.
if gid < minId || gid > maxId { if gid < minID || gid > maxID {
return nil, ErrRange return nil, ErrRange
} }
gidMap[int(gid)] = struct{}{} gidMap[int(gid)] = struct{}{}

View File

@ -4,14 +4,12 @@ git:
depth: 1 depth: 1
env: env:
- GO111MODULE=on - GO111MODULE=on
go: [1.13.x, 1.14.x] go: 1.15.x
os: [linux, osx] os: linux
install: install:
- ./travis/install.sh - ./travis/install.sh
script: script:
- ./travis/cross_build.sh - cd ci
- ./travis/lint.sh - go run mage.go -v -w ../ crossBuild
- export GOMAXPROCS=4 - go run mage.go -v -w ../ lint
- export GORACE=halt_on_error=1 - go run mage.go -v -w ../ test
- go test -race -v ./...
- if [[ "$TRAVIS_OS_NAME" == "linux" ]]; then go test -race -v -tags appengine ./... ; fi

View File

@ -1,3 +1,39 @@
# 1.8.1
Code quality:
* move magefile in its own subdir/submodule to remove magefile dependency on logrus consumer
* improve timestamp format documentation
Fixes:
* fix race condition on logger hooks
# 1.8.0
Correct versioning number replacing v1.7.1.
# 1.7.1
Beware this release has introduced a new public API and its semver is therefore incorrect.
Code quality:
* use go 1.15 in travis
* use magefile as task runner
Fixes:
* small fixes about new go 1.13 error formatting system
* Fix for long time race condiction with mutating data hooks
Features:
* build support for zos
# 1.7.0
Fixes:
* the dependency toward a windows terminal library has been removed
Features:
* a new buffer pool management API has been added
* a set of `<LogLevel>Fn()` functions have been added
# 1.6.0 # 1.6.0
Fixes: Fixes:
* end of line cleanup * end of line cleanup

View File

@ -402,7 +402,7 @@ func (f *MyJSONFormatter) Format(entry *Entry) ([]byte, error) {
// source of the official loggers. // source of the official loggers.
serialized, err := json.Marshal(entry.Data) serialized, err := json.Marshal(entry.Data)
if err != nil { if err != nil {
return nil, fmt.Errorf("Failed to marshal fields to JSON, %v", err) return nil, fmt.Errorf("Failed to marshal fields to JSON, %w", err)
} }
return append(serialized, '\n'), nil return append(serialized, '\n'), nil
} }

View File

@ -78,6 +78,14 @@ func NewEntry(logger *Logger) *Entry {
} }
} }
func (entry *Entry) Dup() *Entry {
data := make(Fields, len(entry.Data))
for k, v := range entry.Data {
data[k] = v
}
return &Entry{Logger: entry.Logger, Data: data, Time: entry.Time, Context: entry.Context, err: entry.err}
}
// Returns the bytes representation of this entry from the formatter. // Returns the bytes representation of this entry from the formatter.
func (entry *Entry) Bytes() ([]byte, error) { func (entry *Entry) Bytes() ([]byte, error) {
return entry.Logger.Formatter.Format(entry) return entry.Logger.Formatter.Format(entry)
@ -123,11 +131,9 @@ func (entry *Entry) WithFields(fields Fields) *Entry {
for k, v := range fields { for k, v := range fields {
isErrField := false isErrField := false
if t := reflect.TypeOf(v); t != nil { if t := reflect.TypeOf(v); t != nil {
switch t.Kind() { switch {
case reflect.Func: case t.Kind() == reflect.Func, t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Func:
isErrField = true isErrField = true
case reflect.Ptr:
isErrField = t.Elem().Kind() == reflect.Func
} }
} }
if isErrField { if isErrField {
@ -212,68 +218,72 @@ func (entry Entry) HasCaller() (has bool) {
entry.Caller != nil entry.Caller != nil
} }
// This function is not declared with a pointer value because otherwise func (entry *Entry) log(level Level, msg string) {
// race conditions will occur when using multiple goroutines
func (entry Entry) log(level Level, msg string) {
var buffer *bytes.Buffer var buffer *bytes.Buffer
// Default to now, but allow users to override if they want. newEntry := entry.Dup()
//
// We don't have to worry about polluting future calls to Entry#log() if newEntry.Time.IsZero() {
// with this assignment because this function is declared with a newEntry.Time = time.Now()
// non-pointer receiver.
if entry.Time.IsZero() {
entry.Time = time.Now()
} }
entry.Level = level newEntry.Level = level
entry.Message = msg newEntry.Message = msg
entry.Logger.mu.Lock()
if entry.Logger.ReportCaller {
entry.Caller = getCaller()
}
entry.Logger.mu.Unlock()
entry.fireHooks() newEntry.Logger.mu.Lock()
reportCaller := newEntry.Logger.ReportCaller
newEntry.Logger.mu.Unlock()
if reportCaller {
newEntry.Caller = getCaller()
}
newEntry.fireHooks()
buffer = getBuffer() buffer = getBuffer()
defer func() { defer func() {
entry.Buffer = nil newEntry.Buffer = nil
putBuffer(buffer) putBuffer(buffer)
}() }()
buffer.Reset() buffer.Reset()
entry.Buffer = buffer newEntry.Buffer = buffer
entry.write() newEntry.write()
entry.Buffer = nil newEntry.Buffer = nil
// To avoid Entry#log() returning a value that only would make sense for // To avoid Entry#log() returning a value that only would make sense for
// panic() to use in Entry#Panic(), we avoid the allocation by checking // panic() to use in Entry#Panic(), we avoid the allocation by checking
// directly here. // directly here.
if level <= PanicLevel { if level <= PanicLevel {
panic(&entry) panic(newEntry)
} }
} }
func (entry *Entry) fireHooks() { func (entry *Entry) fireHooks() {
var tmpHooks LevelHooks
entry.Logger.mu.Lock() entry.Logger.mu.Lock()
defer entry.Logger.mu.Unlock() tmpHooks = make(LevelHooks, len(entry.Logger.Hooks))
err := entry.Logger.Hooks.Fire(entry.Level, entry) for k, v := range entry.Logger.Hooks {
tmpHooks[k] = v
}
entry.Logger.mu.Unlock()
err := tmpHooks.Fire(entry.Level, entry)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "Failed to fire hook: %v\n", err) fmt.Fprintf(os.Stderr, "Failed to fire hook: %v\n", err)
} }
} }
func (entry *Entry) write() { func (entry *Entry) write() {
entry.Logger.mu.Lock()
defer entry.Logger.mu.Unlock()
serialized, err := entry.Logger.Formatter.Format(entry) serialized, err := entry.Logger.Formatter.Format(entry)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "Failed to obtain reader, %v\n", err) fmt.Fprintf(os.Stderr, "Failed to obtain reader, %v\n", err)
return return
} }
if _, err = entry.Logger.Out.Write(serialized); err != nil { entry.Logger.mu.Lock()
defer entry.Logger.mu.Unlock()
if _, err := entry.Logger.Out.Write(serialized); err != nil {
fmt.Fprintf(os.Stderr, "Failed to write to log, %v\n", err) fmt.Fprintf(os.Stderr, "Failed to write to log, %v\n", err)
} }
} }
@ -319,7 +329,6 @@ func (entry *Entry) Fatal(args ...interface{}) {
func (entry *Entry) Panic(args ...interface{}) { func (entry *Entry) Panic(args ...interface{}) {
entry.Log(PanicLevel, args...) entry.Log(PanicLevel, args...)
panic(fmt.Sprint(args...))
} }
// Entry Printf family functions // Entry Printf family functions

View File

@ -4,7 +4,5 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w= github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
golang.org/x/sys v0.0.0-20190422165155-953cdadca894 h1:Cz4ceDQGXuKRnVBDTS23GTn/pU5OE2C0WrNTOYK1Uuc=
golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037 h1:YyJpGZS1sBuBCzLAR1VEpK193GlqGZbnPFnPV/5Rsb4= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037 h1:YyJpGZS1sBuBCzLAR1VEpK193GlqGZbnPFnPV/5Rsb4=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=

View File

@ -23,6 +23,9 @@ func (f FieldMap) resolve(key fieldKey) string {
// JSONFormatter formats logs into parsable json // JSONFormatter formats logs into parsable json
type JSONFormatter struct { type JSONFormatter struct {
// TimestampFormat sets the format used for marshaling timestamps. // TimestampFormat sets the format used for marshaling timestamps.
// The format to use is the same than for time.Format or time.Parse from the standard
// library.
// The standard Library already provides a set of predefined format.
TimestampFormat string TimestampFormat string
// DisableTimestamp allows disabling automatic timestamps in output // DisableTimestamp allows disabling automatic timestamps in output
@ -118,7 +121,7 @@ func (f *JSONFormatter) Format(entry *Entry) ([]byte, error) {
encoder.SetIndent("", " ") encoder.SetIndent("", " ")
} }
if err := encoder.Encode(data); err != nil { if err := encoder.Encode(data); err != nil {
return nil, fmt.Errorf("failed to marshal fields to JSON, %v", err) return nil, fmt.Errorf("failed to marshal fields to JSON, %w", err)
} }
return b.Bytes(), nil return b.Bytes(), nil

View File

@ -12,7 +12,7 @@ import (
// LogFunction For big messages, it can be more efficient to pass a function // LogFunction For big messages, it can be more efficient to pass a function
// and only call it if the log level is actually enables rather than // and only call it if the log level is actually enables rather than
// generating the log message and then checking if the level is enabled // generating the log message and then checking if the level is enabled
type LogFunction func()[]interface{} type LogFunction func() []interface{}
type Logger struct { type Logger struct {
// The logs are `io.Copy`'d to this in a mutex. It's common to set this to a // The logs are `io.Copy`'d to this in a mutex. It's common to set this to a

View File

@ -1,4 +1,4 @@
// +build linux aix // +build linux aix zos
// +build !js // +build !js
package logrus package logrus

View File

@ -53,7 +53,10 @@ type TextFormatter struct {
// the time passed since beginning of execution. // the time passed since beginning of execution.
FullTimestamp bool FullTimestamp bool
// TimestampFormat to use for display when a full timestamp is printed // TimestampFormat to use for display when a full timestamp is printed.
// The format to use is the same than for time.Format or time.Parse from the standard
// library.
// The standard Library already provides a set of predefined format.
TimestampFormat string TimestampFormat string
// The fields are sorted by default for a consistent output. For applications // The fields are sorted by default for a consistent output. For applications
@ -235,6 +238,8 @@ func (f *TextFormatter) printColored(b *bytes.Buffer, entry *Entry, keys []strin
levelColor = yellow levelColor = yellow
case ErrorLevel, FatalLevel, PanicLevel: case ErrorLevel, FatalLevel, PanicLevel:
levelColor = red levelColor = red
case InfoLevel:
levelColor = blue
default: default:
levelColor = blue levelColor = blue
} }

View File

@ -6,7 +6,6 @@ package prototext
import ( import (
"fmt" "fmt"
"strings"
"unicode/utf8" "unicode/utf8"
"google.golang.org/protobuf/internal/encoding/messageset" "google.golang.org/protobuf/internal/encoding/messageset"
@ -23,6 +22,7 @@ import (
) )
// Unmarshal reads the given []byte into the given proto.Message. // Unmarshal reads the given []byte into the given proto.Message.
// The provided message must be mutable (e.g., a non-nil pointer to a message).
func Unmarshal(b []byte, m proto.Message) error { func Unmarshal(b []byte, m proto.Message) error {
return UnmarshalOptions{}.Unmarshal(b, m) return UnmarshalOptions{}.Unmarshal(b, m)
} }
@ -51,8 +51,9 @@ type UnmarshalOptions struct {
} }
} }
// Unmarshal reads the given []byte and populates the given proto.Message using options in // Unmarshal reads the given []byte and populates the given proto.Message
// UnmarshalOptions object. // using options in the UnmarshalOptions object.
// The provided message must be mutable (e.g., a non-nil pointer to a message).
func (o UnmarshalOptions) Unmarshal(b []byte, m proto.Message) error { func (o UnmarshalOptions) Unmarshal(b []byte, m proto.Message) error {
return o.unmarshal(b, m) return o.unmarshal(b, m)
} }
@ -158,21 +159,11 @@ func (d decoder) unmarshalMessage(m pref.Message, checkDelims bool) error {
switch tok.NameKind() { switch tok.NameKind() {
case text.IdentName: case text.IdentName:
name = pref.Name(tok.IdentName()) name = pref.Name(tok.IdentName())
fd = fieldDescs.ByName(name) fd = fieldDescs.ByTextName(string(name))
if fd == nil {
// The proto name of a group field is in all lowercase,
// while the textproto field name is the group message name.
gd := fieldDescs.ByName(pref.Name(strings.ToLower(string(name))))
if gd != nil && gd.Kind() == pref.GroupKind && gd.Message().Name() == name {
fd = gd
}
} else if fd.Kind() == pref.GroupKind && fd.Message().Name() != name {
fd = nil // reset since field name is actually the message name
}
case text.TypeName: case text.TypeName:
// Handle extensions only. This code path is not for Any. // Handle extensions only. This code path is not for Any.
xt, xtErr = d.findExtension(pref.FullName(tok.TypeName())) xt, xtErr = d.opts.Resolver.FindExtensionByName(pref.FullName(tok.TypeName()))
case text.FieldNumber: case text.FieldNumber:
isFieldNumberName = true isFieldNumberName = true
@ -269,15 +260,6 @@ func (d decoder) unmarshalMessage(m pref.Message, checkDelims bool) error {
return nil return nil
} }
// findExtension returns protoreflect.ExtensionType from the Resolver if found.
func (d decoder) findExtension(xtName pref.FullName) (pref.ExtensionType, error) {
xt, err := d.opts.Resolver.FindExtensionByName(xtName)
if err == nil {
return xt, nil
}
return messageset.FindMessageSetExtension(d.opts.Resolver, xtName)
}
// unmarshalSingular unmarshals a non-repeated field value specified by the // unmarshalSingular unmarshals a non-repeated field value specified by the
// given FieldDescriptor. // given FieldDescriptor.
func (d decoder) unmarshalSingular(fd pref.FieldDescriptor, m pref.Message) error { func (d decoder) unmarshalSingular(fd pref.FieldDescriptor, m pref.Message) error {

View File

@ -6,7 +6,6 @@ package prototext
import ( import (
"fmt" "fmt"
"sort"
"strconv" "strconv"
"unicode/utf8" "unicode/utf8"
@ -16,10 +15,11 @@ import (
"google.golang.org/protobuf/internal/errors" "google.golang.org/protobuf/internal/errors"
"google.golang.org/protobuf/internal/flags" "google.golang.org/protobuf/internal/flags"
"google.golang.org/protobuf/internal/genid" "google.golang.org/protobuf/internal/genid"
"google.golang.org/protobuf/internal/mapsort" "google.golang.org/protobuf/internal/order"
"google.golang.org/protobuf/internal/pragma" "google.golang.org/protobuf/internal/pragma"
"google.golang.org/protobuf/internal/strs" "google.golang.org/protobuf/internal/strs"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
pref "google.golang.org/protobuf/reflect/protoreflect" pref "google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry" "google.golang.org/protobuf/reflect/protoregistry"
) )
@ -169,35 +169,15 @@ func (e encoder) marshalMessage(m pref.Message, inclDelims bool) error {
// If unable to expand, continue on to marshal Any as a regular message. // If unable to expand, continue on to marshal Any as a regular message.
} }
// Marshal known fields. // Marshal fields.
fieldDescs := messageDesc.Fields() var err error
size := fieldDescs.Len() order.RangeFields(m, order.IndexNameFieldOrder, func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
for i := 0; i < size; { if err = e.marshalField(fd.TextName(), v, fd); err != nil {
fd := fieldDescs.Get(i) return false
if od := fd.ContainingOneof(); od != nil {
fd = m.WhichOneof(od)
i += od.Fields().Len()
} else {
i++
} }
return true
if fd == nil || !m.Has(fd) { })
continue if err != nil {
}
name := fd.Name()
// Use type name for group field name.
if fd.Kind() == pref.GroupKind {
name = fd.Message().Name()
}
val := m.Get(fd)
if err := e.marshalField(string(name), val, fd); err != nil {
return err
}
}
// Marshal extensions.
if err := e.marshalExtensions(m); err != nil {
return err return err
} }
@ -290,7 +270,7 @@ func (e encoder) marshalList(name string, list pref.List, fd pref.FieldDescripto
// marshalMap marshals the given protoreflect.Map as multiple name-value fields. // marshalMap marshals the given protoreflect.Map as multiple name-value fields.
func (e encoder) marshalMap(name string, mmap pref.Map, fd pref.FieldDescriptor) error { func (e encoder) marshalMap(name string, mmap pref.Map, fd pref.FieldDescriptor) error {
var err error var err error
mapsort.Range(mmap, fd.MapKey().Kind(), func(key pref.MapKey, val pref.Value) bool { order.RangeEntries(mmap, order.GenericKeyOrder, func(key pref.MapKey, val pref.Value) bool {
e.WriteName(name) e.WriteName(name)
e.StartMessage() e.StartMessage()
defer e.EndMessage() defer e.EndMessage()
@ -311,48 +291,6 @@ func (e encoder) marshalMap(name string, mmap pref.Map, fd pref.FieldDescriptor)
return err return err
} }
// marshalExtensions marshals extension fields.
func (e encoder) marshalExtensions(m pref.Message) error {
type entry struct {
key string
value pref.Value
desc pref.FieldDescriptor
}
// Get a sorted list based on field key first.
var entries []entry
m.Range(func(fd pref.FieldDescriptor, v pref.Value) bool {
if !fd.IsExtension() {
return true
}
// For MessageSet extensions, the name used is the parent message.
name := fd.FullName()
if messageset.IsMessageSetExtension(fd) {
name = name.Parent()
}
entries = append(entries, entry{
key: string(name),
value: v,
desc: fd,
})
return true
})
// Sort extensions lexicographically.
sort.Slice(entries, func(i, j int) bool {
return entries[i].key < entries[j].key
})
// Write out sorted list.
for _, entry := range entries {
// Extension field name is the proto field name enclosed in [].
name := "[" + entry.key + "]"
if err := e.marshalField(name, entry.value, entry.desc); err != nil {
return err
}
}
return nil
}
// marshalUnknown parses the given []byte and marshals fields out. // marshalUnknown parses the given []byte and marshals fields out.
// This function assumes proper encoding in the given []byte. // This function assumes proper encoding in the given []byte.
func (e encoder) marshalUnknown(b []byte) { func (e encoder) marshalUnknown(b []byte) {

View File

@ -42,6 +42,8 @@ func formatListOpt(vs list, isRoot, allowMulti bool) string {
name = "FileImports" name = "FileImports"
case pref.Descriptor: case pref.Descriptor:
name = reflect.ValueOf(vs).MethodByName("Get").Type().Out(0).Name() + "s" name = reflect.ValueOf(vs).MethodByName("Get").Type().Out(0).Name() + "s"
default:
name = reflect.ValueOf(vs).Elem().Type().Name()
} }
start, end = name+"{", "}" start, end = name+"{", "}"
} }

View File

@ -26,6 +26,14 @@ func Bool() bool {
return randSeed%2 == 1 return randSeed%2 == 1
} }
// Intn returns a deterministically random integer between 0 and n-1, inclusive.
func Intn(n int) int {
if n <= 0 {
panic("must be positive")
}
return int(randSeed % uint64(n))
}
// randSeed is a best-effort at an approximate hash of the Go binary. // randSeed is a best-effort at an approximate hash of the Go binary.
var randSeed = binaryHash() var randSeed = binaryHash()

View File

@ -11,10 +11,9 @@ import (
"google.golang.org/protobuf/encoding/protowire" "google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/internal/errors" "google.golang.org/protobuf/internal/errors"
pref "google.golang.org/protobuf/reflect/protoreflect" pref "google.golang.org/protobuf/reflect/protoreflect"
preg "google.golang.org/protobuf/reflect/protoregistry"
) )
// The MessageSet wire format is equivalent to a message defiend as follows, // The MessageSet wire format is equivalent to a message defined as follows,
// where each Item defines an extension field with a field number of 'type_id' // where each Item defines an extension field with a field number of 'type_id'
// and content of 'message'. MessageSet extensions must be non-repeated message // and content of 'message'. MessageSet extensions must be non-repeated message
// fields. // fields.
@ -48,33 +47,17 @@ func IsMessageSet(md pref.MessageDescriptor) bool {
return ok && xmd.IsMessageSet() return ok && xmd.IsMessageSet()
} }
// IsMessageSetExtension reports this field extends a MessageSet. // IsMessageSetExtension reports this field properly extends a MessageSet.
func IsMessageSetExtension(fd pref.FieldDescriptor) bool { func IsMessageSetExtension(fd pref.FieldDescriptor) bool {
if fd.Name() != ExtensionName { switch {
case fd.Name() != ExtensionName:
return false
case !IsMessageSet(fd.ContainingMessage()):
return false
case fd.FullName().Parent() != fd.Message().FullName():
return false return false
} }
if fd.FullName().Parent() != fd.Message().FullName() { return true
return false
}
return IsMessageSet(fd.ContainingMessage())
}
// FindMessageSetExtension locates a MessageSet extension field by name.
// In text and JSON formats, the extension name used is the message itself.
// The extension field name is derived by appending ExtensionName.
func FindMessageSetExtension(r preg.ExtensionTypeResolver, s pref.FullName) (pref.ExtensionType, error) {
name := s.Append(ExtensionName)
xt, err := r.FindExtensionByName(name)
if err != nil {
if err == preg.NotFound {
return nil, err
}
return nil, errors.Wrap(err, "%q", name)
}
if !IsMessageSetExtension(xt.TypeDescriptor()) {
return nil, preg.NotFound
}
return xt, nil
} }
// SizeField returns the size of a MessageSet item field containing an extension // SizeField returns the size of a MessageSet item field containing an extension

View File

@ -104,7 +104,7 @@ func Unmarshal(tag string, goType reflect.Type, evs pref.EnumValueDescriptors) p
case strings.HasPrefix(s, "json="): case strings.HasPrefix(s, "json="):
jsonName := s[len("json="):] jsonName := s[len("json="):]
if jsonName != strs.JSONCamelCase(string(f.L0.FullName.Name())) { if jsonName != strs.JSONCamelCase(string(f.L0.FullName.Name())) {
f.L1.JSONName.Init(jsonName) f.L1.StringName.InitJSON(jsonName)
} }
case s == "packed": case s == "packed":
f.L1.HasPacked = true f.L1.HasPacked = true

View File

@ -32,7 +32,6 @@ type Encoder struct {
encoderState encoderState
indent string indent string
newline string // set to "\n" if len(indent) > 0
delims [2]byte delims [2]byte
outputASCII bool outputASCII bool
} }
@ -61,7 +60,6 @@ func NewEncoder(indent string, delims [2]byte, outputASCII bool) (*Encoder, erro
return nil, errors.New("indent may only be composed of space and tab characters") return nil, errors.New("indent may only be composed of space and tab characters")
} }
e.indent = indent e.indent = indent
e.newline = "\n"
} }
switch delims { switch delims {
case [2]byte{0, 0}: case [2]byte{0, 0}:
@ -126,7 +124,7 @@ func appendString(out []byte, in string, outputASCII bool) []byte {
// are used to represent both the proto string and bytes type. // are used to represent both the proto string and bytes type.
r = rune(in[0]) r = rune(in[0])
fallthrough fallthrough
case r < ' ' || r == '"' || r == '\\': case r < ' ' || r == '"' || r == '\\' || r == 0x7f:
out = append(out, '\\') out = append(out, '\\')
switch r { switch r {
case '"', '\\': case '"', '\\':
@ -143,7 +141,7 @@ func appendString(out []byte, in string, outputASCII bool) []byte {
out = strconv.AppendUint(out, uint64(r), 16) out = strconv.AppendUint(out, uint64(r), 16)
} }
in = in[n:] in = in[n:]
case outputASCII && r >= utf8.RuneSelf: case r >= utf8.RuneSelf && (outputASCII || r <= 0x009f):
out = append(out, '\\') out = append(out, '\\')
if r <= math.MaxUint16 { if r <= math.MaxUint16 {
out = append(out, 'u') out = append(out, 'u')
@ -168,7 +166,7 @@ func appendString(out []byte, in string, outputASCII bool) []byte {
// escaping. If no characters need escaping, this returns the input length. // escaping. If no characters need escaping, this returns the input length.
func indexNeedEscapeInString(s string) int { func indexNeedEscapeInString(s string) int {
for i := 0; i < len(s); i++ { for i := 0; i < len(s); i++ {
if c := s[i]; c < ' ' || c == '"' || c == '\'' || c == '\\' || c >= utf8.RuneSelf { if c := s[i]; c < ' ' || c == '"' || c == '\'' || c == '\\' || c >= 0x7f {
return i return i
} }
} }

View File

@ -1,40 +0,0 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package fieldsort defines an ordering of fields.
//
// The ordering defined by this package matches the historic behavior of the proto
// package, placing extensions first and oneofs last.
//
// There is no guarantee about stability of the wire encoding, and users should not
// depend on the order defined in this package as it is subject to change without
// notice.
package fieldsort
import (
"google.golang.org/protobuf/reflect/protoreflect"
)
// Less returns true if field a comes before field j in ordered wire marshal output.
func Less(a, b protoreflect.FieldDescriptor) bool {
ea := a.IsExtension()
eb := b.IsExtension()
oa := a.ContainingOneof()
ob := b.ContainingOneof()
switch {
case ea != eb:
return ea
case oa != nil && ob != nil:
if oa == ob {
return a.Number() < b.Number()
}
return oa.Index() < ob.Index()
case oa != nil && !oa.IsSynthetic():
return false
case ob != nil && !ob.IsSynthetic():
return true
default:
return a.Number() < b.Number()
}
}

View File

@ -3,6 +3,9 @@
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// Package filedesc provides functionality for constructing descriptors. // Package filedesc provides functionality for constructing descriptors.
//
// The types in this package implement interfaces in the protoreflect package
// related to protobuf descripriptors.
package filedesc package filedesc
import ( import (

View File

@ -13,6 +13,7 @@ import (
"google.golang.org/protobuf/internal/descfmt" "google.golang.org/protobuf/internal/descfmt"
"google.golang.org/protobuf/internal/descopts" "google.golang.org/protobuf/internal/descopts"
"google.golang.org/protobuf/internal/encoding/defval" "google.golang.org/protobuf/internal/encoding/defval"
"google.golang.org/protobuf/internal/encoding/messageset"
"google.golang.org/protobuf/internal/genid" "google.golang.org/protobuf/internal/genid"
"google.golang.org/protobuf/internal/pragma" "google.golang.org/protobuf/internal/pragma"
"google.golang.org/protobuf/internal/strs" "google.golang.org/protobuf/internal/strs"
@ -99,15 +100,6 @@ func (fd *File) lazyInitOnce() {
fd.mu.Unlock() fd.mu.Unlock()
} }
// ProtoLegacyRawDesc is a pseudo-internal API for allowing the v1 code
// to be able to retrieve the raw descriptor.
//
// WARNING: This method is exempt from the compatibility promise and may be
// removed in the future without warning.
func (fd *File) ProtoLegacyRawDesc() []byte {
return fd.builder.RawDescriptor
}
// GoPackagePath is a pseudo-internal API for determining the Go package path // GoPackagePath is a pseudo-internal API for determining the Go package path
// that this file descriptor is declared in. // that this file descriptor is declared in.
// //
@ -207,7 +199,7 @@ type (
Number pref.FieldNumber Number pref.FieldNumber
Cardinality pref.Cardinality // must be consistent with Message.RequiredNumbers Cardinality pref.Cardinality // must be consistent with Message.RequiredNumbers
Kind pref.Kind Kind pref.Kind
JSONName jsonName StringName stringName
IsProto3Optional bool // promoted from google.protobuf.FieldDescriptorProto IsProto3Optional bool // promoted from google.protobuf.FieldDescriptorProto
IsWeak bool // promoted from google.protobuf.FieldOptions IsWeak bool // promoted from google.protobuf.FieldOptions
HasPacked bool // promoted from google.protobuf.FieldOptions HasPacked bool // promoted from google.protobuf.FieldOptions
@ -277,8 +269,9 @@ func (fd *Field) Options() pref.ProtoMessage {
func (fd *Field) Number() pref.FieldNumber { return fd.L1.Number } func (fd *Field) Number() pref.FieldNumber { return fd.L1.Number }
func (fd *Field) Cardinality() pref.Cardinality { return fd.L1.Cardinality } func (fd *Field) Cardinality() pref.Cardinality { return fd.L1.Cardinality }
func (fd *Field) Kind() pref.Kind { return fd.L1.Kind } func (fd *Field) Kind() pref.Kind { return fd.L1.Kind }
func (fd *Field) HasJSONName() bool { return fd.L1.JSONName.has } func (fd *Field) HasJSONName() bool { return fd.L1.StringName.hasJSON }
func (fd *Field) JSONName() string { return fd.L1.JSONName.get(fd) } func (fd *Field) JSONName() string { return fd.L1.StringName.getJSON(fd) }
func (fd *Field) TextName() string { return fd.L1.StringName.getText(fd) }
func (fd *Field) HasPresence() bool { func (fd *Field) HasPresence() bool {
return fd.L1.Cardinality != pref.Repeated && (fd.L0.ParentFile.L1.Syntax == pref.Proto2 || fd.L1.Message != nil || fd.L1.ContainingOneof != nil) return fd.L1.Cardinality != pref.Repeated && (fd.L0.ParentFile.L1.Syntax == pref.Proto2 || fd.L1.Message != nil || fd.L1.ContainingOneof != nil)
} }
@ -373,7 +366,7 @@ type (
} }
ExtensionL2 struct { ExtensionL2 struct {
Options func() pref.ProtoMessage Options func() pref.ProtoMessage
JSONName jsonName StringName stringName
IsProto3Optional bool // promoted from google.protobuf.FieldDescriptorProto IsProto3Optional bool // promoted from google.protobuf.FieldDescriptorProto
IsPacked bool // promoted from google.protobuf.FieldOptions IsPacked bool // promoted from google.protobuf.FieldOptions
Default defaultValue Default defaultValue
@ -391,8 +384,9 @@ func (xd *Extension) Options() pref.ProtoMessage {
func (xd *Extension) Number() pref.FieldNumber { return xd.L1.Number } func (xd *Extension) Number() pref.FieldNumber { return xd.L1.Number }
func (xd *Extension) Cardinality() pref.Cardinality { return xd.L1.Cardinality } func (xd *Extension) Cardinality() pref.Cardinality { return xd.L1.Cardinality }
func (xd *Extension) Kind() pref.Kind { return xd.L1.Kind } func (xd *Extension) Kind() pref.Kind { return xd.L1.Kind }
func (xd *Extension) HasJSONName() bool { return xd.lazyInit().JSONName.has } func (xd *Extension) HasJSONName() bool { return xd.lazyInit().StringName.hasJSON }
func (xd *Extension) JSONName() string { return xd.lazyInit().JSONName.get(xd) } func (xd *Extension) JSONName() string { return xd.lazyInit().StringName.getJSON(xd) }
func (xd *Extension) TextName() string { return xd.lazyInit().StringName.getText(xd) }
func (xd *Extension) HasPresence() bool { return xd.L1.Cardinality != pref.Repeated } func (xd *Extension) HasPresence() bool { return xd.L1.Cardinality != pref.Repeated }
func (xd *Extension) HasOptionalKeyword() bool { func (xd *Extension) HasOptionalKeyword() bool {
return (xd.L0.ParentFile.L1.Syntax == pref.Proto2 && xd.L1.Cardinality == pref.Optional) || xd.lazyInit().IsProto3Optional return (xd.L0.ParentFile.L1.Syntax == pref.Proto2 && xd.L1.Cardinality == pref.Optional) || xd.lazyInit().IsProto3Optional
@ -506,27 +500,50 @@ func (d *Base) Syntax() pref.Syntax { return d.L0.ParentFile.Syn
func (d *Base) IsPlaceholder() bool { return false } func (d *Base) IsPlaceholder() bool { return false }
func (d *Base) ProtoInternal(pragma.DoNotImplement) {} func (d *Base) ProtoInternal(pragma.DoNotImplement) {}
type jsonName struct { type stringName struct {
has bool hasJSON bool
once sync.Once once sync.Once
name string nameJSON string
nameText string
} }
// Init initializes the name. It is exported for use by other internal packages. // InitJSON initializes the name. It is exported for use by other internal packages.
func (js *jsonName) Init(s string) { func (s *stringName) InitJSON(name string) {
js.has = true s.hasJSON = true
js.name = s s.nameJSON = name
} }
func (js *jsonName) get(fd pref.FieldDescriptor) string { func (s *stringName) lazyInit(fd pref.FieldDescriptor) *stringName {
if !js.has { s.once.Do(func() {
js.once.Do(func() { if fd.IsExtension() {
js.name = strs.JSONCamelCase(string(fd.Name())) // For extensions, JSON and text are formatted the same way.
}) var name string
if messageset.IsMessageSetExtension(fd) {
name = string("[" + fd.FullName().Parent() + "]")
} else {
name = string("[" + fd.FullName() + "]")
} }
return js.name s.nameJSON = name
s.nameText = name
} else {
// Format the JSON name.
if !s.hasJSON {
s.nameJSON = strs.JSONCamelCase(string(fd.Name()))
}
// Format the text name.
s.nameText = string(fd.Name())
if fd.Kind() == pref.GroupKind {
s.nameText = string(fd.Message().Name())
}
}
})
return s
} }
func (s *stringName) getJSON(fd pref.FieldDescriptor) string { return s.lazyInit(fd).nameJSON }
func (s *stringName) getText(fd pref.FieldDescriptor) string { return s.lazyInit(fd).nameText }
func DefaultValue(v pref.Value, ev pref.EnumValueDescriptor) defaultValue { func DefaultValue(v pref.Value, ev pref.EnumValueDescriptor) defaultValue {
dv := defaultValue{has: v.IsValid(), val: v, enum: ev} dv := defaultValue{has: v.IsValid(), val: v, enum: ev}
if b, ok := v.Interface().([]byte); ok { if b, ok := v.Interface().([]byte); ok {

View File

@ -451,7 +451,7 @@ func (fd *Field) unmarshalFull(b []byte, sb *strs.Builder, pf *File, pd pref.Des
case genid.FieldDescriptorProto_Name_field_number: case genid.FieldDescriptorProto_Name_field_number:
fd.L0.FullName = appendFullName(sb, pd.FullName(), v) fd.L0.FullName = appendFullName(sb, pd.FullName(), v)
case genid.FieldDescriptorProto_JsonName_field_number: case genid.FieldDescriptorProto_JsonName_field_number:
fd.L1.JSONName.Init(sb.MakeString(v)) fd.L1.StringName.InitJSON(sb.MakeString(v))
case genid.FieldDescriptorProto_DefaultValue_field_number: case genid.FieldDescriptorProto_DefaultValue_field_number:
fd.L1.Default.val = pref.ValueOfBytes(v) // temporarily store as bytes; later resolved in resolveMessages fd.L1.Default.val = pref.ValueOfBytes(v) // temporarily store as bytes; later resolved in resolveMessages
case genid.FieldDescriptorProto_TypeName_field_number: case genid.FieldDescriptorProto_TypeName_field_number:
@ -551,7 +551,7 @@ func (xd *Extension) unmarshalFull(b []byte, sb *strs.Builder) {
b = b[m:] b = b[m:]
switch num { switch num {
case genid.FieldDescriptorProto_JsonName_field_number: case genid.FieldDescriptorProto_JsonName_field_number:
xd.L2.JSONName.Init(sb.MakeString(v)) xd.L2.StringName.InitJSON(sb.MakeString(v))
case genid.FieldDescriptorProto_DefaultValue_field_number: case genid.FieldDescriptorProto_DefaultValue_field_number:
xd.L2.Default.val = pref.ValueOfBytes(v) // temporarily store as bytes; later resolved in resolveExtensions xd.L2.Default.val = pref.ValueOfBytes(v) // temporarily store as bytes; later resolved in resolveExtensions
case genid.FieldDescriptorProto_TypeName_field_number: case genid.FieldDescriptorProto_TypeName_field_number:

View File

@ -6,9 +6,12 @@ package filedesc
import ( import (
"fmt" "fmt"
"math"
"sort" "sort"
"sync" "sync"
"google.golang.org/protobuf/internal/genid"
"google.golang.org/protobuf/encoding/protowire" "google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/internal/descfmt" "google.golang.org/protobuf/internal/descfmt"
"google.golang.org/protobuf/internal/errors" "google.golang.org/protobuf/internal/errors"
@ -245,6 +248,7 @@ type OneofFields struct {
once sync.Once once sync.Once
byName map[pref.Name]pref.FieldDescriptor // protected by once byName map[pref.Name]pref.FieldDescriptor // protected by once
byJSON map[string]pref.FieldDescriptor // protected by once byJSON map[string]pref.FieldDescriptor // protected by once
byText map[string]pref.FieldDescriptor // protected by once
byNum map[pref.FieldNumber]pref.FieldDescriptor // protected by once byNum map[pref.FieldNumber]pref.FieldDescriptor // protected by once
} }
@ -252,6 +256,7 @@ func (p *OneofFields) Len() int { return
func (p *OneofFields) Get(i int) pref.FieldDescriptor { return p.List[i] } func (p *OneofFields) Get(i int) pref.FieldDescriptor { return p.List[i] }
func (p *OneofFields) ByName(s pref.Name) pref.FieldDescriptor { return p.lazyInit().byName[s] } func (p *OneofFields) ByName(s pref.Name) pref.FieldDescriptor { return p.lazyInit().byName[s] }
func (p *OneofFields) ByJSONName(s string) pref.FieldDescriptor { return p.lazyInit().byJSON[s] } func (p *OneofFields) ByJSONName(s string) pref.FieldDescriptor { return p.lazyInit().byJSON[s] }
func (p *OneofFields) ByTextName(s string) pref.FieldDescriptor { return p.lazyInit().byText[s] }
func (p *OneofFields) ByNumber(n pref.FieldNumber) pref.FieldDescriptor { return p.lazyInit().byNum[n] } func (p *OneofFields) ByNumber(n pref.FieldNumber) pref.FieldDescriptor { return p.lazyInit().byNum[n] }
func (p *OneofFields) Format(s fmt.State, r rune) { descfmt.FormatList(s, r, p) } func (p *OneofFields) Format(s fmt.State, r rune) { descfmt.FormatList(s, r, p) }
func (p *OneofFields) ProtoInternal(pragma.DoNotImplement) {} func (p *OneofFields) ProtoInternal(pragma.DoNotImplement) {}
@ -261,11 +266,13 @@ func (p *OneofFields) lazyInit() *OneofFields {
if len(p.List) > 0 { if len(p.List) > 0 {
p.byName = make(map[pref.Name]pref.FieldDescriptor, len(p.List)) p.byName = make(map[pref.Name]pref.FieldDescriptor, len(p.List))
p.byJSON = make(map[string]pref.FieldDescriptor, len(p.List)) p.byJSON = make(map[string]pref.FieldDescriptor, len(p.List))
p.byText = make(map[string]pref.FieldDescriptor, len(p.List))
p.byNum = make(map[pref.FieldNumber]pref.FieldDescriptor, len(p.List)) p.byNum = make(map[pref.FieldNumber]pref.FieldDescriptor, len(p.List))
for _, f := range p.List { for _, f := range p.List {
// Field names and numbers are guaranteed to be unique. // Field names and numbers are guaranteed to be unique.
p.byName[f.Name()] = f p.byName[f.Name()] = f
p.byJSON[f.JSONName()] = f p.byJSON[f.JSONName()] = f
p.byText[f.TextName()] = f
p.byNum[f.Number()] = f p.byNum[f.Number()] = f
} }
} }
@ -274,9 +281,170 @@ func (p *OneofFields) lazyInit() *OneofFields {
} }
type SourceLocations struct { type SourceLocations struct {
// List is a list of SourceLocations.
// The SourceLocation.Next field does not need to be populated
// as it will be lazily populated upon first need.
List []pref.SourceLocation List []pref.SourceLocation
// File is the parent file descriptor that these locations are relative to.
// If non-nil, ByDescriptor verifies that the provided descriptor
// is a child of this file descriptor.
File pref.FileDescriptor
once sync.Once
byPath map[pathKey]int
} }
func (p *SourceLocations) Len() int { return len(p.List) } func (p *SourceLocations) Len() int { return len(p.List) }
func (p *SourceLocations) Get(i int) pref.SourceLocation { return p.List[i] } func (p *SourceLocations) Get(i int) pref.SourceLocation { return p.lazyInit().List[i] }
func (p *SourceLocations) byKey(k pathKey) pref.SourceLocation {
if i, ok := p.lazyInit().byPath[k]; ok {
return p.List[i]
}
return pref.SourceLocation{}
}
func (p *SourceLocations) ByPath(path pref.SourcePath) pref.SourceLocation {
return p.byKey(newPathKey(path))
}
func (p *SourceLocations) ByDescriptor(desc pref.Descriptor) pref.SourceLocation {
if p.File != nil && desc != nil && p.File != desc.ParentFile() {
return pref.SourceLocation{} // mismatching parent files
}
var pathArr [16]int32
path := pathArr[:0]
for {
switch desc.(type) {
case pref.FileDescriptor:
// Reverse the path since it was constructed in reverse.
for i, j := 0, len(path)-1; i < j; i, j = i+1, j-1 {
path[i], path[j] = path[j], path[i]
}
return p.byKey(newPathKey(path))
case pref.MessageDescriptor:
path = append(path, int32(desc.Index()))
desc = desc.Parent()
switch desc.(type) {
case pref.FileDescriptor:
path = append(path, int32(genid.FileDescriptorProto_MessageType_field_number))
case pref.MessageDescriptor:
path = append(path, int32(genid.DescriptorProto_NestedType_field_number))
default:
return pref.SourceLocation{}
}
case pref.FieldDescriptor:
isExtension := desc.(pref.FieldDescriptor).IsExtension()
path = append(path, int32(desc.Index()))
desc = desc.Parent()
if isExtension {
switch desc.(type) {
case pref.FileDescriptor:
path = append(path, int32(genid.FileDescriptorProto_Extension_field_number))
case pref.MessageDescriptor:
path = append(path, int32(genid.DescriptorProto_Extension_field_number))
default:
return pref.SourceLocation{}
}
} else {
switch desc.(type) {
case pref.MessageDescriptor:
path = append(path, int32(genid.DescriptorProto_Field_field_number))
default:
return pref.SourceLocation{}
}
}
case pref.OneofDescriptor:
path = append(path, int32(desc.Index()))
desc = desc.Parent()
switch desc.(type) {
case pref.MessageDescriptor:
path = append(path, int32(genid.DescriptorProto_OneofDecl_field_number))
default:
return pref.SourceLocation{}
}
case pref.EnumDescriptor:
path = append(path, int32(desc.Index()))
desc = desc.Parent()
switch desc.(type) {
case pref.FileDescriptor:
path = append(path, int32(genid.FileDescriptorProto_EnumType_field_number))
case pref.MessageDescriptor:
path = append(path, int32(genid.DescriptorProto_EnumType_field_number))
default:
return pref.SourceLocation{}
}
case pref.EnumValueDescriptor:
path = append(path, int32(desc.Index()))
desc = desc.Parent()
switch desc.(type) {
case pref.EnumDescriptor:
path = append(path, int32(genid.EnumDescriptorProto_Value_field_number))
default:
return pref.SourceLocation{}
}
case pref.ServiceDescriptor:
path = append(path, int32(desc.Index()))
desc = desc.Parent()
switch desc.(type) {
case pref.FileDescriptor:
path = append(path, int32(genid.FileDescriptorProto_Service_field_number))
default:
return pref.SourceLocation{}
}
case pref.MethodDescriptor:
path = append(path, int32(desc.Index()))
desc = desc.Parent()
switch desc.(type) {
case pref.ServiceDescriptor:
path = append(path, int32(genid.ServiceDescriptorProto_Method_field_number))
default:
return pref.SourceLocation{}
}
default:
return pref.SourceLocation{}
}
}
}
func (p *SourceLocations) lazyInit() *SourceLocations {
p.once.Do(func() {
if len(p.List) > 0 {
// Collect all the indexes for a given path.
pathIdxs := make(map[pathKey][]int, len(p.List))
for i, l := range p.List {
k := newPathKey(l.Path)
pathIdxs[k] = append(pathIdxs[k], i)
}
// Update the next index for all locations.
p.byPath = make(map[pathKey]int, len(p.List))
for k, idxs := range pathIdxs {
for i := 0; i < len(idxs)-1; i++ {
p.List[idxs[i]].Next = idxs[i+1]
}
p.List[idxs[len(idxs)-1]].Next = 0
p.byPath[k] = idxs[0] // record the first location for this path
}
}
})
return p
}
func (p *SourceLocations) ProtoInternal(pragma.DoNotImplement) {} func (p *SourceLocations) ProtoInternal(pragma.DoNotImplement) {}
// pathKey is a comparable representation of protoreflect.SourcePath.
type pathKey struct {
arr [16]uint8 // first n-1 path segments; last element is the length
str string // used if the path does not fit in arr
}
func newPathKey(p pref.SourcePath) (k pathKey) {
if len(p) < len(k.arr) {
for i, ps := range p {
if ps < 0 || math.MaxUint8 <= ps {
return pathKey{str: p.String()}
}
k.arr[i] = uint8(ps)
}
k.arr[len(k.arr)-1] = uint8(len(p))
return k
}
return pathKey{str: p.String()}
}

View File

@ -142,6 +142,7 @@ type Fields struct {
once sync.Once once sync.Once
byName map[protoreflect.Name]*Field // protected by once byName map[protoreflect.Name]*Field // protected by once
byJSON map[string]*Field // protected by once byJSON map[string]*Field // protected by once
byText map[string]*Field // protected by once
byNum map[protoreflect.FieldNumber]*Field // protected by once byNum map[protoreflect.FieldNumber]*Field // protected by once
} }
@ -163,6 +164,12 @@ func (p *Fields) ByJSONName(s string) protoreflect.FieldDescriptor {
} }
return nil return nil
} }
func (p *Fields) ByTextName(s string) protoreflect.FieldDescriptor {
if d := p.lazyInit().byText[s]; d != nil {
return d
}
return nil
}
func (p *Fields) ByNumber(n protoreflect.FieldNumber) protoreflect.FieldDescriptor { func (p *Fields) ByNumber(n protoreflect.FieldNumber) protoreflect.FieldDescriptor {
if d := p.lazyInit().byNum[n]; d != nil { if d := p.lazyInit().byNum[n]; d != nil {
return d return d
@ -178,6 +185,7 @@ func (p *Fields) lazyInit() *Fields {
if len(p.List) > 0 { if len(p.List) > 0 {
p.byName = make(map[protoreflect.Name]*Field, len(p.List)) p.byName = make(map[protoreflect.Name]*Field, len(p.List))
p.byJSON = make(map[string]*Field, len(p.List)) p.byJSON = make(map[string]*Field, len(p.List))
p.byText = make(map[string]*Field, len(p.List))
p.byNum = make(map[protoreflect.FieldNumber]*Field, len(p.List)) p.byNum = make(map[protoreflect.FieldNumber]*Field, len(p.List))
for i := range p.List { for i := range p.List {
d := &p.List[i] d := &p.List[i]
@ -187,6 +195,9 @@ func (p *Fields) lazyInit() *Fields {
if _, ok := p.byJSON[d.JSONName()]; !ok { if _, ok := p.byJSON[d.JSONName()]; !ok {
p.byJSON[d.JSONName()] = d p.byJSON[d.JSONName()] = d
} }
if _, ok := p.byText[d.TextName()]; !ok {
p.byText[d.TextName()] = d
}
if _, ok := p.byNum[d.Number()]; !ok { if _, ok := p.byNum[d.Number()]; !ok {
p.byNum[d.Number()] = d p.byNum[d.Number()] = d
} }

View File

@ -167,7 +167,7 @@ func (Export) MessageTypeOf(m message) pref.MessageType {
if mv := (Export{}).protoMessageV2Of(m); mv != nil { if mv := (Export{}).protoMessageV2Of(m); mv != nil {
return mv.ProtoReflect().Type() return mv.ProtoReflect().Type()
} }
return legacyLoadMessageInfo(reflect.TypeOf(m), "") return legacyLoadMessageType(reflect.TypeOf(m), "")
} }
// MessageStringOf returns the message value as a string, // MessageStringOf returns the message value as a string,

View File

@ -10,6 +10,7 @@ import (
"sync" "sync"
"google.golang.org/protobuf/encoding/protowire" "google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/internal/errors"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
pref "google.golang.org/protobuf/reflect/protoreflect" pref "google.golang.org/protobuf/reflect/protoreflect"
preg "google.golang.org/protobuf/reflect/protoregistry" preg "google.golang.org/protobuf/reflect/protoregistry"
@ -20,6 +21,7 @@ type errInvalidUTF8 struct{}
func (errInvalidUTF8) Error() string { return "string field contains invalid UTF-8" } func (errInvalidUTF8) Error() string { return "string field contains invalid UTF-8" }
func (errInvalidUTF8) InvalidUTF8() bool { return true } func (errInvalidUTF8) InvalidUTF8() bool { return true }
func (errInvalidUTF8) Unwrap() error { return errors.Error }
// initOneofFieldCoders initializes the fast-path functions for the fields in a oneof. // initOneofFieldCoders initializes the fast-path functions for the fields in a oneof.
// //
@ -242,7 +244,7 @@ func consumeMessageInfo(b []byte, p pointer, wtyp protowire.Type, f *coderFieldI
} }
v, n := protowire.ConsumeBytes(b) v, n := protowire.ConsumeBytes(b)
if n < 0 { if n < 0 {
return out, protowire.ParseError(n) return out, errDecode
} }
if p.Elem().IsNil() { if p.Elem().IsNil() {
p.SetPointer(pointerOfValue(reflect.New(f.mi.GoReflectType.Elem()))) p.SetPointer(pointerOfValue(reflect.New(f.mi.GoReflectType.Elem())))
@ -276,7 +278,7 @@ func consumeMessage(b []byte, m proto.Message, wtyp protowire.Type, opts unmarsh
} }
v, n := protowire.ConsumeBytes(b) v, n := protowire.ConsumeBytes(b)
if n < 0 { if n < 0 {
return out, protowire.ParseError(n) return out, errDecode
} }
o, err := opts.Options().UnmarshalState(piface.UnmarshalInput{ o, err := opts.Options().UnmarshalState(piface.UnmarshalInput{
Buf: v, Buf: v,
@ -420,7 +422,7 @@ func consumeGroup(b []byte, m proto.Message, num protowire.Number, wtyp protowir
} }
b, n := protowire.ConsumeGroup(num, b) b, n := protowire.ConsumeGroup(num, b)
if n < 0 { if n < 0 {
return out, protowire.ParseError(n) return out, errDecode
} }
o, err := opts.Options().UnmarshalState(piface.UnmarshalInput{ o, err := opts.Options().UnmarshalState(piface.UnmarshalInput{
Buf: b, Buf: b,
@ -494,7 +496,7 @@ func consumeMessageSliceInfo(b []byte, p pointer, wtyp protowire.Type, f *coderF
} }
v, n := protowire.ConsumeBytes(b) v, n := protowire.ConsumeBytes(b)
if n < 0 { if n < 0 {
return out, protowire.ParseError(n) return out, errDecode
} }
m := reflect.New(f.mi.GoReflectType.Elem()).Interface() m := reflect.New(f.mi.GoReflectType.Elem()).Interface()
mp := pointerOfIface(m) mp := pointerOfIface(m)
@ -550,7 +552,7 @@ func consumeMessageSlice(b []byte, p pointer, goType reflect.Type, wtyp protowir
} }
v, n := protowire.ConsumeBytes(b) v, n := protowire.ConsumeBytes(b)
if n < 0 { if n < 0 {
return out, protowire.ParseError(n) return out, errDecode
} }
mp := reflect.New(goType.Elem()) mp := reflect.New(goType.Elem())
o, err := opts.Options().UnmarshalState(piface.UnmarshalInput{ o, err := opts.Options().UnmarshalState(piface.UnmarshalInput{
@ -613,7 +615,7 @@ func consumeMessageSliceValue(b []byte, listv pref.Value, _ protowire.Number, wt
} }
v, n := protowire.ConsumeBytes(b) v, n := protowire.ConsumeBytes(b)
if n < 0 { if n < 0 {
return pref.Value{}, out, protowire.ParseError(n) return pref.Value{}, out, errDecode
} }
m := list.NewElement() m := list.NewElement()
o, err := opts.Options().UnmarshalState(piface.UnmarshalInput{ o, err := opts.Options().UnmarshalState(piface.UnmarshalInput{
@ -681,7 +683,7 @@ func consumeGroupSliceValue(b []byte, listv pref.Value, num protowire.Number, wt
} }
b, n := protowire.ConsumeGroup(num, b) b, n := protowire.ConsumeGroup(num, b)
if n < 0 { if n < 0 {
return pref.Value{}, out, protowire.ParseError(n) return pref.Value{}, out, errDecode
} }
m := list.NewElement() m := list.NewElement()
o, err := opts.Options().UnmarshalState(piface.UnmarshalInput{ o, err := opts.Options().UnmarshalState(piface.UnmarshalInput{
@ -767,7 +769,7 @@ func consumeGroupSlice(b []byte, p pointer, num protowire.Number, wtyp protowire
} }
b, n := protowire.ConsumeGroup(num, b) b, n := protowire.ConsumeGroup(num, b)
if n < 0 { if n < 0 {
return out, protowire.ParseError(n) return out, errDecode
} }
mp := reflect.New(goType.Elem()) mp := reflect.New(goType.Elem())
o, err := opts.Options().UnmarshalState(piface.UnmarshalInput{ o, err := opts.Options().UnmarshalState(piface.UnmarshalInput{

File diff suppressed because it is too large Load Diff

View File

@ -5,7 +5,6 @@
package impl package impl
import ( import (
"errors"
"reflect" "reflect"
"sort" "sort"
@ -118,7 +117,7 @@ func consumeMap(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi *mapInfo
} }
b, n := protowire.ConsumeBytes(b) b, n := protowire.ConsumeBytes(b)
if n < 0 { if n < 0 {
return out, protowire.ParseError(n) return out, errDecode
} }
var ( var (
key = mapi.keyZero key = mapi.keyZero
@ -127,10 +126,10 @@ func consumeMap(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi *mapInfo
for len(b) > 0 { for len(b) > 0 {
num, wtyp, n := protowire.ConsumeTag(b) num, wtyp, n := protowire.ConsumeTag(b)
if n < 0 { if n < 0 {
return out, protowire.ParseError(n) return out, errDecode
} }
if num > protowire.MaxValidNumber { if num > protowire.MaxValidNumber {
return out, errors.New("invalid field number") return out, errDecode
} }
b = b[n:] b = b[n:]
err := errUnknown err := errUnknown
@ -157,7 +156,7 @@ func consumeMap(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi *mapInfo
if err == errUnknown { if err == errUnknown {
n = protowire.ConsumeFieldValue(num, wtyp, b) n = protowire.ConsumeFieldValue(num, wtyp, b)
if n < 0 { if n < 0 {
return out, protowire.ParseError(n) return out, errDecode
} }
} else if err != nil { } else if err != nil {
return out, err return out, err
@ -175,7 +174,7 @@ func consumeMapOfMessage(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi
} }
b, n := protowire.ConsumeBytes(b) b, n := protowire.ConsumeBytes(b)
if n < 0 { if n < 0 {
return out, protowire.ParseError(n) return out, errDecode
} }
var ( var (
key = mapi.keyZero key = mapi.keyZero
@ -184,10 +183,10 @@ func consumeMapOfMessage(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi
for len(b) > 0 { for len(b) > 0 {
num, wtyp, n := protowire.ConsumeTag(b) num, wtyp, n := protowire.ConsumeTag(b)
if n < 0 { if n < 0 {
return out, protowire.ParseError(n) return out, errDecode
} }
if num > protowire.MaxValidNumber { if num > protowire.MaxValidNumber {
return out, errors.New("invalid field number") return out, errDecode
} }
b = b[n:] b = b[n:]
err := errUnknown err := errUnknown
@ -208,7 +207,7 @@ func consumeMapOfMessage(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi
var v []byte var v []byte
v, n = protowire.ConsumeBytes(b) v, n = protowire.ConsumeBytes(b)
if n < 0 { if n < 0 {
return out, protowire.ParseError(n) return out, errDecode
} }
var o unmarshalOutput var o unmarshalOutput
o, err = f.mi.unmarshalPointer(v, pointerOfValue(val), 0, opts) o, err = f.mi.unmarshalPointer(v, pointerOfValue(val), 0, opts)
@ -221,7 +220,7 @@ func consumeMapOfMessage(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi
if err == errUnknown { if err == errUnknown {
n = protowire.ConsumeFieldValue(num, wtyp, b) n = protowire.ConsumeFieldValue(num, wtyp, b)
if n < 0 { if n < 0 {
return out, protowire.ParseError(n) return out, errDecode
} }
} else if err != nil { } else if err != nil {
return out, err return out, err

View File

@ -11,7 +11,7 @@ import (
"google.golang.org/protobuf/encoding/protowire" "google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/internal/encoding/messageset" "google.golang.org/protobuf/internal/encoding/messageset"
"google.golang.org/protobuf/internal/fieldsort" "google.golang.org/protobuf/internal/order"
pref "google.golang.org/protobuf/reflect/protoreflect" pref "google.golang.org/protobuf/reflect/protoreflect"
piface "google.golang.org/protobuf/runtime/protoiface" piface "google.golang.org/protobuf/runtime/protoiface"
) )
@ -27,6 +27,7 @@ type coderMessageInfo struct {
coderFields map[protowire.Number]*coderFieldInfo coderFields map[protowire.Number]*coderFieldInfo
sizecacheOffset offset sizecacheOffset offset
unknownOffset offset unknownOffset offset
unknownPtrKind bool
extensionOffset offset extensionOffset offset
needsInitCheck bool needsInitCheck bool
isMessageSet bool isMessageSet bool
@ -47,9 +48,20 @@ type coderFieldInfo struct {
} }
func (mi *MessageInfo) makeCoderMethods(t reflect.Type, si structInfo) { func (mi *MessageInfo) makeCoderMethods(t reflect.Type, si structInfo) {
mi.sizecacheOffset = invalidOffset
mi.unknownOffset = invalidOffset
mi.extensionOffset = invalidOffset
if si.sizecacheOffset.IsValid() && si.sizecacheType == sizecacheType {
mi.sizecacheOffset = si.sizecacheOffset mi.sizecacheOffset = si.sizecacheOffset
}
if si.unknownOffset.IsValid() && (si.unknownType == unknownFieldsAType || si.unknownType == unknownFieldsBType) {
mi.unknownOffset = si.unknownOffset mi.unknownOffset = si.unknownOffset
mi.unknownPtrKind = si.unknownType.Kind() == reflect.Ptr
}
if si.extensionOffset.IsValid() && si.extensionType == extensionFieldsType {
mi.extensionOffset = si.extensionOffset mi.extensionOffset = si.extensionOffset
}
mi.coderFields = make(map[protowire.Number]*coderFieldInfo) mi.coderFields = make(map[protowire.Number]*coderFieldInfo)
fields := mi.Desc.Fields() fields := mi.Desc.Fields()
@ -73,6 +85,27 @@ func (mi *MessageInfo) makeCoderMethods(t reflect.Type, si structInfo) {
var funcs pointerCoderFuncs var funcs pointerCoderFuncs
var childMessage *MessageInfo var childMessage *MessageInfo
switch { switch {
case ft == nil:
// This never occurs for generated message types.
// It implies that a hand-crafted type has missing Go fields
// for specific protobuf message fields.
funcs = pointerCoderFuncs{
size: func(p pointer, f *coderFieldInfo, opts marshalOptions) int {
return 0
},
marshal: func(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
return nil, nil
},
unmarshal: func(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (unmarshalOutput, error) {
panic("missing Go struct field for " + string(fd.FullName()))
},
isInit: func(p pointer, f *coderFieldInfo) error {
panic("missing Go struct field for " + string(fd.FullName()))
},
merge: func(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
panic("missing Go struct field for " + string(fd.FullName()))
},
}
case isOneof: case isOneof:
fieldOffset = offsetOf(fs, mi.Exporter) fieldOffset = offsetOf(fs, mi.Exporter)
case fd.IsWeak(): case fd.IsWeak():
@ -136,7 +169,7 @@ func (mi *MessageInfo) makeCoderMethods(t reflect.Type, si structInfo) {
sort.Slice(mi.orderedCoderFields, func(i, j int) bool { sort.Slice(mi.orderedCoderFields, func(i, j int) bool {
fi := fields.ByNumber(mi.orderedCoderFields[i].num) fi := fields.ByNumber(mi.orderedCoderFields[i].num)
fj := fields.ByNumber(mi.orderedCoderFields[j].num) fj := fields.ByNumber(mi.orderedCoderFields[j].num)
return fieldsort.Less(fi, fj) return order.LegacyFieldOrder(fi, fj)
}) })
} }
@ -157,3 +190,28 @@ func (mi *MessageInfo) makeCoderMethods(t reflect.Type, si structInfo) {
mi.methods.Merge = mi.merge mi.methods.Merge = mi.merge
} }
} }
// getUnknownBytes returns a *[]byte for the unknown fields.
// It is the caller's responsibility to check whether the pointer is nil.
// This function is specially designed to be inlineable.
func (mi *MessageInfo) getUnknownBytes(p pointer) *[]byte {
if mi.unknownPtrKind {
return *p.Apply(mi.unknownOffset).BytesPtr()
} else {
return p.Apply(mi.unknownOffset).Bytes()
}
}
// mutableUnknownBytes returns a *[]byte for the unknown fields.
// The returned pointer is guaranteed to not be nil.
func (mi *MessageInfo) mutableUnknownBytes(p pointer) *[]byte {
if mi.unknownPtrKind {
bp := p.Apply(mi.unknownOffset).BytesPtr()
if *bp == nil {
*bp = new([]byte)
}
return *bp
} else {
return p.Apply(mi.unknownOffset).Bytes()
}
}

View File

@ -29,8 +29,9 @@ func sizeMessageSet(mi *MessageInfo, p pointer, opts marshalOptions) (size int)
size += xi.funcs.size(x.Value(), protowire.SizeTag(messageset.FieldMessage), opts) size += xi.funcs.size(x.Value(), protowire.SizeTag(messageset.FieldMessage), opts)
} }
unknown := *p.Apply(mi.unknownOffset).Bytes() if u := mi.getUnknownBytes(p); u != nil {
size += messageset.SizeUnknown(unknown) size += messageset.SizeUnknown(*u)
}
return size return size
} }
@ -69,11 +70,13 @@ func marshalMessageSet(mi *MessageInfo, b []byte, p pointer, opts marshalOptions
} }
} }
unknown := *p.Apply(mi.unknownOffset).Bytes() if u := mi.getUnknownBytes(p); u != nil {
b, err := messageset.AppendUnknown(b, unknown) var err error
b, err = messageset.AppendUnknown(b, *u)
if err != nil { if err != nil {
return b, err return b, err
} }
}
return b, nil return b, nil
} }
@ -100,13 +103,13 @@ func unmarshalMessageSet(mi *MessageInfo, b []byte, p pointer, opts unmarshalOpt
*ep = make(map[int32]ExtensionField) *ep = make(map[int32]ExtensionField)
} }
ext := *ep ext := *ep
unknown := p.Apply(mi.unknownOffset).Bytes()
initialized := true initialized := true
err = messageset.Unmarshal(b, true, func(num protowire.Number, v []byte) error { err = messageset.Unmarshal(b, true, func(num protowire.Number, v []byte) error {
o, err := mi.unmarshalExtension(v, num, protowire.BytesType, ext, opts) o, err := mi.unmarshalExtension(v, num, protowire.BytesType, ext, opts)
if err == errUnknown { if err == errUnknown {
*unknown = protowire.AppendTag(*unknown, num, protowire.BytesType) u := mi.mutableUnknownBytes(p)
*unknown = append(*unknown, v...) *u = protowire.AppendTag(*u, num, protowire.BytesType)
*u = append(*u, v...)
return nil return nil
} }
if !o.initialized { if !o.initialized {

View File

@ -30,7 +30,7 @@ func consumeEnum(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, _
} }
v, n := protowire.ConsumeVarint(b) v, n := protowire.ConsumeVarint(b)
if n < 0 { if n < 0 {
return out, protowire.ParseError(n) return out, errDecode
} }
p.v.Elem().SetInt(int64(v)) p.v.Elem().SetInt(int64(v))
out.n = n out.n = n
@ -130,12 +130,12 @@ func consumeEnumSlice(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInf
if wtyp == protowire.BytesType { if wtyp == protowire.BytesType {
b, n := protowire.ConsumeBytes(b) b, n := protowire.ConsumeBytes(b)
if n < 0 { if n < 0 {
return out, protowire.ParseError(n) return out, errDecode
} }
for len(b) > 0 { for len(b) > 0 {
v, n := protowire.ConsumeVarint(b) v, n := protowire.ConsumeVarint(b)
if n < 0 { if n < 0 {
return out, protowire.ParseError(n) return out, errDecode
} }
rv := reflect.New(s.Type().Elem()).Elem() rv := reflect.New(s.Type().Elem()).Elem()
rv.SetInt(int64(v)) rv.SetInt(int64(v))
@ -150,7 +150,7 @@ func consumeEnumSlice(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInf
} }
v, n := protowire.ConsumeVarint(b) v, n := protowire.ConsumeVarint(b)
if n < 0 { if n < 0 {
return out, protowire.ParseError(n) return out, errDecode
} }
rv := reflect.New(s.Type().Elem()).Elem() rv := reflect.New(s.Type().Elem()).Elem()
rv.SetInt(int64(v)) rv.SetInt(int64(v))

View File

@ -423,6 +423,13 @@ func (c *messageConverter) PBValueOf(v reflect.Value) pref.Value {
if v.Type() != c.goType { if v.Type() != c.goType {
panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), c.goType)) panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), c.goType))
} }
if c.isNonPointer() {
if v.CanAddr() {
v = v.Addr() // T => *T
} else {
v = reflect.Zero(reflect.PtrTo(v.Type()))
}
}
if m, ok := v.Interface().(pref.ProtoMessage); ok { if m, ok := v.Interface().(pref.ProtoMessage); ok {
return pref.ValueOfMessage(m.ProtoReflect()) return pref.ValueOfMessage(m.ProtoReflect())
} }
@ -437,6 +444,16 @@ func (c *messageConverter) GoValueOf(v pref.Value) reflect.Value {
} else { } else {
rv = reflect.ValueOf(m.Interface()) rv = reflect.ValueOf(m.Interface())
} }
if c.isNonPointer() {
if rv.Type() != reflect.PtrTo(c.goType) {
panic(fmt.Sprintf("invalid type: got %v, want %v", rv.Type(), reflect.PtrTo(c.goType)))
}
if !rv.IsNil() {
rv = rv.Elem() // *T => T
} else {
rv = reflect.Zero(rv.Type().Elem())
}
}
if rv.Type() != c.goType { if rv.Type() != c.goType {
panic(fmt.Sprintf("invalid type: got %v, want %v", rv.Type(), c.goType)) panic(fmt.Sprintf("invalid type: got %v, want %v", rv.Type(), c.goType))
} }
@ -451,6 +468,9 @@ func (c *messageConverter) IsValidPB(v pref.Value) bool {
} else { } else {
rv = reflect.ValueOf(m.Interface()) rv = reflect.ValueOf(m.Interface())
} }
if c.isNonPointer() {
return rv.Type() == reflect.PtrTo(c.goType)
}
return rv.Type() == c.goType return rv.Type() == c.goType
} }
@ -459,9 +479,18 @@ func (c *messageConverter) IsValidGo(v reflect.Value) bool {
} }
func (c *messageConverter) New() pref.Value { func (c *messageConverter) New() pref.Value {
if c.isNonPointer() {
return c.PBValueOf(reflect.New(c.goType).Elem())
}
return c.PBValueOf(reflect.New(c.goType.Elem())) return c.PBValueOf(reflect.New(c.goType.Elem()))
} }
func (c *messageConverter) Zero() pref.Value { func (c *messageConverter) Zero() pref.Value {
return c.PBValueOf(reflect.Zero(c.goType)) return c.PBValueOf(reflect.Zero(c.goType))
} }
// isNonPointer reports whether the type is a non-pointer type.
// This never occurs for generated message types.
func (c *messageConverter) isNonPointer() bool {
return c.goType.Kind() != reflect.Ptr
}

View File

@ -17,6 +17,8 @@ import (
piface "google.golang.org/protobuf/runtime/protoiface" piface "google.golang.org/protobuf/runtime/protoiface"
) )
var errDecode = errors.New("cannot parse invalid wire-format data")
type unmarshalOptions struct { type unmarshalOptions struct {
flags protoiface.UnmarshalInputFlags flags protoiface.UnmarshalInputFlags
resolver interface { resolver interface {
@ -100,13 +102,13 @@ func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag protowire.
var n int var n int
tag, n = protowire.ConsumeVarint(b) tag, n = protowire.ConsumeVarint(b)
if n < 0 { if n < 0 {
return out, protowire.ParseError(n) return out, errDecode
} }
b = b[n:] b = b[n:]
} }
var num protowire.Number var num protowire.Number
if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) { if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) {
return out, errors.New("invalid field number") return out, errDecode
} else { } else {
num = protowire.Number(n) num = protowire.Number(n)
} }
@ -114,7 +116,7 @@ func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag protowire.
if wtyp == protowire.EndGroupType { if wtyp == protowire.EndGroupType {
if num != groupTag { if num != groupTag {
return out, errors.New("mismatching end group marker") return out, errDecode
} }
groupTag = 0 groupTag = 0
break break
@ -170,10 +172,10 @@ func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag protowire.
} }
n = protowire.ConsumeFieldValue(num, wtyp, b) n = protowire.ConsumeFieldValue(num, wtyp, b)
if n < 0 { if n < 0 {
return out, protowire.ParseError(n) return out, errDecode
} }
if !opts.DiscardUnknown() && mi.unknownOffset.IsValid() { if !opts.DiscardUnknown() && mi.unknownOffset.IsValid() {
u := p.Apply(mi.unknownOffset).Bytes() u := mi.mutableUnknownBytes(p)
*u = protowire.AppendTag(*u, num, wtyp) *u = protowire.AppendTag(*u, num, wtyp)
*u = append(*u, b[:n]...) *u = append(*u, b[:n]...)
} }
@ -181,7 +183,7 @@ func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag protowire.
b = b[n:] b = b[n:]
} }
if groupTag != 0 { if groupTag != 0 {
return out, errors.New("missing end group marker") return out, errDecode
} }
if mi.numRequiredFields > 0 && bits.OnesCount64(requiredMask) != int(mi.numRequiredFields) { if mi.numRequiredFields > 0 && bits.OnesCount64(requiredMask) != int(mi.numRequiredFields) {
initialized = false initialized = false
@ -221,7 +223,7 @@ func (mi *MessageInfo) unmarshalExtension(b []byte, num protowire.Number, wtyp p
return out, nil return out, nil
} }
case ValidationInvalid: case ValidationInvalid:
return out, errors.New("invalid wire format") return out, errDecode
case ValidationUnknown: case ValidationUnknown:
} }
} }

View File

@ -79,8 +79,9 @@ func (mi *MessageInfo) sizePointerSlow(p pointer, opts marshalOptions) (size int
size += f.funcs.size(fptr, f, opts) size += f.funcs.size(fptr, f, opts)
} }
if mi.unknownOffset.IsValid() { if mi.unknownOffset.IsValid() {
u := *p.Apply(mi.unknownOffset).Bytes() if u := mi.getUnknownBytes(p); u != nil {
size += len(u) size += len(*u)
}
} }
if mi.sizecacheOffset.IsValid() { if mi.sizecacheOffset.IsValid() {
if size > math.MaxInt32 { if size > math.MaxInt32 {
@ -141,8 +142,9 @@ func (mi *MessageInfo) marshalAppendPointer(b []byte, p pointer, opts marshalOpt
} }
} }
if mi.unknownOffset.IsValid() && !mi.isMessageSet { if mi.unknownOffset.IsValid() && !mi.isMessageSet {
u := *p.Apply(mi.unknownOffset).Bytes() if u := mi.getUnknownBytes(p); u != nil {
b = append(b, u...) b = append(b, (*u)...)
}
} }
return b, nil return b, nil
} }

View File

@ -30,7 +30,7 @@ func (Export) LegacyMessageTypeOf(m piface.MessageV1, name pref.FullName) pref.M
if mv := (Export{}).protoMessageV2Of(m); mv != nil { if mv := (Export{}).protoMessageV2Of(m); mv != nil {
return mv.ProtoReflect().Type() return mv.ProtoReflect().Type()
} }
return legacyLoadMessageInfo(reflect.TypeOf(m), name) return legacyLoadMessageType(reflect.TypeOf(m), name)
} }
// UnmarshalJSONEnum unmarshals an enum from a JSON-encoded input. // UnmarshalJSONEnum unmarshals an enum from a JSON-encoded input.

View File

@ -154,7 +154,8 @@ func (x placeholderExtension) Number() pref.FieldNumber { retu
func (x placeholderExtension) Cardinality() pref.Cardinality { return 0 } func (x placeholderExtension) Cardinality() pref.Cardinality { return 0 }
func (x placeholderExtension) Kind() pref.Kind { return 0 } func (x placeholderExtension) Kind() pref.Kind { return 0 }
func (x placeholderExtension) HasJSONName() bool { return false } func (x placeholderExtension) HasJSONName() bool { return false }
func (x placeholderExtension) JSONName() string { return "" } func (x placeholderExtension) JSONName() string { return "[" + string(x.name) + "]" }
func (x placeholderExtension) TextName() string { return "[" + string(x.name) + "]" }
func (x placeholderExtension) HasPresence() bool { return false } func (x placeholderExtension) HasPresence() bool { return false }
func (x placeholderExtension) HasOptionalKeyword() bool { return false } func (x placeholderExtension) HasOptionalKeyword() bool { return false }
func (x placeholderExtension) IsExtension() bool { return true } func (x placeholderExtension) IsExtension() bool { return true }

View File

@ -24,14 +24,24 @@ import (
// legacyWrapMessage wraps v as a protoreflect.Message, // legacyWrapMessage wraps v as a protoreflect.Message,
// where v must be a *struct kind and not implement the v2 API already. // where v must be a *struct kind and not implement the v2 API already.
func legacyWrapMessage(v reflect.Value) pref.Message { func legacyWrapMessage(v reflect.Value) pref.Message {
typ := v.Type() t := v.Type()
if typ.Kind() != reflect.Ptr || typ.Elem().Kind() != reflect.Struct { if t.Kind() != reflect.Ptr || t.Elem().Kind() != reflect.Struct {
return aberrantMessage{v: v} return aberrantMessage{v: v}
} }
mt := legacyLoadMessageInfo(typ, "") mt := legacyLoadMessageInfo(t, "")
return mt.MessageOf(v.Interface()) return mt.MessageOf(v.Interface())
} }
// legacyLoadMessageType dynamically loads a protoreflect.Type for t,
// where t must be not implement the v2 API already.
// The provided name is used if it cannot be determined from the message.
func legacyLoadMessageType(t reflect.Type, name pref.FullName) protoreflect.MessageType {
if t.Kind() != reflect.Ptr || t.Elem().Kind() != reflect.Struct {
return aberrantMessageType{t}
}
return legacyLoadMessageInfo(t, name)
}
var legacyMessageTypeCache sync.Map // map[reflect.Type]*MessageInfo var legacyMessageTypeCache sync.Map // map[reflect.Type]*MessageInfo
// legacyLoadMessageInfo dynamically loads a *MessageInfo for t, // legacyLoadMessageInfo dynamically loads a *MessageInfo for t,
@ -49,8 +59,9 @@ func legacyLoadMessageInfo(t reflect.Type, name pref.FullName) *MessageInfo {
GoReflectType: t, GoReflectType: t,
} }
var hasMarshal, hasUnmarshal bool
v := reflect.Zero(t).Interface() v := reflect.Zero(t).Interface()
if _, ok := v.(legacyMarshaler); ok { if _, hasMarshal = v.(legacyMarshaler); hasMarshal {
mi.methods.Marshal = legacyMarshal mi.methods.Marshal = legacyMarshal
// We have no way to tell whether the type's Marshal method // We have no way to tell whether the type's Marshal method
@ -59,10 +70,10 @@ func legacyLoadMessageInfo(t reflect.Type, name pref.FullName) *MessageInfo {
// calling Marshal methods when present. // calling Marshal methods when present.
mi.methods.Flags |= piface.SupportMarshalDeterministic mi.methods.Flags |= piface.SupportMarshalDeterministic
} }
if _, ok := v.(legacyUnmarshaler); ok { if _, hasUnmarshal = v.(legacyUnmarshaler); hasUnmarshal {
mi.methods.Unmarshal = legacyUnmarshal mi.methods.Unmarshal = legacyUnmarshal
} }
if _, ok := v.(legacyMerger); ok { if _, hasMerge := v.(legacyMerger); hasMerge || (hasMarshal && hasUnmarshal) {
mi.methods.Merge = legacyMerge mi.methods.Merge = legacyMerge
} }
@ -75,7 +86,7 @@ func legacyLoadMessageInfo(t reflect.Type, name pref.FullName) *MessageInfo {
var legacyMessageDescCache sync.Map // map[reflect.Type]protoreflect.MessageDescriptor var legacyMessageDescCache sync.Map // map[reflect.Type]protoreflect.MessageDescriptor
// LegacyLoadMessageDesc returns an MessageDescriptor derived from the Go type, // LegacyLoadMessageDesc returns an MessageDescriptor derived from the Go type,
// which must be a *struct kind and not implement the v2 API already. // which should be a *struct kind and must not implement the v2 API already.
// //
// This is exported for testing purposes. // This is exported for testing purposes.
func LegacyLoadMessageDesc(t reflect.Type) pref.MessageDescriptor { func LegacyLoadMessageDesc(t reflect.Type) pref.MessageDescriptor {
@ -114,6 +125,7 @@ func legacyLoadMessageDesc(t reflect.Type, name pref.FullName) pref.MessageDescr
// If the Go type has no fields, then this might be a proto3 empty message // If the Go type has no fields, then this might be a proto3 empty message
// from before the size cache was added. If there are any fields, check to // from before the size cache was added. If there are any fields, check to
// see that at least one of them looks like something we generated. // see that at least one of them looks like something we generated.
if t.Elem().Kind() == reflect.Struct {
if nfield := t.Elem().NumField(); nfield > 0 { if nfield := t.Elem().NumField(); nfield > 0 {
hasProtoField := false hasProtoField := false
for i := 0; i < nfield; i++ { for i := 0; i < nfield; i++ {
@ -127,6 +139,7 @@ func legacyLoadMessageDesc(t reflect.Type, name pref.FullName) pref.MessageDescr
return aberrantLoadMessageDesc(t, name) return aberrantLoadMessageDesc(t, name)
} }
} }
}
md := legacyLoadFileDesc(b).Messages().Get(idxs[0]) md := legacyLoadFileDesc(b).Messages().Get(idxs[0])
for _, i := range idxs[1:] { for _, i := range idxs[1:] {
@ -370,7 +383,7 @@ type legacyMerger interface {
Merge(protoiface.MessageV1) Merge(protoiface.MessageV1)
} }
var legacyProtoMethods = &piface.Methods{ var aberrantProtoMethods = &piface.Methods{
Marshal: legacyMarshal, Marshal: legacyMarshal,
Unmarshal: legacyUnmarshal, Unmarshal: legacyUnmarshal,
Merge: legacyMerge, Merge: legacyMerge,
@ -401,18 +414,40 @@ func legacyUnmarshal(in piface.UnmarshalInput) (piface.UnmarshalOutput, error) {
v := in.Message.(unwrapper).protoUnwrap() v := in.Message.(unwrapper).protoUnwrap()
unmarshaler, ok := v.(legacyUnmarshaler) unmarshaler, ok := v.(legacyUnmarshaler)
if !ok { if !ok {
return piface.UnmarshalOutput{}, errors.New("%T does not implement Marshal", v) return piface.UnmarshalOutput{}, errors.New("%T does not implement Unmarshal", v)
} }
return piface.UnmarshalOutput{}, unmarshaler.Unmarshal(in.Buf) return piface.UnmarshalOutput{}, unmarshaler.Unmarshal(in.Buf)
} }
func legacyMerge(in piface.MergeInput) piface.MergeOutput { func legacyMerge(in piface.MergeInput) piface.MergeOutput {
// Check whether this supports the legacy merger.
dstv := in.Destination.(unwrapper).protoUnwrap() dstv := in.Destination.(unwrapper).protoUnwrap()
merger, ok := dstv.(legacyMerger) merger, ok := dstv.(legacyMerger)
if ok {
merger.Merge(Export{}.ProtoMessageV1Of(in.Source))
return piface.MergeOutput{Flags: piface.MergeComplete}
}
// If legacy merger is unavailable, implement merge in terms of
// a marshal and unmarshal operation.
srcv := in.Source.(unwrapper).protoUnwrap()
marshaler, ok := srcv.(legacyMarshaler)
if !ok { if !ok {
return piface.MergeOutput{} return piface.MergeOutput{}
} }
merger.Merge(Export{}.ProtoMessageV1Of(in.Source)) dstv = in.Destination.(unwrapper).protoUnwrap()
unmarshaler, ok := dstv.(legacyUnmarshaler)
if !ok {
return piface.MergeOutput{}
}
b, err := marshaler.Marshal()
if err != nil {
return piface.MergeOutput{}
}
err = unmarshaler.Unmarshal(b)
if err != nil {
return piface.MergeOutput{}
}
return piface.MergeOutput{Flags: piface.MergeComplete} return piface.MergeOutput{Flags: piface.MergeComplete}
} }
@ -422,6 +457,9 @@ type aberrantMessageType struct {
} }
func (mt aberrantMessageType) New() pref.Message { func (mt aberrantMessageType) New() pref.Message {
if mt.t.Kind() == reflect.Ptr {
return aberrantMessage{reflect.New(mt.t.Elem())}
}
return aberrantMessage{reflect.Zero(mt.t)} return aberrantMessage{reflect.Zero(mt.t)}
} }
func (mt aberrantMessageType) Zero() pref.Message { func (mt aberrantMessageType) Zero() pref.Message {
@ -443,6 +481,17 @@ type aberrantMessage struct {
v reflect.Value v reflect.Value
} }
// Reset implements the v1 proto.Message.Reset method.
func (m aberrantMessage) Reset() {
if mr, ok := m.v.Interface().(interface{ Reset() }); ok {
mr.Reset()
return
}
if m.v.Kind() == reflect.Ptr && !m.v.IsNil() {
m.v.Elem().Set(reflect.Zero(m.v.Type().Elem()))
}
}
func (m aberrantMessage) ProtoReflect() pref.Message { func (m aberrantMessage) ProtoReflect() pref.Message {
return m return m
} }
@ -454,33 +503,40 @@ func (m aberrantMessage) Type() pref.MessageType {
return aberrantMessageType{m.v.Type()} return aberrantMessageType{m.v.Type()}
} }
func (m aberrantMessage) New() pref.Message { func (m aberrantMessage) New() pref.Message {
if m.v.Type().Kind() == reflect.Ptr {
return aberrantMessage{reflect.New(m.v.Type().Elem())}
}
return aberrantMessage{reflect.Zero(m.v.Type())} return aberrantMessage{reflect.Zero(m.v.Type())}
} }
func (m aberrantMessage) Interface() pref.ProtoMessage { func (m aberrantMessage) Interface() pref.ProtoMessage {
return m return m
} }
func (m aberrantMessage) Range(f func(pref.FieldDescriptor, pref.Value) bool) { func (m aberrantMessage) Range(f func(pref.FieldDescriptor, pref.Value) bool) {
return
} }
func (m aberrantMessage) Has(pref.FieldDescriptor) bool { func (m aberrantMessage) Has(pref.FieldDescriptor) bool {
panic("invalid field descriptor") return false
} }
func (m aberrantMessage) Clear(pref.FieldDescriptor) { func (m aberrantMessage) Clear(pref.FieldDescriptor) {
panic("invalid field descriptor") panic("invalid Message.Clear on " + string(m.Descriptor().FullName()))
} }
func (m aberrantMessage) Get(pref.FieldDescriptor) pref.Value { func (m aberrantMessage) Get(fd pref.FieldDescriptor) pref.Value {
panic("invalid field descriptor") if fd.Default().IsValid() {
return fd.Default()
}
panic("invalid Message.Get on " + string(m.Descriptor().FullName()))
} }
func (m aberrantMessage) Set(pref.FieldDescriptor, pref.Value) { func (m aberrantMessage) Set(pref.FieldDescriptor, pref.Value) {
panic("invalid field descriptor") panic("invalid Message.Set on " + string(m.Descriptor().FullName()))
} }
func (m aberrantMessage) Mutable(pref.FieldDescriptor) pref.Value { func (m aberrantMessage) Mutable(pref.FieldDescriptor) pref.Value {
panic("invalid field descriptor") panic("invalid Message.Mutable on " + string(m.Descriptor().FullName()))
} }
func (m aberrantMessage) NewField(pref.FieldDescriptor) pref.Value { func (m aberrantMessage) NewField(pref.FieldDescriptor) pref.Value {
panic("invalid field descriptor") panic("invalid Message.NewField on " + string(m.Descriptor().FullName()))
} }
func (m aberrantMessage) WhichOneof(pref.OneofDescriptor) pref.FieldDescriptor { func (m aberrantMessage) WhichOneof(pref.OneofDescriptor) pref.FieldDescriptor {
panic("invalid oneof descriptor") panic("invalid Message.WhichOneof descriptor on " + string(m.Descriptor().FullName()))
} }
func (m aberrantMessage) GetUnknown() pref.RawFields { func (m aberrantMessage) GetUnknown() pref.RawFields {
return nil return nil
@ -489,13 +545,13 @@ func (m aberrantMessage) SetUnknown(pref.RawFields) {
// SetUnknown discards its input on messages which don't support unknown field storage. // SetUnknown discards its input on messages which don't support unknown field storage.
} }
func (m aberrantMessage) IsValid() bool { func (m aberrantMessage) IsValid() bool {
// An invalid message is a read-only, empty message. Since we don't know anything if m.v.Kind() == reflect.Ptr {
// about the alleged contents of this message, we can't say with confidence that return !m.v.IsNil()
// it is invalid in this sense. Therefore, report it as valid. }
return true return false
} }
func (m aberrantMessage) ProtoMethods() *piface.Methods { func (m aberrantMessage) ProtoMethods() *piface.Methods {
return legacyProtoMethods return aberrantProtoMethods
} }
func (m aberrantMessage) protoUnwrap() interface{} { func (m aberrantMessage) protoUnwrap() interface{} {
return m.v.Interface() return m.v.Interface()

View File

@ -77,9 +77,9 @@ func (mi *MessageInfo) mergePointer(dst, src pointer, opts mergeOptions) {
} }
} }
if mi.unknownOffset.IsValid() { if mi.unknownOffset.IsValid() {
du := dst.Apply(mi.unknownOffset).Bytes() su := mi.getUnknownBytes(src)
su := src.Apply(mi.unknownOffset).Bytes() if su != nil && len(*su) > 0 {
if len(*su) > 0 { du := mi.mutableUnknownBytes(dst)
*du = append(*du, *su...) *du = append(*du, *su...)
} }
} }

View File

@ -15,6 +15,7 @@ import (
"google.golang.org/protobuf/internal/genid" "google.golang.org/protobuf/internal/genid"
"google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoreflect"
pref "google.golang.org/protobuf/reflect/protoreflect" pref "google.golang.org/protobuf/reflect/protoreflect"
preg "google.golang.org/protobuf/reflect/protoregistry"
) )
// MessageInfo provides protobuf related functionality for a given Go type // MessageInfo provides protobuf related functionality for a given Go type
@ -109,22 +110,29 @@ func (mi *MessageInfo) getPointer(m pref.Message) (p pointer, ok bool) {
type ( type (
SizeCache = int32 SizeCache = int32
WeakFields = map[int32]protoreflect.ProtoMessage WeakFields = map[int32]protoreflect.ProtoMessage
UnknownFields = []byte UnknownFields = unknownFieldsA // TODO: switch to unknownFieldsB
unknownFieldsA = []byte
unknownFieldsB = *[]byte
ExtensionFields = map[int32]ExtensionField ExtensionFields = map[int32]ExtensionField
) )
var ( var (
sizecacheType = reflect.TypeOf(SizeCache(0)) sizecacheType = reflect.TypeOf(SizeCache(0))
weakFieldsType = reflect.TypeOf(WeakFields(nil)) weakFieldsType = reflect.TypeOf(WeakFields(nil))
unknownFieldsType = reflect.TypeOf(UnknownFields(nil)) unknownFieldsAType = reflect.TypeOf(unknownFieldsA(nil))
unknownFieldsBType = reflect.TypeOf(unknownFieldsB(nil))
extensionFieldsType = reflect.TypeOf(ExtensionFields(nil)) extensionFieldsType = reflect.TypeOf(ExtensionFields(nil))
) )
type structInfo struct { type structInfo struct {
sizecacheOffset offset sizecacheOffset offset
sizecacheType reflect.Type
weakOffset offset weakOffset offset
weakType reflect.Type
unknownOffset offset unknownOffset offset
unknownType reflect.Type
extensionOffset offset extensionOffset offset
extensionType reflect.Type
fieldsByNumber map[pref.FieldNumber]reflect.StructField fieldsByNumber map[pref.FieldNumber]reflect.StructField
oneofsByName map[pref.Name]reflect.StructField oneofsByName map[pref.Name]reflect.StructField
@ -151,18 +159,22 @@ fieldLoop:
case genid.SizeCache_goname, genid.SizeCacheA_goname: case genid.SizeCache_goname, genid.SizeCacheA_goname:
if f.Type == sizecacheType { if f.Type == sizecacheType {
si.sizecacheOffset = offsetOf(f, mi.Exporter) si.sizecacheOffset = offsetOf(f, mi.Exporter)
si.sizecacheType = f.Type
} }
case genid.WeakFields_goname, genid.WeakFieldsA_goname: case genid.WeakFields_goname, genid.WeakFieldsA_goname:
if f.Type == weakFieldsType { if f.Type == weakFieldsType {
si.weakOffset = offsetOf(f, mi.Exporter) si.weakOffset = offsetOf(f, mi.Exporter)
si.weakType = f.Type
} }
case genid.UnknownFields_goname, genid.UnknownFieldsA_goname: case genid.UnknownFields_goname, genid.UnknownFieldsA_goname:
if f.Type == unknownFieldsType { if f.Type == unknownFieldsAType || f.Type == unknownFieldsBType {
si.unknownOffset = offsetOf(f, mi.Exporter) si.unknownOffset = offsetOf(f, mi.Exporter)
si.unknownType = f.Type
} }
case genid.ExtensionFields_goname, genid.ExtensionFieldsA_goname, genid.ExtensionFieldsB_goname: case genid.ExtensionFields_goname, genid.ExtensionFieldsA_goname, genid.ExtensionFieldsB_goname:
if f.Type == extensionFieldsType { if f.Type == extensionFieldsType {
si.extensionOffset = offsetOf(f, mi.Exporter) si.extensionOffset = offsetOf(f, mi.Exporter)
si.extensionType = f.Type
} }
default: default:
for _, s := range strings.Split(f.Tag.Get("protobuf"), ",") { for _, s := range strings.Split(f.Tag.Get("protobuf"), ",") {
@ -212,4 +224,53 @@ func (mi *MessageInfo) New() protoreflect.Message {
func (mi *MessageInfo) Zero() protoreflect.Message { func (mi *MessageInfo) Zero() protoreflect.Message {
return mi.MessageOf(reflect.Zero(mi.GoReflectType).Interface()) return mi.MessageOf(reflect.Zero(mi.GoReflectType).Interface())
} }
func (mi *MessageInfo) Descriptor() protoreflect.MessageDescriptor { return mi.Desc } func (mi *MessageInfo) Descriptor() protoreflect.MessageDescriptor {
return mi.Desc
}
func (mi *MessageInfo) Enum(i int) protoreflect.EnumType {
mi.init()
fd := mi.Desc.Fields().Get(i)
return Export{}.EnumTypeOf(mi.fieldTypes[fd.Number()])
}
func (mi *MessageInfo) Message(i int) protoreflect.MessageType {
mi.init()
fd := mi.Desc.Fields().Get(i)
switch {
case fd.IsWeak():
mt, _ := preg.GlobalTypes.FindMessageByName(fd.Message().FullName())
return mt
case fd.IsMap():
return mapEntryType{fd.Message(), mi.fieldTypes[fd.Number()]}
default:
return Export{}.MessageTypeOf(mi.fieldTypes[fd.Number()])
}
}
type mapEntryType struct {
desc protoreflect.MessageDescriptor
valType interface{} // zero value of enum or message type
}
func (mt mapEntryType) New() protoreflect.Message {
return nil
}
func (mt mapEntryType) Zero() protoreflect.Message {
return nil
}
func (mt mapEntryType) Descriptor() protoreflect.MessageDescriptor {
return mt.desc
}
func (mt mapEntryType) Enum(i int) protoreflect.EnumType {
fd := mt.desc.Fields().Get(i)
if fd.Enum() == nil {
return nil
}
return Export{}.EnumTypeOf(mt.valType)
}
func (mt mapEntryType) Message(i int) protoreflect.MessageType {
fd := mt.desc.Fields().Get(i)
if fd.Message() == nil {
return nil
}
return Export{}.MessageTypeOf(mt.valType)
}

View File

@ -8,6 +8,7 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"google.golang.org/protobuf/internal/detrand"
"google.golang.org/protobuf/internal/pragma" "google.golang.org/protobuf/internal/pragma"
pref "google.golang.org/protobuf/reflect/protoreflect" pref "google.golang.org/protobuf/reflect/protoreflect"
) )
@ -16,6 +17,11 @@ type reflectMessageInfo struct {
fields map[pref.FieldNumber]*fieldInfo fields map[pref.FieldNumber]*fieldInfo
oneofs map[pref.Name]*oneofInfo oneofs map[pref.Name]*oneofInfo
// fieldTypes contains the zero value of an enum or message field.
// For lists, it contains the element type.
// For maps, it contains the entry value type.
fieldTypes map[pref.FieldNumber]interface{}
// denseFields is a subset of fields where: // denseFields is a subset of fields where:
// 0 < fieldDesc.Number() < len(denseFields) // 0 < fieldDesc.Number() < len(denseFields)
// It provides faster access to the fieldInfo, but may be incomplete. // It provides faster access to the fieldInfo, but may be incomplete.
@ -36,6 +42,7 @@ func (mi *MessageInfo) makeReflectFuncs(t reflect.Type, si structInfo) {
mi.makeKnownFieldsFunc(si) mi.makeKnownFieldsFunc(si)
mi.makeUnknownFieldsFunc(t, si) mi.makeUnknownFieldsFunc(t, si)
mi.makeExtensionFieldsFunc(t, si) mi.makeExtensionFieldsFunc(t, si)
mi.makeFieldTypes(si)
} }
// makeKnownFieldsFunc generates functions for operations that can be performed // makeKnownFieldsFunc generates functions for operations that can be performed
@ -51,17 +58,23 @@ func (mi *MessageInfo) makeKnownFieldsFunc(si structInfo) {
for i := 0; i < fds.Len(); i++ { for i := 0; i < fds.Len(); i++ {
fd := fds.Get(i) fd := fds.Get(i)
fs := si.fieldsByNumber[fd.Number()] fs := si.fieldsByNumber[fd.Number()]
isOneof := fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic()
if isOneof {
fs = si.oneofsByName[fd.ContainingOneof().Name()]
}
var fi fieldInfo var fi fieldInfo
switch { switch {
case fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic(): case fs.Type == nil:
fi = fieldInfoForOneof(fd, si.oneofsByName[fd.ContainingOneof().Name()], mi.Exporter, si.oneofWrappersByNumber[fd.Number()]) fi = fieldInfoForMissing(fd) // never occurs for officially generated message types
case isOneof:
fi = fieldInfoForOneof(fd, fs, mi.Exporter, si.oneofWrappersByNumber[fd.Number()])
case fd.IsMap(): case fd.IsMap():
fi = fieldInfoForMap(fd, fs, mi.Exporter) fi = fieldInfoForMap(fd, fs, mi.Exporter)
case fd.IsList(): case fd.IsList():
fi = fieldInfoForList(fd, fs, mi.Exporter) fi = fieldInfoForList(fd, fs, mi.Exporter)
case fd.IsWeak(): case fd.IsWeak():
fi = fieldInfoForWeakMessage(fd, si.weakOffset) fi = fieldInfoForWeakMessage(fd, si.weakOffset)
case fd.Kind() == pref.MessageKind || fd.Kind() == pref.GroupKind: case fd.Message() != nil:
fi = fieldInfoForMessage(fd, fs, mi.Exporter) fi = fieldInfoForMessage(fd, fs, mi.Exporter)
default: default:
fi = fieldInfoForScalar(fd, fs, mi.Exporter) fi = fieldInfoForScalar(fd, fs, mi.Exporter)
@ -92,27 +105,53 @@ func (mi *MessageInfo) makeKnownFieldsFunc(si structInfo) {
i++ i++
} }
} }
// Introduce instability to iteration order, but keep it deterministic.
if len(mi.rangeInfos) > 1 && detrand.Bool() {
i := detrand.Intn(len(mi.rangeInfos) - 1)
mi.rangeInfos[i], mi.rangeInfos[i+1] = mi.rangeInfos[i+1], mi.rangeInfos[i]
}
} }
func (mi *MessageInfo) makeUnknownFieldsFunc(t reflect.Type, si structInfo) { func (mi *MessageInfo) makeUnknownFieldsFunc(t reflect.Type, si structInfo) {
mi.getUnknown = func(pointer) pref.RawFields { return nil } switch {
mi.setUnknown = func(pointer, pref.RawFields) { return } case si.unknownOffset.IsValid() && si.unknownType == unknownFieldsAType:
if si.unknownOffset.IsValid() { // Handle as []byte.
mi.getUnknown = func(p pointer) pref.RawFields { mi.getUnknown = func(p pointer) pref.RawFields {
if p.IsNil() { if p.IsNil() {
return nil return nil
} }
rv := p.Apply(si.unknownOffset).AsValueOf(unknownFieldsType) return *p.Apply(mi.unknownOffset).Bytes()
return pref.RawFields(*rv.Interface().(*[]byte))
} }
mi.setUnknown = func(p pointer, b pref.RawFields) { mi.setUnknown = func(p pointer, b pref.RawFields) {
if p.IsNil() { if p.IsNil() {
panic("invalid SetUnknown on nil Message") panic("invalid SetUnknown on nil Message")
} }
rv := p.Apply(si.unknownOffset).AsValueOf(unknownFieldsType) *p.Apply(mi.unknownOffset).Bytes() = b
*rv.Interface().(*[]byte) = []byte(b)
} }
} else { case si.unknownOffset.IsValid() && si.unknownType == unknownFieldsBType:
// Handle as *[]byte.
mi.getUnknown = func(p pointer) pref.RawFields {
if p.IsNil() {
return nil
}
bp := p.Apply(mi.unknownOffset).BytesPtr()
if *bp == nil {
return nil
}
return **bp
}
mi.setUnknown = func(p pointer, b pref.RawFields) {
if p.IsNil() {
panic("invalid SetUnknown on nil Message")
}
bp := p.Apply(mi.unknownOffset).BytesPtr()
if *bp == nil {
*bp = new([]byte)
}
**bp = b
}
default:
mi.getUnknown = func(pointer) pref.RawFields { mi.getUnknown = func(pointer) pref.RawFields {
return nil return nil
} }
@ -139,6 +178,58 @@ func (mi *MessageInfo) makeExtensionFieldsFunc(t reflect.Type, si structInfo) {
} }
} }
} }
func (mi *MessageInfo) makeFieldTypes(si structInfo) {
md := mi.Desc
fds := md.Fields()
for i := 0; i < fds.Len(); i++ {
var ft reflect.Type
fd := fds.Get(i)
fs := si.fieldsByNumber[fd.Number()]
isOneof := fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic()
if isOneof {
fs = si.oneofsByName[fd.ContainingOneof().Name()]
}
var isMessage bool
switch {
case fs.Type == nil:
continue // never occurs for officially generated message types
case isOneof:
if fd.Enum() != nil || fd.Message() != nil {
ft = si.oneofWrappersByNumber[fd.Number()].Field(0).Type
}
case fd.IsMap():
if fd.MapValue().Enum() != nil || fd.MapValue().Message() != nil {
ft = fs.Type.Elem()
}
isMessage = fd.MapValue().Message() != nil
case fd.IsList():
if fd.Enum() != nil || fd.Message() != nil {
ft = fs.Type.Elem()
}
isMessage = fd.Message() != nil
case fd.Enum() != nil:
ft = fs.Type
if fd.HasPresence() && ft.Kind() == reflect.Ptr {
ft = ft.Elem()
}
case fd.Message() != nil:
ft = fs.Type
if fd.IsWeak() {
ft = nil
}
isMessage = true
}
if isMessage && ft != nil && ft.Kind() != reflect.Ptr {
ft = reflect.PtrTo(ft) // never occurs for officially generated message types
}
if ft != nil {
if mi.fieldTypes == nil {
mi.fieldTypes = make(map[pref.FieldNumber]interface{})
}
mi.fieldTypes[fd.Number()] = reflect.Zero(ft).Interface()
}
}
}
type extensionMap map[int32]ExtensionField type extensionMap map[int32]ExtensionField
@ -306,7 +397,6 @@ var (
// pointer to a named Go struct. If the provided type has a ProtoReflect method, // pointer to a named Go struct. If the provided type has a ProtoReflect method,
// it must be implemented by calling this method. // it must be implemented by calling this method.
func (mi *MessageInfo) MessageOf(m interface{}) pref.Message { func (mi *MessageInfo) MessageOf(m interface{}) pref.Message {
// TODO: Switch the input to be an opaque Pointer.
if reflect.TypeOf(m) != mi.GoReflectType { if reflect.TypeOf(m) != mi.GoReflectType {
panic(fmt.Sprintf("type mismatch: got %T, want %v", m, mi.GoReflectType)) panic(fmt.Sprintf("type mismatch: got %T, want %v", m, mi.GoReflectType))
} }
@ -320,6 +410,17 @@ func (mi *MessageInfo) MessageOf(m interface{}) pref.Message {
func (m *messageReflectWrapper) pointer() pointer { return m.p } func (m *messageReflectWrapper) pointer() pointer { return m.p }
func (m *messageReflectWrapper) messageInfo() *MessageInfo { return m.mi } func (m *messageReflectWrapper) messageInfo() *MessageInfo { return m.mi }
// Reset implements the v1 proto.Message.Reset method.
func (m *messageIfaceWrapper) Reset() {
if mr, ok := m.protoUnwrap().(interface{ Reset() }); ok {
mr.Reset()
return
}
rv := reflect.ValueOf(m.protoUnwrap())
if rv.Kind() == reflect.Ptr && !rv.IsNil() {
rv.Elem().Set(reflect.Zero(rv.Type().Elem()))
}
}
func (m *messageIfaceWrapper) ProtoReflect() pref.Message { func (m *messageIfaceWrapper) ProtoReflect() pref.Message {
return (*messageReflectWrapper)(m) return (*messageReflectWrapper)(m)
} }

View File

@ -28,6 +28,39 @@ type fieldInfo struct {
newField func() pref.Value newField func() pref.Value
} }
func fieldInfoForMissing(fd pref.FieldDescriptor) fieldInfo {
// This never occurs for generated message types.
// It implies that a hand-crafted type has missing Go fields
// for specific protobuf message fields.
return fieldInfo{
fieldDesc: fd,
has: func(p pointer) bool {
return false
},
clear: func(p pointer) {
panic("missing Go struct field for " + string(fd.FullName()))
},
get: func(p pointer) pref.Value {
return fd.Default()
},
set: func(p pointer, v pref.Value) {
panic("missing Go struct field for " + string(fd.FullName()))
},
mutable: func(p pointer) pref.Value {
panic("missing Go struct field for " + string(fd.FullName()))
},
newMessage: func() pref.Message {
panic("missing Go struct field for " + string(fd.FullName()))
},
newField: func() pref.Value {
if v := fd.Default(); v.IsValid() {
return v
}
panic("missing Go struct field for " + string(fd.FullName()))
},
}
}
func fieldInfoForOneof(fd pref.FieldDescriptor, fs reflect.StructField, x exporter, ot reflect.Type) fieldInfo { func fieldInfoForOneof(fd pref.FieldDescriptor, fs reflect.StructField, x exporter, ot reflect.Type) fieldInfo {
ft := fs.Type ft := fs.Type
if ft.Kind() != reflect.Interface { if ft.Kind() != reflect.Interface {
@ -97,7 +130,7 @@ func fieldInfoForOneof(fd pref.FieldDescriptor, fs reflect.StructField, x export
rv.Set(reflect.New(ot)) rv.Set(reflect.New(ot))
} }
rv = rv.Elem().Elem().Field(0) rv = rv.Elem().Elem().Field(0)
if rv.IsNil() { if rv.Kind() == reflect.Ptr && rv.IsNil() {
rv.Set(conv.GoValueOf(pref.ValueOfMessage(conv.New().Message()))) rv.Set(conv.GoValueOf(pref.ValueOfMessage(conv.New().Message())))
} }
return conv.PBValueOf(rv) return conv.PBValueOf(rv)
@ -225,7 +258,10 @@ func fieldInfoForScalar(fd pref.FieldDescriptor, fs reflect.StructField, x expor
isBytes := ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 isBytes := ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8
if nullable { if nullable {
if ft.Kind() != reflect.Ptr && ft.Kind() != reflect.Slice { if ft.Kind() != reflect.Ptr && ft.Kind() != reflect.Slice {
panic(fmt.Sprintf("field %v has invalid type: got %v, want pointer", fd.FullName(), ft)) // This never occurs for generated message types.
// Despite the protobuf type system specifying presence,
// the Go field type cannot represent it.
nullable = false
} }
if ft.Kind() == reflect.Ptr { if ft.Kind() == reflect.Ptr {
ft = ft.Elem() ft = ft.Elem()
@ -388,6 +424,9 @@ func fieldInfoForMessage(fd pref.FieldDescriptor, fs reflect.StructField, x expo
return false return false
} }
rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem() rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
if fs.Type.Kind() != reflect.Ptr {
return !isZero(rv)
}
return !rv.IsNil() return !rv.IsNil()
}, },
clear: func(p pointer) { clear: func(p pointer) {
@ -404,13 +443,13 @@ func fieldInfoForMessage(fd pref.FieldDescriptor, fs reflect.StructField, x expo
set: func(p pointer, v pref.Value) { set: func(p pointer, v pref.Value) {
rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem() rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
rv.Set(conv.GoValueOf(v)) rv.Set(conv.GoValueOf(v))
if rv.IsNil() { if fs.Type.Kind() == reflect.Ptr && rv.IsNil() {
panic(fmt.Sprintf("field %v has invalid nil pointer", fd.FullName())) panic(fmt.Sprintf("field %v has invalid nil pointer", fd.FullName()))
} }
}, },
mutable: func(p pointer) pref.Value { mutable: func(p pointer) pref.Value {
rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem() rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
if rv.IsNil() { if fs.Type.Kind() == reflect.Ptr && rv.IsNil() {
rv.Set(conv.GoValueOf(conv.New())) rv.Set(conv.GoValueOf(conv.New()))
} }
return conv.PBValueOf(rv) return conv.PBValueOf(rv)
@ -464,3 +503,41 @@ func makeOneofInfo(od pref.OneofDescriptor, si structInfo, x exporter) *oneofInf
} }
return oi return oi
} }
// isZero is identical to reflect.Value.IsZero.
// TODO: Remove this when Go1.13 is the minimally supported Go version.
func isZero(v reflect.Value) bool {
switch v.Kind() {
case reflect.Bool:
return !v.Bool()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return v.Int() == 0
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return v.Uint() == 0
case reflect.Float32, reflect.Float64:
return math.Float64bits(v.Float()) == 0
case reflect.Complex64, reflect.Complex128:
c := v.Complex()
return math.Float64bits(real(c)) == 0 && math.Float64bits(imag(c)) == 0
case reflect.Array:
for i := 0; i < v.Len(); i++ {
if !isZero(v.Index(i)) {
return false
}
}
return true
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice, reflect.UnsafePointer:
return v.IsNil()
case reflect.String:
return v.Len() == 0
case reflect.Struct:
for i := 0; i < v.NumField(); i++ {
if !isZero(v.Field(i)) {
return false
}
}
return true
default:
panic(&reflect.ValueError{"reflect.Value.IsZero", v.Kind()})
}
}

View File

@ -121,6 +121,7 @@ func (p pointer) String() *string { return p.v.Interface().(*string) }
func (p pointer) StringPtr() **string { return p.v.Interface().(**string) } func (p pointer) StringPtr() **string { return p.v.Interface().(**string) }
func (p pointer) StringSlice() *[]string { return p.v.Interface().(*[]string) } func (p pointer) StringSlice() *[]string { return p.v.Interface().(*[]string) }
func (p pointer) Bytes() *[]byte { return p.v.Interface().(*[]byte) } func (p pointer) Bytes() *[]byte { return p.v.Interface().(*[]byte) }
func (p pointer) BytesPtr() **[]byte { return p.v.Interface().(**[]byte) }
func (p pointer) BytesSlice() *[][]byte { return p.v.Interface().(*[][]byte) } func (p pointer) BytesSlice() *[][]byte { return p.v.Interface().(*[][]byte) }
func (p pointer) WeakFields() *weakFields { return (*weakFields)(p.v.Interface().(*WeakFields)) } func (p pointer) WeakFields() *weakFields { return (*weakFields)(p.v.Interface().(*WeakFields)) }
func (p pointer) Extensions() *map[int32]ExtensionField { func (p pointer) Extensions() *map[int32]ExtensionField {

View File

@ -109,6 +109,7 @@ func (p pointer) String() *string { return (*string)(p.p)
func (p pointer) StringPtr() **string { return (**string)(p.p) } func (p pointer) StringPtr() **string { return (**string)(p.p) }
func (p pointer) StringSlice() *[]string { return (*[]string)(p.p) } func (p pointer) StringSlice() *[]string { return (*[]string)(p.p) }
func (p pointer) Bytes() *[]byte { return (*[]byte)(p.p) } func (p pointer) Bytes() *[]byte { return (*[]byte)(p.p) }
func (p pointer) BytesPtr() **[]byte { return (**[]byte)(p.p) }
func (p pointer) BytesSlice() *[][]byte { return (*[][]byte)(p.p) } func (p pointer) BytesSlice() *[][]byte { return (*[][]byte)(p.p) }
func (p pointer) WeakFields() *weakFields { return (*weakFields)(p.p) } func (p pointer) WeakFields() *weakFields { return (*weakFields)(p.p) }
func (p pointer) Extensions() *map[int32]ExtensionField { return (*map[int32]ExtensionField)(p.p) } func (p pointer) Extensions() *map[int32]ExtensionField { return (*map[int32]ExtensionField)(p.p) }

View File

@ -1,43 +0,0 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package mapsort provides sorted access to maps.
package mapsort
import (
"sort"
"google.golang.org/protobuf/reflect/protoreflect"
)
// Range iterates over every map entry in sorted key order,
// calling f for each key and value encountered.
func Range(mapv protoreflect.Map, keyKind protoreflect.Kind, f func(protoreflect.MapKey, protoreflect.Value) bool) {
var keys []protoreflect.MapKey
mapv.Range(func(key protoreflect.MapKey, _ protoreflect.Value) bool {
keys = append(keys, key)
return true
})
sort.Slice(keys, func(i, j int) bool {
switch keyKind {
case protoreflect.BoolKind:
return !keys[i].Bool() && keys[j].Bool()
case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind,
protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
return keys[i].Int() < keys[j].Int()
case protoreflect.Uint32Kind, protoreflect.Fixed32Kind,
protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
return keys[i].Uint() < keys[j].Uint()
case protoreflect.StringKind:
return keys[i].String() < keys[j].String()
default:
panic("invalid kind: " + keyKind.String())
}
})
for _, key := range keys {
if !f(key, mapv.Get(key)) {
break
}
}
}

View File

@ -0,0 +1,89 @@
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package order
import (
pref "google.golang.org/protobuf/reflect/protoreflect"
)
// FieldOrder specifies the ordering to visit message fields.
// It is a function that reports whether x is ordered before y.
type FieldOrder func(x, y pref.FieldDescriptor) bool
var (
// AnyFieldOrder specifies no specific field ordering.
AnyFieldOrder FieldOrder = nil
// LegacyFieldOrder sorts fields in the same ordering as emitted by
// wire serialization in the github.com/golang/protobuf implementation.
LegacyFieldOrder FieldOrder = func(x, y pref.FieldDescriptor) bool {
ox, oy := x.ContainingOneof(), y.ContainingOneof()
inOneof := func(od pref.OneofDescriptor) bool {
return od != nil && !od.IsSynthetic()
}
// Extension fields sort before non-extension fields.
if x.IsExtension() != y.IsExtension() {
return x.IsExtension() && !y.IsExtension()
}
// Fields not within a oneof sort before those within a oneof.
if inOneof(ox) != inOneof(oy) {
return !inOneof(ox) && inOneof(oy)
}
// Fields in disjoint oneof sets are sorted by declaration index.
if ox != nil && oy != nil && ox != oy {
return ox.Index() < oy.Index()
}
// Fields sorted by field number.
return x.Number() < y.Number()
}
// NumberFieldOrder sorts fields by their field number.
NumberFieldOrder FieldOrder = func(x, y pref.FieldDescriptor) bool {
return x.Number() < y.Number()
}
// IndexNameFieldOrder sorts non-extension fields before extension fields.
// Non-extensions are sorted according to their declaration index.
// Extensions are sorted according to their full name.
IndexNameFieldOrder FieldOrder = func(x, y pref.FieldDescriptor) bool {
// Non-extension fields sort before extension fields.
if x.IsExtension() != y.IsExtension() {
return !x.IsExtension() && y.IsExtension()
}
// Extensions sorted by fullname.
if x.IsExtension() && y.IsExtension() {
return x.FullName() < y.FullName()
}
// Non-extensions sorted by declaration index.
return x.Index() < y.Index()
}
)
// KeyOrder specifies the ordering to visit map entries.
// It is a function that reports whether x is ordered before y.
type KeyOrder func(x, y pref.MapKey) bool
var (
// AnyKeyOrder specifies no specific key ordering.
AnyKeyOrder KeyOrder = nil
// GenericKeyOrder sorts false before true, numeric keys in ascending order,
// and strings in lexicographical ordering according to UTF-8 codepoints.
GenericKeyOrder KeyOrder = func(x, y pref.MapKey) bool {
switch x.Interface().(type) {
case bool:
return !x.Bool() && y.Bool()
case int32, int64:
return x.Int() < y.Int()
case uint32, uint64:
return x.Uint() < y.Uint()
case string:
return x.String() < y.String()
default:
panic("invalid map key type")
}
}
)

View File

@ -0,0 +1,115 @@
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package order provides ordered access to messages and maps.
package order
import (
"sort"
"sync"
pref "google.golang.org/protobuf/reflect/protoreflect"
)
type messageField struct {
fd pref.FieldDescriptor
v pref.Value
}
var messageFieldPool = sync.Pool{
New: func() interface{} { return new([]messageField) },
}
type (
// FieldRnger is an interface for visiting all fields in a message.
// The protoreflect.Message type implements this interface.
FieldRanger interface{ Range(VisitField) }
// VisitField is called everytime a message field is visited.
VisitField = func(pref.FieldDescriptor, pref.Value) bool
)
// RangeFields iterates over the fields of fs according to the specified order.
func RangeFields(fs FieldRanger, less FieldOrder, fn VisitField) {
if less == nil {
fs.Range(fn)
return
}
// Obtain a pre-allocated scratch buffer.
p := messageFieldPool.Get().(*[]messageField)
fields := (*p)[:0]
defer func() {
if cap(fields) < 1024 {
*p = fields
messageFieldPool.Put(p)
}
}()
// Collect all fields in the message and sort them.
fs.Range(func(fd pref.FieldDescriptor, v pref.Value) bool {
fields = append(fields, messageField{fd, v})
return true
})
sort.Slice(fields, func(i, j int) bool {
return less(fields[i].fd, fields[j].fd)
})
// Visit the fields in the specified ordering.
for _, f := range fields {
if !fn(f.fd, f.v) {
return
}
}
}
type mapEntry struct {
k pref.MapKey
v pref.Value
}
var mapEntryPool = sync.Pool{
New: func() interface{} { return new([]mapEntry) },
}
type (
// EntryRanger is an interface for visiting all fields in a message.
// The protoreflect.Map type implements this interface.
EntryRanger interface{ Range(VisitEntry) }
// VisitEntry is called everytime a map entry is visited.
VisitEntry = func(pref.MapKey, pref.Value) bool
)
// RangeEntries iterates over the entries of es according to the specified order.
func RangeEntries(es EntryRanger, less KeyOrder, fn VisitEntry) {
if less == nil {
es.Range(fn)
return
}
// Obtain a pre-allocated scratch buffer.
p := mapEntryPool.Get().(*[]mapEntry)
entries := (*p)[:0]
defer func() {
if cap(entries) < 1024 {
*p = entries
mapEntryPool.Put(p)
}
}()
// Collect all entries in the map and sort them.
es.Range(func(k pref.MapKey, v pref.Value) bool {
entries = append(entries, mapEntry{k, v})
return true
})
sort.Slice(entries, func(i, j int) bool {
return less(entries[i].k, entries[j].k)
})
// Visit the entries in the specified ordering.
for _, e := range entries {
if !fn(e.k, e.v) {
return
}
}
}

View File

@ -52,7 +52,7 @@ import (
// 10. Send out the CL for review and submit it. // 10. Send out the CL for review and submit it.
const ( const (
Major = 1 Major = 1
Minor = 25 Minor = 26
Patch = 0 Patch = 0
PreRelease = "" PreRelease = ""
) )

View File

@ -45,12 +45,14 @@ type UnmarshalOptions struct {
} }
// Unmarshal parses the wire-format message in b and places the result in m. // Unmarshal parses the wire-format message in b and places the result in m.
// The provided message must be mutable (e.g., a non-nil pointer to a message).
func Unmarshal(b []byte, m Message) error { func Unmarshal(b []byte, m Message) error {
_, err := UnmarshalOptions{}.unmarshal(b, m.ProtoReflect()) _, err := UnmarshalOptions{}.unmarshal(b, m.ProtoReflect())
return err return err
} }
// Unmarshal parses the wire-format message in b and places the result in m. // Unmarshal parses the wire-format message in b and places the result in m.
// The provided message must be mutable (e.g., a non-nil pointer to a message).
func (o UnmarshalOptions) Unmarshal(b []byte, m Message) error { func (o UnmarshalOptions) Unmarshal(b []byte, m Message) error {
_, err := o.unmarshal(b, m.ProtoReflect()) _, err := o.unmarshal(b, m.ProtoReflect())
return err return err
@ -116,10 +118,10 @@ func (o UnmarshalOptions) unmarshalMessageSlow(b []byte, m protoreflect.Message)
// Parse the tag (field number and wire type). // Parse the tag (field number and wire type).
num, wtyp, tagLen := protowire.ConsumeTag(b) num, wtyp, tagLen := protowire.ConsumeTag(b)
if tagLen < 0 { if tagLen < 0 {
return protowire.ParseError(tagLen) return errDecode
} }
if num > protowire.MaxValidNumber { if num > protowire.MaxValidNumber {
return errors.New("invalid field number") return errDecode
} }
// Find the field descriptor for this field number. // Find the field descriptor for this field number.
@ -159,7 +161,7 @@ func (o UnmarshalOptions) unmarshalMessageSlow(b []byte, m protoreflect.Message)
} }
valLen = protowire.ConsumeFieldValue(num, wtyp, b[tagLen:]) valLen = protowire.ConsumeFieldValue(num, wtyp, b[tagLen:])
if valLen < 0 { if valLen < 0 {
return protowire.ParseError(valLen) return errDecode
} }
if !o.DiscardUnknown { if !o.DiscardUnknown {
m.SetUnknown(append(m.GetUnknown(), b[:tagLen+valLen]...)) m.SetUnknown(append(m.GetUnknown(), b[:tagLen+valLen]...))
@ -194,7 +196,7 @@ func (o UnmarshalOptions) unmarshalMap(b []byte, wtyp protowire.Type, mapv proto
} }
b, n = protowire.ConsumeBytes(b) b, n = protowire.ConsumeBytes(b)
if n < 0 { if n < 0 {
return 0, protowire.ParseError(n) return 0, errDecode
} }
var ( var (
keyField = fd.MapKey() keyField = fd.MapKey()
@ -213,10 +215,10 @@ func (o UnmarshalOptions) unmarshalMap(b []byte, wtyp protowire.Type, mapv proto
for len(b) > 0 { for len(b) > 0 {
num, wtyp, n := protowire.ConsumeTag(b) num, wtyp, n := protowire.ConsumeTag(b)
if n < 0 { if n < 0 {
return 0, protowire.ParseError(n) return 0, errDecode
} }
if num > protowire.MaxValidNumber { if num > protowire.MaxValidNumber {
return 0, errors.New("invalid field number") return 0, errDecode
} }
b = b[n:] b = b[n:]
err = errUnknown err = errUnknown
@ -246,7 +248,7 @@ func (o UnmarshalOptions) unmarshalMap(b []byte, wtyp protowire.Type, mapv proto
if err == errUnknown { if err == errUnknown {
n = protowire.ConsumeFieldValue(num, wtyp, b) n = protowire.ConsumeFieldValue(num, wtyp, b)
if n < 0 { if n < 0 {
return 0, protowire.ParseError(n) return 0, errDecode
} }
} else if err != nil { } else if err != nil {
return 0, err return 0, err
@ -272,3 +274,5 @@ func (o UnmarshalOptions) unmarshalMap(b []byte, wtyp protowire.Type, mapv proto
// to the unknown field set of a message. It is never returned from an exported // to the unknown field set of a message. It is never returned from an exported
// function. // function.
var errUnknown = errors.New("BUG: internal error (unknown)") var errUnknown = errors.New("BUG: internal error (unknown)")
var errDecode = errors.New("cannot parse invalid wire-format data")

View File

@ -27,7 +27,7 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp protowire.Type, fd prot
} }
v, n := protowire.ConsumeVarint(b) v, n := protowire.ConsumeVarint(b)
if n < 0 { if n < 0 {
return val, 0, protowire.ParseError(n) return val, 0, errDecode
} }
return protoreflect.ValueOfBool(protowire.DecodeBool(v)), n, nil return protoreflect.ValueOfBool(protowire.DecodeBool(v)), n, nil
case protoreflect.EnumKind: case protoreflect.EnumKind:
@ -36,7 +36,7 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp protowire.Type, fd prot
} }
v, n := protowire.ConsumeVarint(b) v, n := protowire.ConsumeVarint(b)
if n < 0 { if n < 0 {
return val, 0, protowire.ParseError(n) return val, 0, errDecode
} }
return protoreflect.ValueOfEnum(protoreflect.EnumNumber(v)), n, nil return protoreflect.ValueOfEnum(protoreflect.EnumNumber(v)), n, nil
case protoreflect.Int32Kind: case protoreflect.Int32Kind:
@ -45,7 +45,7 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp protowire.Type, fd prot
} }
v, n := protowire.ConsumeVarint(b) v, n := protowire.ConsumeVarint(b)
if n < 0 { if n < 0 {
return val, 0, protowire.ParseError(n) return val, 0, errDecode
} }
return protoreflect.ValueOfInt32(int32(v)), n, nil return protoreflect.ValueOfInt32(int32(v)), n, nil
case protoreflect.Sint32Kind: case protoreflect.Sint32Kind:
@ -54,7 +54,7 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp protowire.Type, fd prot
} }
v, n := protowire.ConsumeVarint(b) v, n := protowire.ConsumeVarint(b)
if n < 0 { if n < 0 {
return val, 0, protowire.ParseError(n) return val, 0, errDecode
} }
return protoreflect.ValueOfInt32(int32(protowire.DecodeZigZag(v & math.MaxUint32))), n, nil return protoreflect.ValueOfInt32(int32(protowire.DecodeZigZag(v & math.MaxUint32))), n, nil
case protoreflect.Uint32Kind: case protoreflect.Uint32Kind:
@ -63,7 +63,7 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp protowire.Type, fd prot
} }
v, n := protowire.ConsumeVarint(b) v, n := protowire.ConsumeVarint(b)
if n < 0 { if n < 0 {
return val, 0, protowire.ParseError(n) return val, 0, errDecode
} }
return protoreflect.ValueOfUint32(uint32(v)), n, nil return protoreflect.ValueOfUint32(uint32(v)), n, nil
case protoreflect.Int64Kind: case protoreflect.Int64Kind:
@ -72,7 +72,7 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp protowire.Type, fd prot
} }
v, n := protowire.ConsumeVarint(b) v, n := protowire.ConsumeVarint(b)
if n < 0 { if n < 0 {
return val, 0, protowire.ParseError(n) return val, 0, errDecode
} }
return protoreflect.ValueOfInt64(int64(v)), n, nil return protoreflect.ValueOfInt64(int64(v)), n, nil
case protoreflect.Sint64Kind: case protoreflect.Sint64Kind:
@ -81,7 +81,7 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp protowire.Type, fd prot
} }
v, n := protowire.ConsumeVarint(b) v, n := protowire.ConsumeVarint(b)
if n < 0 { if n < 0 {
return val, 0, protowire.ParseError(n) return val, 0, errDecode
} }
return protoreflect.ValueOfInt64(protowire.DecodeZigZag(v)), n, nil return protoreflect.ValueOfInt64(protowire.DecodeZigZag(v)), n, nil
case protoreflect.Uint64Kind: case protoreflect.Uint64Kind:
@ -90,7 +90,7 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp protowire.Type, fd prot
} }
v, n := protowire.ConsumeVarint(b) v, n := protowire.ConsumeVarint(b)
if n < 0 { if n < 0 {
return val, 0, protowire.ParseError(n) return val, 0, errDecode
} }
return protoreflect.ValueOfUint64(v), n, nil return protoreflect.ValueOfUint64(v), n, nil
case protoreflect.Sfixed32Kind: case protoreflect.Sfixed32Kind:
@ -99,7 +99,7 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp protowire.Type, fd prot
} }
v, n := protowire.ConsumeFixed32(b) v, n := protowire.ConsumeFixed32(b)
if n < 0 { if n < 0 {
return val, 0, protowire.ParseError(n) return val, 0, errDecode
} }
return protoreflect.ValueOfInt32(int32(v)), n, nil return protoreflect.ValueOfInt32(int32(v)), n, nil
case protoreflect.Fixed32Kind: case protoreflect.Fixed32Kind:
@ -108,7 +108,7 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp protowire.Type, fd prot
} }
v, n := protowire.ConsumeFixed32(b) v, n := protowire.ConsumeFixed32(b)
if n < 0 { if n < 0 {
return val, 0, protowire.ParseError(n) return val, 0, errDecode
} }
return protoreflect.ValueOfUint32(uint32(v)), n, nil return protoreflect.ValueOfUint32(uint32(v)), n, nil
case protoreflect.FloatKind: case protoreflect.FloatKind:
@ -117,7 +117,7 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp protowire.Type, fd prot
} }
v, n := protowire.ConsumeFixed32(b) v, n := protowire.ConsumeFixed32(b)
if n < 0 { if n < 0 {
return val, 0, protowire.ParseError(n) return val, 0, errDecode
} }
return protoreflect.ValueOfFloat32(math.Float32frombits(uint32(v))), n, nil return protoreflect.ValueOfFloat32(math.Float32frombits(uint32(v))), n, nil
case protoreflect.Sfixed64Kind: case protoreflect.Sfixed64Kind:
@ -126,7 +126,7 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp protowire.Type, fd prot
} }
v, n := protowire.ConsumeFixed64(b) v, n := protowire.ConsumeFixed64(b)
if n < 0 { if n < 0 {
return val, 0, protowire.ParseError(n) return val, 0, errDecode
} }
return protoreflect.ValueOfInt64(int64(v)), n, nil return protoreflect.ValueOfInt64(int64(v)), n, nil
case protoreflect.Fixed64Kind: case protoreflect.Fixed64Kind:
@ -135,7 +135,7 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp protowire.Type, fd prot
} }
v, n := protowire.ConsumeFixed64(b) v, n := protowire.ConsumeFixed64(b)
if n < 0 { if n < 0 {
return val, 0, protowire.ParseError(n) return val, 0, errDecode
} }
return protoreflect.ValueOfUint64(v), n, nil return protoreflect.ValueOfUint64(v), n, nil
case protoreflect.DoubleKind: case protoreflect.DoubleKind:
@ -144,7 +144,7 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp protowire.Type, fd prot
} }
v, n := protowire.ConsumeFixed64(b) v, n := protowire.ConsumeFixed64(b)
if n < 0 { if n < 0 {
return val, 0, protowire.ParseError(n) return val, 0, errDecode
} }
return protoreflect.ValueOfFloat64(math.Float64frombits(v)), n, nil return protoreflect.ValueOfFloat64(math.Float64frombits(v)), n, nil
case protoreflect.StringKind: case protoreflect.StringKind:
@ -153,7 +153,7 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp protowire.Type, fd prot
} }
v, n := protowire.ConsumeBytes(b) v, n := protowire.ConsumeBytes(b)
if n < 0 { if n < 0 {
return val, 0, protowire.ParseError(n) return val, 0, errDecode
} }
if strs.EnforceUTF8(fd) && !utf8.Valid(v) { if strs.EnforceUTF8(fd) && !utf8.Valid(v) {
return protoreflect.Value{}, 0, errors.InvalidUTF8(string(fd.FullName())) return protoreflect.Value{}, 0, errors.InvalidUTF8(string(fd.FullName()))
@ -165,7 +165,7 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp protowire.Type, fd prot
} }
v, n := protowire.ConsumeBytes(b) v, n := protowire.ConsumeBytes(b)
if n < 0 { if n < 0 {
return val, 0, protowire.ParseError(n) return val, 0, errDecode
} }
return protoreflect.ValueOfBytes(append(emptyBuf[:], v...)), n, nil return protoreflect.ValueOfBytes(append(emptyBuf[:], v...)), n, nil
case protoreflect.MessageKind: case protoreflect.MessageKind:
@ -174,7 +174,7 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp protowire.Type, fd prot
} }
v, n := protowire.ConsumeBytes(b) v, n := protowire.ConsumeBytes(b)
if n < 0 { if n < 0 {
return val, 0, protowire.ParseError(n) return val, 0, errDecode
} }
return protoreflect.ValueOfBytes(v), n, nil return protoreflect.ValueOfBytes(v), n, nil
case protoreflect.GroupKind: case protoreflect.GroupKind:
@ -183,7 +183,7 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp protowire.Type, fd prot
} }
v, n := protowire.ConsumeGroup(fd.Number(), b) v, n := protowire.ConsumeGroup(fd.Number(), b)
if n < 0 { if n < 0 {
return val, 0, protowire.ParseError(n) return val, 0, errDecode
} }
return protoreflect.ValueOfBytes(v), n, nil return protoreflect.ValueOfBytes(v), n, nil
default: default:
@ -197,12 +197,12 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp protowire.Type, list prot
if wtyp == protowire.BytesType { if wtyp == protowire.BytesType {
buf, n := protowire.ConsumeBytes(b) buf, n := protowire.ConsumeBytes(b)
if n < 0 { if n < 0 {
return 0, protowire.ParseError(n) return 0, errDecode
} }
for len(buf) > 0 { for len(buf) > 0 {
v, n := protowire.ConsumeVarint(buf) v, n := protowire.ConsumeVarint(buf)
if n < 0 { if n < 0 {
return 0, protowire.ParseError(n) return 0, errDecode
} }
buf = buf[n:] buf = buf[n:]
list.Append(protoreflect.ValueOfBool(protowire.DecodeBool(v))) list.Append(protoreflect.ValueOfBool(protowire.DecodeBool(v)))
@ -214,7 +214,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp protowire.Type, list prot
} }
v, n := protowire.ConsumeVarint(b) v, n := protowire.ConsumeVarint(b)
if n < 0 { if n < 0 {
return 0, protowire.ParseError(n) return 0, errDecode
} }
list.Append(protoreflect.ValueOfBool(protowire.DecodeBool(v))) list.Append(protoreflect.ValueOfBool(protowire.DecodeBool(v)))
return n, nil return n, nil
@ -222,12 +222,12 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp protowire.Type, list prot
if wtyp == protowire.BytesType { if wtyp == protowire.BytesType {
buf, n := protowire.ConsumeBytes(b) buf, n := protowire.ConsumeBytes(b)
if n < 0 { if n < 0 {
return 0, protowire.ParseError(n) return 0, errDecode
} }
for len(buf) > 0 { for len(buf) > 0 {
v, n := protowire.ConsumeVarint(buf) v, n := protowire.ConsumeVarint(buf)
if n < 0 { if n < 0 {
return 0, protowire.ParseError(n) return 0, errDecode
} }
buf = buf[n:] buf = buf[n:]
list.Append(protoreflect.ValueOfEnum(protoreflect.EnumNumber(v))) list.Append(protoreflect.ValueOfEnum(protoreflect.EnumNumber(v)))
@ -239,7 +239,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp protowire.Type, list prot
} }
v, n := protowire.ConsumeVarint(b) v, n := protowire.ConsumeVarint(b)
if n < 0 { if n < 0 {
return 0, protowire.ParseError(n) return 0, errDecode
} }
list.Append(protoreflect.ValueOfEnum(protoreflect.EnumNumber(v))) list.Append(protoreflect.ValueOfEnum(protoreflect.EnumNumber(v)))
return n, nil return n, nil
@ -247,12 +247,12 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp protowire.Type, list prot
if wtyp == protowire.BytesType { if wtyp == protowire.BytesType {
buf, n := protowire.ConsumeBytes(b) buf, n := protowire.ConsumeBytes(b)
if n < 0 { if n < 0 {
return 0, protowire.ParseError(n) return 0, errDecode
} }
for len(buf) > 0 { for len(buf) > 0 {
v, n := protowire.ConsumeVarint(buf) v, n := protowire.ConsumeVarint(buf)
if n < 0 { if n < 0 {
return 0, protowire.ParseError(n) return 0, errDecode
} }
buf = buf[n:] buf = buf[n:]
list.Append(protoreflect.ValueOfInt32(int32(v))) list.Append(protoreflect.ValueOfInt32(int32(v)))
@ -264,7 +264,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp protowire.Type, list prot
} }
v, n := protowire.ConsumeVarint(b) v, n := protowire.ConsumeVarint(b)
if n < 0 { if n < 0 {
return 0, protowire.ParseError(n) return 0, errDecode
} }
list.Append(protoreflect.ValueOfInt32(int32(v))) list.Append(protoreflect.ValueOfInt32(int32(v)))
return n, nil return n, nil
@ -272,12 +272,12 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp protowire.Type, list prot
if wtyp == protowire.BytesType { if wtyp == protowire.BytesType {
buf, n := protowire.ConsumeBytes(b) buf, n := protowire.ConsumeBytes(b)
if n < 0 { if n < 0 {
return 0, protowire.ParseError(n) return 0, errDecode
} }
for len(buf) > 0 { for len(buf) > 0 {
v, n := protowire.ConsumeVarint(buf) v, n := protowire.ConsumeVarint(buf)
if n < 0 { if n < 0 {
return 0, protowire.ParseError(n) return 0, errDecode
} }
buf = buf[n:] buf = buf[n:]
list.Append(protoreflect.ValueOfInt32(int32(protowire.DecodeZigZag(v & math.MaxUint32)))) list.Append(protoreflect.ValueOfInt32(int32(protowire.DecodeZigZag(v & math.MaxUint32))))
@ -289,7 +289,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp protowire.Type, list prot
} }
v, n := protowire.ConsumeVarint(b) v, n := protowire.ConsumeVarint(b)
if n < 0 { if n < 0 {
return 0, protowire.ParseError(n) return 0, errDecode
} }
list.Append(protoreflect.ValueOfInt32(int32(protowire.DecodeZigZag(v & math.MaxUint32)))) list.Append(protoreflect.ValueOfInt32(int32(protowire.DecodeZigZag(v & math.MaxUint32))))
return n, nil return n, nil
@ -297,12 +297,12 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp protowire.Type, list prot
if wtyp == protowire.BytesType { if wtyp == protowire.BytesType {
buf, n := protowire.ConsumeBytes(b) buf, n := protowire.ConsumeBytes(b)
if n < 0 { if n < 0 {
return 0, protowire.ParseError(n) return 0, errDecode
} }
for len(buf) > 0 { for len(buf) > 0 {
v, n := protowire.ConsumeVarint(buf) v, n := protowire.ConsumeVarint(buf)
if n < 0 { if n < 0 {
return 0, protowire.ParseError(n) return 0, errDecode
} }
buf = buf[n:] buf = buf[n:]
list.Append(protoreflect.ValueOfUint32(uint32(v))) list.Append(protoreflect.ValueOfUint32(uint32(v)))
@ -314,7 +314,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp protowire.Type, list prot
} }
v, n := protowire.ConsumeVarint(b) v, n := protowire.ConsumeVarint(b)
if n < 0 { if n < 0 {
return 0, protowire.ParseError(n) return 0, errDecode
} }
list.Append(protoreflect.ValueOfUint32(uint32(v))) list.Append(protoreflect.ValueOfUint32(uint32(v)))
return n, nil return n, nil
@ -322,12 +322,12 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp protowire.Type, list prot
if wtyp == protowire.BytesType { if wtyp == protowire.BytesType {
buf, n := protowire.ConsumeBytes(b) buf, n := protowire.ConsumeBytes(b)
if n < 0 { if n < 0 {
return 0, protowire.ParseError(n) return 0, errDecode
} }
for len(buf) > 0 { for len(buf) > 0 {
v, n := protowire.ConsumeVarint(buf) v, n := protowire.ConsumeVarint(buf)
if n < 0 { if n < 0 {
return 0, protowire.ParseError(n) return 0, errDecode
} }
buf = buf[n:] buf = buf[n:]
list.Append(protoreflect.ValueOfInt64(int64(v))) list.Append(protoreflect.ValueOfInt64(int64(v)))
@ -339,7 +339,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp protowire.Type, list prot
} }
v, n := protowire.ConsumeVarint(b) v, n := protowire.ConsumeVarint(b)
if n < 0 { if n < 0 {
return 0, protowire.ParseError(n) return 0, errDecode
} }
list.Append(protoreflect.ValueOfInt64(int64(v))) list.Append(protoreflect.ValueOfInt64(int64(v)))
return n, nil return n, nil
@ -347,12 +347,12 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp protowire.Type, list prot
if wtyp == protowire.BytesType { if wtyp == protowire.BytesType {
buf, n := protowire.ConsumeBytes(b) buf, n := protowire.ConsumeBytes(b)
if n < 0 { if n < 0 {
return 0, protowire.ParseError(n) return 0, errDecode
} }
for len(buf) > 0 { for len(buf) > 0 {
v, n := protowire.ConsumeVarint(buf) v, n := protowire.ConsumeVarint(buf)
if n < 0 { if n < 0 {
return 0, protowire.ParseError(n) return 0, errDecode
} }
buf = buf[n:] buf = buf[n:]
list.Append(protoreflect.ValueOfInt64(protowire.DecodeZigZag(v))) list.Append(protoreflect.ValueOfInt64(protowire.DecodeZigZag(v)))
@ -364,7 +364,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp protowire.Type, list prot
} }
v, n := protowire.ConsumeVarint(b) v, n := protowire.ConsumeVarint(b)
if n < 0 { if n < 0 {
return 0, protowire.ParseError(n) return 0, errDecode
} }
list.Append(protoreflect.ValueOfInt64(protowire.DecodeZigZag(v))) list.Append(protoreflect.ValueOfInt64(protowire.DecodeZigZag(v)))
return n, nil return n, nil
@ -372,12 +372,12 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp protowire.Type, list prot
if wtyp == protowire.BytesType { if wtyp == protowire.BytesType {
buf, n := protowire.ConsumeBytes(b) buf, n := protowire.ConsumeBytes(b)
if n < 0 { if n < 0 {
return 0, protowire.ParseError(n) return 0, errDecode
} }
for len(buf) > 0 { for len(buf) > 0 {
v, n := protowire.ConsumeVarint(buf) v, n := protowire.ConsumeVarint(buf)
if n < 0 { if n < 0 {
return 0, protowire.ParseError(n) return 0, errDecode
} }
buf = buf[n:] buf = buf[n:]
list.Append(protoreflect.ValueOfUint64(v)) list.Append(protoreflect.ValueOfUint64(v))
@ -389,7 +389,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp protowire.Type, list prot
} }
v, n := protowire.ConsumeVarint(b) v, n := protowire.ConsumeVarint(b)
if n < 0 { if n < 0 {
return 0, protowire.ParseError(n) return 0, errDecode
} }
list.Append(protoreflect.ValueOfUint64(v)) list.Append(protoreflect.ValueOfUint64(v))
return n, nil return n, nil
@ -397,12 +397,12 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp protowire.Type, list prot
if wtyp == protowire.BytesType { if wtyp == protowire.BytesType {
buf, n := protowire.ConsumeBytes(b) buf, n := protowire.ConsumeBytes(b)
if n < 0 { if n < 0 {
return 0, protowire.ParseError(n) return 0, errDecode
} }
for len(buf) > 0 { for len(buf) > 0 {
v, n := protowire.ConsumeFixed32(buf) v, n := protowire.ConsumeFixed32(buf)
if n < 0 { if n < 0 {
return 0, protowire.ParseError(n) return 0, errDecode
} }
buf = buf[n:] buf = buf[n:]
list.Append(protoreflect.ValueOfInt32(int32(v))) list.Append(protoreflect.ValueOfInt32(int32(v)))
@ -414,7 +414,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp protowire.Type, list prot
} }
v, n := protowire.ConsumeFixed32(b) v, n := protowire.ConsumeFixed32(b)
if n < 0 { if n < 0 {
return 0, protowire.ParseError(n) return 0, errDecode
} }
list.Append(protoreflect.ValueOfInt32(int32(v))) list.Append(protoreflect.ValueOfInt32(int32(v)))
return n, nil return n, nil
@ -422,12 +422,12 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp protowire.Type, list prot
if wtyp == protowire.BytesType { if wtyp == protowire.BytesType {
buf, n := protowire.ConsumeBytes(b) buf, n := protowire.ConsumeBytes(b)
if n < 0 { if n < 0 {
return 0, protowire.ParseError(n) return 0, errDecode
} }
for len(buf) > 0 { for len(buf) > 0 {
v, n := protowire.ConsumeFixed32(buf) v, n := protowire.ConsumeFixed32(buf)
if n < 0 { if n < 0 {
return 0, protowire.ParseError(n) return 0, errDecode
} }
buf = buf[n:] buf = buf[n:]
list.Append(protoreflect.ValueOfUint32(uint32(v))) list.Append(protoreflect.ValueOfUint32(uint32(v)))
@ -439,7 +439,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp protowire.Type, list prot
} }
v, n := protowire.ConsumeFixed32(b) v, n := protowire.ConsumeFixed32(b)
if n < 0 { if n < 0 {
return 0, protowire.ParseError(n) return 0, errDecode
} }
list.Append(protoreflect.ValueOfUint32(uint32(v))) list.Append(protoreflect.ValueOfUint32(uint32(v)))
return n, nil return n, nil
@ -447,12 +447,12 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp protowire.Type, list prot
if wtyp == protowire.BytesType { if wtyp == protowire.BytesType {
buf, n := protowire.ConsumeBytes(b) buf, n := protowire.ConsumeBytes(b)
if n < 0 { if n < 0 {
return 0, protowire.ParseError(n) return 0, errDecode
} }
for len(buf) > 0 { for len(buf) > 0 {
v, n := protowire.ConsumeFixed32(buf) v, n := protowire.ConsumeFixed32(buf)
if n < 0 { if n < 0 {
return 0, protowire.ParseError(n) return 0, errDecode
} }
buf = buf[n:] buf = buf[n:]
list.Append(protoreflect.ValueOfFloat32(math.Float32frombits(uint32(v)))) list.Append(protoreflect.ValueOfFloat32(math.Float32frombits(uint32(v))))
@ -464,7 +464,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp protowire.Type, list prot
} }
v, n := protowire.ConsumeFixed32(b) v, n := protowire.ConsumeFixed32(b)
if n < 0 { if n < 0 {
return 0, protowire.ParseError(n) return 0, errDecode
} }
list.Append(protoreflect.ValueOfFloat32(math.Float32frombits(uint32(v)))) list.Append(protoreflect.ValueOfFloat32(math.Float32frombits(uint32(v))))
return n, nil return n, nil
@ -472,12 +472,12 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp protowire.Type, list prot
if wtyp == protowire.BytesType { if wtyp == protowire.BytesType {
buf, n := protowire.ConsumeBytes(b) buf, n := protowire.ConsumeBytes(b)
if n < 0 { if n < 0 {
return 0, protowire.ParseError(n) return 0, errDecode
} }
for len(buf) > 0 { for len(buf) > 0 {
v, n := protowire.ConsumeFixed64(buf) v, n := protowire.ConsumeFixed64(buf)
if n < 0 { if n < 0 {
return 0, protowire.ParseError(n) return 0, errDecode
} }
buf = buf[n:] buf = buf[n:]
list.Append(protoreflect.ValueOfInt64(int64(v))) list.Append(protoreflect.ValueOfInt64(int64(v)))
@ -489,7 +489,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp protowire.Type, list prot
} }
v, n := protowire.ConsumeFixed64(b) v, n := protowire.ConsumeFixed64(b)
if n < 0 { if n < 0 {
return 0, protowire.ParseError(n) return 0, errDecode
} }
list.Append(protoreflect.ValueOfInt64(int64(v))) list.Append(protoreflect.ValueOfInt64(int64(v)))
return n, nil return n, nil
@ -497,12 +497,12 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp protowire.Type, list prot
if wtyp == protowire.BytesType { if wtyp == protowire.BytesType {
buf, n := protowire.ConsumeBytes(b) buf, n := protowire.ConsumeBytes(b)
if n < 0 { if n < 0 {
return 0, protowire.ParseError(n) return 0, errDecode
} }
for len(buf) > 0 { for len(buf) > 0 {
v, n := protowire.ConsumeFixed64(buf) v, n := protowire.ConsumeFixed64(buf)
if n < 0 { if n < 0 {
return 0, protowire.ParseError(n) return 0, errDecode
} }
buf = buf[n:] buf = buf[n:]
list.Append(protoreflect.ValueOfUint64(v)) list.Append(protoreflect.ValueOfUint64(v))
@ -514,7 +514,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp protowire.Type, list prot
} }
v, n := protowire.ConsumeFixed64(b) v, n := protowire.ConsumeFixed64(b)
if n < 0 { if n < 0 {
return 0, protowire.ParseError(n) return 0, errDecode
} }
list.Append(protoreflect.ValueOfUint64(v)) list.Append(protoreflect.ValueOfUint64(v))
return n, nil return n, nil
@ -522,12 +522,12 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp protowire.Type, list prot
if wtyp == protowire.BytesType { if wtyp == protowire.BytesType {
buf, n := protowire.ConsumeBytes(b) buf, n := protowire.ConsumeBytes(b)
if n < 0 { if n < 0 {
return 0, protowire.ParseError(n) return 0, errDecode
} }
for len(buf) > 0 { for len(buf) > 0 {
v, n := protowire.ConsumeFixed64(buf) v, n := protowire.ConsumeFixed64(buf)
if n < 0 { if n < 0 {
return 0, protowire.ParseError(n) return 0, errDecode
} }
buf = buf[n:] buf = buf[n:]
list.Append(protoreflect.ValueOfFloat64(math.Float64frombits(v))) list.Append(protoreflect.ValueOfFloat64(math.Float64frombits(v)))
@ -539,7 +539,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp protowire.Type, list prot
} }
v, n := protowire.ConsumeFixed64(b) v, n := protowire.ConsumeFixed64(b)
if n < 0 { if n < 0 {
return 0, protowire.ParseError(n) return 0, errDecode
} }
list.Append(protoreflect.ValueOfFloat64(math.Float64frombits(v))) list.Append(protoreflect.ValueOfFloat64(math.Float64frombits(v)))
return n, nil return n, nil
@ -549,7 +549,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp protowire.Type, list prot
} }
v, n := protowire.ConsumeBytes(b) v, n := protowire.ConsumeBytes(b)
if n < 0 { if n < 0 {
return 0, protowire.ParseError(n) return 0, errDecode
} }
if strs.EnforceUTF8(fd) && !utf8.Valid(v) { if strs.EnforceUTF8(fd) && !utf8.Valid(v) {
return 0, errors.InvalidUTF8(string(fd.FullName())) return 0, errors.InvalidUTF8(string(fd.FullName()))
@ -562,7 +562,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp protowire.Type, list prot
} }
v, n := protowire.ConsumeBytes(b) v, n := protowire.ConsumeBytes(b)
if n < 0 { if n < 0 {
return 0, protowire.ParseError(n) return 0, errDecode
} }
list.Append(protoreflect.ValueOfBytes(append(emptyBuf[:], v...))) list.Append(protoreflect.ValueOfBytes(append(emptyBuf[:], v...)))
return n, nil return n, nil
@ -572,7 +572,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp protowire.Type, list prot
} }
v, n := protowire.ConsumeBytes(b) v, n := protowire.ConsumeBytes(b)
if n < 0 { if n < 0 {
return 0, protowire.ParseError(n) return 0, errDecode
} }
m := list.NewElement() m := list.NewElement()
if err := o.unmarshalMessage(v, m.Message()); err != nil { if err := o.unmarshalMessage(v, m.Message()); err != nil {
@ -586,7 +586,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp protowire.Type, list prot
} }
v, n := protowire.ConsumeGroup(fd.Number(), b) v, n := protowire.ConsumeGroup(fd.Number(), b)
if n < 0 { if n < 0 {
return 0, protowire.ParseError(n) return 0, errDecode
} }
m := list.NewElement() m := list.NewElement()
if err := o.unmarshalMessage(v, m.Message()); err != nil { if err := o.unmarshalMessage(v, m.Message()); err != nil {

View File

@ -5,12 +5,9 @@
package proto package proto
import ( import (
"sort"
"google.golang.org/protobuf/encoding/protowire" "google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/internal/encoding/messageset" "google.golang.org/protobuf/internal/encoding/messageset"
"google.golang.org/protobuf/internal/fieldsort" "google.golang.org/protobuf/internal/order"
"google.golang.org/protobuf/internal/mapsort"
"google.golang.org/protobuf/internal/pragma" "google.golang.org/protobuf/internal/pragma"
"google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/runtime/protoiface" "google.golang.org/protobuf/runtime/protoiface"
@ -211,14 +208,15 @@ func (o MarshalOptions) marshalMessageSlow(b []byte, m protoreflect.Message) ([]
if messageset.IsMessageSet(m.Descriptor()) { if messageset.IsMessageSet(m.Descriptor()) {
return o.marshalMessageSet(b, m) return o.marshalMessageSet(b, m)
} }
// There are many choices for what order we visit fields in. The default one here fieldOrder := order.AnyFieldOrder
// is chosen for reasonable efficiency and simplicity given the protoreflect API. if o.Deterministic {
// It is not deterministic, since Message.Range does not return fields in any // TODO: This should use a more natural ordering like NumberFieldOrder,
// defined order. // but doing so breaks golden tests that make invalid assumption about
// // output stability of this implementation.
// When using deterministic serialization, we sort the known fields. fieldOrder = order.LegacyFieldOrder
}
var err error var err error
o.rangeFields(m, func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool { order.RangeFields(m, fieldOrder, func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
b, err = o.marshalField(b, fd, v) b, err = o.marshalField(b, fd, v)
return err == nil return err == nil
}) })
@ -229,27 +227,6 @@ func (o MarshalOptions) marshalMessageSlow(b []byte, m protoreflect.Message) ([]
return b, nil return b, nil
} }
// rangeFields visits fields in a defined order when deterministic serialization is enabled.
func (o MarshalOptions) rangeFields(m protoreflect.Message, f func(protoreflect.FieldDescriptor, protoreflect.Value) bool) {
if !o.Deterministic {
m.Range(f)
return
}
var fds []protoreflect.FieldDescriptor
m.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool {
fds = append(fds, fd)
return true
})
sort.Slice(fds, func(a, b int) bool {
return fieldsort.Less(fds[a], fds[b])
})
for _, fd := range fds {
if !f(fd, m.Get(fd)) {
break
}
}
}
func (o MarshalOptions) marshalField(b []byte, fd protoreflect.FieldDescriptor, value protoreflect.Value) ([]byte, error) { func (o MarshalOptions) marshalField(b []byte, fd protoreflect.FieldDescriptor, value protoreflect.Value) ([]byte, error) {
switch { switch {
case fd.IsList(): case fd.IsList():
@ -292,8 +269,12 @@ func (o MarshalOptions) marshalList(b []byte, fd protoreflect.FieldDescriptor, l
func (o MarshalOptions) marshalMap(b []byte, fd protoreflect.FieldDescriptor, mapv protoreflect.Map) ([]byte, error) { func (o MarshalOptions) marshalMap(b []byte, fd protoreflect.FieldDescriptor, mapv protoreflect.Map) ([]byte, error) {
keyf := fd.MapKey() keyf := fd.MapKey()
valf := fd.MapValue() valf := fd.MapValue()
keyOrder := order.AnyKeyOrder
if o.Deterministic {
keyOrder = order.GenericKeyOrder
}
var err error var err error
o.rangeMap(mapv, keyf.Kind(), func(key protoreflect.MapKey, value protoreflect.Value) bool { order.RangeEntries(mapv, keyOrder, func(key protoreflect.MapKey, value protoreflect.Value) bool {
b = protowire.AppendTag(b, fd.Number(), protowire.BytesType) b = protowire.AppendTag(b, fd.Number(), protowire.BytesType)
var pos int var pos int
b, pos = appendSpeculativeLength(b) b, pos = appendSpeculativeLength(b)
@ -312,14 +293,6 @@ func (o MarshalOptions) marshalMap(b []byte, fd protoreflect.FieldDescriptor, ma
return b, err return b, err
} }
func (o MarshalOptions) rangeMap(mapv protoreflect.Map, kind protoreflect.Kind, f func(protoreflect.MapKey, protoreflect.Value) bool) {
if !o.Deterministic {
mapv.Range(f)
return
}
mapsort.Range(mapv, kind, f)
}
// When encoding length-prefixed fields, we speculatively set aside some number of bytes // When encoding length-prefixed fields, we speculatively set aside some number of bytes
// for the length, encode the data, and then encode the length (shifting the data if necessary // for the length, encode the data, and then encode the length (shifting the data if necessary
// to make room). // to make room).

View File

@ -111,18 +111,31 @@ func equalList(fd pref.FieldDescriptor, x, y pref.List) bool {
// equalValue compares two singular values. // equalValue compares two singular values.
func equalValue(fd pref.FieldDescriptor, x, y pref.Value) bool { func equalValue(fd pref.FieldDescriptor, x, y pref.Value) bool {
switch { switch fd.Kind() {
case fd.Message() != nil: case pref.BoolKind:
return equalMessage(x.Message(), y.Message()) return x.Bool() == y.Bool()
case fd.Kind() == pref.BytesKind: case pref.EnumKind:
return bytes.Equal(x.Bytes(), y.Bytes()) return x.Enum() == y.Enum()
case fd.Kind() == pref.FloatKind, fd.Kind() == pref.DoubleKind: case pref.Int32Kind, pref.Sint32Kind,
pref.Int64Kind, pref.Sint64Kind,
pref.Sfixed32Kind, pref.Sfixed64Kind:
return x.Int() == y.Int()
case pref.Uint32Kind, pref.Uint64Kind,
pref.Fixed32Kind, pref.Fixed64Kind:
return x.Uint() == y.Uint()
case pref.FloatKind, pref.DoubleKind:
fx := x.Float() fx := x.Float()
fy := y.Float() fy := y.Float()
if math.IsNaN(fx) || math.IsNaN(fy) { if math.IsNaN(fx) || math.IsNaN(fy) {
return math.IsNaN(fx) && math.IsNaN(fy) return math.IsNaN(fx) && math.IsNaN(fy)
} }
return fx == fy return fx == fy
case pref.StringKind:
return x.String() == y.String()
case pref.BytesKind:
return bytes.Equal(x.Bytes(), y.Bytes())
case pref.MessageKind, pref.GroupKind:
return equalMessage(x.Message(), y.Message())
default: default:
return x.Interface() == y.Interface() return x.Interface() == y.Interface()
} }

View File

@ -9,6 +9,7 @@ import (
"google.golang.org/protobuf/internal/encoding/messageset" "google.golang.org/protobuf/internal/encoding/messageset"
"google.golang.org/protobuf/internal/errors" "google.golang.org/protobuf/internal/errors"
"google.golang.org/protobuf/internal/flags" "google.golang.org/protobuf/internal/flags"
"google.golang.org/protobuf/internal/order"
"google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry" "google.golang.org/protobuf/reflect/protoregistry"
) )
@ -28,8 +29,12 @@ func (o MarshalOptions) marshalMessageSet(b []byte, m protoreflect.Message) ([]b
if !flags.ProtoLegacy { if !flags.ProtoLegacy {
return b, errors.New("no support for message_set_wire_format") return b, errors.New("no support for message_set_wire_format")
} }
fieldOrder := order.AnyFieldOrder
if o.Deterministic {
fieldOrder = order.NumberFieldOrder
}
var err error var err error
o.rangeFields(m, func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool { order.RangeFields(m, fieldOrder, func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
b, err = o.marshalMessageSetField(b, fd, v) b, err = o.marshalMessageSetField(b, fd, v)
return err == nil return err == nil
}) })

View File

@ -32,3 +32,12 @@ var Error error
func init() { func init() {
Error = errors.Error Error = errors.Error
} }
// MessageName returns the full name of m.
// If m is nil, it returns an empty string.
func MessageName(m Message) protoreflect.FullName {
if m == nil {
return ""
}
return m.ProtoReflect().Descriptor().FullName()
}

View File

@ -0,0 +1,276 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package protodesc provides functionality for converting
// FileDescriptorProto messages to/from protoreflect.FileDescriptor values.
//
// The google.protobuf.FileDescriptorProto is a protobuf message that describes
// the type information for a .proto file in a form that is easily serializable.
// The protoreflect.FileDescriptor is a more structured representation of
// the FileDescriptorProto message where references and remote dependencies
// can be directly followed.
package protodesc
import (
"google.golang.org/protobuf/internal/errors"
"google.golang.org/protobuf/internal/filedesc"
"google.golang.org/protobuf/internal/pragma"
"google.golang.org/protobuf/internal/strs"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/types/descriptorpb"
)
// Resolver is the resolver used by NewFile to resolve dependencies.
// The enums and messages provided must belong to some parent file,
// which is also registered.
//
// It is implemented by protoregistry.Files.
type Resolver interface {
FindFileByPath(string) (protoreflect.FileDescriptor, error)
FindDescriptorByName(protoreflect.FullName) (protoreflect.Descriptor, error)
}
// FileOptions configures the construction of file descriptors.
type FileOptions struct {
pragma.NoUnkeyedLiterals
// AllowUnresolvable configures New to permissively allow unresolvable
// file, enum, or message dependencies. Unresolved dependencies are replaced
// by placeholder equivalents.
//
// The following dependencies may be left unresolved:
// • Resolving an imported file.
// • Resolving the type for a message field or extension field.
// If the kind of the field is unknown, then a placeholder is used for both
// the Enum and Message accessors on the protoreflect.FieldDescriptor.
// • Resolving an enum value set as the default for an optional enum field.
// If unresolvable, the protoreflect.FieldDescriptor.Default is set to the
// first value in the associated enum (or zero if the also enum dependency
// is also unresolvable). The protoreflect.FieldDescriptor.DefaultEnumValue
// is populated with a placeholder.
// • Resolving the extended message type for an extension field.
// • Resolving the input or output message type for a service method.
//
// If the unresolved dependency uses a relative name,
// then the placeholder will contain an invalid FullName with a "*." prefix,
// indicating that the starting prefix of the full name is unknown.
AllowUnresolvable bool
}
// NewFile creates a new protoreflect.FileDescriptor from the provided
// file descriptor message. See FileOptions.New for more information.
func NewFile(fd *descriptorpb.FileDescriptorProto, r Resolver) (protoreflect.FileDescriptor, error) {
return FileOptions{}.New(fd, r)
}
// NewFiles creates a new protoregistry.Files from the provided
// FileDescriptorSet message. See FileOptions.NewFiles for more information.
func NewFiles(fd *descriptorpb.FileDescriptorSet) (*protoregistry.Files, error) {
return FileOptions{}.NewFiles(fd)
}
// New creates a new protoreflect.FileDescriptor from the provided
// file descriptor message. The file must represent a valid proto file according
// to protobuf semantics. The returned descriptor is a deep copy of the input.
//
// Any imported files, enum types, or message types referenced in the file are
// resolved using the provided registry. When looking up an import file path,
// the path must be unique. The newly created file descriptor is not registered
// back into the provided file registry.
func (o FileOptions) New(fd *descriptorpb.FileDescriptorProto, r Resolver) (protoreflect.FileDescriptor, error) {
if r == nil {
r = (*protoregistry.Files)(nil) // empty resolver
}
// Handle the file descriptor content.
f := &filedesc.File{L2: &filedesc.FileL2{}}
switch fd.GetSyntax() {
case "proto2", "":
f.L1.Syntax = protoreflect.Proto2
case "proto3":
f.L1.Syntax = protoreflect.Proto3
default:
return nil, errors.New("invalid syntax: %q", fd.GetSyntax())
}
f.L1.Path = fd.GetName()
if f.L1.Path == "" {
return nil, errors.New("file path must be populated")
}
f.L1.Package = protoreflect.FullName(fd.GetPackage())
if !f.L1.Package.IsValid() && f.L1.Package != "" {
return nil, errors.New("invalid package: %q", f.L1.Package)
}
if opts := fd.GetOptions(); opts != nil {
opts = proto.Clone(opts).(*descriptorpb.FileOptions)
f.L2.Options = func() protoreflect.ProtoMessage { return opts }
}
f.L2.Imports = make(filedesc.FileImports, len(fd.GetDependency()))
for _, i := range fd.GetPublicDependency() {
if !(0 <= i && int(i) < len(f.L2.Imports)) || f.L2.Imports[i].IsPublic {
return nil, errors.New("invalid or duplicate public import index: %d", i)
}
f.L2.Imports[i].IsPublic = true
}
for _, i := range fd.GetWeakDependency() {
if !(0 <= i && int(i) < len(f.L2.Imports)) || f.L2.Imports[i].IsWeak {
return nil, errors.New("invalid or duplicate weak import index: %d", i)
}
f.L2.Imports[i].IsWeak = true
}
imps := importSet{f.Path(): true}
for i, path := range fd.GetDependency() {
imp := &f.L2.Imports[i]
f, err := r.FindFileByPath(path)
if err == protoregistry.NotFound && (o.AllowUnresolvable || imp.IsWeak) {
f = filedesc.PlaceholderFile(path)
} else if err != nil {
return nil, errors.New("could not resolve import %q: %v", path, err)
}
imp.FileDescriptor = f
if imps[imp.Path()] {
return nil, errors.New("already imported %q", path)
}
imps[imp.Path()] = true
}
for i := range fd.GetDependency() {
imp := &f.L2.Imports[i]
imps.importPublic(imp.Imports())
}
// Handle source locations.
f.L2.Locations.File = f
for _, loc := range fd.GetSourceCodeInfo().GetLocation() {
var l protoreflect.SourceLocation
// TODO: Validate that the path points to an actual declaration?
l.Path = protoreflect.SourcePath(loc.GetPath())
s := loc.GetSpan()
switch len(s) {
case 3:
l.StartLine, l.StartColumn, l.EndLine, l.EndColumn = int(s[0]), int(s[1]), int(s[0]), int(s[2])
case 4:
l.StartLine, l.StartColumn, l.EndLine, l.EndColumn = int(s[0]), int(s[1]), int(s[2]), int(s[3])
default:
return nil, errors.New("invalid span: %v", s)
}
// TODO: Validate that the span information is sensible?
// See https://github.com/protocolbuffers/protobuf/issues/6378.
if false && (l.EndLine < l.StartLine || l.StartLine < 0 || l.StartColumn < 0 || l.EndColumn < 0 ||
(l.StartLine == l.EndLine && l.EndColumn <= l.StartColumn)) {
return nil, errors.New("invalid span: %v", s)
}
l.LeadingDetachedComments = loc.GetLeadingDetachedComments()
l.LeadingComments = loc.GetLeadingComments()
l.TrailingComments = loc.GetTrailingComments()
f.L2.Locations.List = append(f.L2.Locations.List, l)
}
// Step 1: Allocate and derive the names for all declarations.
// This copies all fields from the descriptor proto except:
// google.protobuf.FieldDescriptorProto.type_name
// google.protobuf.FieldDescriptorProto.default_value
// google.protobuf.FieldDescriptorProto.oneof_index
// google.protobuf.FieldDescriptorProto.extendee
// google.protobuf.MethodDescriptorProto.input
// google.protobuf.MethodDescriptorProto.output
var err error
sb := new(strs.Builder)
r1 := make(descsByName)
if f.L1.Enums.List, err = r1.initEnumDeclarations(fd.GetEnumType(), f, sb); err != nil {
return nil, err
}
if f.L1.Messages.List, err = r1.initMessagesDeclarations(fd.GetMessageType(), f, sb); err != nil {
return nil, err
}
if f.L1.Extensions.List, err = r1.initExtensionDeclarations(fd.GetExtension(), f, sb); err != nil {
return nil, err
}
if f.L1.Services.List, err = r1.initServiceDeclarations(fd.GetService(), f, sb); err != nil {
return nil, err
}
// Step 2: Resolve every dependency reference not handled by step 1.
r2 := &resolver{local: r1, remote: r, imports: imps, allowUnresolvable: o.AllowUnresolvable}
if err := r2.resolveMessageDependencies(f.L1.Messages.List, fd.GetMessageType()); err != nil {
return nil, err
}
if err := r2.resolveExtensionDependencies(f.L1.Extensions.List, fd.GetExtension()); err != nil {
return nil, err
}
if err := r2.resolveServiceDependencies(f.L1.Services.List, fd.GetService()); err != nil {
return nil, err
}
// Step 3: Validate every enum, message, and extension declaration.
if err := validateEnumDeclarations(f.L1.Enums.List, fd.GetEnumType()); err != nil {
return nil, err
}
if err := validateMessageDeclarations(f.L1.Messages.List, fd.GetMessageType()); err != nil {
return nil, err
}
if err := validateExtensionDeclarations(f.L1.Extensions.List, fd.GetExtension()); err != nil {
return nil, err
}
return f, nil
}
type importSet map[string]bool
func (is importSet) importPublic(imps protoreflect.FileImports) {
for i := 0; i < imps.Len(); i++ {
if imp := imps.Get(i); imp.IsPublic {
is[imp.Path()] = true
is.importPublic(imp.Imports())
}
}
}
// NewFiles creates a new protoregistry.Files from the provided
// FileDescriptorSet message. The descriptor set must include only
// valid files according to protobuf semantics. The returned descriptors
// are a deep copy of the input.
func (o FileOptions) NewFiles(fds *descriptorpb.FileDescriptorSet) (*protoregistry.Files, error) {
files := make(map[string]*descriptorpb.FileDescriptorProto)
for _, fd := range fds.File {
if _, ok := files[fd.GetName()]; ok {
return nil, errors.New("file appears multiple times: %q", fd.GetName())
}
files[fd.GetName()] = fd
}
r := &protoregistry.Files{}
for _, fd := range files {
if err := o.addFileDeps(r, fd, files); err != nil {
return nil, err
}
}
return r, nil
}
func (o FileOptions) addFileDeps(r *protoregistry.Files, fd *descriptorpb.FileDescriptorProto, files map[string]*descriptorpb.FileDescriptorProto) error {
// Set the entry to nil while descending into a file's dependencies to detect cycles.
files[fd.GetName()] = nil
for _, dep := range fd.Dependency {
depfd, ok := files[dep]
if depfd == nil {
if ok {
return errors.New("import cycle in file: %q", dep)
}
continue
}
if err := o.addFileDeps(r, depfd, files); err != nil {
return err
}
}
// Delete the entry once dependencies are processed.
delete(files, fd.GetName())
f, err := o.New(fd, r)
if err != nil {
return err
}
return r.RegisterFile(f)
}

View File

@ -0,0 +1,248 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package protodesc
import (
"google.golang.org/protobuf/internal/errors"
"google.golang.org/protobuf/internal/filedesc"
"google.golang.org/protobuf/internal/strs"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/descriptorpb"
)
type descsByName map[protoreflect.FullName]protoreflect.Descriptor
func (r descsByName) initEnumDeclarations(eds []*descriptorpb.EnumDescriptorProto, parent protoreflect.Descriptor, sb *strs.Builder) (es []filedesc.Enum, err error) {
es = make([]filedesc.Enum, len(eds)) // allocate up-front to ensure stable pointers
for i, ed := range eds {
e := &es[i]
e.L2 = new(filedesc.EnumL2)
if e.L0, err = r.makeBase(e, parent, ed.GetName(), i, sb); err != nil {
return nil, err
}
if opts := ed.GetOptions(); opts != nil {
opts = proto.Clone(opts).(*descriptorpb.EnumOptions)
e.L2.Options = func() protoreflect.ProtoMessage { return opts }
}
for _, s := range ed.GetReservedName() {
e.L2.ReservedNames.List = append(e.L2.ReservedNames.List, protoreflect.Name(s))
}
for _, rr := range ed.GetReservedRange() {
e.L2.ReservedRanges.List = append(e.L2.ReservedRanges.List, [2]protoreflect.EnumNumber{
protoreflect.EnumNumber(rr.GetStart()),
protoreflect.EnumNumber(rr.GetEnd()),
})
}
if e.L2.Values.List, err = r.initEnumValuesFromDescriptorProto(ed.GetValue(), e, sb); err != nil {
return nil, err
}
}
return es, nil
}
func (r descsByName) initEnumValuesFromDescriptorProto(vds []*descriptorpb.EnumValueDescriptorProto, parent protoreflect.Descriptor, sb *strs.Builder) (vs []filedesc.EnumValue, err error) {
vs = make([]filedesc.EnumValue, len(vds)) // allocate up-front to ensure stable pointers
for i, vd := range vds {
v := &vs[i]
if v.L0, err = r.makeBase(v, parent, vd.GetName(), i, sb); err != nil {
return nil, err
}
if opts := vd.GetOptions(); opts != nil {
opts = proto.Clone(opts).(*descriptorpb.EnumValueOptions)
v.L1.Options = func() protoreflect.ProtoMessage { return opts }
}
v.L1.Number = protoreflect.EnumNumber(vd.GetNumber())
}
return vs, nil
}
func (r descsByName) initMessagesDeclarations(mds []*descriptorpb.DescriptorProto, parent protoreflect.Descriptor, sb *strs.Builder) (ms []filedesc.Message, err error) {
ms = make([]filedesc.Message, len(mds)) // allocate up-front to ensure stable pointers
for i, md := range mds {
m := &ms[i]
m.L2 = new(filedesc.MessageL2)
if m.L0, err = r.makeBase(m, parent, md.GetName(), i, sb); err != nil {
return nil, err
}
if opts := md.GetOptions(); opts != nil {
opts = proto.Clone(opts).(*descriptorpb.MessageOptions)
m.L2.Options = func() protoreflect.ProtoMessage { return opts }
m.L1.IsMapEntry = opts.GetMapEntry()
m.L1.IsMessageSet = opts.GetMessageSetWireFormat()
}
for _, s := range md.GetReservedName() {
m.L2.ReservedNames.List = append(m.L2.ReservedNames.List, protoreflect.Name(s))
}
for _, rr := range md.GetReservedRange() {
m.L2.ReservedRanges.List = append(m.L2.ReservedRanges.List, [2]protoreflect.FieldNumber{
protoreflect.FieldNumber(rr.GetStart()),
protoreflect.FieldNumber(rr.GetEnd()),
})
}
for _, xr := range md.GetExtensionRange() {
m.L2.ExtensionRanges.List = append(m.L2.ExtensionRanges.List, [2]protoreflect.FieldNumber{
protoreflect.FieldNumber(xr.GetStart()),
protoreflect.FieldNumber(xr.GetEnd()),
})
var optsFunc func() protoreflect.ProtoMessage
if opts := xr.GetOptions(); opts != nil {
opts = proto.Clone(opts).(*descriptorpb.ExtensionRangeOptions)
optsFunc = func() protoreflect.ProtoMessage { return opts }
}
m.L2.ExtensionRangeOptions = append(m.L2.ExtensionRangeOptions, optsFunc)
}
if m.L2.Fields.List, err = r.initFieldsFromDescriptorProto(md.GetField(), m, sb); err != nil {
return nil, err
}
if m.L2.Oneofs.List, err = r.initOneofsFromDescriptorProto(md.GetOneofDecl(), m, sb); err != nil {
return nil, err
}
if m.L1.Enums.List, err = r.initEnumDeclarations(md.GetEnumType(), m, sb); err != nil {
return nil, err
}
if m.L1.Messages.List, err = r.initMessagesDeclarations(md.GetNestedType(), m, sb); err != nil {
return nil, err
}
if m.L1.Extensions.List, err = r.initExtensionDeclarations(md.GetExtension(), m, sb); err != nil {
return nil, err
}
}
return ms, nil
}
func (r descsByName) initFieldsFromDescriptorProto(fds []*descriptorpb.FieldDescriptorProto, parent protoreflect.Descriptor, sb *strs.Builder) (fs []filedesc.Field, err error) {
fs = make([]filedesc.Field, len(fds)) // allocate up-front to ensure stable pointers
for i, fd := range fds {
f := &fs[i]
if f.L0, err = r.makeBase(f, parent, fd.GetName(), i, sb); err != nil {
return nil, err
}
f.L1.IsProto3Optional = fd.GetProto3Optional()
if opts := fd.GetOptions(); opts != nil {
opts = proto.Clone(opts).(*descriptorpb.FieldOptions)
f.L1.Options = func() protoreflect.ProtoMessage { return opts }
f.L1.IsWeak = opts.GetWeak()
f.L1.HasPacked = opts.Packed != nil
f.L1.IsPacked = opts.GetPacked()
}
f.L1.Number = protoreflect.FieldNumber(fd.GetNumber())
f.L1.Cardinality = protoreflect.Cardinality(fd.GetLabel())
if fd.Type != nil {
f.L1.Kind = protoreflect.Kind(fd.GetType())
}
if fd.JsonName != nil {
f.L1.StringName.InitJSON(fd.GetJsonName())
}
}
return fs, nil
}
func (r descsByName) initOneofsFromDescriptorProto(ods []*descriptorpb.OneofDescriptorProto, parent protoreflect.Descriptor, sb *strs.Builder) (os []filedesc.Oneof, err error) {
os = make([]filedesc.Oneof, len(ods)) // allocate up-front to ensure stable pointers
for i, od := range ods {
o := &os[i]
if o.L0, err = r.makeBase(o, parent, od.GetName(), i, sb); err != nil {
return nil, err
}
if opts := od.GetOptions(); opts != nil {
opts = proto.Clone(opts).(*descriptorpb.OneofOptions)
o.L1.Options = func() protoreflect.ProtoMessage { return opts }
}
}
return os, nil
}
func (r descsByName) initExtensionDeclarations(xds []*descriptorpb.FieldDescriptorProto, parent protoreflect.Descriptor, sb *strs.Builder) (xs []filedesc.Extension, err error) {
xs = make([]filedesc.Extension, len(xds)) // allocate up-front to ensure stable pointers
for i, xd := range xds {
x := &xs[i]
x.L2 = new(filedesc.ExtensionL2)
if x.L0, err = r.makeBase(x, parent, xd.GetName(), i, sb); err != nil {
return nil, err
}
if opts := xd.GetOptions(); opts != nil {
opts = proto.Clone(opts).(*descriptorpb.FieldOptions)
x.L2.Options = func() protoreflect.ProtoMessage { return opts }
x.L2.IsPacked = opts.GetPacked()
}
x.L1.Number = protoreflect.FieldNumber(xd.GetNumber())
x.L1.Cardinality = protoreflect.Cardinality(xd.GetLabel())
if xd.Type != nil {
x.L1.Kind = protoreflect.Kind(xd.GetType())
}
if xd.JsonName != nil {
x.L2.StringName.InitJSON(xd.GetJsonName())
}
}
return xs, nil
}
func (r descsByName) initServiceDeclarations(sds []*descriptorpb.ServiceDescriptorProto, parent protoreflect.Descriptor, sb *strs.Builder) (ss []filedesc.Service, err error) {
ss = make([]filedesc.Service, len(sds)) // allocate up-front to ensure stable pointers
for i, sd := range sds {
s := &ss[i]
s.L2 = new(filedesc.ServiceL2)
if s.L0, err = r.makeBase(s, parent, sd.GetName(), i, sb); err != nil {
return nil, err
}
if opts := sd.GetOptions(); opts != nil {
opts = proto.Clone(opts).(*descriptorpb.ServiceOptions)
s.L2.Options = func() protoreflect.ProtoMessage { return opts }
}
if s.L2.Methods.List, err = r.initMethodsFromDescriptorProto(sd.GetMethod(), s, sb); err != nil {
return nil, err
}
}
return ss, nil
}
func (r descsByName) initMethodsFromDescriptorProto(mds []*descriptorpb.MethodDescriptorProto, parent protoreflect.Descriptor, sb *strs.Builder) (ms []filedesc.Method, err error) {
ms = make([]filedesc.Method, len(mds)) // allocate up-front to ensure stable pointers
for i, md := range mds {
m := &ms[i]
if m.L0, err = r.makeBase(m, parent, md.GetName(), i, sb); err != nil {
return nil, err
}
if opts := md.GetOptions(); opts != nil {
opts = proto.Clone(opts).(*descriptorpb.MethodOptions)
m.L1.Options = func() protoreflect.ProtoMessage { return opts }
}
m.L1.IsStreamingClient = md.GetClientStreaming()
m.L1.IsStreamingServer = md.GetServerStreaming()
}
return ms, nil
}
func (r descsByName) makeBase(child, parent protoreflect.Descriptor, name string, idx int, sb *strs.Builder) (filedesc.BaseL0, error) {
if !protoreflect.Name(name).IsValid() {
return filedesc.BaseL0{}, errors.New("descriptor %q has an invalid nested name: %q", parent.FullName(), name)
}
// Derive the full name of the child.
// Note that enum values are a sibling to the enum parent in the namespace.
var fullName protoreflect.FullName
if _, ok := parent.(protoreflect.EnumDescriptor); ok {
fullName = sb.AppendFullName(parent.FullName().Parent(), protoreflect.Name(name))
} else {
fullName = sb.AppendFullName(parent.FullName(), protoreflect.Name(name))
}
if _, ok := r[fullName]; ok {
return filedesc.BaseL0{}, errors.New("descriptor %q already declared", fullName)
}
r[fullName] = child
// TODO: Verify that the full name does not already exist in the resolver?
// This is not as critical since most usages of NewFile will register
// the created file back into the registry, which will perform this check.
return filedesc.BaseL0{
FullName: fullName,
ParentFile: parent.ParentFile().(*filedesc.File),
Parent: parent,
Index: idx,
}, nil
}

View File

@ -0,0 +1,286 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package protodesc
import (
"google.golang.org/protobuf/internal/encoding/defval"
"google.golang.org/protobuf/internal/errors"
"google.golang.org/protobuf/internal/filedesc"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/types/descriptorpb"
)
// resolver is a wrapper around a local registry of declarations within the file
// and the remote resolver. The remote resolver is restricted to only return
// descriptors that have been imported.
type resolver struct {
local descsByName
remote Resolver
imports importSet
allowUnresolvable bool
}
func (r *resolver) resolveMessageDependencies(ms []filedesc.Message, mds []*descriptorpb.DescriptorProto) (err error) {
for i, md := range mds {
m := &ms[i]
for j, fd := range md.GetField() {
f := &m.L2.Fields.List[j]
if f.L1.Cardinality == protoreflect.Required {
m.L2.RequiredNumbers.List = append(m.L2.RequiredNumbers.List, f.L1.Number)
}
if fd.OneofIndex != nil {
k := int(fd.GetOneofIndex())
if !(0 <= k && k < len(md.GetOneofDecl())) {
return errors.New("message field %q has an invalid oneof index: %d", f.FullName(), k)
}
o := &m.L2.Oneofs.List[k]
f.L1.ContainingOneof = o
o.L1.Fields.List = append(o.L1.Fields.List, f)
}
if f.L1.Kind, f.L1.Enum, f.L1.Message, err = r.findTarget(f.Kind(), f.Parent().FullName(), partialName(fd.GetTypeName()), f.IsWeak()); err != nil {
return errors.New("message field %q cannot resolve type: %v", f.FullName(), err)
}
if fd.DefaultValue != nil {
v, ev, err := unmarshalDefault(fd.GetDefaultValue(), f, r.allowUnresolvable)
if err != nil {
return errors.New("message field %q has invalid default: %v", f.FullName(), err)
}
f.L1.Default = filedesc.DefaultValue(v, ev)
}
}
if err := r.resolveMessageDependencies(m.L1.Messages.List, md.GetNestedType()); err != nil {
return err
}
if err := r.resolveExtensionDependencies(m.L1.Extensions.List, md.GetExtension()); err != nil {
return err
}
}
return nil
}
func (r *resolver) resolveExtensionDependencies(xs []filedesc.Extension, xds []*descriptorpb.FieldDescriptorProto) (err error) {
for i, xd := range xds {
x := &xs[i]
if x.L1.Extendee, err = r.findMessageDescriptor(x.Parent().FullName(), partialName(xd.GetExtendee()), false); err != nil {
return errors.New("extension field %q cannot resolve extendee: %v", x.FullName(), err)
}
if x.L1.Kind, x.L2.Enum, x.L2.Message, err = r.findTarget(x.Kind(), x.Parent().FullName(), partialName(xd.GetTypeName()), false); err != nil {
return errors.New("extension field %q cannot resolve type: %v", x.FullName(), err)
}
if xd.DefaultValue != nil {
v, ev, err := unmarshalDefault(xd.GetDefaultValue(), x, r.allowUnresolvable)
if err != nil {
return errors.New("extension field %q has invalid default: %v", x.FullName(), err)
}
x.L2.Default = filedesc.DefaultValue(v, ev)
}
}
return nil
}
func (r *resolver) resolveServiceDependencies(ss []filedesc.Service, sds []*descriptorpb.ServiceDescriptorProto) (err error) {
for i, sd := range sds {
s := &ss[i]
for j, md := range sd.GetMethod() {
m := &s.L2.Methods.List[j]
m.L1.Input, err = r.findMessageDescriptor(m.Parent().FullName(), partialName(md.GetInputType()), false)
if err != nil {
return errors.New("service method %q cannot resolve input: %v", m.FullName(), err)
}
m.L1.Output, err = r.findMessageDescriptor(s.FullName(), partialName(md.GetOutputType()), false)
if err != nil {
return errors.New("service method %q cannot resolve output: %v", m.FullName(), err)
}
}
}
return nil
}
// findTarget finds an enum or message descriptor if k is an enum, message,
// group, or unknown. If unknown, and the name could be resolved, the kind
// returned kind is set based on the type of the resolved descriptor.
func (r *resolver) findTarget(k protoreflect.Kind, scope protoreflect.FullName, ref partialName, isWeak bool) (protoreflect.Kind, protoreflect.EnumDescriptor, protoreflect.MessageDescriptor, error) {
switch k {
case protoreflect.EnumKind:
ed, err := r.findEnumDescriptor(scope, ref, isWeak)
if err != nil {
return 0, nil, nil, err
}
return k, ed, nil, nil
case protoreflect.MessageKind, protoreflect.GroupKind:
md, err := r.findMessageDescriptor(scope, ref, isWeak)
if err != nil {
return 0, nil, nil, err
}
return k, nil, md, nil
case 0:
// Handle unspecified kinds (possible with parsers that operate
// on a per-file basis without knowledge of dependencies).
d, err := r.findDescriptor(scope, ref)
if err == protoregistry.NotFound && (r.allowUnresolvable || isWeak) {
return k, filedesc.PlaceholderEnum(ref.FullName()), filedesc.PlaceholderMessage(ref.FullName()), nil
} else if err == protoregistry.NotFound {
return 0, nil, nil, errors.New("%q not found", ref.FullName())
} else if err != nil {
return 0, nil, nil, err
}
switch d := d.(type) {
case protoreflect.EnumDescriptor:
return protoreflect.EnumKind, d, nil, nil
case protoreflect.MessageDescriptor:
return protoreflect.MessageKind, nil, d, nil
default:
return 0, nil, nil, errors.New("unknown kind")
}
default:
if ref != "" {
return 0, nil, nil, errors.New("target name cannot be specified for %v", k)
}
if !k.IsValid() {
return 0, nil, nil, errors.New("invalid kind: %d", k)
}
return k, nil, nil, nil
}
}
// findDescriptor finds the descriptor by name,
// which may be a relative name within some scope.
//
// Suppose the scope was "fizz.buzz" and the reference was "Foo.Bar",
// then the following full names are searched:
// * fizz.buzz.Foo.Bar
// * fizz.Foo.Bar
// * Foo.Bar
func (r *resolver) findDescriptor(scope protoreflect.FullName, ref partialName) (protoreflect.Descriptor, error) {
if !ref.IsValid() {
return nil, errors.New("invalid name reference: %q", ref)
}
if ref.IsFull() {
scope, ref = "", ref[1:]
}
var foundButNotImported protoreflect.Descriptor
for {
// Derive the full name to search.
s := protoreflect.FullName(ref)
if scope != "" {
s = scope + "." + s
}
// Check the current file for the descriptor.
if d, ok := r.local[s]; ok {
return d, nil
}
// Check the remote registry for the descriptor.
d, err := r.remote.FindDescriptorByName(s)
if err == nil {
// Only allow descriptors covered by one of the imports.
if r.imports[d.ParentFile().Path()] {
return d, nil
}
foundButNotImported = d
} else if err != protoregistry.NotFound {
return nil, errors.Wrap(err, "%q", s)
}
// Continue on at a higher level of scoping.
if scope == "" {
if d := foundButNotImported; d != nil {
return nil, errors.New("resolved %q, but %q is not imported", d.FullName(), d.ParentFile().Path())
}
return nil, protoregistry.NotFound
}
scope = scope.Parent()
}
}
func (r *resolver) findEnumDescriptor(scope protoreflect.FullName, ref partialName, isWeak bool) (protoreflect.EnumDescriptor, error) {
d, err := r.findDescriptor(scope, ref)
if err == protoregistry.NotFound && (r.allowUnresolvable || isWeak) {
return filedesc.PlaceholderEnum(ref.FullName()), nil
} else if err == protoregistry.NotFound {
return nil, errors.New("%q not found", ref.FullName())
} else if err != nil {
return nil, err
}
ed, ok := d.(protoreflect.EnumDescriptor)
if !ok {
return nil, errors.New("resolved %q, but it is not an enum", d.FullName())
}
return ed, nil
}
func (r *resolver) findMessageDescriptor(scope protoreflect.FullName, ref partialName, isWeak bool) (protoreflect.MessageDescriptor, error) {
d, err := r.findDescriptor(scope, ref)
if err == protoregistry.NotFound && (r.allowUnresolvable || isWeak) {
return filedesc.PlaceholderMessage(ref.FullName()), nil
} else if err == protoregistry.NotFound {
return nil, errors.New("%q not found", ref.FullName())
} else if err != nil {
return nil, err
}
md, ok := d.(protoreflect.MessageDescriptor)
if !ok {
return nil, errors.New("resolved %q, but it is not an message", d.FullName())
}
return md, nil
}
// partialName is the partial name. A leading dot means that the name is full,
// otherwise the name is relative to some current scope.
// See google.protobuf.FieldDescriptorProto.type_name.
type partialName string
func (s partialName) IsFull() bool {
return len(s) > 0 && s[0] == '.'
}
func (s partialName) IsValid() bool {
if s.IsFull() {
return protoreflect.FullName(s[1:]).IsValid()
}
return protoreflect.FullName(s).IsValid()
}
const unknownPrefix = "*."
// FullName converts the partial name to a full name on a best-effort basis.
// If relative, it creates an invalid full name, using a "*." prefix
// to indicate that the start of the full name is unknown.
func (s partialName) FullName() protoreflect.FullName {
if s.IsFull() {
return protoreflect.FullName(s[1:])
}
return protoreflect.FullName(unknownPrefix + s)
}
func unmarshalDefault(s string, fd protoreflect.FieldDescriptor, allowUnresolvable bool) (protoreflect.Value, protoreflect.EnumValueDescriptor, error) {
var evs protoreflect.EnumValueDescriptors
if fd.Enum() != nil {
evs = fd.Enum().Values()
}
v, ev, err := defval.Unmarshal(s, fd.Kind(), evs, defval.Descriptor)
if err != nil && allowUnresolvable && evs != nil && protoreflect.Name(s).IsValid() {
v = protoreflect.ValueOfEnum(0)
if evs.Len() > 0 {
v = protoreflect.ValueOfEnum(evs.Get(0).Number())
}
ev = filedesc.PlaceholderEnumValue(fd.Enum().FullName().Parent().Append(protoreflect.Name(s)))
} else if err != nil {
return v, ev, err
}
if fd.Syntax() == protoreflect.Proto3 {
return v, ev, errors.New("cannot be specified under proto3 semantics")
}
if fd.Kind() == protoreflect.MessageKind || fd.Kind() == protoreflect.GroupKind || fd.Cardinality() == protoreflect.Repeated {
return v, ev, errors.New("cannot be specified on composite types")
}
return v, ev, nil
}

View File

@ -0,0 +1,374 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package protodesc
import (
"strings"
"unicode"
"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/internal/errors"
"google.golang.org/protobuf/internal/filedesc"
"google.golang.org/protobuf/internal/flags"
"google.golang.org/protobuf/internal/genid"
"google.golang.org/protobuf/internal/strs"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/descriptorpb"
)
func validateEnumDeclarations(es []filedesc.Enum, eds []*descriptorpb.EnumDescriptorProto) error {
for i, ed := range eds {
e := &es[i]
if err := e.L2.ReservedNames.CheckValid(); err != nil {
return errors.New("enum %q reserved names has %v", e.FullName(), err)
}
if err := e.L2.ReservedRanges.CheckValid(); err != nil {
return errors.New("enum %q reserved ranges has %v", e.FullName(), err)
}
if len(ed.GetValue()) == 0 {
return errors.New("enum %q must contain at least one value declaration", e.FullName())
}
allowAlias := ed.GetOptions().GetAllowAlias()
foundAlias := false
for i := 0; i < e.Values().Len(); i++ {
v1 := e.Values().Get(i)
if v2 := e.Values().ByNumber(v1.Number()); v1 != v2 {
foundAlias = true
if !allowAlias {
return errors.New("enum %q has conflicting non-aliased values on number %d: %q with %q", e.FullName(), v1.Number(), v1.Name(), v2.Name())
}
}
}
if allowAlias && !foundAlias {
return errors.New("enum %q allows aliases, but none were found", e.FullName())
}
if e.Syntax() == protoreflect.Proto3 {
if v := e.Values().Get(0); v.Number() != 0 {
return errors.New("enum %q using proto3 semantics must have zero number for the first value", v.FullName())
}
// Verify that value names in proto3 do not conflict if the
// case-insensitive prefix is removed.
// See protoc v3.8.0: src/google/protobuf/descriptor.cc:4991-5055
names := map[string]protoreflect.EnumValueDescriptor{}
prefix := strings.Replace(strings.ToLower(string(e.Name())), "_", "", -1)
for i := 0; i < e.Values().Len(); i++ {
v1 := e.Values().Get(i)
s := strs.EnumValueName(strs.TrimEnumPrefix(string(v1.Name()), prefix))
if v2, ok := names[s]; ok && v1.Number() != v2.Number() {
return errors.New("enum %q using proto3 semantics has conflict: %q with %q", e.FullName(), v1.Name(), v2.Name())
}
names[s] = v1
}
}
for j, vd := range ed.GetValue() {
v := &e.L2.Values.List[j]
if vd.Number == nil {
return errors.New("enum value %q must have a specified number", v.FullName())
}
if e.L2.ReservedNames.Has(v.Name()) {
return errors.New("enum value %q must not use reserved name", v.FullName())
}
if e.L2.ReservedRanges.Has(v.Number()) {
return errors.New("enum value %q must not use reserved number %d", v.FullName(), v.Number())
}
}
}
return nil
}
func validateMessageDeclarations(ms []filedesc.Message, mds []*descriptorpb.DescriptorProto) error {
for i, md := range mds {
m := &ms[i]
// Handle the message descriptor itself.
isMessageSet := md.GetOptions().GetMessageSetWireFormat()
if err := m.L2.ReservedNames.CheckValid(); err != nil {
return errors.New("message %q reserved names has %v", m.FullName(), err)
}
if err := m.L2.ReservedRanges.CheckValid(isMessageSet); err != nil {
return errors.New("message %q reserved ranges has %v", m.FullName(), err)
}
if err := m.L2.ExtensionRanges.CheckValid(isMessageSet); err != nil {
return errors.New("message %q extension ranges has %v", m.FullName(), err)
}
if err := (*filedesc.FieldRanges).CheckOverlap(&m.L2.ReservedRanges, &m.L2.ExtensionRanges); err != nil {
return errors.New("message %q reserved and extension ranges has %v", m.FullName(), err)
}
for i := 0; i < m.Fields().Len(); i++ {
f1 := m.Fields().Get(i)
if f2 := m.Fields().ByNumber(f1.Number()); f1 != f2 {
return errors.New("message %q has conflicting fields: %q with %q", m.FullName(), f1.Name(), f2.Name())
}
}
if isMessageSet && !flags.ProtoLegacy {
return errors.New("message %q is a MessageSet, which is a legacy proto1 feature that is no longer supported", m.FullName())
}
if isMessageSet && (m.Syntax() != protoreflect.Proto2 || m.Fields().Len() > 0 || m.ExtensionRanges().Len() == 0) {
return errors.New("message %q is an invalid proto1 MessageSet", m.FullName())
}
if m.Syntax() == protoreflect.Proto3 {
if m.ExtensionRanges().Len() > 0 {
return errors.New("message %q using proto3 semantics cannot have extension ranges", m.FullName())
}
// Verify that field names in proto3 do not conflict if lowercased
// with all underscores removed.
// See protoc v3.8.0: src/google/protobuf/descriptor.cc:5830-5847
names := map[string]protoreflect.FieldDescriptor{}
for i := 0; i < m.Fields().Len(); i++ {
f1 := m.Fields().Get(i)
s := strings.Replace(strings.ToLower(string(f1.Name())), "_", "", -1)
if f2, ok := names[s]; ok {
return errors.New("message %q using proto3 semantics has conflict: %q with %q", m.FullName(), f1.Name(), f2.Name())
}
names[s] = f1
}
}
for j, fd := range md.GetField() {
f := &m.L2.Fields.List[j]
if m.L2.ReservedNames.Has(f.Name()) {
return errors.New("message field %q must not use reserved name", f.FullName())
}
if !f.Number().IsValid() {
return errors.New("message field %q has an invalid number: %d", f.FullName(), f.Number())
}
if !f.Cardinality().IsValid() {
return errors.New("message field %q has an invalid cardinality: %d", f.FullName(), f.Cardinality())
}
if m.L2.ReservedRanges.Has(f.Number()) {
return errors.New("message field %q must not use reserved number %d", f.FullName(), f.Number())
}
if m.L2.ExtensionRanges.Has(f.Number()) {
return errors.New("message field %q with number %d in extension range", f.FullName(), f.Number())
}
if fd.Extendee != nil {
return errors.New("message field %q may not have extendee: %q", f.FullName(), fd.GetExtendee())
}
if f.L1.IsProto3Optional {
if f.Syntax() != protoreflect.Proto3 {
return errors.New("message field %q under proto3 optional semantics must be specified in the proto3 syntax", f.FullName())
}
if f.Cardinality() != protoreflect.Optional {
return errors.New("message field %q under proto3 optional semantics must have optional cardinality", f.FullName())
}
if f.ContainingOneof() != nil && f.ContainingOneof().Fields().Len() != 1 {
return errors.New("message field %q under proto3 optional semantics must be within a single element oneof", f.FullName())
}
}
if f.IsWeak() && !flags.ProtoLegacy {
return errors.New("message field %q is a weak field, which is a legacy proto1 feature that is no longer supported", f.FullName())
}
if f.IsWeak() && (f.Syntax() != protoreflect.Proto2 || !isOptionalMessage(f) || f.ContainingOneof() != nil) {
return errors.New("message field %q may only be weak for an optional message", f.FullName())
}
if f.IsPacked() && !isPackable(f) {
return errors.New("message field %q is not packable", f.FullName())
}
if err := checkValidGroup(f); err != nil {
return errors.New("message field %q is an invalid group: %v", f.FullName(), err)
}
if err := checkValidMap(f); err != nil {
return errors.New("message field %q is an invalid map: %v", f.FullName(), err)
}
if f.Syntax() == protoreflect.Proto3 {
if f.Cardinality() == protoreflect.Required {
return errors.New("message field %q using proto3 semantics cannot be required", f.FullName())
}
if f.Enum() != nil && !f.Enum().IsPlaceholder() && f.Enum().Syntax() != protoreflect.Proto3 {
return errors.New("message field %q using proto3 semantics may only depend on a proto3 enum", f.FullName())
}
}
}
seenSynthetic := false // synthetic oneofs for proto3 optional must come after real oneofs
for j := range md.GetOneofDecl() {
o := &m.L2.Oneofs.List[j]
if o.Fields().Len() == 0 {
return errors.New("message oneof %q must contain at least one field declaration", o.FullName())
}
if n := o.Fields().Len(); n-1 != (o.Fields().Get(n-1).Index() - o.Fields().Get(0).Index()) {
return errors.New("message oneof %q must have consecutively declared fields", o.FullName())
}
if o.IsSynthetic() {
seenSynthetic = true
continue
}
if !o.IsSynthetic() && seenSynthetic {
return errors.New("message oneof %q must be declared before synthetic oneofs", o.FullName())
}
for i := 0; i < o.Fields().Len(); i++ {
f := o.Fields().Get(i)
if f.Cardinality() != protoreflect.Optional {
return errors.New("message field %q belongs in a oneof and must be optional", f.FullName())
}
if f.IsWeak() {
return errors.New("message field %q belongs in a oneof and must not be a weak reference", f.FullName())
}
}
}
if err := validateEnumDeclarations(m.L1.Enums.List, md.GetEnumType()); err != nil {
return err
}
if err := validateMessageDeclarations(m.L1.Messages.List, md.GetNestedType()); err != nil {
return err
}
if err := validateExtensionDeclarations(m.L1.Extensions.List, md.GetExtension()); err != nil {
return err
}
}
return nil
}
func validateExtensionDeclarations(xs []filedesc.Extension, xds []*descriptorpb.FieldDescriptorProto) error {
for i, xd := range xds {
x := &xs[i]
// NOTE: Avoid using the IsValid method since extensions to MessageSet
// may have a field number higher than normal. This check only verifies
// that the number is not negative or reserved. We check again later
// if we know that the extendee is definitely not a MessageSet.
if n := x.Number(); n < 0 || (protowire.FirstReservedNumber <= n && n <= protowire.LastReservedNumber) {
return errors.New("extension field %q has an invalid number: %d", x.FullName(), x.Number())
}
if !x.Cardinality().IsValid() || x.Cardinality() == protoreflect.Required {
return errors.New("extension field %q has an invalid cardinality: %d", x.FullName(), x.Cardinality())
}
if xd.JsonName != nil {
// A bug in older versions of protoc would always populate the
// "json_name" option for extensions when it is meaningless.
// When it did so, it would always use the camel-cased field name.
if xd.GetJsonName() != strs.JSONCamelCase(string(x.Name())) {
return errors.New("extension field %q may not have an explicitly set JSON name: %q", x.FullName(), xd.GetJsonName())
}
}
if xd.OneofIndex != nil {
return errors.New("extension field %q may not be part of a oneof", x.FullName())
}
if md := x.ContainingMessage(); !md.IsPlaceholder() {
if !md.ExtensionRanges().Has(x.Number()) {
return errors.New("extension field %q extends %q with non-extension field number: %d", x.FullName(), md.FullName(), x.Number())
}
isMessageSet := md.Options().(*descriptorpb.MessageOptions).GetMessageSetWireFormat()
if isMessageSet && !isOptionalMessage(x) {
return errors.New("extension field %q extends MessageSet and must be an optional message", x.FullName())
}
if !isMessageSet && !x.Number().IsValid() {
return errors.New("extension field %q has an invalid number: %d", x.FullName(), x.Number())
}
}
if xd.GetOptions().GetWeak() {
return errors.New("extension field %q cannot be a weak reference", x.FullName())
}
if x.IsPacked() && !isPackable(x) {
return errors.New("extension field %q is not packable", x.FullName())
}
if err := checkValidGroup(x); err != nil {
return errors.New("extension field %q is an invalid group: %v", x.FullName(), err)
}
if md := x.Message(); md != nil && md.IsMapEntry() {
return errors.New("extension field %q cannot be a map entry", x.FullName())
}
if x.Syntax() == protoreflect.Proto3 {
switch x.ContainingMessage().FullName() {
case (*descriptorpb.FileOptions)(nil).ProtoReflect().Descriptor().FullName():
case (*descriptorpb.EnumOptions)(nil).ProtoReflect().Descriptor().FullName():
case (*descriptorpb.EnumValueOptions)(nil).ProtoReflect().Descriptor().FullName():
case (*descriptorpb.MessageOptions)(nil).ProtoReflect().Descriptor().FullName():
case (*descriptorpb.FieldOptions)(nil).ProtoReflect().Descriptor().FullName():
case (*descriptorpb.OneofOptions)(nil).ProtoReflect().Descriptor().FullName():
case (*descriptorpb.ExtensionRangeOptions)(nil).ProtoReflect().Descriptor().FullName():
case (*descriptorpb.ServiceOptions)(nil).ProtoReflect().Descriptor().FullName():
case (*descriptorpb.MethodOptions)(nil).ProtoReflect().Descriptor().FullName():
default:
return errors.New("extension field %q cannot be declared in proto3 unless extended descriptor options", x.FullName())
}
}
}
return nil
}
// isOptionalMessage reports whether this is an optional message.
// If the kind is unknown, it is assumed to be a message.
func isOptionalMessage(fd protoreflect.FieldDescriptor) bool {
return (fd.Kind() == 0 || fd.Kind() == protoreflect.MessageKind) && fd.Cardinality() == protoreflect.Optional
}
// isPackable checks whether the pack option can be specified.
func isPackable(fd protoreflect.FieldDescriptor) bool {
switch fd.Kind() {
case protoreflect.StringKind, protoreflect.BytesKind, protoreflect.MessageKind, protoreflect.GroupKind:
return false
}
return fd.IsList()
}
// checkValidGroup reports whether fd is a valid group according to the same
// rules that protoc imposes.
func checkValidGroup(fd protoreflect.FieldDescriptor) error {
md := fd.Message()
switch {
case fd.Kind() != protoreflect.GroupKind:
return nil
case fd.Syntax() != protoreflect.Proto2:
return errors.New("invalid under proto2 semantics")
case md == nil || md.IsPlaceholder():
return errors.New("message must be resolvable")
case fd.FullName().Parent() != md.FullName().Parent():
return errors.New("message and field must be declared in the same scope")
case !unicode.IsUpper(rune(md.Name()[0])):
return errors.New("message name must start with an uppercase")
case fd.Name() != protoreflect.Name(strings.ToLower(string(md.Name()))):
return errors.New("field name must be lowercased form of the message name")
}
return nil
}
// checkValidMap checks whether the field is a valid map according to the same
// rules that protoc imposes.
// See protoc v3.8.0: src/google/protobuf/descriptor.cc:6045-6115
func checkValidMap(fd protoreflect.FieldDescriptor) error {
md := fd.Message()
switch {
case md == nil || !md.IsMapEntry():
return nil
case fd.FullName().Parent() != md.FullName().Parent():
return errors.New("message and field must be declared in the same scope")
case md.Name() != protoreflect.Name(strs.MapEntryName(string(fd.Name()))):
return errors.New("incorrect implicit map entry name")
case fd.Cardinality() != protoreflect.Repeated:
return errors.New("field must be repeated")
case md.Fields().Len() != 2:
return errors.New("message must have exactly two fields")
case md.ExtensionRanges().Len() > 0:
return errors.New("message must not have any extension ranges")
case md.Enums().Len()+md.Messages().Len()+md.Extensions().Len() > 0:
return errors.New("message must not have any nested declarations")
}
kf := md.Fields().Get(0)
vf := md.Fields().Get(1)
switch {
case kf.Name() != genid.MapEntry_Key_field_name || kf.Number() != genid.MapEntry_Key_field_number || kf.Cardinality() != protoreflect.Optional || kf.ContainingOneof() != nil || kf.HasDefault():
return errors.New("invalid key field")
case vf.Name() != genid.MapEntry_Value_field_name || vf.Number() != genid.MapEntry_Value_field_number || vf.Cardinality() != protoreflect.Optional || vf.ContainingOneof() != nil || vf.HasDefault():
return errors.New("invalid value field")
}
switch kf.Kind() {
case protoreflect.BoolKind: // bool
case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind: // int32
case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind: // int64
case protoreflect.Uint32Kind, protoreflect.Fixed32Kind: // uint32
case protoreflect.Uint64Kind, protoreflect.Fixed64Kind: // uint64
case protoreflect.StringKind: // string
default:
return errors.New("invalid key kind: %v", kf.Kind())
}
if e := vf.Enum(); e != nil && e.Values().Len() > 0 && e.Values().Get(0).Number() != 0 {
return errors.New("map enum value must have zero number for the first value")
}
return nil
}

View File

@ -0,0 +1,252 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package protodesc
import (
"fmt"
"strings"
"google.golang.org/protobuf/internal/encoding/defval"
"google.golang.org/protobuf/internal/strs"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/descriptorpb"
)
// ToFileDescriptorProto copies a protoreflect.FileDescriptor into a
// google.protobuf.FileDescriptorProto message.
func ToFileDescriptorProto(file protoreflect.FileDescriptor) *descriptorpb.FileDescriptorProto {
p := &descriptorpb.FileDescriptorProto{
Name: proto.String(file.Path()),
Options: proto.Clone(file.Options()).(*descriptorpb.FileOptions),
}
if file.Package() != "" {
p.Package = proto.String(string(file.Package()))
}
for i, imports := 0, file.Imports(); i < imports.Len(); i++ {
imp := imports.Get(i)
p.Dependency = append(p.Dependency, imp.Path())
if imp.IsPublic {
p.PublicDependency = append(p.PublicDependency, int32(i))
}
if imp.IsWeak {
p.WeakDependency = append(p.WeakDependency, int32(i))
}
}
for i, locs := 0, file.SourceLocations(); i < locs.Len(); i++ {
loc := locs.Get(i)
l := &descriptorpb.SourceCodeInfo_Location{}
l.Path = append(l.Path, loc.Path...)
if loc.StartLine == loc.EndLine {
l.Span = []int32{int32(loc.StartLine), int32(loc.StartColumn), int32(loc.EndColumn)}
} else {
l.Span = []int32{int32(loc.StartLine), int32(loc.StartColumn), int32(loc.EndLine), int32(loc.EndColumn)}
}
l.LeadingDetachedComments = append([]string(nil), loc.LeadingDetachedComments...)
if loc.LeadingComments != "" {
l.LeadingComments = proto.String(loc.LeadingComments)
}
if loc.TrailingComments != "" {
l.TrailingComments = proto.String(loc.TrailingComments)
}
if p.SourceCodeInfo == nil {
p.SourceCodeInfo = &descriptorpb.SourceCodeInfo{}
}
p.SourceCodeInfo.Location = append(p.SourceCodeInfo.Location, l)
}
for i, messages := 0, file.Messages(); i < messages.Len(); i++ {
p.MessageType = append(p.MessageType, ToDescriptorProto(messages.Get(i)))
}
for i, enums := 0, file.Enums(); i < enums.Len(); i++ {
p.EnumType = append(p.EnumType, ToEnumDescriptorProto(enums.Get(i)))
}
for i, services := 0, file.Services(); i < services.Len(); i++ {
p.Service = append(p.Service, ToServiceDescriptorProto(services.Get(i)))
}
for i, exts := 0, file.Extensions(); i < exts.Len(); i++ {
p.Extension = append(p.Extension, ToFieldDescriptorProto(exts.Get(i)))
}
if syntax := file.Syntax(); syntax != protoreflect.Proto2 {
p.Syntax = proto.String(file.Syntax().String())
}
return p
}
// ToDescriptorProto copies a protoreflect.MessageDescriptor into a
// google.protobuf.DescriptorProto message.
func ToDescriptorProto(message protoreflect.MessageDescriptor) *descriptorpb.DescriptorProto {
p := &descriptorpb.DescriptorProto{
Name: proto.String(string(message.Name())),
Options: proto.Clone(message.Options()).(*descriptorpb.MessageOptions),
}
for i, fields := 0, message.Fields(); i < fields.Len(); i++ {
p.Field = append(p.Field, ToFieldDescriptorProto(fields.Get(i)))
}
for i, exts := 0, message.Extensions(); i < exts.Len(); i++ {
p.Extension = append(p.Extension, ToFieldDescriptorProto(exts.Get(i)))
}
for i, messages := 0, message.Messages(); i < messages.Len(); i++ {
p.NestedType = append(p.NestedType, ToDescriptorProto(messages.Get(i)))
}
for i, enums := 0, message.Enums(); i < enums.Len(); i++ {
p.EnumType = append(p.EnumType, ToEnumDescriptorProto(enums.Get(i)))
}
for i, xranges := 0, message.ExtensionRanges(); i < xranges.Len(); i++ {
xrange := xranges.Get(i)
p.ExtensionRange = append(p.ExtensionRange, &descriptorpb.DescriptorProto_ExtensionRange{
Start: proto.Int32(int32(xrange[0])),
End: proto.Int32(int32(xrange[1])),
Options: proto.Clone(message.ExtensionRangeOptions(i)).(*descriptorpb.ExtensionRangeOptions),
})
}
for i, oneofs := 0, message.Oneofs(); i < oneofs.Len(); i++ {
p.OneofDecl = append(p.OneofDecl, ToOneofDescriptorProto(oneofs.Get(i)))
}
for i, ranges := 0, message.ReservedRanges(); i < ranges.Len(); i++ {
rrange := ranges.Get(i)
p.ReservedRange = append(p.ReservedRange, &descriptorpb.DescriptorProto_ReservedRange{
Start: proto.Int32(int32(rrange[0])),
End: proto.Int32(int32(rrange[1])),
})
}
for i, names := 0, message.ReservedNames(); i < names.Len(); i++ {
p.ReservedName = append(p.ReservedName, string(names.Get(i)))
}
return p
}
// ToFieldDescriptorProto copies a protoreflect.FieldDescriptor into a
// google.protobuf.FieldDescriptorProto message.
func ToFieldDescriptorProto(field protoreflect.FieldDescriptor) *descriptorpb.FieldDescriptorProto {
p := &descriptorpb.FieldDescriptorProto{
Name: proto.String(string(field.Name())),
Number: proto.Int32(int32(field.Number())),
Label: descriptorpb.FieldDescriptorProto_Label(field.Cardinality()).Enum(),
Options: proto.Clone(field.Options()).(*descriptorpb.FieldOptions),
}
if field.IsExtension() {
p.Extendee = fullNameOf(field.ContainingMessage())
}
if field.Kind().IsValid() {
p.Type = descriptorpb.FieldDescriptorProto_Type(field.Kind()).Enum()
}
if field.Enum() != nil {
p.TypeName = fullNameOf(field.Enum())
}
if field.Message() != nil {
p.TypeName = fullNameOf(field.Message())
}
if field.HasJSONName() {
// A bug in older versions of protoc would always populate the
// "json_name" option for extensions when it is meaningless.
// When it did so, it would always use the camel-cased field name.
if field.IsExtension() {
p.JsonName = proto.String(strs.JSONCamelCase(string(field.Name())))
} else {
p.JsonName = proto.String(field.JSONName())
}
}
if field.Syntax() == protoreflect.Proto3 && field.HasOptionalKeyword() {
p.Proto3Optional = proto.Bool(true)
}
if field.HasDefault() {
def, err := defval.Marshal(field.Default(), field.DefaultEnumValue(), field.Kind(), defval.Descriptor)
if err != nil && field.DefaultEnumValue() != nil {
def = string(field.DefaultEnumValue().Name()) // occurs for unresolved enum values
} else if err != nil {
panic(fmt.Sprintf("%v: %v", field.FullName(), err))
}
p.DefaultValue = proto.String(def)
}
if oneof := field.ContainingOneof(); oneof != nil {
p.OneofIndex = proto.Int32(int32(oneof.Index()))
}
return p
}
// ToOneofDescriptorProto copies a protoreflect.OneofDescriptor into a
// google.protobuf.OneofDescriptorProto message.
func ToOneofDescriptorProto(oneof protoreflect.OneofDescriptor) *descriptorpb.OneofDescriptorProto {
return &descriptorpb.OneofDescriptorProto{
Name: proto.String(string(oneof.Name())),
Options: proto.Clone(oneof.Options()).(*descriptorpb.OneofOptions),
}
}
// ToEnumDescriptorProto copies a protoreflect.EnumDescriptor into a
// google.protobuf.EnumDescriptorProto message.
func ToEnumDescriptorProto(enum protoreflect.EnumDescriptor) *descriptorpb.EnumDescriptorProto {
p := &descriptorpb.EnumDescriptorProto{
Name: proto.String(string(enum.Name())),
Options: proto.Clone(enum.Options()).(*descriptorpb.EnumOptions),
}
for i, values := 0, enum.Values(); i < values.Len(); i++ {
p.Value = append(p.Value, ToEnumValueDescriptorProto(values.Get(i)))
}
for i, ranges := 0, enum.ReservedRanges(); i < ranges.Len(); i++ {
rrange := ranges.Get(i)
p.ReservedRange = append(p.ReservedRange, &descriptorpb.EnumDescriptorProto_EnumReservedRange{
Start: proto.Int32(int32(rrange[0])),
End: proto.Int32(int32(rrange[1])),
})
}
for i, names := 0, enum.ReservedNames(); i < names.Len(); i++ {
p.ReservedName = append(p.ReservedName, string(names.Get(i)))
}
return p
}
// ToEnumValueDescriptorProto copies a protoreflect.EnumValueDescriptor into a
// google.protobuf.EnumValueDescriptorProto message.
func ToEnumValueDescriptorProto(value protoreflect.EnumValueDescriptor) *descriptorpb.EnumValueDescriptorProto {
return &descriptorpb.EnumValueDescriptorProto{
Name: proto.String(string(value.Name())),
Number: proto.Int32(int32(value.Number())),
Options: proto.Clone(value.Options()).(*descriptorpb.EnumValueOptions),
}
}
// ToServiceDescriptorProto copies a protoreflect.ServiceDescriptor into a
// google.protobuf.ServiceDescriptorProto message.
func ToServiceDescriptorProto(service protoreflect.ServiceDescriptor) *descriptorpb.ServiceDescriptorProto {
p := &descriptorpb.ServiceDescriptorProto{
Name: proto.String(string(service.Name())),
Options: proto.Clone(service.Options()).(*descriptorpb.ServiceOptions),
}
for i, methods := 0, service.Methods(); i < methods.Len(); i++ {
p.Method = append(p.Method, ToMethodDescriptorProto(methods.Get(i)))
}
return p
}
// ToMethodDescriptorProto copies a protoreflect.MethodDescriptor into a
// google.protobuf.MethodDescriptorProto message.
func ToMethodDescriptorProto(method protoreflect.MethodDescriptor) *descriptorpb.MethodDescriptorProto {
p := &descriptorpb.MethodDescriptorProto{
Name: proto.String(string(method.Name())),
InputType: fullNameOf(method.Input()),
OutputType: fullNameOf(method.Output()),
Options: proto.Clone(method.Options()).(*descriptorpb.MethodOptions),
}
if method.IsStreamingClient() {
p.ClientStreaming = proto.Bool(true)
}
if method.IsStreamingServer() {
p.ServerStreaming = proto.Bool(true)
}
return p
}
func fullNameOf(d protoreflect.Descriptor) *string {
if d == nil {
return nil
}
if strings.HasPrefix(string(d.FullName()), unknownPrefix) {
return proto.String(string(d.FullName()[len(unknownPrefix):]))
}
return proto.String("." + string(d.FullName()))
}

View File

@ -4,6 +4,10 @@
package protoreflect package protoreflect
import (
"strconv"
)
// SourceLocations is a list of source locations. // SourceLocations is a list of source locations.
type SourceLocations interface { type SourceLocations interface {
// Len reports the number of source locations in the proto file. // Len reports the number of source locations in the proto file.
@ -11,9 +15,20 @@ type SourceLocations interface {
// Get returns the ith SourceLocation. It panics if out of bounds. // Get returns the ith SourceLocation. It panics if out of bounds.
Get(int) SourceLocation Get(int) SourceLocation
doNotImplement // ByPath returns the SourceLocation for the given path,
// returning the first location if multiple exist for the same path.
// If multiple locations exist for the same path,
// then SourceLocation.Next index can be used to identify the
// index of the next SourceLocation.
// If no location exists for this path, it returns the zero value.
ByPath(path SourcePath) SourceLocation
// TODO: Add ByPath and ByDescriptor helper methods. // ByDescriptor returns the SourceLocation for the given descriptor,
// returning the first location if multiple exist for the same path.
// If no location exists for this descriptor, it returns the zero value.
ByDescriptor(desc Descriptor) SourceLocation
doNotImplement
} }
// SourceLocation describes a source location and // SourceLocation describes a source location and
@ -39,6 +54,10 @@ type SourceLocation struct {
LeadingComments string LeadingComments string
// TrailingComments is the trailing attached comment for the declaration. // TrailingComments is the trailing attached comment for the declaration.
TrailingComments string TrailingComments string
// Next is an index into SourceLocations for the next source location that
// has the same Path. It is zero if there is no next location.
Next int
} }
// SourcePath identifies part of a file descriptor for a source location. // SourcePath identifies part of a file descriptor for a source location.
@ -48,5 +67,62 @@ type SourceLocation struct {
// See google.protobuf.SourceCodeInfo.Location.path. // See google.protobuf.SourceCodeInfo.Location.path.
type SourcePath []int32 type SourcePath []int32
// TODO: Add SourcePath.String method to pretty-print the path. For example: // Equal reports whether p1 equals p2.
// ".message_type[6].nested_type[15].field[3]" func (p1 SourcePath) Equal(p2 SourcePath) bool {
if len(p1) != len(p2) {
return false
}
for i := range p1 {
if p1[i] != p2[i] {
return false
}
}
return true
}
// String formats the path in a humanly readable manner.
// The output is guaranteed to be deterministic,
// making it suitable for use as a key into a Go map.
// It is not guaranteed to be stable as the exact output could change
// in a future version of this module.
//
// Example output:
// .message_type[6].nested_type[15].field[3]
func (p SourcePath) String() string {
b := p.appendFileDescriptorProto(nil)
for _, i := range p {
b = append(b, '.')
b = strconv.AppendInt(b, int64(i), 10)
}
return string(b)
}
type appendFunc func(*SourcePath, []byte) []byte
func (p *SourcePath) appendSingularField(b []byte, name string, f appendFunc) []byte {
if len(*p) == 0 {
return b
}
b = append(b, '.')
b = append(b, name...)
*p = (*p)[1:]
if f != nil {
b = f(p, b)
}
return b
}
func (p *SourcePath) appendRepeatedField(b []byte, name string, f appendFunc) []byte {
b = p.appendSingularField(b, name, nil)
if len(*p) == 0 || (*p)[0] < 0 {
return b
}
b = append(b, '[')
b = strconv.AppendUint(b, uint64((*p)[0]), 10)
b = append(b, ']')
*p = (*p)[1:]
if f != nil {
b = f(p, b)
}
return b
}

View File

@ -0,0 +1,461 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Code generated by generate-protos. DO NOT EDIT.
package protoreflect
func (p *SourcePath) appendFileDescriptorProto(b []byte) []byte {
if len(*p) == 0 {
return b
}
switch (*p)[0] {
case 1:
b = p.appendSingularField(b, "name", nil)
case 2:
b = p.appendSingularField(b, "package", nil)
case 3:
b = p.appendRepeatedField(b, "dependency", nil)
case 10:
b = p.appendRepeatedField(b, "public_dependency", nil)
case 11:
b = p.appendRepeatedField(b, "weak_dependency", nil)
case 4:
b = p.appendRepeatedField(b, "message_type", (*SourcePath).appendDescriptorProto)
case 5:
b = p.appendRepeatedField(b, "enum_type", (*SourcePath).appendEnumDescriptorProto)
case 6:
b = p.appendRepeatedField(b, "service", (*SourcePath).appendServiceDescriptorProto)
case 7:
b = p.appendRepeatedField(b, "extension", (*SourcePath).appendFieldDescriptorProto)
case 8:
b = p.appendSingularField(b, "options", (*SourcePath).appendFileOptions)
case 9:
b = p.appendSingularField(b, "source_code_info", (*SourcePath).appendSourceCodeInfo)
case 12:
b = p.appendSingularField(b, "syntax", nil)
}
return b
}
func (p *SourcePath) appendDescriptorProto(b []byte) []byte {
if len(*p) == 0 {
return b
}
switch (*p)[0] {
case 1:
b = p.appendSingularField(b, "name", nil)
case 2:
b = p.appendRepeatedField(b, "field", (*SourcePath).appendFieldDescriptorProto)
case 6:
b = p.appendRepeatedField(b, "extension", (*SourcePath).appendFieldDescriptorProto)
case 3:
b = p.appendRepeatedField(b, "nested_type", (*SourcePath).appendDescriptorProto)
case 4:
b = p.appendRepeatedField(b, "enum_type", (*SourcePath).appendEnumDescriptorProto)
case 5:
b = p.appendRepeatedField(b, "extension_range", (*SourcePath).appendDescriptorProto_ExtensionRange)
case 8:
b = p.appendRepeatedField(b, "oneof_decl", (*SourcePath).appendOneofDescriptorProto)
case 7:
b = p.appendSingularField(b, "options", (*SourcePath).appendMessageOptions)
case 9:
b = p.appendRepeatedField(b, "reserved_range", (*SourcePath).appendDescriptorProto_ReservedRange)
case 10:
b = p.appendRepeatedField(b, "reserved_name", nil)
}
return b
}
func (p *SourcePath) appendEnumDescriptorProto(b []byte) []byte {
if len(*p) == 0 {
return b
}
switch (*p)[0] {
case 1:
b = p.appendSingularField(b, "name", nil)
case 2:
b = p.appendRepeatedField(b, "value", (*SourcePath).appendEnumValueDescriptorProto)
case 3:
b = p.appendSingularField(b, "options", (*SourcePath).appendEnumOptions)
case 4:
b = p.appendRepeatedField(b, "reserved_range", (*SourcePath).appendEnumDescriptorProto_EnumReservedRange)
case 5:
b = p.appendRepeatedField(b, "reserved_name", nil)
}
return b
}
func (p *SourcePath) appendServiceDescriptorProto(b []byte) []byte {
if len(*p) == 0 {
return b
}
switch (*p)[0] {
case 1:
b = p.appendSingularField(b, "name", nil)
case 2:
b = p.appendRepeatedField(b, "method", (*SourcePath).appendMethodDescriptorProto)
case 3:
b = p.appendSingularField(b, "options", (*SourcePath).appendServiceOptions)
}
return b
}
func (p *SourcePath) appendFieldDescriptorProto(b []byte) []byte {
if len(*p) == 0 {
return b
}
switch (*p)[0] {
case 1:
b = p.appendSingularField(b, "name", nil)
case 3:
b = p.appendSingularField(b, "number", nil)
case 4:
b = p.appendSingularField(b, "label", nil)
case 5:
b = p.appendSingularField(b, "type", nil)
case 6:
b = p.appendSingularField(b, "type_name", nil)
case 2:
b = p.appendSingularField(b, "extendee", nil)
case 7:
b = p.appendSingularField(b, "default_value", nil)
case 9:
b = p.appendSingularField(b, "oneof_index", nil)
case 10:
b = p.appendSingularField(b, "json_name", nil)
case 8:
b = p.appendSingularField(b, "options", (*SourcePath).appendFieldOptions)
case 17:
b = p.appendSingularField(b, "proto3_optional", nil)
}
return b
}
func (p *SourcePath) appendFileOptions(b []byte) []byte {
if len(*p) == 0 {
return b
}
switch (*p)[0] {
case 1:
b = p.appendSingularField(b, "java_package", nil)
case 8:
b = p.appendSingularField(b, "java_outer_classname", nil)
case 10:
b = p.appendSingularField(b, "java_multiple_files", nil)
case 20:
b = p.appendSingularField(b, "java_generate_equals_and_hash", nil)
case 27:
b = p.appendSingularField(b, "java_string_check_utf8", nil)
case 9:
b = p.appendSingularField(b, "optimize_for", nil)
case 11:
b = p.appendSingularField(b, "go_package", nil)
case 16:
b = p.appendSingularField(b, "cc_generic_services", nil)
case 17:
b = p.appendSingularField(b, "java_generic_services", nil)
case 18:
b = p.appendSingularField(b, "py_generic_services", nil)
case 42:
b = p.appendSingularField(b, "php_generic_services", nil)
case 23:
b = p.appendSingularField(b, "deprecated", nil)
case 31:
b = p.appendSingularField(b, "cc_enable_arenas", nil)
case 36:
b = p.appendSingularField(b, "objc_class_prefix", nil)
case 37:
b = p.appendSingularField(b, "csharp_namespace", nil)
case 39:
b = p.appendSingularField(b, "swift_prefix", nil)
case 40:
b = p.appendSingularField(b, "php_class_prefix", nil)
case 41:
b = p.appendSingularField(b, "php_namespace", nil)
case 44:
b = p.appendSingularField(b, "php_metadata_namespace", nil)
case 45:
b = p.appendSingularField(b, "ruby_package", nil)
case 999:
b = p.appendRepeatedField(b, "uninterpreted_option", (*SourcePath).appendUninterpretedOption)
}
return b
}
func (p *SourcePath) appendSourceCodeInfo(b []byte) []byte {
if len(*p) == 0 {
return b
}
switch (*p)[0] {
case 1:
b = p.appendRepeatedField(b, "location", (*SourcePath).appendSourceCodeInfo_Location)
}
return b
}
func (p *SourcePath) appendDescriptorProto_ExtensionRange(b []byte) []byte {
if len(*p) == 0 {
return b
}
switch (*p)[0] {
case 1:
b = p.appendSingularField(b, "start", nil)
case 2:
b = p.appendSingularField(b, "end", nil)
case 3:
b = p.appendSingularField(b, "options", (*SourcePath).appendExtensionRangeOptions)
}
return b
}
func (p *SourcePath) appendOneofDescriptorProto(b []byte) []byte {
if len(*p) == 0 {
return b
}
switch (*p)[0] {
case 1:
b = p.appendSingularField(b, "name", nil)
case 2:
b = p.appendSingularField(b, "options", (*SourcePath).appendOneofOptions)
}
return b
}
func (p *SourcePath) appendMessageOptions(b []byte) []byte {
if len(*p) == 0 {
return b
}
switch (*p)[0] {
case 1:
b = p.appendSingularField(b, "message_set_wire_format", nil)
case 2:
b = p.appendSingularField(b, "no_standard_descriptor_accessor", nil)
case 3:
b = p.appendSingularField(b, "deprecated", nil)
case 7:
b = p.appendSingularField(b, "map_entry", nil)
case 999:
b = p.appendRepeatedField(b, "uninterpreted_option", (*SourcePath).appendUninterpretedOption)
}
return b
}
func (p *SourcePath) appendDescriptorProto_ReservedRange(b []byte) []byte {
if len(*p) == 0 {
return b
}
switch (*p)[0] {
case 1:
b = p.appendSingularField(b, "start", nil)
case 2:
b = p.appendSingularField(b, "end", nil)
}
return b
}
func (p *SourcePath) appendEnumValueDescriptorProto(b []byte) []byte {
if len(*p) == 0 {
return b
}
switch (*p)[0] {
case 1:
b = p.appendSingularField(b, "name", nil)
case 2:
b = p.appendSingularField(b, "number", nil)
case 3:
b = p.appendSingularField(b, "options", (*SourcePath).appendEnumValueOptions)
}
return b
}
func (p *SourcePath) appendEnumOptions(b []byte) []byte {
if len(*p) == 0 {
return b
}
switch (*p)[0] {
case 2:
b = p.appendSingularField(b, "allow_alias", nil)
case 3:
b = p.appendSingularField(b, "deprecated", nil)
case 999:
b = p.appendRepeatedField(b, "uninterpreted_option", (*SourcePath).appendUninterpretedOption)
}
return b
}
func (p *SourcePath) appendEnumDescriptorProto_EnumReservedRange(b []byte) []byte {
if len(*p) == 0 {
return b
}
switch (*p)[0] {
case 1:
b = p.appendSingularField(b, "start", nil)
case 2:
b = p.appendSingularField(b, "end", nil)
}
return b
}
func (p *SourcePath) appendMethodDescriptorProto(b []byte) []byte {
if len(*p) == 0 {
return b
}
switch (*p)[0] {
case 1:
b = p.appendSingularField(b, "name", nil)
case 2:
b = p.appendSingularField(b, "input_type", nil)
case 3:
b = p.appendSingularField(b, "output_type", nil)
case 4:
b = p.appendSingularField(b, "options", (*SourcePath).appendMethodOptions)
case 5:
b = p.appendSingularField(b, "client_streaming", nil)
case 6:
b = p.appendSingularField(b, "server_streaming", nil)
}
return b
}
func (p *SourcePath) appendServiceOptions(b []byte) []byte {
if len(*p) == 0 {
return b
}
switch (*p)[0] {
case 33:
b = p.appendSingularField(b, "deprecated", nil)
case 999:
b = p.appendRepeatedField(b, "uninterpreted_option", (*SourcePath).appendUninterpretedOption)
}
return b
}
func (p *SourcePath) appendFieldOptions(b []byte) []byte {
if len(*p) == 0 {
return b
}
switch (*p)[0] {
case 1:
b = p.appendSingularField(b, "ctype", nil)
case 2:
b = p.appendSingularField(b, "packed", nil)
case 6:
b = p.appendSingularField(b, "jstype", nil)
case 5:
b = p.appendSingularField(b, "lazy", nil)
case 3:
b = p.appendSingularField(b, "deprecated", nil)
case 10:
b = p.appendSingularField(b, "weak", nil)
case 999:
b = p.appendRepeatedField(b, "uninterpreted_option", (*SourcePath).appendUninterpretedOption)
}
return b
}
func (p *SourcePath) appendUninterpretedOption(b []byte) []byte {
if len(*p) == 0 {
return b
}
switch (*p)[0] {
case 2:
b = p.appendRepeatedField(b, "name", (*SourcePath).appendUninterpretedOption_NamePart)
case 3:
b = p.appendSingularField(b, "identifier_value", nil)
case 4:
b = p.appendSingularField(b, "positive_int_value", nil)
case 5:
b = p.appendSingularField(b, "negative_int_value", nil)
case 6:
b = p.appendSingularField(b, "double_value", nil)
case 7:
b = p.appendSingularField(b, "string_value", nil)
case 8:
b = p.appendSingularField(b, "aggregate_value", nil)
}
return b
}
func (p *SourcePath) appendSourceCodeInfo_Location(b []byte) []byte {
if len(*p) == 0 {
return b
}
switch (*p)[0] {
case 1:
b = p.appendRepeatedField(b, "path", nil)
case 2:
b = p.appendRepeatedField(b, "span", nil)
case 3:
b = p.appendSingularField(b, "leading_comments", nil)
case 4:
b = p.appendSingularField(b, "trailing_comments", nil)
case 6:
b = p.appendRepeatedField(b, "leading_detached_comments", nil)
}
return b
}
func (p *SourcePath) appendExtensionRangeOptions(b []byte) []byte {
if len(*p) == 0 {
return b
}
switch (*p)[0] {
case 999:
b = p.appendRepeatedField(b, "uninterpreted_option", (*SourcePath).appendUninterpretedOption)
}
return b
}
func (p *SourcePath) appendOneofOptions(b []byte) []byte {
if len(*p) == 0 {
return b
}
switch (*p)[0] {
case 999:
b = p.appendRepeatedField(b, "uninterpreted_option", (*SourcePath).appendUninterpretedOption)
}
return b
}
func (p *SourcePath) appendEnumValueOptions(b []byte) []byte {
if len(*p) == 0 {
return b
}
switch (*p)[0] {
case 1:
b = p.appendSingularField(b, "deprecated", nil)
case 999:
b = p.appendRepeatedField(b, "uninterpreted_option", (*SourcePath).appendUninterpretedOption)
}
return b
}
func (p *SourcePath) appendMethodOptions(b []byte) []byte {
if len(*p) == 0 {
return b
}
switch (*p)[0] {
case 33:
b = p.appendSingularField(b, "deprecated", nil)
case 34:
b = p.appendSingularField(b, "idempotency_level", nil)
case 999:
b = p.appendRepeatedField(b, "uninterpreted_option", (*SourcePath).appendUninterpretedOption)
}
return b
}
func (p *SourcePath) appendUninterpretedOption_NamePart(b []byte) []byte {
if len(*p) == 0 {
return b
}
switch (*p)[0] {
case 1:
b = p.appendSingularField(b, "name_part", nil)
case 2:
b = p.appendSingularField(b, "is_extension", nil)
}
return b
}

Some files were not shown because too many files have changed in this diff Show More