updated etcd to v3.5.5 and newer otel libraries as well

Signed-off-by: Davanum Srinivas <davanum@gmail.com>
This commit is contained in:
Davanum Srinivas
2022-09-17 14:04:16 -04:00
parent 23a4920d83
commit 5be80c05c4
439 changed files with 28673 additions and 20196 deletions

25
vendor/github.com/cenkalti/backoff/v4/.gitignore generated vendored Normal file
View File

@@ -0,0 +1,25 @@
# Compiled Object files, Static and Dynamic libs (Shared Objects)
*.o
*.a
*.so
# Folders
_obj
_test
# Architecture specific extensions/prefixes
*.[568vq]
[568vq].out
*.cgo1.go
*.cgo2.c
_cgo_defun.c
_cgo_gotypes.go
_cgo_export.*
_testmain.go
*.exe
# IDEs
.idea/

10
vendor/github.com/cenkalti/backoff/v4/.travis.yml generated vendored Normal file
View File

@@ -0,0 +1,10 @@
language: go
go:
- 1.13
- 1.x
- tip
before_install:
- go get github.com/mattn/goveralls
- go get golang.org/x/tools/cmd/cover
script:
- $HOME/gopath/bin/goveralls -service=travis-ci

20
vendor/github.com/cenkalti/backoff/v4/LICENSE generated vendored Normal file
View File

@@ -0,0 +1,20 @@
The MIT License (MIT)
Copyright (c) 2014 Cenk Altı
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software is furnished to do so,
subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

32
vendor/github.com/cenkalti/backoff/v4/README.md generated vendored Normal file
View File

@@ -0,0 +1,32 @@
# Exponential Backoff [![GoDoc][godoc image]][godoc] [![Build Status][travis image]][travis] [![Coverage Status][coveralls image]][coveralls]
This is a Go port of the exponential backoff algorithm from [Google's HTTP Client Library for Java][google-http-java-client].
[Exponential backoff][exponential backoff wiki]
is an algorithm that uses feedback to multiplicatively decrease the rate of some process,
in order to gradually find an acceptable rate.
The retries exponentially increase and stop increasing when a certain threshold is met.
## Usage
Import path is `github.com/cenkalti/backoff/v4`. Please note the version part at the end.
Use https://pkg.go.dev/github.com/cenkalti/backoff/v4 to view the documentation.
## Contributing
* I would like to keep this library as small as possible.
* Please don't send a PR without opening an issue and discussing it first.
* If proposed change is not a common use case, I will probably not accept it.
[godoc]: https://pkg.go.dev/github.com/cenkalti/backoff/v4
[godoc image]: https://godoc.org/github.com/cenkalti/backoff?status.png
[travis]: https://travis-ci.org/cenkalti/backoff
[travis image]: https://travis-ci.org/cenkalti/backoff.png?branch=master
[coveralls]: https://coveralls.io/github/cenkalti/backoff?branch=master
[coveralls image]: https://coveralls.io/repos/github/cenkalti/backoff/badge.svg?branch=master
[google-http-java-client]: https://github.com/google/google-http-java-client/blob/da1aa993e90285ec18579f1553339b00e19b3ab5/google-http-client/src/main/java/com/google/api/client/util/ExponentialBackOff.java
[exponential backoff wiki]: http://en.wikipedia.org/wiki/Exponential_backoff
[advanced example]: https://pkg.go.dev/github.com/cenkalti/backoff/v4?tab=doc#pkg-examples

66
vendor/github.com/cenkalti/backoff/v4/backoff.go generated vendored Normal file
View File

@@ -0,0 +1,66 @@
// Package backoff implements backoff algorithms for retrying operations.
//
// Use Retry function for retrying operations that may fail.
// If Retry does not meet your needs,
// copy/paste the function into your project and modify as you wish.
//
// There is also Ticker type similar to time.Ticker.
// You can use it if you need to work with channels.
//
// See Examples section below for usage examples.
package backoff
import "time"
// BackOff is a backoff policy for retrying an operation.
type BackOff interface {
// NextBackOff returns the duration to wait before retrying the operation,
// or backoff. Stop to indicate that no more retries should be made.
//
// Example usage:
//
// duration := backoff.NextBackOff();
// if (duration == backoff.Stop) {
// // Do not retry operation.
// } else {
// // Sleep for duration and retry operation.
// }
//
NextBackOff() time.Duration
// Reset to initial state.
Reset()
}
// Stop indicates that no more retries should be made for use in NextBackOff().
const Stop time.Duration = -1
// ZeroBackOff is a fixed backoff policy whose backoff time is always zero,
// meaning that the operation is retried immediately without waiting, indefinitely.
type ZeroBackOff struct{}
func (b *ZeroBackOff) Reset() {}
func (b *ZeroBackOff) NextBackOff() time.Duration { return 0 }
// StopBackOff is a fixed backoff policy that always returns backoff.Stop for
// NextBackOff(), meaning that the operation should never be retried.
type StopBackOff struct{}
func (b *StopBackOff) Reset() {}
func (b *StopBackOff) NextBackOff() time.Duration { return Stop }
// ConstantBackOff is a backoff policy that always returns the same backoff delay.
// This is in contrast to an exponential backoff policy,
// which returns a delay that grows longer as you call NextBackOff() over and over again.
type ConstantBackOff struct {
Interval time.Duration
}
func (b *ConstantBackOff) Reset() {}
func (b *ConstantBackOff) NextBackOff() time.Duration { return b.Interval }
func NewConstantBackOff(d time.Duration) *ConstantBackOff {
return &ConstantBackOff{Interval: d}
}

62
vendor/github.com/cenkalti/backoff/v4/context.go generated vendored Normal file
View File

@@ -0,0 +1,62 @@
package backoff
import (
"context"
"time"
)
// BackOffContext is a backoff policy that stops retrying after the context
// is canceled.
type BackOffContext interface { // nolint: golint
BackOff
Context() context.Context
}
type backOffContext struct {
BackOff
ctx context.Context
}
// WithContext returns a BackOffContext with context ctx
//
// ctx must not be nil
func WithContext(b BackOff, ctx context.Context) BackOffContext { // nolint: golint
if ctx == nil {
panic("nil context")
}
if b, ok := b.(*backOffContext); ok {
return &backOffContext{
BackOff: b.BackOff,
ctx: ctx,
}
}
return &backOffContext{
BackOff: b,
ctx: ctx,
}
}
func getContext(b BackOff) context.Context {
if cb, ok := b.(BackOffContext); ok {
return cb.Context()
}
if tb, ok := b.(*backOffTries); ok {
return getContext(tb.delegate)
}
return context.Background()
}
func (b *backOffContext) Context() context.Context {
return b.ctx
}
func (b *backOffContext) NextBackOff() time.Duration {
select {
case <-b.ctx.Done():
return Stop
default:
return b.BackOff.NextBackOff()
}
}

161
vendor/github.com/cenkalti/backoff/v4/exponential.go generated vendored Normal file
View File

@@ -0,0 +1,161 @@
package backoff
import (
"math/rand"
"time"
)
/*
ExponentialBackOff is a backoff implementation that increases the backoff
period for each retry attempt using a randomization function that grows exponentially.
NextBackOff() is calculated using the following formula:
randomized interval =
RetryInterval * (random value in range [1 - RandomizationFactor, 1 + RandomizationFactor])
In other words NextBackOff() will range between the randomization factor
percentage below and above the retry interval.
For example, given the following parameters:
RetryInterval = 2
RandomizationFactor = 0.5
Multiplier = 2
the actual backoff period used in the next retry attempt will range between 1 and 3 seconds,
multiplied by the exponential, that is, between 2 and 6 seconds.
Note: MaxInterval caps the RetryInterval and not the randomized interval.
If the time elapsed since an ExponentialBackOff instance is created goes past the
MaxElapsedTime, then the method NextBackOff() starts returning backoff.Stop.
The elapsed time can be reset by calling Reset().
Example: Given the following default arguments, for 10 tries the sequence will be,
and assuming we go over the MaxElapsedTime on the 10th try:
Request # RetryInterval (seconds) Randomized Interval (seconds)
1 0.5 [0.25, 0.75]
2 0.75 [0.375, 1.125]
3 1.125 [0.562, 1.687]
4 1.687 [0.8435, 2.53]
5 2.53 [1.265, 3.795]
6 3.795 [1.897, 5.692]
7 5.692 [2.846, 8.538]
8 8.538 [4.269, 12.807]
9 12.807 [6.403, 19.210]
10 19.210 backoff.Stop
Note: Implementation is not thread-safe.
*/
type ExponentialBackOff struct {
InitialInterval time.Duration
RandomizationFactor float64
Multiplier float64
MaxInterval time.Duration
// After MaxElapsedTime the ExponentialBackOff returns Stop.
// It never stops if MaxElapsedTime == 0.
MaxElapsedTime time.Duration
Stop time.Duration
Clock Clock
currentInterval time.Duration
startTime time.Time
}
// Clock is an interface that returns current time for BackOff.
type Clock interface {
Now() time.Time
}
// Default values for ExponentialBackOff.
const (
DefaultInitialInterval = 500 * time.Millisecond
DefaultRandomizationFactor = 0.5
DefaultMultiplier = 1.5
DefaultMaxInterval = 60 * time.Second
DefaultMaxElapsedTime = 15 * time.Minute
)
// NewExponentialBackOff creates an instance of ExponentialBackOff using default values.
func NewExponentialBackOff() *ExponentialBackOff {
b := &ExponentialBackOff{
InitialInterval: DefaultInitialInterval,
RandomizationFactor: DefaultRandomizationFactor,
Multiplier: DefaultMultiplier,
MaxInterval: DefaultMaxInterval,
MaxElapsedTime: DefaultMaxElapsedTime,
Stop: Stop,
Clock: SystemClock,
}
b.Reset()
return b
}
type systemClock struct{}
func (t systemClock) Now() time.Time {
return time.Now()
}
// SystemClock implements Clock interface that uses time.Now().
var SystemClock = systemClock{}
// Reset the interval back to the initial retry interval and restarts the timer.
// Reset must be called before using b.
func (b *ExponentialBackOff) Reset() {
b.currentInterval = b.InitialInterval
b.startTime = b.Clock.Now()
}
// NextBackOff calculates the next backoff interval using the formula:
// Randomized interval = RetryInterval * (1 ± RandomizationFactor)
func (b *ExponentialBackOff) NextBackOff() time.Duration {
// Make sure we have not gone over the maximum elapsed time.
elapsed := b.GetElapsedTime()
next := getRandomValueFromInterval(b.RandomizationFactor, rand.Float64(), b.currentInterval)
b.incrementCurrentInterval()
if b.MaxElapsedTime != 0 && elapsed+next > b.MaxElapsedTime {
return b.Stop
}
return next
}
// GetElapsedTime returns the elapsed time since an ExponentialBackOff instance
// is created and is reset when Reset() is called.
//
// The elapsed time is computed using time.Now().UnixNano(). It is
// safe to call even while the backoff policy is used by a running
// ticker.
func (b *ExponentialBackOff) GetElapsedTime() time.Duration {
return b.Clock.Now().Sub(b.startTime)
}
// Increments the current interval by multiplying it with the multiplier.
func (b *ExponentialBackOff) incrementCurrentInterval() {
// Check for overflow, if overflow is detected set the current interval to the max interval.
if float64(b.currentInterval) >= float64(b.MaxInterval)/b.Multiplier {
b.currentInterval = b.MaxInterval
} else {
b.currentInterval = time.Duration(float64(b.currentInterval) * b.Multiplier)
}
}
// Returns a random value from the following interval:
// [currentInterval - randomizationFactor * currentInterval, currentInterval + randomizationFactor * currentInterval].
func getRandomValueFromInterval(randomizationFactor, random float64, currentInterval time.Duration) time.Duration {
if randomizationFactor == 0 {
return currentInterval // make sure no randomness is used when randomizationFactor is 0.
}
var delta = randomizationFactor * float64(currentInterval)
var minInterval = float64(currentInterval) - delta
var maxInterval = float64(currentInterval) + delta
// Get a random value from the range [minInterval, maxInterval].
// The formula used below has a +1 because if the minInterval is 1 and the maxInterval is 3 then
// we want a 33% chance for selecting either 1, 2 or 3.
return time.Duration(minInterval + (random * (maxInterval - minInterval + 1)))
}

112
vendor/github.com/cenkalti/backoff/v4/retry.go generated vendored Normal file
View File

@@ -0,0 +1,112 @@
package backoff
import (
"errors"
"time"
)
// An Operation is executing by Retry() or RetryNotify().
// The operation will be retried using a backoff policy if it returns an error.
type Operation func() error
// Notify is a notify-on-error function. It receives an operation error and
// backoff delay if the operation failed (with an error).
//
// NOTE that if the backoff policy stated to stop retrying,
// the notify function isn't called.
type Notify func(error, time.Duration)
// Retry the operation o until it does not return error or BackOff stops.
// o is guaranteed to be run at least once.
//
// If o returns a *PermanentError, the operation is not retried, and the
// wrapped error is returned.
//
// Retry sleeps the goroutine for the duration returned by BackOff after a
// failed operation returns.
func Retry(o Operation, b BackOff) error {
return RetryNotify(o, b, nil)
}
// RetryNotify calls notify function with the error and wait duration
// for each failed attempt before sleep.
func RetryNotify(operation Operation, b BackOff, notify Notify) error {
return RetryNotifyWithTimer(operation, b, notify, nil)
}
// RetryNotifyWithTimer calls notify function with the error and wait duration using the given Timer
// for each failed attempt before sleep.
// A default timer that uses system timer is used when nil is passed.
func RetryNotifyWithTimer(operation Operation, b BackOff, notify Notify, t Timer) error {
var err error
var next time.Duration
if t == nil {
t = &defaultTimer{}
}
defer func() {
t.Stop()
}()
ctx := getContext(b)
b.Reset()
for {
if err = operation(); err == nil {
return nil
}
var permanent *PermanentError
if errors.As(err, &permanent) {
return permanent.Err
}
if next = b.NextBackOff(); next == Stop {
if cerr := ctx.Err(); cerr != nil {
return cerr
}
return err
}
if notify != nil {
notify(err, next)
}
t.Start(next)
select {
case <-ctx.Done():
return ctx.Err()
case <-t.C():
}
}
}
// PermanentError signals that the operation should not be retried.
type PermanentError struct {
Err error
}
func (e *PermanentError) Error() string {
return e.Err.Error()
}
func (e *PermanentError) Unwrap() error {
return e.Err
}
func (e *PermanentError) Is(target error) bool {
_, ok := target.(*PermanentError)
return ok
}
// Permanent wraps the given err in a *PermanentError.
func Permanent(err error) error {
if err == nil {
return nil
}
return &PermanentError{
Err: err,
}
}

97
vendor/github.com/cenkalti/backoff/v4/ticker.go generated vendored Normal file
View File

@@ -0,0 +1,97 @@
package backoff
import (
"context"
"sync"
"time"
)
// Ticker holds a channel that delivers `ticks' of a clock at times reported by a BackOff.
//
// Ticks will continue to arrive when the previous operation is still running,
// so operations that take a while to fail could run in quick succession.
type Ticker struct {
C <-chan time.Time
c chan time.Time
b BackOff
ctx context.Context
timer Timer
stop chan struct{}
stopOnce sync.Once
}
// NewTicker returns a new Ticker containing a channel that will send
// the time at times specified by the BackOff argument. Ticker is
// guaranteed to tick at least once. The channel is closed when Stop
// method is called or BackOff stops. It is not safe to manipulate the
// provided backoff policy (notably calling NextBackOff or Reset)
// while the ticker is running.
func NewTicker(b BackOff) *Ticker {
return NewTickerWithTimer(b, &defaultTimer{})
}
// NewTickerWithTimer returns a new Ticker with a custom timer.
// A default timer that uses system timer is used when nil is passed.
func NewTickerWithTimer(b BackOff, timer Timer) *Ticker {
if timer == nil {
timer = &defaultTimer{}
}
c := make(chan time.Time)
t := &Ticker{
C: c,
c: c,
b: b,
ctx: getContext(b),
timer: timer,
stop: make(chan struct{}),
}
t.b.Reset()
go t.run()
return t
}
// Stop turns off a ticker. After Stop, no more ticks will be sent.
func (t *Ticker) Stop() {
t.stopOnce.Do(func() { close(t.stop) })
}
func (t *Ticker) run() {
c := t.c
defer close(c)
// Ticker is guaranteed to tick at least once.
afterC := t.send(time.Now())
for {
if afterC == nil {
return
}
select {
case tick := <-afterC:
afterC = t.send(tick)
case <-t.stop:
t.c = nil // Prevent future ticks from being sent to the channel.
return
case <-t.ctx.Done():
return
}
}
}
func (t *Ticker) send(tick time.Time) <-chan time.Time {
select {
case t.c <- tick:
case <-t.stop:
return nil
}
next := t.b.NextBackOff()
if next == Stop {
t.Stop()
return nil
}
t.timer.Start(next)
return t.timer.C()
}

35
vendor/github.com/cenkalti/backoff/v4/timer.go generated vendored Normal file
View File

@@ -0,0 +1,35 @@
package backoff
import "time"
type Timer interface {
Start(duration time.Duration)
Stop()
C() <-chan time.Time
}
// defaultTimer implements Timer interface using time.Timer
type defaultTimer struct {
timer *time.Timer
}
// C returns the timers channel which receives the current time when the timer fires.
func (t *defaultTimer) C() <-chan time.Time {
return t.timer.C
}
// Start starts the timer to fire after the given duration
func (t *defaultTimer) Start(duration time.Duration) {
if t.timer == nil {
t.timer = time.NewTimer(duration)
} else {
t.timer.Reset(duration)
}
}
// Stop is called when the timer is not used anymore and resources may be freed.
func (t *defaultTimer) Stop() {
if t.timer != nil {
t.timer.Stop()
}
}

38
vendor/github.com/cenkalti/backoff/v4/tries.go generated vendored Normal file
View File

@@ -0,0 +1,38 @@
package backoff
import "time"
/*
WithMaxRetries creates a wrapper around another BackOff, which will
return Stop if NextBackOff() has been called too many times since
the last time Reset() was called
Note: Implementation is not thread-safe.
*/
func WithMaxRetries(b BackOff, max uint64) BackOff {
return &backOffTries{delegate: b, maxTries: max}
}
type backOffTries struct {
delegate BackOff
maxTries uint64
numTries uint64
}
func (b *backOffTries) NextBackOff() time.Duration {
if b.maxTries == 0 {
return Stop
}
if b.maxTries > 0 {
if b.maxTries <= b.numTries {
return Stop
}
b.numTries++
}
return b.delegate.NextBackOff()
}
func (b *backOffTries) Reset() {
b.numTries = 0
b.delegate.Reset()
}

View File

@@ -1,5 +1,9 @@
# Change history of go-restful
## [v3.9.0] - 20221-07-21
- add support for http.Handler implementations to work as FilterFunction, issue #504 (thanks to https://github.com/ggicci)
## [v3.8.0] - 20221-06-06
- use exact matching of allowed domain entries, issue #489 (#493)

View File

@@ -84,6 +84,7 @@ func (u UserResource) findUser(request *restful.Request, response *restful.Respo
- Route errors produce HTTP 404/405/406/415 errors, customizable using ServiceErrorHandler(...)
- Configurable (trace) logging
- Customizable gzip/deflate readers and writers using CompressorProvider registration
- Inject your own http.Handler using the `HttpMiddlewareHandlerToFilter` function
## How to customize
There are several hooks to customize the behavior of the go-restful package.
@@ -94,7 +95,7 @@ There are several hooks to customize the behavior of the go-restful package.
- Trace logging
- Compression
- Encoders for other serializers
- Use [jsoniter](https://github.com/json-iterator/go) by build this package using a tag, e.g. `go build -tags=jsoniter .`
- Use [jsoniter](https://github.com/json-iterator/go) by building this package using a build tag, e.g. `go build -tags=jsoniter .`
## Resources

View File

@@ -0,0 +1,21 @@
package restful
import (
"net/http"
)
// HttpMiddlewareHandler is a function that takes a http.Handler and returns a http.Handler
type HttpMiddlewareHandler func(http.Handler) http.Handler
// HttpMiddlewareHandlerToFilter converts a HttpMiddlewareHandler to a FilterFunction.
func HttpMiddlewareHandlerToFilter(middleware HttpMiddlewareHandler) FilterFunction {
return func(req *Request, resp *Response, chain *FilterChain) {
next := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
req.Request = r
resp.ResponseWriter = rw
chain.ProcessFilter(req, resp)
})
middleware(next).ServeHTTP(resp.ResponseWriter, req.Request)
}
}

View File

@@ -22,6 +22,9 @@ const (
// FormParameterKind = indicator of Request parameter type "form"
FormParameterKind
// MultiPartFormParameterKind = indicator of Request parameter type "multipart/form-data"
MultiPartFormParameterKind
// CollectionFormatCSV comma separated values `foo,bar`
CollectionFormatCSV = CollectionFormat("csv")
@@ -108,6 +111,11 @@ func (p *Parameter) beForm() *Parameter {
return p
}
func (p *Parameter) beMultiPartForm() *Parameter {
p.data.Kind = MultiPartFormParameterKind
return p
}
// Required sets the required field and returns the receiver
func (p *Parameter) Required(required bool) *Parameter {
p.data.Required = required

View File

@@ -165,6 +165,18 @@ func FormParameter(name, description string) *Parameter {
return p
}
// MultiPartFormParameter creates a new Parameter of kind Form (using multipart/form-data) for documentation purposes.
// It is initialized as required with string as its DataType.
func (w *WebService) MultiPartFormParameter(name, description string) *Parameter {
return MultiPartFormParameter(name, description)
}
func MultiPartFormParameter(name, description string) *Parameter {
p := &Parameter{&ParameterData{Name: name, Description: description, Required: false, DataType: "string"}}
p.beMultiPartForm()
return p
}
// Route creates a new Route using the RouteBuilder and add to the ordered list of Routes.
func (w *WebService) Route(builder *RouteBuilder) *WebService {
w.routesLock.Lock()

View File

@@ -65,7 +65,8 @@ being called, or called more than once, as well as concurrent calls to
Unfortunately this package is not perfect either. It's possible that it is
still missing some interfaces provided by the go core (let me know if you find
one), and it won't work for applications adding their own interfaces into the
mix.
mix. You can however use `httpsnoop.Unwrap(w)` to access the underlying
`http.ResponseWriter` and type-assert the result to its other interfaces.
However, hopefully the explanation above has sufficiently scared you of rolling
your own solution to this problem. httpsnoop may still break your application,

View File

@@ -3,7 +3,6 @@ package httpsnoop
import (
"io"
"net/http"
"sync"
"time"
)
@@ -36,17 +35,23 @@ func CaptureMetrics(hnd http.Handler, w http.ResponseWriter, r *http.Request) Me
// sugar on top of this func), but is a more usable interface if your
// application doesn't use the Go http.Handler interface.
func CaptureMetricsFn(w http.ResponseWriter, fn func(http.ResponseWriter)) Metrics {
m := Metrics{Code: http.StatusOK}
m.CaptureMetrics(w, fn)
return m
}
// CaptureMetrics wraps w and calls fn with the wrapped w and updates
// Metrics m with the resulting metrics. This is similar to CaptureMetricsFn,
// but allows one to customize starting Metrics object.
func (m *Metrics) CaptureMetrics(w http.ResponseWriter, fn func(http.ResponseWriter)) {
var (
start = time.Now()
m = Metrics{Code: http.StatusOK}
headerWritten bool
lock sync.Mutex
hooks = Hooks{
WriteHeader: func(next WriteHeaderFunc) WriteHeaderFunc {
return func(code int) {
next(code)
lock.Lock()
defer lock.Unlock()
if !headerWritten {
m.Code = code
headerWritten = true
@@ -57,8 +62,7 @@ func CaptureMetricsFn(w http.ResponseWriter, fn func(http.ResponseWriter)) Metri
Write: func(next WriteFunc) WriteFunc {
return func(p []byte) (int, error) {
n, err := next(p)
lock.Lock()
defer lock.Unlock()
m.Written += int64(n)
headerWritten = true
return n, err
@@ -68,8 +72,7 @@ func CaptureMetricsFn(w http.ResponseWriter, fn func(http.ResponseWriter)) Metri
ReadFrom: func(next ReadFromFunc) ReadFromFunc {
return func(src io.Reader) (int64, error) {
n, err := next(src)
lock.Lock()
defer lock.Unlock()
headerWritten = true
m.Written += n
return n, err
@@ -79,6 +82,5 @@ func CaptureMetricsFn(w http.ResponseWriter, fn func(http.ResponseWriter)) Metri
)
fn(Wrap(w, hooks))
m.Duration = time.Since(start)
return m
m.Duration += time.Since(start)
}

View File

@@ -74,243 +74,275 @@ func Wrap(w http.ResponseWriter, hooks Hooks) http.ResponseWriter {
// combination 1/32
case !i0 && !i1 && !i2 && !i3 && !i4:
return struct {
Unwrapper
http.ResponseWriter
}{rw}
}{rw, rw}
// combination 2/32
case !i0 && !i1 && !i2 && !i3 && i4:
return struct {
Unwrapper
http.ResponseWriter
http.Pusher
}{rw, rw}
}{rw, rw, rw}
// combination 3/32
case !i0 && !i1 && !i2 && i3 && !i4:
return struct {
Unwrapper
http.ResponseWriter
io.ReaderFrom
}{rw, rw}
}{rw, rw, rw}
// combination 4/32
case !i0 && !i1 && !i2 && i3 && i4:
return struct {
Unwrapper
http.ResponseWriter
io.ReaderFrom
http.Pusher
}{rw, rw, rw}
}{rw, rw, rw, rw}
// combination 5/32
case !i0 && !i1 && i2 && !i3 && !i4:
return struct {
Unwrapper
http.ResponseWriter
http.Hijacker
}{rw, rw}
}{rw, rw, rw}
// combination 6/32
case !i0 && !i1 && i2 && !i3 && i4:
return struct {
Unwrapper
http.ResponseWriter
http.Hijacker
http.Pusher
}{rw, rw, rw}
}{rw, rw, rw, rw}
// combination 7/32
case !i0 && !i1 && i2 && i3 && !i4:
return struct {
Unwrapper
http.ResponseWriter
http.Hijacker
io.ReaderFrom
}{rw, rw, rw}
}{rw, rw, rw, rw}
// combination 8/32
case !i0 && !i1 && i2 && i3 && i4:
return struct {
Unwrapper
http.ResponseWriter
http.Hijacker
io.ReaderFrom
http.Pusher
}{rw, rw, rw, rw}
}{rw, rw, rw, rw, rw}
// combination 9/32
case !i0 && i1 && !i2 && !i3 && !i4:
return struct {
Unwrapper
http.ResponseWriter
http.CloseNotifier
}{rw, rw}
}{rw, rw, rw}
// combination 10/32
case !i0 && i1 && !i2 && !i3 && i4:
return struct {
Unwrapper
http.ResponseWriter
http.CloseNotifier
http.Pusher
}{rw, rw, rw}
}{rw, rw, rw, rw}
// combination 11/32
case !i0 && i1 && !i2 && i3 && !i4:
return struct {
Unwrapper
http.ResponseWriter
http.CloseNotifier
io.ReaderFrom
}{rw, rw, rw}
}{rw, rw, rw, rw}
// combination 12/32
case !i0 && i1 && !i2 && i3 && i4:
return struct {
Unwrapper
http.ResponseWriter
http.CloseNotifier
io.ReaderFrom
http.Pusher
}{rw, rw, rw, rw}
}{rw, rw, rw, rw, rw}
// combination 13/32
case !i0 && i1 && i2 && !i3 && !i4:
return struct {
Unwrapper
http.ResponseWriter
http.CloseNotifier
http.Hijacker
}{rw, rw, rw}
}{rw, rw, rw, rw}
// combination 14/32
case !i0 && i1 && i2 && !i3 && i4:
return struct {
Unwrapper
http.ResponseWriter
http.CloseNotifier
http.Hijacker
http.Pusher
}{rw, rw, rw, rw}
}{rw, rw, rw, rw, rw}
// combination 15/32
case !i0 && i1 && i2 && i3 && !i4:
return struct {
Unwrapper
http.ResponseWriter
http.CloseNotifier
http.Hijacker
io.ReaderFrom
}{rw, rw, rw, rw}
}{rw, rw, rw, rw, rw}
// combination 16/32
case !i0 && i1 && i2 && i3 && i4:
return struct {
Unwrapper
http.ResponseWriter
http.CloseNotifier
http.Hijacker
io.ReaderFrom
http.Pusher
}{rw, rw, rw, rw, rw}
// combination 17/32
case i0 && !i1 && !i2 && !i3 && !i4:
return struct {
http.ResponseWriter
http.Flusher
}{rw, rw}
// combination 18/32
case i0 && !i1 && !i2 && !i3 && i4:
return struct {
http.ResponseWriter
http.Flusher
http.Pusher
}{rw, rw, rw}
// combination 19/32
case i0 && !i1 && !i2 && i3 && !i4:
return struct {
http.ResponseWriter
http.Flusher
io.ReaderFrom
}{rw, rw, rw}
// combination 20/32
case i0 && !i1 && !i2 && i3 && i4:
return struct {
http.ResponseWriter
http.Flusher
io.ReaderFrom
http.Pusher
}{rw, rw, rw, rw}
// combination 21/32
case i0 && !i1 && i2 && !i3 && !i4:
return struct {
http.ResponseWriter
http.Flusher
http.Hijacker
}{rw, rw, rw}
// combination 22/32
case i0 && !i1 && i2 && !i3 && i4:
return struct {
http.ResponseWriter
http.Flusher
http.Hijacker
http.Pusher
}{rw, rw, rw, rw}
// combination 23/32
case i0 && !i1 && i2 && i3 && !i4:
return struct {
http.ResponseWriter
http.Flusher
http.Hijacker
io.ReaderFrom
}{rw, rw, rw, rw}
// combination 24/32
case i0 && !i1 && i2 && i3 && i4:
return struct {
http.ResponseWriter
http.Flusher
http.Hijacker
io.ReaderFrom
http.Pusher
}{rw, rw, rw, rw, rw}
// combination 25/32
case i0 && i1 && !i2 && !i3 && !i4:
return struct {
http.ResponseWriter
http.Flusher
http.CloseNotifier
}{rw, rw, rw}
// combination 26/32
case i0 && i1 && !i2 && !i3 && i4:
return struct {
http.ResponseWriter
http.Flusher
http.CloseNotifier
http.Pusher
}{rw, rw, rw, rw}
// combination 27/32
case i0 && i1 && !i2 && i3 && !i4:
return struct {
http.ResponseWriter
http.Flusher
http.CloseNotifier
io.ReaderFrom
}{rw, rw, rw, rw}
// combination 28/32
case i0 && i1 && !i2 && i3 && i4:
return struct {
http.ResponseWriter
http.Flusher
http.CloseNotifier
io.ReaderFrom
http.Pusher
}{rw, rw, rw, rw, rw}
// combination 29/32
case i0 && i1 && i2 && !i3 && !i4:
return struct {
http.ResponseWriter
http.Flusher
http.CloseNotifier
http.Hijacker
}{rw, rw, rw, rw}
// combination 30/32
case i0 && i1 && i2 && !i3 && i4:
return struct {
http.ResponseWriter
http.Flusher
http.CloseNotifier
http.Hijacker
http.Pusher
}{rw, rw, rw, rw, rw}
// combination 31/32
case i0 && i1 && i2 && i3 && !i4:
return struct {
http.ResponseWriter
http.Flusher
http.CloseNotifier
http.Hijacker
io.ReaderFrom
}{rw, rw, rw, rw, rw}
// combination 32/32
case i0 && i1 && i2 && i3 && i4:
return struct {
http.ResponseWriter
http.Flusher
http.CloseNotifier
http.Hijacker
io.ReaderFrom
http.Pusher
}{rw, rw, rw, rw, rw, rw}
// combination 17/32
case i0 && !i1 && !i2 && !i3 && !i4:
return struct {
Unwrapper
http.ResponseWriter
http.Flusher
}{rw, rw, rw}
// combination 18/32
case i0 && !i1 && !i2 && !i3 && i4:
return struct {
Unwrapper
http.ResponseWriter
http.Flusher
http.Pusher
}{rw, rw, rw, rw}
// combination 19/32
case i0 && !i1 && !i2 && i3 && !i4:
return struct {
Unwrapper
http.ResponseWriter
http.Flusher
io.ReaderFrom
}{rw, rw, rw, rw}
// combination 20/32
case i0 && !i1 && !i2 && i3 && i4:
return struct {
Unwrapper
http.ResponseWriter
http.Flusher
io.ReaderFrom
http.Pusher
}{rw, rw, rw, rw, rw}
// combination 21/32
case i0 && !i1 && i2 && !i3 && !i4:
return struct {
Unwrapper
http.ResponseWriter
http.Flusher
http.Hijacker
}{rw, rw, rw, rw}
// combination 22/32
case i0 && !i1 && i2 && !i3 && i4:
return struct {
Unwrapper
http.ResponseWriter
http.Flusher
http.Hijacker
http.Pusher
}{rw, rw, rw, rw, rw}
// combination 23/32
case i0 && !i1 && i2 && i3 && !i4:
return struct {
Unwrapper
http.ResponseWriter
http.Flusher
http.Hijacker
io.ReaderFrom
}{rw, rw, rw, rw, rw}
// combination 24/32
case i0 && !i1 && i2 && i3 && i4:
return struct {
Unwrapper
http.ResponseWriter
http.Flusher
http.Hijacker
io.ReaderFrom
http.Pusher
}{rw, rw, rw, rw, rw, rw}
// combination 25/32
case i0 && i1 && !i2 && !i3 && !i4:
return struct {
Unwrapper
http.ResponseWriter
http.Flusher
http.CloseNotifier
}{rw, rw, rw, rw}
// combination 26/32
case i0 && i1 && !i2 && !i3 && i4:
return struct {
Unwrapper
http.ResponseWriter
http.Flusher
http.CloseNotifier
http.Pusher
}{rw, rw, rw, rw, rw}
// combination 27/32
case i0 && i1 && !i2 && i3 && !i4:
return struct {
Unwrapper
http.ResponseWriter
http.Flusher
http.CloseNotifier
io.ReaderFrom
}{rw, rw, rw, rw, rw}
// combination 28/32
case i0 && i1 && !i2 && i3 && i4:
return struct {
Unwrapper
http.ResponseWriter
http.Flusher
http.CloseNotifier
io.ReaderFrom
http.Pusher
}{rw, rw, rw, rw, rw, rw}
// combination 29/32
case i0 && i1 && i2 && !i3 && !i4:
return struct {
Unwrapper
http.ResponseWriter
http.Flusher
http.CloseNotifier
http.Hijacker
}{rw, rw, rw, rw, rw}
// combination 30/32
case i0 && i1 && i2 && !i3 && i4:
return struct {
Unwrapper
http.ResponseWriter
http.Flusher
http.CloseNotifier
http.Hijacker
http.Pusher
}{rw, rw, rw, rw, rw, rw}
// combination 31/32
case i0 && i1 && i2 && i3 && !i4:
return struct {
Unwrapper
http.ResponseWriter
http.Flusher
http.CloseNotifier
http.Hijacker
io.ReaderFrom
}{rw, rw, rw, rw, rw, rw}
// combination 32/32
case i0 && i1 && i2 && i3 && i4:
return struct {
Unwrapper
http.ResponseWriter
http.Flusher
http.CloseNotifier
http.Hijacker
io.ReaderFrom
http.Pusher
}{rw, rw, rw, rw, rw, rw, rw}
}
panic("unreachable")
}
@@ -320,6 +352,10 @@ type rw struct {
h Hooks
}
func (w *rw) Unwrap() http.ResponseWriter {
return w.w
}
func (w *rw) Header() http.Header {
f := w.w.(http.ResponseWriter).Header
if w.h.Header != nil {
@@ -383,3 +419,18 @@ func (w *rw) Push(target string, opts *http.PushOptions) error {
}
return f(target, opts)
}
type Unwrapper interface {
Unwrap() http.ResponseWriter
}
// Unwrap returns the underlying http.ResponseWriter from within zero or more
// layers of httpsnoop wrappers.
func Unwrap(w http.ResponseWriter) http.ResponseWriter {
if rw, ok := w.(Unwrapper); ok {
// recurse until rw.Unwrap() returns a non-Unwrapper
return Unwrap(rw.Unwrap())
} else {
return w
}
}

View File

@@ -68,115 +68,131 @@ func Wrap(w http.ResponseWriter, hooks Hooks) http.ResponseWriter {
// combination 1/16
case !i0 && !i1 && !i2 && !i3:
return struct {
Unwrapper
http.ResponseWriter
}{rw}
}{rw, rw}
// combination 2/16
case !i0 && !i1 && !i2 && i3:
return struct {
Unwrapper
http.ResponseWriter
io.ReaderFrom
}{rw, rw}
}{rw, rw, rw}
// combination 3/16
case !i0 && !i1 && i2 && !i3:
return struct {
Unwrapper
http.ResponseWriter
http.Hijacker
}{rw, rw}
}{rw, rw, rw}
// combination 4/16
case !i0 && !i1 && i2 && i3:
return struct {
Unwrapper
http.ResponseWriter
http.Hijacker
io.ReaderFrom
}{rw, rw, rw}
}{rw, rw, rw, rw}
// combination 5/16
case !i0 && i1 && !i2 && !i3:
return struct {
Unwrapper
http.ResponseWriter
http.CloseNotifier
}{rw, rw}
}{rw, rw, rw}
// combination 6/16
case !i0 && i1 && !i2 && i3:
return struct {
Unwrapper
http.ResponseWriter
http.CloseNotifier
io.ReaderFrom
}{rw, rw, rw}
}{rw, rw, rw, rw}
// combination 7/16
case !i0 && i1 && i2 && !i3:
return struct {
Unwrapper
http.ResponseWriter
http.CloseNotifier
http.Hijacker
}{rw, rw, rw}
}{rw, rw, rw, rw}
// combination 8/16
case !i0 && i1 && i2 && i3:
return struct {
Unwrapper
http.ResponseWriter
http.CloseNotifier
http.Hijacker
io.ReaderFrom
}{rw, rw, rw, rw}
// combination 9/16
case i0 && !i1 && !i2 && !i3:
return struct {
http.ResponseWriter
http.Flusher
}{rw, rw}
// combination 10/16
case i0 && !i1 && !i2 && i3:
return struct {
http.ResponseWriter
http.Flusher
io.ReaderFrom
}{rw, rw, rw}
// combination 11/16
case i0 && !i1 && i2 && !i3:
return struct {
http.ResponseWriter
http.Flusher
http.Hijacker
}{rw, rw, rw}
// combination 12/16
case i0 && !i1 && i2 && i3:
return struct {
http.ResponseWriter
http.Flusher
http.Hijacker
io.ReaderFrom
}{rw, rw, rw, rw}
// combination 13/16
case i0 && i1 && !i2 && !i3:
return struct {
http.ResponseWriter
http.Flusher
http.CloseNotifier
}{rw, rw, rw}
// combination 14/16
case i0 && i1 && !i2 && i3:
return struct {
http.ResponseWriter
http.Flusher
http.CloseNotifier
io.ReaderFrom
}{rw, rw, rw, rw}
// combination 15/16
case i0 && i1 && i2 && !i3:
return struct {
http.ResponseWriter
http.Flusher
http.CloseNotifier
http.Hijacker
}{rw, rw, rw, rw}
// combination 16/16
case i0 && i1 && i2 && i3:
return struct {
http.ResponseWriter
http.Flusher
http.CloseNotifier
http.Hijacker
io.ReaderFrom
}{rw, rw, rw, rw, rw}
// combination 9/16
case i0 && !i1 && !i2 && !i3:
return struct {
Unwrapper
http.ResponseWriter
http.Flusher
}{rw, rw, rw}
// combination 10/16
case i0 && !i1 && !i2 && i3:
return struct {
Unwrapper
http.ResponseWriter
http.Flusher
io.ReaderFrom
}{rw, rw, rw, rw}
// combination 11/16
case i0 && !i1 && i2 && !i3:
return struct {
Unwrapper
http.ResponseWriter
http.Flusher
http.Hijacker
}{rw, rw, rw, rw}
// combination 12/16
case i0 && !i1 && i2 && i3:
return struct {
Unwrapper
http.ResponseWriter
http.Flusher
http.Hijacker
io.ReaderFrom
}{rw, rw, rw, rw, rw}
// combination 13/16
case i0 && i1 && !i2 && !i3:
return struct {
Unwrapper
http.ResponseWriter
http.Flusher
http.CloseNotifier
}{rw, rw, rw, rw}
// combination 14/16
case i0 && i1 && !i2 && i3:
return struct {
Unwrapper
http.ResponseWriter
http.Flusher
http.CloseNotifier
io.ReaderFrom
}{rw, rw, rw, rw, rw}
// combination 15/16
case i0 && i1 && i2 && !i3:
return struct {
Unwrapper
http.ResponseWriter
http.Flusher
http.CloseNotifier
http.Hijacker
}{rw, rw, rw, rw, rw}
// combination 16/16
case i0 && i1 && i2 && i3:
return struct {
Unwrapper
http.ResponseWriter
http.Flusher
http.CloseNotifier
http.Hijacker
io.ReaderFrom
}{rw, rw, rw, rw, rw, rw}
}
panic("unreachable")
}
@@ -186,6 +202,10 @@ type rw struct {
h Hooks
}
func (w *rw) Unwrap() http.ResponseWriter {
return w.w
}
func (w *rw) Header() http.Header {
f := w.w.(http.ResponseWriter).Header
if w.h.Header != nil {
@@ -241,3 +261,18 @@ func (w *rw) ReadFrom(src io.Reader) (int64, error) {
}
return f(src)
}
type Unwrapper interface {
Unwrap() http.ResponseWriter
}
// Unwrap returns the underlying http.ResponseWriter from within zero or more
// layers of httpsnoop wrappers.
func Unwrap(w http.ResponseWriter) http.ResponseWriter {
if rw, ok := w.(Unwrapper); ok {
// recurse until rw.Unwrap() returns a non-Unwrapper
return Unwrap(rw.Unwrap())
} else {
return w
}
}

787
vendor/github.com/go-logr/logr/funcr/funcr.go generated vendored Normal file
View File

@@ -0,0 +1,787 @@
/*
Copyright 2021 The logr Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Package funcr implements formatting of structured log messages and
// optionally captures the call site and timestamp.
//
// The simplest way to use it is via its implementation of a
// github.com/go-logr/logr.LogSink with output through an arbitrary
// "write" function. See New and NewJSON for details.
//
// Custom LogSinks
//
// For users who need more control, a funcr.Formatter can be embedded inside
// your own custom LogSink implementation. This is useful when the LogSink
// needs to implement additional methods, for example.
//
// Formatting
//
// This will respect logr.Marshaler, fmt.Stringer, and error interfaces for
// values which are being logged. When rendering a struct, funcr will use Go's
// standard JSON tags (all except "string").
package funcr
import (
"bytes"
"encoding"
"fmt"
"path/filepath"
"reflect"
"runtime"
"strconv"
"strings"
"time"
"github.com/go-logr/logr"
)
// New returns a logr.Logger which is implemented by an arbitrary function.
func New(fn func(prefix, args string), opts Options) logr.Logger {
return logr.New(newSink(fn, NewFormatter(opts)))
}
// NewJSON returns a logr.Logger which is implemented by an arbitrary function
// and produces JSON output.
func NewJSON(fn func(obj string), opts Options) logr.Logger {
fnWrapper := func(_, obj string) {
fn(obj)
}
return logr.New(newSink(fnWrapper, NewFormatterJSON(opts)))
}
// Underlier exposes access to the underlying logging function. Since
// callers only have a logr.Logger, they have to know which
// implementation is in use, so this interface is less of an
// abstraction and more of a way to test type conversion.
type Underlier interface {
GetUnderlying() func(prefix, args string)
}
func newSink(fn func(prefix, args string), formatter Formatter) logr.LogSink {
l := &fnlogger{
Formatter: formatter,
write: fn,
}
// For skipping fnlogger.Info and fnlogger.Error.
l.Formatter.AddCallDepth(1)
return l
}
// Options carries parameters which influence the way logs are generated.
type Options struct {
// LogCaller tells funcr to add a "caller" key to some or all log lines.
// This has some overhead, so some users might not want it.
LogCaller MessageClass
// LogCallerFunc tells funcr to also log the calling function name. This
// has no effect if caller logging is not enabled (see Options.LogCaller).
LogCallerFunc bool
// LogTimestamp tells funcr to add a "ts" key to log lines. This has some
// overhead, so some users might not want it.
LogTimestamp bool
// TimestampFormat tells funcr how to render timestamps when LogTimestamp
// is enabled. If not specified, a default format will be used. For more
// details, see docs for Go's time.Layout.
TimestampFormat string
// Verbosity tells funcr which V logs to produce. Higher values enable
// more logs. Info logs at or below this level will be written, while logs
// above this level will be discarded.
Verbosity int
// RenderBuiltinsHook allows users to mutate the list of key-value pairs
// while a log line is being rendered. The kvList argument follows logr
// conventions - each pair of slice elements is comprised of a string key
// and an arbitrary value (verified and sanitized before calling this
// hook). The value returned must follow the same conventions. This hook
// can be used to audit or modify logged data. For example, you might want
// to prefix all of funcr's built-in keys with some string. This hook is
// only called for built-in (provided by funcr itself) key-value pairs.
// Equivalent hooks are offered for key-value pairs saved via
// logr.Logger.WithValues or Formatter.AddValues (see RenderValuesHook) and
// for user-provided pairs (see RenderArgsHook).
RenderBuiltinsHook func(kvList []interface{}) []interface{}
// RenderValuesHook is the same as RenderBuiltinsHook, except that it is
// only called for key-value pairs saved via logr.Logger.WithValues. See
// RenderBuiltinsHook for more details.
RenderValuesHook func(kvList []interface{}) []interface{}
// RenderArgsHook is the same as RenderBuiltinsHook, except that it is only
// called for key-value pairs passed directly to Info and Error. See
// RenderBuiltinsHook for more details.
RenderArgsHook func(kvList []interface{}) []interface{}
// MaxLogDepth tells funcr how many levels of nested fields (e.g. a struct
// that contains a struct, etc.) it may log. Every time it finds a struct,
// slice, array, or map the depth is increased by one. When the maximum is
// reached, the value will be converted to a string indicating that the max
// depth has been exceeded. If this field is not specified, a default
// value will be used.
MaxLogDepth int
}
// MessageClass indicates which category or categories of messages to consider.
type MessageClass int
const (
// None ignores all message classes.
None MessageClass = iota
// All considers all message classes.
All
// Info only considers info messages.
Info
// Error only considers error messages.
Error
)
// fnlogger inherits some of its LogSink implementation from Formatter
// and just needs to add some glue code.
type fnlogger struct {
Formatter
write func(prefix, args string)
}
func (l fnlogger) WithName(name string) logr.LogSink {
l.Formatter.AddName(name)
return &l
}
func (l fnlogger) WithValues(kvList ...interface{}) logr.LogSink {
l.Formatter.AddValues(kvList)
return &l
}
func (l fnlogger) WithCallDepth(depth int) logr.LogSink {
l.Formatter.AddCallDepth(depth)
return &l
}
func (l fnlogger) Info(level int, msg string, kvList ...interface{}) {
prefix, args := l.FormatInfo(level, msg, kvList)
l.write(prefix, args)
}
func (l fnlogger) Error(err error, msg string, kvList ...interface{}) {
prefix, args := l.FormatError(err, msg, kvList)
l.write(prefix, args)
}
func (l fnlogger) GetUnderlying() func(prefix, args string) {
return l.write
}
// Assert conformance to the interfaces.
var _ logr.LogSink = &fnlogger{}
var _ logr.CallDepthLogSink = &fnlogger{}
var _ Underlier = &fnlogger{}
// NewFormatter constructs a Formatter which emits a JSON-like key=value format.
func NewFormatter(opts Options) Formatter {
return newFormatter(opts, outputKeyValue)
}
// NewFormatterJSON constructs a Formatter which emits strict JSON.
func NewFormatterJSON(opts Options) Formatter {
return newFormatter(opts, outputJSON)
}
// Defaults for Options.
const defaultTimestampFormat = "2006-01-02 15:04:05.000000"
const defaultMaxLogDepth = 16
func newFormatter(opts Options, outfmt outputFormat) Formatter {
if opts.TimestampFormat == "" {
opts.TimestampFormat = defaultTimestampFormat
}
if opts.MaxLogDepth == 0 {
opts.MaxLogDepth = defaultMaxLogDepth
}
f := Formatter{
outputFormat: outfmt,
prefix: "",
values: nil,
depth: 0,
opts: opts,
}
return f
}
// Formatter is an opaque struct which can be embedded in a LogSink
// implementation. It should be constructed with NewFormatter. Some of
// its methods directly implement logr.LogSink.
type Formatter struct {
outputFormat outputFormat
prefix string
values []interface{}
valuesStr string
depth int
opts Options
}
// outputFormat indicates which outputFormat to use.
type outputFormat int
const (
// outputKeyValue emits a JSON-like key=value format, but not strict JSON.
outputKeyValue outputFormat = iota
// outputJSON emits strict JSON.
outputJSON
)
// PseudoStruct is a list of key-value pairs that gets logged as a struct.
type PseudoStruct []interface{}
// render produces a log line, ready to use.
func (f Formatter) render(builtins, args []interface{}) string {
// Empirically bytes.Buffer is faster than strings.Builder for this.
buf := bytes.NewBuffer(make([]byte, 0, 1024))
if f.outputFormat == outputJSON {
buf.WriteByte('{')
}
vals := builtins
if hook := f.opts.RenderBuiltinsHook; hook != nil {
vals = hook(f.sanitize(vals))
}
f.flatten(buf, vals, false, false) // keys are ours, no need to escape
continuing := len(builtins) > 0
if len(f.valuesStr) > 0 {
if continuing {
if f.outputFormat == outputJSON {
buf.WriteByte(',')
} else {
buf.WriteByte(' ')
}
}
continuing = true
buf.WriteString(f.valuesStr)
}
vals = args
if hook := f.opts.RenderArgsHook; hook != nil {
vals = hook(f.sanitize(vals))
}
f.flatten(buf, vals, continuing, true) // escape user-provided keys
if f.outputFormat == outputJSON {
buf.WriteByte('}')
}
return buf.String()
}
// flatten renders a list of key-value pairs into a buffer. If continuing is
// true, it assumes that the buffer has previous values and will emit a
// separator (which depends on the output format) before the first pair it
// writes. If escapeKeys is true, the keys are assumed to have
// non-JSON-compatible characters in them and must be evaluated for escapes.
//
// This function returns a potentially modified version of kvList, which
// ensures that there is a value for every key (adding a value if needed) and
// that each key is a string (substituting a key if needed).
func (f Formatter) flatten(buf *bytes.Buffer, kvList []interface{}, continuing bool, escapeKeys bool) []interface{} {
// This logic overlaps with sanitize() but saves one type-cast per key,
// which can be measurable.
if len(kvList)%2 != 0 {
kvList = append(kvList, noValue)
}
for i := 0; i < len(kvList); i += 2 {
k, ok := kvList[i].(string)
if !ok {
k = f.nonStringKey(kvList[i])
kvList[i] = k
}
v := kvList[i+1]
if i > 0 || continuing {
if f.outputFormat == outputJSON {
buf.WriteByte(',')
} else {
// In theory the format could be something we don't understand. In
// practice, we control it, so it won't be.
buf.WriteByte(' ')
}
}
if escapeKeys {
buf.WriteString(prettyString(k))
} else {
// this is faster
buf.WriteByte('"')
buf.WriteString(k)
buf.WriteByte('"')
}
if f.outputFormat == outputJSON {
buf.WriteByte(':')
} else {
buf.WriteByte('=')
}
buf.WriteString(f.pretty(v))
}
return kvList
}
func (f Formatter) pretty(value interface{}) string {
return f.prettyWithFlags(value, 0, 0)
}
const (
flagRawStruct = 0x1 // do not print braces on structs
)
// TODO: This is not fast. Most of the overhead goes here.
func (f Formatter) prettyWithFlags(value interface{}, flags uint32, depth int) string {
if depth > f.opts.MaxLogDepth {
return `"<max-log-depth-exceeded>"`
}
// Handle types that take full control of logging.
if v, ok := value.(logr.Marshaler); ok {
// Replace the value with what the type wants to get logged.
// That then gets handled below via reflection.
value = invokeMarshaler(v)
}
// Handle types that want to format themselves.
switch v := value.(type) {
case fmt.Stringer:
value = invokeStringer(v)
case error:
value = invokeError(v)
}
// Handling the most common types without reflect is a small perf win.
switch v := value.(type) {
case bool:
return strconv.FormatBool(v)
case string:
return prettyString(v)
case int:
return strconv.FormatInt(int64(v), 10)
case int8:
return strconv.FormatInt(int64(v), 10)
case int16:
return strconv.FormatInt(int64(v), 10)
case int32:
return strconv.FormatInt(int64(v), 10)
case int64:
return strconv.FormatInt(int64(v), 10)
case uint:
return strconv.FormatUint(uint64(v), 10)
case uint8:
return strconv.FormatUint(uint64(v), 10)
case uint16:
return strconv.FormatUint(uint64(v), 10)
case uint32:
return strconv.FormatUint(uint64(v), 10)
case uint64:
return strconv.FormatUint(v, 10)
case uintptr:
return strconv.FormatUint(uint64(v), 10)
case float32:
return strconv.FormatFloat(float64(v), 'f', -1, 32)
case float64:
return strconv.FormatFloat(v, 'f', -1, 64)
case complex64:
return `"` + strconv.FormatComplex(complex128(v), 'f', -1, 64) + `"`
case complex128:
return `"` + strconv.FormatComplex(v, 'f', -1, 128) + `"`
case PseudoStruct:
buf := bytes.NewBuffer(make([]byte, 0, 1024))
v = f.sanitize(v)
if flags&flagRawStruct == 0 {
buf.WriteByte('{')
}
for i := 0; i < len(v); i += 2 {
if i > 0 {
buf.WriteByte(',')
}
k, _ := v[i].(string) // sanitize() above means no need to check success
// arbitrary keys might need escaping
buf.WriteString(prettyString(k))
buf.WriteByte(':')
buf.WriteString(f.prettyWithFlags(v[i+1], 0, depth+1))
}
if flags&flagRawStruct == 0 {
buf.WriteByte('}')
}
return buf.String()
}
buf := bytes.NewBuffer(make([]byte, 0, 256))
t := reflect.TypeOf(value)
if t == nil {
return "null"
}
v := reflect.ValueOf(value)
switch t.Kind() {
case reflect.Bool:
return strconv.FormatBool(v.Bool())
case reflect.String:
return prettyString(v.String())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return strconv.FormatInt(int64(v.Int()), 10)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return strconv.FormatUint(uint64(v.Uint()), 10)
case reflect.Float32:
return strconv.FormatFloat(float64(v.Float()), 'f', -1, 32)
case reflect.Float64:
return strconv.FormatFloat(v.Float(), 'f', -1, 64)
case reflect.Complex64:
return `"` + strconv.FormatComplex(complex128(v.Complex()), 'f', -1, 64) + `"`
case reflect.Complex128:
return `"` + strconv.FormatComplex(v.Complex(), 'f', -1, 128) + `"`
case reflect.Struct:
if flags&flagRawStruct == 0 {
buf.WriteByte('{')
}
for i := 0; i < t.NumField(); i++ {
fld := t.Field(i)
if fld.PkgPath != "" {
// reflect says this field is only defined for non-exported fields.
continue
}
if !v.Field(i).CanInterface() {
// reflect isn't clear exactly what this means, but we can't use it.
continue
}
name := ""
omitempty := false
if tag, found := fld.Tag.Lookup("json"); found {
if tag == "-" {
continue
}
if comma := strings.Index(tag, ","); comma != -1 {
if n := tag[:comma]; n != "" {
name = n
}
rest := tag[comma:]
if strings.Contains(rest, ",omitempty,") || strings.HasSuffix(rest, ",omitempty") {
omitempty = true
}
} else {
name = tag
}
}
if omitempty && isEmpty(v.Field(i)) {
continue
}
if i > 0 {
buf.WriteByte(',')
}
if fld.Anonymous && fld.Type.Kind() == reflect.Struct && name == "" {
buf.WriteString(f.prettyWithFlags(v.Field(i).Interface(), flags|flagRawStruct, depth+1))
continue
}
if name == "" {
name = fld.Name
}
// field names can't contain characters which need escaping
buf.WriteByte('"')
buf.WriteString(name)
buf.WriteByte('"')
buf.WriteByte(':')
buf.WriteString(f.prettyWithFlags(v.Field(i).Interface(), 0, depth+1))
}
if flags&flagRawStruct == 0 {
buf.WriteByte('}')
}
return buf.String()
case reflect.Slice, reflect.Array:
buf.WriteByte('[')
for i := 0; i < v.Len(); i++ {
if i > 0 {
buf.WriteByte(',')
}
e := v.Index(i)
buf.WriteString(f.prettyWithFlags(e.Interface(), 0, depth+1))
}
buf.WriteByte(']')
return buf.String()
case reflect.Map:
buf.WriteByte('{')
// This does not sort the map keys, for best perf.
it := v.MapRange()
i := 0
for it.Next() {
if i > 0 {
buf.WriteByte(',')
}
// If a map key supports TextMarshaler, use it.
keystr := ""
if m, ok := it.Key().Interface().(encoding.TextMarshaler); ok {
txt, err := m.MarshalText()
if err != nil {
keystr = fmt.Sprintf("<error-MarshalText: %s>", err.Error())
} else {
keystr = string(txt)
}
keystr = prettyString(keystr)
} else {
// prettyWithFlags will produce already-escaped values
keystr = f.prettyWithFlags(it.Key().Interface(), 0, depth+1)
if t.Key().Kind() != reflect.String {
// JSON only does string keys. Unlike Go's standard JSON, we'll
// convert just about anything to a string.
keystr = prettyString(keystr)
}
}
buf.WriteString(keystr)
buf.WriteByte(':')
buf.WriteString(f.prettyWithFlags(it.Value().Interface(), 0, depth+1))
i++
}
buf.WriteByte('}')
return buf.String()
case reflect.Ptr, reflect.Interface:
if v.IsNil() {
return "null"
}
return f.prettyWithFlags(v.Elem().Interface(), 0, depth)
}
return fmt.Sprintf(`"<unhandled-%s>"`, t.Kind().String())
}
func prettyString(s string) string {
// Avoid escaping (which does allocations) if we can.
if needsEscape(s) {
return strconv.Quote(s)
}
b := bytes.NewBuffer(make([]byte, 0, 1024))
b.WriteByte('"')
b.WriteString(s)
b.WriteByte('"')
return b.String()
}
// needsEscape determines whether the input string needs to be escaped or not,
// without doing any allocations.
func needsEscape(s string) bool {
for _, r := range s {
if !strconv.IsPrint(r) || r == '\\' || r == '"' {
return true
}
}
return false
}
func isEmpty(v reflect.Value) bool {
switch v.Kind() {
case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
return v.Len() == 0
case reflect.Bool:
return !v.Bool()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return v.Int() == 0
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return v.Uint() == 0
case reflect.Float32, reflect.Float64:
return v.Float() == 0
case reflect.Complex64, reflect.Complex128:
return v.Complex() == 0
case reflect.Interface, reflect.Ptr:
return v.IsNil()
}
return false
}
func invokeMarshaler(m logr.Marshaler) (ret interface{}) {
defer func() {
if r := recover(); r != nil {
ret = fmt.Sprintf("<panic: %s>", r)
}
}()
return m.MarshalLog()
}
func invokeStringer(s fmt.Stringer) (ret string) {
defer func() {
if r := recover(); r != nil {
ret = fmt.Sprintf("<panic: %s>", r)
}
}()
return s.String()
}
func invokeError(e error) (ret string) {
defer func() {
if r := recover(); r != nil {
ret = fmt.Sprintf("<panic: %s>", r)
}
}()
return e.Error()
}
// Caller represents the original call site for a log line, after considering
// logr.Logger.WithCallDepth and logr.Logger.WithCallStackHelper. The File and
// Line fields will always be provided, while the Func field is optional.
// Users can set the render hook fields in Options to examine logged key-value
// pairs, one of which will be {"caller", Caller} if the Options.LogCaller
// field is enabled for the given MessageClass.
type Caller struct {
// File is the basename of the file for this call site.
File string `json:"file"`
// Line is the line number in the file for this call site.
Line int `json:"line"`
// Func is the function name for this call site, or empty if
// Options.LogCallerFunc is not enabled.
Func string `json:"function,omitempty"`
}
func (f Formatter) caller() Caller {
// +1 for this frame, +1 for Info/Error.
pc, file, line, ok := runtime.Caller(f.depth + 2)
if !ok {
return Caller{"<unknown>", 0, ""}
}
fn := ""
if f.opts.LogCallerFunc {
if fp := runtime.FuncForPC(pc); fp != nil {
fn = fp.Name()
}
}
return Caller{filepath.Base(file), line, fn}
}
const noValue = "<no-value>"
func (f Formatter) nonStringKey(v interface{}) string {
return fmt.Sprintf("<non-string-key: %s>", f.snippet(v))
}
// snippet produces a short snippet string of an arbitrary value.
func (f Formatter) snippet(v interface{}) string {
const snipLen = 16
snip := f.pretty(v)
if len(snip) > snipLen {
snip = snip[:snipLen]
}
return snip
}
// sanitize ensures that a list of key-value pairs has a value for every key
// (adding a value if needed) and that each key is a string (substituting a key
// if needed).
func (f Formatter) sanitize(kvList []interface{}) []interface{} {
if len(kvList)%2 != 0 {
kvList = append(kvList, noValue)
}
for i := 0; i < len(kvList); i += 2 {
_, ok := kvList[i].(string)
if !ok {
kvList[i] = f.nonStringKey(kvList[i])
}
}
return kvList
}
// Init configures this Formatter from runtime info, such as the call depth
// imposed by logr itself.
// Note that this receiver is a pointer, so depth can be saved.
func (f *Formatter) Init(info logr.RuntimeInfo) {
f.depth += info.CallDepth
}
// Enabled checks whether an info message at the given level should be logged.
func (f Formatter) Enabled(level int) bool {
return level <= f.opts.Verbosity
}
// GetDepth returns the current depth of this Formatter. This is useful for
// implementations which do their own caller attribution.
func (f Formatter) GetDepth() int {
return f.depth
}
// FormatInfo renders an Info log message into strings. The prefix will be
// empty when no names were set (via AddNames), or when the output is
// configured for JSON.
func (f Formatter) FormatInfo(level int, msg string, kvList []interface{}) (prefix, argsStr string) {
args := make([]interface{}, 0, 64) // using a constant here impacts perf
prefix = f.prefix
if f.outputFormat == outputJSON {
args = append(args, "logger", prefix)
prefix = ""
}
if f.opts.LogTimestamp {
args = append(args, "ts", time.Now().Format(f.opts.TimestampFormat))
}
if policy := f.opts.LogCaller; policy == All || policy == Info {
args = append(args, "caller", f.caller())
}
args = append(args, "level", level, "msg", msg)
return prefix, f.render(args, kvList)
}
// FormatError renders an Error log message into strings. The prefix will be
// empty when no names were set (via AddNames), or when the output is
// configured for JSON.
func (f Formatter) FormatError(err error, msg string, kvList []interface{}) (prefix, argsStr string) {
args := make([]interface{}, 0, 64) // using a constant here impacts perf
prefix = f.prefix
if f.outputFormat == outputJSON {
args = append(args, "logger", prefix)
prefix = ""
}
if f.opts.LogTimestamp {
args = append(args, "ts", time.Now().Format(f.opts.TimestampFormat))
}
if policy := f.opts.LogCaller; policy == All || policy == Error {
args = append(args, "caller", f.caller())
}
args = append(args, "msg", msg)
var loggableErr interface{}
if err != nil {
loggableErr = err.Error()
}
args = append(args, "error", loggableErr)
return f.prefix, f.render(args, kvList)
}
// AddName appends the specified name. funcr uses '/' characters to separate
// name elements. Callers should not pass '/' in the provided name string, but
// this library does not actually enforce that.
func (f *Formatter) AddName(name string) {
if len(f.prefix) > 0 {
f.prefix += "/"
}
f.prefix += name
}
// AddValues adds key-value pairs to the set of saved values to be logged with
// each log line.
func (f *Formatter) AddValues(kvList []interface{}) {
// Three slice args forces a copy.
n := len(f.values)
f.values = append(f.values[:n:n], kvList...)
vals := f.values
if hook := f.opts.RenderValuesHook; hook != nil {
vals = hook(f.sanitize(vals))
}
// Pre-render values, so we don't have to do it on each Info/Error call.
buf := bytes.NewBuffer(make([]byte, 0, 1024))
f.flatten(buf, vals, false, true) // escape user-provided keys
f.valuesStr = buf.String()
}
// AddCallDepth increases the number of stack-frames to skip when attributing
// the log line to a file and line.
func (f *Formatter) AddCallDepth(depth int) {
f.depth += depth
}

6
vendor/github.com/go-logr/stdr/README.md generated vendored Normal file
View File

@@ -0,0 +1,6 @@
# Minimal Go logging using logr and Go's standard library
[![Go Reference](https://pkg.go.dev/badge/github.com/go-logr/stdr.svg)](https://pkg.go.dev/github.com/go-logr/stdr)
This package implements the [logr interface](https://github.com/go-logr/logr)
in terms of Go's standard log package(https://pkg.go.dev/log).

170
vendor/github.com/go-logr/stdr/stdr.go generated vendored Normal file
View File

@@ -0,0 +1,170 @@
/*
Copyright 2019 The logr Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Package stdr implements github.com/go-logr/logr.Logger in terms of
// Go's standard log package.
package stdr
import (
"log"
"os"
"github.com/go-logr/logr"
"github.com/go-logr/logr/funcr"
)
// The global verbosity level. See SetVerbosity().
var globalVerbosity int
// SetVerbosity sets the global level against which all info logs will be
// compared. If this is greater than or equal to the "V" of the logger, the
// message will be logged. A higher value here means more logs will be written.
// The previous verbosity value is returned. This is not concurrent-safe -
// callers must be sure to call it from only one goroutine.
func SetVerbosity(v int) int {
old := globalVerbosity
globalVerbosity = v
return old
}
// New returns a logr.Logger which is implemented by Go's standard log package,
// or something like it. If std is nil, this will use a default logger
// instead.
//
// Example: stdr.New(log.New(os.Stderr, "", log.LstdFlags|log.Lshortfile)))
func New(std StdLogger) logr.Logger {
return NewWithOptions(std, Options{})
}
// NewWithOptions returns a logr.Logger which is implemented by Go's standard
// log package, or something like it. See New for details.
func NewWithOptions(std StdLogger, opts Options) logr.Logger {
if std == nil {
// Go's log.Default() is only available in 1.16 and higher.
std = log.New(os.Stderr, "", log.LstdFlags)
}
if opts.Depth < 0 {
opts.Depth = 0
}
fopts := funcr.Options{
LogCaller: funcr.MessageClass(opts.LogCaller),
}
sl := &logger{
Formatter: funcr.NewFormatter(fopts),
std: std,
}
// For skipping our own logger.Info/Error.
sl.Formatter.AddCallDepth(1 + opts.Depth)
return logr.New(sl)
}
// Options carries parameters which influence the way logs are generated.
type Options struct {
// Depth biases the assumed number of call frames to the "true" caller.
// This is useful when the calling code calls a function which then calls
// stdr (e.g. a logging shim to another API). Values less than zero will
// be treated as zero.
Depth int
// LogCaller tells stdr to add a "caller" key to some or all log lines.
// Go's log package has options to log this natively, too.
LogCaller MessageClass
// TODO: add an option to log the date/time
}
// MessageClass indicates which category or categories of messages to consider.
type MessageClass int
const (
// None ignores all message classes.
None MessageClass = iota
// All considers all message classes.
All
// Info only considers info messages.
Info
// Error only considers error messages.
Error
)
// StdLogger is the subset of the Go stdlib log.Logger API that is needed for
// this adapter.
type StdLogger interface {
// Output is the same as log.Output and log.Logger.Output.
Output(calldepth int, logline string) error
}
type logger struct {
funcr.Formatter
std StdLogger
}
var _ logr.LogSink = &logger{}
var _ logr.CallDepthLogSink = &logger{}
func (l logger) Enabled(level int) bool {
return globalVerbosity >= level
}
func (l logger) Info(level int, msg string, kvList ...interface{}) {
prefix, args := l.FormatInfo(level, msg, kvList)
if prefix != "" {
args = prefix + ": " + args
}
_ = l.std.Output(l.Formatter.GetDepth()+1, args)
}
func (l logger) Error(err error, msg string, kvList ...interface{}) {
prefix, args := l.FormatError(err, msg, kvList)
if prefix != "" {
args = prefix + ": " + args
}
_ = l.std.Output(l.Formatter.GetDepth()+1, args)
}
func (l logger) WithName(name string) logr.LogSink {
l.Formatter.AddName(name)
return &l
}
func (l logger) WithValues(kvList ...interface{}) logr.LogSink {
l.Formatter.AddValues(kvList)
return &l
}
func (l logger) WithCallDepth(depth int) logr.LogSink {
l.Formatter.AddCallDepth(depth)
return &l
}
// Underlier exposes access to the underlying logging implementation. Since
// callers only have a logr.Logger, they have to know which implementation is
// in use, so this interface is less of an abstraction and more of way to test
// type conversion.
type Underlier interface {
GetUnderlying() StdLogger
}
// GetUnderlying returns the StdLogger underneath this logger. Since StdLogger
// is itself an interface, the result may or may not be a Go log.Logger.
func (l logger) GetUnderlying() StdLogger {
return l.std
}

View File

@@ -42,6 +42,7 @@ func isEmpty(x, y interface{}) bool {
// The fraction and margin must be non-negative.
//
// The mathematical expression used is equivalent to:
//
// |x-y| ≤ max(fraction*min(|x|, |y|), margin)
//
// EquateApprox can be used in conjunction with EquateNaNs.

View File

@@ -18,9 +18,9 @@ import (
// sort any slice with element type V that is assignable to T.
//
// The less function must be:
// Deterministic: less(x, y) == less(x, y)
// Irreflexive: !less(x, x)
// Transitive: if !less(x, y) and !less(y, z), then !less(x, z)
// - Deterministic: less(x, y) == less(x, y)
// - Irreflexive: !less(x, x)
// - Transitive: if !less(x, y) and !less(y, z), then !less(x, z)
//
// The less function does not have to be "total". That is, if !less(x, y) and
// !less(y, x) for two elements x and y, their relative order is maintained.
@@ -91,10 +91,10 @@ func (ss sliceSorter) less(v reflect.Value, i, j int) bool {
// use Comparers on K or the K.Equal method if it exists.
//
// The less function must be:
// Deterministic: less(x, y) == less(x, y)
// Irreflexive: !less(x, x)
// Transitive: if !less(x, y) and !less(y, z), then !less(x, z)
// Total: if x != y, then either less(x, y) or less(y, x)
// - Deterministic: less(x, y) == less(x, y)
// - Irreflexive: !less(x, x)
// - Transitive: if !less(x, y) and !less(y, z), then !less(x, z)
// - Total: if x != y, then either less(x, y) or less(y, x)
//
// SortMaps can be used in conjunction with EquateEmpty.
func SortMaps(lessFunc interface{}) cmp.Option {

View File

@@ -67,12 +67,14 @@ func (sf structFilter) filter(p cmp.Path) bool {
// fieldTree represents a set of dot-separated identifiers.
//
// For example, inserting the following selectors:
//
// Foo
// Foo.Bar.Baz
// Foo.Buzz
// Nuka.Cola.Quantum
//
// Results in a tree of the form:
//
// {sub: {
// "Foo": {ok: true, sub: {
// "Bar": {sub: {

View File

@@ -23,6 +23,7 @@ func (xf xformFilter) filter(p cmp.Path) bool {
// that the transformer cannot be recursively applied upon its own output.
//
// An example use case is a transformer that splits a string by lines:
//
// AcyclicTransformer("SplitLines", func(s string) []string{
// return strings.Split(s, "\n")
// })

View File

@@ -13,21 +13,21 @@
//
// The primary features of cmp are:
//
// When the default behavior of equality does not suit the needs of the test,
// custom equality functions can override the equality operation.
// For example, an equality function may report floats as equal so long as they
// are within some tolerance of each other.
// - When the default behavior of equality does not suit the test's needs,
// custom equality functions can override the equality operation.
// For example, an equality function may report floats as equal so long as
// they are within some tolerance of each other.
//
// Types that have an Equal method may use that method to determine equality.
// This allows package authors to determine the equality operation for the types
// that they define.
// - Types with an Equal method may use that method to determine equality.
// This allows package authors to determine the equality operation
// for the types that they define.
//
// If no custom equality functions are used and no Equal method is defined,
// equality is determined by recursively comparing the primitive kinds on both
// values, much like reflect.DeepEqual. Unlike reflect.DeepEqual, unexported
// fields are not compared by default; they result in panics unless suppressed
// by using an Ignore option (see cmpopts.IgnoreUnexported) or explicitly
// compared using the Exporter option.
// - If no custom equality functions are used and no Equal method is defined,
// equality is determined by recursively comparing the primitive kinds on
// both values, much like reflect.DeepEqual. Unlike reflect.DeepEqual,
// unexported fields are not compared by default; they result in panics
// unless suppressed by using an Ignore option (see cmpopts.IgnoreUnexported)
// or explicitly compared using the Exporter option.
package cmp
import (
@@ -45,25 +45,25 @@ import (
// Equal reports whether x and y are equal by recursively applying the
// following rules in the given order to x and y and all of their sub-values:
//
// Let S be the set of all Ignore, Transformer, and Comparer options that
// remain after applying all path filters, value filters, and type filters.
// If at least one Ignore exists in S, then the comparison is ignored.
// If the number of Transformer and Comparer options in S is greater than one,
// then Equal panics because it is ambiguous which option to use.
// If S contains a single Transformer, then use that to transform the current
// values and recursively call Equal on the output values.
// If S contains a single Comparer, then use that to compare the current values.
// Otherwise, evaluation proceeds to the next rule.
// - Let S be the set of all Ignore, Transformer, and Comparer options that
// remain after applying all path filters, value filters, and type filters.
// If at least one Ignore exists in S, then the comparison is ignored.
// If the number of Transformer and Comparer options in S is non-zero,
// then Equal panics because it is ambiguous which option to use.
// If S contains a single Transformer, then use that to transform
// the current values and recursively call Equal on the output values.
// If S contains a single Comparer, then use that to compare the current values.
// Otherwise, evaluation proceeds to the next rule.
//
// If the values have an Equal method of the form "(T) Equal(T) bool" or
// "(T) Equal(I) bool" where T is assignable to I, then use the result of
// x.Equal(y) even if x or y is nil. Otherwise, no such method exists and
// evaluation proceeds to the next rule.
// - If the values have an Equal method of the form "(T) Equal(T) bool" or
// "(T) Equal(I) bool" where T is assignable to I, then use the result of
// x.Equal(y) even if x or y is nil. Otherwise, no such method exists and
// evaluation proceeds to the next rule.
//
// Lastly, try to compare x and y based on their basic kinds.
// Simple kinds like booleans, integers, floats, complex numbers, strings, and
// channels are compared using the equivalent of the == operator in Go.
// Functions are only equal if they are both nil, otherwise they are unequal.
// - Lastly, try to compare x and y based on their basic kinds.
// Simple kinds like booleans, integers, floats, complex numbers, strings,
// and channels are compared using the equivalent of the == operator in Go.
// Functions are only equal if they are both nil, otherwise they are unequal.
//
// Structs are equal if recursively calling Equal on all fields report equal.
// If a struct contains unexported fields, Equal panics unless an Ignore option
@@ -144,7 +144,7 @@ func rootStep(x, y interface{}) PathStep {
// so that they have the same parent type.
var t reflect.Type
if !vx.IsValid() || !vy.IsValid() || vx.Type() != vy.Type() {
t = reflect.TypeOf((*interface{})(nil)).Elem()
t = anyType
if vx.IsValid() {
vvx := reflect.New(t).Elem()
vvx.Set(vx)
@@ -639,7 +639,9 @@ type dynChecker struct{ curr, next int }
// Next increments the state and reports whether a check should be performed.
//
// Checks occur every Nth function call, where N is a triangular number:
//
// 0 1 3 6 10 15 21 28 36 45 55 66 78 91 105 120 136 153 171 190 ...
//
// See https://en.wikipedia.org/wiki/Triangular_number
//
// This sequence ensures that the cost of checks drops significantly as

View File

@@ -127,9 +127,9 @@ var randBool = rand.New(rand.NewSource(time.Now().Unix())).Intn(2) == 0
// This function returns an edit-script, which is a sequence of operations
// needed to convert one list into the other. The following invariants for
// the edit-script are maintained:
// eq == (es.Dist()==0)
// nx == es.LenX()
// ny == es.LenY()
// - eq == (es.Dist()==0)
// - nx == es.LenX()
// - ny == es.LenY()
//
// This algorithm is not guaranteed to be an optimal solution (i.e., one that
// produces an edit-script with a minimal Levenshtein distance). This algorithm
@@ -169,12 +169,13 @@ func Difference(nx, ny int, f EqualFunc) (es EditScript) {
// A diagonal edge is equivalent to a matching symbol between both X and Y.
// Invariants:
// 0 ≤ fwdPath.X ≤ (fwdFrontier.X, revFrontier.X) ≤ revPath.X ≤ nx
// 0 ≤ fwdPath.Y ≤ (fwdFrontier.Y, revFrontier.Y) ≤ revPath.Y ≤ ny
// - 0 ≤ fwdPath.X ≤ (fwdFrontier.X, revFrontier.X) ≤ revPath.X ≤ nx
// - 0 ≤ fwdPath.Y ≤ (fwdFrontier.Y, revFrontier.Y) ≤ revPath.Y ≤ ny
//
// In general:
// fwdFrontier.X < revFrontier.X
// fwdFrontier.Y < revFrontier.Y
// - fwdFrontier.X < revFrontier.X
// - fwdFrontier.Y < revFrontier.Y
//
// Unless, it is time for the algorithm to terminate.
fwdPath := path{+1, point{0, 0}, make(EditScript, 0, (nx+ny)/2)}
revPath := path{-1, point{nx, ny}, make(EditScript, 0)}
@@ -195,19 +196,21 @@ func Difference(nx, ny int, f EqualFunc) (es EditScript) {
// computing sub-optimal edit-scripts between two lists.
//
// The algorithm is approximately as follows:
// Searching for differences switches back-and-forth between
// a search that starts at the beginning (the top-left corner), and
// a search that starts at the end (the bottom-right corner). The goal of
// the search is connect with the search from the opposite corner.
// • As we search, we build a path in a greedy manner, where the first
// match seen is added to the path (this is sub-optimal, but provides a
// decent result in practice). When matches are found, we try the next pair
// of symbols in the lists and follow all matches as far as possible.
// • When searching for matches, we search along a diagonal going through
// through the "frontier" point. If no matches are found, we advance the
// frontier towards the opposite corner.
// • This algorithm terminates when either the X coordinates or the
// Y coordinates of the forward and reverse frontier points ever intersect.
// - Searching for differences switches back-and-forth between
// a search that starts at the beginning (the top-left corner), and
// a search that starts at the end (the bottom-right corner).
// The goal of the search is connect with the search
// from the opposite corner.
// - As we search, we build a path in a greedy manner,
// where the first match seen is added to the path (this is sub-optimal,
// but provides a decent result in practice). When matches are found,
// we try the next pair of symbols in the lists and follow all matches
// as far as possible.
// - When searching for matches, we search along a diagonal going through
// through the "frontier" point. If no matches are found,
// we advance the frontier towards the opposite corner.
// - This algorithm terminates when either the X coordinates or the
// Y coordinates of the forward and reverse frontier points ever intersect.
// This algorithm is correct even if searching only in the forward direction
// or in the reverse direction. We do both because it is commonly observed
@@ -389,6 +392,7 @@ type point struct{ X, Y int }
func (p *point) add(dx, dy int) { p.X += dx; p.Y += dy }
// zigzag maps a consecutive sequence of integers to a zig-zag sequence.
//
// [0 1 2 3 4 5 ...] => [0 -1 +1 -2 +2 ...]
func zigzag(x int) int {
if x&1 != 0 {

View File

@@ -1,48 +0,0 @@
// Copyright 2017, The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package value
import (
"math"
"reflect"
)
// IsZero reports whether v is the zero value.
// This does not rely on Interface and so can be used on unexported fields.
func IsZero(v reflect.Value) bool {
switch v.Kind() {
case reflect.Bool:
return v.Bool() == false
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return v.Int() == 0
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return v.Uint() == 0
case reflect.Float32, reflect.Float64:
return math.Float64bits(v.Float()) == 0
case reflect.Complex64, reflect.Complex128:
return math.Float64bits(real(v.Complex())) == 0 && math.Float64bits(imag(v.Complex())) == 0
case reflect.String:
return v.String() == ""
case reflect.UnsafePointer:
return v.Pointer() == 0
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Ptr, reflect.Map, reflect.Slice:
return v.IsNil()
case reflect.Array:
for i := 0; i < v.Len(); i++ {
if !IsZero(v.Index(i)) {
return false
}
}
return true
case reflect.Struct:
for i := 0; i < v.NumField(); i++ {
if !IsZero(v.Field(i)) {
return false
}
}
return true
}
return false
}

View File

@@ -33,6 +33,7 @@ type Option interface {
}
// applicableOption represents the following types:
//
// Fundamental: ignore | validator | *comparer | *transformer
// Grouping: Options
type applicableOption interface {
@@ -43,6 +44,7 @@ type applicableOption interface {
}
// coreOption represents the following types:
//
// Fundamental: ignore | validator | *comparer | *transformer
// Filters: *pathFilter | *valuesFilter
type coreOption interface {
@@ -336,9 +338,9 @@ func (tr transformer) String() string {
// both implement T.
//
// The equality function must be:
// Symmetric: equal(x, y) == equal(y, x)
// Deterministic: equal(x, y) == equal(x, y)
// Pure: equal(x, y) does not modify x or y
// - Symmetric: equal(x, y) == equal(y, x)
// - Deterministic: equal(x, y) == equal(x, y)
// - Pure: equal(x, y) does not modify x or y
func Comparer(f interface{}) Option {
v := reflect.ValueOf(f)
if !function.IsType(v.Type(), function.Equal) || v.IsNil() {
@@ -430,7 +432,7 @@ func AllowUnexported(types ...interface{}) Option {
}
// Result represents the comparison result for a single node and
// is provided by cmp when calling Result (see Reporter).
// is provided by cmp when calling Report (see Reporter).
type Result struct {
_ [0]func() // Make Result incomparable
flags resultFlags

View File

@@ -41,13 +41,13 @@ type PathStep interface {
// The type of each valid value is guaranteed to be identical to Type.
//
// In some cases, one or both may be invalid or have restrictions:
// For StructField, both are not interface-able if the current field
// is unexported and the struct type is not explicitly permitted by
// an Exporter to traverse unexported fields.
// For SliceIndex, one may be invalid if an element is missing from
// either the x or y slice.
// For MapIndex, one may be invalid if an entry is missing from
// either the x or y map.
// - For StructField, both are not interface-able if the current field
// is unexported and the struct type is not explicitly permitted by
// an Exporter to traverse unexported fields.
// - For SliceIndex, one may be invalid if an element is missing from
// either the x or y slice.
// - For MapIndex, one may be invalid if an entry is missing from
// either the x or y map.
//
// The provided values must not be mutated.
Values() (vx, vy reflect.Value)
@@ -94,6 +94,7 @@ func (pa Path) Index(i int) PathStep {
// The simplified path only contains struct field accesses.
//
// For example:
//
// MyMap.MySlices.MyField
func (pa Path) String() string {
var ss []string
@@ -108,6 +109,7 @@ func (pa Path) String() string {
// GoString returns the path to a specific node using Go syntax.
//
// For example:
//
// (*root.MyMap["key"].(*mypkg.MyStruct).MySlices)[2][3].MyField
func (pa Path) GoString() string {
var ssPre, ssPost []string
@@ -159,7 +161,7 @@ func (ps pathStep) String() string {
if ps.typ == nil {
return "<nil>"
}
s := ps.typ.String()
s := value.TypeString(ps.typ, false)
if s == "" || strings.ContainsAny(s, "{}\n") {
return "root" // Type too simple or complex to print
}
@@ -282,7 +284,7 @@ type typeAssertion struct {
func (ta TypeAssertion) Type() reflect.Type { return ta.typ }
func (ta TypeAssertion) Values() (vx, vy reflect.Value) { return ta.vx, ta.vy }
func (ta TypeAssertion) String() string { return fmt.Sprintf(".(%v)", ta.typ) }
func (ta TypeAssertion) String() string { return fmt.Sprintf(".(%v)", value.TypeString(ta.typ, false)) }
// Transform is a transformation from the parent type to the current type.
type Transform struct{ *transform }

View File

@@ -7,8 +7,6 @@ package cmp
import (
"fmt"
"reflect"
"github.com/google/go-cmp/cmp/internal/value"
)
// numContextRecords is the number of surrounding equal records to print.
@@ -117,7 +115,7 @@ func (opts formatOptions) FormatDiff(v *valueNode, ptrs *pointerReferences) (out
// For leaf nodes, format the value based on the reflect.Values alone.
// As a special case, treat equal []byte as a leaf nodes.
isBytes := v.Type.Kind() == reflect.Slice && v.Type.Elem() == reflect.TypeOf(byte(0))
isBytes := v.Type.Kind() == reflect.Slice && v.Type.Elem() == byteType
isEqualBytes := isBytes && v.NumDiff+v.NumIgnored+v.NumTransformed == 0
if v.MaxDepth == 0 || isEqualBytes {
switch opts.DiffMode {
@@ -248,11 +246,11 @@ func (opts formatOptions) formatDiffList(recs []reportRecord, k reflect.Kind, pt
var isZero bool
switch opts.DiffMode {
case diffIdentical:
isZero = value.IsZero(r.Value.ValueX) || value.IsZero(r.Value.ValueY)
isZero = r.Value.ValueX.IsZero() || r.Value.ValueY.IsZero()
case diffRemoved:
isZero = value.IsZero(r.Value.ValueX)
isZero = r.Value.ValueX.IsZero()
case diffInserted:
isZero = value.IsZero(r.Value.ValueY)
isZero = r.Value.ValueY.IsZero()
}
if isZero {
continue

View File

@@ -16,6 +16,13 @@ import (
"github.com/google/go-cmp/cmp/internal/value"
)
var (
anyType = reflect.TypeOf((*interface{})(nil)).Elem()
stringType = reflect.TypeOf((*string)(nil)).Elem()
bytesType = reflect.TypeOf((*[]byte)(nil)).Elem()
byteType = reflect.TypeOf((*byte)(nil)).Elem()
)
type formatValueOptions struct {
// AvoidStringer controls whether to avoid calling custom stringer
// methods like error.Error or fmt.Stringer.String.
@@ -184,7 +191,7 @@ func (opts formatOptions) FormatValue(v reflect.Value, parentKind reflect.Kind,
}
for i := 0; i < v.NumField(); i++ {
vv := v.Field(i)
if value.IsZero(vv) {
if vv.IsZero() {
continue // Elide fields with zero values
}
if len(list) == maxLen {
@@ -205,7 +212,7 @@ func (opts formatOptions) FormatValue(v reflect.Value, parentKind reflect.Kind,
}
// Check whether this is a []byte of text data.
if t.Elem() == reflect.TypeOf(byte(0)) {
if t.Elem() == byteType {
b := v.Bytes()
isPrintSpace := func(r rune) bool { return unicode.IsPrint(r) || unicode.IsSpace(r) }
if len(b) > 0 && utf8.Valid(b) && len(bytes.TrimFunc(b, isPrintSpace)) == 0 {

View File

@@ -104,7 +104,7 @@ func (opts formatOptions) FormatDiffSlice(v *valueNode) textNode {
case t.Kind() == reflect.String:
sx, sy = vx.String(), vy.String()
isString = true
case t.Kind() == reflect.Slice && t.Elem() == reflect.TypeOf(byte(0)):
case t.Kind() == reflect.Slice && t.Elem() == byteType:
sx, sy = string(vx.Bytes()), string(vy.Bytes())
isString = true
case t.Kind() == reflect.Array:
@@ -147,7 +147,10 @@ func (opts formatOptions) FormatDiffSlice(v *valueNode) textNode {
})
efficiencyLines := float64(esLines.Dist()) / float64(len(esLines))
efficiencyBytes := float64(esBytes.Dist()) / float64(len(esBytes))
isPureLinedText = efficiencyLines < 4*efficiencyBytes
quotedLength := len(strconv.Quote(sx + sy))
unquotedLength := len(sx) + len(sy)
escapeExpansionRatio := float64(quotedLength) / float64(unquotedLength)
isPureLinedText = efficiencyLines < 4*efficiencyBytes || escapeExpansionRatio > 1.1
}
}
@@ -171,12 +174,13 @@ func (opts formatOptions) FormatDiffSlice(v *valueNode) textNode {
// differences in a string literal. This format is more readable,
// but has edge-cases where differences are visually indistinguishable.
// This format is avoided under the following conditions:
// A line starts with `"""`
// A line starts with "..."
// A line contains non-printable characters
// Adjacent different lines differ only by whitespace
// - A line starts with `"""`
// - A line starts with "..."
// - A line contains non-printable characters
// - Adjacent different lines differ only by whitespace
//
// For example:
//
// """
// ... // 3 identical lines
// foo
@@ -231,7 +235,7 @@ func (opts formatOptions) FormatDiffSlice(v *valueNode) textNode {
var out textNode = &textWrap{Prefix: "(", Value: list2, Suffix: ")"}
switch t.Kind() {
case reflect.String:
if t != reflect.TypeOf(string("")) {
if t != stringType {
out = opts.FormatType(t, out)
}
case reflect.Slice:
@@ -326,12 +330,12 @@ func (opts formatOptions) FormatDiffSlice(v *valueNode) textNode {
switch t.Kind() {
case reflect.String:
out = &textWrap{Prefix: "strings.Join(", Value: out, Suffix: fmt.Sprintf(", %q)", delim)}
if t != reflect.TypeOf(string("")) {
if t != stringType {
out = opts.FormatType(t, out)
}
case reflect.Slice:
out = &textWrap{Prefix: "bytes.Join(", Value: out, Suffix: fmt.Sprintf(", %q)", delim)}
if t != reflect.TypeOf([]byte(nil)) {
if t != bytesType {
out = opts.FormatType(t, out)
}
}
@@ -446,7 +450,6 @@ func (opts formatOptions) formatDiffSlice(
// {NumIdentical: 3},
// {NumInserted: 1},
// ]
//
func coalesceAdjacentEdits(name string, es diff.EditScript) (groups []diffStats) {
var prevMode byte
lastStats := func(mode byte) *diffStats {
@@ -503,7 +506,6 @@ func coalesceAdjacentEdits(name string, es diff.EditScript) (groups []diffStats)
// {NumIdentical: 8, NumRemoved: 12, NumInserted: 3},
// {NumIdentical: 63},
// ]
//
func coalesceInterveningIdentical(groups []diffStats, windowSize int) []diffStats {
groups, groupsOrig := groups[:0], groups
for i, ds := range groupsOrig {
@@ -548,7 +550,6 @@ func coalesceInterveningIdentical(groups []diffStats, windowSize int) []diffStat
// {NumRemoved: 9},
// {NumIdentical: 64}, // incremented by 10
// ]
//
func cleanupSurroundingIdentical(groups []diffStats, eq func(i, j int) bool) []diffStats {
var ix, iy int // indexes into sequence x and y
for i, ds := range groups {

View File

@@ -393,6 +393,7 @@ func (s diffStats) Append(ds diffStats) diffStats {
// String prints a humanly-readable summary of coalesced records.
//
// Example:
//
// diffStats{Name: "Field", NumIgnored: 5}.String() => "5 ignored fields"
func (s diffStats) String() string {
var ss []string

View File

@@ -0,0 +1,27 @@
Copyright (c) 2015, Gengo, Inc.
All rights reserved.
Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice,
this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of Gengo, Inc. nor the names of its
contributors may be used to endorse or promote products derived from this
software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@@ -0,0 +1,35 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
package(default_visibility = ["//visibility:public"])
go_library(
name = "httprule",
srcs = [
"compile.go",
"parse.go",
"types.go",
],
importpath = "github.com/grpc-ecosystem/grpc-gateway/v2/internal/httprule",
deps = ["//utilities"],
)
go_test(
name = "httprule_test",
size = "small",
srcs = [
"compile_test.go",
"parse_test.go",
"types_test.go",
],
embed = [":httprule"],
deps = [
"//utilities",
"@com_github_golang_glog//:glog",
],
)
alias(
name = "go_default_library",
actual = ":httprule",
visibility = ["//:__subpackages__"],
)

View File

@@ -0,0 +1,121 @@
package httprule
import (
"github.com/grpc-ecosystem/grpc-gateway/v2/utilities"
)
const (
opcodeVersion = 1
)
// Template is a compiled representation of path templates.
type Template struct {
// Version is the version number of the format.
Version int
// OpCodes is a sequence of operations.
OpCodes []int
// Pool is a constant pool
Pool []string
// Verb is a VERB part in the template.
Verb string
// Fields is a list of field paths bound in this template.
Fields []string
// Original template (example: /v1/a_bit_of_everything)
Template string
}
// Compiler compiles utilities representation of path templates into marshallable operations.
// They can be unmarshalled by runtime.NewPattern.
type Compiler interface {
Compile() Template
}
type op struct {
// code is the opcode of the operation
code utilities.OpCode
// str is a string operand of the code.
// num is ignored if str is not empty.
str string
// num is a numeric operand of the code.
num int
}
func (w wildcard) compile() []op {
return []op{
{code: utilities.OpPush},
}
}
func (w deepWildcard) compile() []op {
return []op{
{code: utilities.OpPushM},
}
}
func (l literal) compile() []op {
return []op{
{
code: utilities.OpLitPush,
str: string(l),
},
}
}
func (v variable) compile() []op {
var ops []op
for _, s := range v.segments {
ops = append(ops, s.compile()...)
}
ops = append(ops, op{
code: utilities.OpConcatN,
num: len(v.segments),
}, op{
code: utilities.OpCapture,
str: v.path,
})
return ops
}
func (t template) Compile() Template {
var rawOps []op
for _, s := range t.segments {
rawOps = append(rawOps, s.compile()...)
}
var (
ops []int
pool []string
fields []string
)
consts := make(map[string]int)
for _, op := range rawOps {
ops = append(ops, int(op.code))
if op.str == "" {
ops = append(ops, op.num)
} else {
// eof segment literal represents the "/" path pattern
if op.str == eof {
op.str = ""
}
if _, ok := consts[op.str]; !ok {
consts[op.str] = len(pool)
pool = append(pool, op.str)
}
ops = append(ops, consts[op.str])
}
if op.code == utilities.OpCapture {
fields = append(fields, op.str)
}
}
return Template{
Version: opcodeVersion,
OpCodes: ops,
Pool: pool,
Verb: t.verb,
Fields: fields,
Template: t.template,
}
}

View File

@@ -0,0 +1,11 @@
// +build gofuzz
package httprule
func Fuzz(data []byte) int {
_, err := Parse(string(data))
if err != nil {
return 0
}
return 0
}

View File

@@ -0,0 +1,368 @@
package httprule
import (
"fmt"
"strings"
)
// InvalidTemplateError indicates that the path template is not valid.
type InvalidTemplateError struct {
tmpl string
msg string
}
func (e InvalidTemplateError) Error() string {
return fmt.Sprintf("%s: %s", e.msg, e.tmpl)
}
// Parse parses the string representation of path template
func Parse(tmpl string) (Compiler, error) {
if !strings.HasPrefix(tmpl, "/") {
return template{}, InvalidTemplateError{tmpl: tmpl, msg: "no leading /"}
}
tokens, verb := tokenize(tmpl[1:])
p := parser{tokens: tokens}
segs, err := p.topLevelSegments()
if err != nil {
return template{}, InvalidTemplateError{tmpl: tmpl, msg: err.Error()}
}
return template{
segments: segs,
verb: verb,
template: tmpl,
}, nil
}
func tokenize(path string) (tokens []string, verb string) {
if path == "" {
return []string{eof}, ""
}
const (
init = iota
field
nested
)
st := init
for path != "" {
var idx int
switch st {
case init:
idx = strings.IndexAny(path, "/{")
case field:
idx = strings.IndexAny(path, ".=}")
case nested:
idx = strings.IndexAny(path, "/}")
}
if idx < 0 {
tokens = append(tokens, path)
break
}
switch r := path[idx]; r {
case '/', '.':
case '{':
st = field
case '=':
st = nested
case '}':
st = init
}
if idx == 0 {
tokens = append(tokens, path[idx:idx+1])
} else {
tokens = append(tokens, path[:idx], path[idx:idx+1])
}
path = path[idx+1:]
}
l := len(tokens)
// See
// https://github.com/grpc-ecosystem/grpc-gateway/pull/1947#issuecomment-774523693 ;
// although normal and backwards-compat logic here is to use the last index
// of a colon, if the final segment is a variable followed by a colon, the
// part following the colon must be a verb. Hence if the previous token is
// an end var marker, we switch the index we're looking for to Index instead
// of LastIndex, so that we correctly grab the remaining part of the path as
// the verb.
var penultimateTokenIsEndVar bool
switch l {
case 0, 1:
// Not enough to be variable so skip this logic and don't result in an
// invalid index
default:
penultimateTokenIsEndVar = tokens[l-2] == "}"
}
t := tokens[l-1]
var idx int
if penultimateTokenIsEndVar {
idx = strings.Index(t, ":")
} else {
idx = strings.LastIndex(t, ":")
}
if idx == 0 {
tokens, verb = tokens[:l-1], t[1:]
} else if idx > 0 {
tokens[l-1], verb = t[:idx], t[idx+1:]
}
tokens = append(tokens, eof)
return tokens, verb
}
// parser is a parser of the template syntax defined in github.com/googleapis/googleapis/google/api/http.proto.
type parser struct {
tokens []string
accepted []string
}
// topLevelSegments is the target of this parser.
func (p *parser) topLevelSegments() ([]segment, error) {
if _, err := p.accept(typeEOF); err == nil {
p.tokens = p.tokens[:0]
return []segment{literal(eof)}, nil
}
segs, err := p.segments()
if err != nil {
return nil, err
}
if _, err := p.accept(typeEOF); err != nil {
return nil, fmt.Errorf("unexpected token %q after segments %q", p.tokens[0], strings.Join(p.accepted, ""))
}
return segs, nil
}
func (p *parser) segments() ([]segment, error) {
s, err := p.segment()
if err != nil {
return nil, err
}
segs := []segment{s}
for {
if _, err := p.accept("/"); err != nil {
return segs, nil
}
s, err := p.segment()
if err != nil {
return segs, err
}
segs = append(segs, s)
}
}
func (p *parser) segment() (segment, error) {
if _, err := p.accept("*"); err == nil {
return wildcard{}, nil
}
if _, err := p.accept("**"); err == nil {
return deepWildcard{}, nil
}
if l, err := p.literal(); err == nil {
return l, nil
}
v, err := p.variable()
if err != nil {
return nil, fmt.Errorf("segment neither wildcards, literal or variable: %v", err)
}
return v, err
}
func (p *parser) literal() (segment, error) {
lit, err := p.accept(typeLiteral)
if err != nil {
return nil, err
}
return literal(lit), nil
}
func (p *parser) variable() (segment, error) {
if _, err := p.accept("{"); err != nil {
return nil, err
}
path, err := p.fieldPath()
if err != nil {
return nil, err
}
var segs []segment
if _, err := p.accept("="); err == nil {
segs, err = p.segments()
if err != nil {
return nil, fmt.Errorf("invalid segment in variable %q: %v", path, err)
}
} else {
segs = []segment{wildcard{}}
}
if _, err := p.accept("}"); err != nil {
return nil, fmt.Errorf("unterminated variable segment: %s", path)
}
return variable{
path: path,
segments: segs,
}, nil
}
func (p *parser) fieldPath() (string, error) {
c, err := p.accept(typeIdent)
if err != nil {
return "", err
}
components := []string{c}
for {
if _, err = p.accept("."); err != nil {
return strings.Join(components, "."), nil
}
c, err := p.accept(typeIdent)
if err != nil {
return "", fmt.Errorf("invalid field path component: %v", err)
}
components = append(components, c)
}
}
// A termType is a type of terminal symbols.
type termType string
// These constants define some of valid values of termType.
// They improve readability of parse functions.
//
// You can also use "/", "*", "**", "." or "=" as valid values.
const (
typeIdent = termType("ident")
typeLiteral = termType("literal")
typeEOF = termType("$")
)
const (
// eof is the terminal symbol which always appears at the end of token sequence.
eof = "\u0000"
)
// accept tries to accept a token in "p".
// This function consumes a token and returns it if it matches to the specified "term".
// If it doesn't match, the function does not consume any tokens and return an error.
func (p *parser) accept(term termType) (string, error) {
t := p.tokens[0]
switch term {
case "/", "*", "**", ".", "=", "{", "}":
if t != string(term) && t != "/" {
return "", fmt.Errorf("expected %q but got %q", term, t)
}
case typeEOF:
if t != eof {
return "", fmt.Errorf("expected EOF but got %q", t)
}
case typeIdent:
if err := expectIdent(t); err != nil {
return "", err
}
case typeLiteral:
if err := expectPChars(t); err != nil {
return "", err
}
default:
return "", fmt.Errorf("unknown termType %q", term)
}
p.tokens = p.tokens[1:]
p.accepted = append(p.accepted, t)
return t, nil
}
// expectPChars determines if "t" consists of only pchars defined in RFC3986.
//
// https://www.ietf.org/rfc/rfc3986.txt, P.49
// pchar = unreserved / pct-encoded / sub-delims / ":" / "@"
// unreserved = ALPHA / DIGIT / "-" / "." / "_" / "~"
// sub-delims = "!" / "$" / "&" / "'" / "(" / ")"
// / "*" / "+" / "," / ";" / "="
// pct-encoded = "%" HEXDIG HEXDIG
func expectPChars(t string) error {
const (
init = iota
pct1
pct2
)
st := init
for _, r := range t {
if st != init {
if !isHexDigit(r) {
return fmt.Errorf("invalid hexdigit: %c(%U)", r, r)
}
switch st {
case pct1:
st = pct2
case pct2:
st = init
}
continue
}
// unreserved
switch {
case 'A' <= r && r <= 'Z':
continue
case 'a' <= r && r <= 'z':
continue
case '0' <= r && r <= '9':
continue
}
switch r {
case '-', '.', '_', '~':
// unreserved
case '!', '$', '&', '\'', '(', ')', '*', '+', ',', ';', '=':
// sub-delims
case ':', '@':
// rest of pchar
case '%':
// pct-encoded
st = pct1
default:
return fmt.Errorf("invalid character in path segment: %q(%U)", r, r)
}
}
if st != init {
return fmt.Errorf("invalid percent-encoding in %q", t)
}
return nil
}
// expectIdent determines if "ident" is a valid identifier in .proto schema ([[:alpha:]_][[:alphanum:]_]*).
func expectIdent(ident string) error {
if ident == "" {
return fmt.Errorf("empty identifier")
}
for pos, r := range ident {
switch {
case '0' <= r && r <= '9':
if pos == 0 {
return fmt.Errorf("identifier starting with digit: %s", ident)
}
continue
case 'A' <= r && r <= 'Z':
continue
case 'a' <= r && r <= 'z':
continue
case r == '_':
continue
default:
return fmt.Errorf("invalid character %q(%U) in identifier: %s", r, r, ident)
}
}
return nil
}
func isHexDigit(r rune) bool {
switch {
case '0' <= r && r <= '9':
return true
case 'A' <= r && r <= 'F':
return true
case 'a' <= r && r <= 'f':
return true
}
return false
}

View File

@@ -0,0 +1,60 @@
package httprule
import (
"fmt"
"strings"
)
type template struct {
segments []segment
verb string
template string
}
type segment interface {
fmt.Stringer
compile() (ops []op)
}
type wildcard struct{}
type deepWildcard struct{}
type literal string
type variable struct {
path string
segments []segment
}
func (wildcard) String() string {
return "*"
}
func (deepWildcard) String() string {
return "**"
}
func (l literal) String() string {
return string(l)
}
func (v variable) String() string {
var segs []string
for _, s := range v.segments {
segs = append(segs, s.String())
}
return fmt.Sprintf("{%s=%s}", v.path, strings.Join(segs, "/"))
}
func (t template) String() string {
var segs []string
for _, s := range t.segments {
segs = append(segs, s.String())
}
str := strings.Join(segs, "/")
if t.verb != "" {
str = fmt.Sprintf("%s:%s", str, t.verb)
}
return "/" + str
}

View File

@@ -0,0 +1,91 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
package(default_visibility = ["//visibility:public"])
go_library(
name = "runtime",
srcs = [
"context.go",
"convert.go",
"doc.go",
"errors.go",
"fieldmask.go",
"handler.go",
"marshal_httpbodyproto.go",
"marshal_json.go",
"marshal_jsonpb.go",
"marshal_proto.go",
"marshaler.go",
"marshaler_registry.go",
"mux.go",
"pattern.go",
"proto2_convert.go",
"query.go",
],
importpath = "github.com/grpc-ecosystem/grpc-gateway/v2/runtime",
deps = [
"//internal/httprule",
"//utilities",
"@go_googleapis//google/api:httpbody_go_proto",
"@io_bazel_rules_go//proto/wkt:field_mask_go_proto",
"@org_golang_google_grpc//codes",
"@org_golang_google_grpc//grpclog",
"@org_golang_google_grpc//metadata",
"@org_golang_google_grpc//status",
"@org_golang_google_protobuf//encoding/protojson",
"@org_golang_google_protobuf//proto",
"@org_golang_google_protobuf//reflect/protoreflect",
"@org_golang_google_protobuf//reflect/protoregistry",
"@org_golang_google_protobuf//types/known/durationpb",
"@org_golang_google_protobuf//types/known/timestamppb",
"@org_golang_google_protobuf//types/known/wrapperspb",
],
)
go_test(
name = "runtime_test",
size = "small",
srcs = [
"context_test.go",
"convert_test.go",
"errors_test.go",
"fieldmask_test.go",
"handler_test.go",
"marshal_httpbodyproto_test.go",
"marshal_json_test.go",
"marshal_jsonpb_test.go",
"marshal_proto_test.go",
"marshaler_registry_test.go",
"mux_test.go",
"pattern_test.go",
"query_test.go",
],
embed = [":runtime"],
deps = [
"//runtime/internal/examplepb",
"//utilities",
"@com_github_google_go_cmp//cmp",
"@com_github_google_go_cmp//cmp/cmpopts",
"@go_googleapis//google/api:httpbody_go_proto",
"@go_googleapis//google/rpc:errdetails_go_proto",
"@go_googleapis//google/rpc:status_go_proto",
"@io_bazel_rules_go//proto/wkt:field_mask_go_proto",
"@org_golang_google_grpc//codes",
"@org_golang_google_grpc//metadata",
"@org_golang_google_grpc//status",
"@org_golang_google_protobuf//encoding/protojson",
"@org_golang_google_protobuf//proto",
"@org_golang_google_protobuf//testing/protocmp",
"@org_golang_google_protobuf//types/known/durationpb",
"@org_golang_google_protobuf//types/known/emptypb",
"@org_golang_google_protobuf//types/known/structpb",
"@org_golang_google_protobuf//types/known/timestamppb",
"@org_golang_google_protobuf//types/known/wrapperspb",
],
)
alias(
name = "go_default_library",
actual = ":runtime",
visibility = ["//visibility:public"],
)

View File

@@ -0,0 +1,345 @@
package runtime
import (
"context"
"encoding/base64"
"fmt"
"net"
"net/http"
"net/textproto"
"strconv"
"strings"
"sync"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
// MetadataHeaderPrefix is the http prefix that represents custom metadata
// parameters to or from a gRPC call.
const MetadataHeaderPrefix = "Grpc-Metadata-"
// MetadataPrefix is prepended to permanent HTTP header keys (as specified
// by the IANA) when added to the gRPC context.
const MetadataPrefix = "grpcgateway-"
// MetadataTrailerPrefix is prepended to gRPC metadata as it is converted to
// HTTP headers in a response handled by grpc-gateway
const MetadataTrailerPrefix = "Grpc-Trailer-"
const metadataGrpcTimeout = "Grpc-Timeout"
const metadataHeaderBinarySuffix = "-Bin"
const xForwardedFor = "X-Forwarded-For"
const xForwardedHost = "X-Forwarded-Host"
var (
// DefaultContextTimeout is used for gRPC call context.WithTimeout whenever a Grpc-Timeout inbound
// header isn't present. If the value is 0 the sent `context` will not have a timeout.
DefaultContextTimeout = 0 * time.Second
)
type (
rpcMethodKey struct{}
httpPathPatternKey struct{}
AnnotateContextOption func(ctx context.Context) context.Context
)
func WithHTTPPathPattern(pattern string) AnnotateContextOption {
return func(ctx context.Context) context.Context {
return withHTTPPathPattern(ctx, pattern)
}
}
func decodeBinHeader(v string) ([]byte, error) {
if len(v)%4 == 0 {
// Input was padded, or padding was not necessary.
return base64.StdEncoding.DecodeString(v)
}
return base64.RawStdEncoding.DecodeString(v)
}
/*
AnnotateContext adds context information such as metadata from the request.
At a minimum, the RemoteAddr is included in the fashion of "X-Forwarded-For",
except that the forwarded destination is not another HTTP service but rather
a gRPC service.
*/
func AnnotateContext(ctx context.Context, mux *ServeMux, req *http.Request, rpcMethodName string, options ...AnnotateContextOption) (context.Context, error) {
ctx, md, err := annotateContext(ctx, mux, req, rpcMethodName, options...)
if err != nil {
return nil, err
}
if md == nil {
return ctx, nil
}
return metadata.NewOutgoingContext(ctx, md), nil
}
// AnnotateIncomingContext adds context information such as metadata from the request.
// Attach metadata as incoming context.
func AnnotateIncomingContext(ctx context.Context, mux *ServeMux, req *http.Request, rpcMethodName string, options ...AnnotateContextOption) (context.Context, error) {
ctx, md, err := annotateContext(ctx, mux, req, rpcMethodName, options...)
if err != nil {
return nil, err
}
if md == nil {
return ctx, nil
}
return metadata.NewIncomingContext(ctx, md), nil
}
func annotateContext(ctx context.Context, mux *ServeMux, req *http.Request, rpcMethodName string, options ...AnnotateContextOption) (context.Context, metadata.MD, error) {
ctx = withRPCMethod(ctx, rpcMethodName)
for _, o := range options {
ctx = o(ctx)
}
var pairs []string
timeout := DefaultContextTimeout
if tm := req.Header.Get(metadataGrpcTimeout); tm != "" {
var err error
timeout, err = timeoutDecode(tm)
if err != nil {
return nil, nil, status.Errorf(codes.InvalidArgument, "invalid grpc-timeout: %s", tm)
}
}
for key, vals := range req.Header {
key = textproto.CanonicalMIMEHeaderKey(key)
for _, val := range vals {
// For backwards-compatibility, pass through 'authorization' header with no prefix.
if key == "Authorization" {
pairs = append(pairs, "authorization", val)
}
if h, ok := mux.incomingHeaderMatcher(key); ok {
// Handles "-bin" metadata in grpc, since grpc will do another base64
// encode before sending to server, we need to decode it first.
if strings.HasSuffix(key, metadataHeaderBinarySuffix) {
b, err := decodeBinHeader(val)
if err != nil {
return nil, nil, status.Errorf(codes.InvalidArgument, "invalid binary header %s: %s", key, err)
}
val = string(b)
}
pairs = append(pairs, h, val)
}
}
}
if host := req.Header.Get(xForwardedHost); host != "" {
pairs = append(pairs, strings.ToLower(xForwardedHost), host)
} else if req.Host != "" {
pairs = append(pairs, strings.ToLower(xForwardedHost), req.Host)
}
if addr := req.RemoteAddr; addr != "" {
if remoteIP, _, err := net.SplitHostPort(addr); err == nil {
if fwd := req.Header.Get(xForwardedFor); fwd == "" {
pairs = append(pairs, strings.ToLower(xForwardedFor), remoteIP)
} else {
pairs = append(pairs, strings.ToLower(xForwardedFor), fmt.Sprintf("%s, %s", fwd, remoteIP))
}
}
}
if timeout != 0 {
//nolint:govet // The context outlives this function
ctx, _ = context.WithTimeout(ctx, timeout)
}
if len(pairs) == 0 {
return ctx, nil, nil
}
md := metadata.Pairs(pairs...)
for _, mda := range mux.metadataAnnotators {
md = metadata.Join(md, mda(ctx, req))
}
return ctx, md, nil
}
// ServerMetadata consists of metadata sent from gRPC server.
type ServerMetadata struct {
HeaderMD metadata.MD
TrailerMD metadata.MD
}
type serverMetadataKey struct{}
// NewServerMetadataContext creates a new context with ServerMetadata
func NewServerMetadataContext(ctx context.Context, md ServerMetadata) context.Context {
return context.WithValue(ctx, serverMetadataKey{}, md)
}
// ServerMetadataFromContext returns the ServerMetadata in ctx
func ServerMetadataFromContext(ctx context.Context) (md ServerMetadata, ok bool) {
md, ok = ctx.Value(serverMetadataKey{}).(ServerMetadata)
return
}
// ServerTransportStream implements grpc.ServerTransportStream.
// It should only be used by the generated files to support grpc.SendHeader
// outside of gRPC server use.
type ServerTransportStream struct {
mu sync.Mutex
header metadata.MD
trailer metadata.MD
}
// Method returns the method for the stream.
func (s *ServerTransportStream) Method() string {
return ""
}
// Header returns the header metadata of the stream.
func (s *ServerTransportStream) Header() metadata.MD {
s.mu.Lock()
defer s.mu.Unlock()
return s.header.Copy()
}
// SetHeader sets the header metadata.
func (s *ServerTransportStream) SetHeader(md metadata.MD) error {
if md.Len() == 0 {
return nil
}
s.mu.Lock()
s.header = metadata.Join(s.header, md)
s.mu.Unlock()
return nil
}
// SendHeader sets the header metadata.
func (s *ServerTransportStream) SendHeader(md metadata.MD) error {
return s.SetHeader(md)
}
// Trailer returns the cached trailer metadata.
func (s *ServerTransportStream) Trailer() metadata.MD {
s.mu.Lock()
defer s.mu.Unlock()
return s.trailer.Copy()
}
// SetTrailer sets the trailer metadata.
func (s *ServerTransportStream) SetTrailer(md metadata.MD) error {
if md.Len() == 0 {
return nil
}
s.mu.Lock()
s.trailer = metadata.Join(s.trailer, md)
s.mu.Unlock()
return nil
}
func timeoutDecode(s string) (time.Duration, error) {
size := len(s)
if size < 2 {
return 0, fmt.Errorf("timeout string is too short: %q", s)
}
d, ok := timeoutUnitToDuration(s[size-1])
if !ok {
return 0, fmt.Errorf("timeout unit is not recognized: %q", s)
}
t, err := strconv.ParseInt(s[:size-1], 10, 64)
if err != nil {
return 0, err
}
return d * time.Duration(t), nil
}
func timeoutUnitToDuration(u uint8) (d time.Duration, ok bool) {
switch u {
case 'H':
return time.Hour, true
case 'M':
return time.Minute, true
case 'S':
return time.Second, true
case 'm':
return time.Millisecond, true
case 'u':
return time.Microsecond, true
case 'n':
return time.Nanosecond, true
default:
}
return
}
// isPermanentHTTPHeader checks whether hdr belongs to the list of
// permanent request headers maintained by IANA.
// http://www.iana.org/assignments/message-headers/message-headers.xml
func isPermanentHTTPHeader(hdr string) bool {
switch hdr {
case
"Accept",
"Accept-Charset",
"Accept-Language",
"Accept-Ranges",
"Authorization",
"Cache-Control",
"Content-Type",
"Cookie",
"Date",
"Expect",
"From",
"Host",
"If-Match",
"If-Modified-Since",
"If-None-Match",
"If-Schedule-Tag-Match",
"If-Unmodified-Since",
"Max-Forwards",
"Origin",
"Pragma",
"Referer",
"User-Agent",
"Via",
"Warning":
return true
}
return false
}
// RPCMethod returns the method string for the server context. The returned
// string is in the format of "/package.service/method".
func RPCMethod(ctx context.Context) (string, bool) {
m := ctx.Value(rpcMethodKey{})
if m == nil {
return "", false
}
ms, ok := m.(string)
if !ok {
return "", false
}
return ms, true
}
func withRPCMethod(ctx context.Context, rpcMethodName string) context.Context {
return context.WithValue(ctx, rpcMethodKey{}, rpcMethodName)
}
// HTTPPathPattern returns the HTTP path pattern string relating to the HTTP handler, if one exists.
// The format of the returned string is defined by the google.api.http path template type.
func HTTPPathPattern(ctx context.Context) (string, bool) {
m := ctx.Value(httpPathPatternKey{})
if m == nil {
return "", false
}
ms, ok := m.(string)
if !ok {
return "", false
}
return ms, true
}
func withHTTPPathPattern(ctx context.Context, httpPathPattern string) context.Context {
return context.WithValue(ctx, httpPathPatternKey{}, httpPathPattern)
}

View File

@@ -0,0 +1,322 @@
package runtime
import (
"encoding/base64"
"fmt"
"strconv"
"strings"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb"
"google.golang.org/protobuf/types/known/wrapperspb"
)
// String just returns the given string.
// It is just for compatibility to other types.
func String(val string) (string, error) {
return val, nil
}
// StringSlice converts 'val' where individual strings are separated by
// 'sep' into a string slice.
func StringSlice(val, sep string) ([]string, error) {
return strings.Split(val, sep), nil
}
// Bool converts the given string representation of a boolean value into bool.
func Bool(val string) (bool, error) {
return strconv.ParseBool(val)
}
// BoolSlice converts 'val' where individual booleans are separated by
// 'sep' into a bool slice.
func BoolSlice(val, sep string) ([]bool, error) {
s := strings.Split(val, sep)
values := make([]bool, len(s))
for i, v := range s {
value, err := Bool(v)
if err != nil {
return values, err
}
values[i] = value
}
return values, nil
}
// Float64 converts the given string representation into representation of a floating point number into float64.
func Float64(val string) (float64, error) {
return strconv.ParseFloat(val, 64)
}
// Float64Slice converts 'val' where individual floating point numbers are separated by
// 'sep' into a float64 slice.
func Float64Slice(val, sep string) ([]float64, error) {
s := strings.Split(val, sep)
values := make([]float64, len(s))
for i, v := range s {
value, err := Float64(v)
if err != nil {
return values, err
}
values[i] = value
}
return values, nil
}
// Float32 converts the given string representation of a floating point number into float32.
func Float32(val string) (float32, error) {
f, err := strconv.ParseFloat(val, 32)
if err != nil {
return 0, err
}
return float32(f), nil
}
// Float32Slice converts 'val' where individual floating point numbers are separated by
// 'sep' into a float32 slice.
func Float32Slice(val, sep string) ([]float32, error) {
s := strings.Split(val, sep)
values := make([]float32, len(s))
for i, v := range s {
value, err := Float32(v)
if err != nil {
return values, err
}
values[i] = value
}
return values, nil
}
// Int64 converts the given string representation of an integer into int64.
func Int64(val string) (int64, error) {
return strconv.ParseInt(val, 0, 64)
}
// Int64Slice converts 'val' where individual integers are separated by
// 'sep' into a int64 slice.
func Int64Slice(val, sep string) ([]int64, error) {
s := strings.Split(val, sep)
values := make([]int64, len(s))
for i, v := range s {
value, err := Int64(v)
if err != nil {
return values, err
}
values[i] = value
}
return values, nil
}
// Int32 converts the given string representation of an integer into int32.
func Int32(val string) (int32, error) {
i, err := strconv.ParseInt(val, 0, 32)
if err != nil {
return 0, err
}
return int32(i), nil
}
// Int32Slice converts 'val' where individual integers are separated by
// 'sep' into a int32 slice.
func Int32Slice(val, sep string) ([]int32, error) {
s := strings.Split(val, sep)
values := make([]int32, len(s))
for i, v := range s {
value, err := Int32(v)
if err != nil {
return values, err
}
values[i] = value
}
return values, nil
}
// Uint64 converts the given string representation of an integer into uint64.
func Uint64(val string) (uint64, error) {
return strconv.ParseUint(val, 0, 64)
}
// Uint64Slice converts 'val' where individual integers are separated by
// 'sep' into a uint64 slice.
func Uint64Slice(val, sep string) ([]uint64, error) {
s := strings.Split(val, sep)
values := make([]uint64, len(s))
for i, v := range s {
value, err := Uint64(v)
if err != nil {
return values, err
}
values[i] = value
}
return values, nil
}
// Uint32 converts the given string representation of an integer into uint32.
func Uint32(val string) (uint32, error) {
i, err := strconv.ParseUint(val, 0, 32)
if err != nil {
return 0, err
}
return uint32(i), nil
}
// Uint32Slice converts 'val' where individual integers are separated by
// 'sep' into a uint32 slice.
func Uint32Slice(val, sep string) ([]uint32, error) {
s := strings.Split(val, sep)
values := make([]uint32, len(s))
for i, v := range s {
value, err := Uint32(v)
if err != nil {
return values, err
}
values[i] = value
}
return values, nil
}
// Bytes converts the given string representation of a byte sequence into a slice of bytes
// A bytes sequence is encoded in URL-safe base64 without padding
func Bytes(val string) ([]byte, error) {
b, err := base64.StdEncoding.DecodeString(val)
if err != nil {
b, err = base64.URLEncoding.DecodeString(val)
if err != nil {
return nil, err
}
}
return b, nil
}
// BytesSlice converts 'val' where individual bytes sequences, encoded in URL-safe
// base64 without padding, are separated by 'sep' into a slice of bytes slices slice.
func BytesSlice(val, sep string) ([][]byte, error) {
s := strings.Split(val, sep)
values := make([][]byte, len(s))
for i, v := range s {
value, err := Bytes(v)
if err != nil {
return values, err
}
values[i] = value
}
return values, nil
}
// Timestamp converts the given RFC3339 formatted string into a timestamp.Timestamp.
func Timestamp(val string) (*timestamppb.Timestamp, error) {
var r timestamppb.Timestamp
val = strconv.Quote(strings.Trim(val, `"`))
unmarshaler := &protojson.UnmarshalOptions{}
err := unmarshaler.Unmarshal([]byte(val), &r)
if err != nil {
return nil, err
}
return &r, nil
}
// Duration converts the given string into a timestamp.Duration.
func Duration(val string) (*durationpb.Duration, error) {
var r durationpb.Duration
val = strconv.Quote(strings.Trim(val, `"`))
unmarshaler := &protojson.UnmarshalOptions{}
err := unmarshaler.Unmarshal([]byte(val), &r)
if err != nil {
return nil, err
}
return &r, nil
}
// Enum converts the given string into an int32 that should be type casted into the
// correct enum proto type.
func Enum(val string, enumValMap map[string]int32) (int32, error) {
e, ok := enumValMap[val]
if ok {
return e, nil
}
i, err := Int32(val)
if err != nil {
return 0, fmt.Errorf("%s is not valid", val)
}
for _, v := range enumValMap {
if v == i {
return i, nil
}
}
return 0, fmt.Errorf("%s is not valid", val)
}
// EnumSlice converts 'val' where individual enums are separated by 'sep'
// into a int32 slice. Each individual int32 should be type casted into the
// correct enum proto type.
func EnumSlice(val, sep string, enumValMap map[string]int32) ([]int32, error) {
s := strings.Split(val, sep)
values := make([]int32, len(s))
for i, v := range s {
value, err := Enum(v, enumValMap)
if err != nil {
return values, err
}
values[i] = value
}
return values, nil
}
/*
Support fot google.protobuf.wrappers on top of primitive types
*/
// StringValue well-known type support as wrapper around string type
func StringValue(val string) (*wrapperspb.StringValue, error) {
return &wrapperspb.StringValue{Value: val}, nil
}
// FloatValue well-known type support as wrapper around float32 type
func FloatValue(val string) (*wrapperspb.FloatValue, error) {
parsedVal, err := Float32(val)
return &wrapperspb.FloatValue{Value: parsedVal}, err
}
// DoubleValue well-known type support as wrapper around float64 type
func DoubleValue(val string) (*wrapperspb.DoubleValue, error) {
parsedVal, err := Float64(val)
return &wrapperspb.DoubleValue{Value: parsedVal}, err
}
// BoolValue well-known type support as wrapper around bool type
func BoolValue(val string) (*wrapperspb.BoolValue, error) {
parsedVal, err := Bool(val)
return &wrapperspb.BoolValue{Value: parsedVal}, err
}
// Int32Value well-known type support as wrapper around int32 type
func Int32Value(val string) (*wrapperspb.Int32Value, error) {
parsedVal, err := Int32(val)
return &wrapperspb.Int32Value{Value: parsedVal}, err
}
// UInt32Value well-known type support as wrapper around uint32 type
func UInt32Value(val string) (*wrapperspb.UInt32Value, error) {
parsedVal, err := Uint32(val)
return &wrapperspb.UInt32Value{Value: parsedVal}, err
}
// Int64Value well-known type support as wrapper around int64 type
func Int64Value(val string) (*wrapperspb.Int64Value, error) {
parsedVal, err := Int64(val)
return &wrapperspb.Int64Value{Value: parsedVal}, err
}
// UInt64Value well-known type support as wrapper around uint64 type
func UInt64Value(val string) (*wrapperspb.UInt64Value, error) {
parsedVal, err := Uint64(val)
return &wrapperspb.UInt64Value{Value: parsedVal}, err
}
// BytesValue well-known type support as wrapper around bytes[] type
func BytesValue(val string) (*wrapperspb.BytesValue, error) {
parsedVal, err := Bytes(val)
return &wrapperspb.BytesValue{Value: parsedVal}, err
}

View File

@@ -0,0 +1,5 @@
/*
Package runtime contains runtime helper functions used by
servers which protoc-gen-grpc-gateway generates.
*/
package runtime

View File

@@ -0,0 +1,180 @@
package runtime
import (
"context"
"errors"
"io"
"net/http"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/status"
)
// ErrorHandlerFunc is the signature used to configure error handling.
type ErrorHandlerFunc func(context.Context, *ServeMux, Marshaler, http.ResponseWriter, *http.Request, error)
// StreamErrorHandlerFunc is the signature used to configure stream error handling.
type StreamErrorHandlerFunc func(context.Context, error) *status.Status
// RoutingErrorHandlerFunc is the signature used to configure error handling for routing errors.
type RoutingErrorHandlerFunc func(context.Context, *ServeMux, Marshaler, http.ResponseWriter, *http.Request, int)
// HTTPStatusError is the error to use when needing to provide a different HTTP status code for an error
// passed to the DefaultRoutingErrorHandler.
type HTTPStatusError struct {
HTTPStatus int
Err error
}
func (e *HTTPStatusError) Error() string {
return e.Err.Error()
}
// HTTPStatusFromCode converts a gRPC error code into the corresponding HTTP response status.
// See: https://github.com/googleapis/googleapis/blob/master/google/rpc/code.proto
func HTTPStatusFromCode(code codes.Code) int {
switch code {
case codes.OK:
return http.StatusOK
case codes.Canceled:
return http.StatusRequestTimeout
case codes.Unknown:
return http.StatusInternalServerError
case codes.InvalidArgument:
return http.StatusBadRequest
case codes.DeadlineExceeded:
return http.StatusGatewayTimeout
case codes.NotFound:
return http.StatusNotFound
case codes.AlreadyExists:
return http.StatusConflict
case codes.PermissionDenied:
return http.StatusForbidden
case codes.Unauthenticated:
return http.StatusUnauthorized
case codes.ResourceExhausted:
return http.StatusTooManyRequests
case codes.FailedPrecondition:
// Note, this deliberately doesn't translate to the similarly named '412 Precondition Failed' HTTP response status.
return http.StatusBadRequest
case codes.Aborted:
return http.StatusConflict
case codes.OutOfRange:
return http.StatusBadRequest
case codes.Unimplemented:
return http.StatusNotImplemented
case codes.Internal:
return http.StatusInternalServerError
case codes.Unavailable:
return http.StatusServiceUnavailable
case codes.DataLoss:
return http.StatusInternalServerError
}
grpclog.Infof("Unknown gRPC error code: %v", code)
return http.StatusInternalServerError
}
// HTTPError uses the mux-configured error handler.
func HTTPError(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, r *http.Request, err error) {
mux.errorHandler(ctx, mux, marshaler, w, r, err)
}
// DefaultHTTPErrorHandler is the default error handler.
// If "err" is a gRPC Status, the function replies with the status code mapped by HTTPStatusFromCode.
// If "err" is a HTTPStatusError, the function replies with the status code provide by that struct. This is
// intended to allow passing through of specific statuses via the function set via WithRoutingErrorHandler
// for the ServeMux constructor to handle edge cases which the standard mappings in HTTPStatusFromCode
// are insufficient for.
// If otherwise, it replies with http.StatusInternalServerError.
//
// The response body written by this function is a Status message marshaled by the Marshaler.
func DefaultHTTPErrorHandler(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, r *http.Request, err error) {
// return Internal when Marshal failed
const fallback = `{"code": 13, "message": "failed to marshal error message"}`
var customStatus *HTTPStatusError
if errors.As(err, &customStatus) {
err = customStatus.Err
}
s := status.Convert(err)
pb := s.Proto()
w.Header().Del("Trailer")
w.Header().Del("Transfer-Encoding")
contentType := marshaler.ContentType(pb)
w.Header().Set("Content-Type", contentType)
if s.Code() == codes.Unauthenticated {
w.Header().Set("WWW-Authenticate", s.Message())
}
buf, merr := marshaler.Marshal(pb)
if merr != nil {
grpclog.Infof("Failed to marshal error message %q: %v", s, merr)
w.WriteHeader(http.StatusInternalServerError)
if _, err := io.WriteString(w, fallback); err != nil {
grpclog.Infof("Failed to write response: %v", err)
}
return
}
md, ok := ServerMetadataFromContext(ctx)
if !ok {
grpclog.Infof("Failed to extract ServerMetadata from context")
}
handleForwardResponseServerMetadata(w, mux, md)
// RFC 7230 https://tools.ietf.org/html/rfc7230#section-4.1.2
// Unless the request includes a TE header field indicating "trailers"
// is acceptable, as described in Section 4.3, a server SHOULD NOT
// generate trailer fields that it believes are necessary for the user
// agent to receive.
doForwardTrailers := requestAcceptsTrailers(r)
if doForwardTrailers {
handleForwardResponseTrailerHeader(w, md)
w.Header().Set("Transfer-Encoding", "chunked")
}
st := HTTPStatusFromCode(s.Code())
if customStatus != nil {
st = customStatus.HTTPStatus
}
w.WriteHeader(st)
if _, err := w.Write(buf); err != nil {
grpclog.Infof("Failed to write response: %v", err)
}
if doForwardTrailers {
handleForwardResponseTrailer(w, md)
}
}
func DefaultStreamErrorHandler(_ context.Context, err error) *status.Status {
return status.Convert(err)
}
// DefaultRoutingErrorHandler is our default handler for routing errors.
// By default http error codes mapped on the following error codes:
// NotFound -> grpc.NotFound
// StatusBadRequest -> grpc.InvalidArgument
// MethodNotAllowed -> grpc.Unimplemented
// Other -> grpc.Internal, method is not expecting to be called for anything else
func DefaultRoutingErrorHandler(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, r *http.Request, httpStatus int) {
sterr := status.Error(codes.Internal, "Unexpected routing error")
switch httpStatus {
case http.StatusBadRequest:
sterr = status.Error(codes.InvalidArgument, http.StatusText(httpStatus))
case http.StatusMethodNotAllowed:
sterr = status.Error(codes.Unimplemented, http.StatusText(httpStatus))
case http.StatusNotFound:
sterr = status.Error(codes.NotFound, http.StatusText(httpStatus))
}
mux.errorHandler(ctx, mux, marshaler, w, r, sterr)
}

View File

@@ -0,0 +1,165 @@
package runtime
import (
"encoding/json"
"fmt"
"io"
"sort"
"google.golang.org/genproto/protobuf/field_mask"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
)
func getFieldByName(fields protoreflect.FieldDescriptors, name string) protoreflect.FieldDescriptor {
fd := fields.ByName(protoreflect.Name(name))
if fd != nil {
return fd
}
return fields.ByJSONName(name)
}
// FieldMaskFromRequestBody creates a FieldMask printing all complete paths from the JSON body.
func FieldMaskFromRequestBody(r io.Reader, msg proto.Message) (*field_mask.FieldMask, error) {
fm := &field_mask.FieldMask{}
var root interface{}
if err := json.NewDecoder(r).Decode(&root); err != nil {
if err == io.EOF {
return fm, nil
}
return nil, err
}
queue := []fieldMaskPathItem{{node: root, msg: msg.ProtoReflect()}}
for len(queue) > 0 {
// dequeue an item
item := queue[0]
queue = queue[1:]
m, ok := item.node.(map[string]interface{})
switch {
case ok:
// if the item is an object, then enqueue all of its children
for k, v := range m {
if item.msg == nil {
return nil, fmt.Errorf("JSON structure did not match request type")
}
fd := getFieldByName(item.msg.Descriptor().Fields(), k)
if fd == nil {
return nil, fmt.Errorf("could not find field %q in %q", k, item.msg.Descriptor().FullName())
}
if isDynamicProtoMessage(fd.Message()) {
for _, p := range buildPathsBlindly(k, v) {
newPath := p
if item.path != "" {
newPath = item.path + "." + newPath
}
queue = append(queue, fieldMaskPathItem{path: newPath})
}
continue
}
if isProtobufAnyMessage(fd.Message()) {
_, hasTypeField := v.(map[string]interface{})["@type"]
if hasTypeField {
queue = append(queue, fieldMaskPathItem{path: k})
continue
} else {
return nil, fmt.Errorf("could not find field @type in %q in message %q", k, item.msg.Descriptor().FullName())
}
}
child := fieldMaskPathItem{
node: v,
}
if item.path == "" {
child.path = string(fd.FullName().Name())
} else {
child.path = item.path + "." + string(fd.FullName().Name())
}
switch {
case fd.IsList(), fd.IsMap():
// As per: https://github.com/protocolbuffers/protobuf/blob/master/src/google/protobuf/field_mask.proto#L85-L86
// Do not recurse into repeated fields. The repeated field goes on the end of the path and we stop.
fm.Paths = append(fm.Paths, child.path)
case fd.Message() != nil:
child.msg = item.msg.Get(fd).Message()
fallthrough
default:
queue = append(queue, child)
}
}
case len(item.path) > 0:
// otherwise, it's a leaf node so print its path
fm.Paths = append(fm.Paths, item.path)
}
}
// Sort for deterministic output in the presence
// of repeated fields.
sort.Strings(fm.Paths)
return fm, nil
}
func isProtobufAnyMessage(md protoreflect.MessageDescriptor) bool {
return md != nil && (md.FullName() == "google.protobuf.Any")
}
func isDynamicProtoMessage(md protoreflect.MessageDescriptor) bool {
return md != nil && (md.FullName() == "google.protobuf.Struct" || md.FullName() == "google.protobuf.Value")
}
// buildPathsBlindly does not attempt to match proto field names to the
// json value keys. Instead it relies completely on the structure of
// the unmarshalled json contained within in.
// Returns a slice containing all subpaths with the root at the
// passed in name and json value.
func buildPathsBlindly(name string, in interface{}) []string {
m, ok := in.(map[string]interface{})
if !ok {
return []string{name}
}
var paths []string
queue := []fieldMaskPathItem{{path: name, node: m}}
for len(queue) > 0 {
cur := queue[0]
queue = queue[1:]
m, ok := cur.node.(map[string]interface{})
if !ok {
// This should never happen since we should always check that we only add
// nodes of type map[string]interface{} to the queue.
continue
}
for k, v := range m {
if mi, ok := v.(map[string]interface{}); ok {
queue = append(queue, fieldMaskPathItem{path: cur.path + "." + k, node: mi})
} else {
// This is not a struct, so there are no more levels to descend.
curPath := cur.path + "." + k
paths = append(paths, curPath)
}
}
}
return paths
}
// fieldMaskPathItem stores a in-progress deconstruction of a path for a fieldmask
type fieldMaskPathItem struct {
// the list of prior fields leading up to node connected by dots
path string
// a generic decoded json object the current item to inspect for further path extraction
node interface{}
// parent message
msg protoreflect.Message
}

View File

@@ -0,0 +1,223 @@
package runtime
import (
"context"
"fmt"
"io"
"net/http"
"net/textproto"
"strings"
"google.golang.org/genproto/googleapis/api/httpbody"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
)
// ForwardResponseStream forwards the stream from gRPC server to REST client.
func ForwardResponseStream(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, req *http.Request, recv func() (proto.Message, error), opts ...func(context.Context, http.ResponseWriter, proto.Message) error) {
f, ok := w.(http.Flusher)
if !ok {
grpclog.Infof("Flush not supported in %T", w)
http.Error(w, "unexpected type of web server", http.StatusInternalServerError)
return
}
md, ok := ServerMetadataFromContext(ctx)
if !ok {
grpclog.Infof("Failed to extract ServerMetadata from context")
http.Error(w, "unexpected error", http.StatusInternalServerError)
return
}
handleForwardResponseServerMetadata(w, mux, md)
w.Header().Set("Transfer-Encoding", "chunked")
if err := handleForwardResponseOptions(ctx, w, nil, opts); err != nil {
HTTPError(ctx, mux, marshaler, w, req, err)
return
}
var delimiter []byte
if d, ok := marshaler.(Delimited); ok {
delimiter = d.Delimiter()
} else {
delimiter = []byte("\n")
}
var wroteHeader bool
for {
resp, err := recv()
if err == io.EOF {
return
}
if err != nil {
handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err)
return
}
if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil {
handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err)
return
}
if !wroteHeader {
w.Header().Set("Content-Type", marshaler.ContentType(resp))
}
var buf []byte
httpBody, isHTTPBody := resp.(*httpbody.HttpBody)
switch {
case resp == nil:
buf, err = marshaler.Marshal(errorChunk(status.New(codes.Internal, "empty response")))
case isHTTPBody:
buf = httpBody.GetData()
default:
result := map[string]interface{}{"result": resp}
if rb, ok := resp.(responseBody); ok {
result["result"] = rb.XXX_ResponseBody()
}
buf, err = marshaler.Marshal(result)
}
if err != nil {
grpclog.Infof("Failed to marshal response chunk: %v", err)
handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err)
return
}
if _, err = w.Write(buf); err != nil {
grpclog.Infof("Failed to send response chunk: %v", err)
return
}
wroteHeader = true
if _, err = w.Write(delimiter); err != nil {
grpclog.Infof("Failed to send delimiter chunk: %v", err)
return
}
f.Flush()
}
}
func handleForwardResponseServerMetadata(w http.ResponseWriter, mux *ServeMux, md ServerMetadata) {
for k, vs := range md.HeaderMD {
if h, ok := mux.outgoingHeaderMatcher(k); ok {
for _, v := range vs {
w.Header().Add(h, v)
}
}
}
}
func handleForwardResponseTrailerHeader(w http.ResponseWriter, md ServerMetadata) {
for k := range md.TrailerMD {
tKey := textproto.CanonicalMIMEHeaderKey(fmt.Sprintf("%s%s", MetadataTrailerPrefix, k))
w.Header().Add("Trailer", tKey)
}
}
func handleForwardResponseTrailer(w http.ResponseWriter, md ServerMetadata) {
for k, vs := range md.TrailerMD {
tKey := fmt.Sprintf("%s%s", MetadataTrailerPrefix, k)
for _, v := range vs {
w.Header().Add(tKey, v)
}
}
}
// responseBody interface contains method for getting field for marshaling to the response body
// this method is generated for response struct from the value of `response_body` in the `google.api.HttpRule`
type responseBody interface {
XXX_ResponseBody() interface{}
}
// ForwardResponseMessage forwards the message "resp" from gRPC server to REST client.
func ForwardResponseMessage(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, req *http.Request, resp proto.Message, opts ...func(context.Context, http.ResponseWriter, proto.Message) error) {
md, ok := ServerMetadataFromContext(ctx)
if !ok {
grpclog.Infof("Failed to extract ServerMetadata from context")
}
handleForwardResponseServerMetadata(w, mux, md)
// RFC 7230 https://tools.ietf.org/html/rfc7230#section-4.1.2
// Unless the request includes a TE header field indicating "trailers"
// is acceptable, as described in Section 4.3, a server SHOULD NOT
// generate trailer fields that it believes are necessary for the user
// agent to receive.
doForwardTrailers := requestAcceptsTrailers(req)
if doForwardTrailers {
handleForwardResponseTrailerHeader(w, md)
w.Header().Set("Transfer-Encoding", "chunked")
}
handleForwardResponseTrailerHeader(w, md)
contentType := marshaler.ContentType(resp)
w.Header().Set("Content-Type", contentType)
if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil {
HTTPError(ctx, mux, marshaler, w, req, err)
return
}
var buf []byte
var err error
if rb, ok := resp.(responseBody); ok {
buf, err = marshaler.Marshal(rb.XXX_ResponseBody())
} else {
buf, err = marshaler.Marshal(resp)
}
if err != nil {
grpclog.Infof("Marshal error: %v", err)
HTTPError(ctx, mux, marshaler, w, req, err)
return
}
if _, err = w.Write(buf); err != nil {
grpclog.Infof("Failed to write response: %v", err)
}
if doForwardTrailers {
handleForwardResponseTrailer(w, md)
}
}
func requestAcceptsTrailers(req *http.Request) bool {
te := req.Header.Get("TE")
return strings.Contains(strings.ToLower(te), "trailers")
}
func handleForwardResponseOptions(ctx context.Context, w http.ResponseWriter, resp proto.Message, opts []func(context.Context, http.ResponseWriter, proto.Message) error) error {
if len(opts) == 0 {
return nil
}
for _, opt := range opts {
if err := opt(ctx, w, resp); err != nil {
grpclog.Infof("Error handling ForwardResponseOptions: %v", err)
return err
}
}
return nil
}
func handleForwardResponseStreamError(ctx context.Context, wroteHeader bool, marshaler Marshaler, w http.ResponseWriter, req *http.Request, mux *ServeMux, err error) {
st := mux.streamErrorHandler(ctx, err)
msg := errorChunk(st)
if !wroteHeader {
w.Header().Set("Content-Type", marshaler.ContentType(msg))
w.WriteHeader(HTTPStatusFromCode(st.Code()))
}
buf, merr := marshaler.Marshal(msg)
if merr != nil {
grpclog.Infof("Failed to marshal an error: %v", merr)
return
}
if _, werr := w.Write(buf); werr != nil {
grpclog.Infof("Failed to notify error to client: %v", werr)
return
}
}
func errorChunk(st *status.Status) map[string]proto.Message {
return map[string]proto.Message{"error": st.Proto()}
}

View File

@@ -0,0 +1,32 @@
package runtime
import (
"google.golang.org/genproto/googleapis/api/httpbody"
)
// HTTPBodyMarshaler is a Marshaler which supports marshaling of a
// google.api.HttpBody message as the full response body if it is
// the actual message used as the response. If not, then this will
// simply fallback to the Marshaler specified as its default Marshaler.
type HTTPBodyMarshaler struct {
Marshaler
}
// ContentType returns its specified content type in case v is a
// google.api.HttpBody message, otherwise it will fall back to the default Marshalers
// content type.
func (h *HTTPBodyMarshaler) ContentType(v interface{}) string {
if httpBody, ok := v.(*httpbody.HttpBody); ok {
return httpBody.GetContentType()
}
return h.Marshaler.ContentType(v)
}
// Marshal marshals "v" by returning the body bytes if v is a
// google.api.HttpBody message, otherwise it falls back to the default Marshaler.
func (h *HTTPBodyMarshaler) Marshal(v interface{}) ([]byte, error) {
if httpBody, ok := v.(*httpbody.HttpBody); ok {
return httpBody.Data, nil
}
return h.Marshaler.Marshal(v)
}

View File

@@ -0,0 +1,45 @@
package runtime
import (
"encoding/json"
"io"
)
// JSONBuiltin is a Marshaler which marshals/unmarshals into/from JSON
// with the standard "encoding/json" package of Golang.
// Although it is generally faster for simple proto messages than JSONPb,
// it does not support advanced features of protobuf, e.g. map, oneof, ....
//
// The NewEncoder and NewDecoder types return *json.Encoder and
// *json.Decoder respectively.
type JSONBuiltin struct{}
// ContentType always Returns "application/json".
func (*JSONBuiltin) ContentType(_ interface{}) string {
return "application/json"
}
// Marshal marshals "v" into JSON
func (j *JSONBuiltin) Marshal(v interface{}) ([]byte, error) {
return json.Marshal(v)
}
// Unmarshal unmarshals JSON data into "v".
func (j *JSONBuiltin) Unmarshal(data []byte, v interface{}) error {
return json.Unmarshal(data, v)
}
// NewDecoder returns a Decoder which reads JSON stream from "r".
func (j *JSONBuiltin) NewDecoder(r io.Reader) Decoder {
return json.NewDecoder(r)
}
// NewEncoder returns an Encoder which writes JSON stream into "w".
func (j *JSONBuiltin) NewEncoder(w io.Writer) Encoder {
return json.NewEncoder(w)
}
// Delimiter for newline encoded JSON streams.
func (j *JSONBuiltin) Delimiter() []byte {
return []byte("\n")
}

View File

@@ -0,0 +1,344 @@
package runtime
import (
"bytes"
"encoding/json"
"fmt"
"io"
"reflect"
"strconv"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
)
// JSONPb is a Marshaler which marshals/unmarshals into/from JSON
// with the "google.golang.org/protobuf/encoding/protojson" marshaler.
// It supports the full functionality of protobuf unlike JSONBuiltin.
//
// The NewDecoder method returns a DecoderWrapper, so the underlying
// *json.Decoder methods can be used.
type JSONPb struct {
protojson.MarshalOptions
protojson.UnmarshalOptions
}
// ContentType always returns "application/json".
func (*JSONPb) ContentType(_ interface{}) string {
return "application/json"
}
// Marshal marshals "v" into JSON.
func (j *JSONPb) Marshal(v interface{}) ([]byte, error) {
if _, ok := v.(proto.Message); !ok {
return j.marshalNonProtoField(v)
}
var buf bytes.Buffer
if err := j.marshalTo(&buf, v); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func (j *JSONPb) marshalTo(w io.Writer, v interface{}) error {
p, ok := v.(proto.Message)
if !ok {
buf, err := j.marshalNonProtoField(v)
if err != nil {
return err
}
_, err = w.Write(buf)
return err
}
b, err := j.MarshalOptions.Marshal(p)
if err != nil {
return err
}
_, err = w.Write(b)
return err
}
var (
// protoMessageType is stored to prevent constant lookup of the same type at runtime.
protoMessageType = reflect.TypeOf((*proto.Message)(nil)).Elem()
)
// marshalNonProto marshals a non-message field of a protobuf message.
// This function does not correctly marshal arbitrary data structures into JSON,
// it is only capable of marshaling non-message field values of protobuf,
// i.e. primitive types, enums; pointers to primitives or enums; maps from
// integer/string types to primitives/enums/pointers to messages.
func (j *JSONPb) marshalNonProtoField(v interface{}) ([]byte, error) {
if v == nil {
return []byte("null"), nil
}
rv := reflect.ValueOf(v)
for rv.Kind() == reflect.Ptr {
if rv.IsNil() {
return []byte("null"), nil
}
rv = rv.Elem()
}
if rv.Kind() == reflect.Slice {
if rv.IsNil() {
if j.EmitUnpopulated {
return []byte("[]"), nil
}
return []byte("null"), nil
}
if rv.Type().Elem().Implements(protoMessageType) {
var buf bytes.Buffer
err := buf.WriteByte('[')
if err != nil {
return nil, err
}
for i := 0; i < rv.Len(); i++ {
if i != 0 {
err = buf.WriteByte(',')
if err != nil {
return nil, err
}
}
if err = j.marshalTo(&buf, rv.Index(i).Interface().(proto.Message)); err != nil {
return nil, err
}
}
err = buf.WriteByte(']')
if err != nil {
return nil, err
}
return buf.Bytes(), nil
}
if rv.Type().Elem().Implements(typeProtoEnum) {
var buf bytes.Buffer
err := buf.WriteByte('[')
if err != nil {
return nil, err
}
for i := 0; i < rv.Len(); i++ {
if i != 0 {
err = buf.WriteByte(',')
if err != nil {
return nil, err
}
}
if j.UseEnumNumbers {
_, err = buf.WriteString(strconv.FormatInt(rv.Index(i).Int(), 10))
} else {
_, err = buf.WriteString("\"" + rv.Index(i).Interface().(protoEnum).String() + "\"")
}
if err != nil {
return nil, err
}
}
err = buf.WriteByte(']')
if err != nil {
return nil, err
}
return buf.Bytes(), nil
}
}
if rv.Kind() == reflect.Map {
m := make(map[string]*json.RawMessage)
for _, k := range rv.MapKeys() {
buf, err := j.Marshal(rv.MapIndex(k).Interface())
if err != nil {
return nil, err
}
m[fmt.Sprintf("%v", k.Interface())] = (*json.RawMessage)(&buf)
}
if j.Indent != "" {
return json.MarshalIndent(m, "", j.Indent)
}
return json.Marshal(m)
}
if enum, ok := rv.Interface().(protoEnum); ok && !j.UseEnumNumbers {
return json.Marshal(enum.String())
}
return json.Marshal(rv.Interface())
}
// Unmarshal unmarshals JSON "data" into "v"
func (j *JSONPb) Unmarshal(data []byte, v interface{}) error {
return unmarshalJSONPb(data, j.UnmarshalOptions, v)
}
// NewDecoder returns a Decoder which reads JSON stream from "r".
func (j *JSONPb) NewDecoder(r io.Reader) Decoder {
d := json.NewDecoder(r)
return DecoderWrapper{
Decoder: d,
UnmarshalOptions: j.UnmarshalOptions,
}
}
// DecoderWrapper is a wrapper around a *json.Decoder that adds
// support for protos to the Decode method.
type DecoderWrapper struct {
*json.Decoder
protojson.UnmarshalOptions
}
// Decode wraps the embedded decoder's Decode method to support
// protos using a jsonpb.Unmarshaler.
func (d DecoderWrapper) Decode(v interface{}) error {
return decodeJSONPb(d.Decoder, d.UnmarshalOptions, v)
}
// NewEncoder returns an Encoder which writes JSON stream into "w".
func (j *JSONPb) NewEncoder(w io.Writer) Encoder {
return EncoderFunc(func(v interface{}) error {
if err := j.marshalTo(w, v); err != nil {
return err
}
// mimic json.Encoder by adding a newline (makes output
// easier to read when it contains multiple encoded items)
_, err := w.Write(j.Delimiter())
return err
})
}
func unmarshalJSONPb(data []byte, unmarshaler protojson.UnmarshalOptions, v interface{}) error {
d := json.NewDecoder(bytes.NewReader(data))
return decodeJSONPb(d, unmarshaler, v)
}
func decodeJSONPb(d *json.Decoder, unmarshaler protojson.UnmarshalOptions, v interface{}) error {
p, ok := v.(proto.Message)
if !ok {
return decodeNonProtoField(d, unmarshaler, v)
}
// Decode into bytes for marshalling
var b json.RawMessage
err := d.Decode(&b)
if err != nil {
return err
}
return unmarshaler.Unmarshal([]byte(b), p)
}
func decodeNonProtoField(d *json.Decoder, unmarshaler protojson.UnmarshalOptions, v interface{}) error {
rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Ptr {
return fmt.Errorf("%T is not a pointer", v)
}
for rv.Kind() == reflect.Ptr {
if rv.IsNil() {
rv.Set(reflect.New(rv.Type().Elem()))
}
if rv.Type().ConvertibleTo(typeProtoMessage) {
// Decode into bytes for marshalling
var b json.RawMessage
err := d.Decode(&b)
if err != nil {
return err
}
return unmarshaler.Unmarshal([]byte(b), rv.Interface().(proto.Message))
}
rv = rv.Elem()
}
if rv.Kind() == reflect.Map {
if rv.IsNil() {
rv.Set(reflect.MakeMap(rv.Type()))
}
conv, ok := convFromType[rv.Type().Key().Kind()]
if !ok {
return fmt.Errorf("unsupported type of map field key: %v", rv.Type().Key())
}
m := make(map[string]*json.RawMessage)
if err := d.Decode(&m); err != nil {
return err
}
for k, v := range m {
result := conv.Call([]reflect.Value{reflect.ValueOf(k)})
if err := result[1].Interface(); err != nil {
return err.(error)
}
bk := result[0]
bv := reflect.New(rv.Type().Elem())
if v == nil {
null := json.RawMessage("null")
v = &null
}
if err := unmarshalJSONPb([]byte(*v), unmarshaler, bv.Interface()); err != nil {
return err
}
rv.SetMapIndex(bk, bv.Elem())
}
return nil
}
if rv.Kind() == reflect.Slice {
var sl []json.RawMessage
if err := d.Decode(&sl); err != nil {
return err
}
if sl != nil {
rv.Set(reflect.MakeSlice(rv.Type(), 0, 0))
}
for _, item := range sl {
bv := reflect.New(rv.Type().Elem())
if err := unmarshalJSONPb([]byte(item), unmarshaler, bv.Interface()); err != nil {
return err
}
rv.Set(reflect.Append(rv, bv.Elem()))
}
return nil
}
if _, ok := rv.Interface().(protoEnum); ok {
var repr interface{}
if err := d.Decode(&repr); err != nil {
return err
}
switch v := repr.(type) {
case string:
// TODO(yugui) Should use proto.StructProperties?
return fmt.Errorf("unmarshaling of symbolic enum %q not supported: %T", repr, rv.Interface())
case float64:
rv.Set(reflect.ValueOf(int32(v)).Convert(rv.Type()))
return nil
default:
return fmt.Errorf("cannot assign %#v into Go type %T", repr, rv.Interface())
}
}
return d.Decode(v)
}
type protoEnum interface {
fmt.Stringer
EnumDescriptor() ([]byte, []int)
}
var typeProtoEnum = reflect.TypeOf((*protoEnum)(nil)).Elem()
var typeProtoMessage = reflect.TypeOf((*proto.Message)(nil)).Elem()
// Delimiter for newline encoded JSON streams.
func (j *JSONPb) Delimiter() []byte {
return []byte("\n")
}
var (
convFromType = map[reflect.Kind]reflect.Value{
reflect.String: reflect.ValueOf(String),
reflect.Bool: reflect.ValueOf(Bool),
reflect.Float64: reflect.ValueOf(Float64),
reflect.Float32: reflect.ValueOf(Float32),
reflect.Int64: reflect.ValueOf(Int64),
reflect.Int32: reflect.ValueOf(Int32),
reflect.Uint64: reflect.ValueOf(Uint64),
reflect.Uint32: reflect.ValueOf(Uint32),
reflect.Slice: reflect.ValueOf(Bytes),
}
)

View File

@@ -0,0 +1,63 @@
package runtime
import (
"io"
"errors"
"io/ioutil"
"google.golang.org/protobuf/proto"
)
// ProtoMarshaller is a Marshaller which marshals/unmarshals into/from serialize proto bytes
type ProtoMarshaller struct{}
// ContentType always returns "application/octet-stream".
func (*ProtoMarshaller) ContentType(_ interface{}) string {
return "application/octet-stream"
}
// Marshal marshals "value" into Proto
func (*ProtoMarshaller) Marshal(value interface{}) ([]byte, error) {
message, ok := value.(proto.Message)
if !ok {
return nil, errors.New("unable to marshal non proto field")
}
return proto.Marshal(message)
}
// Unmarshal unmarshals proto "data" into "value"
func (*ProtoMarshaller) Unmarshal(data []byte, value interface{}) error {
message, ok := value.(proto.Message)
if !ok {
return errors.New("unable to unmarshal non proto field")
}
return proto.Unmarshal(data, message)
}
// NewDecoder returns a Decoder which reads proto stream from "reader".
func (marshaller *ProtoMarshaller) NewDecoder(reader io.Reader) Decoder {
return DecoderFunc(func(value interface{}) error {
buffer, err := ioutil.ReadAll(reader)
if err != nil {
return err
}
return marshaller.Unmarshal(buffer, value)
})
}
// NewEncoder returns an Encoder which writes proto stream into "writer".
func (marshaller *ProtoMarshaller) NewEncoder(writer io.Writer) Encoder {
return EncoderFunc(func(value interface{}) error {
buffer, err := marshaller.Marshal(value)
if err != nil {
return err
}
_, err = writer.Write(buffer)
if err != nil {
return err
}
return nil
})
}

View File

@@ -0,0 +1,50 @@
package runtime
import (
"io"
)
// Marshaler defines a conversion between byte sequence and gRPC payloads / fields.
type Marshaler interface {
// Marshal marshals "v" into byte sequence.
Marshal(v interface{}) ([]byte, error)
// Unmarshal unmarshals "data" into "v".
// "v" must be a pointer value.
Unmarshal(data []byte, v interface{}) error
// NewDecoder returns a Decoder which reads byte sequence from "r".
NewDecoder(r io.Reader) Decoder
// NewEncoder returns an Encoder which writes bytes sequence into "w".
NewEncoder(w io.Writer) Encoder
// ContentType returns the Content-Type which this marshaler is responsible for.
// The parameter describes the type which is being marshalled, which can sometimes
// affect the content type returned.
ContentType(v interface{}) string
}
// Decoder decodes a byte sequence
type Decoder interface {
Decode(v interface{}) error
}
// Encoder encodes gRPC payloads / fields into byte sequence.
type Encoder interface {
Encode(v interface{}) error
}
// DecoderFunc adapts an decoder function into Decoder.
type DecoderFunc func(v interface{}) error
// Decode delegates invocations to the underlying function itself.
func (f DecoderFunc) Decode(v interface{}) error { return f(v) }
// EncoderFunc adapts an encoder function into Encoder
type EncoderFunc func(v interface{}) error
// Encode delegates invocations to the underlying function itself.
func (f EncoderFunc) Encode(v interface{}) error { return f(v) }
// Delimited defines the streaming delimiter.
type Delimited interface {
// Delimiter returns the record separator for the stream.
Delimiter() []byte
}

View File

@@ -0,0 +1,109 @@
package runtime
import (
"errors"
"mime"
"net/http"
"google.golang.org/grpc/grpclog"
"google.golang.org/protobuf/encoding/protojson"
)
// MIMEWildcard is the fallback MIME type used for requests which do not match
// a registered MIME type.
const MIMEWildcard = "*"
var (
acceptHeader = http.CanonicalHeaderKey("Accept")
contentTypeHeader = http.CanonicalHeaderKey("Content-Type")
defaultMarshaler = &HTTPBodyMarshaler{
Marshaler: &JSONPb{
MarshalOptions: protojson.MarshalOptions{
EmitUnpopulated: true,
},
UnmarshalOptions: protojson.UnmarshalOptions{
DiscardUnknown: true,
},
},
}
)
// MarshalerForRequest returns the inbound/outbound marshalers for this request.
// It checks the registry on the ServeMux for the MIME type set by the Content-Type header.
// If it isn't set (or the request Content-Type is empty), checks for "*".
// If there are multiple Content-Type headers set, choose the first one that it can
// exactly match in the registry.
// Otherwise, it follows the above logic for "*"/InboundMarshaler/OutboundMarshaler.
func MarshalerForRequest(mux *ServeMux, r *http.Request) (inbound Marshaler, outbound Marshaler) {
for _, acceptVal := range r.Header[acceptHeader] {
if m, ok := mux.marshalers.mimeMap[acceptVal]; ok {
outbound = m
break
}
}
for _, contentTypeVal := range r.Header[contentTypeHeader] {
contentType, _, err := mime.ParseMediaType(contentTypeVal)
if err != nil {
grpclog.Infof("Failed to parse Content-Type %s: %v", contentTypeVal, err)
continue
}
if m, ok := mux.marshalers.mimeMap[contentType]; ok {
inbound = m
break
}
}
if inbound == nil {
inbound = mux.marshalers.mimeMap[MIMEWildcard]
}
if outbound == nil {
outbound = inbound
}
return inbound, outbound
}
// marshalerRegistry is a mapping from MIME types to Marshalers.
type marshalerRegistry struct {
mimeMap map[string]Marshaler
}
// add adds a marshaler for a case-sensitive MIME type string ("*" to match any
// MIME type).
func (m marshalerRegistry) add(mime string, marshaler Marshaler) error {
if len(mime) == 0 {
return errors.New("empty MIME type")
}
m.mimeMap[mime] = marshaler
return nil
}
// makeMarshalerMIMERegistry returns a new registry of marshalers.
// It allows for a mapping of case-sensitive Content-Type MIME type string to runtime.Marshaler interfaces.
//
// For example, you could allow the client to specify the use of the runtime.JSONPb marshaler
// with a "application/jsonpb" Content-Type and the use of the runtime.JSONBuiltin marshaler
// with a "application/json" Content-Type.
// "*" can be used to match any Content-Type.
// This can be attached to a ServerMux with the marshaler option.
func makeMarshalerMIMERegistry() marshalerRegistry {
return marshalerRegistry{
mimeMap: map[string]Marshaler{
MIMEWildcard: defaultMarshaler,
},
}
}
// WithMarshalerOption returns a ServeMuxOption which associates inbound and outbound
// Marshalers to a MIME type in mux.
func WithMarshalerOption(mime string, marshaler Marshaler) ServeMuxOption {
return func(mux *ServeMux) {
if err := mux.marshalers.add(mime, marshaler); err != nil {
panic(err)
}
}
}

View File

@@ -0,0 +1,356 @@
package runtime
import (
"context"
"errors"
"fmt"
"net/http"
"net/textproto"
"strings"
"github.com/grpc-ecosystem/grpc-gateway/v2/internal/httprule"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
)
// UnescapingMode defines the behavior of ServeMux when unescaping path parameters.
type UnescapingMode int
const (
// UnescapingModeLegacy is the default V2 behavior, which escapes the entire
// path string before doing any routing.
UnescapingModeLegacy UnescapingMode = iota
// EscapingTypeExceptReserved unescapes all path parameters except RFC 6570
// reserved characters.
UnescapingModeAllExceptReserved
// EscapingTypeExceptSlash unescapes URL path parameters except path
// seperators, which will be left as "%2F".
UnescapingModeAllExceptSlash
// URL path parameters will be fully decoded.
UnescapingModeAllCharacters
// UnescapingModeDefault is the default escaping type.
// TODO(v3): default this to UnescapingModeAllExceptReserved per grpc-httpjson-transcoding's
// reference implementation
UnescapingModeDefault = UnescapingModeLegacy
)
// A HandlerFunc handles a specific pair of path pattern and HTTP method.
type HandlerFunc func(w http.ResponseWriter, r *http.Request, pathParams map[string]string)
// ServeMux is a request multiplexer for grpc-gateway.
// It matches http requests to patterns and invokes the corresponding handler.
type ServeMux struct {
// handlers maps HTTP method to a list of handlers.
handlers map[string][]handler
forwardResponseOptions []func(context.Context, http.ResponseWriter, proto.Message) error
marshalers marshalerRegistry
incomingHeaderMatcher HeaderMatcherFunc
outgoingHeaderMatcher HeaderMatcherFunc
metadataAnnotators []func(context.Context, *http.Request) metadata.MD
errorHandler ErrorHandlerFunc
streamErrorHandler StreamErrorHandlerFunc
routingErrorHandler RoutingErrorHandlerFunc
disablePathLengthFallback bool
unescapingMode UnescapingMode
}
// ServeMuxOption is an option that can be given to a ServeMux on construction.
type ServeMuxOption func(*ServeMux)
// WithForwardResponseOption returns a ServeMuxOption representing the forwardResponseOption.
//
// forwardResponseOption is an option that will be called on the relevant context.Context,
// http.ResponseWriter, and proto.Message before every forwarded response.
//
// The message may be nil in the case where just a header is being sent.
func WithForwardResponseOption(forwardResponseOption func(context.Context, http.ResponseWriter, proto.Message) error) ServeMuxOption {
return func(serveMux *ServeMux) {
serveMux.forwardResponseOptions = append(serveMux.forwardResponseOptions, forwardResponseOption)
}
}
// WithEscapingType sets the escaping type. See the definitions of UnescapingMode
// for more information.
func WithUnescapingMode(mode UnescapingMode) ServeMuxOption {
return func(serveMux *ServeMux) {
serveMux.unescapingMode = mode
}
}
// SetQueryParameterParser sets the query parameter parser, used to populate message from query parameters.
// Configuring this will mean the generated OpenAPI output is no longer correct, and it should be
// done with careful consideration.
func SetQueryParameterParser(queryParameterParser QueryParameterParser) ServeMuxOption {
return func(serveMux *ServeMux) {
currentQueryParser = queryParameterParser
}
}
// HeaderMatcherFunc checks whether a header key should be forwarded to/from gRPC context.
type HeaderMatcherFunc func(string) (string, bool)
// DefaultHeaderMatcher is used to pass http request headers to/from gRPC context. This adds permanent HTTP header
// keys (as specified by the IANA) to gRPC context with grpcgateway- prefix. HTTP headers that start with
// 'Grpc-Metadata-' are mapped to gRPC metadata after removing prefix 'Grpc-Metadata-'.
func DefaultHeaderMatcher(key string) (string, bool) {
key = textproto.CanonicalMIMEHeaderKey(key)
if isPermanentHTTPHeader(key) {
return MetadataPrefix + key, true
} else if strings.HasPrefix(key, MetadataHeaderPrefix) {
return key[len(MetadataHeaderPrefix):], true
}
return "", false
}
// WithIncomingHeaderMatcher returns a ServeMuxOption representing a headerMatcher for incoming request to gateway.
//
// This matcher will be called with each header in http.Request. If matcher returns true, that header will be
// passed to gRPC context. To transform the header before passing to gRPC context, matcher should return modified header.
func WithIncomingHeaderMatcher(fn HeaderMatcherFunc) ServeMuxOption {
return func(mux *ServeMux) {
mux.incomingHeaderMatcher = fn
}
}
// WithOutgoingHeaderMatcher returns a ServeMuxOption representing a headerMatcher for outgoing response from gateway.
//
// This matcher will be called with each header in response header metadata. If matcher returns true, that header will be
// passed to http response returned from gateway. To transform the header before passing to response,
// matcher should return modified header.
func WithOutgoingHeaderMatcher(fn HeaderMatcherFunc) ServeMuxOption {
return func(mux *ServeMux) {
mux.outgoingHeaderMatcher = fn
}
}
// WithMetadata returns a ServeMuxOption for passing metadata to a gRPC context.
//
// This can be used by services that need to read from http.Request and modify gRPC context. A common use case
// is reading token from cookie and adding it in gRPC context.
func WithMetadata(annotator func(context.Context, *http.Request) metadata.MD) ServeMuxOption {
return func(serveMux *ServeMux) {
serveMux.metadataAnnotators = append(serveMux.metadataAnnotators, annotator)
}
}
// WithErrorHandler returns a ServeMuxOption for configuring a custom error handler.
//
// This can be used to configure a custom error response.
func WithErrorHandler(fn ErrorHandlerFunc) ServeMuxOption {
return func(serveMux *ServeMux) {
serveMux.errorHandler = fn
}
}
// WithStreamErrorHandler returns a ServeMuxOption that will use the given custom stream
// error handler, which allows for customizing the error trailer for server-streaming
// calls.
//
// For stream errors that occur before any response has been written, the mux's
// ErrorHandler will be invoked. However, once data has been written, the errors must
// be handled differently: they must be included in the response body. The response body's
// final message will include the error details returned by the stream error handler.
func WithStreamErrorHandler(fn StreamErrorHandlerFunc) ServeMuxOption {
return func(serveMux *ServeMux) {
serveMux.streamErrorHandler = fn
}
}
// WithRoutingErrorHandler returns a ServeMuxOption for configuring a custom error handler to handle http routing errors.
//
// Method called for errors which can happen before gRPC route selected or executed.
// The following error codes: StatusMethodNotAllowed StatusNotFound StatusBadRequest
func WithRoutingErrorHandler(fn RoutingErrorHandlerFunc) ServeMuxOption {
return func(serveMux *ServeMux) {
serveMux.routingErrorHandler = fn
}
}
// WithDisablePathLengthFallback returns a ServeMuxOption for disable path length fallback.
func WithDisablePathLengthFallback() ServeMuxOption {
return func(serveMux *ServeMux) {
serveMux.disablePathLengthFallback = true
}
}
// NewServeMux returns a new ServeMux whose internal mapping is empty.
func NewServeMux(opts ...ServeMuxOption) *ServeMux {
serveMux := &ServeMux{
handlers: make(map[string][]handler),
forwardResponseOptions: make([]func(context.Context, http.ResponseWriter, proto.Message) error, 0),
marshalers: makeMarshalerMIMERegistry(),
errorHandler: DefaultHTTPErrorHandler,
streamErrorHandler: DefaultStreamErrorHandler,
routingErrorHandler: DefaultRoutingErrorHandler,
unescapingMode: UnescapingModeDefault,
}
for _, opt := range opts {
opt(serveMux)
}
if serveMux.incomingHeaderMatcher == nil {
serveMux.incomingHeaderMatcher = DefaultHeaderMatcher
}
if serveMux.outgoingHeaderMatcher == nil {
serveMux.outgoingHeaderMatcher = func(key string) (string, bool) {
return fmt.Sprintf("%s%s", MetadataHeaderPrefix, key), true
}
}
return serveMux
}
// Handle associates "h" to the pair of HTTP method and path pattern.
func (s *ServeMux) Handle(meth string, pat Pattern, h HandlerFunc) {
s.handlers[meth] = append([]handler{{pat: pat, h: h}}, s.handlers[meth]...)
}
// HandlePath allows users to configure custom path handlers.
// refer: https://grpc-ecosystem.github.io/grpc-gateway/docs/operations/inject_router/
func (s *ServeMux) HandlePath(meth string, pathPattern string, h HandlerFunc) error {
compiler, err := httprule.Parse(pathPattern)
if err != nil {
return fmt.Errorf("parsing path pattern: %w", err)
}
tp := compiler.Compile()
pattern, err := NewPattern(tp.Version, tp.OpCodes, tp.Pool, tp.Verb)
if err != nil {
return fmt.Errorf("creating new pattern: %w", err)
}
s.Handle(meth, pattern, h)
return nil
}
// ServeHTTP dispatches the request to the first handler whose pattern matches to r.Method and r.Path.
func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
path := r.URL.Path
if !strings.HasPrefix(path, "/") {
_, outboundMarshaler := MarshalerForRequest(s, r)
s.routingErrorHandler(ctx, s, outboundMarshaler, w, r, http.StatusBadRequest)
return
}
// TODO(v3): remove UnescapingModeLegacy
if s.unescapingMode != UnescapingModeLegacy && r.URL.RawPath != "" {
path = r.URL.RawPath
}
components := strings.Split(path[1:], "/")
if override := r.Header.Get("X-HTTP-Method-Override"); override != "" && s.isPathLengthFallback(r) {
r.Method = strings.ToUpper(override)
if err := r.ParseForm(); err != nil {
_, outboundMarshaler := MarshalerForRequest(s, r)
sterr := status.Error(codes.InvalidArgument, err.Error())
s.errorHandler(ctx, s, outboundMarshaler, w, r, sterr)
return
}
}
// Verb out here is to memoize for the fallback case below
var verb string
for _, h := range s.handlers[r.Method] {
// If the pattern has a verb, explicitly look for a suffix in the last
// component that matches a colon plus the verb. This allows us to
// handle some cases that otherwise can't be correctly handled by the
// former LastIndex case, such as when the verb literal itself contains
// a colon. This should work for all cases that have run through the
// parser because we know what verb we're looking for, however, there
// are still some cases that the parser itself cannot disambiguate. See
// the comment there if interested.
patVerb := h.pat.Verb()
l := len(components)
lastComponent := components[l-1]
var idx int = -1
if patVerb != "" && strings.HasSuffix(lastComponent, ":"+patVerb) {
idx = len(lastComponent) - len(patVerb) - 1
}
if idx == 0 {
_, outboundMarshaler := MarshalerForRequest(s, r)
s.routingErrorHandler(ctx, s, outboundMarshaler, w, r, http.StatusNotFound)
return
}
if idx > 0 {
components[l-1], verb = lastComponent[:idx], lastComponent[idx+1:]
}
pathParams, err := h.pat.MatchAndEscape(components, verb, s.unescapingMode)
if err != nil {
var mse MalformedSequenceError
if ok := errors.As(err, &mse); ok {
_, outboundMarshaler := MarshalerForRequest(s, r)
s.errorHandler(ctx, s, outboundMarshaler, w, r, &HTTPStatusError{
HTTPStatus: http.StatusBadRequest,
Err: mse,
})
}
continue
}
h.h(w, r, pathParams)
return
}
// lookup other methods to handle fallback from GET to POST and
// to determine if it is NotImplemented or NotFound.
for m, handlers := range s.handlers {
if m == r.Method {
continue
}
for _, h := range handlers {
pathParams, err := h.pat.MatchAndEscape(components, verb, s.unescapingMode)
if err != nil {
var mse MalformedSequenceError
if ok := errors.As(err, &mse); ok {
_, outboundMarshaler := MarshalerForRequest(s, r)
s.errorHandler(ctx, s, outboundMarshaler, w, r, &HTTPStatusError{
HTTPStatus: http.StatusBadRequest,
Err: mse,
})
}
continue
}
// X-HTTP-Method-Override is optional. Always allow fallback to POST.
if s.isPathLengthFallback(r) {
if err := r.ParseForm(); err != nil {
_, outboundMarshaler := MarshalerForRequest(s, r)
sterr := status.Error(codes.InvalidArgument, err.Error())
s.errorHandler(ctx, s, outboundMarshaler, w, r, sterr)
return
}
h.h(w, r, pathParams)
return
}
_, outboundMarshaler := MarshalerForRequest(s, r)
s.routingErrorHandler(ctx, s, outboundMarshaler, w, r, http.StatusMethodNotAllowed)
return
}
}
_, outboundMarshaler := MarshalerForRequest(s, r)
s.routingErrorHandler(ctx, s, outboundMarshaler, w, r, http.StatusNotFound)
}
// GetForwardResponseOptions returns the ForwardResponseOptions associated with this ServeMux.
func (s *ServeMux) GetForwardResponseOptions() []func(context.Context, http.ResponseWriter, proto.Message) error {
return s.forwardResponseOptions
}
func (s *ServeMux) isPathLengthFallback(r *http.Request) bool {
return !s.disablePathLengthFallback && r.Method == "POST" && r.Header.Get("Content-Type") == "application/x-www-form-urlencoded"
}
type handler struct {
pat Pattern
h HandlerFunc
}

View File

@@ -0,0 +1,383 @@
package runtime
import (
"errors"
"fmt"
"strconv"
"strings"
"github.com/grpc-ecosystem/grpc-gateway/v2/utilities"
"google.golang.org/grpc/grpclog"
)
var (
// ErrNotMatch indicates that the given HTTP request path does not match to the pattern.
ErrNotMatch = errors.New("not match to the path pattern")
// ErrInvalidPattern indicates that the given definition of Pattern is not valid.
ErrInvalidPattern = errors.New("invalid pattern")
// ErrMalformedSequence indicates that an escape sequence was malformed.
ErrMalformedSequence = errors.New("malformed escape sequence")
)
type MalformedSequenceError string
func (e MalformedSequenceError) Error() string {
return "malformed path escape " + strconv.Quote(string(e))
}
type op struct {
code utilities.OpCode
operand int
}
// Pattern is a template pattern of http request paths defined in
// https://github.com/googleapis/googleapis/blob/master/google/api/http.proto
type Pattern struct {
// ops is a list of operations
ops []op
// pool is a constant pool indexed by the operands or vars.
pool []string
// vars is a list of variables names to be bound by this pattern
vars []string
// stacksize is the max depth of the stack
stacksize int
// tailLen is the length of the fixed-size segments after a deep wildcard
tailLen int
// verb is the VERB part of the path pattern. It is empty if the pattern does not have VERB part.
verb string
}
// NewPattern returns a new Pattern from the given definition values.
// "ops" is a sequence of op codes. "pool" is a constant pool.
// "verb" is the verb part of the pattern. It is empty if the pattern does not have the part.
// "version" must be 1 for now.
// It returns an error if the given definition is invalid.
func NewPattern(version int, ops []int, pool []string, verb string) (Pattern, error) {
if version != 1 {
grpclog.Infof("unsupported version: %d", version)
return Pattern{}, ErrInvalidPattern
}
l := len(ops)
if l%2 != 0 {
grpclog.Infof("odd number of ops codes: %d", l)
return Pattern{}, ErrInvalidPattern
}
var (
typedOps []op
stack, maxstack int
tailLen int
pushMSeen bool
vars []string
)
for i := 0; i < l; i += 2 {
op := op{code: utilities.OpCode(ops[i]), operand: ops[i+1]}
switch op.code {
case utilities.OpNop:
continue
case utilities.OpPush:
if pushMSeen {
tailLen++
}
stack++
case utilities.OpPushM:
if pushMSeen {
grpclog.Infof("pushM appears twice")
return Pattern{}, ErrInvalidPattern
}
pushMSeen = true
stack++
case utilities.OpLitPush:
if op.operand < 0 || len(pool) <= op.operand {
grpclog.Infof("negative literal index: %d", op.operand)
return Pattern{}, ErrInvalidPattern
}
if pushMSeen {
tailLen++
}
stack++
case utilities.OpConcatN:
if op.operand <= 0 {
grpclog.Infof("negative concat size: %d", op.operand)
return Pattern{}, ErrInvalidPattern
}
stack -= op.operand
if stack < 0 {
grpclog.Info("stack underflow")
return Pattern{}, ErrInvalidPattern
}
stack++
case utilities.OpCapture:
if op.operand < 0 || len(pool) <= op.operand {
grpclog.Infof("variable name index out of bound: %d", op.operand)
return Pattern{}, ErrInvalidPattern
}
v := pool[op.operand]
op.operand = len(vars)
vars = append(vars, v)
stack--
if stack < 0 {
grpclog.Infof("stack underflow")
return Pattern{}, ErrInvalidPattern
}
default:
grpclog.Infof("invalid opcode: %d", op.code)
return Pattern{}, ErrInvalidPattern
}
if maxstack < stack {
maxstack = stack
}
typedOps = append(typedOps, op)
}
return Pattern{
ops: typedOps,
pool: pool,
vars: vars,
stacksize: maxstack,
tailLen: tailLen,
verb: verb,
}, nil
}
// MustPattern is a helper function which makes it easier to call NewPattern in variable initialization.
func MustPattern(p Pattern, err error) Pattern {
if err != nil {
grpclog.Fatalf("Pattern initialization failed: %v", err)
}
return p
}
// MatchAndEscape examines components to determine if they match to a Pattern.
// MatchAndEscape will return an error if no Patterns matched or if a pattern
// matched but contained malformed escape sequences. If successful, the function
// returns a mapping from field paths to their captured values.
func (p Pattern) MatchAndEscape(components []string, verb string, unescapingMode UnescapingMode) (map[string]string, error) {
if p.verb != verb {
if p.verb != "" {
return nil, ErrNotMatch
}
if len(components) == 0 {
components = []string{":" + verb}
} else {
components = append([]string{}, components...)
components[len(components)-1] += ":" + verb
}
}
var pos int
stack := make([]string, 0, p.stacksize)
captured := make([]string, len(p.vars))
l := len(components)
for _, op := range p.ops {
var err error
switch op.code {
case utilities.OpNop:
continue
case utilities.OpPush, utilities.OpLitPush:
if pos >= l {
return nil, ErrNotMatch
}
c := components[pos]
if op.code == utilities.OpLitPush {
if lit := p.pool[op.operand]; c != lit {
return nil, ErrNotMatch
}
} else if op.code == utilities.OpPush {
if c, err = unescape(c, unescapingMode, false); err != nil {
return nil, err
}
}
stack = append(stack, c)
pos++
case utilities.OpPushM:
end := len(components)
if end < pos+p.tailLen {
return nil, ErrNotMatch
}
end -= p.tailLen
c := strings.Join(components[pos:end], "/")
if c, err = unescape(c, unescapingMode, true); err != nil {
return nil, err
}
stack = append(stack, c)
pos = end
case utilities.OpConcatN:
n := op.operand
l := len(stack) - n
stack = append(stack[:l], strings.Join(stack[l:], "/"))
case utilities.OpCapture:
n := len(stack) - 1
captured[op.operand] = stack[n]
stack = stack[:n]
}
}
if pos < l {
return nil, ErrNotMatch
}
bindings := make(map[string]string)
for i, val := range captured {
bindings[p.vars[i]] = val
}
return bindings, nil
}
// MatchAndEscape examines components to determine if they match to a Pattern.
// It will never perform per-component unescaping (see: UnescapingModeLegacy).
// MatchAndEscape will return an error if no Patterns matched. If successful,
// the function returns a mapping from field paths to their captured values.
//
// Deprecated: Use MatchAndEscape.
func (p Pattern) Match(components []string, verb string) (map[string]string, error) {
return p.MatchAndEscape(components, verb, UnescapingModeDefault)
}
// Verb returns the verb part of the Pattern.
func (p Pattern) Verb() string { return p.verb }
func (p Pattern) String() string {
var stack []string
for _, op := range p.ops {
switch op.code {
case utilities.OpNop:
continue
case utilities.OpPush:
stack = append(stack, "*")
case utilities.OpLitPush:
stack = append(stack, p.pool[op.operand])
case utilities.OpPushM:
stack = append(stack, "**")
case utilities.OpConcatN:
n := op.operand
l := len(stack) - n
stack = append(stack[:l], strings.Join(stack[l:], "/"))
case utilities.OpCapture:
n := len(stack) - 1
stack[n] = fmt.Sprintf("{%s=%s}", p.vars[op.operand], stack[n])
}
}
segs := strings.Join(stack, "/")
if p.verb != "" {
return fmt.Sprintf("/%s:%s", segs, p.verb)
}
return "/" + segs
}
/*
* The following code is adopted and modified from Go's standard library
* and carries the attached license.
*
* Copyright 2009 The Go Authors. All rights reserved.
* Use of this source code is governed by a BSD-style
* license that can be found in the LICENSE file.
*/
// ishex returns whether or not the given byte is a valid hex character
func ishex(c byte) bool {
switch {
case '0' <= c && c <= '9':
return true
case 'a' <= c && c <= 'f':
return true
case 'A' <= c && c <= 'F':
return true
}
return false
}
func isRFC6570Reserved(c byte) bool {
switch c {
case '!', '#', '$', '&', '\'', '(', ')', '*',
'+', ',', '/', ':', ';', '=', '?', '@', '[', ']':
return true
default:
return false
}
}
// unhex converts a hex point to the bit representation
func unhex(c byte) byte {
switch {
case '0' <= c && c <= '9':
return c - '0'
case 'a' <= c && c <= 'f':
return c - 'a' + 10
case 'A' <= c && c <= 'F':
return c - 'A' + 10
}
return 0
}
// shouldUnescapeWithMode returns true if the character is escapable with the
// given mode
func shouldUnescapeWithMode(c byte, mode UnescapingMode) bool {
switch mode {
case UnescapingModeAllExceptReserved:
if isRFC6570Reserved(c) {
return false
}
case UnescapingModeAllExceptSlash:
if c == '/' {
return false
}
case UnescapingModeAllCharacters:
return true
}
return true
}
// unescape unescapes a path string using the provided mode
func unescape(s string, mode UnescapingMode, multisegment bool) (string, error) {
// TODO(v3): remove UnescapingModeLegacy
if mode == UnescapingModeLegacy {
return s, nil
}
if !multisegment {
mode = UnescapingModeAllCharacters
}
// Count %, check that they're well-formed.
n := 0
for i := 0; i < len(s); {
if s[i] == '%' {
n++
if i+2 >= len(s) || !ishex(s[i+1]) || !ishex(s[i+2]) {
s = s[i:]
if len(s) > 3 {
s = s[:3]
}
return "", MalformedSequenceError(s)
}
i += 3
} else {
i++
}
}
if n == 0 {
return s, nil
}
var t strings.Builder
t.Grow(len(s))
for i := 0; i < len(s); i++ {
switch s[i] {
case '%':
c := unhex(s[i+1])<<4 | unhex(s[i+2])
if shouldUnescapeWithMode(c, mode) {
t.WriteByte(c)
i += 2
continue
}
fallthrough
default:
t.WriteByte(s[i])
}
}
return t.String(), nil
}

View File

@@ -0,0 +1,80 @@
package runtime
import (
"google.golang.org/protobuf/proto"
)
// StringP returns a pointer to a string whose pointee is same as the given string value.
func StringP(val string) (*string, error) {
return proto.String(val), nil
}
// BoolP parses the given string representation of a boolean value,
// and returns a pointer to a bool whose value is same as the parsed value.
func BoolP(val string) (*bool, error) {
b, err := Bool(val)
if err != nil {
return nil, err
}
return proto.Bool(b), nil
}
// Float64P parses the given string representation of a floating point number,
// and returns a pointer to a float64 whose value is same as the parsed number.
func Float64P(val string) (*float64, error) {
f, err := Float64(val)
if err != nil {
return nil, err
}
return proto.Float64(f), nil
}
// Float32P parses the given string representation of a floating point number,
// and returns a pointer to a float32 whose value is same as the parsed number.
func Float32P(val string) (*float32, error) {
f, err := Float32(val)
if err != nil {
return nil, err
}
return proto.Float32(f), nil
}
// Int64P parses the given string representation of an integer
// and returns a pointer to a int64 whose value is same as the parsed integer.
func Int64P(val string) (*int64, error) {
i, err := Int64(val)
if err != nil {
return nil, err
}
return proto.Int64(i), nil
}
// Int32P parses the given string representation of an integer
// and returns a pointer to a int32 whose value is same as the parsed integer.
func Int32P(val string) (*int32, error) {
i, err := Int32(val)
if err != nil {
return nil, err
}
return proto.Int32(i), err
}
// Uint64P parses the given string representation of an integer
// and returns a pointer to a uint64 whose value is same as the parsed integer.
func Uint64P(val string) (*uint64, error) {
i, err := Uint64(val)
if err != nil {
return nil, err
}
return proto.Uint64(i), err
}
// Uint32P parses the given string representation of an integer
// and returns a pointer to a uint32 whose value is same as the parsed integer.
func Uint32P(val string) (*uint32, error) {
i, err := Uint32(val)
if err != nil {
return nil, err
}
return proto.Uint32(i), err
}

View File

@@ -0,0 +1,329 @@
package runtime
import (
"encoding/base64"
"errors"
"fmt"
"net/url"
"regexp"
"strconv"
"strings"
"time"
"github.com/grpc-ecosystem/grpc-gateway/v2/utilities"
"google.golang.org/genproto/protobuf/field_mask"
"google.golang.org/grpc/grpclog"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb"
"google.golang.org/protobuf/types/known/wrapperspb"
)
var valuesKeyRegexp = regexp.MustCompile(`^(.*)\[(.*)\]$`)
var currentQueryParser QueryParameterParser = &defaultQueryParser{}
// QueryParameterParser defines interface for all query parameter parsers
type QueryParameterParser interface {
Parse(msg proto.Message, values url.Values, filter *utilities.DoubleArray) error
}
// PopulateQueryParameters parses query parameters
// into "msg" using current query parser
func PopulateQueryParameters(msg proto.Message, values url.Values, filter *utilities.DoubleArray) error {
return currentQueryParser.Parse(msg, values, filter)
}
type defaultQueryParser struct{}
// Parse populates "values" into "msg".
// A value is ignored if its key starts with one of the elements in "filter".
func (*defaultQueryParser) Parse(msg proto.Message, values url.Values, filter *utilities.DoubleArray) error {
for key, values := range values {
match := valuesKeyRegexp.FindStringSubmatch(key)
if len(match) == 3 {
key = match[1]
values = append([]string{match[2]}, values...)
}
fieldPath := strings.Split(key, ".")
if filter.HasCommonPrefix(fieldPath) {
continue
}
if err := populateFieldValueFromPath(msg.ProtoReflect(), fieldPath, values); err != nil {
return err
}
}
return nil
}
// PopulateFieldFromPath sets a value in a nested Protobuf structure.
func PopulateFieldFromPath(msg proto.Message, fieldPathString string, value string) error {
fieldPath := strings.Split(fieldPathString, ".")
return populateFieldValueFromPath(msg.ProtoReflect(), fieldPath, []string{value})
}
func populateFieldValueFromPath(msgValue protoreflect.Message, fieldPath []string, values []string) error {
if len(fieldPath) < 1 {
return errors.New("no field path")
}
if len(values) < 1 {
return errors.New("no value provided")
}
var fieldDescriptor protoreflect.FieldDescriptor
for i, fieldName := range fieldPath {
fields := msgValue.Descriptor().Fields()
// Get field by name
fieldDescriptor = fields.ByName(protoreflect.Name(fieldName))
if fieldDescriptor == nil {
fieldDescriptor = fields.ByJSONName(fieldName)
if fieldDescriptor == nil {
// We're not returning an error here because this could just be
// an extra query parameter that isn't part of the request.
grpclog.Infof("field not found in %q: %q", msgValue.Descriptor().FullName(), strings.Join(fieldPath, "."))
return nil
}
}
// If this is the last element, we're done
if i == len(fieldPath)-1 {
break
}
// Only singular message fields are allowed
if fieldDescriptor.Message() == nil || fieldDescriptor.Cardinality() == protoreflect.Repeated {
return fmt.Errorf("invalid path: %q is not a message", fieldName)
}
// Get the nested message
msgValue = msgValue.Mutable(fieldDescriptor).Message()
}
// Check if oneof already set
if of := fieldDescriptor.ContainingOneof(); of != nil {
if f := msgValue.WhichOneof(of); f != nil {
return fmt.Errorf("field already set for oneof %q", of.FullName().Name())
}
}
switch {
case fieldDescriptor.IsList():
return populateRepeatedField(fieldDescriptor, msgValue.Mutable(fieldDescriptor).List(), values)
case fieldDescriptor.IsMap():
return populateMapField(fieldDescriptor, msgValue.Mutable(fieldDescriptor).Map(), values)
}
if len(values) > 1 {
return fmt.Errorf("too many values for field %q: %s", fieldDescriptor.FullName().Name(), strings.Join(values, ", "))
}
return populateField(fieldDescriptor, msgValue, values[0])
}
func populateField(fieldDescriptor protoreflect.FieldDescriptor, msgValue protoreflect.Message, value string) error {
v, err := parseField(fieldDescriptor, value)
if err != nil {
return fmt.Errorf("parsing field %q: %w", fieldDescriptor.FullName().Name(), err)
}
msgValue.Set(fieldDescriptor, v)
return nil
}
func populateRepeatedField(fieldDescriptor protoreflect.FieldDescriptor, list protoreflect.List, values []string) error {
for _, value := range values {
v, err := parseField(fieldDescriptor, value)
if err != nil {
return fmt.Errorf("parsing list %q: %w", fieldDescriptor.FullName().Name(), err)
}
list.Append(v)
}
return nil
}
func populateMapField(fieldDescriptor protoreflect.FieldDescriptor, mp protoreflect.Map, values []string) error {
if len(values) != 2 {
return fmt.Errorf("more than one value provided for key %q in map %q", values[0], fieldDescriptor.FullName())
}
key, err := parseField(fieldDescriptor.MapKey(), values[0])
if err != nil {
return fmt.Errorf("parsing map key %q: %w", fieldDescriptor.FullName().Name(), err)
}
value, err := parseField(fieldDescriptor.MapValue(), values[1])
if err != nil {
return fmt.Errorf("parsing map value %q: %w", fieldDescriptor.FullName().Name(), err)
}
mp.Set(key.MapKey(), value)
return nil
}
func parseField(fieldDescriptor protoreflect.FieldDescriptor, value string) (protoreflect.Value, error) {
switch fieldDescriptor.Kind() {
case protoreflect.BoolKind:
v, err := strconv.ParseBool(value)
if err != nil {
return protoreflect.Value{}, err
}
return protoreflect.ValueOfBool(v), nil
case protoreflect.EnumKind:
enum, err := protoregistry.GlobalTypes.FindEnumByName(fieldDescriptor.Enum().FullName())
switch {
case errors.Is(err, protoregistry.NotFound):
return protoreflect.Value{}, fmt.Errorf("enum %q is not registered", fieldDescriptor.Enum().FullName())
case err != nil:
return protoreflect.Value{}, fmt.Errorf("failed to look up enum: %w", err)
}
// Look for enum by name
v := enum.Descriptor().Values().ByName(protoreflect.Name(value))
if v == nil {
i, err := strconv.Atoi(value)
if err != nil {
return protoreflect.Value{}, fmt.Errorf("%q is not a valid value", value)
}
// Look for enum by number
v = enum.Descriptor().Values().ByNumber(protoreflect.EnumNumber(i))
if v == nil {
return protoreflect.Value{}, fmt.Errorf("%q is not a valid value", value)
}
}
return protoreflect.ValueOfEnum(v.Number()), nil
case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
v, err := strconv.ParseInt(value, 10, 32)
if err != nil {
return protoreflect.Value{}, err
}
return protoreflect.ValueOfInt32(int32(v)), nil
case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
v, err := strconv.ParseInt(value, 10, 64)
if err != nil {
return protoreflect.Value{}, err
}
return protoreflect.ValueOfInt64(v), nil
case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
v, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return protoreflect.Value{}, err
}
return protoreflect.ValueOfUint32(uint32(v)), nil
case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
v, err := strconv.ParseUint(value, 10, 64)
if err != nil {
return protoreflect.Value{}, err
}
return protoreflect.ValueOfUint64(v), nil
case protoreflect.FloatKind:
v, err := strconv.ParseFloat(value, 32)
if err != nil {
return protoreflect.Value{}, err
}
return protoreflect.ValueOfFloat32(float32(v)), nil
case protoreflect.DoubleKind:
v, err := strconv.ParseFloat(value, 64)
if err != nil {
return protoreflect.Value{}, err
}
return protoreflect.ValueOfFloat64(v), nil
case protoreflect.StringKind:
return protoreflect.ValueOfString(value), nil
case protoreflect.BytesKind:
v, err := base64.URLEncoding.DecodeString(value)
if err != nil {
return protoreflect.Value{}, err
}
return protoreflect.ValueOfBytes(v), nil
case protoreflect.MessageKind, protoreflect.GroupKind:
return parseMessage(fieldDescriptor.Message(), value)
default:
panic(fmt.Sprintf("unknown field kind: %v", fieldDescriptor.Kind()))
}
}
func parseMessage(msgDescriptor protoreflect.MessageDescriptor, value string) (protoreflect.Value, error) {
var msg proto.Message
switch msgDescriptor.FullName() {
case "google.protobuf.Timestamp":
if value == "null" {
break
}
t, err := time.Parse(time.RFC3339Nano, value)
if err != nil {
return protoreflect.Value{}, err
}
msg = timestamppb.New(t)
case "google.protobuf.Duration":
if value == "null" {
break
}
d, err := time.ParseDuration(value)
if err != nil {
return protoreflect.Value{}, err
}
msg = durationpb.New(d)
case "google.protobuf.DoubleValue":
v, err := strconv.ParseFloat(value, 64)
if err != nil {
return protoreflect.Value{}, err
}
msg = &wrapperspb.DoubleValue{Value: v}
case "google.protobuf.FloatValue":
v, err := strconv.ParseFloat(value, 32)
if err != nil {
return protoreflect.Value{}, err
}
msg = &wrapperspb.FloatValue{Value: float32(v)}
case "google.protobuf.Int64Value":
v, err := strconv.ParseInt(value, 10, 64)
if err != nil {
return protoreflect.Value{}, err
}
msg = &wrapperspb.Int64Value{Value: v}
case "google.protobuf.Int32Value":
v, err := strconv.ParseInt(value, 10, 32)
if err != nil {
return protoreflect.Value{}, err
}
msg = &wrapperspb.Int32Value{Value: int32(v)}
case "google.protobuf.UInt64Value":
v, err := strconv.ParseUint(value, 10, 64)
if err != nil {
return protoreflect.Value{}, err
}
msg = &wrapperspb.UInt64Value{Value: v}
case "google.protobuf.UInt32Value":
v, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return protoreflect.Value{}, err
}
msg = &wrapperspb.UInt32Value{Value: uint32(v)}
case "google.protobuf.BoolValue":
v, err := strconv.ParseBool(value)
if err != nil {
return protoreflect.Value{}, err
}
msg = &wrapperspb.BoolValue{Value: v}
case "google.protobuf.StringValue":
msg = &wrapperspb.StringValue{Value: value}
case "google.protobuf.BytesValue":
v, err := base64.URLEncoding.DecodeString(value)
if err != nil {
return protoreflect.Value{}, err
}
msg = &wrapperspb.BytesValue{Value: v}
case "google.protobuf.FieldMask":
fm := &field_mask.FieldMask{}
fm.Paths = append(fm.Paths, strings.Split(value, ",")...)
msg = fm
default:
return protoreflect.Value{}, fmt.Errorf("unsupported message type: %q", string(msgDescriptor.FullName()))
}
return protoreflect.ValueOfMessage(msg.ProtoReflect()), nil
}

View File

@@ -0,0 +1,27 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
package(default_visibility = ["//visibility:public"])
go_library(
name = "utilities",
srcs = [
"doc.go",
"pattern.go",
"readerfactory.go",
"trie.go",
],
importpath = "github.com/grpc-ecosystem/grpc-gateway/v2/utilities",
)
go_test(
name = "utilities_test",
size = "small",
srcs = ["trie_test.go"],
deps = [":utilities"],
)
alias(
name = "go_default_library",
actual = ":utilities",
visibility = ["//visibility:public"],
)

View File

@@ -0,0 +1,2 @@
// Package utilities provides members for internal use in grpc-gateway.
package utilities

View File

@@ -0,0 +1,22 @@
package utilities
// An OpCode is a opcode of compiled path patterns.
type OpCode int
// These constants are the valid values of OpCode.
const (
// OpNop does nothing
OpNop = OpCode(iota)
// OpPush pushes a component to stack
OpPush
// OpLitPush pushes a component to stack if it matches to the literal
OpLitPush
// OpPushM concatenates the remaining components and pushes it to stack
OpPushM
// OpConcatN pops N items from stack, concatenates them and pushes it back to stack
OpConcatN
// OpCapture pops an item and binds it to the variable
OpCapture
// OpEnd is the least positive invalid opcode.
OpEnd
)

View File

@@ -0,0 +1,20 @@
package utilities
import (
"bytes"
"io"
"io/ioutil"
)
// IOReaderFactory takes in an io.Reader and returns a function that will allow you to create a new reader that begins
// at the start of the stream
func IOReaderFactory(r io.Reader) (func() io.Reader, error) {
b, err := ioutil.ReadAll(r)
if err != nil {
return nil, err
}
return func() io.Reader {
return bytes.NewReader(b)
}, nil
}

View File

@@ -0,0 +1,174 @@
package utilities
import (
"sort"
)
// DoubleArray is a Double Array implementation of trie on sequences of strings.
type DoubleArray struct {
// Encoding keeps an encoding from string to int
Encoding map[string]int
// Base is the base array of Double Array
Base []int
// Check is the check array of Double Array
Check []int
}
// NewDoubleArray builds a DoubleArray from a set of sequences of strings.
func NewDoubleArray(seqs [][]string) *DoubleArray {
da := &DoubleArray{Encoding: make(map[string]int)}
if len(seqs) == 0 {
return da
}
encoded := registerTokens(da, seqs)
sort.Sort(byLex(encoded))
root := node{row: -1, col: -1, left: 0, right: len(encoded)}
addSeqs(da, encoded, 0, root)
for i := len(da.Base); i > 0; i-- {
if da.Check[i-1] != 0 {
da.Base = da.Base[:i]
da.Check = da.Check[:i]
break
}
}
return da
}
func registerTokens(da *DoubleArray, seqs [][]string) [][]int {
var result [][]int
for _, seq := range seqs {
var encoded []int
for _, token := range seq {
if _, ok := da.Encoding[token]; !ok {
da.Encoding[token] = len(da.Encoding)
}
encoded = append(encoded, da.Encoding[token])
}
result = append(result, encoded)
}
for i := range result {
result[i] = append(result[i], len(da.Encoding))
}
return result
}
type node struct {
row, col int
left, right int
}
func (n node) value(seqs [][]int) int {
return seqs[n.row][n.col]
}
func (n node) children(seqs [][]int) []*node {
var result []*node
lastVal := int(-1)
last := new(node)
for i := n.left; i < n.right; i++ {
if lastVal == seqs[i][n.col+1] {
continue
}
last.right = i
last = &node{
row: i,
col: n.col + 1,
left: i,
}
result = append(result, last)
}
last.right = n.right
return result
}
func addSeqs(da *DoubleArray, seqs [][]int, pos int, n node) {
ensureSize(da, pos)
children := n.children(seqs)
var i int
for i = 1; ; i++ {
ok := func() bool {
for _, child := range children {
code := child.value(seqs)
j := i + code
ensureSize(da, j)
if da.Check[j] != 0 {
return false
}
}
return true
}()
if ok {
break
}
}
da.Base[pos] = i
for _, child := range children {
code := child.value(seqs)
j := i + code
da.Check[j] = pos + 1
}
terminator := len(da.Encoding)
for _, child := range children {
code := child.value(seqs)
if code == terminator {
continue
}
j := i + code
addSeqs(da, seqs, j, *child)
}
}
func ensureSize(da *DoubleArray, i int) {
for i >= len(da.Base) {
da.Base = append(da.Base, make([]int, len(da.Base)+1)...)
da.Check = append(da.Check, make([]int, len(da.Check)+1)...)
}
}
type byLex [][]int
func (l byLex) Len() int { return len(l) }
func (l byLex) Swap(i, j int) { l[i], l[j] = l[j], l[i] }
func (l byLex) Less(i, j int) bool {
si := l[i]
sj := l[j]
var k int
for k = 0; k < len(si) && k < len(sj); k++ {
if si[k] < sj[k] {
return true
}
if si[k] > sj[k] {
return false
}
}
return k < len(sj)
}
// HasCommonPrefix determines if any sequence in the DoubleArray is a prefix of the given sequence.
func (da *DoubleArray) HasCommonPrefix(seq []string) bool {
if len(da.Base) == 0 {
return false
}
var i int
for _, t := range seq {
code, ok := da.Encoding[t]
if !ok {
break
}
j := da.Base[i] + code
if len(da.Check) <= j || da.Check[j] != i+1 {
break
}
i = j
}
j := da.Base[i] + len(da.Encoding)
if len(da.Check) <= j || da.Check[j] != i+1 {
return false
}
return true
}

View File

@@ -1,30 +0,0 @@
language: go
go:
- "1.10.x"
- "1.11.x"
- "1.12.x"
- master
matrix:
allow_failures:
- go: master
fast_finish: true
env:
global:
- CC_TEST_REPORTER_ID=68feaa3410049ce73e145287acbcdacc525087a30627f96f04e579e75bd71c00
before_script:
- curl -L https://codeclimate.com/downloads/test-reporter/test-reporter-latest-linux-amd64 > ./cc-test-reporter
- chmod +x ./cc-test-reporter
- ./cc-test-reporter before-build
install:
- curl -sL https://taskfile.dev/install.sh | sh
script:
- diff -u <(echo -n) <(./bin/task lint)
- ./bin/task test-coverage
after_script:
- ./cc-test-reporter after-build --exit-code $TRAVIS_TEST_RESULT

View File

@@ -1,6 +1,7 @@
package objx
import (
"reflect"
"regexp"
"strconv"
"strings"
@@ -16,11 +17,18 @@ const (
// arrayAccesRegexString is the regex used to extract the array number
// from the access path
arrayAccesRegexString = `^(.+)\[([0-9]+)\]$`
// mapAccessRegexString is the regex used to extract the map key
// from the access path
mapAccessRegexString = `^([^\[]*)\[([^\]]+)\](.*)$`
)
// arrayAccesRegex is the compiled arrayAccesRegexString
var arrayAccesRegex = regexp.MustCompile(arrayAccesRegexString)
// mapAccessRegex is the compiled mapAccessRegexString
var mapAccessRegex = regexp.MustCompile(mapAccessRegexString)
// Get gets the value using the specified selector and
// returns it inside a new Obj object.
//
@@ -70,15 +78,53 @@ func getIndex(s string) (int, string) {
return -1, s
}
// getKey returns the key which is held in s by two brackets.
// It also returns the next selector.
func getKey(s string) (string, string) {
selSegs := strings.SplitN(s, PathSeparator, 2)
thisSel := selSegs[0]
nextSel := ""
if len(selSegs) > 1 {
nextSel = selSegs[1]
}
mapMatches := mapAccessRegex.FindStringSubmatch(s)
if len(mapMatches) > 0 {
if _, err := strconv.Atoi(mapMatches[2]); err != nil {
thisSel = mapMatches[1]
nextSel = "[" + mapMatches[2] + "]" + mapMatches[3]
if thisSel == "" {
thisSel = mapMatches[2]
nextSel = mapMatches[3]
}
if nextSel == "" {
selSegs = []string{"", ""}
} else if nextSel[0] == '.' {
nextSel = nextSel[1:]
}
}
}
return thisSel, nextSel
}
// access accesses the object using the selector and performs the
// appropriate action.
func access(current interface{}, selector string, value interface{}, isSet bool) interface{} {
selSegs := strings.SplitN(selector, PathSeparator, 2)
thisSel := selSegs[0]
index := -1
thisSel, nextSel := getKey(selector)
if strings.Contains(thisSel, "[") {
indexes := []int{}
for strings.Contains(thisSel, "[") {
prevSel := thisSel
index := -1
index, thisSel = getIndex(thisSel)
indexes = append(indexes, index)
if prevSel == thisSel {
break
}
}
if curMap, ok := current.(Map); ok {
@@ -88,13 +134,17 @@ func access(current interface{}, selector string, value interface{}, isSet bool)
switch current.(type) {
case map[string]interface{}:
curMSI := current.(map[string]interface{})
if len(selSegs) <= 1 && isSet {
if nextSel == "" && isSet {
curMSI[thisSel] = value
return nil
}
_, ok := curMSI[thisSel].(map[string]interface{})
if (curMSI[thisSel] == nil || !ok) && index == -1 && isSet {
if !ok {
_, ok = curMSI[thisSel].(Map)
}
if (curMSI[thisSel] == nil || !ok) && len(indexes) == 0 && isSet {
curMSI[thisSel] = map[string]interface{}{}
}
@@ -102,18 +152,46 @@ func access(current interface{}, selector string, value interface{}, isSet bool)
default:
current = nil
}
// do we need to access the item of an array?
if index > -1 {
if array, ok := current.([]interface{}); ok {
if index < len(array) {
current = array[index]
} else {
current = nil
if len(indexes) > 0 {
num := len(indexes)
for num > 0 {
num--
index := indexes[num]
indexes = indexes[:num]
if array, ok := interSlice(current); ok {
if index < len(array) {
current = array[index]
} else {
current = nil
break
}
}
}
}
if len(selSegs) > 1 {
current = access(current, selSegs[1], value, isSet)
if nextSel != "" {
current = access(current, nextSel, value, isSet)
}
return current
}
func interSlice(slice interface{}) ([]interface{}, bool) {
if array, ok := slice.([]interface{}); ok {
return array, ok
}
s := reflect.ValueOf(slice)
if s.Kind() != reflect.Slice {
return nil, false
}
ret := make([]interface{}, s.Len())
for i := 0; i < s.Len(); i++ {
ret[i] = s.Index(i).Interface()
}
return ret, true
}

View File

@@ -92,6 +92,18 @@ func MustFromJSON(jsonString string) Map {
return o
}
// MustFromJSONSlice creates a new slice of Map containing the data specified in the
// jsonString. Works with jsons with a top level array
//
// Panics if the JSON is invalid.
func MustFromJSONSlice(jsonString string) []Map {
slice, err := FromJSONSlice(jsonString)
if err != nil {
panic("objx: MustFromJSONSlice failed with error: " + err.Error())
}
return slice
}
// FromJSON creates a new Map containing the data specified in the
// jsonString.
//
@@ -102,45 +114,20 @@ func FromJSON(jsonString string) (Map, error) {
if err != nil {
return Nil, err
}
m.tryConvertFloat64()
return m, nil
}
func (m Map) tryConvertFloat64() {
for k, v := range m {
switch v.(type) {
case float64:
f := v.(float64)
if float64(int(f)) == f {
m[k] = int(f)
}
case map[string]interface{}:
t := New(v)
t.tryConvertFloat64()
m[k] = t
case []interface{}:
m[k] = tryConvertFloat64InSlice(v.([]interface{}))
}
// FromJSONSlice creates a new slice of Map containing the data specified in the
// jsonString. Works with jsons with a top level array
//
// Returns an error if the JSON is invalid.
func FromJSONSlice(jsonString string) ([]Map, error) {
var slice []Map
err := json.Unmarshal([]byte(jsonString), &slice)
if err != nil {
return nil, err
}
}
func tryConvertFloat64InSlice(s []interface{}) []interface{} {
for k, v := range s {
switch v.(type) {
case float64:
f := v.(float64)
if float64(int(f)) == f {
s[k] = int(f)
}
case map[string]interface{}:
t := New(v)
t.tryConvertFloat64()
s[k] = t
case []interface{}:
s[k] = tryConvertFloat64InSlice(v.([]interface{}))
}
}
return s
return slice, nil
}
// FromBase64 creates a new Obj containing the data specified

View File

@@ -385,6 +385,11 @@ func (v *Value) Int(optionalDefault ...int) int {
if s, ok := v.data.(int); ok {
return s
}
if s, ok := v.data.(float64); ok {
if float64(int(s)) == s {
return int(s)
}
}
if len(optionalDefault) == 1 {
return optionalDefault[0]
}
@@ -395,6 +400,11 @@ func (v *Value) Int(optionalDefault ...int) int {
//
// Panics if the object is not a int.
func (v *Value) MustInt() int {
if s, ok := v.data.(float64); ok {
if float64(int(s)) == s {
return int(s)
}
}
return v.data.(int)
}

View File

@@ -1,8 +1,10 @@
package assert
import (
"bytes"
"fmt"
"reflect"
"time"
)
type CompareType int
@@ -30,6 +32,9 @@ var (
float64Type = reflect.TypeOf(float64(1))
stringType = reflect.TypeOf("")
timeType = reflect.TypeOf(time.Time{})
bytesType = reflect.TypeOf([]byte{})
)
func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
@@ -299,6 +304,47 @@ func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
return compareLess, true
}
}
// Check for known struct types we can check for compare results.
case reflect.Struct:
{
// All structs enter here. We're not interested in most types.
if !canConvert(obj1Value, timeType) {
break
}
// time.Time can compared!
timeObj1, ok := obj1.(time.Time)
if !ok {
timeObj1 = obj1Value.Convert(timeType).Interface().(time.Time)
}
timeObj2, ok := obj2.(time.Time)
if !ok {
timeObj2 = obj2Value.Convert(timeType).Interface().(time.Time)
}
return compare(timeObj1.UnixNano(), timeObj2.UnixNano(), reflect.Int64)
}
case reflect.Slice:
{
// We only care about the []byte type.
if !canConvert(obj1Value, bytesType) {
break
}
// []byte can be compared!
bytesObj1, ok := obj1.([]byte)
if !ok {
bytesObj1 = obj1Value.Convert(bytesType).Interface().([]byte)
}
bytesObj2, ok := obj2.([]byte)
if !ok {
bytesObj2 = obj2Value.Convert(bytesType).Interface().([]byte)
}
return CompareType(bytes.Compare(bytesObj1, bytesObj2)), true
}
}
return compareEqual, false
@@ -310,7 +356,10 @@ func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
// assert.Greater(t, float64(2), float64(1))
// assert.Greater(t, "b", "a")
func Greater(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
return compareTwoValues(t, e1, e2, []CompareType{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs)
if h, ok := t.(tHelper); ok {
h.Helper()
}
return compareTwoValues(t, e1, e2, []CompareType{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs...)
}
// GreaterOrEqual asserts that the first element is greater than or equal to the second
@@ -320,7 +369,10 @@ func Greater(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface
// assert.GreaterOrEqual(t, "b", "a")
// assert.GreaterOrEqual(t, "b", "b")
func GreaterOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
return compareTwoValues(t, e1, e2, []CompareType{compareGreater, compareEqual}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs)
if h, ok := t.(tHelper); ok {
h.Helper()
}
return compareTwoValues(t, e1, e2, []CompareType{compareGreater, compareEqual}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs...)
}
// Less asserts that the first element is less than the second
@@ -329,7 +381,10 @@ func GreaterOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...in
// assert.Less(t, float64(1), float64(2))
// assert.Less(t, "a", "b")
func Less(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
return compareTwoValues(t, e1, e2, []CompareType{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs)
if h, ok := t.(tHelper); ok {
h.Helper()
}
return compareTwoValues(t, e1, e2, []CompareType{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs...)
}
// LessOrEqual asserts that the first element is less than or equal to the second
@@ -339,7 +394,10 @@ func Less(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{})
// assert.LessOrEqual(t, "a", "b")
// assert.LessOrEqual(t, "b", "b")
func LessOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
return compareTwoValues(t, e1, e2, []CompareType{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs)
if h, ok := t.(tHelper); ok {
h.Helper()
}
return compareTwoValues(t, e1, e2, []CompareType{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs...)
}
// Positive asserts that the specified element is positive
@@ -347,8 +405,11 @@ func LessOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...inter
// assert.Positive(t, 1)
// assert.Positive(t, 1.23)
func Positive(t TestingT, e interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
zero := reflect.Zero(reflect.TypeOf(e))
return compareTwoValues(t, e, zero.Interface(), []CompareType{compareGreater}, "\"%v\" is not positive", msgAndArgs)
return compareTwoValues(t, e, zero.Interface(), []CompareType{compareGreater}, "\"%v\" is not positive", msgAndArgs...)
}
// Negative asserts that the specified element is negative
@@ -356,8 +417,11 @@ func Positive(t TestingT, e interface{}, msgAndArgs ...interface{}) bool {
// assert.Negative(t, -1)
// assert.Negative(t, -1.23)
func Negative(t TestingT, e interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
zero := reflect.Zero(reflect.TypeOf(e))
return compareTwoValues(t, e, zero.Interface(), []CompareType{compareLess}, "\"%v\" is not negative", msgAndArgs)
return compareTwoValues(t, e, zero.Interface(), []CompareType{compareLess}, "\"%v\" is not negative", msgAndArgs...)
}
func compareTwoValues(t TestingT, e1 interface{}, e2 interface{}, allowedComparesResults []CompareType, failMessage string, msgAndArgs ...interface{}) bool {

View File

@@ -0,0 +1,16 @@
//go:build go1.17
// +build go1.17
// TODO: once support for Go 1.16 is dropped, this file can be
// merged/removed with assertion_compare_go1.17_test.go and
// assertion_compare_legacy.go
package assert
import "reflect"
// Wrapper around reflect.Value.CanConvert, for compatibility
// reasons.
func canConvert(value reflect.Value, to reflect.Type) bool {
return value.CanConvert(to)
}

View File

@@ -0,0 +1,16 @@
//go:build !go1.17
// +build !go1.17
// TODO: once support for Go 1.16 is dropped, this file can be
// merged/removed with assertion_compare_go1.17_test.go and
// assertion_compare_can_convert.go
package assert
import "reflect"
// Older versions of Go does not have the reflect.Value.CanConvert
// method.
func canConvert(value reflect.Value, to reflect.Type) bool {
return false
}

View File

@@ -123,6 +123,18 @@ func ErrorAsf(t TestingT, err error, target interface{}, msg string, args ...int
return ErrorAs(t, err, target, append([]interface{}{msg}, args...)...)
}
// ErrorContainsf asserts that a function returned an error (i.e. not `nil`)
// and that the error contains the specified substring.
//
// actualObj, err := SomeFunction()
// assert.ErrorContainsf(t, err, expectedErrorSubString, "error message %s", "formatted")
func ErrorContainsf(t TestingT, theError error, contains string, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return ErrorContains(t, theError, contains, append([]interface{}{msg}, args...)...)
}
// ErrorIsf asserts that at least one of the errors in err's chain matches target.
// This is a wrapper for errors.Is.
func ErrorIsf(t TestingT, err error, target error, msg string, args ...interface{}) bool {
@@ -724,6 +736,16 @@ func WithinDurationf(t TestingT, expected time.Time, actual time.Time, delta tim
return WithinDuration(t, expected, actual, delta, append([]interface{}{msg}, args...)...)
}
// WithinRangef asserts that a time is within a time range (inclusive).
//
// assert.WithinRangef(t, time.Now(), time.Now().Add(-time.Second), time.Now().Add(time.Second), "error message %s", "formatted")
func WithinRangef(t TestingT, actual time.Time, start time.Time, end time.Time, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return WithinRange(t, actual, start, end, append([]interface{}{msg}, args...)...)
}
// YAMLEqf asserts that two YAML strings are equivalent.
func YAMLEqf(t TestingT, expected string, actual string, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {

View File

@@ -222,6 +222,30 @@ func (a *Assertions) ErrorAsf(err error, target interface{}, msg string, args ..
return ErrorAsf(a.t, err, target, msg, args...)
}
// ErrorContains asserts that a function returned an error (i.e. not `nil`)
// and that the error contains the specified substring.
//
// actualObj, err := SomeFunction()
// a.ErrorContains(err, expectedErrorSubString)
func (a *Assertions) ErrorContains(theError error, contains string, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return ErrorContains(a.t, theError, contains, msgAndArgs...)
}
// ErrorContainsf asserts that a function returned an error (i.e. not `nil`)
// and that the error contains the specified substring.
//
// actualObj, err := SomeFunction()
// a.ErrorContainsf(err, expectedErrorSubString, "error message %s", "formatted")
func (a *Assertions) ErrorContainsf(theError error, contains string, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return ErrorContainsf(a.t, theError, contains, msg, args...)
}
// ErrorIs asserts that at least one of the errors in err's chain matches target.
// This is a wrapper for errors.Is.
func (a *Assertions) ErrorIs(err error, target error, msgAndArgs ...interface{}) bool {
@@ -1437,6 +1461,26 @@ func (a *Assertions) WithinDurationf(expected time.Time, actual time.Time, delta
return WithinDurationf(a.t, expected, actual, delta, msg, args...)
}
// WithinRange asserts that a time is within a time range (inclusive).
//
// a.WithinRange(time.Now(), time.Now().Add(-time.Second), time.Now().Add(time.Second))
func (a *Assertions) WithinRange(actual time.Time, start time.Time, end time.Time, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return WithinRange(a.t, actual, start, end, msgAndArgs...)
}
// WithinRangef asserts that a time is within a time range (inclusive).
//
// a.WithinRangef(time.Now(), time.Now().Add(-time.Second), time.Now().Add(time.Second), "error message %s", "formatted")
func (a *Assertions) WithinRangef(actual time.Time, start time.Time, end time.Time, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return WithinRangef(a.t, actual, start, end, msg, args...)
}
// YAMLEq asserts that two YAML strings are equivalent.
func (a *Assertions) YAMLEq(expected string, actual string, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {

View File

@@ -50,7 +50,7 @@ func isOrdered(t TestingT, object interface{}, allowedComparesResults []CompareT
// assert.IsIncreasing(t, []float{1, 2})
// assert.IsIncreasing(t, []string{"a", "b"})
func IsIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
return isOrdered(t, object, []CompareType{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs)
return isOrdered(t, object, []CompareType{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs...)
}
// IsNonIncreasing asserts that the collection is not increasing
@@ -59,7 +59,7 @@ func IsIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) boo
// assert.IsNonIncreasing(t, []float{2, 1})
// assert.IsNonIncreasing(t, []string{"b", "a"})
func IsNonIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
return isOrdered(t, object, []CompareType{compareEqual, compareGreater}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs)
return isOrdered(t, object, []CompareType{compareEqual, compareGreater}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs...)
}
// IsDecreasing asserts that the collection is decreasing
@@ -68,7 +68,7 @@ func IsNonIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{})
// assert.IsDecreasing(t, []float{2, 1})
// assert.IsDecreasing(t, []string{"b", "a"})
func IsDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
return isOrdered(t, object, []CompareType{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs)
return isOrdered(t, object, []CompareType{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs...)
}
// IsNonDecreasing asserts that the collection is not decreasing
@@ -77,5 +77,5 @@ func IsDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) boo
// assert.IsNonDecreasing(t, []float{1, 2})
// assert.IsNonDecreasing(t, []string{"a", "b"})
func IsNonDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
return isOrdered(t, object, []CompareType{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs)
return isOrdered(t, object, []CompareType{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs...)
}

View File

@@ -8,6 +8,7 @@ import (
"fmt"
"math"
"os"
"path/filepath"
"reflect"
"regexp"
"runtime"
@@ -144,7 +145,8 @@ func CallerInfo() []string {
if len(parts) > 1 {
dir := parts[len(parts)-2]
if (dir != "assert" && dir != "mock" && dir != "require") || file == "mock_test.go" {
callers = append(callers, fmt.Sprintf("%s:%d", file, line))
path, _ := filepath.Abs(file)
callers = append(callers, fmt.Sprintf("%s:%d", path, line))
}
}
@@ -563,16 +565,17 @@ func isEmpty(object interface{}) bool {
switch objValue.Kind() {
// collection types are empty when they have no element
case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice:
case reflect.Chan, reflect.Map, reflect.Slice:
return objValue.Len() == 0
// pointers are empty if nil or if the value they point to is empty
// pointers are empty if nil or if the value they point to is empty
case reflect.Ptr:
if objValue.IsNil() {
return true
}
deref := objValue.Elem().Interface()
return isEmpty(deref)
// for all other types, compare against the zero value
// for all other types, compare against the zero value
// array types are empty when they match their zero-initialized state
default:
zero := reflect.Zero(objValue.Type())
return reflect.DeepEqual(object, zero.Interface())
@@ -718,10 +721,14 @@ func NotEqualValues(t TestingT, expected, actual interface{}, msgAndArgs ...inte
// return (false, false) if impossible.
// return (true, false) if element was not found.
// return (true, true) if element was found.
func includeElement(list interface{}, element interface{}) (ok, found bool) {
func containsElement(list interface{}, element interface{}) (ok, found bool) {
listValue := reflect.ValueOf(list)
listKind := reflect.TypeOf(list).Kind()
listType := reflect.TypeOf(list)
if listType == nil {
return false, false
}
listKind := listType.Kind()
defer func() {
if e := recover(); e != nil {
ok = false
@@ -764,7 +771,7 @@ func Contains(t TestingT, s, contains interface{}, msgAndArgs ...interface{}) bo
h.Helper()
}
ok, found := includeElement(s, contains)
ok, found := containsElement(s, contains)
if !ok {
return Fail(t, fmt.Sprintf("%#v could not be applied builtin len()", s), msgAndArgs...)
}
@@ -787,7 +794,7 @@ func NotContains(t TestingT, s, contains interface{}, msgAndArgs ...interface{})
h.Helper()
}
ok, found := includeElement(s, contains)
ok, found := containsElement(s, contains)
if !ok {
return Fail(t, fmt.Sprintf("\"%s\" could not be applied builtin len()", s), msgAndArgs...)
}
@@ -811,7 +818,6 @@ func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok
return true // we consider nil to be equal to the nil set
}
subsetValue := reflect.ValueOf(subset)
defer func() {
if e := recover(); e != nil {
ok = false
@@ -821,17 +827,35 @@ func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok
listKind := reflect.TypeOf(list).Kind()
subsetKind := reflect.TypeOf(subset).Kind()
if listKind != reflect.Array && listKind != reflect.Slice {
if listKind != reflect.Array && listKind != reflect.Slice && listKind != reflect.Map {
return Fail(t, fmt.Sprintf("%q has an unsupported type %s", list, listKind), msgAndArgs...)
}
if subsetKind != reflect.Array && subsetKind != reflect.Slice {
if subsetKind != reflect.Array && subsetKind != reflect.Slice && listKind != reflect.Map {
return Fail(t, fmt.Sprintf("%q has an unsupported type %s", subset, subsetKind), msgAndArgs...)
}
subsetValue := reflect.ValueOf(subset)
if subsetKind == reflect.Map && listKind == reflect.Map {
listValue := reflect.ValueOf(list)
subsetKeys := subsetValue.MapKeys()
for i := 0; i < len(subsetKeys); i++ {
subsetKey := subsetKeys[i]
subsetElement := subsetValue.MapIndex(subsetKey).Interface()
listElement := listValue.MapIndex(subsetKey).Interface()
if !ObjectsAreEqual(subsetElement, listElement) {
return Fail(t, fmt.Sprintf("\"%s\" does not contain \"%s\"", list, subsetElement), msgAndArgs...)
}
}
return true
}
for i := 0; i < subsetValue.Len(); i++ {
element := subsetValue.Index(i).Interface()
ok, found := includeElement(list, element)
ok, found := containsElement(list, element)
if !ok {
return Fail(t, fmt.Sprintf("\"%s\" could not be applied builtin len()", list), msgAndArgs...)
}
@@ -852,10 +876,9 @@ func NotSubset(t TestingT, list, subset interface{}, msgAndArgs ...interface{})
h.Helper()
}
if subset == nil {
return Fail(t, fmt.Sprintf("nil is the empty set which is a subset of every set"), msgAndArgs...)
return Fail(t, "nil is the empty set which is a subset of every set", msgAndArgs...)
}
subsetValue := reflect.ValueOf(subset)
defer func() {
if e := recover(); e != nil {
ok = false
@@ -865,17 +888,35 @@ func NotSubset(t TestingT, list, subset interface{}, msgAndArgs ...interface{})
listKind := reflect.TypeOf(list).Kind()
subsetKind := reflect.TypeOf(subset).Kind()
if listKind != reflect.Array && listKind != reflect.Slice {
if listKind != reflect.Array && listKind != reflect.Slice && listKind != reflect.Map {
return Fail(t, fmt.Sprintf("%q has an unsupported type %s", list, listKind), msgAndArgs...)
}
if subsetKind != reflect.Array && subsetKind != reflect.Slice {
if subsetKind != reflect.Array && subsetKind != reflect.Slice && listKind != reflect.Map {
return Fail(t, fmt.Sprintf("%q has an unsupported type %s", subset, subsetKind), msgAndArgs...)
}
subsetValue := reflect.ValueOf(subset)
if subsetKind == reflect.Map && listKind == reflect.Map {
listValue := reflect.ValueOf(list)
subsetKeys := subsetValue.MapKeys()
for i := 0; i < len(subsetKeys); i++ {
subsetKey := subsetKeys[i]
subsetElement := subsetValue.MapIndex(subsetKey).Interface()
listElement := listValue.MapIndex(subsetKey).Interface()
if !ObjectsAreEqual(subsetElement, listElement) {
return true
}
}
return Fail(t, fmt.Sprintf("%q is a subset of %q", subset, list), msgAndArgs...)
}
for i := 0; i < subsetValue.Len(); i++ {
element := subsetValue.Index(i).Interface()
ok, found := includeElement(list, element)
ok, found := containsElement(list, element)
if !ok {
return Fail(t, fmt.Sprintf("\"%s\" could not be applied builtin len()", list), msgAndArgs...)
}
@@ -1000,27 +1041,21 @@ func Condition(t TestingT, comp Comparison, msgAndArgs ...interface{}) bool {
type PanicTestFunc func()
// didPanic returns true if the function passed to it panics. Otherwise, it returns false.
func didPanic(f PanicTestFunc) (bool, interface{}, string) {
didPanic := false
var message interface{}
var stack string
func() {
defer func() {
if message = recover(); message != nil {
didPanic = true
stack = string(debug.Stack())
}
}()
// call the target function
f()
func didPanic(f PanicTestFunc) (didPanic bool, message interface{}, stack string) {
didPanic = true
defer func() {
message = recover()
if didPanic {
stack = string(debug.Stack())
}
}()
return didPanic, message, stack
// call the target function
f()
didPanic = false
return
}
// Panics asserts that the code inside the specified PanicTestFunc panics.
@@ -1111,6 +1146,27 @@ func WithinDuration(t TestingT, expected, actual time.Time, delta time.Duration,
return true
}
// WithinRange asserts that a time is within a time range (inclusive).
//
// assert.WithinRange(t, time.Now(), time.Now().Add(-time.Second), time.Now().Add(time.Second))
func WithinRange(t TestingT, actual, start, end time.Time, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if end.Before(start) {
return Fail(t, "Start should be before end", msgAndArgs...)
}
if actual.Before(start) {
return Fail(t, fmt.Sprintf("Time %v expected to be in time range %v to %v, but is before the range", actual, start, end), msgAndArgs...)
} else if actual.After(end) {
return Fail(t, fmt.Sprintf("Time %v expected to be in time range %v to %v, but is after the range", actual, start, end), msgAndArgs...)
}
return true
}
func toFloat(x interface{}) (float64, bool) {
var xf float64
xok := true
@@ -1161,11 +1217,15 @@ func InDelta(t TestingT, expected, actual interface{}, delta float64, msgAndArgs
bf, bok := toFloat(actual)
if !aok || !bok {
return Fail(t, fmt.Sprintf("Parameters must be numerical"), msgAndArgs...)
return Fail(t, "Parameters must be numerical", msgAndArgs...)
}
if math.IsNaN(af) && math.IsNaN(bf) {
return true
}
if math.IsNaN(af) {
return Fail(t, fmt.Sprintf("Expected must not be NaN"), msgAndArgs...)
return Fail(t, "Expected must not be NaN", msgAndArgs...)
}
if math.IsNaN(bf) {
@@ -1188,7 +1248,7 @@ func InDeltaSlice(t TestingT, expected, actual interface{}, delta float64, msgAn
if expected == nil || actual == nil ||
reflect.TypeOf(actual).Kind() != reflect.Slice ||
reflect.TypeOf(expected).Kind() != reflect.Slice {
return Fail(t, fmt.Sprintf("Parameters must be slice"), msgAndArgs...)
return Fail(t, "Parameters must be slice", msgAndArgs...)
}
actualSlice := reflect.ValueOf(actual)
@@ -1250,8 +1310,12 @@ func InDeltaMapValues(t TestingT, expected, actual interface{}, delta float64, m
func calcRelativeError(expected, actual interface{}) (float64, error) {
af, aok := toFloat(expected)
if !aok {
return 0, fmt.Errorf("expected value %q cannot be converted to float", expected)
bf, bok := toFloat(actual)
if !aok || !bok {
return 0, fmt.Errorf("Parameters must be numerical")
}
if math.IsNaN(af) && math.IsNaN(bf) {
return 0, nil
}
if math.IsNaN(af) {
return 0, errors.New("expected value must not be NaN")
@@ -1259,10 +1323,6 @@ func calcRelativeError(expected, actual interface{}) (float64, error) {
if af == 0 {
return 0, fmt.Errorf("expected value must have a value other than zero to calculate the relative error")
}
bf, bok := toFloat(actual)
if !bok {
return 0, fmt.Errorf("actual value %q cannot be converted to float", actual)
}
if math.IsNaN(bf) {
return 0, errors.New("actual value must not be NaN")
}
@@ -1298,7 +1358,7 @@ func InEpsilonSlice(t TestingT, expected, actual interface{}, epsilon float64, m
if expected == nil || actual == nil ||
reflect.TypeOf(actual).Kind() != reflect.Slice ||
reflect.TypeOf(expected).Kind() != reflect.Slice {
return Fail(t, fmt.Sprintf("Parameters must be slice"), msgAndArgs...)
return Fail(t, "Parameters must be slice", msgAndArgs...)
}
actualSlice := reflect.ValueOf(actual)
@@ -1375,6 +1435,27 @@ func EqualError(t TestingT, theError error, errString string, msgAndArgs ...inte
return true
}
// ErrorContains asserts that a function returned an error (i.e. not `nil`)
// and that the error contains the specified substring.
//
// actualObj, err := SomeFunction()
// assert.ErrorContains(t, err, expectedErrorSubString)
func ErrorContains(t TestingT, theError error, contains string, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if !Error(t, theError, msgAndArgs...) {
return false
}
actual := theError.Error()
if !strings.Contains(actual, contains) {
return Fail(t, fmt.Sprintf("Error %#v does not contain %#v", actual, contains), msgAndArgs...)
}
return true
}
// matchRegexp return true if a specified regexp matches a string.
func matchRegexp(rx interface{}, str interface{}) bool {
@@ -1588,12 +1669,17 @@ func diff(expected interface{}, actual interface{}) string {
}
var e, a string
if et != reflect.TypeOf("") {
e = spewConfig.Sdump(expected)
a = spewConfig.Sdump(actual)
} else {
switch et {
case reflect.TypeOf(""):
e = reflect.ValueOf(expected).String()
a = reflect.ValueOf(actual).String()
case reflect.TypeOf(time.Time{}):
e = spewConfigStringerEnabled.Sdump(expected)
a = spewConfigStringerEnabled.Sdump(actual)
default:
e = spewConfig.Sdump(expected)
a = spewConfig.Sdump(actual)
}
diff, _ := difflib.GetUnifiedDiffString(difflib.UnifiedDiff{
@@ -1625,6 +1711,14 @@ var spewConfig = spew.ConfigState{
MaxDepth: 10,
}
var spewConfigStringerEnabled = spew.ConfigState{
Indent: " ",
DisablePointerAddresses: true,
DisableCapacities: true,
SortKeys: true,
MaxDepth: 10,
}
type tHelper interface {
Helper()
}

View File

@@ -70,6 +70,9 @@ type Call struct {
// if the PanicMsg is set to a non nil string the function call will panic
// irrespective of other settings
PanicMsg *string
// Calls which must be satisfied before this call can be
requires []*Call
}
func newCall(parent *Mock, methodName string, callerInfo []string, methodArguments ...interface{}) *Call {
@@ -199,6 +202,64 @@ func (c *Call) On(methodName string, arguments ...interface{}) *Call {
return c.Parent.On(methodName, arguments...)
}
// Unset removes a mock handler from being called.
// test.On("func", mock.Anything).Unset()
func (c *Call) Unset() *Call {
var unlockOnce sync.Once
for _, arg := range c.Arguments {
if v := reflect.ValueOf(arg); v.Kind() == reflect.Func {
panic(fmt.Sprintf("cannot use Func in expectations. Use mock.AnythingOfType(\"%T\")", arg))
}
}
c.lock()
defer unlockOnce.Do(c.unlock)
foundMatchingCall := false
for i, call := range c.Parent.ExpectedCalls {
if call.Method == c.Method {
_, diffCount := call.Arguments.Diff(c.Arguments)
if diffCount == 0 {
foundMatchingCall = true
// Remove from ExpectedCalls
c.Parent.ExpectedCalls = append(c.Parent.ExpectedCalls[:i], c.Parent.ExpectedCalls[i+1:]...)
}
}
}
if !foundMatchingCall {
unlockOnce.Do(c.unlock)
c.Parent.fail("\n\nmock: Could not find expected call\n-----------------------------\n\n%s\n\n",
callString(c.Method, c.Arguments, true),
)
}
return c
}
// NotBefore indicates that the mock should only be called after the referenced
// calls have been called as expected. The referenced calls may be from the
// same mock instance and/or other mock instances.
//
// Mock.On("Do").Return(nil).Notbefore(
// Mock.On("Init").Return(nil)
// )
func (c *Call) NotBefore(calls ...*Call) *Call {
c.lock()
defer c.unlock()
for _, call := range calls {
if call.Parent == nil {
panic("not before calls must be created with Mock.On()")
}
}
c.requires = append(c.requires, calls...)
return c
}
// Mock is the workhorse used to track activity on another object.
// For an example of its usage, refer to the "Example Usage" section at the top
// of this document.
@@ -221,10 +282,17 @@ type Mock struct {
mutex sync.Mutex
}
// String provides a %v format string for Mock.
// Note: this is used implicitly by Arguments.Diff if a Mock is passed.
// It exists because go's default %v formatting traverses the struct
// without acquiring the mutex, which is detected by go test -race.
func (m *Mock) String() string {
return fmt.Sprintf("%[1]T<%[1]p>", m)
}
// TestData holds any data that might be useful for testing. Testify ignores
// this data completely allowing you to do whatever you like with it.
func (m *Mock) TestData() objx.Map {
if m.testData == nil {
m.testData = make(objx.Map)
}
@@ -346,7 +414,6 @@ func (m *Mock) findClosestCall(method string, arguments ...interface{}) (*Call,
}
func callString(method string, arguments Arguments, includeArgumentValues bool) string {
var argValsString string
if includeArgumentValues {
var argVals []string
@@ -370,10 +437,10 @@ func (m *Mock) Called(arguments ...interface{}) Arguments {
panic("Couldn't get the caller information")
}
functionPath := runtime.FuncForPC(pc).Name()
//Next four lines are required to use GCCGO function naming conventions.
//For Ex: github_com_docker_libkv_store_mock.WatchTree.pN39_github_com_docker_libkv_store_mock.Mock
//uses interface information unlike golang github.com/docker/libkv/store/mock.(*Mock).WatchTree
//With GCCGO we need to remove interface information starting from pN<dd>.
// Next four lines are required to use GCCGO function naming conventions.
// For Ex: github_com_docker_libkv_store_mock.WatchTree.pN39_github_com_docker_libkv_store_mock.Mock
// uses interface information unlike golang github.com/docker/libkv/store/mock.(*Mock).WatchTree
// With GCCGO we need to remove interface information starting from pN<dd>.
re := regexp.MustCompile("\\.pN\\d+_")
if re.MatchString(functionPath) {
functionPath = re.Split(functionPath, -1)[0]
@@ -389,7 +456,7 @@ func (m *Mock) Called(arguments ...interface{}) Arguments {
// If Call.WaitFor is set, blocks until the channel is closed or receives a message.
func (m *Mock) MethodCalled(methodName string, arguments ...interface{}) Arguments {
m.mutex.Lock()
//TODO: could combine expected and closes in single loop
// TODO: could combine expected and closes in single loop
found, call := m.findExpectedCall(methodName, arguments...)
if found < 0 {
@@ -419,6 +486,25 @@ func (m *Mock) MethodCalled(methodName string, arguments ...interface{}) Argumen
}
}
for _, requirement := range call.requires {
if satisfied, _ := requirement.Parent.checkExpectation(requirement); !satisfied {
m.mutex.Unlock()
m.fail("mock: Unexpected Method Call\n-----------------------------\n\n%s\n\nMust not be called before%s:\n\n%s",
callString(call.Method, call.Arguments, true),
func() (s string) {
if requirement.totalCalls > 0 {
s = " another call of"
}
if call.Parent != requirement.Parent {
s += " method from another mock instance"
}
return
}(),
callString(requirement.Method, requirement.Arguments, true),
)
}
}
if call.Repeatability == 1 {
call.Repeatability = -1
} else if call.Repeatability > 1 {
@@ -476,9 +562,9 @@ func AssertExpectationsForObjects(t TestingT, testObjects ...interface{}) bool {
h.Helper()
}
for _, obj := range testObjects {
if m, ok := obj.(Mock); ok {
if m, ok := obj.(*Mock); ok {
t.Logf("Deprecated mock.AssertExpectationsForObjects(myMock.Mock) use mock.AssertExpectationsForObjects(myMock)")
obj = &m
obj = m
}
m := obj.(assertExpectationser)
if !m.AssertExpectations(t) {
@@ -495,34 +581,36 @@ func (m *Mock) AssertExpectations(t TestingT) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
m.mutex.Lock()
defer m.mutex.Unlock()
var somethingMissing bool
var failedExpectations int
// iterate through each expectation
expectedCalls := m.expectedCalls()
for _, expectedCall := range expectedCalls {
if !expectedCall.optional && !m.methodWasCalled(expectedCall.Method, expectedCall.Arguments) && expectedCall.totalCalls == 0 {
somethingMissing = true
satisfied, reason := m.checkExpectation(expectedCall)
if !satisfied {
failedExpectations++
t.Logf("FAIL:\t%s(%s)\n\t\tat: %s", expectedCall.Method, expectedCall.Arguments.String(), expectedCall.callerInfo)
} else {
if expectedCall.Repeatability > 0 {
somethingMissing = true
failedExpectations++
t.Logf("FAIL:\t%s(%s)\n\t\tat: %s", expectedCall.Method, expectedCall.Arguments.String(), expectedCall.callerInfo)
} else {
t.Logf("PASS:\t%s(%s)", expectedCall.Method, expectedCall.Arguments.String())
}
}
t.Logf(reason)
}
if somethingMissing {
if failedExpectations != 0 {
t.Errorf("FAIL: %d out of %d expectation(s) were met.\n\tThe code you are testing needs to make %d more call(s).\n\tat: %s", len(expectedCalls)-failedExpectations, len(expectedCalls), failedExpectations, assert.CallerInfo())
}
return !somethingMissing
return failedExpectations == 0
}
func (m *Mock) checkExpectation(call *Call) (bool, string) {
if !call.optional && !m.methodWasCalled(call.Method, call.Arguments) && call.totalCalls == 0 {
return false, fmt.Sprintf("FAIL:\t%s(%s)\n\t\tat: %s", call.Method, call.Arguments.String(), call.callerInfo)
}
if call.Repeatability > 0 {
return false, fmt.Sprintf("FAIL:\t%s(%s)\n\t\tat: %s", call.Method, call.Arguments.String(), call.callerInfo)
}
return true, fmt.Sprintf("PASS:\t%s(%s)", call.Method, call.Arguments.String())
}
// AssertNumberOfCalls asserts that the method was called expectedCalls times.
@@ -720,7 +808,7 @@ func (f argumentMatcher) Matches(argument interface{}) bool {
}
func (f argumentMatcher) String() string {
return fmt.Sprintf("func(%s) bool", f.fn.Type().In(0).Name())
return fmt.Sprintf("func(%s) bool", f.fn.Type().In(0).String())
}
// MatchedBy can be used to match a mock call based on only certain properties
@@ -773,12 +861,12 @@ func (args Arguments) Is(objects ...interface{}) bool {
//
// Returns the diff string and number of differences found.
func (args Arguments) Diff(objects []interface{}) (string, int) {
//TODO: could return string as error and nil for No difference
// TODO: could return string as error and nil for No difference
var output = "\n"
output := "\n"
var differences int
var maxArgCount = len(args)
maxArgCount := len(args)
if len(objects) > maxArgCount {
maxArgCount = len(objects)
}
@@ -804,21 +892,28 @@ func (args Arguments) Diff(objects []interface{}) (string, int) {
}
if matcher, ok := expected.(argumentMatcher); ok {
if matcher.Matches(actual) {
var matches bool
func() {
defer func() {
if r := recover(); r != nil {
actualFmt = fmt.Sprintf("panic in argument matcher: %v", r)
}
}()
matches = matcher.Matches(actual)
}()
if matches {
output = fmt.Sprintf("%s\t%d: PASS: %s matched by %s\n", output, i, actualFmt, matcher)
} else {
differences++
output = fmt.Sprintf("%s\t%d: FAIL: %s not matched by %s\n", output, i, actualFmt, matcher)
}
} else if reflect.TypeOf(expected) == reflect.TypeOf((*AnythingOfTypeArgument)(nil)).Elem() {
// type checking
if reflect.TypeOf(actual).Name() != string(expected.(AnythingOfTypeArgument)) && reflect.TypeOf(actual).String() != string(expected.(AnythingOfTypeArgument)) {
// not match
differences++
output = fmt.Sprintf("%s\t%d: FAIL: type %s != type %s - %s\n", output, i, expected, reflect.TypeOf(actual).Name(), actualFmt)
}
} else if reflect.TypeOf(expected) == reflect.TypeOf((*IsTypeArgument)(nil)) {
t := expected.(*IsTypeArgument).t
if reflect.TypeOf(t) != reflect.TypeOf(actual) {
@@ -826,7 +921,6 @@ func (args Arguments) Diff(objects []interface{}) (string, int) {
output = fmt.Sprintf("%s\t%d: FAIL: type %s != type %s - %s\n", output, i, reflect.TypeOf(t).Name(), reflect.TypeOf(actual).Name(), actualFmt)
}
} else {
// normal checking
if assert.ObjectsAreEqual(expected, Anything) || assert.ObjectsAreEqual(actual, Anything) || assert.ObjectsAreEqual(actual, expected) {
@@ -846,7 +940,6 @@ func (args Arguments) Diff(objects []interface{}) (string, int) {
}
return output, differences
}
// Assert compares the arguments with the specified objects and fails if
@@ -868,7 +961,6 @@ func (args Arguments) Assert(t TestingT, objects ...interface{}) bool {
t.Errorf("%sArguments do not match.", assert.CallerInfo())
return false
}
// String gets the argument at the specified index. Panics if there is no argument, or
@@ -877,7 +969,6 @@ func (args Arguments) Assert(t TestingT, objects ...interface{}) bool {
// If no index is provided, String() returns a complete string representation
// of the arguments.
func (args Arguments) String(indexOrNil ...int) string {
if len(indexOrNil) == 0 {
// normal String() method - return a string representation of the args
var argsStr []string
@@ -887,7 +978,7 @@ func (args Arguments) String(indexOrNil ...int) string {
return strings.Join(argsStr, ",")
} else if len(indexOrNil) == 1 {
// Index has been specified - get the argument at that index
var index = indexOrNil[0]
index := indexOrNil[0]
var s string
var ok bool
if s, ok = args.Get(index).(string); !ok {
@@ -897,7 +988,6 @@ func (args Arguments) String(indexOrNil ...int) string {
}
panic(fmt.Sprintf("assert: arguments: Wrong number of arguments passed to String. Must be 0 or 1, not %d", len(indexOrNil)))
}
// Int gets the argument at the specified index. Panics if there is no argument, or

View File

@@ -280,6 +280,36 @@ func ErrorAsf(t TestingT, err error, target interface{}, msg string, args ...int
t.FailNow()
}
// ErrorContains asserts that a function returned an error (i.e. not `nil`)
// and that the error contains the specified substring.
//
// actualObj, err := SomeFunction()
// assert.ErrorContains(t, err, expectedErrorSubString)
func ErrorContains(t TestingT, theError error, contains string, msgAndArgs ...interface{}) {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if assert.ErrorContains(t, theError, contains, msgAndArgs...) {
return
}
t.FailNow()
}
// ErrorContainsf asserts that a function returned an error (i.e. not `nil`)
// and that the error contains the specified substring.
//
// actualObj, err := SomeFunction()
// assert.ErrorContainsf(t, err, expectedErrorSubString, "error message %s", "formatted")
func ErrorContainsf(t TestingT, theError error, contains string, msg string, args ...interface{}) {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if assert.ErrorContainsf(t, theError, contains, msg, args...) {
return
}
t.FailNow()
}
// ErrorIs asserts that at least one of the errors in err's chain matches target.
// This is a wrapper for errors.Is.
func ErrorIs(t TestingT, err error, target error, msgAndArgs ...interface{}) {
@@ -1834,6 +1864,32 @@ func WithinDurationf(t TestingT, expected time.Time, actual time.Time, delta tim
t.FailNow()
}
// WithinRange asserts that a time is within a time range (inclusive).
//
// assert.WithinRange(t, time.Now(), time.Now().Add(-time.Second), time.Now().Add(time.Second))
func WithinRange(t TestingT, actual time.Time, start time.Time, end time.Time, msgAndArgs ...interface{}) {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if assert.WithinRange(t, actual, start, end, msgAndArgs...) {
return
}
t.FailNow()
}
// WithinRangef asserts that a time is within a time range (inclusive).
//
// assert.WithinRangef(t, time.Now(), time.Now().Add(-time.Second), time.Now().Add(time.Second), "error message %s", "formatted")
func WithinRangef(t TestingT, actual time.Time, start time.Time, end time.Time, msg string, args ...interface{}) {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if assert.WithinRangef(t, actual, start, end, msg, args...) {
return
}
t.FailNow()
}
// YAMLEq asserts that two YAML strings are equivalent.
func YAMLEq(t TestingT, expected string, actual string, msgAndArgs ...interface{}) {
if h, ok := t.(tHelper); ok {

View File

@@ -223,6 +223,30 @@ func (a *Assertions) ErrorAsf(err error, target interface{}, msg string, args ..
ErrorAsf(a.t, err, target, msg, args...)
}
// ErrorContains asserts that a function returned an error (i.e. not `nil`)
// and that the error contains the specified substring.
//
// actualObj, err := SomeFunction()
// a.ErrorContains(err, expectedErrorSubString)
func (a *Assertions) ErrorContains(theError error, contains string, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
ErrorContains(a.t, theError, contains, msgAndArgs...)
}
// ErrorContainsf asserts that a function returned an error (i.e. not `nil`)
// and that the error contains the specified substring.
//
// actualObj, err := SomeFunction()
// a.ErrorContainsf(err, expectedErrorSubString, "error message %s", "formatted")
func (a *Assertions) ErrorContainsf(theError error, contains string, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
ErrorContainsf(a.t, theError, contains, msg, args...)
}
// ErrorIs asserts that at least one of the errors in err's chain matches target.
// This is a wrapper for errors.Is.
func (a *Assertions) ErrorIs(err error, target error, msgAndArgs ...interface{}) {
@@ -1438,6 +1462,26 @@ func (a *Assertions) WithinDurationf(expected time.Time, actual time.Time, delta
WithinDurationf(a.t, expected, actual, delta, msg, args...)
}
// WithinRange asserts that a time is within a time range (inclusive).
//
// a.WithinRange(time.Now(), time.Now().Add(-time.Second), time.Now().Add(time.Second))
func (a *Assertions) WithinRange(actual time.Time, start time.Time, end time.Time, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
WithinRange(a.t, actual, start, end, msgAndArgs...)
}
// WithinRangef asserts that a time is within a time range (inclusive).
//
// a.WithinRangef(time.Now(), time.Now().Add(-time.Second), time.Now().Add(time.Second), "error message %s", "formatted")
func (a *Assertions) WithinRangef(actual time.Time, start time.Time, end time.Time, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
WithinRangef(a.t, actual, start, end, msg, args...)
}
// YAMLEq asserts that two YAML strings are equivalent.
func (a *Assertions) YAMLEq(expected string, actual string, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {

View File

@@ -26,7 +26,7 @@ import (
var (
// MinClusterVersion is the min cluster version this etcd binary is compatible with.
MinClusterVersion = "3.0.0"
Version = "3.5.4"
Version = "3.5.5"
APIVersion = "unknown"
// Git SHA Value will be set during build

View File

@@ -0,0 +1,60 @@
// Copyright 2022 The etcd Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package fileutil
import (
"bufio"
"io"
"io/fs"
"os"
)
// FileReader is a wrapper of io.Reader. It also provides file info.
type FileReader interface {
io.Reader
FileInfo() (fs.FileInfo, error)
}
type fileReader struct {
*os.File
}
func NewFileReader(f *os.File) FileReader {
return &fileReader{f}
}
func (fr *fileReader) FileInfo() (fs.FileInfo, error) {
return fr.Stat()
}
// FileBufReader is a wrapper of bufio.Reader. It also provides file info.
type FileBufReader struct {
*bufio.Reader
fi fs.FileInfo
}
func NewFileBufReader(fr FileReader) *FileBufReader {
bufReader := bufio.NewReader(fr)
fi, err := fr.FileInfo()
if err != nil {
// This should never happen.
panic(err)
}
return &FileBufReader{bufReader, fi}
}
func (fbr *FileBufReader) FileInfo() fs.FileInfo {
return fbr.fi
}

View File

@@ -21,26 +21,29 @@ import (
"time"
)
type keepAliveConn interface {
SetKeepAlive(bool) error
SetKeepAlivePeriod(d time.Duration) error
}
// NewKeepAliveListener returns a listener that listens on the given address.
// Be careful when wrap around KeepAliveListener with another Listener if TLSInfo is not nil.
// Some pkgs (like go/http) might expect Listener to return TLSConn type to start TLS handshake.
// http://tldp.org/HOWTO/TCP-Keepalive-HOWTO/overview.html
//
// Note(ahrtr):
// only `net.TCPConn` supports `SetKeepAlive` and `SetKeepAlivePeriod`
// by default, so if you want to wrap multiple layers of net.Listener,
// the `keepaliveListener` should be the one which is closest to the
// original `net.Listener` implementation, namely `TCPListener`.
func NewKeepAliveListener(l net.Listener, scheme string, tlscfg *tls.Config) (net.Listener, error) {
kal := &keepaliveListener{
Listener: l,
}
if scheme == "https" {
if tlscfg == nil {
return nil, fmt.Errorf("cannot listen on TLS for given listener: KeyFile and CertFile are not presented")
}
return newTLSKeepaliveListener(l, tlscfg), nil
return newTLSKeepaliveListener(kal, tlscfg), nil
}
return &keepaliveListener{
Listener: l,
}, nil
return kal, nil
}
type keepaliveListener struct{ net.Listener }
@@ -50,13 +53,43 @@ func (kln *keepaliveListener) Accept() (net.Conn, error) {
if err != nil {
return nil, err
}
kac := c.(keepAliveConn)
kac, err := createKeepaliveConn(c)
if err != nil {
return nil, fmt.Errorf("create keepalive connection failed, %w", err)
}
// detection time: tcp_keepalive_time + tcp_keepalive_probes + tcp_keepalive_intvl
// default on linux: 30 + 8 * 30
// default on osx: 30 + 8 * 75
kac.SetKeepAlive(true)
kac.SetKeepAlivePeriod(30 * time.Second)
return c, nil
if err := kac.SetKeepAlive(true); err != nil {
return nil, fmt.Errorf("SetKeepAlive failed, %w", err)
}
if err := kac.SetKeepAlivePeriod(30 * time.Second); err != nil {
return nil, fmt.Errorf("SetKeepAlivePeriod failed, %w", err)
}
return kac, nil
}
func createKeepaliveConn(c net.Conn) (*keepAliveConn, error) {
tcpc, ok := c.(*net.TCPConn)
if !ok {
return nil, ErrNotTCP
}
return &keepAliveConn{tcpc}, nil
}
type keepAliveConn struct {
*net.TCPConn
}
// SetKeepAlive sets keepalive
func (l *keepAliveConn) SetKeepAlive(doKeepAlive bool) error {
return l.TCPConn.SetKeepAlive(doKeepAlive)
}
// SetKeepAlivePeriod sets keepalive period
func (l *keepAliveConn) SetKeepAlivePeriod(d time.Duration) error {
return l.TCPConn.SetKeepAlivePeriod(d)
}
// A tlsKeepaliveListener implements a network listener (net.Listener) for TLS connections.
@@ -72,12 +105,6 @@ func (l *tlsKeepaliveListener) Accept() (c net.Conn, err error) {
if err != nil {
return
}
kac := c.(keepAliveConn)
// detection time: tcp_keepalive_time + tcp_keepalive_probes + tcp_keepalive_intvl
// default on linux: 30 + 8 * 30
// default on osx: 30 + 8 * 75
kac.SetKeepAlive(true)
kac.SetKeepAlivePeriod(30 * time.Second)
c = tls.Server(c, l.config)
return c, nil
}

View File

@@ -63,6 +63,9 @@ func (l *limitListenerConn) Close() error {
return err
}
// SetKeepAlive sets keepalive
//
// Deprecated: use (*keepAliveConn) SetKeepAlive instead.
func (l *limitListenerConn) SetKeepAlive(doKeepAlive bool) error {
tcpc, ok := l.Conn.(*net.TCPConn)
if !ok {
@@ -71,6 +74,9 @@ func (l *limitListenerConn) SetKeepAlive(doKeepAlive bool) error {
return tcpc.SetKeepAlive(doKeepAlive)
}
// SetKeepAlivePeriod sets keepalive period
//
// Deprecated: use (*keepAliveConn) SetKeepAlivePeriod instead.
func (l *limitListenerConn) SetKeepAlivePeriod(d time.Duration) error {
tcpc, ok := l.Conn.(*net.TCPConn)
if !ok {

View File

@@ -68,7 +68,7 @@ func newListener(addr, scheme string, opts ...ListenerOption) (net.Listener, err
fallthrough
case lnOpts.IsTimeout(), lnOpts.IsSocketOpts():
// timeout listener with socket options.
ln, err := lnOpts.ListenConfig.Listen(context.TODO(), "tcp", addr)
ln, err := newKeepAliveListener(&lnOpts.ListenConfig, addr)
if err != nil {
return nil, err
}
@@ -78,7 +78,7 @@ func newListener(addr, scheme string, opts ...ListenerOption) (net.Listener, err
writeTimeout: lnOpts.writeTimeout,
}
case lnOpts.IsTimeout():
ln, err := net.Listen("tcp", addr)
ln, err := newKeepAliveListener(nil, addr)
if err != nil {
return nil, err
}
@@ -88,7 +88,7 @@ func newListener(addr, scheme string, opts ...ListenerOption) (net.Listener, err
writeTimeout: lnOpts.writeTimeout,
}
default:
ln, err := net.Listen("tcp", addr)
ln, err := newKeepAliveListener(nil, addr)
if err != nil {
return nil, err
}
@@ -102,6 +102,19 @@ func newListener(addr, scheme string, opts ...ListenerOption) (net.Listener, err
return wrapTLS(scheme, lnOpts.tlsInfo, lnOpts.Listener)
}
func newKeepAliveListener(cfg *net.ListenConfig, addr string) (ln net.Listener, err error) {
if cfg != nil {
ln, err = cfg.Listen(context.TODO(), "tcp", addr)
} else {
ln, err = net.Listen("tcp", addr)
}
if err != nil {
return
}
return NewKeepAliveListener(ln, "tcp", nil)
}
func wrapTLS(scheme string, tlsinfo *TLSInfo, l net.Listener) (net.Listener, error) {
if scheme != "https" && scheme != "unixs" {
return l, nil

View File

@@ -286,8 +286,7 @@ func (c *Client) dial(creds grpccredentials.TransportCredentials, dopts ...grpc.
if err != nil {
return nil, fmt.Errorf("failed to configure dialer: %v", err)
}
if c.Username != "" && c.Password != "" {
c.authTokenBundle = credentials.NewBundle(credentials.Config{})
if c.authTokenBundle != nil {
opts = append(opts, grpc.WithPerRPCCredentials(c.authTokenBundle.PerRPCCredentials()))
}
@@ -383,6 +382,7 @@ func newClient(cfg *Config) (*Client, error) {
if cfg.Username != "" && cfg.Password != "" {
client.Username = cfg.Username
client.Password = cfg.Password
client.authTokenBundle = credentials.NewBundle(credentials.Config{})
}
if cfg.MaxCallSendMsgSize > 0 || cfg.MaxCallRecvMsgSize > 0 {
if cfg.MaxCallRecvMsgSize > 0 && cfg.MaxCallSendMsgSize > cfg.MaxCallRecvMsgSize {

View File

@@ -397,7 +397,7 @@ func (l *lessor) closeRequireLeader() {
}
}
func (l *lessor) keepAliveOnce(ctx context.Context, id LeaseID) (*LeaseKeepAliveResponse, error) {
func (l *lessor) keepAliveOnce(ctx context.Context, id LeaseID) (karesp *LeaseKeepAliveResponse, ferr error) {
cctx, cancel := context.WithCancel(ctx)
defer cancel()
@@ -406,6 +406,15 @@ func (l *lessor) keepAliveOnce(ctx context.Context, id LeaseID) (*LeaseKeepAlive
return nil, toErr(ctx, err)
}
defer func() {
if err := stream.CloseSend(); err != nil {
if ferr == nil {
ferr = toErr(ctx, err)
}
return
}
}()
err = stream.Send(&pb.LeaseKeepAliveRequest{ID: int64(id)})
if err != nil {
return nil, toErr(ctx, err)
@@ -416,7 +425,7 @@ func (l *lessor) keepAliveOnce(ctx context.Context, id LeaseID) (*LeaseKeepAlive
return nil, toErr(ctx, rerr)
}
karesp := &LeaseKeepAliveResponse{
karesp = &LeaseKeepAliveResponse{
ResponseHeader: resp.GetHeader(),
ID: LeaseID(resp.ID),
TTL: resp.TTL,

View File

@@ -51,8 +51,8 @@ func etcdClientDebugLevel() zapcore.Level {
return zapcore.InfoLevel
}
var l zapcore.Level
if err := l.Set(envLevel); err == nil {
log.Printf("Deprecated env ETCD_CLIENT_DEBUG value. Using default level: 'info'")
if err := l.Set(envLevel); err != nil {
log.Printf("Invalid value for environment variable 'ETCD_CLIENT_DEBUG'. Using default level: 'info'")
return zapcore.InfoLevel
}
return l

View File

@@ -389,12 +389,12 @@ func getPrefix(key []byte) []byte {
// can return 'foo1', 'foo2', and so on.
func WithPrefix() OpOption {
return func(op *Op) {
op.isOptsWithPrefix = true
if len(op.key) == 0 {
op.key, op.end = []byte{0}, []byte{0}
return
}
op.end = getPrefix(op.key)
op.isOptsWithPrefix = true
}
}

45
vendor/go.etcd.io/etcd/pkg/v3/flags/uint32.go generated vendored Normal file
View File

@@ -0,0 +1,45 @@
// Copyright 2022 The etcd Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package flags
import (
"flag"
"strconv"
)
type uint32Value uint32
// NewUint32Value creates an uint32 instance with the provided value.
func NewUint32Value(v uint32) *uint32Value {
val := new(uint32Value)
*val = uint32Value(v)
return val
}
// Set parses a command line uint32 value.
// Implements "flag.Value" interface.
func (i *uint32Value) Set(s string) error {
v, err := strconv.ParseUint(s, 0, 32)
*i = uint32Value(v)
return err
}
func (i *uint32Value) String() string { return strconv.FormatUint(uint64(*i), 10) }
// Uint32FromFlag return the uint32 value of a flag with the given name
func Uint32FromFlag(fs *flag.FlagSet, name string) uint32 {
val := *fs.Lookup(name).Value.(*uint32Value)
return uint32(val)
}

View File

@@ -105,34 +105,48 @@ func checkKeyPoint(lg *zap.Logger, cachedPerms *unifiedRangePermissions, key []b
return false
}
func (as *authStore) isRangeOpPermitted(tx backend.ReadTx, userName string, key, rangeEnd []byte, permtyp authpb.Permission_Type) bool {
// assumption: tx is Lock()ed
_, ok := as.rangePermCache[userName]
func (as *authStore) isRangeOpPermitted(userName string, key, rangeEnd []byte, permtyp authpb.Permission_Type) bool {
as.rangePermCacheMu.RLock()
defer as.rangePermCacheMu.RUnlock()
rangePerm, ok := as.rangePermCache[userName]
if !ok {
as.lg.Error(
"user doesn't exist",
zap.String("user-name", userName),
)
return false
}
if len(rangeEnd) == 0 {
return checkKeyPoint(as.lg, rangePerm, key, permtyp)
}
return checkKeyInterval(as.lg, rangePerm, key, rangeEnd, permtyp)
}
func (as *authStore) refreshRangePermCache(tx backend.ReadTx) {
// Note that every authentication configuration update calls this method and it invalidates the entire
// rangePermCache and reconstruct it based on information of users and roles stored in the backend.
// This can be a costly operation.
as.rangePermCacheMu.Lock()
defer as.rangePermCacheMu.Unlock()
as.rangePermCache = make(map[string]*unifiedRangePermissions)
users := getAllUsers(as.lg, tx)
for _, user := range users {
userName := string(user.Name)
perms := getMergedPerms(as.lg, tx, userName)
if perms == nil {
as.lg.Error(
"failed to create a merged permission",
zap.String("user-name", userName),
)
return false
continue
}
as.rangePermCache[userName] = perms
}
if len(rangeEnd) == 0 {
return checkKeyPoint(as.lg, as.rangePermCache[userName], key, permtyp)
}
return checkKeyInterval(as.lg, as.rangePermCache[userName], key, rangeEnd, permtyp)
}
func (as *authStore) clearCachedPerm() {
as.rangePermCache = make(map[string]*unifiedRangePermissions)
}
func (as *authStore) invalidateCachedPerm(userName string) {
delete(as.rangePermCache, userName)
}
type unifiedRangePermissions struct {

View File

@@ -156,6 +156,11 @@ func (t *tokenSimple) invalidateUser(username string) {
}
func (t *tokenSimple) enable() {
t.simpleTokensMu.Lock()
defer t.simpleTokensMu.Unlock()
if t.simpleTokenKeeper != nil { // already enabled
return
}
if t.simpleTokenTTL <= 0 {
t.simpleTokenTTL = simpleTokenTTLDefault
}

View File

@@ -208,7 +208,14 @@ type authStore struct {
enabled bool
enabledMu sync.RWMutex
rangePermCache map[string]*unifiedRangePermissions // username -> unifiedRangePermissions
// rangePermCache needs to be protected by rangePermCacheMu
// rangePermCacheMu needs to be write locked only in initialization phase or configuration changes
// Hot paths like Range(), needs to acquire read lock for improving performance
//
// Note that BatchTx and ReadTx cannot be a mutex for rangePermCache because they are independent resources
// see also: https://github.com/etcd-io/etcd/pull/13920#discussion_r849114855
rangePermCache map[string]*unifiedRangePermissions // username -> unifiedRangePermissions
rangePermCacheMu sync.RWMutex
tokenProvider TokenProvider
bcryptCost int // the algorithm cost / strength for hashing auth passwords
@@ -243,7 +250,7 @@ func (as *authStore) AuthEnable() error {
as.enabled = true
as.tokenProvider.enable()
as.rangePermCache = make(map[string]*unifiedRangePermissions)
as.refreshRangePermCache(tx)
as.setRevision(getRevision(tx))
@@ -368,6 +375,9 @@ func (as *authStore) Recover(be backend.Backend) {
as.enabledMu.Lock()
as.enabled = enabled
if enabled {
as.tokenProvider.enable()
}
as.enabledMu.Unlock()
}
@@ -419,6 +429,7 @@ func (as *authStore) UserAdd(r *pb.AuthUserAddRequest) (*pb.AuthUserAddResponse,
putUser(as.lg, tx, newUser)
as.commitRevision(tx)
as.refreshRangePermCache(tx)
as.lg.Info("added a user", zap.String("user-name", r.Name))
return &pb.AuthUserAddResponse{}, nil
@@ -442,8 +453,8 @@ func (as *authStore) UserDelete(r *pb.AuthUserDeleteRequest) (*pb.AuthUserDelete
delUser(tx, r.Name)
as.commitRevision(tx)
as.refreshRangePermCache(tx)
as.invalidateCachedPerm(r.Name)
as.tokenProvider.invalidateUser(r.Name)
as.lg.Info(
@@ -484,8 +495,8 @@ func (as *authStore) UserChangePassword(r *pb.AuthUserChangePasswordRequest) (*p
putUser(as.lg, tx, updatedUser)
as.commitRevision(tx)
as.refreshRangePermCache(tx)
as.invalidateCachedPerm(r.Name)
as.tokenProvider.invalidateUser(r.Name)
as.lg.Info(
@@ -529,9 +540,8 @@ func (as *authStore) UserGrantRole(r *pb.AuthUserGrantRoleRequest) (*pb.AuthUser
putUser(as.lg, tx, user)
as.invalidateCachedPerm(r.User)
as.commitRevision(tx)
as.refreshRangePermCache(tx)
as.lg.Info(
"granted a role to a user",
@@ -607,9 +617,8 @@ func (as *authStore) UserRevokeRole(r *pb.AuthUserRevokeRoleRequest) (*pb.AuthUs
putUser(as.lg, tx, updatedUser)
as.invalidateCachedPerm(r.Name)
as.commitRevision(tx)
as.refreshRangePermCache(tx)
as.lg.Info(
"revoked a role from a user",
@@ -675,11 +684,8 @@ func (as *authStore) RoleRevokePermission(r *pb.AuthRoleRevokePermissionRequest)
putRole(as.lg, tx, updatedRole)
// TODO(mitake): currently single role update invalidates every cache
// It should be optimized.
as.clearCachedPerm()
as.commitRevision(tx)
as.refreshRangePermCache(tx)
as.lg.Info(
"revoked a permission on range",
@@ -727,10 +733,10 @@ func (as *authStore) RoleDelete(r *pb.AuthRoleDeleteRequest) (*pb.AuthRoleDelete
putUser(as.lg, tx, updatedUser)
as.invalidateCachedPerm(string(user.Name))
}
as.commitRevision(tx)
as.refreshRangePermCache(tx)
as.lg.Info("deleted a role", zap.String("role-name", r.Role))
return &pb.AuthRoleDeleteResponse{}, nil
@@ -815,11 +821,8 @@ func (as *authStore) RoleGrantPermission(r *pb.AuthRoleGrantPermissionRequest) (
putRole(as.lg, tx, role)
// TODO(mitake): currently single role update invalidates every cache
// It should be optimized.
as.clearCachedPerm()
as.commitRevision(tx)
as.refreshRangePermCache(tx)
as.lg.Info(
"granted/updated a permission to a user",
@@ -864,7 +867,7 @@ func (as *authStore) isOpPermitted(userName string, revision uint64, key, rangeE
return nil
}
if as.isRangeOpPermitted(tx, userName, key, rangeEnd, permTyp) {
if as.isRangeOpPermitted(userName, key, rangeEnd, permTyp) {
return nil
}
@@ -926,7 +929,15 @@ func getUser(lg *zap.Logger, tx backend.ReadTx, username string) *authpb.User {
}
func getAllUsers(lg *zap.Logger, tx backend.ReadTx) []*authpb.User {
_, vs := tx.UnsafeRange(buckets.AuthUsers, []byte{0}, []byte{0xff}, -1)
var vs [][]byte
err := tx.UnsafeForEach(buckets.AuthUsers, func(k []byte, v []byte) error {
vs = append(vs, v)
return nil
})
if err != nil {
lg.Panic("failed to get users",
zap.Error(err))
}
if len(vs) == 0 {
return nil
}
@@ -1062,6 +1073,8 @@ func NewAuthStore(lg *zap.Logger, be backend.Backend, tp TokenProvider, bcryptCo
as.setupMetricsReporter()
as.refreshRangePermCache(tx)
tx.Unlock()
be.ForceCommit()

View File

@@ -120,6 +120,10 @@ type ServerConfig struct {
// MaxRequestBytes is the maximum request size to send over raft.
MaxRequestBytes uint
// MaxConcurrentStreams specifies the maximum number of concurrent
// streams that each client can open at a time.
MaxConcurrentStreams uint32
WarningApplyDuration time.Duration
StrictReconfigCheck bool
@@ -133,8 +137,10 @@ type ServerConfig struct {
// InitialCorruptCheck is true to check data corruption on boot
// before serving any peer/client traffic.
InitialCorruptCheck bool
CorruptCheckTime time.Duration
InitialCorruptCheck bool
CorruptCheckTime time.Duration
CompactHashCheckEnabled bool
CompactHashCheckTime time.Duration
// PreVote is true to enable Raft Pre-Vote.
PreVote bool

View File

@@ -17,6 +17,7 @@ package embed
import (
"fmt"
"io/ioutil"
"math"
"net"
"net/http"
"net/url"
@@ -55,6 +56,7 @@ const (
DefaultMaxTxnOps = uint(128)
DefaultWarningApplyDuration = 100 * time.Millisecond
DefaultMaxRequestBytes = 1.5 * 1024 * 1024
DefaultMaxConcurrentStreams = math.MaxUint32
DefaultGRPCKeepAliveMinTime = 5 * time.Second
DefaultGRPCKeepAliveInterval = 2 * time.Hour
DefaultGRPCKeepAliveTimeout = 20 * time.Second
@@ -198,6 +200,10 @@ type Config struct {
MaxTxnOps uint `json:"max-txn-ops"`
MaxRequestBytes uint `json:"max-request-bytes"`
// MaxConcurrentStreams specifies the maximum number of concurrent
// streams that each client can open at a time.
MaxConcurrentStreams uint32 `json:"max-concurrent-streams"`
LPUrls, LCUrls []url.URL
APUrls, ACUrls []url.URL
ClientTLSInfo transport.TLSInfo
@@ -305,11 +311,13 @@ type Config struct {
AuthToken string `json:"auth-token"`
BcryptCost uint `json:"bcrypt-cost"`
//The AuthTokenTTL in seconds of the simple token
// AuthTokenTTL specifies the TTL in seconds of the simple token
AuthTokenTTL uint `json:"auth-token-ttl"`
ExperimentalInitialCorruptCheck bool `json:"experimental-initial-corrupt-check"`
ExperimentalCorruptCheckTime time.Duration `json:"experimental-corrupt-check-time"`
ExperimentalInitialCorruptCheck bool `json:"experimental-initial-corrupt-check"`
ExperimentalCorruptCheckTime time.Duration `json:"experimental-corrupt-check-time"`
ExperimentalCompactHashCheckEnabled bool `json:"experimental-compact-hash-check-enabled"`
ExperimentalCompactHashCheckTime time.Duration `json:"experimental-compact-hash-check-time"`
// ExperimentalEnableV2V3 configures URLs that expose deprecated V2 API working on V3 store.
// Deprecated in v3.5.
// TODO: Delete in v3.6 (https://github.com/etcd-io/etcd/issues/12913)
@@ -448,6 +456,7 @@ func NewConfig() *Config {
MaxTxnOps: DefaultMaxTxnOps,
MaxRequestBytes: DefaultMaxRequestBytes,
MaxConcurrentStreams: DefaultMaxConcurrentStreams,
ExperimentalWarningApplyDuration: DefaultWarningApplyDuration,
GRPCKeepAliveMinTime: DefaultGRPCKeepAliveMinTime,
@@ -494,6 +503,9 @@ func NewConfig() *Config {
ExperimentalMemoryMlock: false,
ExperimentalTxnModeWriteWithSharedBuffer: true,
ExperimentalCompactHashCheckEnabled: false,
ExperimentalCompactHashCheckTime: time.Minute,
V2Deprecation: config.V2_DEPR_DEFAULT,
}
cfg.InitialCluster = cfg.InitialClusterFromName(cfg.Name)
@@ -691,6 +703,10 @@ func (cfg *Config) Validate() error {
return fmt.Errorf("setting experimental-enable-lease-checkpoint-persist requires experimental-enable-lease-checkpoint")
}
if cfg.ExperimentalCompactHashCheckTime <= 0 {
return fmt.Errorf("--experimental-compact-hash-check-time must be >0 (set to %v)", cfg.ExperimentalCompactHashCheckTime)
}
return nil
}

View File

@@ -0,0 +1,92 @@
// Copyright 2022 The etcd Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package embed
import (
"context"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
"go.opentelemetry.io/otel/propagation"
"go.opentelemetry.io/otel/sdk/resource"
tracesdk "go.opentelemetry.io/otel/sdk/trace"
semconv "go.opentelemetry.io/otel/semconv/v1.4.0"
"go.uber.org/zap"
)
func setupTracingExporter(ctx context.Context, cfg *Config) (exporter tracesdk.SpanExporter, options []otelgrpc.Option, err error) {
exporter, err = otlptracegrpc.New(ctx,
otlptracegrpc.WithInsecure(),
otlptracegrpc.WithEndpoint(cfg.ExperimentalDistributedTracingAddress),
)
if err != nil {
return nil, nil, err
}
res, err := resource.New(ctx,
resource.WithAttributes(
semconv.ServiceNameKey.String(cfg.ExperimentalDistributedTracingServiceName),
),
)
if err != nil {
return nil, nil, err
}
if resWithIDKey := determineResourceWithIDKey(cfg.ExperimentalDistributedTracingServiceInstanceID); resWithIDKey != nil {
// Merge resources into a new
// resource in case of duplicates.
res, err = resource.Merge(res, resWithIDKey)
if err != nil {
return nil, nil, err
}
}
options = append(options,
otelgrpc.WithPropagators(
propagation.NewCompositeTextMapPropagator(
propagation.TraceContext{},
propagation.Baggage{},
),
),
otelgrpc.WithTracerProvider(
tracesdk.NewTracerProvider(
tracesdk.WithBatcher(exporter),
tracesdk.WithResource(res),
tracesdk.WithSampler(tracesdk.ParentBased(tracesdk.NeverSample())),
),
),
)
cfg.logger.Debug(
"distributed tracing enabled",
zap.String("address", cfg.ExperimentalDistributedTracingAddress),
zap.String("service-name", cfg.ExperimentalDistributedTracingServiceName),
zap.String("service-instance-id", cfg.ExperimentalDistributedTracingServiceInstanceID),
)
return exporter, options, err
}
// As Tracing service Instance ID must be unique, it should
// never use the empty default string value, it's set if
// if it's a non empty string.
func determineResourceWithIDKey(serviceInstanceID string) *resource.Resource {
if serviceInstanceID != "" {
return resource.NewSchemaless(
(semconv.ServiceInstanceIDKey.String(serviceInstanceID)),
)
}
return nil
}

View File

@@ -46,13 +46,6 @@ import (
grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus"
"github.com/soheilhy/cmux"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
"go.opentelemetry.io/otel/exporters/otlp"
"go.opentelemetry.io/otel/exporters/otlp/otlpgrpc"
"go.opentelemetry.io/otel/propagation"
"go.opentelemetry.io/otel/sdk/resource"
tracesdk "go.opentelemetry.io/otel/sdk/trace"
"go.opentelemetry.io/otel/semconv"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
@@ -199,6 +192,7 @@ func StartEtcd(inCfg *Config) (e *Etcd, err error) {
BackendBatchInterval: cfg.BackendBatchInterval,
MaxTxnOps: cfg.MaxTxnOps,
MaxRequestBytes: cfg.MaxRequestBytes,
MaxConcurrentStreams: cfg.MaxConcurrentStreams,
SocketOpts: cfg.SocketOpts,
StrictReconfigCheck: cfg.StrictReconfigCheck,
ClientCertAuthEnabled: cfg.ClientTLSInfo.ClientCertAuth,
@@ -209,6 +203,8 @@ func StartEtcd(inCfg *Config) (e *Etcd, err error) {
HostWhitelist: cfg.HostWhitelist,
InitialCorruptCheck: cfg.ExperimentalInitialCorruptCheck,
CorruptCheckTime: cfg.ExperimentalCorruptCheckTime,
CompactHashCheckEnabled: cfg.ExperimentalCompactHashCheckEnabled,
CompactHashCheckTime: cfg.ExperimentalCompactHashCheckTime,
PreVote: cfg.PreVote,
Logger: cfg.logger,
ForceNewCluster: cfg.ForceNewCluster,
@@ -229,7 +225,7 @@ func StartEtcd(inCfg *Config) (e *Etcd, err error) {
if srvcfg.ExperimentalEnableDistributedTracing {
tctx := context.Background()
tracingExporter, opts, err := e.setupTracing(tctx)
tracingExporter, opts, err := setupTracingExporter(tctx, cfg)
if err != nil {
return e, err
}
@@ -238,6 +234,8 @@ func StartEtcd(inCfg *Config) (e *Etcd, err error) {
}
e.tracingExporterShutdown = func() { tracingExporter.Shutdown(tctx) }
srvcfg.ExperimentalTracerOptions = opts
e.cfg.logger.Info("distributed tracing setup enabled")
}
print(e.cfg.logger, *cfg, srvcfg, memberInitialized)
@@ -251,8 +249,8 @@ func StartEtcd(inCfg *Config) (e *Etcd, err error) {
// newly started member ("memberInitialized==false")
// does not need corruption check
if memberInitialized {
if err = e.Server.CheckInitialHashKV(); err != nil {
if memberInitialized && srvcfg.InitialCorruptCheck {
if err = e.Server.CorruptionChecker().InitialCheck(); err != nil {
// set "EtcdServer" to nil, so that it does not block on "EtcdServer.Close()"
// (nothing to close since rafthttp transports have not been started)
@@ -336,10 +334,15 @@ func print(lg *zap.Logger, ec Config, sc config.ServerConfig, memberInitialized
zap.String("initial-cluster", sc.InitialPeerURLsMap.String()),
zap.String("initial-cluster-state", ec.ClusterState),
zap.String("initial-cluster-token", sc.InitialClusterToken),
zap.Int64("quota-size-bytes", quota),
zap.Int64("quota-backend-bytes", quota),
zap.Uint("max-request-bytes", sc.MaxRequestBytes),
zap.Uint32("max-concurrent-streams", sc.MaxConcurrentStreams),
zap.Bool("pre-vote", sc.PreVote),
zap.Bool("initial-corrupt-check", sc.InitialCorruptCheck),
zap.String("corrupt-check-time-interval", sc.CorruptCheckTime.String()),
zap.Bool("compact-check-time-enabled", sc.CompactHashCheckEnabled),
zap.Duration("compact-check-time-interval", sc.CompactHashCheckTime),
zap.String("auto-compaction-mode", sc.AutoCompactionMode),
zap.Duration("auto-compaction-retention", sc.AutoCompactionRetention),
zap.String("auto-compaction-interval", sc.AutoCompactionRetention.String()),
@@ -651,12 +654,6 @@ func configureClientListeners(cfg *Config) (sctxs map[string]*serveCtx, err erro
sctx.l = transport.LimitListener(sctx.l, int(fdLimit-reservedInternalFDNum))
}
if network == "tcp" {
if sctx.l, err = transport.NewKeepAliveListener(sctx.l, network, nil); err != nil {
return nil, err
}
}
defer func(u url.URL) {
if err == nil {
return
@@ -809,52 +806,3 @@ func parseCompactionRetention(mode, retention string) (ret time.Duration, err er
}
return ret, nil
}
func (e *Etcd) setupTracing(ctx context.Context) (exporter tracesdk.SpanExporter, options []otelgrpc.Option, err error) {
exporter, err = otlp.NewExporter(ctx,
otlpgrpc.NewDriver(
otlpgrpc.WithEndpoint(e.cfg.ExperimentalDistributedTracingAddress),
otlpgrpc.WithInsecure(),
))
if err != nil {
return nil, nil, err
}
res := resource.NewWithAttributes(
semconv.ServiceNameKey.String(e.cfg.ExperimentalDistributedTracingServiceName),
)
// As Tracing service Instance ID must be unique, it should
// never use the empty default string value, so we only set it
// if it's a non empty string.
if e.cfg.ExperimentalDistributedTracingServiceInstanceID != "" {
resWithIDKey := resource.NewWithAttributes(
(semconv.ServiceInstanceIDKey.String(e.cfg.ExperimentalDistributedTracingServiceInstanceID)),
)
// Merge resources to combine into a new
// resource in case of duplicates.
res = resource.Merge(res, resWithIDKey)
}
options = append(options,
otelgrpc.WithPropagators(
propagation.NewCompositeTextMapPropagator(
propagation.TraceContext{},
propagation.Baggage{},
),
),
otelgrpc.WithTracerProvider(
tracesdk.NewTracerProvider(
tracesdk.WithBatcher(exporter),
tracesdk.WithResource(res),
),
),
)
e.cfg.logger.Info(
"distributed tracing enabled",
zap.String("distributed-tracing-address", e.cfg.ExperimentalDistributedTracingAddress),
zap.String("distributed-tracing-service-name", e.cfg.ExperimentalDistributedTracingServiceName),
zap.String("distributed-tracing-service-instance-id", e.cfg.ExperimentalDistributedTracingServiceInstanceID),
)
return exporter, options, err
}

View File

@@ -29,6 +29,7 @@ import (
"go.etcd.io/etcd/client/v3/credentials"
"go.etcd.io/etcd/pkg/v3/debugutil"
"go.etcd.io/etcd/pkg/v3/httputil"
"go.etcd.io/etcd/server/v3/config"
"go.etcd.io/etcd/server/v3/etcdserver"
"go.etcd.io/etcd/server/v3/etcdserver/api/v3client"
"go.etcd.io/etcd/server/v3/etcdserver/api/v3election"
@@ -43,6 +44,7 @@ import (
"github.com/soheilhy/cmux"
"github.com/tmc/grpc-websocket-proxy/wsproxy"
"go.uber.org/zap"
"golang.org/x/net/http2"
"golang.org/x/net/trace"
"google.golang.org/grpc"
)
@@ -133,6 +135,10 @@ func (sctx *serveCtx) serve(
Handler: createAccessController(sctx.lg, s, httpmux),
ErrorLog: logger, // do not log user error
}
if err := configureHttpServer(srvhttp, s.Cfg); err != nil {
sctx.lg.Error("Configure http server failed", zap.Error(err))
return err
}
httpl := m.Match(cmux.HTTP1())
go func() { errHandler(srvhttp.Serve(httpl)) }()
@@ -182,6 +188,10 @@ func (sctx *serveCtx) serve(
TLSConfig: tlscfg,
ErrorLog: logger, // do not log user error
}
if err := configureHttpServer(srv, s.Cfg); err != nil {
sctx.lg.Error("Configure https server failed", zap.Error(err))
return err
}
go func() { errHandler(srv.Serve(tlsl)) }()
sctx.serversC <- &servers{secure: true, grpc: gs, http: srv}
@@ -195,6 +205,13 @@ func (sctx *serveCtx) serve(
return m.Serve()
}
func configureHttpServer(srv *http.Server, cfg config.ServerConfig) error {
// todo (ahrtr): should we support configuring other parameters in the future as well?
return http2.ConfigureServer(srv, &http2.Server{
MaxConcurrentStreams: cfg.MaxConcurrentStreams,
})
}
// grpcHandlerFunc returns an http.Handler that delegates to grpcServer on incoming gRPC
// connections or otherHandler otherwise. Given in gRPC docs.
func grpcHandlerFunc(grpcServer *grpc.Server, otherHandler http.Handler) http.Handler {

View File

@@ -32,7 +32,6 @@ import (
const (
grpcOverheadBytes = 512 * 1024
maxStreams = math.MaxUint32
maxSendBytes = math.MaxInt32
)
@@ -68,7 +67,7 @@ func Server(s *etcdserver.EtcdServer, tls *tls.Config, interceptor grpc.UnarySer
opts = append(opts, grpc.MaxRecvMsgSize(int(s.Cfg.MaxRequestBytes+grpcOverheadBytes)))
opts = append(opts, grpc.MaxSendMsgSize(maxSendBytes))
opts = append(opts, grpc.MaxConcurrentStreams(maxStreams))
opts = append(opts, grpc.MaxConcurrentStreams(s.Cfg.MaxConcurrentStreams))
grpcServer := grpc.NewServer(append(opts, gopts...)...)

View File

@@ -66,19 +66,20 @@ type ClusterStatusGetter interface {
}
type maintenanceServer struct {
lg *zap.Logger
rg etcdserver.RaftStatusGetter
kg KVGetter
bg BackendGetter
a Alarmer
lt LeaderTransferrer
hdr header
cs ClusterStatusGetter
d Downgrader
lg *zap.Logger
rg etcdserver.RaftStatusGetter
hasher mvcc.HashStorage
kg KVGetter
bg BackendGetter
a Alarmer
lt LeaderTransferrer
hdr header
cs ClusterStatusGetter
d Downgrader
}
func NewMaintenanceServer(s *etcdserver.EtcdServer) pb.MaintenanceServer {
srv := &maintenanceServer{lg: s.Cfg.Logger, rg: s, kg: s, bg: s, a: s, lt: s, hdr: newHeader(s), cs: s, d: s}
srv := &maintenanceServer{lg: s.Cfg.Logger, rg: s, hasher: s.KV().HashStorage(), kg: s, bg: s, a: s, lt: s, hdr: newHeader(s), cs: s, d: s}
if srv.lg == nil {
srv.lg = zap.NewNop()
}
@@ -180,7 +181,7 @@ func (ms *maintenanceServer) Snapshot(sr *pb.SnapshotRequest, srv pb.Maintenance
}
func (ms *maintenanceServer) Hash(ctx context.Context, r *pb.HashRequest) (*pb.HashResponse, error) {
h, rev, err := ms.kg.KV().Hash()
h, rev, err := ms.hasher.Hash()
if err != nil {
return nil, togRPCError(err)
}
@@ -190,12 +191,12 @@ func (ms *maintenanceServer) Hash(ctx context.Context, r *pb.HashRequest) (*pb.H
}
func (ms *maintenanceServer) HashKV(ctx context.Context, r *pb.HashKVRequest) (*pb.HashKVResponse, error) {
h, rev, compactRev, err := ms.kg.KV().HashByRev(r.Revision)
h, rev, err := ms.hasher.HashByRev(r.Revision)
if err != nil {
return nil, togRPCError(err)
}
resp := &pb.HashKVResponse{Header: &pb.ResponseHeader{Revision: rev}, Hash: h, CompactRevision: compactRev}
resp := &pb.HashKVResponse{Header: &pb.ResponseHeader{Revision: rev}, Hash: h.Hash, CompactRevision: h.CompactRevision}
ms.hdr.fill(resp.Header)
return resp, nil
}

View File

@@ -428,6 +428,7 @@ func (a *applierV3backend) Range(ctx context.Context, txn mvcc.TxnRead, r *pb.Ra
}
func (a *applierV3backend) Txn(ctx context.Context, rt *pb.TxnRequest) (*pb.TxnResponse, *traceutil.Trace, error) {
lg := a.s.Logger()
trace := traceutil.Get(ctx)
if trace.IsEmpty() {
trace = traceutil.New("transaction", a.s.Logger())
@@ -474,7 +475,18 @@ func (a *applierV3backend) Txn(ctx context.Context, rt *pb.TxnRequest) (*pb.TxnR
txn.End()
txn = a.s.KV().Write(trace)
}
a.applyTxn(ctx, txn, rt, txnPath, txnResp)
_, err := a.applyTxn(ctx, txn, rt, txnPath, txnResp)
if err != nil {
if isWrite {
// end txn to release locks before panic
txn.End()
// When txn with write operations starts it has to be successful
// We don't have a way to recover state in case of write failure
lg.Panic("unexpected error during txn with writes", zap.Error(err))
} else {
lg.Error("unexpected error during readonly txn", zap.Error(err))
}
}
rev := txn.Rev()
if len(txn.Changes()) != 0 {
rev++
@@ -486,7 +498,7 @@ func (a *applierV3backend) Txn(ctx context.Context, rt *pb.TxnRequest) (*pb.TxnR
traceutil.Field{Key: "number_of_response", Value: len(txnResp.Responses)},
traceutil.Field{Key: "response_revision", Value: txnResp.Header.Revision},
)
return txnResp, trace, nil
return txnResp, trace, err
}
// newTxnResp allocates a txn response for a txn request given a path.
@@ -617,14 +629,13 @@ func compareKV(c *pb.Compare, ckv mvccpb.KeyValue) bool {
return true
}
func (a *applierV3backend) applyTxn(ctx context.Context, txn mvcc.TxnWrite, rt *pb.TxnRequest, txnPath []bool, tresp *pb.TxnResponse) (txns int) {
func (a *applierV3backend) applyTxn(ctx context.Context, txn mvcc.TxnWrite, rt *pb.TxnRequest, txnPath []bool, tresp *pb.TxnResponse) (txns int, err error) {
trace := traceutil.Get(ctx)
reqs := rt.Success
if !txnPath[0] {
reqs = rt.Failure
}
lg := a.s.Logger()
for i, req := range reqs {
respi := tresp.Responses[i].Response
switch tv := req.Request.(type) {
@@ -635,7 +646,7 @@ func (a *applierV3backend) applyTxn(ctx context.Context, txn mvcc.TxnWrite, rt *
traceutil.Field{Key: "range_end", Value: string(tv.RequestRange.RangeEnd)})
resp, err := a.Range(ctx, txn, tv.RequestRange)
if err != nil {
lg.Panic("unexpected error during txn", zap.Error(err))
return 0, fmt.Errorf("applyTxn: failed Range: %w", err)
}
respi.(*pb.ResponseOp_ResponseRange).ResponseRange = resp
trace.StopSubTrace()
@@ -646,26 +657,30 @@ func (a *applierV3backend) applyTxn(ctx context.Context, txn mvcc.TxnWrite, rt *
traceutil.Field{Key: "req_size", Value: tv.RequestPut.Size()})
resp, _, err := a.Put(ctx, txn, tv.RequestPut)
if err != nil {
lg.Panic("unexpected error during txn", zap.Error(err))
return 0, fmt.Errorf("applyTxn: failed Put: %w", err)
}
respi.(*pb.ResponseOp_ResponsePut).ResponsePut = resp
trace.StopSubTrace()
case *pb.RequestOp_RequestDeleteRange:
resp, err := a.DeleteRange(txn, tv.RequestDeleteRange)
if err != nil {
lg.Panic("unexpected error during txn", zap.Error(err))
return 0, fmt.Errorf("applyTxn: failed DeleteRange: %w", err)
}
respi.(*pb.ResponseOp_ResponseDeleteRange).ResponseDeleteRange = resp
case *pb.RequestOp_RequestTxn:
resp := respi.(*pb.ResponseOp_ResponseTxn).ResponseTxn
applyTxns := a.applyTxn(ctx, txn, tv.RequestTxn, txnPath[1:], resp)
applyTxns, err := a.applyTxn(ctx, txn, tv.RequestTxn, txnPath[1:], resp)
if err != nil {
// don't wrap the error. It's a recursive call and err should be already wrapped
return 0, err
}
txns += applyTxns + 1
txnPath = txnPath[applyTxns+1:]
default:
// empty union
}
}
return txns
return txns, nil
}
func (a *applierV3backend) Compaction(compaction *pb.CompactionRequest) (*pb.CompactionResponse, <-chan struct{}, *traceutil.Trace, error) {

Some files were not shown because too many files have changed in this diff Show More