Merge pull request #87253 from SaranBalaji90/update-aws-sdk

Update aws-sdk-go dependency to v1.28.2
This commit is contained in:
Kubernetes Prow Robot
2020-01-20 09:05:36 -08:00
committed by GitHub
137 changed files with 43354 additions and 5454 deletions

4
go.mod
View File

@@ -20,7 +20,7 @@ require (
github.com/Rican7/retry v0.1.0 // indirect github.com/Rican7/retry v0.1.0 // indirect
github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e
github.com/auth0/go-jwt-middleware v0.0.0-20170425171159-5493cabe49f7 // indirect github.com/auth0/go-jwt-middleware v0.0.0-20170425171159-5493cabe49f7 // indirect
github.com/aws/aws-sdk-go v1.16.26 github.com/aws/aws-sdk-go v1.28.2
github.com/bazelbuild/bazel-gazelle v0.19.1-0.20191105222053-70208cbdc798 github.com/bazelbuild/bazel-gazelle v0.19.1-0.20191105222053-70208cbdc798
github.com/bazelbuild/buildtools v0.0.0-20190917191645-69366ca98f89 github.com/bazelbuild/buildtools v0.0.0-20190917191645-69366ca98f89
github.com/blang/semver v3.5.0+incompatible github.com/blang/semver v3.5.0+incompatible
@@ -204,7 +204,7 @@ replace (
github.com/armon/consul-api => github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6 github.com/armon/consul-api => github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6
github.com/asaskevich/govalidator => github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a github.com/asaskevich/govalidator => github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a
github.com/auth0/go-jwt-middleware => github.com/auth0/go-jwt-middleware v0.0.0-20170425171159-5493cabe49f7 github.com/auth0/go-jwt-middleware => github.com/auth0/go-jwt-middleware v0.0.0-20170425171159-5493cabe49f7
github.com/aws/aws-sdk-go => github.com/aws/aws-sdk-go v1.16.26 github.com/aws/aws-sdk-go => github.com/aws/aws-sdk-go v1.28.2
github.com/bazelbuild/bazel-gazelle => github.com/bazelbuild/bazel-gazelle v0.19.1-0.20191105222053-70208cbdc798 github.com/bazelbuild/bazel-gazelle => github.com/bazelbuild/bazel-gazelle v0.19.1-0.20191105222053-70208cbdc798
github.com/bazelbuild/buildtools => github.com/bazelbuild/buildtools v0.0.0-20190917191645-69366ca98f89 github.com/bazelbuild/buildtools => github.com/bazelbuild/buildtools v0.0.0-20190917191645-69366ca98f89
github.com/bazelbuild/rules_go => github.com/bazelbuild/rules_go v0.0.0-20190719190356-6dae44dc5cab github.com/bazelbuild/rules_go => github.com/bazelbuild/rules_go v0.0.0-20190719190356-6dae44dc5cab

4
go.sum
View File

@@ -58,8 +58,8 @@ github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a h1:idn718Q4
github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY= github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY=
github.com/auth0/go-jwt-middleware v0.0.0-20170425171159-5493cabe49f7 h1:irR1cO6eek3n5uquIVaRAsQmZnlsfPuHNz31cXo4eyk= github.com/auth0/go-jwt-middleware v0.0.0-20170425171159-5493cabe49f7 h1:irR1cO6eek3n5uquIVaRAsQmZnlsfPuHNz31cXo4eyk=
github.com/auth0/go-jwt-middleware v0.0.0-20170425171159-5493cabe49f7/go.mod h1:LWMyo4iOLWXHGdBki7NIht1kHru/0wM179h+d3g8ATM= github.com/auth0/go-jwt-middleware v0.0.0-20170425171159-5493cabe49f7/go.mod h1:LWMyo4iOLWXHGdBki7NIht1kHru/0wM179h+d3g8ATM=
github.com/aws/aws-sdk-go v1.16.26 h1:GWkl3rkRO/JGRTWoLLIqwf7AWC4/W/1hMOUZqmX0js4= github.com/aws/aws-sdk-go v1.28.2 h1:j5IXG9CdyLfcVfICqo1PXVv+rua+QQHbkXuvuU/JF+8=
github.com/aws/aws-sdk-go v1.16.26/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= github.com/aws/aws-sdk-go v1.28.2/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo=
github.com/bazelbuild/bazel-gazelle v0.19.1-0.20191105222053-70208cbdc798 h1:1pcd3bq1rC5X4nz56fdK0oQrz9CZup6DGZl8isCPthk= github.com/bazelbuild/bazel-gazelle v0.19.1-0.20191105222053-70208cbdc798 h1:1pcd3bq1rC5X4nz56fdK0oQrz9CZup6DGZl8isCPthk=
github.com/bazelbuild/bazel-gazelle v0.19.1-0.20191105222053-70208cbdc798/go.mod h1:rPwzNHUqEzngx1iVBfO/2X2npKaT3tqPqqHW6rVsn/A= github.com/bazelbuild/bazel-gazelle v0.19.1-0.20191105222053-70208cbdc798/go.mod h1:rPwzNHUqEzngx1iVBfO/2X2npKaT3tqPqqHW6rVsn/A=
github.com/bazelbuild/buildtools v0.0.0-20190917191645-69366ca98f89 h1:3B/ZE1a6eEJ/4Jf/M6RM2KBouN8yKCUcMmXzSyWqa3g= github.com/bazelbuild/buildtools v0.0.0-20190917191645-69366ca98f89 h1:3B/ZE1a6eEJ/4Jf/M6RM2KBouN8yKCUcMmXzSyWqa3g=

View File

@@ -13,7 +13,7 @@ require (
github.com/Azure/go-autorest/autorest/to v0.2.0 github.com/Azure/go-autorest/autorest/to v0.2.0
github.com/Azure/go-autorest/autorest/validation v0.1.0 // indirect github.com/Azure/go-autorest/autorest/validation v0.1.0 // indirect
github.com/GoogleCloudPlatform/k8s-cloud-provider v0.0.0-20190822182118-27a4ced34534 github.com/GoogleCloudPlatform/k8s-cloud-provider v0.0.0-20190822182118-27a4ced34534
github.com/aws/aws-sdk-go v1.16.26 github.com/aws/aws-sdk-go v1.28.2
github.com/dnaeon/go-vcr v1.0.1 // indirect github.com/dnaeon/go-vcr v1.0.1 // indirect
github.com/golang/mock v1.3.1 github.com/golang/mock v1.3.1
github.com/gophercloud/gophercloud v0.1.0 github.com/gophercloud/gophercloud v0.1.0

View File

@@ -33,8 +33,8 @@ github.com/PuerkitoBio/urlesc v0.0.0-20160726150825-5bd2802263f2/go.mod h1:uGdko
github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE= github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE=
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
github.com/aws/aws-sdk-go v1.16.26 h1:GWkl3rkRO/JGRTWoLLIqwf7AWC4/W/1hMOUZqmX0js4= github.com/aws/aws-sdk-go v1.28.2 h1:j5IXG9CdyLfcVfICqo1PXVv+rua+QQHbkXuvuU/JF+8=
github.com/aws/aws-sdk-go v1.16.26/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= github.com/aws/aws-sdk-go v1.28.2/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo=
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973 h1:xJ4a3vCFaGF/jqvzLMYoU8P317H5OQ+Via4RmuPwCS0= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973 h1:xJ4a3vCFaGF/jqvzLMYoU8P317H5OQ+Via4RmuPwCS0=
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
github.com/beorn7/perks v1.0.0 h1:HWo1m869IqiPhD389kmkxeTalrjNbbJTC8LXupb+sl0= github.com/beorn7/perks v1.0.0 h1:HWo1m869IqiPhD389kmkxeTalrjNbbJTC8LXupb+sl0=

2
vendor/BUILD vendored
View File

@@ -35,9 +35,11 @@ filegroup(
"//vendor/github.com/aws/aws-sdk-go/aws:all-srcs", "//vendor/github.com/aws/aws-sdk-go/aws:all-srcs",
"//vendor/github.com/aws/aws-sdk-go/internal/ini:all-srcs", "//vendor/github.com/aws/aws-sdk-go/internal/ini:all-srcs",
"//vendor/github.com/aws/aws-sdk-go/internal/sdkio:all-srcs", "//vendor/github.com/aws/aws-sdk-go/internal/sdkio:all-srcs",
"//vendor/github.com/aws/aws-sdk-go/internal/sdkmath:all-srcs",
"//vendor/github.com/aws/aws-sdk-go/internal/sdkrand:all-srcs", "//vendor/github.com/aws/aws-sdk-go/internal/sdkrand:all-srcs",
"//vendor/github.com/aws/aws-sdk-go/internal/sdkuri:all-srcs", "//vendor/github.com/aws/aws-sdk-go/internal/sdkuri:all-srcs",
"//vendor/github.com/aws/aws-sdk-go/internal/shareddefaults:all-srcs", "//vendor/github.com/aws/aws-sdk-go/internal/shareddefaults:all-srcs",
"//vendor/github.com/aws/aws-sdk-go/internal/strings:all-srcs",
"//vendor/github.com/aws/aws-sdk-go/private/protocol:all-srcs", "//vendor/github.com/aws/aws-sdk-go/private/protocol:all-srcs",
"//vendor/github.com/aws/aws-sdk-go/service/autoscaling:all-srcs", "//vendor/github.com/aws/aws-sdk-go/service/autoscaling:all-srcs",
"//vendor/github.com/aws/aws-sdk-go/service/ec2:all-srcs", "//vendor/github.com/aws/aws-sdk-go/service/ec2:all-srcs",

View File

@@ -138,8 +138,27 @@ type RequestFailure interface {
RequestID() string RequestID() string
} }
// NewRequestFailure returns a new request error wrapper for the given Error // NewRequestFailure returns a wrapped error with additional information for
// provided. // request status code, and service requestID.
//
// Should be used to wrap all request which involve service requests. Even if
// the request failed without a service response, but had an HTTP status code
// that may be meaningful.
func NewRequestFailure(err Error, statusCode int, reqID string) RequestFailure { func NewRequestFailure(err Error, statusCode int, reqID string) RequestFailure {
return newRequestError(err, statusCode, reqID) return newRequestError(err, statusCode, reqID)
} }
// UnmarshalError provides the interface for the SDK failing to unmarshal data.
type UnmarshalError interface {
awsError
Bytes() []byte
}
// NewUnmarshalError returns an initialized UnmarshalError error wrapper adding
// the bytes that fail to unmarshal to the error.
func NewUnmarshalError(err error, msg string, bytes []byte) UnmarshalError {
return &unmarshalError{
awsError: New("UnmarshalError", msg, err),
bytes: bytes,
}
}

View File

@@ -1,6 +1,9 @@
package awserr package awserr
import "fmt" import (
"encoding/hex"
"fmt"
)
// SprintError returns a string of the formatted error code. // SprintError returns a string of the formatted error code.
// //
@@ -119,6 +122,7 @@ type requestError struct {
awsError awsError
statusCode int statusCode int
requestID string requestID string
bytes []byte
} }
// newRequestError returns a wrapped error with additional information for // newRequestError returns a wrapped error with additional information for
@@ -170,6 +174,29 @@ func (r requestError) OrigErrs() []error {
return []error{r.OrigErr()} return []error{r.OrigErr()}
} }
type unmarshalError struct {
awsError
bytes []byte
}
// Error returns the string representation of the error.
// Satisfies the error interface.
func (e unmarshalError) Error() string {
extra := hex.Dump(e.bytes)
return SprintError(e.Code(), e.Message(), extra, e.OrigErr())
}
// String returns the string representation of the error.
// Alias for Error to satisfy the stringer interface.
func (e unmarshalError) String() string {
return e.Error()
}
// Bytes returns the bytes that failed to unmarshal.
func (e unmarshalError) Bytes() []byte {
return e.bytes
}
// An error list that satisfies the golang interface // An error list that satisfies the golang interface
type errorList []error type errorList []error
@@ -181,7 +208,7 @@ func (e errorList) Error() string {
// How do we want to handle the array size being zero // How do we want to handle the array size being zero
if size := len(e); size > 0 { if size := len(e); size > 0 {
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
msg += fmt.Sprintf("%s", e[i].Error()) msg += e[i].Error()
// We check the next index to see if it is within the slice. // We check the next index to see if it is within the slice.
// If it is, then we append a newline. We do this, because unit tests // If it is, then we append a newline. We do this, because unit tests
// could be broken with the additional '\n' // could be broken with the additional '\n'

View File

@@ -15,7 +15,7 @@ func DeepEqual(a, b interface{}) bool {
rb := reflect.Indirect(reflect.ValueOf(b)) rb := reflect.Indirect(reflect.ValueOf(b))
if raValid, rbValid := ra.IsValid(), rb.IsValid(); !raValid && !rbValid { if raValid, rbValid := ra.IsValid(), rb.IsValid(); !raValid && !rbValid {
// If the elements are both nil, and of the same type the are equal // If the elements are both nil, and of the same type they are equal
// If they are of different types they are not equal // If they are of different types they are not equal
return reflect.TypeOf(a) == reflect.TypeOf(b) return reflect.TypeOf(a) == reflect.TypeOf(b)
} else if raValid != rbValid { } else if raValid != rbValid {

View File

@@ -70,7 +70,7 @@ func rValuesAtPath(v interface{}, path string, createPath, caseSensitive, nilTer
value = value.FieldByNameFunc(func(name string) bool { value = value.FieldByNameFunc(func(name string) bool {
if c == name { if c == name {
return true return true
} else if !caseSensitive && strings.ToLower(name) == strings.ToLower(c) { } else if !caseSensitive && strings.EqualFold(name, c) {
return true return true
} }
return false return false
@@ -185,14 +185,13 @@ func ValuesAtPath(i interface{}, path string) ([]interface{}, error) {
// SetValueAtPath sets a value at the case insensitive lexical path inside // SetValueAtPath sets a value at the case insensitive lexical path inside
// of a structure. // of a structure.
func SetValueAtPath(i interface{}, path string, v interface{}) { func SetValueAtPath(i interface{}, path string, v interface{}) {
if rvals := rValuesAtPath(i, path, true, false, v == nil); rvals != nil { rvals := rValuesAtPath(i, path, true, false, v == nil)
for _, rval := range rvals { for _, rval := range rvals {
if rval.Kind() == reflect.Ptr && rval.IsNil() { if rval.Kind() == reflect.Ptr && rval.IsNil() {
continue continue
} }
setValue(rval, v) setValue(rval, v)
} }
}
} }
func setValue(dstVal reflect.Value, src interface{}) { func setValue(dstVal reflect.Value, src interface{}) {

View File

@@ -6,6 +6,7 @@ go_library(
"client.go", "client.go",
"default_retryer.go", "default_retryer.go",
"logger.go", "logger.go",
"no_op_retryer.go",
], ],
importmap = "k8s.io/kubernetes/vendor/github.com/aws/aws-sdk-go/aws/client", importmap = "k8s.io/kubernetes/vendor/github.com/aws/aws-sdk-go/aws/client",
importpath = "github.com/aws/aws-sdk-go/aws/client", importpath = "github.com/aws/aws-sdk-go/aws/client",

View File

@@ -12,6 +12,7 @@ import (
type Config struct { type Config struct {
Config *aws.Config Config *aws.Config
Handlers request.Handlers Handlers request.Handlers
PartitionID string
Endpoint string Endpoint string
SigningRegion string SigningRegion string
SigningName string SigningName string
@@ -64,7 +65,7 @@ func New(cfg aws.Config, info metadata.ClientInfo, handlers request.Handlers, op
default: default:
maxRetries := aws.IntValue(cfg.MaxRetries) maxRetries := aws.IntValue(cfg.MaxRetries)
if cfg.MaxRetries == nil || maxRetries == aws.UseServiceDefaultRetries { if cfg.MaxRetries == nil || maxRetries == aws.UseServiceDefaultRetries {
maxRetries = 3 maxRetries = DefaultRetryerMaxNumRetries
} }
svc.Retryer = DefaultRetryer{NumMaxRetries: maxRetries} svc.Retryer = DefaultRetryer{NumMaxRetries: maxRetries}
} }

View File

@@ -1,6 +1,7 @@
package client package client
import ( import (
"math"
"strconv" "strconv"
"time" "time"
@@ -9,82 +10,142 @@ import (
) )
// DefaultRetryer implements basic retry logic using exponential backoff for // DefaultRetryer implements basic retry logic using exponential backoff for
// most services. If you want to implement custom retry logic, implement the // most services. If you want to implement custom retry logic, you can implement the
// request.Retryer interface or create a structure type that composes this // request.Retryer interface.
// struct and override the specific methods. For example, to override only
// the MaxRetries method:
// //
// type retryer struct {
// client.DefaultRetryer
// }
//
// // This implementation always has 100 max retries
// func (d retryer) MaxRetries() int { return 100 }
type DefaultRetryer struct { type DefaultRetryer struct {
// Num max Retries is the number of max retries that will be performed.
// By default, this is zero.
NumMaxRetries int NumMaxRetries int
// MinRetryDelay is the minimum retry delay after which retry will be performed.
// If not set, the value is 0ns.
MinRetryDelay time.Duration
// MinThrottleRetryDelay is the minimum retry delay when throttled.
// If not set, the value is 0ns.
MinThrottleDelay time.Duration
// MaxRetryDelay is the maximum retry delay before which retry must be performed.
// If not set, the value is 0ns.
MaxRetryDelay time.Duration
// MaxThrottleDelay is the maximum retry delay when throttled.
// If not set, the value is 0ns.
MaxThrottleDelay time.Duration
} }
const (
// DefaultRetryerMaxNumRetries sets maximum number of retries
DefaultRetryerMaxNumRetries = 3
// DefaultRetryerMinRetryDelay sets minimum retry delay
DefaultRetryerMinRetryDelay = 30 * time.Millisecond
// DefaultRetryerMinThrottleDelay sets minimum delay when throttled
DefaultRetryerMinThrottleDelay = 500 * time.Millisecond
// DefaultRetryerMaxRetryDelay sets maximum retry delay
DefaultRetryerMaxRetryDelay = 300 * time.Second
// DefaultRetryerMaxThrottleDelay sets maximum delay when throttled
DefaultRetryerMaxThrottleDelay = 300 * time.Second
)
// MaxRetries returns the number of maximum returns the service will use to make // MaxRetries returns the number of maximum returns the service will use to make
// an individual API request. // an individual API request.
func (d DefaultRetryer) MaxRetries() int { func (d DefaultRetryer) MaxRetries() int {
return d.NumMaxRetries return d.NumMaxRetries
} }
// setRetryerDefaults sets the default values of the retryer if not set
func (d *DefaultRetryer) setRetryerDefaults() {
if d.MinRetryDelay == 0 {
d.MinRetryDelay = DefaultRetryerMinRetryDelay
}
if d.MaxRetryDelay == 0 {
d.MaxRetryDelay = DefaultRetryerMaxRetryDelay
}
if d.MinThrottleDelay == 0 {
d.MinThrottleDelay = DefaultRetryerMinThrottleDelay
}
if d.MaxThrottleDelay == 0 {
d.MaxThrottleDelay = DefaultRetryerMaxThrottleDelay
}
}
// RetryRules returns the delay duration before retrying this request again // RetryRules returns the delay duration before retrying this request again
func (d DefaultRetryer) RetryRules(r *request.Request) time.Duration { func (d DefaultRetryer) RetryRules(r *request.Request) time.Duration {
// Set the upper limit of delay in retrying at ~five minutes
minTime := 30 // if number of max retries is zero, no retries will be performed.
throttle := d.shouldThrottle(r) if d.NumMaxRetries == 0 {
if throttle { return 0
if delay, ok := getRetryDelay(r); ok {
return delay
} }
minTime = 500 // Sets default value for retryer members
d.setRetryerDefaults()
// minDelay is the minimum retryer delay
minDelay := d.MinRetryDelay
var initialDelay time.Duration
isThrottle := r.IsErrorThrottle()
if isThrottle {
if delay, ok := getRetryAfterDelay(r); ok {
initialDelay = delay
}
minDelay = d.MinThrottleDelay
} }
retryCount := r.RetryCount retryCount := r.RetryCount
if throttle && retryCount > 8 {
retryCount = 8 // maxDelay the maximum retryer delay
} else if retryCount > 13 { maxDelay := d.MaxRetryDelay
retryCount = 13
if isThrottle {
maxDelay = d.MaxThrottleDelay
} }
delay := (1 << uint(retryCount)) * (sdkrand.SeededRand.Intn(minTime) + minTime) var delay time.Duration
return time.Duration(delay) * time.Millisecond
// Logic to cap the retry count based on the minDelay provided
actualRetryCount := int(math.Log2(float64(minDelay))) + 1
if actualRetryCount < 63-retryCount {
delay = time.Duration(1<<uint64(retryCount)) * getJitterDelay(minDelay)
if delay > maxDelay {
delay = getJitterDelay(maxDelay / 2)
}
} else {
delay = getJitterDelay(maxDelay / 2)
}
return delay + initialDelay
}
// getJitterDelay returns a jittered delay for retry
func getJitterDelay(duration time.Duration) time.Duration {
return time.Duration(sdkrand.SeededRand.Int63n(int64(duration)) + int64(duration))
} }
// ShouldRetry returns true if the request should be retried. // ShouldRetry returns true if the request should be retried.
func (d DefaultRetryer) ShouldRetry(r *request.Request) bool { func (d DefaultRetryer) ShouldRetry(r *request.Request) bool {
// ShouldRetry returns false if number of max retries is 0.
if d.NumMaxRetries == 0 {
return false
}
// If one of the other handlers already set the retry state // If one of the other handlers already set the retry state
// we don't want to override it based on the service's state // we don't want to override it based on the service's state
if r.Retryable != nil { if r.Retryable != nil {
return *r.Retryable return *r.Retryable
} }
return r.IsErrorRetryable() || r.IsErrorThrottle()
if r.HTTPResponse.StatusCode >= 500 && r.HTTPResponse.StatusCode != 501 {
return true
}
return r.IsErrorRetryable() || d.shouldThrottle(r)
}
// ShouldThrottle returns true if the request should be throttled.
func (d DefaultRetryer) shouldThrottle(r *request.Request) bool {
switch r.HTTPResponse.StatusCode {
case 429:
case 502:
case 503:
case 504:
default:
return r.IsErrorThrottle()
}
return true
} }
// This will look in the Retry-After header, RFC 7231, for how long // This will look in the Retry-After header, RFC 7231, for how long
// it will wait before attempting another request // it will wait before attempting another request
func getRetryDelay(r *request.Request) (time.Duration, bool) { func getRetryAfterDelay(r *request.Request) (time.Duration, bool) {
if !canUseRetryAfterHeader(r) { if !canUseRetryAfterHeader(r) {
return 0, false return 0, false
} }

View File

@@ -67,10 +67,14 @@ func logRequest(r *request.Request) {
if !bodySeekable { if !bodySeekable {
r.SetReaderBody(aws.ReadSeekCloser(r.HTTPRequest.Body)) r.SetReaderBody(aws.ReadSeekCloser(r.HTTPRequest.Body))
} }
// Reset the request body because dumpRequest will re-wrap the r.HTTPRequest's // Reset the request body because dumpRequest will re-wrap the
// Body as a NoOpCloser and will not be reset after read by the HTTP // r.HTTPRequest's Body as a NoOpCloser and will not be reset after
// client reader. // read by the HTTP client reader.
r.ResetBody() if err := r.Error; err != nil {
r.Config.Logger.Log(fmt.Sprintf(logReqErrMsg,
r.ClientInfo.ServiceName, r.Operation.Name, err))
return
}
} }
r.Config.Logger.Log(fmt.Sprintf(logReqMsg, r.Config.Logger.Log(fmt.Sprintf(logReqMsg,
@@ -118,6 +122,12 @@ var LogHTTPResponseHandler = request.NamedHandler{
func logResponse(r *request.Request) { func logResponse(r *request.Request) {
lw := &logWriter{r.Config.Logger, bytes.NewBuffer(nil)} lw := &logWriter{r.Config.Logger, bytes.NewBuffer(nil)}
if r.HTTPResponse == nil {
lw.Logger.Log(fmt.Sprintf(logRespErrMsg,
r.ClientInfo.ServiceName, r.Operation.Name, "request's HTTPResponse is nil"))
return
}
logBody := r.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody) logBody := r.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody)
if logBody { if logBody {
r.HTTPResponse.Body = &teeReaderCloser{ r.HTTPResponse.Body = &teeReaderCloser{

View File

@@ -5,6 +5,7 @@ type ClientInfo struct {
ServiceName string ServiceName string
ServiceID string ServiceID string
APIVersion string APIVersion string
PartitionID string
Endpoint string Endpoint string
SigningName string SigningName string
SigningRegion string SigningRegion string

View File

@@ -0,0 +1,28 @@
package client
import (
"time"
"github.com/aws/aws-sdk-go/aws/request"
)
// NoOpRetryer provides a retryer that performs no retries.
// It should be used when we do not want retries to be performed.
type NoOpRetryer struct{}
// MaxRetries returns the number of maximum returns the service will use to make
// an individual API; For NoOpRetryer the MaxRetries will always be zero.
func (d NoOpRetryer) MaxRetries() int {
return 0
}
// ShouldRetry will always return false for NoOpRetryer, as it should never retry.
func (d NoOpRetryer) ShouldRetry(_ *request.Request) bool {
return false
}
// RetryRules returns the delay duration before retrying this request again;
// since NoOpRetryer does not retry, RetryRules always returns 0.
func (d NoOpRetryer) RetryRules(_ *request.Request) time.Duration {
return 0
}

View File

@@ -20,7 +20,7 @@ type RequestRetryer interface{}
// A Config provides service configuration for service clients. By default, // A Config provides service configuration for service clients. By default,
// all clients will use the defaults.DefaultConfig structure. // all clients will use the defaults.DefaultConfig structure.
// //
// // Create Session with MaxRetry configuration to be shared by multiple // // Create Session with MaxRetries configuration to be shared by multiple
// // service clients. // // service clients.
// sess := session.Must(session.NewSession(&aws.Config{ // sess := session.Must(session.NewSession(&aws.Config{
// MaxRetries: aws.Int(3), // MaxRetries: aws.Int(3),
@@ -161,6 +161,17 @@ type Config struct {
// on GetObject API calls. // on GetObject API calls.
S3DisableContentMD5Validation *bool S3DisableContentMD5Validation *bool
// Set this to `true` to have the S3 service client to use the region specified
// in the ARN, when an ARN is provided as an argument to a bucket parameter.
S3UseARNRegion *bool
// Set this to `true` to enable the SDK to unmarshal API response header maps to
// normalized lower case map keys.
//
// For example S3's X-Amz-Meta prefixed header will be unmarshaled to lower case
// Metadata member's map keys. The value of the header in the map is unaffected.
LowerCaseHeaderMaps *bool
// Set this to `true` to disable the EC2Metadata client from overriding the // Set this to `true` to disable the EC2Metadata client from overriding the
// default http.Client's Timeout. This is helpful if you do not want the // default http.Client's Timeout. This is helpful if you do not want the
// EC2Metadata client to create a new http.Client. This options is only // EC2Metadata client to create a new http.Client. This options is only
@@ -246,12 +257,18 @@ type Config struct {
// Disabling this feature is useful when you want to use local endpoints // Disabling this feature is useful when you want to use local endpoints
// for testing that do not support the modeled host prefix pattern. // for testing that do not support the modeled host prefix pattern.
DisableEndpointHostPrefix *bool DisableEndpointHostPrefix *bool
// STSRegionalEndpoint will enable regional or legacy endpoint resolving
STSRegionalEndpoint endpoints.STSRegionalEndpoint
// S3UsEast1RegionalEndpoint will enable regional or legacy endpoint resolving
S3UsEast1RegionalEndpoint endpoints.S3UsEast1RegionalEndpoint
} }
// NewConfig returns a new Config pointer that can be chained with builder // NewConfig returns a new Config pointer that can be chained with builder
// methods to set multiple configuration values inline without using pointers. // methods to set multiple configuration values inline without using pointers.
// //
// // Create Session with MaxRetry configuration to be shared by multiple // // Create Session with MaxRetries configuration to be shared by multiple
// // service clients. // // service clients.
// sess := session.Must(session.NewSession(aws.NewConfig(). // sess := session.Must(session.NewSession(aws.NewConfig().
// WithMaxRetries(3), // WithMaxRetries(3),
@@ -379,6 +396,13 @@ func (c *Config) WithS3DisableContentMD5Validation(enable bool) *Config {
} }
// WithS3UseARNRegion sets a config S3UseARNRegion value and
// returning a Config pointer for chaining
func (c *Config) WithS3UseARNRegion(enable bool) *Config {
c.S3UseARNRegion = &enable
return c
}
// WithUseDualStack sets a config UseDualStack value returning a Config // WithUseDualStack sets a config UseDualStack value returning a Config
// pointer for chaining. // pointer for chaining.
func (c *Config) WithUseDualStack(enable bool) *Config { func (c *Config) WithUseDualStack(enable bool) *Config {
@@ -420,6 +444,20 @@ func (c *Config) MergeIn(cfgs ...*Config) {
} }
} }
// WithSTSRegionalEndpoint will set whether or not to use regional endpoint flag
// when resolving the endpoint for a service
func (c *Config) WithSTSRegionalEndpoint(sre endpoints.STSRegionalEndpoint) *Config {
c.STSRegionalEndpoint = sre
return c
}
// WithS3UsEast1RegionalEndpoint will set whether or not to use regional endpoint flag
// when resolving the endpoint for a service
func (c *Config) WithS3UsEast1RegionalEndpoint(sre endpoints.S3UsEast1RegionalEndpoint) *Config {
c.S3UsEast1RegionalEndpoint = sre
return c
}
func mergeInConfig(dst *Config, other *Config) { func mergeInConfig(dst *Config, other *Config) {
if other == nil { if other == nil {
return return
@@ -493,6 +531,10 @@ func mergeInConfig(dst *Config, other *Config) {
dst.S3DisableContentMD5Validation = other.S3DisableContentMD5Validation dst.S3DisableContentMD5Validation = other.S3DisableContentMD5Validation
} }
if other.S3UseARNRegion != nil {
dst.S3UseARNRegion = other.S3UseARNRegion
}
if other.UseDualStack != nil { if other.UseDualStack != nil {
dst.UseDualStack = other.UseDualStack dst.UseDualStack = other.UseDualStack
} }
@@ -520,6 +562,14 @@ func mergeInConfig(dst *Config, other *Config) {
if other.DisableEndpointHostPrefix != nil { if other.DisableEndpointHostPrefix != nil {
dst.DisableEndpointHostPrefix = other.DisableEndpointHostPrefix dst.DisableEndpointHostPrefix = other.DisableEndpointHostPrefix
} }
if other.STSRegionalEndpoint != endpoints.UnsetSTSEndpoint {
dst.STSRegionalEndpoint = other.STSRegionalEndpoint
}
if other.S3UsEast1RegionalEndpoint != endpoints.UnsetS3UsEast1Endpoint {
dst.S3UsEast1RegionalEndpoint = other.S3UsEast1RegionalEndpoint
}
} }
// Copy will return a shallow copy of the Config object. If any additional // Copy will return a shallow copy of the Config object. If any additional

View File

@@ -179,6 +179,242 @@ func IntValueMap(src map[string]*int) map[string]int {
return dst return dst
} }
// Uint returns a pointer to the uint value passed in.
func Uint(v uint) *uint {
return &v
}
// UintValue returns the value of the uint pointer passed in or
// 0 if the pointer is nil.
func UintValue(v *uint) uint {
if v != nil {
return *v
}
return 0
}
// UintSlice converts a slice of uint values uinto a slice of
// uint pointers
func UintSlice(src []uint) []*uint {
dst := make([]*uint, len(src))
for i := 0; i < len(src); i++ {
dst[i] = &(src[i])
}
return dst
}
// UintValueSlice converts a slice of uint pointers uinto a slice of
// uint values
func UintValueSlice(src []*uint) []uint {
dst := make([]uint, len(src))
for i := 0; i < len(src); i++ {
if src[i] != nil {
dst[i] = *(src[i])
}
}
return dst
}
// UintMap converts a string map of uint values uinto a string
// map of uint pointers
func UintMap(src map[string]uint) map[string]*uint {
dst := make(map[string]*uint)
for k, val := range src {
v := val
dst[k] = &v
}
return dst
}
// UintValueMap converts a string map of uint pointers uinto a string
// map of uint values
func UintValueMap(src map[string]*uint) map[string]uint {
dst := make(map[string]uint)
for k, val := range src {
if val != nil {
dst[k] = *val
}
}
return dst
}
// Int8 returns a pointer to the int8 value passed in.
func Int8(v int8) *int8 {
return &v
}
// Int8Value returns the value of the int8 pointer passed in or
// 0 if the pointer is nil.
func Int8Value(v *int8) int8 {
if v != nil {
return *v
}
return 0
}
// Int8Slice converts a slice of int8 values into a slice of
// int8 pointers
func Int8Slice(src []int8) []*int8 {
dst := make([]*int8, len(src))
for i := 0; i < len(src); i++ {
dst[i] = &(src[i])
}
return dst
}
// Int8ValueSlice converts a slice of int8 pointers into a slice of
// int8 values
func Int8ValueSlice(src []*int8) []int8 {
dst := make([]int8, len(src))
for i := 0; i < len(src); i++ {
if src[i] != nil {
dst[i] = *(src[i])
}
}
return dst
}
// Int8Map converts a string map of int8 values into a string
// map of int8 pointers
func Int8Map(src map[string]int8) map[string]*int8 {
dst := make(map[string]*int8)
for k, val := range src {
v := val
dst[k] = &v
}
return dst
}
// Int8ValueMap converts a string map of int8 pointers into a string
// map of int8 values
func Int8ValueMap(src map[string]*int8) map[string]int8 {
dst := make(map[string]int8)
for k, val := range src {
if val != nil {
dst[k] = *val
}
}
return dst
}
// Int16 returns a pointer to the int16 value passed in.
func Int16(v int16) *int16 {
return &v
}
// Int16Value returns the value of the int16 pointer passed in or
// 0 if the pointer is nil.
func Int16Value(v *int16) int16 {
if v != nil {
return *v
}
return 0
}
// Int16Slice converts a slice of int16 values into a slice of
// int16 pointers
func Int16Slice(src []int16) []*int16 {
dst := make([]*int16, len(src))
for i := 0; i < len(src); i++ {
dst[i] = &(src[i])
}
return dst
}
// Int16ValueSlice converts a slice of int16 pointers into a slice of
// int16 values
func Int16ValueSlice(src []*int16) []int16 {
dst := make([]int16, len(src))
for i := 0; i < len(src); i++ {
if src[i] != nil {
dst[i] = *(src[i])
}
}
return dst
}
// Int16Map converts a string map of int16 values into a string
// map of int16 pointers
func Int16Map(src map[string]int16) map[string]*int16 {
dst := make(map[string]*int16)
for k, val := range src {
v := val
dst[k] = &v
}
return dst
}
// Int16ValueMap converts a string map of int16 pointers into a string
// map of int16 values
func Int16ValueMap(src map[string]*int16) map[string]int16 {
dst := make(map[string]int16)
for k, val := range src {
if val != nil {
dst[k] = *val
}
}
return dst
}
// Int32 returns a pointer to the int32 value passed in.
func Int32(v int32) *int32 {
return &v
}
// Int32Value returns the value of the int32 pointer passed in or
// 0 if the pointer is nil.
func Int32Value(v *int32) int32 {
if v != nil {
return *v
}
return 0
}
// Int32Slice converts a slice of int32 values into a slice of
// int32 pointers
func Int32Slice(src []int32) []*int32 {
dst := make([]*int32, len(src))
for i := 0; i < len(src); i++ {
dst[i] = &(src[i])
}
return dst
}
// Int32ValueSlice converts a slice of int32 pointers into a slice of
// int32 values
func Int32ValueSlice(src []*int32) []int32 {
dst := make([]int32, len(src))
for i := 0; i < len(src); i++ {
if src[i] != nil {
dst[i] = *(src[i])
}
}
return dst
}
// Int32Map converts a string map of int32 values into a string
// map of int32 pointers
func Int32Map(src map[string]int32) map[string]*int32 {
dst := make(map[string]*int32)
for k, val := range src {
v := val
dst[k] = &v
}
return dst
}
// Int32ValueMap converts a string map of int32 pointers into a string
// map of int32 values
func Int32ValueMap(src map[string]*int32) map[string]int32 {
dst := make(map[string]int32)
for k, val := range src {
if val != nil {
dst[k] = *val
}
}
return dst
}
// Int64 returns a pointer to the int64 value passed in. // Int64 returns a pointer to the int64 value passed in.
func Int64(v int64) *int64 { func Int64(v int64) *int64 {
return &v return &v
@@ -238,6 +474,301 @@ func Int64ValueMap(src map[string]*int64) map[string]int64 {
return dst return dst
} }
// Uint8 returns a pointer to the uint8 value passed in.
func Uint8(v uint8) *uint8 {
return &v
}
// Uint8Value returns the value of the uint8 pointer passed in or
// 0 if the pointer is nil.
func Uint8Value(v *uint8) uint8 {
if v != nil {
return *v
}
return 0
}
// Uint8Slice converts a slice of uint8 values into a slice of
// uint8 pointers
func Uint8Slice(src []uint8) []*uint8 {
dst := make([]*uint8, len(src))
for i := 0; i < len(src); i++ {
dst[i] = &(src[i])
}
return dst
}
// Uint8ValueSlice converts a slice of uint8 pointers into a slice of
// uint8 values
func Uint8ValueSlice(src []*uint8) []uint8 {
dst := make([]uint8, len(src))
for i := 0; i < len(src); i++ {
if src[i] != nil {
dst[i] = *(src[i])
}
}
return dst
}
// Uint8Map converts a string map of uint8 values into a string
// map of uint8 pointers
func Uint8Map(src map[string]uint8) map[string]*uint8 {
dst := make(map[string]*uint8)
for k, val := range src {
v := val
dst[k] = &v
}
return dst
}
// Uint8ValueMap converts a string map of uint8 pointers into a string
// map of uint8 values
func Uint8ValueMap(src map[string]*uint8) map[string]uint8 {
dst := make(map[string]uint8)
for k, val := range src {
if val != nil {
dst[k] = *val
}
}
return dst
}
// Uint16 returns a pointer to the uint16 value passed in.
func Uint16(v uint16) *uint16 {
return &v
}
// Uint16Value returns the value of the uint16 pointer passed in or
// 0 if the pointer is nil.
func Uint16Value(v *uint16) uint16 {
if v != nil {
return *v
}
return 0
}
// Uint16Slice converts a slice of uint16 values into a slice of
// uint16 pointers
func Uint16Slice(src []uint16) []*uint16 {
dst := make([]*uint16, len(src))
for i := 0; i < len(src); i++ {
dst[i] = &(src[i])
}
return dst
}
// Uint16ValueSlice converts a slice of uint16 pointers into a slice of
// uint16 values
func Uint16ValueSlice(src []*uint16) []uint16 {
dst := make([]uint16, len(src))
for i := 0; i < len(src); i++ {
if src[i] != nil {
dst[i] = *(src[i])
}
}
return dst
}
// Uint16Map converts a string map of uint16 values into a string
// map of uint16 pointers
func Uint16Map(src map[string]uint16) map[string]*uint16 {
dst := make(map[string]*uint16)
for k, val := range src {
v := val
dst[k] = &v
}
return dst
}
// Uint16ValueMap converts a string map of uint16 pointers into a string
// map of uint16 values
func Uint16ValueMap(src map[string]*uint16) map[string]uint16 {
dst := make(map[string]uint16)
for k, val := range src {
if val != nil {
dst[k] = *val
}
}
return dst
}
// Uint32 returns a pointer to the uint32 value passed in.
func Uint32(v uint32) *uint32 {
return &v
}
// Uint32Value returns the value of the uint32 pointer passed in or
// 0 if the pointer is nil.
func Uint32Value(v *uint32) uint32 {
if v != nil {
return *v
}
return 0
}
// Uint32Slice converts a slice of uint32 values into a slice of
// uint32 pointers
func Uint32Slice(src []uint32) []*uint32 {
dst := make([]*uint32, len(src))
for i := 0; i < len(src); i++ {
dst[i] = &(src[i])
}
return dst
}
// Uint32ValueSlice converts a slice of uint32 pointers into a slice of
// uint32 values
func Uint32ValueSlice(src []*uint32) []uint32 {
dst := make([]uint32, len(src))
for i := 0; i < len(src); i++ {
if src[i] != nil {
dst[i] = *(src[i])
}
}
return dst
}
// Uint32Map converts a string map of uint32 values into a string
// map of uint32 pointers
func Uint32Map(src map[string]uint32) map[string]*uint32 {
dst := make(map[string]*uint32)
for k, val := range src {
v := val
dst[k] = &v
}
return dst
}
// Uint32ValueMap converts a string map of uint32 pointers into a string
// map of uint32 values
func Uint32ValueMap(src map[string]*uint32) map[string]uint32 {
dst := make(map[string]uint32)
for k, val := range src {
if val != nil {
dst[k] = *val
}
}
return dst
}
// Uint64 returns a pointer to the uint64 value passed in.
func Uint64(v uint64) *uint64 {
return &v
}
// Uint64Value returns the value of the uint64 pointer passed in or
// 0 if the pointer is nil.
func Uint64Value(v *uint64) uint64 {
if v != nil {
return *v
}
return 0
}
// Uint64Slice converts a slice of uint64 values into a slice of
// uint64 pointers
func Uint64Slice(src []uint64) []*uint64 {
dst := make([]*uint64, len(src))
for i := 0; i < len(src); i++ {
dst[i] = &(src[i])
}
return dst
}
// Uint64ValueSlice converts a slice of uint64 pointers into a slice of
// uint64 values
func Uint64ValueSlice(src []*uint64) []uint64 {
dst := make([]uint64, len(src))
for i := 0; i < len(src); i++ {
if src[i] != nil {
dst[i] = *(src[i])
}
}
return dst
}
// Uint64Map converts a string map of uint64 values into a string
// map of uint64 pointers
func Uint64Map(src map[string]uint64) map[string]*uint64 {
dst := make(map[string]*uint64)
for k, val := range src {
v := val
dst[k] = &v
}
return dst
}
// Uint64ValueMap converts a string map of uint64 pointers into a string
// map of uint64 values
func Uint64ValueMap(src map[string]*uint64) map[string]uint64 {
dst := make(map[string]uint64)
for k, val := range src {
if val != nil {
dst[k] = *val
}
}
return dst
}
// Float32 returns a pointer to the float32 value passed in.
func Float32(v float32) *float32 {
return &v
}
// Float32Value returns the value of the float32 pointer passed in or
// 0 if the pointer is nil.
func Float32Value(v *float32) float32 {
if v != nil {
return *v
}
return 0
}
// Float32Slice converts a slice of float32 values into a slice of
// float32 pointers
func Float32Slice(src []float32) []*float32 {
dst := make([]*float32, len(src))
for i := 0; i < len(src); i++ {
dst[i] = &(src[i])
}
return dst
}
// Float32ValueSlice converts a slice of float32 pointers into a slice of
// float32 values
func Float32ValueSlice(src []*float32) []float32 {
dst := make([]float32, len(src))
for i := 0; i < len(src); i++ {
if src[i] != nil {
dst[i] = *(src[i])
}
}
return dst
}
// Float32Map converts a string map of float32 values into a string
// map of float32 pointers
func Float32Map(src map[string]float32) map[string]*float32 {
dst := make(map[string]*float32)
for k, val := range src {
v := val
dst[k] = &v
}
return dst
}
// Float32ValueMap converts a string map of float32 pointers into a string
// map of float32 values
func Float32ValueMap(src map[string]*float32) map[string]float32 {
dst := make(map[string]float32)
for k, val := range src {
if val != nil {
dst[k] = *val
}
}
return dst
}
// Float64 returns a pointer to the float64 value passed in. // Float64 returns a pointer to the float64 value passed in.
func Float64(v float64) *float64 { func Float64(v float64) *float64 {
return &v return &v

View File

@@ -159,9 +159,9 @@ func handleSendError(r *request.Request, err error) {
Body: ioutil.NopCloser(bytes.NewReader([]byte{})), Body: ioutil.NopCloser(bytes.NewReader([]byte{})),
} }
} }
// Catch all other request errors. // Catch all request errors, and let the default retrier determine
r.Error = awserr.New("RequestError", "send request failed", err) // if the error is retryable.
r.Retryable = aws.Bool(true) // network errors are retryable r.Error = awserr.New(request.ErrCodeRequestError, "send request failed", err)
// Override the error with a context canceled error, if that was canceled. // Override the error with a context canceled error, if that was canceled.
ctx := r.Context() ctx := r.Context()
@@ -184,7 +184,9 @@ var ValidateResponseHandler = request.NamedHandler{Name: "core.ValidateResponseH
// AfterRetryHandler performs final checks to determine if the request should // AfterRetryHandler performs final checks to determine if the request should
// be retried and how long to delay. // be retried and how long to delay.
var AfterRetryHandler = request.NamedHandler{Name: "core.AfterRetryHandler", Fn: func(r *request.Request) { var AfterRetryHandler = request.NamedHandler{
Name: "core.AfterRetryHandler",
Fn: func(r *request.Request) {
// If one of the other handlers already set the retry state // If one of the other handlers already set the retry state
// we don't want to override it based on the service's state // we don't want to override it based on the service's state
if r.Retryable == nil || aws.BoolValue(r.Config.EnforceShouldRetryCheck) { if r.Retryable == nil || aws.BoolValue(r.Config.EnforceShouldRetryCheck) {
@@ -214,7 +216,7 @@ var AfterRetryHandler = request.NamedHandler{Name: "core.AfterRetryHandler", Fn:
r.RetryCount++ r.RetryCount++
r.Error = nil r.Error = nil
} }
}} }}
// ValidateEndpointHandler is a request handler to validate a request had the // ValidateEndpointHandler is a request handler to validate a request had the
// appropriate Region and Endpoint set. Will set r.Error if the endpoint or // appropriate Region and Endpoint set. Will set r.Error if the endpoint or

View File

@@ -50,9 +50,10 @@ package credentials
import ( import (
"fmt" "fmt"
"github.com/aws/aws-sdk-go/aws/awserr"
"sync" "sync"
"time" "time"
"github.com/aws/aws-sdk-go/aws/awserr"
) )
// AnonymousCredentials is an empty Credential object that can be used as // AnonymousCredentials is an empty Credential object that can be used as
@@ -83,6 +84,12 @@ type Value struct {
ProviderName string ProviderName string
} }
// HasKeys returns if the credentials Value has both AccessKeyID and
// SecretAccessKey value set.
func (v Value) HasKeys() bool {
return len(v.AccessKeyID) != 0 && len(v.SecretAccessKey) != 0
}
// A Provider is the interface for any component which will provide credentials // A Provider is the interface for any component which will provide credentials
// Value. A provider is required to manage its own Expired state, and what to // Value. A provider is required to manage its own Expired state, and what to
// be expired means. // be expired means.

View File

@@ -11,6 +11,7 @@ go_library(
"//vendor/github.com/aws/aws-sdk-go/aws/client:go_default_library", "//vendor/github.com/aws/aws-sdk-go/aws/client:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/aws/credentials:go_default_library", "//vendor/github.com/aws/aws-sdk-go/aws/credentials:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/aws/ec2metadata:go_default_library", "//vendor/github.com/aws/aws-sdk-go/aws/ec2metadata:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/aws/request:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/internal/sdkuri:go_default_library", "//vendor/github.com/aws/aws-sdk-go/internal/sdkuri:go_default_library",
], ],
) )

View File

@@ -11,6 +11,7 @@ import (
"github.com/aws/aws-sdk-go/aws/client" "github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/ec2metadata" "github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/internal/sdkuri" "github.com/aws/aws-sdk-go/internal/sdkuri"
) )
@@ -142,7 +143,8 @@ func requestCredList(client *ec2metadata.EC2Metadata) ([]string, error) {
} }
if err := s.Err(); err != nil { if err := s.Err(); err != nil {
return nil, awserr.New("SerializationError", "failed to read EC2 instance role from metadata service", err) return nil, awserr.New(request.ErrCodeSerialization,
"failed to read EC2 instance role from metadata service", err)
} }
return credsList, nil return credsList, nil
@@ -164,7 +166,7 @@ func requestCred(client *ec2metadata.EC2Metadata, credsName string) (ec2RoleCred
respCreds := ec2RoleCredRespBody{} respCreds := ec2RoleCredRespBody{}
if err := json.NewDecoder(strings.NewReader(resp)).Decode(&respCreds); err != nil { if err := json.NewDecoder(strings.NewReader(resp)).Decode(&respCreds); err != nil {
return ec2RoleCredRespBody{}, return ec2RoleCredRespBody{},
awserr.New("SerializationError", awserr.New(request.ErrCodeSerialization,
fmt.Sprintf("failed to decode %s EC2 instance role credentials", credsName), fmt.Sprintf("failed to decode %s EC2 instance role credentials", credsName),
err) err)
} }

View File

@@ -13,6 +13,7 @@ go_library(
"//vendor/github.com/aws/aws-sdk-go/aws/client/metadata:go_default_library", "//vendor/github.com/aws/aws-sdk-go/aws/client/metadata:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/aws/credentials:go_default_library", "//vendor/github.com/aws/aws-sdk-go/aws/credentials:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/aws/request:go_default_library", "//vendor/github.com/aws/aws-sdk-go/aws/request:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/private/protocol/json/jsonutil:go_default_library",
], ],
) )

View File

@@ -39,6 +39,7 @@ import (
"github.com/aws/aws-sdk-go/aws/client/metadata" "github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/private/protocol/json/jsonutil"
) )
// ProviderName is the name of the credentials provider. // ProviderName is the name of the credentials provider.
@@ -97,8 +98,8 @@ func NewProviderClient(cfg aws.Config, handlers request.Handlers, endpoint strin
return p return p
} }
// NewCredentialsClient returns a Credentials wrapper for retrieving credentials // NewCredentialsClient returns a pointer to a new Credentials object
// from an arbitrary endpoint concurrently. The client will request the // wrapping the endpoint credentials Provider.
func NewCredentialsClient(cfg aws.Config, handlers request.Handlers, endpoint string, options ...func(*Provider)) *credentials.Credentials { func NewCredentialsClient(cfg aws.Config, handlers request.Handlers, endpoint string, options ...func(*Provider)) *credentials.Credentials {
return credentials.NewCredentials(NewProviderClient(cfg, handlers, endpoint, options...)) return credentials.NewCredentials(NewProviderClient(cfg, handlers, endpoint, options...))
} }
@@ -174,7 +175,7 @@ func unmarshalHandler(r *request.Request) {
out := r.Data.(*getCredentialsOutput) out := r.Data.(*getCredentialsOutput)
if err := json.NewDecoder(r.HTTPResponse.Body).Decode(&out); err != nil { if err := json.NewDecoder(r.HTTPResponse.Body).Decode(&out); err != nil {
r.Error = awserr.New("SerializationError", r.Error = awserr.New(request.ErrCodeSerialization,
"failed to decode endpoint credentials", "failed to decode endpoint credentials",
err, err,
) )
@@ -185,11 +186,15 @@ func unmarshalError(r *request.Request) {
defer r.HTTPResponse.Body.Close() defer r.HTTPResponse.Body.Close()
var errOut errorOutput var errOut errorOutput
if err := json.NewDecoder(r.HTTPResponse.Body).Decode(&errOut); err != nil { err := jsonutil.UnmarshalJSONError(&errOut, r.HTTPResponse.Body)
r.Error = awserr.New("SerializationError", if err != nil {
"failed to decode endpoint credentials", r.Error = awserr.NewRequestFailure(
err, awserr.New(request.ErrCodeSerialization,
"failed to decode error message", err),
r.HTTPResponse.StatusCode,
r.RequestID,
) )
return
} }
// Response body format is not consistent between metadata endpoints. // Response body format is not consistent between metadata endpoints.

View File

@@ -9,6 +9,7 @@ go_library(
deps = [ deps = [
"//vendor/github.com/aws/aws-sdk-go/aws/awserr:go_default_library", "//vendor/github.com/aws/aws-sdk-go/aws/awserr:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/aws/credentials:go_default_library", "//vendor/github.com/aws/aws-sdk-go/aws/credentials:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/internal/sdkio:go_default_library",
], ],
) )

View File

@@ -90,6 +90,7 @@ import (
"github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/internal/sdkio"
) )
const ( const (
@@ -142,7 +143,7 @@ const (
// DefaultBufSize limits buffer size from growing to an enormous // DefaultBufSize limits buffer size from growing to an enormous
// amount due to a faulty process. // amount due to a faulty process.
DefaultBufSize = 1024 DefaultBufSize = int(8 * sdkio.KibiByte)
// DefaultTimeout default limit on time a process can run. // DefaultTimeout default limit on time a process can run.
DefaultTimeout = time.Duration(1) * time.Minute DefaultTimeout = time.Duration(1) * time.Minute

View File

@@ -2,7 +2,10 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library")
go_library( go_library(
name = "go_default_library", name = "go_default_library",
srcs = ["assume_role_provider.go"], srcs = [
"assume_role_provider.go",
"web_identity_provider.go",
],
importmap = "k8s.io/kubernetes/vendor/github.com/aws/aws-sdk-go/aws/credentials/stscreds", importmap = "k8s.io/kubernetes/vendor/github.com/aws/aws-sdk-go/aws/credentials/stscreds",
importpath = "github.com/aws/aws-sdk-go/aws/credentials/stscreds", importpath = "github.com/aws/aws-sdk-go/aws/credentials/stscreds",
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
@@ -11,7 +14,9 @@ go_library(
"//vendor/github.com/aws/aws-sdk-go/aws/awserr:go_default_library", "//vendor/github.com/aws/aws-sdk-go/aws/awserr:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/aws/client:go_default_library", "//vendor/github.com/aws/aws-sdk-go/aws/client:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/aws/credentials:go_default_library", "//vendor/github.com/aws/aws-sdk-go/aws/credentials:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/internal/sdkrand:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/service/sts:go_default_library", "//vendor/github.com/aws/aws-sdk-go/service/sts:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/service/sts/stsiface:go_default_library",
], ],
) )

View File

@@ -80,16 +80,18 @@ package stscreds
import ( import (
"fmt" "fmt"
"os"
"time" "time"
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client" "github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/internal/sdkrand"
"github.com/aws/aws-sdk-go/service/sts" "github.com/aws/aws-sdk-go/service/sts"
) )
// StdinTokenProvider will prompt on stdout and read from stdin for a string value. // StdinTokenProvider will prompt on stderr and read from stdin for a string value.
// An error is returned if reading from stdin fails. // An error is returned if reading from stdin fails.
// //
// Use this function go read MFA tokens from stdin. The function makes no attempt // Use this function go read MFA tokens from stdin. The function makes no attempt
@@ -102,7 +104,7 @@ import (
// Will wait forever until something is provided on the stdin. // Will wait forever until something is provided on the stdin.
func StdinTokenProvider() (string, error) { func StdinTokenProvider() (string, error) {
var v string var v string
fmt.Printf("Assume Role MFA token code: ") fmt.Fprintf(os.Stderr, "Assume Role MFA token code: ")
_, err := fmt.Scanln(&v) _, err := fmt.Scanln(&v)
return v, err return v, err
@@ -142,6 +144,13 @@ type AssumeRoleProvider struct {
// Session name, if you wish to reuse the credentials elsewhere. // Session name, if you wish to reuse the credentials elsewhere.
RoleSessionName string RoleSessionName string
// Optional, you can pass tag key-value pairs to your session. These tags are called session tags.
Tags []*sts.Tag
// A list of keys for session tags that you want to set as transitive.
// If you set a tag key as transitive, the corresponding key and value passes to subsequent sessions in a role chain.
TransitiveTagKeys []*string
// Expiry duration of the STS credentials. Defaults to 15 minutes if not set. // Expiry duration of the STS credentials. Defaults to 15 minutes if not set.
Duration time.Duration Duration time.Duration
@@ -193,6 +202,18 @@ type AssumeRoleProvider struct {
// //
// If ExpiryWindow is 0 or less it will be ignored. // If ExpiryWindow is 0 or less it will be ignored.
ExpiryWindow time.Duration ExpiryWindow time.Duration
// MaxJitterFrac reduces the effective Duration of each credential requested
// by a random percentage between 0 and MaxJitterFraction. MaxJitterFrac must
// have a value between 0 and 1. Any other value may lead to expected behavior.
// With a MaxJitterFrac value of 0, default) will no jitter will be used.
//
// For example, with a Duration of 30m and a MaxJitterFrac of 0.1, the
// AssumeRole call will be made with an arbitrary Duration between 27m and
// 30m.
//
// MaxJitterFrac should not be negative.
MaxJitterFrac float64
} }
// NewCredentials returns a pointer to a new Credentials object wrapping the // NewCredentials returns a pointer to a new Credentials object wrapping the
@@ -244,7 +265,6 @@ func NewCredentialsWithClient(svc AssumeRoler, roleARN string, options ...func(*
// Retrieve generates a new set of temporary credentials using STS. // Retrieve generates a new set of temporary credentials using STS.
func (p *AssumeRoleProvider) Retrieve() (credentials.Value, error) { func (p *AssumeRoleProvider) Retrieve() (credentials.Value, error) {
// Apply defaults where parameters are not set. // Apply defaults where parameters are not set.
if p.RoleSessionName == "" { if p.RoleSessionName == "" {
// Try to work out a role name that will hopefully end up unique. // Try to work out a role name that will hopefully end up unique.
@@ -254,11 +274,14 @@ func (p *AssumeRoleProvider) Retrieve() (credentials.Value, error) {
// Expire as often as AWS permits. // Expire as often as AWS permits.
p.Duration = DefaultDuration p.Duration = DefaultDuration
} }
jitter := time.Duration(sdkrand.SeededRand.Float64() * p.MaxJitterFrac * float64(p.Duration))
input := &sts.AssumeRoleInput{ input := &sts.AssumeRoleInput{
DurationSeconds: aws.Int64(int64(p.Duration / time.Second)), DurationSeconds: aws.Int64(int64((p.Duration - jitter) / time.Second)),
RoleArn: aws.String(p.RoleARN), RoleArn: aws.String(p.RoleARN),
RoleSessionName: aws.String(p.RoleSessionName), RoleSessionName: aws.String(p.RoleSessionName),
ExternalId: p.ExternalID, ExternalId: p.ExternalID,
Tags: p.Tags,
TransitiveTagKeys: p.TransitiveTagKeys,
} }
if p.Policy != nil { if p.Policy != nil {
input.Policy = p.Policy input.Policy = p.Policy

View File

@@ -0,0 +1,100 @@
package stscreds
import (
"fmt"
"io/ioutil"
"strconv"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/aws-sdk-go/service/sts/stsiface"
)
const (
// ErrCodeWebIdentity will be used as an error code when constructing
// a new error to be returned during session creation or retrieval.
ErrCodeWebIdentity = "WebIdentityErr"
// WebIdentityProviderName is the web identity provider name
WebIdentityProviderName = "WebIdentityCredentials"
)
// now is used to return a time.Time object representing
// the current time. This can be used to easily test and
// compare test values.
var now = time.Now
// WebIdentityRoleProvider is used to retrieve credentials using
// an OIDC token.
type WebIdentityRoleProvider struct {
credentials.Expiry
client stsiface.STSAPI
ExpiryWindow time.Duration
tokenFilePath string
roleARN string
roleSessionName string
}
// NewWebIdentityCredentials will return a new set of credentials with a given
// configuration, role arn, and token file path.
func NewWebIdentityCredentials(c client.ConfigProvider, roleARN, roleSessionName, path string) *credentials.Credentials {
svc := sts.New(c)
p := NewWebIdentityRoleProvider(svc, roleARN, roleSessionName, path)
return credentials.NewCredentials(p)
}
// NewWebIdentityRoleProvider will return a new WebIdentityRoleProvider with the
// provided stsiface.STSAPI
func NewWebIdentityRoleProvider(svc stsiface.STSAPI, roleARN, roleSessionName, path string) *WebIdentityRoleProvider {
return &WebIdentityRoleProvider{
client: svc,
tokenFilePath: path,
roleARN: roleARN,
roleSessionName: roleSessionName,
}
}
// Retrieve will attempt to assume a role from a token which is located at
// 'WebIdentityTokenFilePath' specified destination and if that is empty an
// error will be returned.
func (p *WebIdentityRoleProvider) Retrieve() (credentials.Value, error) {
b, err := ioutil.ReadFile(p.tokenFilePath)
if err != nil {
errMsg := fmt.Sprintf("unable to read file at %s", p.tokenFilePath)
return credentials.Value{}, awserr.New(ErrCodeWebIdentity, errMsg, err)
}
sessionName := p.roleSessionName
if len(sessionName) == 0 {
// session name is used to uniquely identify a session. This simply
// uses unix time in nanoseconds to uniquely identify sessions.
sessionName = strconv.FormatInt(now().UnixNano(), 10)
}
req, resp := p.client.AssumeRoleWithWebIdentityRequest(&sts.AssumeRoleWithWebIdentityInput{
RoleArn: &p.roleARN,
RoleSessionName: &sessionName,
WebIdentityToken: aws.String(string(b)),
})
// InvalidIdentityToken error is a temporary error that can occur
// when assuming an Role with a JWT web identity token.
req.RetryErrorCodes = append(req.RetryErrorCodes, sts.ErrCodeInvalidIdentityTokenException)
if err := req.Send(); err != nil {
return credentials.Value{}, awserr.New(ErrCodeWebIdentity, "failed to retrieve credentials", err)
}
p.SetExpiration(aws.TimeValue(resp.Credentials.Expiration), p.ExpiryWindow)
value := credentials.Value{
AccessKeyID: aws.StringValue(resp.Credentials.AccessKeyId),
SecretAccessKey: aws.StringValue(resp.Credentials.SecretAccessKey),
SessionToken: aws.StringValue(resp.Credentials.SessionToken),
ProviderName: WebIdentityProviderName,
}
return value, nil
}

View File

@@ -1,30 +1,61 @@
// Package csm provides Client Side Monitoring (CSM) which enables sending metrics // Package csm provides the Client Side Monitoring (CSM) client which enables
// via UDP connection. Using the Start function will enable the reporting of // sending metrics via UDP connection to the CSM agent. This package provides
// metrics on a given port. If Start is called, with different parameters, again, // control options, and configuration for the CSM client. The client can be
// a panic will occur. // controlled manually, or automatically via the SDK's Session configuration.
// //
// Pause can be called to pause any metrics publishing on a given port. Sessions // Enabling CSM client via SDK's Session configuration
// that have had their handlers modified via InjectHandlers may still be used. //
// However, the handlers will act as a no-op meaning no metrics will be published. // The CSM client can be enabled automatically via SDK's Session configuration.
// The SDK's session configuration enables the CSM client if the AWS_CSM_PORT
// environment variable is set to a non-empty value.
//
// The configuration options for the CSM client via the SDK's session
// configuration are:
//
// * AWS_CSM_PORT=<port number>
// The port number the CSM agent will receive metrics on.
//
// * AWS_CSM_HOST=<hostname or ip>
// The hostname, or IP address the CSM agent will receive metrics on.
// Without port number.
//
// Manually enabling the CSM client
//
// The CSM client can be started, paused, and resumed manually. The Start
// function will enable the CSM client to publish metrics to the CSM agent. It
// is safe to call Start concurrently, but if Start is called additional times
// with different ClientID or address it will panic.
// //
// Example:
// r, err := csm.Start("clientID", ":31000") // r, err := csm.Start("clientID", ":31000")
// if err != nil { // if err != nil {
// panic(fmt.Errorf("failed starting CSM: %v", err)) // panic(fmt.Errorf("failed starting CSM: %v", err))
// } // }
// //
// When controlling the CSM client manually, you must also inject its request
// handlers into the SDK's Session configuration for the SDK's API clients to
// publish metrics.
//
// sess, err := session.NewSession(&aws.Config{}) // sess, err := session.NewSession(&aws.Config{})
// if err != nil { // if err != nil {
// panic(fmt.Errorf("failed loading session: %v", err)) // panic(fmt.Errorf("failed loading session: %v", err))
// } // }
// //
// // Add CSM client's metric publishing request handlers to the SDK's
// // Session Configuration.
// r.InjectHandlers(&sess.Handlers) // r.InjectHandlers(&sess.Handlers)
// //
// client := s3.New(sess) // Controlling CSM client
// resp, err := client.GetObject(&s3.GetObjectInput{ //
// Bucket: aws.String("bucket"), // Once the CSM client has been enabled the Get function will return a Reporter
// Key: aws.String("key"), // value that you can use to pause and resume the metrics published to the CSM
// }) // agent. If Get function is called before the reporter is enabled with the
// Start function or via SDK's Session configuration nil will be returned.
//
// The Pause method can be called to stop the CSM client publishing metrics to
// the CSM agent. The Continue method will resume metric publishing.
//
// // Get the CSM client Reporter.
// r := csm.Get()
// //
// // Will pause monitoring // // Will pause monitoring
// r.Pause() // r.Pause()
@@ -35,12 +66,4 @@
// //
// // Resume monitoring // // Resume monitoring
// r.Continue() // r.Continue()
//
// Start returns a Reporter that is used to enable or disable monitoring. If
// access to the Reporter is required later, calling Get will return the Reporter
// singleton.
//
// Example:
// r := csm.Get()
// r.Continue()
package csm package csm

View File

@@ -2,6 +2,7 @@ package csm
import ( import (
"fmt" "fmt"
"strings"
"sync" "sync"
) )
@@ -9,19 +10,40 @@ var (
lock sync.Mutex lock sync.Mutex
) )
// Client side metric handler names
const ( const (
APICallMetricHandlerName = "awscsm.SendAPICallMetric" // DefaultPort is used when no port is specified.
APICallAttemptMetricHandlerName = "awscsm.SendAPICallAttemptMetric" DefaultPort = "31000"
// DefaultHost is the host that will be used when none is specified.
DefaultHost = "127.0.0.1"
) )
// Start will start the a long running go routine to capture // AddressWithDefaults returns a CSM address built from the host and port
// values. If the host or port is not set, default values will be used
// instead. If host is "localhost" it will be replaced with "127.0.0.1".
func AddressWithDefaults(host, port string) string {
if len(host) == 0 || strings.EqualFold(host, "localhost") {
host = DefaultHost
}
if len(port) == 0 {
port = DefaultPort
}
// Only IP6 host can contain a colon
if strings.Contains(host, ":") {
return "[" + host + "]:" + port
}
return host + ":" + port
}
// Start will start a long running go routine to capture
// client side metrics. Calling start multiple time will only // client side metrics. Calling start multiple time will only
// start the metric listener once and will panic if a different // start the metric listener once and will panic if a different
// client ID or port is passed in. // client ID or port is passed in.
// //
// Example: // r, err := csm.Start("clientID", "127.0.0.1:31000")
// r, err := csm.Start("clientID", "127.0.0.1:8094")
// if err != nil { // if err != nil {
// panic(fmt.Errorf("expected no error, but received %v", err)) // panic(fmt.Errorf("expected no error, but received %v", err))
// } // }

View File

@@ -16,25 +16,26 @@ var (
type metricChan struct { type metricChan struct {
ch chan metric ch chan metric
paused int64 paused *int64
} }
func newMetricChan(size int) metricChan { func newMetricChan(size int) metricChan {
return metricChan{ return metricChan{
ch: make(chan metric, size), ch: make(chan metric, size),
paused: new(int64),
} }
} }
func (ch *metricChan) Pause() { func (ch *metricChan) Pause() {
atomic.StoreInt64(&ch.paused, pausedEnum) atomic.StoreInt64(ch.paused, pausedEnum)
} }
func (ch *metricChan) Continue() { func (ch *metricChan) Continue() {
atomic.StoreInt64(&ch.paused, runningEnum) atomic.StoreInt64(ch.paused, runningEnum)
} }
func (ch *metricChan) IsPaused() bool { func (ch *metricChan) IsPaused() bool {
v := atomic.LoadInt64(&ch.paused) v := atomic.LoadInt64(ch.paused)
return v == pausedEnum return v == pausedEnum
} }

View File

@@ -10,11 +10,6 @@ import (
"github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/request"
) )
const (
// DefaultPort is used when no port is specified
DefaultPort = "31000"
)
// Reporter will gather metrics of API requests made and // Reporter will gather metrics of API requests made and
// send those metrics to the CSM endpoint. // send those metrics to the CSM endpoint.
type Reporter struct { type Reporter struct {
@@ -71,7 +66,6 @@ func (rep *Reporter) sendAPICallAttemptMetric(r *request.Request) {
XAmzRequestID: aws.String(r.RequestID), XAmzRequestID: aws.String(r.RequestID),
AttemptCount: aws.Int(r.RetryCount + 1),
AttemptLatency: aws.Int(int(now.Sub(r.AttemptTime).Nanoseconds() / int64(time.Millisecond))), AttemptLatency: aws.Int(int(now.Sub(r.AttemptTime).Nanoseconds() / int64(time.Millisecond))),
AccessKey: aws.String(creds.AccessKeyID), AccessKey: aws.String(creds.AccessKeyID),
} }
@@ -95,8 +89,8 @@ func getMetricException(err awserr.Error) metricException {
code := err.Code() code := err.Code()
switch code { switch code {
case "RequestError", case request.ErrCodeRequestError,
"SerializationError", request.ErrCodeSerialization,
request.CanceledErrorCode: request.CanceledErrorCode:
return sdkException{ return sdkException{
requestException{exception: code, message: msg}, requestException{exception: code, message: msg},
@@ -123,7 +117,7 @@ func (rep *Reporter) sendAPICallMetric(r *request.Request) {
Type: aws.String("ApiCall"), Type: aws.String("ApiCall"),
AttemptCount: aws.Int(r.RetryCount + 1), AttemptCount: aws.Int(r.RetryCount + 1),
Region: r.Config.Region, Region: r.Config.Region,
Latency: aws.Int(int(time.Now().Sub(r.Time) / time.Millisecond)), Latency: aws.Int(int(time.Since(r.Time) / time.Millisecond)),
XAmzRequestID: aws.String(r.RequestID), XAmzRequestID: aws.String(r.RequestID),
MaxRetriesExceeded: aws.Int(boolIntValue(r.RetryCount >= r.MaxRetries())), MaxRetriesExceeded: aws.Int(boolIntValue(r.RetryCount >= r.MaxRetries())),
} }
@@ -190,8 +184,9 @@ func (rep *Reporter) start() {
} }
} }
// Pause will pause the metric channel preventing any new metrics from // Pause will pause the metric channel preventing any new metrics from being
// being added. // added. It is safe to call concurrently with other calls to Pause, but if
// called concurently with Continue can lead to unexpected state.
func (rep *Reporter) Pause() { func (rep *Reporter) Pause() {
lock.Lock() lock.Lock()
defer lock.Unlock() defer lock.Unlock()
@@ -203,8 +198,9 @@ func (rep *Reporter) Pause() {
rep.close() rep.close()
} }
// Continue will reopen the metric channel and allow for monitoring // Continue will reopen the metric channel and allow for monitoring to be
// to be resumed. // resumed. It is safe to call concurrently with other calls to Continue, but
// if called concurently with Pause can lead to unexpected state.
func (rep *Reporter) Continue() { func (rep *Reporter) Continue() {
lock.Lock() lock.Lock()
defer lock.Unlock() defer lock.Unlock()
@@ -219,10 +215,18 @@ func (rep *Reporter) Continue() {
rep.metricsCh.Continue() rep.metricsCh.Continue()
} }
// Client side metric handler names
const (
APICallMetricHandlerName = "awscsm.SendAPICallMetric"
APICallAttemptMetricHandlerName = "awscsm.SendAPICallAttemptMetric"
)
// InjectHandlers will will enable client side metrics and inject the proper // InjectHandlers will will enable client side metrics and inject the proper
// handlers to handle how metrics are sent. // handlers to handle how metrics are sent.
// //
// Example: // InjectHandlers is NOT safe to call concurrently. Calling InjectHandlers
// multiple times may lead to unexpected behavior, (e.g. duplicate metrics).
//
// // Start must be called in order to inject the correct handlers // // Start must be called in order to inject the correct handlers
// r, err := csm.Start("clientID", "127.0.0.1:8094") // r, err := csm.Start("clientID", "127.0.0.1:8094")
// if err != nil { // if err != nil {

View File

@@ -5,6 +5,7 @@ go_library(
srcs = [ srcs = [
"api.go", "api.go",
"service.go", "service.go",
"token_provider.go",
], ],
importmap = "k8s.io/kubernetes/vendor/github.com/aws/aws-sdk-go/aws/ec2metadata", importmap = "k8s.io/kubernetes/vendor/github.com/aws/aws-sdk-go/aws/ec2metadata",
importpath = "github.com/aws/aws-sdk-go/aws/ec2metadata", importpath = "github.com/aws/aws-sdk-go/aws/ec2metadata",
@@ -15,6 +16,7 @@ go_library(
"//vendor/github.com/aws/aws-sdk-go/aws/client:go_default_library", "//vendor/github.com/aws/aws-sdk-go/aws/client:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/aws/client/metadata:go_default_library", "//vendor/github.com/aws/aws-sdk-go/aws/client/metadata:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/aws/corehandlers:go_default_library", "//vendor/github.com/aws/aws-sdk-go/aws/corehandlers:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/aws/credentials:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/aws/request:go_default_library", "//vendor/github.com/aws/aws-sdk-go/aws/request:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/internal/sdkuri:go_default_library", "//vendor/github.com/aws/aws-sdk-go/internal/sdkuri:go_default_library",
], ],

View File

@@ -4,6 +4,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"strconv"
"strings" "strings"
"time" "time"
@@ -12,8 +13,41 @@ import (
"github.com/aws/aws-sdk-go/internal/sdkuri" "github.com/aws/aws-sdk-go/internal/sdkuri"
) )
// getToken uses the duration to return a token for EC2 metadata service,
// or an error if the request failed.
func (c *EC2Metadata) getToken(duration time.Duration) (tokenOutput, error) {
op := &request.Operation{
Name: "GetToken",
HTTPMethod: "PUT",
HTTPPath: "/api/token",
}
var output tokenOutput
req := c.NewRequest(op, nil, &output)
// remove the fetch token handler from the request handlers to avoid infinite recursion
req.Handlers.Sign.RemoveByName(fetchTokenHandlerName)
// Swap the unmarshalMetadataHandler with unmarshalTokenHandler on this request.
req.Handlers.Unmarshal.Swap(unmarshalMetadataHandlerName, unmarshalTokenHandler)
ttl := strconv.FormatInt(int64(duration/time.Second), 10)
req.HTTPRequest.Header.Set(ttlHeader, ttl)
err := req.Send()
// Errors with bad request status should be returned.
if err != nil {
err = awserr.NewRequestFailure(
awserr.New(req.HTTPResponse.Status, http.StatusText(req.HTTPResponse.StatusCode), err),
req.HTTPResponse.StatusCode, req.RequestID)
}
return output, err
}
// GetMetadata uses the path provided to request information from the EC2 // GetMetadata uses the path provided to request information from the EC2
// instance metdata service. The content will be returned as a string, or // instance metadata service. The content will be returned as a string, or
// error if the request failed. // error if the request failed.
func (c *EC2Metadata) GetMetadata(p string) (string, error) { func (c *EC2Metadata) GetMetadata(p string) (string, error) {
op := &request.Operation{ op := &request.Operation{
@@ -21,11 +55,12 @@ func (c *EC2Metadata) GetMetadata(p string) (string, error) {
HTTPMethod: "GET", HTTPMethod: "GET",
HTTPPath: sdkuri.PathJoin("/meta-data", p), HTTPPath: sdkuri.PathJoin("/meta-data", p),
} }
output := &metadataOutput{} output := &metadataOutput{}
req := c.NewRequest(op, nil, output) req := c.NewRequest(op, nil, output)
return output.Content, req.Send() err := req.Send()
return output.Content, err
} }
// GetUserData returns the userdata that was configured for the service. If // GetUserData returns the userdata that was configured for the service. If
@@ -40,13 +75,9 @@ func (c *EC2Metadata) GetUserData() (string, error) {
output := &metadataOutput{} output := &metadataOutput{}
req := c.NewRequest(op, nil, output) req := c.NewRequest(op, nil, output)
req.Handlers.UnmarshalError.PushBack(func(r *request.Request) {
if r.HTTPResponse.StatusCode == http.StatusNotFound {
r.Error = awserr.New("NotFoundError", "user-data not found", r.Error)
}
})
return output.Content, req.Send() err := req.Send()
return output.Content, err
} }
// GetDynamicData uses the path provided to request information from the EC2 // GetDynamicData uses the path provided to request information from the EC2
@@ -62,7 +93,8 @@ func (c *EC2Metadata) GetDynamicData(p string) (string, error) {
output := &metadataOutput{} output := &metadataOutput{}
req := c.NewRequest(op, nil, output) req := c.NewRequest(op, nil, output)
return output.Content, req.Send() err := req.Send()
return output.Content, err
} }
// GetInstanceIdentityDocument retrieves an identity document describing an // GetInstanceIdentityDocument retrieves an identity document describing an
@@ -79,7 +111,7 @@ func (c *EC2Metadata) GetInstanceIdentityDocument() (EC2InstanceIdentityDocument
doc := EC2InstanceIdentityDocument{} doc := EC2InstanceIdentityDocument{}
if err := json.NewDecoder(strings.NewReader(resp)).Decode(&doc); err != nil { if err := json.NewDecoder(strings.NewReader(resp)).Decode(&doc); err != nil {
return EC2InstanceIdentityDocument{}, return EC2InstanceIdentityDocument{},
awserr.New("SerializationError", awserr.New(request.ErrCodeSerialization,
"failed to decode EC2 instance identity document", err) "failed to decode EC2 instance identity document", err)
} }
@@ -98,7 +130,7 @@ func (c *EC2Metadata) IAMInfo() (EC2IAMInfo, error) {
info := EC2IAMInfo{} info := EC2IAMInfo{}
if err := json.NewDecoder(strings.NewReader(resp)).Decode(&info); err != nil { if err := json.NewDecoder(strings.NewReader(resp)).Decode(&info); err != nil {
return EC2IAMInfo{}, return EC2IAMInfo{},
awserr.New("SerializationError", awserr.New(request.ErrCodeSerialization,
"failed to decode EC2 IAM info", err) "failed to decode EC2 IAM info", err)
} }
@@ -113,17 +145,17 @@ func (c *EC2Metadata) IAMInfo() (EC2IAMInfo, error) {
// Region returns the region the instance is running in. // Region returns the region the instance is running in.
func (c *EC2Metadata) Region() (string, error) { func (c *EC2Metadata) Region() (string, error) {
resp, err := c.GetMetadata("placement/availability-zone") ec2InstanceIdentityDocument, err := c.GetInstanceIdentityDocument()
if err != nil { if err != nil {
return "", err return "", err
} }
// extract region from the ec2InstanceIdentityDocument
if len(resp) == 0 { region := ec2InstanceIdentityDocument.Region
return "", awserr.New("EC2MetadataError", "invalid Region response", nil) if len(region) == 0 {
return "", awserr.New("EC2MetadataError", "invalid region received for ec2metadata instance", nil)
} }
// returns region
// returns region without the suffix. Eg: us-west-2a becomes us-west-2 return region, nil
return resp[:len(resp)-1], nil
} }
// Available returns if the application has access to the EC2 Metadata service. // Available returns if the application has access to the EC2 Metadata service.
@@ -150,6 +182,7 @@ type EC2IAMInfo struct {
// an instance identity document // an instance identity document
type EC2InstanceIdentityDocument struct { type EC2InstanceIdentityDocument struct {
DevpayProductCodes []string `json:"devpayProductCodes"` DevpayProductCodes []string `json:"devpayProductCodes"`
MarketplaceProductCodes []string `json:"marketplaceProductCodes"`
AvailabilityZone string `json:"availabilityZone"` AvailabilityZone string `json:"availabilityZone"`
PrivateIP string `json:"privateIp"` PrivateIP string `json:"privateIp"`
Version string `json:"version"` Version string `json:"version"`

View File

@@ -13,6 +13,7 @@ import (
"io" "io"
"net/http" "net/http"
"os" "os"
"strconv"
"strings" "strings"
"time" "time"
@@ -24,9 +25,25 @@ import (
"github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/request"
) )
// ServiceName is the name of the service. const (
const ServiceName = "ec2metadata" // ServiceName is the name of the service.
const disableServiceEnvVar = "AWS_EC2_METADATA_DISABLED" ServiceName = "ec2metadata"
disableServiceEnvVar = "AWS_EC2_METADATA_DISABLED"
// Headers for Token and TTL
ttlHeader = "x-aws-ec2-metadata-token-ttl-seconds"
tokenHeader = "x-aws-ec2-metadata-token"
// Named Handler constants
fetchTokenHandlerName = "FetchTokenHandler"
unmarshalMetadataHandlerName = "unmarshalMetadataHandler"
unmarshalTokenHandlerName = "unmarshalTokenHandler"
enableTokenProviderHandlerName = "enableTokenProviderHandler"
// TTL constants
defaultTTL = 21600 * time.Second
ttlExpirationWindow = 30 * time.Second
)
// A EC2Metadata is an EC2 Metadata service Client. // A EC2Metadata is an EC2 Metadata service Client.
type EC2Metadata struct { type EC2Metadata struct {
@@ -63,8 +80,10 @@ func NewClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegio
// use a shorter timeout than default because the metadata // use a shorter timeout than default because the metadata
// service is local if it is running, and to fail faster // service is local if it is running, and to fail faster
// if not running on an ec2 instance. // if not running on an ec2 instance.
Timeout: 5 * time.Second, Timeout: 1 * time.Second,
} }
// max number of retries on the client operation
cfg.MaxRetries = aws.Int(2)
} }
svc := &EC2Metadata{ svc := &EC2Metadata{
@@ -80,18 +99,35 @@ func NewClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegio
), ),
} }
svc.Handlers.Unmarshal.PushBack(unmarshalHandler) // token provider instance
tp := newTokenProvider(svc, defaultTTL)
// NamedHandler for fetching token
svc.Handlers.Sign.PushBackNamed(request.NamedHandler{
Name: fetchTokenHandlerName,
Fn: tp.fetchTokenHandler,
})
// NamedHandler for enabling token provider
svc.Handlers.Complete.PushBackNamed(request.NamedHandler{
Name: enableTokenProviderHandlerName,
Fn: tp.enableTokenProviderHandler,
})
svc.Handlers.Unmarshal.PushBackNamed(unmarshalHandler)
svc.Handlers.UnmarshalError.PushBack(unmarshalError) svc.Handlers.UnmarshalError.PushBack(unmarshalError)
svc.Handlers.Validate.Clear() svc.Handlers.Validate.Clear()
svc.Handlers.Validate.PushBack(validateEndpointHandler) svc.Handlers.Validate.PushBack(validateEndpointHandler)
// Disable the EC2 Metadata service if the environment variable is set. // Disable the EC2 Metadata service if the environment variable is set.
// This shortcirctes the service's functionality to always fail to send // This short-circuits the service's functionality to always fail to send
// requests. // requests.
if strings.ToLower(os.Getenv(disableServiceEnvVar)) == "true" { if strings.ToLower(os.Getenv(disableServiceEnvVar)) == "true" {
svc.Handlers.Send.SwapNamed(request.NamedHandler{ svc.Handlers.Send.SwapNamed(request.NamedHandler{
Name: corehandlers.SendHandler.Name, Name: corehandlers.SendHandler.Name,
Fn: func(r *request.Request) { Fn: func(r *request.Request) {
r.HTTPResponse = &http.Response{
Header: http.Header{},
}
r.Error = awserr.New( r.Error = awserr.New(
request.CanceledErrorCode, request.CanceledErrorCode,
"EC2 IMDS access disabled via "+disableServiceEnvVar+" env var", "EC2 IMDS access disabled via "+disableServiceEnvVar+" env var",
@@ -104,7 +140,6 @@ func NewClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegio
for _, option := range opts { for _, option := range opts {
option(svc.Client) option(svc.Client)
} }
return svc return svc
} }
@@ -116,30 +151,74 @@ type metadataOutput struct {
Content string Content string
} }
func unmarshalHandler(r *request.Request) { type tokenOutput struct {
Token string
TTL time.Duration
}
// unmarshal token handler is used to parse the response of a getToken operation
var unmarshalTokenHandler = request.NamedHandler{
Name: unmarshalTokenHandlerName,
Fn: func(r *request.Request) {
defer r.HTTPResponse.Body.Close() defer r.HTTPResponse.Body.Close()
b := &bytes.Buffer{} var b bytes.Buffer
if _, err := io.Copy(b, r.HTTPResponse.Body); err != nil { if _, err := io.Copy(&b, r.HTTPResponse.Body); err != nil {
r.Error = awserr.New("SerializationError", "unable to unmarshal EC2 metadata respose", err) r.Error = awserr.NewRequestFailure(awserr.New(request.ErrCodeSerialization,
"unable to unmarshal EC2 metadata response", err), r.HTTPResponse.StatusCode, r.RequestID)
return
}
v := r.HTTPResponse.Header.Get(ttlHeader)
data, ok := r.Data.(*tokenOutput)
if !ok {
return
}
data.Token = b.String()
// TTL is in seconds
i, err := strconv.ParseInt(v, 10, 64)
if err != nil {
r.Error = awserr.NewRequestFailure(awserr.New(request.ParamFormatErrCode,
"unable to parse EC2 token TTL response", err), r.HTTPResponse.StatusCode, r.RequestID)
return
}
t := time.Duration(i) * time.Second
data.TTL = t
},
}
var unmarshalHandler = request.NamedHandler{
Name: unmarshalMetadataHandlerName,
Fn: func(r *request.Request) {
defer r.HTTPResponse.Body.Close()
var b bytes.Buffer
if _, err := io.Copy(&b, r.HTTPResponse.Body); err != nil {
r.Error = awserr.NewRequestFailure(awserr.New(request.ErrCodeSerialization,
"unable to unmarshal EC2 metadata response", err), r.HTTPResponse.StatusCode, r.RequestID)
return return
} }
if data, ok := r.Data.(*metadataOutput); ok { if data, ok := r.Data.(*metadataOutput); ok {
data.Content = b.String() data.Content = b.String()
} }
},
} }
func unmarshalError(r *request.Request) { func unmarshalError(r *request.Request) {
defer r.HTTPResponse.Body.Close() defer r.HTTPResponse.Body.Close()
b := &bytes.Buffer{} var b bytes.Buffer
if _, err := io.Copy(b, r.HTTPResponse.Body); err != nil {
r.Error = awserr.New("SerializationError", "unable to unmarshal EC2 metadata error respose", err) if _, err := io.Copy(&b, r.HTTPResponse.Body); err != nil {
r.Error = awserr.NewRequestFailure(
awserr.New(request.ErrCodeSerialization, "unable to unmarshal EC2 metadata error response", err),
r.HTTPResponse.StatusCode, r.RequestID)
return return
} }
// Response body format is not consistent between metadata endpoints. // Response body format is not consistent between metadata endpoints.
// Grab the error message as a string and include that as the source error // Grab the error message as a string and include that as the source error
r.Error = awserr.New("EC2MetadataError", "failed to make EC2Metadata request", errors.New(b.String())) r.Error = awserr.NewRequestFailure(awserr.New("EC2MetadataError", "failed to make EC2Metadata request", errors.New(b.String())),
r.HTTPResponse.StatusCode, r.RequestID)
} }
func validateEndpointHandler(r *request.Request) { func validateEndpointHandler(r *request.Request) {

View File

@@ -0,0 +1,92 @@
package ec2metadata
import (
"net/http"
"sync/atomic"
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
)
// A tokenProvider struct provides access to EC2Metadata client
// and atomic instance of a token, along with configuredTTL for it.
// tokenProvider also provides an atomic flag to disable the
// fetch token operation.
// The disabled member will use 0 as false, and 1 as true.
type tokenProvider struct {
client *EC2Metadata
token atomic.Value
configuredTTL time.Duration
disabled uint32
}
// A ec2Token struct helps use of token in EC2 Metadata service ops
type ec2Token struct {
token string
credentials.Expiry
}
// newTokenProvider provides a pointer to a tokenProvider instance
func newTokenProvider(c *EC2Metadata, duration time.Duration) *tokenProvider {
return &tokenProvider{client: c, configuredTTL: duration}
}
// fetchTokenHandler fetches token for EC2Metadata service client by default.
func (t *tokenProvider) fetchTokenHandler(r *request.Request) {
// short-circuits to insecure data flow if tokenProvider is disabled.
if v := atomic.LoadUint32(&t.disabled); v == 1 {
return
}
if ec2Token, ok := t.token.Load().(ec2Token); ok && !ec2Token.IsExpired() {
r.HTTPRequest.Header.Set(tokenHeader, ec2Token.token)
return
}
output, err := t.client.getToken(t.configuredTTL)
if err != nil {
// change the disabled flag on token provider to true,
// when error is request timeout error.
if requestFailureError, ok := err.(awserr.RequestFailure); ok {
switch requestFailureError.StatusCode() {
case http.StatusForbidden, http.StatusNotFound, http.StatusMethodNotAllowed:
atomic.StoreUint32(&t.disabled, 1)
case http.StatusBadRequest:
r.Error = requestFailureError
}
// Check if request timed out while waiting for response
if e, ok := requestFailureError.OrigErr().(awserr.Error); ok {
if e.Code() == request.ErrCodeRequestError {
atomic.StoreUint32(&t.disabled, 1)
}
}
}
return
}
newToken := ec2Token{
token: output.Token,
}
newToken.SetExpiration(time.Now().Add(output.TTL), ttlExpirationWindow)
t.token.Store(newToken)
// Inject token header to the request.
if ec2Token, ok := t.token.Load().(ec2Token); ok {
r.HTTPRequest.Header.Set(tokenHeader, ec2Token.token)
}
}
// enableTokenProviderHandler enables the token provider
func (t *tokenProvider) enableTokenProviderHandler(r *request.Request) {
// If the error code status is 401, we enable the token provider
if e, ok := r.Error.(awserr.RequestFailure); ok && e != nil &&
e.StatusCode() == http.StatusUnauthorized {
atomic.StoreUint32(&t.disabled, 0)
}
}

View File

@@ -8,6 +8,7 @@ go_library(
"dep_service_ids.go", "dep_service_ids.go",
"doc.go", "doc.go",
"endpoints.go", "endpoints.go",
"legacy_regions.go",
"v3model.go", "v3model.go",
], ],
importmap = "k8s.io/kubernetes/vendor/github.com/aws/aws-sdk-go/aws/endpoints", importmap = "k8s.io/kubernetes/vendor/github.com/aws/aws-sdk-go/aws/endpoints",

View File

@@ -83,6 +83,7 @@ func decodeV3Endpoints(modelDef modelDefinition, opts DecodeModelOptions) (Resol
p := &ps[i] p := &ps[i]
custAddEC2Metadata(p) custAddEC2Metadata(p)
custAddS3DualStack(p) custAddS3DualStack(p)
custRegionalS3(p)
custRmIotDataService(p) custRmIotDataService(p)
custFixAppAutoscalingChina(p) custFixAppAutoscalingChina(p)
custFixAppAutoscalingUsGov(p) custFixAppAutoscalingUsGov(p)
@@ -100,6 +101,33 @@ func custAddS3DualStack(p *partition) {
custAddDualstack(p, "s3-control") custAddDualstack(p, "s3-control")
} }
func custRegionalS3(p *partition) {
if p.ID != "aws" {
return
}
service, ok := p.Services["s3"]
if !ok {
return
}
// If global endpoint already exists no customization needed.
if _, ok := service.Endpoints["aws-global"]; ok {
return
}
service.PartitionEndpoint = "aws-global"
service.Endpoints["us-east-1"] = endpoint{}
service.Endpoints["aws-global"] = endpoint{
Hostname: "s3.amazonaws.com",
CredentialScope: credentialScope{
Region: "us-east-1",
},
}
p.Services["s3"] = service
}
func custAddDualstack(p *partition, svcName string) { func custAddDualstack(p *partition, svcName string) {
s, ok := p.Services[svcName] s, ok := p.Services[svcName]
if !ok { if !ok {

File diff suppressed because it is too large Load Diff

View File

@@ -2,7 +2,7 @@ package endpoints
// Service identifiers // Service identifiers
// //
// Deprecated: Use client package's EndpointID value instead of these // Deprecated: Use client package's EndpointsID value instead of these
// ServiceIDs. These IDs are not maintained, and are out of date. // ServiceIDs. These IDs are not maintained, and are out of date.
const ( const (
A4bServiceID = "a4b" // A4b. A4bServiceID = "a4b" // A4b.

View File

@@ -3,6 +3,7 @@ package endpoints
import ( import (
"fmt" "fmt"
"regexp" "regexp"
"strings"
"github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/awserr"
) )
@@ -46,6 +47,108 @@ type Options struct {
// //
// This option is ignored if StrictMatching is enabled. // This option is ignored if StrictMatching is enabled.
ResolveUnknownService bool ResolveUnknownService bool
// STS Regional Endpoint flag helps with resolving the STS endpoint
STSRegionalEndpoint STSRegionalEndpoint
// S3 Regional Endpoint flag helps with resolving the S3 endpoint
S3UsEast1RegionalEndpoint S3UsEast1RegionalEndpoint
}
// STSRegionalEndpoint is an enum for the states of the STS Regional Endpoint
// options.
type STSRegionalEndpoint int
func (e STSRegionalEndpoint) String() string {
switch e {
case LegacySTSEndpoint:
return "legacy"
case RegionalSTSEndpoint:
return "regional"
case UnsetSTSEndpoint:
return ""
default:
return "unknown"
}
}
const (
// UnsetSTSEndpoint represents that STS Regional Endpoint flag is not specified.
UnsetSTSEndpoint STSRegionalEndpoint = iota
// LegacySTSEndpoint represents when STS Regional Endpoint flag is specified
// to use legacy endpoints.
LegacySTSEndpoint
// RegionalSTSEndpoint represents when STS Regional Endpoint flag is specified
// to use regional endpoints.
RegionalSTSEndpoint
)
// GetSTSRegionalEndpoint function returns the STSRegionalEndpointFlag based
// on the input string provided in env config or shared config by the user.
//
// `legacy`, `regional` are the only case-insensitive valid strings for
// resolving the STS regional Endpoint flag.
func GetSTSRegionalEndpoint(s string) (STSRegionalEndpoint, error) {
switch {
case strings.EqualFold(s, "legacy"):
return LegacySTSEndpoint, nil
case strings.EqualFold(s, "regional"):
return RegionalSTSEndpoint, nil
default:
return UnsetSTSEndpoint, fmt.Errorf("unable to resolve the value of STSRegionalEndpoint for %v", s)
}
}
// S3UsEast1RegionalEndpoint is an enum for the states of the S3 us-east-1
// Regional Endpoint options.
type S3UsEast1RegionalEndpoint int
func (e S3UsEast1RegionalEndpoint) String() string {
switch e {
case LegacyS3UsEast1Endpoint:
return "legacy"
case RegionalS3UsEast1Endpoint:
return "regional"
case UnsetS3UsEast1Endpoint:
return ""
default:
return "unknown"
}
}
const (
// UnsetS3UsEast1Endpoint represents that S3 Regional Endpoint flag is not
// specified.
UnsetS3UsEast1Endpoint S3UsEast1RegionalEndpoint = iota
// LegacyS3UsEast1Endpoint represents when S3 Regional Endpoint flag is
// specified to use legacy endpoints.
LegacyS3UsEast1Endpoint
// RegionalS3UsEast1Endpoint represents when S3 Regional Endpoint flag is
// specified to use regional endpoints.
RegionalS3UsEast1Endpoint
)
// GetS3UsEast1RegionalEndpoint function returns the S3UsEast1RegionalEndpointFlag based
// on the input string provided in env config or shared config by the user.
//
// `legacy`, `regional` are the only case-insensitive valid strings for
// resolving the S3 regional Endpoint flag.
func GetS3UsEast1RegionalEndpoint(s string) (S3UsEast1RegionalEndpoint, error) {
switch {
case strings.EqualFold(s, "legacy"):
return LegacyS3UsEast1Endpoint, nil
case strings.EqualFold(s, "regional"):
return RegionalS3UsEast1Endpoint, nil
default:
return UnsetS3UsEast1Endpoint,
fmt.Errorf("unable to resolve the value of S3UsEast1RegionalEndpoint for %v", s)
}
} }
// Set combines all of the option functions together. // Set combines all of the option functions together.
@@ -79,6 +182,12 @@ func ResolveUnknownServiceOption(o *Options) {
o.ResolveUnknownService = true o.ResolveUnknownService = true
} }
// STSRegionalEndpointOption enables the STS endpoint resolver behavior to resolve
// STS endpoint to their regional endpoint, instead of the global endpoint.
func STSRegionalEndpointOption(o *Options) {
o.STSRegionalEndpoint = RegionalSTSEndpoint
}
// A Resolver provides the interface for functionality to resolve endpoints. // A Resolver provides the interface for functionality to resolve endpoints.
// The build in Partition and DefaultResolver return value satisfy this interface. // The build in Partition and DefaultResolver return value satisfy this interface.
type Resolver interface { type Resolver interface {
@@ -170,10 +279,13 @@ func PartitionForRegion(ps []Partition, regionID string) (Partition, bool) {
// A Partition provides the ability to enumerate the partition's regions // A Partition provides the ability to enumerate the partition's regions
// and services. // and services.
type Partition struct { type Partition struct {
id string id, dnsSuffix string
p *partition p *partition
} }
// DNSSuffix returns the base domain name of the partition.
func (p Partition) DNSSuffix() string { return p.dnsSuffix }
// ID returns the identifier of the partition. // ID returns the identifier of the partition.
func (p Partition) ID() string { return p.id } func (p Partition) ID() string { return p.id }
@@ -191,7 +303,7 @@ func (p Partition) ID() string { return p.id }
// require the provided service and region to be known by the partition. // require the provided service and region to be known by the partition.
// If the endpoint cannot be strictly resolved an error will be returned. This // If the endpoint cannot be strictly resolved an error will be returned. This
// mode is useful to ensure the endpoint resolved is valid. Without // mode is useful to ensure the endpoint resolved is valid. Without
// StrictMatching enabled the endpoint returned my look valid but may not work. // StrictMatching enabled the endpoint returned may look valid but may not work.
// StrictMatching requires the SDK to be updated if you want to take advantage // StrictMatching requires the SDK to be updated if you want to take advantage
// of new regions and services expansions. // of new regions and services expansions.
// //
@@ -205,7 +317,7 @@ func (p Partition) EndpointFor(service, region string, opts ...func(*Options)) (
// Regions returns a map of Regions indexed by their ID. This is useful for // Regions returns a map of Regions indexed by their ID. This is useful for
// enumerating over the regions in a partition. // enumerating over the regions in a partition.
func (p Partition) Regions() map[string]Region { func (p Partition) Regions() map[string]Region {
rs := map[string]Region{} rs := make(map[string]Region, len(p.p.Regions))
for id, r := range p.p.Regions { for id, r := range p.p.Regions {
rs[id] = Region{ rs[id] = Region{
id: id, id: id,
@@ -220,7 +332,7 @@ func (p Partition) Regions() map[string]Region {
// Services returns a map of Service indexed by their ID. This is useful for // Services returns a map of Service indexed by their ID. This is useful for
// enumerating over the services in a partition. // enumerating over the services in a partition.
func (p Partition) Services() map[string]Service { func (p Partition) Services() map[string]Service {
ss := map[string]Service{} ss := make(map[string]Service, len(p.p.Services))
for id := range p.p.Services { for id := range p.p.Services {
ss[id] = Service{ ss[id] = Service{
id: id, id: id,
@@ -307,7 +419,7 @@ func (s Service) Regions() map[string]Region {
// A region is the AWS region the service exists in. Whereas a Endpoint is // A region is the AWS region the service exists in. Whereas a Endpoint is
// an URL that can be resolved to a instance of a service. // an URL that can be resolved to a instance of a service.
func (s Service) Endpoints() map[string]Endpoint { func (s Service) Endpoints() map[string]Endpoint {
es := map[string]Endpoint{} es := make(map[string]Endpoint, len(s.p.Services[s.id].Endpoints))
for id := range s.p.Services[s.id].Endpoints { for id := range s.p.Services[s.id].Endpoints {
es[id] = Endpoint{ es[id] = Endpoint{
id: id, id: id,
@@ -347,6 +459,9 @@ type ResolvedEndpoint struct {
// The endpoint URL // The endpoint URL
URL string URL string
// The endpoint partition
PartitionID string
// The region that should be used for signing requests. // The region that should be used for signing requests.
SigningRegion string SigningRegion string

View File

@@ -0,0 +1,24 @@
package endpoints
var legacyGlobalRegions = map[string]map[string]struct{}{
"sts": {
"ap-northeast-1": {},
"ap-south-1": {},
"ap-southeast-1": {},
"ap-southeast-2": {},
"ca-central-1": {},
"eu-central-1": {},
"eu-north-1": {},
"eu-west-1": {},
"eu-west-2": {},
"eu-west-3": {},
"sa-east-1": {},
"us-east-1": {},
"us-east-2": {},
"us-west-1": {},
"us-west-2": {},
},
"s3": {
"us-east-1": {},
},
}

View File

@@ -54,6 +54,7 @@ type partition struct {
func (p partition) Partition() Partition { func (p partition) Partition() Partition {
return Partition{ return Partition{
dnsSuffix: p.DNSSuffix,
id: p.ID, id: p.ID,
p: &p, p: &p,
} }
@@ -74,24 +75,56 @@ func (p partition) canResolveEndpoint(service, region string, strictMatch bool)
return p.RegionRegex.MatchString(region) return p.RegionRegex.MatchString(region)
} }
func allowLegacyEmptyRegion(service string) bool {
legacy := map[string]struct{}{
"budgets": {},
"ce": {},
"chime": {},
"cloudfront": {},
"ec2metadata": {},
"iam": {},
"importexport": {},
"organizations": {},
"route53": {},
"sts": {},
"support": {},
"waf": {},
}
_, allowed := legacy[service]
return allowed
}
func (p partition) EndpointFor(service, region string, opts ...func(*Options)) (resolved ResolvedEndpoint, err error) { func (p partition) EndpointFor(service, region string, opts ...func(*Options)) (resolved ResolvedEndpoint, err error) {
var opt Options var opt Options
opt.Set(opts...) opt.Set(opts...)
s, hasService := p.Services[service] s, hasService := p.Services[service]
if !(hasService || opt.ResolveUnknownService) { if len(service) == 0 || !(hasService || opt.ResolveUnknownService) {
// Only return error if the resolver will not fallback to creating // Only return error if the resolver will not fallback to creating
// endpoint based on service endpoint ID passed in. // endpoint based on service endpoint ID passed in.
return resolved, NewUnknownServiceError(p.ID, service, serviceList(p.Services)) return resolved, NewUnknownServiceError(p.ID, service, serviceList(p.Services))
} }
if len(region) == 0 && allowLegacyEmptyRegion(service) && len(s.PartitionEndpoint) != 0 {
region = s.PartitionEndpoint
}
if (service == "sts" && opt.STSRegionalEndpoint != RegionalSTSEndpoint) ||
(service == "s3" && opt.S3UsEast1RegionalEndpoint != RegionalS3UsEast1Endpoint) {
if _, ok := legacyGlobalRegions[service][region]; ok {
region = "aws-global"
}
}
e, hasEndpoint := s.endpointForRegion(region) e, hasEndpoint := s.endpointForRegion(region)
if !hasEndpoint && opt.StrictMatching { if len(region) == 0 || (!hasEndpoint && opt.StrictMatching) {
return resolved, NewUnknownEndpointError(p.ID, service, region, endpointList(s.Endpoints)) return resolved, NewUnknownEndpointError(p.ID, service, region, endpointList(s.Endpoints))
} }
defs := []endpoint{p.Defaults, s.Defaults} defs := []endpoint{p.Defaults, s.Defaults}
return e.resolve(service, region, p.DNSSuffix, defs, opt), nil
return e.resolve(service, p.ID, region, p.DNSSuffix, defs, opt), nil
} }
func serviceList(ss services) []string { func serviceList(ss services) []string {
@@ -200,7 +233,7 @@ func getByPriority(s []string, p []string, def string) string {
return s[0] return s[0]
} }
func (e endpoint) resolve(service, region, dnsSuffix string, defs []endpoint, opts Options) ResolvedEndpoint { func (e endpoint) resolve(service, partitionID, region, dnsSuffix string, defs []endpoint, opts Options) ResolvedEndpoint {
var merged endpoint var merged endpoint
for _, def := range defs { for _, def := range defs {
merged.mergeIn(def) merged.mergeIn(def)
@@ -208,20 +241,6 @@ func (e endpoint) resolve(service, region, dnsSuffix string, defs []endpoint, op
merged.mergeIn(e) merged.mergeIn(e)
e = merged e = merged
hostname := e.Hostname
// Offset the hostname for dualstack if enabled
if opts.UseDualStack && e.HasDualStack == boxedTrue {
hostname = e.DualStackHostname
}
u := strings.Replace(hostname, "{service}", service, 1)
u = strings.Replace(u, "{region}", region, 1)
u = strings.Replace(u, "{dnsSuffix}", dnsSuffix, 1)
scheme := getEndpointScheme(e.Protocols, opts.DisableSSL)
u = fmt.Sprintf("%s://%s", scheme, u)
signingRegion := e.CredentialScope.Region signingRegion := e.CredentialScope.Region
if len(signingRegion) == 0 { if len(signingRegion) == 0 {
signingRegion = region signingRegion = region
@@ -234,8 +253,23 @@ func (e endpoint) resolve(service, region, dnsSuffix string, defs []endpoint, op
signingNameDerived = true signingNameDerived = true
} }
hostname := e.Hostname
// Offset the hostname for dualstack if enabled
if opts.UseDualStack && e.HasDualStack == boxedTrue {
hostname = e.DualStackHostname
region = signingRegion
}
u := strings.Replace(hostname, "{service}", service, 1)
u = strings.Replace(u, "{region}", region, 1)
u = strings.Replace(u, "{dnsSuffix}", dnsSuffix, 1)
scheme := getEndpointScheme(e.Protocols, opts.DisableSSL)
u = fmt.Sprintf("%s://%s", scheme, u)
return ResolvedEndpoint{ return ResolvedEndpoint{
URL: u, URL: u,
PartitionID: partitionID,
SigningRegion: signingRegion, SigningRegion: signingRegion,
SigningName: signingName, SigningName: signingName,
SigningNameDerived: signingNameDerived, SigningNameDerived: signingNameDerived,

View File

@@ -4,7 +4,6 @@ go_library(
name = "go_default_library", name = "go_default_library",
srcs = [ srcs = [
"connection_reset_error.go", "connection_reset_error.go",
"connection_reset_error_other.go",
"handlers.go", "handlers.go",
"http_request.go", "http_request.go",
"offset_reader.go", "offset_reader.go",

View File

@@ -1,18 +1,17 @@
// +build !appengine,!plan9
package request package request
import ( import (
"net" "strings"
"os"
"syscall"
) )
func isErrConnectionReset(err error) bool { func isErrConnectionReset(err error) bool {
if opErr, ok := err.(*net.OpError); ok { if strings.Contains(err.Error(), "read: connection reset") {
if sysErr, ok := opErr.Err.(*os.SyscallError); ok { return false
return sysErr.Err == syscall.ECONNRESET
} }
if strings.Contains(err.Error(), "connection reset") ||
strings.Contains(err.Error(), "broken pipe") {
return true
} }
return false return false

View File

@@ -1,11 +0,0 @@
// +build appengine plan9
package request
import (
"strings"
)
func isErrConnectionReset(err error) bool {
return strings.Contains(err.Error(), "connection reset")
}

View File

@@ -10,6 +10,7 @@ import (
type Handlers struct { type Handlers struct {
Validate HandlerList Validate HandlerList
Build HandlerList Build HandlerList
BuildStream HandlerList
Sign HandlerList Sign HandlerList
Send HandlerList Send HandlerList
ValidateResponse HandlerList ValidateResponse HandlerList
@@ -23,11 +24,12 @@ type Handlers struct {
Complete HandlerList Complete HandlerList
} }
// Copy returns of this handler's lists. // Copy returns a copy of this handler's lists.
func (h *Handlers) Copy() Handlers { func (h *Handlers) Copy() Handlers {
return Handlers{ return Handlers{
Validate: h.Validate.copy(), Validate: h.Validate.copy(),
Build: h.Build.copy(), Build: h.Build.copy(),
BuildStream: h.BuildStream.copy(),
Sign: h.Sign.copy(), Sign: h.Sign.copy(),
Send: h.Send.copy(), Send: h.Send.copy(),
ValidateResponse: h.ValidateResponse.copy(), ValidateResponse: h.ValidateResponse.copy(),
@@ -42,10 +44,11 @@ func (h *Handlers) Copy() Handlers {
} }
} }
// Clear removes callback functions for all handlers // Clear removes callback functions for all handlers.
func (h *Handlers) Clear() { func (h *Handlers) Clear() {
h.Validate.Clear() h.Validate.Clear()
h.Build.Clear() h.Build.Clear()
h.BuildStream.Clear()
h.Send.Clear() h.Send.Clear()
h.Sign.Clear() h.Sign.Clear()
h.Unmarshal.Clear() h.Unmarshal.Clear()
@@ -59,6 +62,54 @@ func (h *Handlers) Clear() {
h.Complete.Clear() h.Complete.Clear()
} }
// IsEmpty returns if there are no handlers in any of the handlerlists.
func (h *Handlers) IsEmpty() bool {
if h.Validate.Len() != 0 {
return false
}
if h.Build.Len() != 0 {
return false
}
if h.BuildStream.Len() != 0 {
return false
}
if h.Send.Len() != 0 {
return false
}
if h.Sign.Len() != 0 {
return false
}
if h.Unmarshal.Len() != 0 {
return false
}
if h.UnmarshalStream.Len() != 0 {
return false
}
if h.UnmarshalMeta.Len() != 0 {
return false
}
if h.UnmarshalError.Len() != 0 {
return false
}
if h.ValidateResponse.Len() != 0 {
return false
}
if h.Retry.Len() != 0 {
return false
}
if h.AfterRetry.Len() != 0 {
return false
}
if h.CompleteAttempt.Len() != 0 {
return false
}
if h.Complete.Len() != 0 {
return false
}
return true
}
// A HandlerListRunItem represents an entry in the HandlerList which // A HandlerListRunItem represents an entry in the HandlerList which
// is being run. // is being run.
type HandlerListRunItem struct { type HandlerListRunItem struct {
@@ -275,3 +326,18 @@ func MakeAddToUserAgentFreeFormHandler(s string) func(*Request) {
AddToUserAgent(r, s) AddToUserAgent(r, s)
} }
} }
// WithSetRequestHeaders updates the operation request's HTTP header to contain
// the header key value pairs provided. If the header key already exists in the
// request's HTTP header set, the existing value(s) will be replaced.
func WithSetRequestHeaders(h map[string]string) Option {
return withRequestHeader(h).SetRequestHeaders
}
type withRequestHeader map[string]string
func (h withRequestHeader) SetRequestHeaders(r *Request) {
for k, v := range h {
r.HTTPRequest.Header[k] = []string{v}
}
}

View File

@@ -15,12 +15,15 @@ type offsetReader struct {
closed bool closed bool
} }
func newOffsetReader(buf io.ReadSeeker, offset int64) *offsetReader { func newOffsetReader(buf io.ReadSeeker, offset int64) (*offsetReader, error) {
reader := &offsetReader{} reader := &offsetReader{}
buf.Seek(offset, sdkio.SeekStart) _, err := buf.Seek(offset, sdkio.SeekStart)
if err != nil {
return nil, err
}
reader.buf = buf reader.buf = buf
return reader return reader, nil
} }
// Close will close the instance of the offset reader's access to // Close will close the instance of the offset reader's access to
@@ -54,7 +57,9 @@ func (o *offsetReader) Seek(offset int64, whence int) (int64, error) {
// CloseAndCopy will return a new offsetReader with a copy of the old buffer // CloseAndCopy will return a new offsetReader with a copy of the old buffer
// and close the old buffer. // and close the old buffer.
func (o *offsetReader) CloseAndCopy(offset int64) *offsetReader { func (o *offsetReader) CloseAndCopy(offset int64) (*offsetReader, error) {
o.Close() if err := o.Close(); err != nil {
return nil, err
}
return newOffsetReader(o.buf, offset) return newOffsetReader(o.buf, offset)
} }

View File

@@ -36,6 +36,10 @@ const (
// API request that was canceled. Requests given a aws.Context may // API request that was canceled. Requests given a aws.Context may
// return this error when canceled. // return this error when canceled.
CanceledErrorCode = "RequestCanceled" CanceledErrorCode = "RequestCanceled"
// ErrCodeRequestError is an error preventing the SDK from continuing to
// process the request.
ErrCodeRequestError = "RequestError"
) )
// A Request is the service request to be made. // A Request is the service request to be made.
@@ -51,6 +55,7 @@ type Request struct {
HTTPRequest *http.Request HTTPRequest *http.Request
HTTPResponse *http.Response HTTPResponse *http.Response
Body io.ReadSeeker Body io.ReadSeeker
streamingBody io.ReadCloser
BodyStart int64 // offset from beginning of Body that the request body starts BodyStart int64 // offset from beginning of Body that the request body starts
Params interface{} Params interface{}
Error error Error error
@@ -64,6 +69,15 @@ type Request struct {
LastSignedAt time.Time LastSignedAt time.Time
DisableFollowRedirects bool DisableFollowRedirects bool
// Additional API error codes that should be retried. IsErrorRetryable
// will consider these codes in addition to its built in cases.
RetryErrorCodes []string
// Additional API error codes that should be retried with throttle backoff
// delay. IsErrorThrottle will consider these codes in addition to its
// built in cases.
ThrottleErrorCodes []string
// A value greater than 0 instructs the request to be signed as Presigned URL // A value greater than 0 instructs the request to be signed as Presigned URL
// You should not set this field directly. Instead use Request's // You should not set this field directly. Instead use Request's
// Presign or PresignRequest methods. // Presign or PresignRequest methods.
@@ -90,8 +104,12 @@ type Operation struct {
BeforePresignFn func(r *Request) error BeforePresignFn func(r *Request) error
} }
// New returns a new Request pointer for the service API // New returns a new Request pointer for the service API operation and
// operation and parameters. // parameters.
//
// A Retryer should be provided to direct how the request is retried. If
// Retryer is nil, a default no retry value will be used. You can use
// NoOpRetryer in the Client package to disable retry behavior directly.
// //
// Params is any value of input parameters to be the request payload. // Params is any value of input parameters to be the request payload.
// Data is pointer value to an object which the request's response // Data is pointer value to an object which the request's response
@@ -99,6 +117,10 @@ type Operation struct {
func New(cfg aws.Config, clientInfo metadata.ClientInfo, handlers Handlers, func New(cfg aws.Config, clientInfo metadata.ClientInfo, handlers Handlers,
retryer Retryer, operation *Operation, params interface{}, data interface{}) *Request { retryer Retryer, operation *Operation, params interface{}, data interface{}) *Request {
if retryer == nil {
retryer = noOpRetryer{}
}
method := operation.HTTPMethod method := operation.HTTPMethod
if method == "" { if method == "" {
method = "POST" method = "POST"
@@ -231,6 +253,10 @@ func (r *Request) WillRetry() bool {
return r.Error != nil && aws.BoolValue(r.Retryable) && r.RetryCount < r.MaxRetries() return r.Error != nil && aws.BoolValue(r.Retryable) && r.RetryCount < r.MaxRetries()
} }
func fmtAttemptCount(retryCount, maxRetries int) string {
return fmt.Sprintf("attempt %v/%v", retryCount, maxRetries)
}
// ParamsFilled returns if the request's parameters have been populated // ParamsFilled returns if the request's parameters have been populated
// and the parameters are valid. False is returned if no parameters are // and the parameters are valid. False is returned if no parameters are
// provided or invalid. // provided or invalid.
@@ -259,10 +285,28 @@ func (r *Request) SetStringBody(s string) {
// SetReaderBody will set the request's body reader. // SetReaderBody will set the request's body reader.
func (r *Request) SetReaderBody(reader io.ReadSeeker) { func (r *Request) SetReaderBody(reader io.ReadSeeker) {
r.Body = reader r.Body = reader
r.BodyStart, _ = reader.Seek(0, sdkio.SeekCurrent) // Get the Bodies current offset.
if aws.IsReaderSeekable(reader) {
var err error
// Get the Bodies current offset so retries will start from the same
// initial position.
r.BodyStart, err = reader.Seek(0, sdkio.SeekCurrent)
if err != nil {
r.Error = awserr.New(ErrCodeSerialization,
"failed to determine start of request body", err)
return
}
}
r.ResetBody() r.ResetBody()
} }
// SetStreamingBody set the reader to be used for the request that will stream
// bytes to the server. Request's Body must not be set to any reader.
func (r *Request) SetStreamingBody(reader io.ReadCloser) {
r.streamingBody = reader
r.SetReaderBody(aws.ReadSeekCloser(reader))
}
// Presign returns the request's signed URL. Error will be returned // Presign returns the request's signed URL. Error will be returned
// if the signing fails. The expire parameter is only used for presigned Amazon // if the signing fails. The expire parameter is only used for presigned Amazon
// S3 API requests. All other AWS services will use a fixed expiration // S3 API requests. All other AWS services will use a fixed expiration
@@ -330,16 +374,15 @@ func getPresignedURL(r *Request, expire time.Duration) (string, http.Header, err
return r.HTTPRequest.URL.String(), r.SignedHeaderVals, nil return r.HTTPRequest.URL.String(), r.SignedHeaderVals, nil
} }
func debugLogReqError(r *Request, stage string, retrying bool, err error) { const (
notRetrying = "not retrying"
)
func debugLogReqError(r *Request, stage, retryStr string, err error) {
if !r.Config.LogLevel.Matches(aws.LogDebugWithRequestErrors) { if !r.Config.LogLevel.Matches(aws.LogDebugWithRequestErrors) {
return return
} }
retryStr := "not retrying"
if retrying {
retryStr = "will retry"
}
r.Config.Logger.Log(fmt.Sprintf("DEBUG: %s %s/%s failed, %s, error %v", r.Config.Logger.Log(fmt.Sprintf("DEBUG: %s %s/%s failed, %s, error %v",
stage, r.ClientInfo.ServiceName, r.Operation.Name, retryStr, err)) stage, r.ClientInfo.ServiceName, r.Operation.Name, retryStr, err))
} }
@@ -358,12 +401,12 @@ func (r *Request) Build() error {
if !r.built { if !r.built {
r.Handlers.Validate.Run(r) r.Handlers.Validate.Run(r)
if r.Error != nil { if r.Error != nil {
debugLogReqError(r, "Validate Request", false, r.Error) debugLogReqError(r, "Validate Request", notRetrying, r.Error)
return r.Error return r.Error
} }
r.Handlers.Build.Run(r) r.Handlers.Build.Run(r)
if r.Error != nil { if r.Error != nil {
debugLogReqError(r, "Build Request", false, r.Error) debugLogReqError(r, "Build Request", notRetrying, r.Error)
return r.Error return r.Error
} }
r.built = true r.built = true
@@ -379,7 +422,7 @@ func (r *Request) Build() error {
func (r *Request) Sign() error { func (r *Request) Sign() error {
r.Build() r.Build()
if r.Error != nil { if r.Error != nil {
debugLogReqError(r, "Build Request", false, r.Error) debugLogReqError(r, "Build Request", notRetrying, r.Error)
return r.Error return r.Error
} }
@@ -387,12 +430,20 @@ func (r *Request) Sign() error {
return r.Error return r.Error
} }
func (r *Request) getNextRequestBody() (io.ReadCloser, error) { func (r *Request) getNextRequestBody() (body io.ReadCloser, err error) {
if r.streamingBody != nil {
return r.streamingBody, nil
}
if r.safeBody != nil { if r.safeBody != nil {
r.safeBody.Close() r.safeBody.Close()
} }
r.safeBody = newOffsetReader(r.Body, r.BodyStart) r.safeBody, err = newOffsetReader(r.Body, r.BodyStart)
if err != nil {
return nil, awserr.New(ErrCodeSerialization,
"failed to get next request body reader", err)
}
// Go 1.8 tightened and clarified the rules code needs to use when building // Go 1.8 tightened and clarified the rules code needs to use when building
// requests with the http package. Go 1.8 removed the automatic detection // requests with the http package. Go 1.8 removed the automatic detection
@@ -409,10 +460,10 @@ func (r *Request) getNextRequestBody() (io.ReadCloser, error) {
// Related golang/go#18257 // Related golang/go#18257
l, err := aws.SeekerLen(r.Body) l, err := aws.SeekerLen(r.Body)
if err != nil { if err != nil {
return nil, awserr.New(ErrCodeSerialization, "failed to compute request body size", err) return nil, awserr.New(ErrCodeSerialization,
"failed to compute request body size", err)
} }
var body io.ReadCloser
if l == 0 { if l == 0 {
body = NoBody body = NoBody
} else if l > 0 { } else if l > 0 {
@@ -473,15 +524,13 @@ func (r *Request) Send() error {
r.AttemptTime = time.Now() r.AttemptTime = time.Now()
if err := r.Sign(); err != nil { if err := r.Sign(); err != nil {
debugLogReqError(r, "Sign Request", false, err) debugLogReqError(r, "Sign Request", notRetrying, err)
return err return err
} }
if err := r.sendRequest(); err == nil { if err := r.sendRequest(); err == nil {
return nil return nil
} else if !shouldRetryCancel(r.Error) { }
return err
} else {
r.Handlers.Retry.Run(r) r.Handlers.Retry.Run(r)
r.Handlers.AfterRetry.Run(r) r.Handlers.AfterRetry.Run(r)
@@ -489,13 +538,14 @@ func (r *Request) Send() error {
return r.Error return r.Error
} }
r.prepareRetry() if err := r.prepareRetry(); err != nil {
continue r.Error = err
return err
} }
} }
} }
func (r *Request) prepareRetry() { func (r *Request) prepareRetry() error {
if r.Config.LogLevel.Matches(aws.LogDebugWithRequestRetries) { if r.Config.LogLevel.Matches(aws.LogDebugWithRequestRetries) {
r.Config.Logger.Log(fmt.Sprintf("DEBUG: Retrying Request %s/%s, attempt %d", r.Config.Logger.Log(fmt.Sprintf("DEBUG: Retrying Request %s/%s, attempt %d",
r.ClientInfo.ServiceName, r.Operation.Name, r.RetryCount)) r.ClientInfo.ServiceName, r.Operation.Name, r.RetryCount))
@@ -506,12 +556,19 @@ func (r *Request) prepareRetry() {
// the request's body even though the Client's Do returned. // the request's body even though the Client's Do returned.
r.HTTPRequest = copyHTTPRequest(r.HTTPRequest, nil) r.HTTPRequest = copyHTTPRequest(r.HTTPRequest, nil)
r.ResetBody() r.ResetBody()
if err := r.Error; err != nil {
return awserr.New(ErrCodeSerialization,
"failed to prepare body for retry", err)
}
// Closing response body to ensure that no response body is leaked // Closing response body to ensure that no response body is leaked
// between retry attempts. // between retry attempts.
if r.HTTPResponse != nil && r.HTTPResponse.Body != nil { if r.HTTPResponse != nil && r.HTTPResponse.Body != nil {
r.HTTPResponse.Body.Close() r.HTTPResponse.Body.Close()
} }
return nil
} }
func (r *Request) sendRequest() (sendErr error) { func (r *Request) sendRequest() (sendErr error) {
@@ -520,7 +577,9 @@ func (r *Request) sendRequest() (sendErr error) {
r.Retryable = nil r.Retryable = nil
r.Handlers.Send.Run(r) r.Handlers.Send.Run(r)
if r.Error != nil { if r.Error != nil {
debugLogReqError(r, "Send Request", r.WillRetry(), r.Error) debugLogReqError(r, "Send Request",
fmtAttemptCount(r.RetryCount, r.MaxRetries()),
r.Error)
return r.Error return r.Error
} }
@@ -528,13 +587,17 @@ func (r *Request) sendRequest() (sendErr error) {
r.Handlers.ValidateResponse.Run(r) r.Handlers.ValidateResponse.Run(r)
if r.Error != nil { if r.Error != nil {
r.Handlers.UnmarshalError.Run(r) r.Handlers.UnmarshalError.Run(r)
debugLogReqError(r, "Validate Response", r.WillRetry(), r.Error) debugLogReqError(r, "Validate Response",
fmtAttemptCount(r.RetryCount, r.MaxRetries()),
r.Error)
return r.Error return r.Error
} }
r.Handlers.Unmarshal.Run(r) r.Handlers.Unmarshal.Run(r)
if r.Error != nil { if r.Error != nil {
debugLogReqError(r, "Unmarshal Response", r.WillRetry(), r.Error) debugLogReqError(r, "Unmarshal Response",
fmtAttemptCount(r.RetryCount, r.MaxRetries()),
r.Error)
return r.Error return r.Error
} }
@@ -561,48 +624,6 @@ func AddToUserAgent(r *Request, s string) {
r.HTTPRequest.Header.Set("User-Agent", s) r.HTTPRequest.Header.Set("User-Agent", s)
} }
type temporary interface {
Temporary() bool
}
func shouldRetryCancel(err error) bool {
switch err := err.(type) {
case awserr.Error:
if err.Code() == CanceledErrorCode {
return false
}
return shouldRetryCancel(err.OrigErr())
case *url.Error:
if strings.Contains(err.Error(), "connection refused") {
// Refused connections should be retried as the service may not yet
// be running on the port. Go TCP dial considers refused
// connections as not temporary.
return true
}
// *url.Error only implements Temporary after golang 1.6 but since
// url.Error only wraps the error:
return shouldRetryCancel(err.Err)
case temporary:
// If the error is temporary, we want to allow continuation of the
// retry process
return err.Temporary()
case nil:
// `awserr.Error.OrigErr()` can be nil, meaning there was an error but
// because we don't know the cause, it is marked as retriable. See
// TestRequest4xxUnretryable for an example.
return true
default:
switch err.Error() {
case "net/http: request canceled",
"net/http: request canceled while waiting for connection":
// known 1.5 error case when an http request is cancelled
return false
}
// here we don't know the error; so we allow a retry.
return true
}
}
// SanitizeHostForHeader removes default port from host and updates request.Host // SanitizeHostForHeader removes default port from host and updates request.Host
func SanitizeHostForHeader(r *http.Request) { func SanitizeHostForHeader(r *http.Request) {
host := getHost(r) host := getHost(r)

View File

@@ -4,6 +4,8 @@ package request
import ( import (
"net/http" "net/http"
"github.com/aws/aws-sdk-go/aws/awserr"
) )
// NoBody is a http.NoBody reader instructing Go HTTP client to not include // NoBody is a http.NoBody reader instructing Go HTTP client to not include
@@ -24,7 +26,8 @@ var NoBody = http.NoBody
func (r *Request) ResetBody() { func (r *Request) ResetBody() {
body, err := r.getNextRequestBody() body, err := r.getNextRequestBody()
if err != nil { if err != nil {
r.Error = err r.Error = awserr.New(ErrCodeSerialization,
"failed to reset request body", err)
return return
} }

View File

@@ -17,11 +17,13 @@ import (
// does the pagination between API operations, and Paginator defines the // does the pagination between API operations, and Paginator defines the
// configuration that will be used per page request. // configuration that will be used per page request.
// //
// cont := true // for p.Next() {
// for p.Next() && cont {
// data := p.Page().(*s3.ListObjectsOutput) // data := p.Page().(*s3.ListObjectsOutput)
// // process the page's data // // process the page's data
// // ...
// // break out of loop to stop fetching additional pages
// } // }
//
// return p.Err() // return p.Err()
// //
// See service client API operation Pages methods for examples how the SDK will // See service client API operation Pages methods for examples how the SDK will
@@ -146,7 +148,7 @@ func (r *Request) nextPageTokens() []interface{} {
return nil return nil
} }
case bool: case bool:
if v == false { if !v {
return nil return nil
} }
} }

View File

@@ -1,32 +1,81 @@
package request package request
import ( import (
"net"
"net/url"
"strings"
"time" "time"
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/awserr"
) )
// Retryer is an interface to control retry logic for a given service. // Retryer provides the interface drive the SDK's request retry behavior. The
// The default implementation used by most services is the client.DefaultRetryer // Retryer implementation is responsible for implementing exponential backoff,
// structure, which contains basic retry logic using exponential backoff. // and determine if a request API error should be retried.
//
// client.DefaultRetryer is the SDK's default implementation of the Retryer. It
// uses the which uses the Request.IsErrorRetryable and Request.IsErrorThrottle
// methods to determine if the request is retried.
type Retryer interface { type Retryer interface {
// RetryRules return the retry delay that should be used by the SDK before
// making another request attempt for the failed request.
RetryRules(*Request) time.Duration RetryRules(*Request) time.Duration
// ShouldRetry returns if the failed request is retryable.
//
// Implementations may consider request attempt count when determining if a
// request is retryable, but the SDK will use MaxRetries to limit the
// number of attempts a request are made.
ShouldRetry(*Request) bool ShouldRetry(*Request) bool
// MaxRetries is the number of times a request may be retried before
// failing.
MaxRetries() int MaxRetries() int
} }
// WithRetryer sets a config Retryer value to the given Config returning it // WithRetryer sets a Retryer value to the given Config returning the Config
// for chaining. // value for chaining. The value must not be nil.
func WithRetryer(cfg *aws.Config, retryer Retryer) *aws.Config { func WithRetryer(cfg *aws.Config, retryer Retryer) *aws.Config {
if retryer == nil {
if cfg.Logger != nil {
cfg.Logger.Log("ERROR: Request.WithRetryer called with nil retryer. Replacing with retry disabled Retryer.")
}
retryer = noOpRetryer{}
}
cfg.Retryer = retryer cfg.Retryer = retryer
return cfg return cfg
}
// noOpRetryer is a internal no op retryer used when a request is created
// without a retryer.
//
// Provides a retryer that performs no retries.
// It should be used when we do not want retries to be performed.
type noOpRetryer struct{}
// MaxRetries returns the number of maximum returns the service will use to make
// an individual API; For NoOpRetryer the MaxRetries will always be zero.
func (d noOpRetryer) MaxRetries() int {
return 0
}
// ShouldRetry will always return false for NoOpRetryer, as it should never retry.
func (d noOpRetryer) ShouldRetry(_ *Request) bool {
return false
}
// RetryRules returns the delay duration before retrying this request again;
// since NoOpRetryer does not retry, RetryRules always returns 0.
func (d noOpRetryer) RetryRules(_ *Request) time.Duration {
return 0
} }
// retryableCodes is a collection of service response codes which are retry-able // retryableCodes is a collection of service response codes which are retry-able
// without any further action. // without any further action.
var retryableCodes = map[string]struct{}{ var retryableCodes = map[string]struct{}{
"RequestError": {}, ErrCodeRequestError: {},
"RequestTimeout": {}, "RequestTimeout": {},
ErrCodeResponseTimeout: {}, ErrCodeResponseTimeout: {},
"RequestTimeoutException": {}, // Glacier's flavor of RequestTimeout "RequestTimeoutException": {}, // Glacier's flavor of RequestTimeout
@@ -34,10 +83,12 @@ var retryableCodes = map[string]struct{}{
var throttleCodes = map[string]struct{}{ var throttleCodes = map[string]struct{}{
"ProvisionedThroughputExceededException": {}, "ProvisionedThroughputExceededException": {},
"ThrottledException": {}, // SNS, XRay, ResourceGroupsTagging API
"Throttling": {}, "Throttling": {},
"ThrottlingException": {}, "ThrottlingException": {},
"RequestLimitExceeded": {}, "RequestLimitExceeded": {},
"RequestThrottled": {}, "RequestThrottled": {},
"RequestThrottledException": {},
"TooManyRequestsException": {}, // Lambda functions "TooManyRequestsException": {}, // Lambda functions
"PriorRequestNotComplete": {}, // Route53 "PriorRequestNotComplete": {}, // Route53
"TransactionInProgressException": {}, "TransactionInProgressException": {},
@@ -75,10 +126,6 @@ var validParentCodes = map[string]struct{}{
ErrCodeRead: {}, ErrCodeRead: {},
} }
type temporaryError interface {
Temporary() bool
}
func isNestedErrorRetryable(parentErr awserr.Error) bool { func isNestedErrorRetryable(parentErr awserr.Error) bool {
if parentErr == nil { if parentErr == nil {
return false return false
@@ -97,7 +144,7 @@ func isNestedErrorRetryable(parentErr awserr.Error) bool {
return isCodeRetryable(aerr.Code()) return isCodeRetryable(aerr.Code())
} }
if t, ok := err.(temporaryError); ok { if t, ok := err.(temporary); ok {
return t.Temporary() || isErrConnectionReset(err) return t.Temporary() || isErrConnectionReset(err)
} }
@@ -107,33 +154,91 @@ func isNestedErrorRetryable(parentErr awserr.Error) bool {
// IsErrorRetryable returns whether the error is retryable, based on its Code. // IsErrorRetryable returns whether the error is retryable, based on its Code.
// Returns false if error is nil. // Returns false if error is nil.
func IsErrorRetryable(err error) bool { func IsErrorRetryable(err error) bool {
if err != nil { if err == nil {
if aerr, ok := err.(awserr.Error); ok {
return isCodeRetryable(aerr.Code()) || isNestedErrorRetryable(aerr)
}
}
return false return false
}
return shouldRetryError(err)
}
type temporary interface {
Temporary() bool
}
func shouldRetryError(origErr error) bool {
switch err := origErr.(type) {
case awserr.Error:
if err.Code() == CanceledErrorCode {
return false
}
if isNestedErrorRetryable(err) {
return true
}
origErr := err.OrigErr()
var shouldRetry bool
if origErr != nil {
shouldRetry = shouldRetryError(origErr)
if err.Code() == ErrCodeRequestError && !shouldRetry {
return false
}
}
if isCodeRetryable(err.Code()) {
return true
}
return shouldRetry
case *url.Error:
if strings.Contains(err.Error(), "connection refused") {
// Refused connections should be retried as the service may not yet
// be running on the port. Go TCP dial considers refused
// connections as not temporary.
return true
}
// *url.Error only implements Temporary after golang 1.6 but since
// url.Error only wraps the error:
return shouldRetryError(err.Err)
case temporary:
if netErr, ok := err.(*net.OpError); ok && netErr.Op == "dial" {
return true
}
// If the error is temporary, we want to allow continuation of the
// retry process
return err.Temporary() || isErrConnectionReset(origErr)
case nil:
// `awserr.Error.OrigErr()` can be nil, meaning there was an error but
// because we don't know the cause, it is marked as retryable. See
// TestRequest4xxUnretryable for an example.
return true
default:
switch err.Error() {
case "net/http: request canceled",
"net/http: request canceled while waiting for connection":
// known 1.5 error case when an http request is cancelled
return false
}
// here we don't know the error; so we allow a retry.
return true
}
} }
// IsErrorThrottle returns whether the error is to be throttled based on its code. // IsErrorThrottle returns whether the error is to be throttled based on its code.
// Returns false if error is nil. // Returns false if error is nil.
func IsErrorThrottle(err error) bool { func IsErrorThrottle(err error) bool {
if err != nil { if aerr, ok := err.(awserr.Error); ok && aerr != nil {
if aerr, ok := err.(awserr.Error); ok {
return isCodeThrottle(aerr.Code()) return isCodeThrottle(aerr.Code())
} }
}
return false return false
} }
// IsErrorExpiredCreds returns whether the error code is a credential expiry error. // IsErrorExpiredCreds returns whether the error code is a credential expiry
// Returns false if error is nil. // error. Returns false if error is nil.
func IsErrorExpiredCreds(err error) bool { func IsErrorExpiredCreds(err error) bool {
if err != nil { if aerr, ok := err.(awserr.Error); ok && aerr != nil {
if aerr, ok := err.(awserr.Error); ok {
return isCodeExpiredCreds(aerr.Code()) return isCodeExpiredCreds(aerr.Code())
} }
}
return false return false
} }
@@ -142,17 +247,58 @@ func IsErrorExpiredCreds(err error) bool {
// //
// Alias for the utility function IsErrorRetryable // Alias for the utility function IsErrorRetryable
func (r *Request) IsErrorRetryable() bool { func (r *Request) IsErrorRetryable() bool {
if isErrCode(r.Error, r.RetryErrorCodes) {
return true
}
// HTTP response status code 501 should not be retried.
// 501 represents Not Implemented which means the request method is not
// supported by the server and cannot be handled.
if r.HTTPResponse != nil {
// HTTP response status code 500 represents internal server error and
// should be retried without any throttle.
if r.HTTPResponse.StatusCode == 500 {
return true
}
}
return IsErrorRetryable(r.Error) return IsErrorRetryable(r.Error)
} }
// IsErrorThrottle returns whether the error is to be throttled based on its code. // IsErrorThrottle returns whether the error is to be throttled based on its
// Returns false if the request has no Error set // code. Returns false if the request has no Error set.
// //
// Alias for the utility function IsErrorThrottle // Alias for the utility function IsErrorThrottle
func (r *Request) IsErrorThrottle() bool { func (r *Request) IsErrorThrottle() bool {
if isErrCode(r.Error, r.ThrottleErrorCodes) {
return true
}
if r.HTTPResponse != nil {
switch r.HTTPResponse.StatusCode {
case
429, // error caused due to too many requests
502, // Bad Gateway error should be throttled
503, // caused when service is unavailable
504: // error occurred due to gateway timeout
return true
}
}
return IsErrorThrottle(r.Error) return IsErrorThrottle(r.Error)
} }
func isErrCode(err error, codes []string) bool {
if aerr, ok := err.(awserr.Error); ok && aerr != nil {
for _, code := range codes {
if code == aerr.Code() {
return true
}
}
}
return false
}
// IsErrorExpired returns whether the error code is a credential expiry error. // IsErrorExpired returns whether the error code is a credential expiry error.
// Returns false if the request has no Error set. // Returns false if the request has no Error set.
// //

View File

@@ -3,6 +3,10 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library")
go_library( go_library(
name = "go_default_library", name = "go_default_library",
srcs = [ srcs = [
"cabundle_transport.go",
"cabundle_transport_1_5.go",
"cabundle_transport_1_6.go",
"credentials.go",
"doc.go", "doc.go",
"env_config.go", "env_config.go",
"session.go", "session.go",

View File

@@ -0,0 +1,26 @@
// +build go1.7
package session
import (
"net"
"net/http"
"time"
)
// Transport that should be used when a custom CA bundle is specified with the
// SDK.
func getCABundleTransport() *http.Transport {
return &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
}

View File

@@ -0,0 +1,22 @@
// +build !go1.6,go1.5
package session
import (
"net"
"net/http"
"time"
)
// Transport that should be used when a custom CA bundle is specified with the
// SDK.
func getCABundleTransport() *http.Transport {
return &http.Transport{
Proxy: http.ProxyFromEnvironment,
Dial: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).Dial,
TLSHandshakeTimeout: 10 * time.Second,
}
}

View File

@@ -0,0 +1,23 @@
// +build !go1.7,go1.6
package session
import (
"net"
"net/http"
"time"
)
// Transport that should be used when a custom CA bundle is specified with the
// SDK.
func getCABundleTransport() *http.Transport {
return &http.Transport{
Proxy: http.ProxyFromEnvironment,
Dial: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).Dial,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
}

View File

@@ -0,0 +1,259 @@
package session
import (
"fmt"
"os"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/credentials/processcreds"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/defaults"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/internal/shareddefaults"
)
func resolveCredentials(cfg *aws.Config,
envCfg envConfig, sharedCfg sharedConfig,
handlers request.Handlers,
sessOpts Options,
) (*credentials.Credentials, error) {
switch {
case len(sessOpts.Profile) != 0:
// User explicitly provided an Profile in the session's configuration
// so load that profile from shared config first.
// Github(aws/aws-sdk-go#2727)
return resolveCredsFromProfile(cfg, envCfg, sharedCfg, handlers, sessOpts)
case envCfg.Creds.HasKeys():
// Environment credentials
return credentials.NewStaticCredentialsFromCreds(envCfg.Creds), nil
case len(envCfg.WebIdentityTokenFilePath) != 0:
// Web identity token from environment, RoleARN required to also be
// set.
return assumeWebIdentity(cfg, handlers,
envCfg.WebIdentityTokenFilePath,
envCfg.RoleARN,
envCfg.RoleSessionName,
)
default:
// Fallback to the "default" credential resolution chain.
return resolveCredsFromProfile(cfg, envCfg, sharedCfg, handlers, sessOpts)
}
}
// WebIdentityEmptyRoleARNErr will occur if 'AWS_WEB_IDENTITY_TOKEN_FILE' was set but
// 'AWS_ROLE_ARN' was not set.
var WebIdentityEmptyRoleARNErr = awserr.New(stscreds.ErrCodeWebIdentity, "role ARN is not set", nil)
// WebIdentityEmptyTokenFilePathErr will occur if 'AWS_ROLE_ARN' was set but
// 'AWS_WEB_IDENTITY_TOKEN_FILE' was not set.
var WebIdentityEmptyTokenFilePathErr = awserr.New(stscreds.ErrCodeWebIdentity, "token file path is not set", nil)
func assumeWebIdentity(cfg *aws.Config, handlers request.Handlers,
filepath string,
roleARN, sessionName string,
) (*credentials.Credentials, error) {
if len(filepath) == 0 {
return nil, WebIdentityEmptyTokenFilePathErr
}
if len(roleARN) == 0 {
return nil, WebIdentityEmptyRoleARNErr
}
creds := stscreds.NewWebIdentityCredentials(
&Session{
Config: cfg,
Handlers: handlers.Copy(),
},
roleARN,
sessionName,
filepath,
)
return creds, nil
}
func resolveCredsFromProfile(cfg *aws.Config,
envCfg envConfig, sharedCfg sharedConfig,
handlers request.Handlers,
sessOpts Options,
) (creds *credentials.Credentials, err error) {
switch {
case sharedCfg.SourceProfile != nil:
// Assume IAM role with credentials source from a different profile.
creds, err = resolveCredsFromProfile(cfg, envCfg,
*sharedCfg.SourceProfile, handlers, sessOpts,
)
case sharedCfg.Creds.HasKeys():
// Static Credentials from Shared Config/Credentials file.
creds = credentials.NewStaticCredentialsFromCreds(
sharedCfg.Creds,
)
case len(sharedCfg.CredentialProcess) != 0:
// Get credentials from CredentialProcess
creds = processcreds.NewCredentials(sharedCfg.CredentialProcess)
case len(sharedCfg.CredentialSource) != 0:
creds, err = resolveCredsFromSource(cfg, envCfg,
sharedCfg, handlers, sessOpts,
)
case len(sharedCfg.WebIdentityTokenFile) != 0:
// Credentials from Assume Web Identity token require an IAM Role, and
// that roll will be assumed. May be wrapped with another assume role
// via SourceProfile.
return assumeWebIdentity(cfg, handlers,
sharedCfg.WebIdentityTokenFile,
sharedCfg.RoleARN,
sharedCfg.RoleSessionName,
)
default:
// Fallback to default credentials provider, include mock errors for
// the credential chain so user can identify why credentials failed to
// be retrieved.
creds = credentials.NewCredentials(&credentials.ChainProvider{
VerboseErrors: aws.BoolValue(cfg.CredentialsChainVerboseErrors),
Providers: []credentials.Provider{
&credProviderError{
Err: awserr.New("EnvAccessKeyNotFound",
"failed to find credentials in the environment.", nil),
},
&credProviderError{
Err: awserr.New("SharedCredsLoad",
fmt.Sprintf("failed to load profile, %s.", envCfg.Profile), nil),
},
defaults.RemoteCredProvider(*cfg, handlers),
},
})
}
if err != nil {
return nil, err
}
if len(sharedCfg.RoleARN) > 0 {
cfgCp := *cfg
cfgCp.Credentials = creds
return credsFromAssumeRole(cfgCp, handlers, sharedCfg, sessOpts)
}
return creds, nil
}
// valid credential source values
const (
credSourceEc2Metadata = "Ec2InstanceMetadata"
credSourceEnvironment = "Environment"
credSourceECSContainer = "EcsContainer"
)
func resolveCredsFromSource(cfg *aws.Config,
envCfg envConfig, sharedCfg sharedConfig,
handlers request.Handlers,
sessOpts Options,
) (creds *credentials.Credentials, err error) {
switch sharedCfg.CredentialSource {
case credSourceEc2Metadata:
p := defaults.RemoteCredProvider(*cfg, handlers)
creds = credentials.NewCredentials(p)
case credSourceEnvironment:
creds = credentials.NewStaticCredentialsFromCreds(envCfg.Creds)
case credSourceECSContainer:
if len(os.Getenv(shareddefaults.ECSCredsProviderEnvVar)) == 0 {
return nil, ErrSharedConfigECSContainerEnvVarEmpty
}
p := defaults.RemoteCredProvider(*cfg, handlers)
creds = credentials.NewCredentials(p)
default:
return nil, ErrSharedConfigInvalidCredSource
}
return creds, nil
}
func credsFromAssumeRole(cfg aws.Config,
handlers request.Handlers,
sharedCfg sharedConfig,
sessOpts Options,
) (*credentials.Credentials, error) {
if len(sharedCfg.MFASerial) != 0 && sessOpts.AssumeRoleTokenProvider == nil {
// AssumeRole Token provider is required if doing Assume Role
// with MFA.
return nil, AssumeRoleTokenProviderNotSetError{}
}
return stscreds.NewCredentials(
&Session{
Config: &cfg,
Handlers: handlers.Copy(),
},
sharedCfg.RoleARN,
func(opt *stscreds.AssumeRoleProvider) {
opt.RoleSessionName = sharedCfg.RoleSessionName
opt.Duration = sessOpts.AssumeRoleDuration
// Assume role with external ID
if len(sharedCfg.ExternalID) > 0 {
opt.ExternalID = aws.String(sharedCfg.ExternalID)
}
// Assume role with MFA
if len(sharedCfg.MFASerial) > 0 {
opt.SerialNumber = aws.String(sharedCfg.MFASerial)
opt.TokenProvider = sessOpts.AssumeRoleTokenProvider
}
},
), nil
}
// AssumeRoleTokenProviderNotSetError is an error returned when creating a
// session when the MFAToken option is not set when shared config is configured
// load assume a role with an MFA token.
type AssumeRoleTokenProviderNotSetError struct{}
// Code is the short id of the error.
func (e AssumeRoleTokenProviderNotSetError) Code() string {
return "AssumeRoleTokenProviderNotSetError"
}
// Message is the description of the error
func (e AssumeRoleTokenProviderNotSetError) Message() string {
return fmt.Sprintf("assume role with MFA enabled, but AssumeRoleTokenProvider session option not set.")
}
// OrigErr is the underlying error that caused the failure.
func (e AssumeRoleTokenProviderNotSetError) OrigErr() error {
return nil
}
// Error satisfies the error interface.
func (e AssumeRoleTokenProviderNotSetError) Error() string {
return awserr.SprintError(e.Code(), e.Message(), "", nil)
}
type credProviderError struct {
Err error
}
func (c credProviderError) Retrieve() (credentials.Value, error) {
return credentials.Value{}, c.Err
}
func (c credProviderError) IsExpired() bool {
return true
}

View File

@@ -1,97 +1,93 @@
/* /*
Package session provides configuration for the SDK's service clients. Package session provides configuration for the SDK's service clients. Sessions
can be shared across service clients that share the same base configuration.
Sessions can be shared across all service clients that share the same base
configuration. The Session is built from the SDK's default configuration and
request handlers.
Sessions should be cached when possible, because creating a new Session will
load all configuration values from the environment, and config files each time
the Session is created. Sharing the Session value across all of your service
clients will ensure the configuration is loaded the fewest number of times possible.
Concurrency
Sessions are safe to use concurrently as long as the Session is not being Sessions are safe to use concurrently as long as the Session is not being
modified. The SDK will not modify the Session once the Session has been created. modified. Sessions should be cached when possible, because creating a new
Creating service clients concurrently from a shared Session is safe. Session will load all configuration values from the environment, and config
files each time the Session is created. Sharing the Session value across all of
your service clients will ensure the configuration is loaded the fewest number
of times possible.
Sessions from Shared Config Sessions options from Shared Config
Sessions can be created using the method above that will only load the
additional config if the AWS_SDK_LOAD_CONFIG environment variable is set.
Alternatively you can explicitly create a Session with shared config enabled.
To do this you can use NewSessionWithOptions to configure how the Session will
be created. Using the NewSessionWithOptions with SharedConfigState set to
SharedConfigEnable will create the session as if the AWS_SDK_LOAD_CONFIG
environment variable was set.
Creating Sessions
When creating Sessions optional aws.Config values can be passed in that will
override the default, or loaded config values the Session is being created
with. This allows you to provide additional, or case based, configuration
as needed.
By default NewSession will only load credentials from the shared credentials By default NewSession will only load credentials from the shared credentials
file (~/.aws/credentials). If the AWS_SDK_LOAD_CONFIG environment variable is file (~/.aws/credentials). If the AWS_SDK_LOAD_CONFIG environment variable is
set to a truthy value the Session will be created from the configuration set to a truthy value the Session will be created from the configuration
values from the shared config (~/.aws/config) and shared credentials values from the shared config (~/.aws/config) and shared credentials
(~/.aws/credentials) files. See the section Sessions from Shared Config for (~/.aws/credentials) files. Using the NewSessionWithOptions with
more information. SharedConfigState set to SharedConfigEnable will create the session as if the
AWS_SDK_LOAD_CONFIG environment variable was set.
Create a Session with the default config and request handlers. With credentials Credential and config loading order
region, and profile loaded from the environment and shared config automatically.
Requires the AWS_PROFILE to be set, or "default" is used. The Session will attempt to load configuration and credentials from the
environment, configuration files, and other credential sources. The order
configuration is loaded in is:
* Environment Variables
* Shared Credentials file
* Shared Configuration file (if SharedConfig is enabled)
* EC2 Instance Metadata (credentials only)
The Environment variables for credentials will have precedence over shared
config even if SharedConfig is enabled. To override this behavior, and use
shared config credentials instead specify the session.Options.Profile, (e.g.
when using credential_source=Environment to assume a role).
sess, err := session.NewSessionWithOptions(session.Options{
Profile: "myProfile",
})
Creating Sessions
Creating a Session without additional options will load credentials region, and
profile loaded from the environment and shared config automatically. See,
"Environment Variables" section for information on environment variables used
by Session.
// Create Session // Create Session
sess := session.Must(session.NewSession()) sess, err := session.NewSession()
When creating Sessions optional aws.Config values can be passed in that will
override the default, or loaded, config values the Session is being created
with. This allows you to provide additional, or case based, configuration
as needed.
// Create a Session with a custom region // Create a Session with a custom region
sess := session.Must(session.NewSession(&aws.Config{ sess, err := session.NewSession(&aws.Config{
Region: aws.String("us-east-1"), Region: aws.String("us-west-2"),
})) })
// Create a S3 client instance from a session Use NewSessionWithOptions to provide additional configuration driving how the
sess := session.Must(session.NewSession()) Session's configuration will be loaded. Such as, specifying shared config
profile, or override the shared config state, (AWS_SDK_LOAD_CONFIG).
svc := s3.New(sess)
Create Session With Option Overrides
In addition to NewSession, Sessions can be created using NewSessionWithOptions.
This func allows you to control and override how the Session will be created
through code instead of being driven by environment variables only.
Use NewSessionWithOptions when you want to provide the config profile, or
override the shared config state (AWS_SDK_LOAD_CONFIG).
// Equivalent to session.NewSession() // Equivalent to session.NewSession()
sess := session.Must(session.NewSessionWithOptions(session.Options{ sess, err := session.NewSessionWithOptions(session.Options{
// Options // Options
})) })
sess, err := session.NewSessionWithOptions(session.Options{
// Specify profile to load for the session's config // Specify profile to load for the session's config
sess := session.Must(session.NewSessionWithOptions(session.Options{
Profile: "profile_name", Profile: "profile_name",
}))
// Specify profile for config and region for requests // Provide SDK Config options, such as Region.
sess := session.Must(session.NewSessionWithOptions(session.Options{ Config: aws.Config{
Config: aws.Config{Region: aws.String("us-east-1")}, Region: aws.String("us-west-2"),
Profile: "profile_name", },
}))
// Force enable Shared Config support // Force enable Shared Config support
sess := session.Must(session.NewSessionWithOptions(session.Options{
SharedConfigState: session.SharedConfigEnable, SharedConfigState: session.SharedConfigEnable,
})) })
Adding Handlers Adding Handlers
You can add handlers to a session for processing HTTP requests. All service You can add handlers to a session to decorate API operation, (e.g. adding HTTP
clients that use the session inherit the handlers. For example, the following headers). All clients that use the Session receive a copy of the Session's
handler logs every request and its payload made by a service client: handlers. For example, the following request handler added to the Session logs
every requests made.
// Create a session, and add additional handlers for all service // Create a session, and add additional handlers for all service
// clients created with the Session to inherit. Adds logging handler. // clients created with the Session to inherit. Adds logging handler.
@@ -99,22 +95,15 @@ handler logs every request and its payload made by a service client:
sess.Handlers.Send.PushFront(func(r *request.Request) { sess.Handlers.Send.PushFront(func(r *request.Request) {
// Log every request made and its payload // Log every request made and its payload
logger.Printf("Request: %s/%s, Payload: %s", logger.Printf("Request: %s/%s, Params: %s",
r.ClientInfo.ServiceName, r.Operation, r.Params) r.ClientInfo.ServiceName, r.Operation, r.Params)
}) })
Deprecated "New" function
The New session function has been deprecated because it does not provide good
way to return errors that occur when loading the configuration files and values.
Because of this, NewSession was created so errors can be retrieved when
creating a session fails.
Shared Config Fields Shared Config Fields
By default the SDK will only load the shared credentials file's (~/.aws/credentials) By default the SDK will only load the shared credentials file's
credentials values, and all other config is provided by the environment variables, (~/.aws/credentials) credentials values, and all other config is provided by
SDK defaults, and user provided aws.Config values. the environment variables, SDK defaults, and user provided aws.Config values.
If the AWS_SDK_LOAD_CONFIG environment variable is set, or SharedConfigEnable If the AWS_SDK_LOAD_CONFIG environment variable is set, or SharedConfigEnable
option is used to create the Session the full shared config values will be option is used to create the Session the full shared config values will be
@@ -125,24 +114,31 @@ files have the same format.
If both config files are present the configuration from both files will be If both config files are present the configuration from both files will be
read. The Session will be created from configuration values from the shared read. The Session will be created from configuration values from the shared
credentials file (~/.aws/credentials) over those in the shared config file (~/.aws/config). credentials file (~/.aws/credentials) over those in the shared config file
(~/.aws/config).
Credentials are the values the SDK should use for authenticating requests with Credentials are the values the SDK uses to authenticating requests with AWS
AWS Services. They are from a configuration file will need to include both Services. When specified in a file, both aws_access_key_id and
aws_access_key_id and aws_secret_access_key must be provided together in the aws_secret_access_key must be provided together in the same file to be
same file to be considered valid. The values will be ignored if not a complete considered valid. They will be ignored if both are not present.
group. aws_session_token is an optional field that can be provided if both of aws_session_token is an optional field that can be provided in addition to the
the other two fields are also provided. other two fields.
aws_access_key_id = AKID aws_access_key_id = AKID
aws_secret_access_key = SECRET aws_secret_access_key = SECRET
aws_session_token = TOKEN aws_session_token = TOKEN
Assume Role values allow you to configure the SDK to assume an IAM role using ; region only supported if SharedConfigEnabled.
a set of credentials provided in a config file via the source_profile field. region = us-east-1
Both "role_arn" and "source_profile" are required. The SDK supports assuming
a role with MFA token if the session option AssumeRoleTokenProvider Assume Role configuration
is set.
The role_arn field allows you to configure the SDK to assume an IAM role using
a set of credentials from another source. Such as when paired with static
credentials, "profile_source", "credential_process", or "credential_source"
fields. If "role_arn" is provided, a source of credentials must also be
specified, such as "source_profile", "credential_source", or
"credential_process".
role_arn = arn:aws:iam::<account_number>:role/<role_name> role_arn = arn:aws:iam::<account_number>:role/<role_name>
source_profile = profile_with_creds source_profile = profile_with_creds
@@ -150,40 +146,16 @@ is set.
mfa_serial = <serial or mfa arn> mfa_serial = <serial or mfa arn>
role_session_name = session_name role_session_name = session_name
Region is the region the SDK should use for looking up AWS service endpoints
and signing requests.
region = us-east-1 The SDK supports assuming a role with MFA token. If "mfa_serial" is set, you
must also set the Session Option.AssumeRoleTokenProvider. The Session will fail
Assume Role with MFA token to load if the AssumeRoleTokenProvider is not specified.
To create a session with support for assuming an IAM role with MFA set the
session option AssumeRoleTokenProvider to a function that will prompt for the
MFA token code when the SDK assumes the role and refreshes the role's credentials.
This allows you to configure the SDK via the shared config to assumea role
with MFA tokens.
In order for the SDK to assume a role with MFA the SharedConfigState
session option must be set to SharedConfigEnable, or AWS_SDK_LOAD_CONFIG
environment variable set.
The shared configuration instructs the SDK to assume an IAM role with MFA
when the mfa_serial configuration field is set in the shared config
(~/.aws/config) or shared credentials (~/.aws/credentials) file.
If mfa_serial is set in the configuration, the SDK will assume the role, and
the AssumeRoleTokenProvider session option is not set an an error will
be returned when creating the session.
sess := session.Must(session.NewSessionWithOptions(session.Options{ sess := session.Must(session.NewSessionWithOptions(session.Options{
AssumeRoleTokenProvider: stscreds.StdinTokenProvider, AssumeRoleTokenProvider: stscreds.StdinTokenProvider,
})) }))
// Create service client value configured for credentials To setup Assume Role outside of a session see the stscreds.AssumeRoleProvider
// from assumed role.
svc := s3.New(sess)
To setup assume role outside of a session see the stscreds.AssumeRoleProvider
documentation. documentation.
Environment Variables Environment Variables

View File

@@ -1,12 +1,15 @@
package session package session
import ( import (
"fmt"
"os" "os"
"strconv" "strconv"
"strings"
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/defaults" "github.com/aws/aws-sdk-go/aws/defaults"
"github.com/aws/aws-sdk-go/aws/endpoints"
) )
// EnvProviderName provides a name of the provider when config is loaded from environment. // EnvProviderName provides a name of the provider when config is loaded from environment.
@@ -99,21 +102,61 @@ type envConfig struct {
CustomCABundle string CustomCABundle string
csmEnabled string csmEnabled string
CSMEnabled bool CSMEnabled *bool
CSMPort string CSMPort string
CSMHost string
CSMClientID string CSMClientID string
enableEndpointDiscovery string
// Enables endpoint discovery via environment variables. // Enables endpoint discovery via environment variables.
// //
// AWS_ENABLE_ENDPOINT_DISCOVERY=true // AWS_ENABLE_ENDPOINT_DISCOVERY=true
EnableEndpointDiscovery *bool EnableEndpointDiscovery *bool
enableEndpointDiscovery string
// Specifies the WebIdentity token the SDK should use to assume a role
// with.
//
// AWS_WEB_IDENTITY_TOKEN_FILE=file_path
WebIdentityTokenFilePath string
// Specifies the IAM role arn to use when assuming an role.
//
// AWS_ROLE_ARN=role_arn
RoleARN string
// Specifies the IAM role session name to use when assuming a role.
//
// AWS_ROLE_SESSION_NAME=session_name
RoleSessionName string
// Specifies the STS Regional Endpoint flag for the SDK to resolve the endpoint
// for a service.
//
// AWS_STS_REGIONAL_ENDPOINTS=regional
// This can take value as `regional` or `legacy`
STSRegionalEndpoint endpoints.STSRegionalEndpoint
// Specifies the S3 Regional Endpoint flag for the SDK to resolve the
// endpoint for a service.
//
// AWS_S3_US_EAST_1_REGIONAL_ENDPOINT=regional
// This can take value as `regional` or `legacy`
S3UsEast1RegionalEndpoint endpoints.S3UsEast1RegionalEndpoint
// Specifies if the S3 service should allow ARNs to direct the region
// the client's requests are sent to.
//
// AWS_S3_USE_ARN_REGION=true
S3UseARNRegion bool
} }
var ( var (
csmEnabledEnvKey = []string{ csmEnabledEnvKey = []string{
"AWS_CSM_ENABLED", "AWS_CSM_ENABLED",
} }
csmHostEnvKey = []string{
"AWS_CSM_HOST",
}
csmPortEnvKey = []string{ csmPortEnvKey = []string{
"AWS_CSM_PORT", "AWS_CSM_PORT",
} }
@@ -150,6 +193,24 @@ var (
sharedConfigFileEnvKey = []string{ sharedConfigFileEnvKey = []string{
"AWS_CONFIG_FILE", "AWS_CONFIG_FILE",
} }
webIdentityTokenFilePathEnvKey = []string{
"AWS_WEB_IDENTITY_TOKEN_FILE",
}
roleARNEnvKey = []string{
"AWS_ROLE_ARN",
}
roleSessionNameEnvKey = []string{
"AWS_ROLE_SESSION_NAME",
}
stsRegionalEndpointKey = []string{
"AWS_STS_REGIONAL_ENDPOINTS",
}
s3UsEast1RegionalEndpoint = []string{
"AWS_S3_US_EAST_1_REGIONAL_ENDPOINT",
}
s3UseARNRegionEnvKey = []string{
"AWS_S3_USE_ARN_REGION",
}
) )
// loadEnvConfig retrieves the SDK's environment configuration. // loadEnvConfig retrieves the SDK's environment configuration.
@@ -158,7 +219,7 @@ var (
// If the environment variable `AWS_SDK_LOAD_CONFIG` is set to a truthy value // If the environment variable `AWS_SDK_LOAD_CONFIG` is set to a truthy value
// the shared SDK config will be loaded in addition to the SDK's specific // the shared SDK config will be loaded in addition to the SDK's specific
// configuration values. // configuration values.
func loadEnvConfig() envConfig { func loadEnvConfig() (envConfig, error) {
enableSharedConfig, _ := strconv.ParseBool(os.Getenv("AWS_SDK_LOAD_CONFIG")) enableSharedConfig, _ := strconv.ParseBool(os.Getenv("AWS_SDK_LOAD_CONFIG"))
return envConfigLoad(enableSharedConfig) return envConfigLoad(enableSharedConfig)
} }
@@ -169,30 +230,42 @@ func loadEnvConfig() envConfig {
// Loads the shared configuration in addition to the SDK's specific configuration. // Loads the shared configuration in addition to the SDK's specific configuration.
// This will load the same values as `loadEnvConfig` if the `AWS_SDK_LOAD_CONFIG` // This will load the same values as `loadEnvConfig` if the `AWS_SDK_LOAD_CONFIG`
// environment variable is set. // environment variable is set.
func loadSharedEnvConfig() envConfig { func loadSharedEnvConfig() (envConfig, error) {
return envConfigLoad(true) return envConfigLoad(true)
} }
func envConfigLoad(enableSharedConfig bool) envConfig { func envConfigLoad(enableSharedConfig bool) (envConfig, error) {
cfg := envConfig{} cfg := envConfig{}
cfg.EnableSharedConfig = enableSharedConfig cfg.EnableSharedConfig = enableSharedConfig
setFromEnvVal(&cfg.Creds.AccessKeyID, credAccessEnvKey) // Static environment credentials
setFromEnvVal(&cfg.Creds.SecretAccessKey, credSecretEnvKey) var creds credentials.Value
setFromEnvVal(&cfg.Creds.SessionToken, credSessionEnvKey) setFromEnvVal(&creds.AccessKeyID, credAccessEnvKey)
setFromEnvVal(&creds.SecretAccessKey, credSecretEnvKey)
setFromEnvVal(&creds.SessionToken, credSessionEnvKey)
if creds.HasKeys() {
// Require logical grouping of credentials
creds.ProviderName = EnvProviderName
cfg.Creds = creds
}
// Role Metadata
setFromEnvVal(&cfg.RoleARN, roleARNEnvKey)
setFromEnvVal(&cfg.RoleSessionName, roleSessionNameEnvKey)
// Web identity environment variables
setFromEnvVal(&cfg.WebIdentityTokenFilePath, webIdentityTokenFilePathEnvKey)
// CSM environment variables // CSM environment variables
setFromEnvVal(&cfg.csmEnabled, csmEnabledEnvKey) setFromEnvVal(&cfg.csmEnabled, csmEnabledEnvKey)
setFromEnvVal(&cfg.CSMHost, csmHostEnvKey)
setFromEnvVal(&cfg.CSMPort, csmPortEnvKey) setFromEnvVal(&cfg.CSMPort, csmPortEnvKey)
setFromEnvVal(&cfg.CSMClientID, csmClientIDEnvKey) setFromEnvVal(&cfg.CSMClientID, csmClientIDEnvKey)
cfg.CSMEnabled = len(cfg.csmEnabled) > 0
// Require logical grouping of credentials if len(cfg.csmEnabled) != 0 {
if len(cfg.Creds.AccessKeyID) == 0 || len(cfg.Creds.SecretAccessKey) == 0 { v, _ := strconv.ParseBool(cfg.csmEnabled)
cfg.Creds = credentials.Value{} cfg.CSMEnabled = &v
} else {
cfg.Creds.ProviderName = EnvProviderName
} }
regionKeys := regionEnvKeys regionKeys := regionEnvKeys
@@ -223,12 +296,48 @@ func envConfigLoad(enableSharedConfig bool) envConfig {
cfg.CustomCABundle = os.Getenv("AWS_CA_BUNDLE") cfg.CustomCABundle = os.Getenv("AWS_CA_BUNDLE")
return cfg var err error
// STS Regional Endpoint variable
for _, k := range stsRegionalEndpointKey {
if v := os.Getenv(k); len(v) != 0 {
cfg.STSRegionalEndpoint, err = endpoints.GetSTSRegionalEndpoint(v)
if err != nil {
return cfg, fmt.Errorf("failed to load, %v from env config, %v", k, err)
}
}
}
// S3 Regional Endpoint variable
for _, k := range s3UsEast1RegionalEndpoint {
if v := os.Getenv(k); len(v) != 0 {
cfg.S3UsEast1RegionalEndpoint, err = endpoints.GetS3UsEast1RegionalEndpoint(v)
if err != nil {
return cfg, fmt.Errorf("failed to load, %v from env config, %v", k, err)
}
}
}
var s3UseARNRegion string
setFromEnvVal(&s3UseARNRegion, s3UseARNRegionEnvKey)
if len(s3UseARNRegion) != 0 {
switch {
case strings.EqualFold(s3UseARNRegion, "false"):
cfg.S3UseARNRegion = false
case strings.EqualFold(s3UseARNRegion, "true"):
cfg.S3UseARNRegion = true
default:
return envConfig{}, fmt.Errorf(
"invalid value for environment variable, %s=%s, need true or false",
s3UseARNRegionEnvKey[0], s3UseARNRegion)
}
}
return cfg, nil
} }
func setFromEnvVal(dst *string, keys []string) { func setFromEnvVal(dst *string, keys []string) {
for _, k := range keys { for _, k := range keys {
if v := os.Getenv(k); len(v) > 0 { if v := os.Getenv(k); len(v) != 0 {
*dst = v *dst = v
break break
} }

View File

@@ -8,19 +8,17 @@ import (
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"os" "os"
"time"
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client" "github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/corehandlers" "github.com/aws/aws-sdk-go/aws/corehandlers"
"github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/credentials/processcreds"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/csm" "github.com/aws/aws-sdk-go/aws/csm"
"github.com/aws/aws-sdk-go/aws/defaults" "github.com/aws/aws-sdk-go/aws/defaults"
"github.com/aws/aws-sdk-go/aws/endpoints" "github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/internal/shareddefaults"
) )
const ( const (
@@ -75,7 +73,7 @@ type Session struct {
// func is called instead of waiting to receive an error until a request is made. // func is called instead of waiting to receive an error until a request is made.
func New(cfgs ...*aws.Config) *Session { func New(cfgs ...*aws.Config) *Session {
// load initial config from environment // load initial config from environment
envCfg := loadEnvConfig() envCfg, envErr := loadEnvConfig()
if envCfg.EnableSharedConfig { if envCfg.EnableSharedConfig {
var cfg aws.Config var cfg aws.Config
@@ -95,19 +93,28 @@ func New(cfgs ...*aws.Config) *Session {
// Session creation failed, need to report the error and prevent // Session creation failed, need to report the error and prevent
// any requests from succeeding. // any requests from succeeding.
s = &Session{Config: defaults.Config()} s = &Session{Config: defaults.Config()}
s.Config.MergeIn(cfgs...) s.logDeprecatedNewSessionError(msg, err, cfgs)
s.Config.Logger.Log("ERROR:", msg, "Error:", err)
s.Handlers.Validate.PushBack(func(r *request.Request) {
r.Error = err
})
} }
return s return s
} }
s := deprecatedNewSession(cfgs...) s := deprecatedNewSession(cfgs...)
if envCfg.CSMEnabled { if envErr != nil {
enableCSM(&s.Handlers, envCfg.CSMClientID, envCfg.CSMPort, s.Config.Logger) msg := "failed to load env config"
s.logDeprecatedNewSessionError(msg, envErr, cfgs)
}
if csmCfg, err := loadCSMConfig(envCfg, []string{}); err != nil {
if l := s.Config.Logger; l != nil {
l.Log(fmt.Sprintf("ERROR: failed to load CSM configuration, %v", err))
}
} else if csmCfg.Enabled {
err := enableCSM(&s.Handlers, csmCfg, s.Config.Logger)
if err != nil {
msg := "failed to enable CSM"
s.logDeprecatedNewSessionError(msg, err, cfgs)
}
} }
return s return s
@@ -126,7 +133,7 @@ func New(cfgs ...*aws.Config) *Session {
// to be built with retrieving credentials with AssumeRole set in the config. // to be built with retrieving credentials with AssumeRole set in the config.
// //
// See the NewSessionWithOptions func for information on how to override or // See the NewSessionWithOptions func for information on how to override or
// control through code how the Session will be created. Such as specifying the // control through code how the Session will be created, such as specifying the
// config profile, and controlling if shared config is enabled or not. // config profile, and controlling if shared config is enabled or not.
func NewSession(cfgs ...*aws.Config) (*Session, error) { func NewSession(cfgs ...*aws.Config) (*Session, error) {
opts := Options{} opts := Options{}
@@ -210,6 +217,12 @@ type Options struct {
// the config enables assume role wit MFA via the mfa_serial field. // the config enables assume role wit MFA via the mfa_serial field.
AssumeRoleTokenProvider func() (string, error) AssumeRoleTokenProvider func() (string, error)
// When the SDK's shared config is configured to assume a role this option
// may be provided to set the expiry duration of the STS credentials.
// Defaults to 15 minutes if not set as documented in the
// stscreds.AssumeRoleProvider.
AssumeRoleDuration time.Duration
// Reader for a custom Credentials Authority (CA) bundle in PEM format that // Reader for a custom Credentials Authority (CA) bundle in PEM format that
// the SDK will use instead of the default system's root CA bundle. Use this // the SDK will use instead of the default system's root CA bundle. Use this
// only if you want to replace the CA bundle the SDK uses for TLS requests. // only if you want to replace the CA bundle the SDK uses for TLS requests.
@@ -224,6 +237,12 @@ type Options struct {
// to also enable this feature. CustomCABundle session option field has priority // to also enable this feature. CustomCABundle session option field has priority
// over the AWS_CA_BUNDLE environment variable, and will be used if both are set. // over the AWS_CA_BUNDLE environment variable, and will be used if both are set.
CustomCABundle io.Reader CustomCABundle io.Reader
// The handlers that the session and all API clients will be created with.
// This must be a complete set of handlers. Use the defaults.Handlers()
// function to initialize this value before changing the handlers to be
// used by the SDK.
Handlers request.Handlers
} }
// NewSessionWithOptions returns a new Session created from SDK defaults, config files, // NewSessionWithOptions returns a new Session created from SDK defaults, config files,
@@ -257,13 +276,20 @@ type Options struct {
// })) // }))
func NewSessionWithOptions(opts Options) (*Session, error) { func NewSessionWithOptions(opts Options) (*Session, error) {
var envCfg envConfig var envCfg envConfig
var err error
if opts.SharedConfigState == SharedConfigEnable { if opts.SharedConfigState == SharedConfigEnable {
envCfg = loadSharedEnvConfig() envCfg, err = loadSharedEnvConfig()
if err != nil {
return nil, fmt.Errorf("failed to load shared config, %v", err)
}
} else { } else {
envCfg = loadEnvConfig() envCfg, err = loadEnvConfig()
if err != nil {
return nil, fmt.Errorf("failed to load environment config, %v", err)
}
} }
if len(opts.Profile) > 0 { if len(opts.Profile) != 0 {
envCfg.Profile = opts.Profile envCfg.Profile = opts.Profile
} }
@@ -329,27 +355,33 @@ func deprecatedNewSession(cfgs ...*aws.Config) *Session {
return s return s
} }
func enableCSM(handlers *request.Handlers, clientID string, port string, logger aws.Logger) { func enableCSM(handlers *request.Handlers, cfg csmConfig, logger aws.Logger) error {
if logger != nil {
logger.Log("Enabling CSM") logger.Log("Enabling CSM")
if len(port) == 0 {
port = csm.DefaultPort
} }
r, err := csm.Start(clientID, "127.0.0.1:"+port) r, err := csm.Start(cfg.ClientID, csm.AddressWithDefaults(cfg.Host, cfg.Port))
if err != nil { if err != nil {
return return err
} }
r.InjectHandlers(handlers) r.InjectHandlers(handlers)
return nil
} }
func newSession(opts Options, envCfg envConfig, cfgs ...*aws.Config) (*Session, error) { func newSession(opts Options, envCfg envConfig, cfgs ...*aws.Config) (*Session, error) {
cfg := defaults.Config() cfg := defaults.Config()
handlers := defaults.Handlers()
handlers := opts.Handlers
if handlers.IsEmpty() {
handlers = defaults.Handlers()
}
// Get a merged version of the user provided config to determine if // Get a merged version of the user provided config to determine if
// credentials were. // credentials were.
userCfg := &aws.Config{} userCfg := &aws.Config{}
userCfg.MergeIn(cfgs...) userCfg.MergeIn(cfgs...)
cfg.MergeIn(userCfg)
// Ordered config files will be loaded in with later files overwriting // Ordered config files will be loaded in with later files overwriting
// previous config file values. // previous config file values.
@@ -366,10 +398,18 @@ func newSession(opts Options, envCfg envConfig, cfgs ...*aws.Config) (*Session,
} }
// Load additional config from file(s) // Load additional config from file(s)
sharedCfg, err := loadSharedConfig(envCfg.Profile, cfgFiles) sharedCfg, err := loadSharedConfig(envCfg.Profile, cfgFiles, envCfg.EnableSharedConfig)
if err != nil { if err != nil {
if len(envCfg.Profile) == 0 && !envCfg.EnableSharedConfig && (envCfg.Creds.HasKeys() || userCfg.Credentials != nil) {
// Special case where the user has not explicitly specified an AWS_PROFILE,
// or session.Options.profile, shared config is not enabled, and the
// environment has credentials, allow the shared config file to fail to
// load since the user has already provided credentials, and nothing else
// is required to be read file. Github(aws/aws-sdk-go#2455)
} else if _, ok := err.(SharedConfigProfileNotExistsError); !ok {
return nil, err return nil, err
} }
}
if err := mergeConfigSrcs(cfg, userCfg, envCfg, sharedCfg, handlers, opts); err != nil { if err := mergeConfigSrcs(cfg, userCfg, envCfg, sharedCfg, handlers, opts); err != nil {
return nil, err return nil, err
@@ -381,8 +421,16 @@ func newSession(opts Options, envCfg envConfig, cfgs ...*aws.Config) (*Session,
} }
initHandlers(s) initHandlers(s)
if envCfg.CSMEnabled {
enableCSM(&s.Handlers, envCfg.CSMClientID, envCfg.CSMPort, s.Config.Logger) if csmCfg, err := loadCSMConfig(envCfg, cfgFiles); err != nil {
if l := s.Config.Logger; l != nil {
l.Log(fmt.Sprintf("ERROR: failed to load CSM configuration, %v", err))
}
} else if csmCfg.Enabled {
err = enableCSM(&s.Handlers, csmCfg, s.Config.Logger)
if err != nil {
return nil, err
}
} }
// Setup HTTP client with custom cert bundle if enabled // Setup HTTP client with custom cert bundle if enabled
@@ -395,6 +443,46 @@ func newSession(opts Options, envCfg envConfig, cfgs ...*aws.Config) (*Session,
return s, nil return s, nil
} }
type csmConfig struct {
Enabled bool
Host string
Port string
ClientID string
}
var csmProfileName = "aws_csm"
func loadCSMConfig(envCfg envConfig, cfgFiles []string) (csmConfig, error) {
if envCfg.CSMEnabled != nil {
if *envCfg.CSMEnabled {
return csmConfig{
Enabled: true,
ClientID: envCfg.CSMClientID,
Host: envCfg.CSMHost,
Port: envCfg.CSMPort,
}, nil
}
return csmConfig{}, nil
}
sharedCfg, err := loadSharedConfig(csmProfileName, cfgFiles, false)
if err != nil {
if _, ok := err.(SharedConfigProfileNotExistsError); !ok {
return csmConfig{}, err
}
}
if sharedCfg.CSMEnabled != nil && *sharedCfg.CSMEnabled == true {
return csmConfig{
Enabled: true,
ClientID: sharedCfg.CSMClientID,
Host: sharedCfg.CSMHost,
Port: sharedCfg.CSMPort,
}, nil
}
return csmConfig{}, nil
}
func loadCustomCABundle(s *Session, bundle io.Reader) error { func loadCustomCABundle(s *Session, bundle io.Reader) error {
var t *http.Transport var t *http.Transport
switch v := s.Config.HTTPClient.Transport.(type) { switch v := s.Config.HTTPClient.Transport.(type) {
@@ -407,7 +495,10 @@ func loadCustomCABundle(s *Session, bundle io.Reader) error {
} }
} }
if t == nil { if t == nil {
t = &http.Transport{} // Nil transport implies `http.DefaultTransport` should be used. Since
// the SDK cannot modify, nor copy the `DefaultTransport` specifying
// the values the next closest behavior.
t = getCABundleTransport()
} }
p, err := loadCertPool(bundle) p, err := loadCertPool(bundle)
@@ -440,9 +531,11 @@ func loadCertPool(r io.Reader) (*x509.CertPool, error) {
return p, nil return p, nil
} }
func mergeConfigSrcs(cfg, userCfg *aws.Config, envCfg envConfig, sharedCfg sharedConfig, handlers request.Handlers, sessOpts Options) error { func mergeConfigSrcs(cfg, userCfg *aws.Config,
// Merge in user provided configuration envCfg envConfig, sharedCfg sharedConfig,
cfg.MergeIn(userCfg) handlers request.Handlers,
sessOpts Options,
) error {
// Region if not already set by user // Region if not already set by user
if len(aws.StringValue(cfg.Region)) == 0 { if len(aws.StringValue(cfg.Region)) == 0 {
@@ -461,162 +554,59 @@ func mergeConfigSrcs(cfg, userCfg *aws.Config, envCfg envConfig, sharedCfg share
} }
} }
// Configure credentials if not already set // Regional Endpoint flag for STS endpoint resolving
if cfg.Credentials == credentials.AnonymousCredentials && userCfg.Credentials == nil { mergeSTSRegionalEndpointConfig(cfg, []endpoints.STSRegionalEndpoint{
userCfg.STSRegionalEndpoint,
// inspect the profile to see if a credential source has been specified. envCfg.STSRegionalEndpoint,
if envCfg.EnableSharedConfig && len(sharedCfg.AssumeRole.CredentialSource) > 0 { sharedCfg.STSRegionalEndpoint,
endpoints.LegacySTSEndpoint,
// if both credential_source and source_profile have been set, return an error
// as this is undefined behavior.
if len(sharedCfg.AssumeRole.SourceProfile) > 0 {
return ErrSharedConfigSourceCollision
}
// valid credential source values
const (
credSourceEc2Metadata = "Ec2InstanceMetadata"
credSourceEnvironment = "Environment"
credSourceECSContainer = "EcsContainer"
)
switch sharedCfg.AssumeRole.CredentialSource {
case credSourceEc2Metadata:
cfgCp := *cfg
p := defaults.RemoteCredProvider(cfgCp, handlers)
cfgCp.Credentials = credentials.NewCredentials(p)
if len(sharedCfg.AssumeRole.MFASerial) > 0 && sessOpts.AssumeRoleTokenProvider == nil {
// AssumeRole Token provider is required if doing Assume Role
// with MFA.
return AssumeRoleTokenProviderNotSetError{}
}
cfg.Credentials = assumeRoleCredentials(cfgCp, handlers, sharedCfg, sessOpts)
case credSourceEnvironment:
cfg.Credentials = credentials.NewStaticCredentialsFromCreds(
envCfg.Creds,
)
case credSourceECSContainer:
if len(os.Getenv(shareddefaults.ECSCredsProviderEnvVar)) == 0 {
return ErrSharedConfigECSContainerEnvVarEmpty
}
cfgCp := *cfg
p := defaults.RemoteCredProvider(cfgCp, handlers)
creds := credentials.NewCredentials(p)
cfg.Credentials = creds
default:
return ErrSharedConfigInvalidCredSource
}
return nil
}
if len(envCfg.Creds.AccessKeyID) > 0 {
cfg.Credentials = credentials.NewStaticCredentialsFromCreds(
envCfg.Creds,
)
} else if envCfg.EnableSharedConfig && len(sharedCfg.AssumeRole.RoleARN) > 0 && sharedCfg.AssumeRoleSource != nil {
cfgCp := *cfg
cfgCp.Credentials = credentials.NewStaticCredentialsFromCreds(
sharedCfg.AssumeRoleSource.Creds,
)
if len(sharedCfg.AssumeRole.MFASerial) > 0 && sessOpts.AssumeRoleTokenProvider == nil {
// AssumeRole Token provider is required if doing Assume Role
// with MFA.
return AssumeRoleTokenProviderNotSetError{}
}
cfg.Credentials = assumeRoleCredentials(cfgCp, handlers, sharedCfg, sessOpts)
} else if len(sharedCfg.Creds.AccessKeyID) > 0 {
cfg.Credentials = credentials.NewStaticCredentialsFromCreds(
sharedCfg.Creds,
)
} else if len(sharedCfg.CredentialProcess) > 0 {
cfg.Credentials = processcreds.NewCredentials(
sharedCfg.CredentialProcess,
)
} else {
// Fallback to default credentials provider, include mock errors
// for the credential chain so user can identify why credentials
// failed to be retrieved.
cfg.Credentials = credentials.NewCredentials(&credentials.ChainProvider{
VerboseErrors: aws.BoolValue(cfg.CredentialsChainVerboseErrors),
Providers: []credentials.Provider{
&credProviderError{Err: awserr.New("EnvAccessKeyNotFound", "failed to find credentials in the environment.", nil)},
&credProviderError{Err: awserr.New("SharedCredsLoad", fmt.Sprintf("failed to load profile, %s.", envCfg.Profile), nil)},
defaults.RemoteCredProvider(*cfg, handlers),
},
}) })
// Regional Endpoint flag for S3 endpoint resolving
mergeS3UsEast1RegionalEndpointConfig(cfg, []endpoints.S3UsEast1RegionalEndpoint{
userCfg.S3UsEast1RegionalEndpoint,
envCfg.S3UsEast1RegionalEndpoint,
sharedCfg.S3UsEast1RegionalEndpoint,
endpoints.LegacyS3UsEast1Endpoint,
})
// Configure credentials if not already set by the user when creating the
// Session.
if cfg.Credentials == credentials.AnonymousCredentials && userCfg.Credentials == nil {
creds, err := resolveCredentials(cfg, envCfg, sharedCfg, handlers, sessOpts)
if err != nil {
return err
} }
cfg.Credentials = creds
}
cfg.S3UseARNRegion = userCfg.S3UseARNRegion
if cfg.S3UseARNRegion == nil {
cfg.S3UseARNRegion = &envCfg.S3UseARNRegion
}
if cfg.S3UseARNRegion == nil {
cfg.S3UseARNRegion = &sharedCfg.S3UseARNRegion
} }
return nil return nil
} }
func assumeRoleCredentials(cfg aws.Config, handlers request.Handlers, sharedCfg sharedConfig, sessOpts Options) *credentials.Credentials { func mergeSTSRegionalEndpointConfig(cfg *aws.Config, values []endpoints.STSRegionalEndpoint) {
return stscreds.NewCredentials( for _, v := range values {
&Session{ if v != endpoints.UnsetSTSEndpoint {
Config: &cfg, cfg.STSRegionalEndpoint = v
Handlers: handlers.Copy(), break
},
sharedCfg.AssumeRole.RoleARN,
func(opt *stscreds.AssumeRoleProvider) {
opt.RoleSessionName = sharedCfg.AssumeRole.RoleSessionName
// Assume role with external ID
if len(sharedCfg.AssumeRole.ExternalID) > 0 {
opt.ExternalID = aws.String(sharedCfg.AssumeRole.ExternalID)
} }
// Assume role with MFA
if len(sharedCfg.AssumeRole.MFASerial) > 0 {
opt.SerialNumber = aws.String(sharedCfg.AssumeRole.MFASerial)
opt.TokenProvider = sessOpts.AssumeRoleTokenProvider
} }
},
)
} }
// AssumeRoleTokenProviderNotSetError is an error returned when creating a session when the func mergeS3UsEast1RegionalEndpointConfig(cfg *aws.Config, values []endpoints.S3UsEast1RegionalEndpoint) {
// MFAToken option is not set when shared config is configured load assume a for _, v := range values {
// role with an MFA token. if v != endpoints.UnsetS3UsEast1Endpoint {
type AssumeRoleTokenProviderNotSetError struct{} cfg.S3UsEast1RegionalEndpoint = v
break
// Code is the short id of the error. }
func (e AssumeRoleTokenProviderNotSetError) Code() string { }
return "AssumeRoleTokenProviderNotSetError"
}
// Message is the description of the error
func (e AssumeRoleTokenProviderNotSetError) Message() string {
return fmt.Sprintf("assume role with MFA enabled, but AssumeRoleTokenProvider session option not set.")
}
// OrigErr is the underlying error that caused the failure.
func (e AssumeRoleTokenProviderNotSetError) OrigErr() error {
return nil
}
// Error satisfies the error interface.
func (e AssumeRoleTokenProviderNotSetError) Error() string {
return awserr.SprintError(e.Code(), e.Message(), "", nil)
}
type credProviderError struct {
Err error
}
var emptyCreds = credentials.Value{}
func (c credProviderError) Retrieve() (credentials.Value, error) {
return credentials.Value{}, c.Err
}
func (c credProviderError) IsExpired() bool {
return true
} }
func initHandlers(s *Session) { func initHandlers(s *Session) {
@@ -627,7 +617,7 @@ func initHandlers(s *Session) {
} }
} }
// Copy creates and returns a copy of the current Session, coping the config // Copy creates and returns a copy of the current Session, copying the config
// and handlers. If any additional configs are provided they will be merged // and handlers. If any additional configs are provided they will be merged
// on top of the Session's copied config. // on top of the Session's copied config.
// //
@@ -647,47 +637,67 @@ func (s *Session) Copy(cfgs ...*aws.Config) *Session {
// ClientConfig satisfies the client.ConfigProvider interface and is used to // ClientConfig satisfies the client.ConfigProvider interface and is used to
// configure the service client instances. Passing the Session to the service // configure the service client instances. Passing the Session to the service
// client's constructor (New) will use this method to configure the client. // client's constructor (New) will use this method to configure the client.
func (s *Session) ClientConfig(serviceName string, cfgs ...*aws.Config) client.Config { func (s *Session) ClientConfig(service string, cfgs ...*aws.Config) client.Config {
// Backwards compatibility, the error will be eaten if user calls ClientConfig
// directly. All SDK services will use ClientconfigWithError.
cfg, _ := s.clientConfigWithErr(serviceName, cfgs...)
return cfg
}
func (s *Session) clientConfigWithErr(serviceName string, cfgs ...*aws.Config) (client.Config, error) {
s = s.Copy(cfgs...) s = s.Copy(cfgs...)
var resolved endpoints.ResolvedEndpoint
var err error
region := aws.StringValue(s.Config.Region) region := aws.StringValue(s.Config.Region)
resolved, err := s.resolveEndpoint(service, region, s.Config)
if err != nil {
s.Handlers.Validate.PushBack(func(r *request.Request) {
if len(r.ClientInfo.Endpoint) != 0 {
// Error occurred while resolving endpoint, but the request
// being invoked has had an endpoint specified after the client
// was created.
return
}
r.Error = err
})
}
if endpoint := aws.StringValue(s.Config.Endpoint); len(endpoint) != 0 { return client.Config{
resolved.URL = endpoints.AddScheme(endpoint, aws.BoolValue(s.Config.DisableSSL)) Config: s.Config,
resolved.SigningRegion = region Handlers: s.Handlers,
} else { PartitionID: resolved.PartitionID,
resolved, err = s.Config.EndpointResolver.EndpointFor( Endpoint: resolved.URL,
serviceName, region, SigningRegion: resolved.SigningRegion,
SigningNameDerived: resolved.SigningNameDerived,
SigningName: resolved.SigningName,
}
}
func (s *Session) resolveEndpoint(service, region string, cfg *aws.Config) (endpoints.ResolvedEndpoint, error) {
if ep := aws.StringValue(cfg.Endpoint); len(ep) != 0 {
return endpoints.ResolvedEndpoint{
URL: endpoints.AddScheme(ep, aws.BoolValue(cfg.DisableSSL)),
SigningRegion: region,
}, nil
}
resolved, err := cfg.EndpointResolver.EndpointFor(service, region,
func(opt *endpoints.Options) { func(opt *endpoints.Options) {
opt.DisableSSL = aws.BoolValue(s.Config.DisableSSL) opt.DisableSSL = aws.BoolValue(cfg.DisableSSL)
opt.UseDualStack = aws.BoolValue(s.Config.UseDualStack) opt.UseDualStack = aws.BoolValue(cfg.UseDualStack)
// Support for STSRegionalEndpoint where the STSRegionalEndpoint is
// provided in envConfig or sharedConfig with envConfig getting
// precedence.
opt.STSRegionalEndpoint = cfg.STSRegionalEndpoint
// Support for S3UsEast1RegionalEndpoint where the S3UsEast1RegionalEndpoint is
// provided in envConfig or sharedConfig with envConfig getting
// precedence.
opt.S3UsEast1RegionalEndpoint = cfg.S3UsEast1RegionalEndpoint
// Support the condition where the service is modeled but its // Support the condition where the service is modeled but its
// endpoint metadata is not available. // endpoint metadata is not available.
opt.ResolveUnknownService = true opt.ResolveUnknownService = true
}, },
) )
if err != nil {
return endpoints.ResolvedEndpoint{}, err
} }
return client.Config{ return resolved, nil
Config: s.Config,
Handlers: s.Handlers,
Endpoint: resolved.URL,
SigningRegion: resolved.SigningRegion,
SigningNameDerived: resolved.SigningNameDerived,
SigningName: resolved.SigningName,
}, err
} }
// ClientConfigNoResolveEndpoint is the same as ClientConfig with the exception // ClientConfigNoResolveEndpoint is the same as ClientConfig with the exception
@@ -697,12 +707,9 @@ func (s *Session) ClientConfigNoResolveEndpoint(cfgs ...*aws.Config) client.Conf
s = s.Copy(cfgs...) s = s.Copy(cfgs...)
var resolved endpoints.ResolvedEndpoint var resolved endpoints.ResolvedEndpoint
region := aws.StringValue(s.Config.Region)
if ep := aws.StringValue(s.Config.Endpoint); len(ep) > 0 { if ep := aws.StringValue(s.Config.Endpoint); len(ep) > 0 {
resolved.URL = endpoints.AddScheme(ep, aws.BoolValue(s.Config.DisableSSL)) resolved.URL = endpoints.AddScheme(ep, aws.BoolValue(s.Config.DisableSSL))
resolved.SigningRegion = region resolved.SigningRegion = aws.StringValue(s.Config.Region)
} }
return client.Config{ return client.Config{
@@ -714,3 +721,14 @@ func (s *Session) ClientConfigNoResolveEndpoint(cfgs ...*aws.Config) client.Conf
SigningName: resolved.SigningName, SigningName: resolved.SigningName,
} }
} }
// logDeprecatedNewSessionError function enables error handling for session
func (s *Session) logDeprecatedNewSessionError(msg string, err error, cfgs []*aws.Config) {
// Session creation failed, need to report the error and prevent
// any requests from succeeding.
s.Config.MergeIn(cfgs...)
s.Config.Logger.Log("ERROR:", msg, "Error:", err)
s.Handlers.Validate.PushBack(func(r *request.Request) {
r.Error = err
})
}

View File

@@ -5,7 +5,7 @@ import (
"github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/internal/ini" "github.com/aws/aws-sdk-go/internal/ini"
) )
@@ -23,50 +23,66 @@ const (
mfaSerialKey = `mfa_serial` // optional mfaSerialKey = `mfa_serial` // optional
roleSessionNameKey = `role_session_name` // optional roleSessionNameKey = `role_session_name` // optional
// CSM options
csmEnabledKey = `csm_enabled`
csmHostKey = `csm_host`
csmPortKey = `csm_port`
csmClientIDKey = `csm_client_id`
// Additional Config fields // Additional Config fields
regionKey = `region` regionKey = `region`
// endpoint discovery group // endpoint discovery group
enableEndpointDiscoveryKey = `endpoint_discovery_enabled` // optional enableEndpointDiscoveryKey = `endpoint_discovery_enabled` // optional
// External Credential Process // External Credential Process
credentialProcessKey = `credential_process` credentialProcessKey = `credential_process` // optional
// Web Identity Token File
webIdentityTokenFileKey = `web_identity_token_file` // optional
// Additional config fields for regional or legacy endpoints
stsRegionalEndpointSharedKey = `sts_regional_endpoints`
// Additional config fields for regional or legacy endpoints
s3UsEast1RegionalSharedKey = `s3_us_east_1_regional_endpoint`
// DefaultSharedConfigProfile is the default profile to be used when // DefaultSharedConfigProfile is the default profile to be used when
// loading configuration from the config files if another profile name // loading configuration from the config files if another profile name
// is not provided. // is not provided.
DefaultSharedConfigProfile = `default` DefaultSharedConfigProfile = `default`
)
type assumeRoleConfig struct { // S3 ARN Region Usage
RoleARN string s3UseARNRegionKey = "s3_use_arn_region"
SourceProfile string )
CredentialSource string
ExternalID string
MFASerial string
RoleSessionName string
}
// sharedConfig represents the configuration fields of the SDK config files. // sharedConfig represents the configuration fields of the SDK config files.
type sharedConfig struct { type sharedConfig struct {
// Credentials values from the config file. Both aws_access_key_id // Credentials values from the config file. Both aws_access_key_id and
// and aws_secret_access_key must be provided together in the same file // aws_secret_access_key must be provided together in the same file to be
// to be considered valid. The values will be ignored if not a complete group. // considered valid. The values will be ignored if not a complete group.
// aws_session_token is an optional field that can be provided if both of the // aws_session_token is an optional field that can be provided if both of
// other two fields are also provided. // the other two fields are also provided.
// //
// aws_access_key_id // aws_access_key_id
// aws_secret_access_key // aws_secret_access_key
// aws_session_token // aws_session_token
Creds credentials.Value Creds credentials.Value
AssumeRole assumeRoleConfig CredentialSource string
AssumeRoleSource *sharedConfig
// An external process to request credentials
CredentialProcess string CredentialProcess string
WebIdentityTokenFile string
// Region is the region the SDK should use for looking up AWS service endpoints RoleARN string
// and signing requests. RoleSessionName string
ExternalID string
MFASerial string
SourceProfileName string
SourceProfile *sharedConfig
// Region is the region the SDK should use for looking up AWS service
// endpoints and signing requests.
// //
// region // region
Region string Region string
@@ -76,6 +92,30 @@ type sharedConfig struct {
// //
// endpoint_discovery_enabled = true // endpoint_discovery_enabled = true
EnableEndpointDiscovery *bool EnableEndpointDiscovery *bool
// CSM Options
CSMEnabled *bool
CSMHost string
CSMPort string
CSMClientID string
// Specifies the Regional Endpoint flag for the SDK to resolve the endpoint for a service
//
// sts_regional_endpoints = regional
// This can take value as `LegacySTSEndpoint` or `RegionalSTSEndpoint`
STSRegionalEndpoint endpoints.STSRegionalEndpoint
// Specifies the Regional Endpoint flag for the SDK to resolve the endpoint for a service
//
// s3_us_east_1_regional_endpoint = regional
// This can take value as `LegacyS3UsEast1Endpoint` or `RegionalS3UsEast1Endpoint`
S3UsEast1RegionalEndpoint endpoints.S3UsEast1RegionalEndpoint
// Specifies if the S3 service should allow ARNs to direct the region
// the client's requests are sent to.
//
// s3_use_arn_region=true
S3UseARNRegion bool
} }
type sharedConfigFile struct { type sharedConfigFile struct {
@@ -83,17 +123,18 @@ type sharedConfigFile struct {
IniData ini.Sections IniData ini.Sections
} }
// loadSharedConfig retrieves the configuration from the list of files // loadSharedConfig retrieves the configuration from the list of files using
// using the profile provided. The order the files are listed will determine // the profile provided. The order the files are listed will determine
// precedence. Values in subsequent files will overwrite values defined in // precedence. Values in subsequent files will overwrite values defined in
// earlier files. // earlier files.
// //
// For example, given two files A and B. Both define credentials. If the order // For example, given two files A and B. Both define credentials. If the order
// of the files are A then B, B's credential values will be used instead of A's. // of the files are A then B, B's credential values will be used instead of
// A's.
// //
// See sharedConfig.setFromFile for information how the config files // See sharedConfig.setFromFile for information how the config files
// will be loaded. // will be loaded.
func loadSharedConfig(profile string, filenames []string) (sharedConfig, error) { func loadSharedConfig(profile string, filenames []string, exOpts bool) (sharedConfig, error) {
if len(profile) == 0 { if len(profile) == 0 {
profile = DefaultSharedConfigProfile profile = DefaultSharedConfigProfile
} }
@@ -104,16 +145,11 @@ func loadSharedConfig(profile string, filenames []string) (sharedConfig, error)
} }
cfg := sharedConfig{} cfg := sharedConfig{}
if err = cfg.setFromIniFiles(profile, files); err != nil { profiles := map[string]struct{}{}
if err = cfg.setFromIniFiles(profiles, profile, files, exOpts); err != nil {
return sharedConfig{}, err return sharedConfig{}, err
} }
if len(cfg.AssumeRole.SourceProfile) > 0 {
if err := cfg.setAssumeRoleSource(profile, files); err != nil {
return sharedConfig{}, err
}
}
return cfg, nil return cfg, nil
} }
@@ -137,60 +173,88 @@ func loadSharedConfigIniFiles(filenames []string) ([]sharedConfigFile, error) {
return files, nil return files, nil
} }
func (cfg *sharedConfig) setAssumeRoleSource(origProfile string, files []sharedConfigFile) error { func (cfg *sharedConfig) setFromIniFiles(profiles map[string]struct{}, profile string, files []sharedConfigFile, exOpts bool) error {
var assumeRoleSrc sharedConfig
if len(cfg.AssumeRole.CredentialSource) > 0 {
// setAssumeRoleSource is only called when source_profile is found.
// If both source_profile and credential_source are set, then
// ErrSharedConfigSourceCollision will be returned
return ErrSharedConfigSourceCollision
}
// Multiple level assume role chains are not support
if cfg.AssumeRole.SourceProfile == origProfile {
assumeRoleSrc = *cfg
assumeRoleSrc.AssumeRole = assumeRoleConfig{}
} else {
err := assumeRoleSrc.setFromIniFiles(cfg.AssumeRole.SourceProfile, files)
if err != nil {
return err
}
}
if len(assumeRoleSrc.Creds.AccessKeyID) == 0 {
return SharedConfigAssumeRoleError{RoleARN: cfg.AssumeRole.RoleARN}
}
cfg.AssumeRoleSource = &assumeRoleSrc
return nil
}
func (cfg *sharedConfig) setFromIniFiles(profile string, files []sharedConfigFile) error {
// Trim files from the list that don't exist. // Trim files from the list that don't exist.
var skippedFiles int
var profileNotFoundErr error
for _, f := range files { for _, f := range files {
if err := cfg.setFromIniFile(profile, f); err != nil { if err := cfg.setFromIniFile(profile, f, exOpts); err != nil {
if _, ok := err.(SharedConfigProfileNotExistsError); ok { if _, ok := err.(SharedConfigProfileNotExistsError); ok {
// Ignore proviles missings // Ignore profiles not defined in individual files.
profileNotFoundErr = err
skippedFiles++
continue continue
} }
return err return err
} }
} }
if skippedFiles == len(files) {
// If all files were skipped because the profile is not found, return
// the original profile not found error.
return profileNotFoundErr
}
if _, ok := profiles[profile]; ok {
// if this is the second instance of the profile the Assume Role
// options must be cleared because they are only valid for the
// first reference of a profile. The self linked instance of the
// profile only have credential provider options.
cfg.clearAssumeRoleOptions()
} else {
// First time a profile has been seen, It must either be a assume role
// or credentials. Assert if the credential type requires a role ARN,
// the ARN is also set.
if err := cfg.validateCredentialsRequireARN(profile); err != nil {
return err
}
}
profiles[profile] = struct{}{}
if err := cfg.validateCredentialType(); err != nil {
return err
}
// Link source profiles for assume roles
if len(cfg.SourceProfileName) != 0 {
// Linked profile via source_profile ignore credential provider
// options, the source profile must provide the credentials.
cfg.clearCredentialOptions()
srcCfg := &sharedConfig{}
err := srcCfg.setFromIniFiles(profiles, cfg.SourceProfileName, files, exOpts)
if err != nil {
// SourceProfile that doesn't exist is an error in configuration.
if _, ok := err.(SharedConfigProfileNotExistsError); ok {
err = SharedConfigAssumeRoleError{
RoleARN: cfg.RoleARN,
SourceProfile: cfg.SourceProfileName,
}
}
return err
}
if !srcCfg.hasCredentials() {
return SharedConfigAssumeRoleError{
RoleARN: cfg.RoleARN,
SourceProfile: cfg.SourceProfileName,
}
}
cfg.SourceProfile = srcCfg
}
return nil return nil
} }
// setFromFile loads the configuration from the file using // setFromFile loads the configuration from the file using the profile
// the profile provided. A sharedConfig pointer type value is used so that // provided. A sharedConfig pointer type value is used so that multiple config
// multiple config file loadings can be chained. // file loadings can be chained.
// //
// Only loads complete logically grouped values, and will not set fields in cfg // Only loads complete logically grouped values, and will not set fields in cfg
// for incomplete grouped values in the config. Such as credentials. For example // for incomplete grouped values in the config. Such as credentials. For
// if a config file only includes aws_access_key_id but no aws_secret_access_key // example if a config file only includes aws_access_key_id but no
// the aws_access_key_id will be ignored. // aws_secret_access_key the aws_access_key_id will be ignored.
func (cfg *sharedConfig) setFromIniFile(profile string, file sharedConfigFile) error { func (cfg *sharedConfig) setFromIniFile(profile string, file sharedConfigFile, exOpts bool) error {
section, ok := file.IniData.GetSection(profile) section, ok := file.IniData.GetSection(profile)
if !ok { if !ok {
// Fallback to to alternate profile name: profile <name> // Fallback to to alternate profile name: profile <name>
@@ -200,53 +264,171 @@ func (cfg *sharedConfig) setFromIniFile(profile string, file sharedConfigFile) e
} }
} }
if exOpts {
// Assume Role Parameters
updateString(&cfg.RoleARN, section, roleArnKey)
updateString(&cfg.ExternalID, section, externalIDKey)
updateString(&cfg.MFASerial, section, mfaSerialKey)
updateString(&cfg.RoleSessionName, section, roleSessionNameKey)
updateString(&cfg.SourceProfileName, section, sourceProfileKey)
updateString(&cfg.CredentialSource, section, credentialSourceKey)
updateString(&cfg.Region, section, regionKey)
if v := section.String(stsRegionalEndpointSharedKey); len(v) != 0 {
sre, err := endpoints.GetSTSRegionalEndpoint(v)
if err != nil {
return fmt.Errorf("failed to load %s from shared config, %s, %v",
stsRegionalEndpointSharedKey, file.Filename, err)
}
cfg.STSRegionalEndpoint = sre
}
if v := section.String(s3UsEast1RegionalSharedKey); len(v) != 0 {
sre, err := endpoints.GetS3UsEast1RegionalEndpoint(v)
if err != nil {
return fmt.Errorf("failed to load %s from shared config, %s, %v",
s3UsEast1RegionalSharedKey, file.Filename, err)
}
cfg.S3UsEast1RegionalEndpoint = sre
}
}
updateString(&cfg.CredentialProcess, section, credentialProcessKey)
updateString(&cfg.WebIdentityTokenFile, section, webIdentityTokenFileKey)
// Shared Credentials // Shared Credentials
akid := section.String(accessKeyIDKey) creds := credentials.Value{
secret := section.String(secretAccessKey) AccessKeyID: section.String(accessKeyIDKey),
if len(akid) > 0 && len(secret) > 0 { SecretAccessKey: section.String(secretAccessKey),
cfg.Creds = credentials.Value{
AccessKeyID: akid,
SecretAccessKey: secret,
SessionToken: section.String(sessionTokenKey), SessionToken: section.String(sessionTokenKey),
ProviderName: fmt.Sprintf("SharedConfigCredentials: %s", file.Filename), ProviderName: fmt.Sprintf("SharedConfigCredentials: %s", file.Filename),
} }
} if creds.HasKeys() {
cfg.Creds = creds
// Assume Role
roleArn := section.String(roleArnKey)
srcProfile := section.String(sourceProfileKey)
credentialSource := section.String(credentialSourceKey)
hasSource := len(srcProfile) > 0 || len(credentialSource) > 0
if len(roleArn) > 0 && hasSource {
cfg.AssumeRole = assumeRoleConfig{
RoleARN: roleArn,
SourceProfile: srcProfile,
CredentialSource: credentialSource,
ExternalID: section.String(externalIDKey),
MFASerial: section.String(mfaSerialKey),
RoleSessionName: section.String(roleSessionNameKey),
}
}
// `credential_process`
if credProc := section.String(credentialProcessKey); len(credProc) > 0 {
cfg.CredentialProcess = credProc
}
// Region
if v := section.String(regionKey); len(v) > 0 {
cfg.Region = v
} }
// Endpoint discovery // Endpoint discovery
if section.Has(enableEndpointDiscoveryKey) { updateBoolPtr(&cfg.EnableEndpointDiscovery, section, enableEndpointDiscoveryKey)
v := section.Bool(enableEndpointDiscoveryKey)
cfg.EnableEndpointDiscovery = &v // CSM options
updateBoolPtr(&cfg.CSMEnabled, section, csmEnabledKey)
updateString(&cfg.CSMHost, section, csmHostKey)
updateString(&cfg.CSMPort, section, csmPortKey)
updateString(&cfg.CSMClientID, section, csmClientIDKey)
updateBool(&cfg.S3UseARNRegion, section, s3UseARNRegionKey)
return nil
}
func (cfg *sharedConfig) validateCredentialsRequireARN(profile string) error {
var credSource string
switch {
case len(cfg.SourceProfileName) != 0:
credSource = sourceProfileKey
case len(cfg.CredentialSource) != 0:
credSource = credentialSourceKey
case len(cfg.WebIdentityTokenFile) != 0:
credSource = webIdentityTokenFileKey
}
if len(credSource) != 0 && len(cfg.RoleARN) == 0 {
return CredentialRequiresARNError{
Type: credSource,
Profile: profile,
}
} }
return nil return nil
} }
func (cfg *sharedConfig) validateCredentialType() error {
// Only one or no credential type can be defined.
if !oneOrNone(
len(cfg.SourceProfileName) != 0,
len(cfg.CredentialSource) != 0,
len(cfg.CredentialProcess) != 0,
len(cfg.WebIdentityTokenFile) != 0,
) {
return ErrSharedConfigSourceCollision
}
return nil
}
func (cfg *sharedConfig) hasCredentials() bool {
switch {
case len(cfg.SourceProfileName) != 0:
case len(cfg.CredentialSource) != 0:
case len(cfg.CredentialProcess) != 0:
case len(cfg.WebIdentityTokenFile) != 0:
case cfg.Creds.HasKeys():
default:
return false
}
return true
}
func (cfg *sharedConfig) clearCredentialOptions() {
cfg.CredentialSource = ""
cfg.CredentialProcess = ""
cfg.WebIdentityTokenFile = ""
cfg.Creds = credentials.Value{}
}
func (cfg *sharedConfig) clearAssumeRoleOptions() {
cfg.RoleARN = ""
cfg.ExternalID = ""
cfg.MFASerial = ""
cfg.RoleSessionName = ""
cfg.SourceProfileName = ""
}
func oneOrNone(bs ...bool) bool {
var count int
for _, b := range bs {
if b {
count++
if count > 1 {
return false
}
}
}
return true
}
// updateString will only update the dst with the value in the section key, key
// is present in the section.
func updateString(dst *string, section ini.Section, key string) {
if !section.Has(key) {
return
}
*dst = section.String(key)
}
// updateBool will only update the dst with the value in the section key, key
// is present in the section.
func updateBool(dst *bool, section ini.Section, key string) {
if !section.Has(key) {
return
}
*dst = section.Bool(key)
}
// updateBoolPtr will only update the dst with the value in the section key,
// key is present in the section.
func updateBoolPtr(dst **bool, section ini.Section, key string) {
if !section.Has(key) {
return
}
*dst = new(bool)
**dst = section.Bool(key)
}
// SharedConfigLoadError is an error for the shared config file failed to load. // SharedConfigLoadError is an error for the shared config file failed to load.
type SharedConfigLoadError struct { type SharedConfigLoadError struct {
Filename string Filename string
@@ -305,6 +487,7 @@ func (e SharedConfigProfileNotExistsError) Error() string {
// or not complete. // or not complete.
type SharedConfigAssumeRoleError struct { type SharedConfigAssumeRoleError struct {
RoleARN string RoleARN string
SourceProfile string
} }
// Code is the short id of the error. // Code is the short id of the error.
@@ -314,8 +497,10 @@ func (e SharedConfigAssumeRoleError) Code() string {
// Message is the description of the error // Message is the description of the error
func (e SharedConfigAssumeRoleError) Message() string { func (e SharedConfigAssumeRoleError) Message() string {
return fmt.Sprintf("failed to load assume role for %s, source profile has no shared credentials", return fmt.Sprintf(
e.RoleARN) "failed to load assume role for %s, source profile %s has no shared credentials",
e.RoleARN, e.SourceProfile,
)
} }
// OrigErr is the underlying error that caused the failure. // OrigErr is the underlying error that caused the failure.
@@ -327,3 +512,36 @@ func (e SharedConfigAssumeRoleError) OrigErr() error {
func (e SharedConfigAssumeRoleError) Error() string { func (e SharedConfigAssumeRoleError) Error() string {
return awserr.SprintError(e.Code(), e.Message(), "", nil) return awserr.SprintError(e.Code(), e.Message(), "", nil)
} }
// CredentialRequiresARNError provides the error for shared config credentials
// that are incorrectly configured in the shared config or credentials file.
type CredentialRequiresARNError struct {
// type of credentials that were configured.
Type string
// Profile name the credentials were in.
Profile string
}
// Code is the short id of the error.
func (e CredentialRequiresARNError) Code() string {
return "CredentialRequiresARNError"
}
// Message is the description of the error
func (e CredentialRequiresARNError) Message() string {
return fmt.Sprintf(
"credential type %s requires role_arn, profile %s",
e.Type, e.Profile,
)
}
// OrigErr is the underlying error that caused the failure.
func (e CredentialRequiresARNError) OrigErr() error {
return nil
}
// Error satisfies the error interface.
func (e CredentialRequiresARNError) Error() string {
return awserr.SprintError(e.Code(), e.Message(), "", nil)
}

View File

@@ -5,6 +5,7 @@ go_library(
srcs = [ srcs = [
"header_rules.go", "header_rules.go",
"options.go", "options.go",
"stream.go",
"uri_path.go", "uri_path.go",
"v4.go", "v4.go",
], ],
@@ -16,6 +17,7 @@ go_library(
"//vendor/github.com/aws/aws-sdk-go/aws/credentials:go_default_library", "//vendor/github.com/aws/aws-sdk-go/aws/credentials:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/aws/request:go_default_library", "//vendor/github.com/aws/aws-sdk-go/aws/request:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/internal/sdkio:go_default_library", "//vendor/github.com/aws/aws-sdk-go/internal/sdkio:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/internal/strings:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/private/protocol/rest:go_default_library", "//vendor/github.com/aws/aws-sdk-go/private/protocol/rest:go_default_library",
], ],
) )

View File

@@ -1,8 +1,7 @@
package v4 package v4
import ( import (
"net/http" "github.com/aws/aws-sdk-go/internal/strings"
"strings"
) )
// validator houses a set of rule needed for validation of a // validator houses a set of rule needed for validation of a
@@ -61,7 +60,7 @@ type patterns []string
// been found // been found
func (p patterns) IsValid(value string) bool { func (p patterns) IsValid(value string) bool {
for _, pattern := range p { for _, pattern := range p {
if strings.HasPrefix(http.CanonicalHeaderKey(value), pattern) { if strings.HasPrefixFold(value, pattern) {
return true return true
} }
} }

View File

@@ -0,0 +1,63 @@
package v4
import (
"encoding/hex"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws/credentials"
)
type credentialValueProvider interface {
Get() (credentials.Value, error)
}
// StreamSigner implements signing of event stream encoded payloads
type StreamSigner struct {
region string
service string
credentials credentialValueProvider
prevSig []byte
}
// NewStreamSigner creates a SigV4 signer used to sign Event Stream encoded messages
func NewStreamSigner(region, service string, seedSignature []byte, credentials *credentials.Credentials) *StreamSigner {
return &StreamSigner{
region: region,
service: service,
credentials: credentials,
prevSig: seedSignature,
}
}
// GetSignature takes an event stream encoded headers and payload and returns a signature
func (s *StreamSigner) GetSignature(headers, payload []byte, date time.Time) ([]byte, error) {
credValue, err := s.credentials.Get()
if err != nil {
return nil, err
}
sigKey := deriveSigningKey(s.region, s.service, credValue.SecretAccessKey, date)
keyPath := buildSigningScope(s.region, s.service, date)
stringToSign := buildEventStreamStringToSign(headers, payload, s.prevSig, keyPath, date)
signature := hmacSHA256(sigKey, []byte(stringToSign))
s.prevSig = signature
return signature, nil
}
func buildEventStreamStringToSign(headers, payload, prevSig []byte, scope string, date time.Time) string {
return strings.Join([]string{
"AWS4-HMAC-SHA256-PAYLOAD",
formatTime(date),
scope,
hex.EncodeToString(prevSig),
hex.EncodeToString(hashSHA256(headers)),
hex.EncodeToString(hashSHA256(payload)),
}, "\n")
}

View File

@@ -76,9 +76,14 @@ import (
) )
const ( const (
authorizationHeader = "Authorization"
authHeaderSignatureElem = "Signature="
signatureQueryKey = "X-Amz-Signature"
authHeaderPrefix = "AWS4-HMAC-SHA256" authHeaderPrefix = "AWS4-HMAC-SHA256"
timeFormat = "20060102T150405Z" timeFormat = "20060102T150405Z"
shortTimeFormat = "20060102" shortTimeFormat = "20060102"
awsV4Request = "aws4_request"
// emptyStringSHA256 is a SHA256 of an empty string // emptyStringSHA256 is a SHA256 of an empty string
emptyStringSHA256 = `e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855` emptyStringSHA256 = `e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855`
@@ -87,7 +92,7 @@ const (
var ignoredHeaders = rules{ var ignoredHeaders = rules{
blacklist{ blacklist{
mapRule{ mapRule{
"Authorization": struct{}{}, authorizationHeader: struct{}{},
"User-Agent": struct{}{}, "User-Agent": struct{}{},
"X-Amzn-Trace-Id": struct{}{}, "X-Amzn-Trace-Id": struct{}{},
}, },
@@ -231,8 +236,6 @@ type signingCtx struct {
credValues credentials.Value credValues credentials.Value
isPresign bool isPresign bool
formattedTime string
formattedShortTime string
unsignedPayload bool unsignedPayload bool
bodyDigest string bodyDigest string
@@ -532,39 +535,56 @@ func (ctx *signingCtx) build(disableHeaderHoisting bool) error {
ctx.buildSignature() // depends on string to sign ctx.buildSignature() // depends on string to sign
if ctx.isPresign { if ctx.isPresign {
ctx.Request.URL.RawQuery += "&X-Amz-Signature=" + ctx.signature ctx.Request.URL.RawQuery += "&" + signatureQueryKey + "=" + ctx.signature
} else { } else {
parts := []string{ parts := []string{
authHeaderPrefix + " Credential=" + ctx.credValues.AccessKeyID + "/" + ctx.credentialString, authHeaderPrefix + " Credential=" + ctx.credValues.AccessKeyID + "/" + ctx.credentialString,
"SignedHeaders=" + ctx.signedHeaders, "SignedHeaders=" + ctx.signedHeaders,
"Signature=" + ctx.signature, authHeaderSignatureElem + ctx.signature,
} }
ctx.Request.Header.Set("Authorization", strings.Join(parts, ", ")) ctx.Request.Header.Set(authorizationHeader, strings.Join(parts, ", "))
} }
return nil return nil
} }
func (ctx *signingCtx) buildTime() { // GetSignedRequestSignature attempts to extract the signature of the request.
ctx.formattedTime = ctx.Time.UTC().Format(timeFormat) // Returning an error if the request is unsigned, or unable to extract the
ctx.formattedShortTime = ctx.Time.UTC().Format(shortTimeFormat) // signature.
func GetSignedRequestSignature(r *http.Request) ([]byte, error) {
if auth := r.Header.Get(authorizationHeader); len(auth) != 0 {
ps := strings.Split(auth, ", ")
for _, p := range ps {
if idx := strings.Index(p, authHeaderSignatureElem); idx >= 0 {
sig := p[len(authHeaderSignatureElem):]
if len(sig) == 0 {
return nil, fmt.Errorf("invalid request signature authorization header")
}
return hex.DecodeString(sig)
}
}
}
if sig := r.URL.Query().Get("X-Amz-Signature"); len(sig) != 0 {
return hex.DecodeString(sig)
}
return nil, fmt.Errorf("request not signed")
}
func (ctx *signingCtx) buildTime() {
if ctx.isPresign { if ctx.isPresign {
duration := int64(ctx.ExpireTime / time.Second) duration := int64(ctx.ExpireTime / time.Second)
ctx.Query.Set("X-Amz-Date", ctx.formattedTime) ctx.Query.Set("X-Amz-Date", formatTime(ctx.Time))
ctx.Query.Set("X-Amz-Expires", strconv.FormatInt(duration, 10)) ctx.Query.Set("X-Amz-Expires", strconv.FormatInt(duration, 10))
} else { } else {
ctx.Request.Header.Set("X-Amz-Date", ctx.formattedTime) ctx.Request.Header.Set("X-Amz-Date", formatTime(ctx.Time))
} }
} }
func (ctx *signingCtx) buildCredentialString() { func (ctx *signingCtx) buildCredentialString() {
ctx.credentialString = strings.Join([]string{ ctx.credentialString = buildSigningScope(ctx.Region, ctx.ServiceName, ctx.Time)
ctx.formattedShortTime,
ctx.Region,
ctx.ServiceName,
"aws4_request",
}, "/")
if ctx.isPresign { if ctx.isPresign {
ctx.Query.Set("X-Amz-Credential", ctx.credValues.AccessKeyID+"/"+ctx.credentialString) ctx.Query.Set("X-Amz-Credential", ctx.credValues.AccessKeyID+"/"+ctx.credentialString)
@@ -588,8 +608,7 @@ func (ctx *signingCtx) buildCanonicalHeaders(r rule, header http.Header) {
var headers []string var headers []string
headers = append(headers, "host") headers = append(headers, "host")
for k, v := range header { for k, v := range header {
canonicalKey := http.CanonicalHeaderKey(k) if !r.IsValid(k) {
if !r.IsValid(canonicalKey) {
continue // ignored header continue // ignored header
} }
if ctx.SignedHeaderVals == nil { if ctx.SignedHeaderVals == nil {
@@ -653,19 +672,15 @@ func (ctx *signingCtx) buildCanonicalString() {
func (ctx *signingCtx) buildStringToSign() { func (ctx *signingCtx) buildStringToSign() {
ctx.stringToSign = strings.Join([]string{ ctx.stringToSign = strings.Join([]string{
authHeaderPrefix, authHeaderPrefix,
ctx.formattedTime, formatTime(ctx.Time),
ctx.credentialString, ctx.credentialString,
hex.EncodeToString(makeSha256([]byte(ctx.canonicalString))), hex.EncodeToString(hashSHA256([]byte(ctx.canonicalString))),
}, "\n") }, "\n")
} }
func (ctx *signingCtx) buildSignature() { func (ctx *signingCtx) buildSignature() {
secret := ctx.credValues.SecretAccessKey creds := deriveSigningKey(ctx.Region, ctx.ServiceName, ctx.credValues.SecretAccessKey, ctx.Time)
date := makeHmac([]byte("AWS4"+secret), []byte(ctx.formattedShortTime)) signature := hmacSHA256(creds, []byte(ctx.stringToSign))
region := makeHmac(date, []byte(ctx.Region))
service := makeHmac(region, []byte(ctx.ServiceName))
credentials := makeHmac(service, []byte("aws4_request"))
signature := makeHmac(credentials, []byte(ctx.stringToSign))
ctx.signature = hex.EncodeToString(signature) ctx.signature = hex.EncodeToString(signature)
} }
@@ -687,7 +702,11 @@ func (ctx *signingCtx) buildBodyDigest() error {
if !aws.IsReaderSeekable(ctx.Body) { if !aws.IsReaderSeekable(ctx.Body) {
return fmt.Errorf("cannot use unseekable request body %T, for signed request with body", ctx.Body) return fmt.Errorf("cannot use unseekable request body %T, for signed request with body", ctx.Body)
} }
hash = hex.EncodeToString(makeSha256Reader(ctx.Body)) hashBytes, err := makeSha256Reader(ctx.Body)
if err != nil {
return err
}
hash = hex.EncodeToString(hashBytes)
} }
if includeSHA256Header { if includeSHA256Header {
@@ -722,22 +741,28 @@ func (ctx *signingCtx) removePresign() {
ctx.Query.Del("X-Amz-SignedHeaders") ctx.Query.Del("X-Amz-SignedHeaders")
} }
func makeHmac(key []byte, data []byte) []byte { func hmacSHA256(key []byte, data []byte) []byte {
hash := hmac.New(sha256.New, key) hash := hmac.New(sha256.New, key)
hash.Write(data) hash.Write(data)
return hash.Sum(nil) return hash.Sum(nil)
} }
func makeSha256(data []byte) []byte { func hashSHA256(data []byte) []byte {
hash := sha256.New() hash := sha256.New()
hash.Write(data) hash.Write(data)
return hash.Sum(nil) return hash.Sum(nil)
} }
func makeSha256Reader(reader io.ReadSeeker) []byte { func makeSha256Reader(reader io.ReadSeeker) (hashBytes []byte, err error) {
hash := sha256.New() hash := sha256.New()
start, _ := reader.Seek(0, sdkio.SeekCurrent) start, err := reader.Seek(0, sdkio.SeekCurrent)
defer reader.Seek(start, sdkio.SeekStart) if err != nil {
return nil, err
}
defer func() {
// ensure error is return if unable to seek back to start of payload.
_, err = reader.Seek(start, sdkio.SeekStart)
}()
// Use CopyN to avoid allocating the 32KB buffer in io.Copy for bodies // Use CopyN to avoid allocating the 32KB buffer in io.Copy for bodies
// smaller than 32KB. Fall back to io.Copy if we fail to determine the size. // smaller than 32KB. Fall back to io.Copy if we fail to determine the size.
@@ -748,7 +773,7 @@ func makeSha256Reader(reader io.ReadSeeker) []byte {
io.CopyN(hash, reader, size) io.CopyN(hash, reader, size)
} }
return hash.Sum(nil) return hash.Sum(nil), nil
} }
const doubleSpace = " " const doubleSpace = " "
@@ -794,3 +819,28 @@ func stripExcessSpaces(vals []string) {
vals[i] = string(buf[:m]) vals[i] = string(buf[:m])
} }
} }
func buildSigningScope(region, service string, dt time.Time) string {
return strings.Join([]string{
formatShortTime(dt),
region,
service,
awsV4Request,
}, "/")
}
func deriveSigningKey(region, service, secretKey string, dt time.Time) []byte {
kDate := hmacSHA256([]byte("AWS4"+secretKey), []byte(formatShortTime(dt)))
kRegion := hmacSHA256(kDate, []byte(region))
kService := hmacSHA256(kRegion, []byte(service))
signingKey := hmacSHA256(kService, []byte(awsV4Request))
return signingKey
}
func formatShortTime(dt time.Time) string {
return dt.UTC().Format(shortTimeFormat)
}
func formatTime(dt time.Time) string {
return dt.UTC().Format(timeFormat)
}

View File

@@ -2,18 +2,24 @@ package aws
import ( import (
"io" "io"
"strings"
"sync" "sync"
"github.com/aws/aws-sdk-go/internal/sdkio" "github.com/aws/aws-sdk-go/internal/sdkio"
) )
// ReadSeekCloser wraps a io.Reader returning a ReaderSeekerCloser. Should // ReadSeekCloser wraps a io.Reader returning a ReaderSeekerCloser. Allows the
// only be used with an io.Reader that is also an io.Seeker. Doing so may // SDK to accept an io.Reader that is not also an io.Seeker for unsigned
// cause request signature errors, or request body's not sent for GET, HEAD // streaming payload API operations.
// and DELETE HTTP methods.
// //
// Deprecated: Should only be used with io.ReadSeeker. If using for // A ReadSeekCloser wrapping an nonseekable io.Reader used in an API
// S3 PutObject to stream content use s3manager.Uploader instead. // operation's input will prevent that operation being retried in the case of
// network errors, and cause operation requests to fail if the operation
// requires payload signing.
//
// Note: If using With S3 PutObject to stream an object upload The SDK's S3
// Upload manager (s3manager.Uploader) provides support for streaming with the
// ability to retry network errors.
func ReadSeekCloser(r io.Reader) ReaderSeekerCloser { func ReadSeekCloser(r io.Reader) ReaderSeekerCloser {
return ReaderSeekerCloser{r} return ReaderSeekerCloser{r}
} }
@@ -43,7 +49,8 @@ func IsReaderSeekable(r io.Reader) bool {
// Read reads from the reader up to size of p. The number of bytes read, and // Read reads from the reader up to size of p. The number of bytes read, and
// error if it occurred will be returned. // error if it occurred will be returned.
// //
// If the reader is not an io.Reader zero bytes read, and nil error will be returned. // If the reader is not an io.Reader zero bytes read, and nil error will be
// returned.
// //
// Performs the same functionality as io.Reader Read // Performs the same functionality as io.Reader Read
func (r ReaderSeekerCloser) Read(p []byte) (int, error) { func (r ReaderSeekerCloser) Read(p []byte) (int, error) {
@@ -199,3 +206,36 @@ func (b *WriteAtBuffer) Bytes() []byte {
defer b.m.Unlock() defer b.m.Unlock()
return b.buf return b.buf
} }
// MultiCloser is a utility to close multiple io.Closers within a single
// statement.
type MultiCloser []io.Closer
// Close closes all of the io.Closers making up the MultiClosers. Any
// errors that occur while closing will be returned in the order they
// occur.
func (m MultiCloser) Close() error {
var errs errors
for _, c := range m {
err := c.Close()
if err != nil {
errs = append(errs, err)
}
}
if len(errs) != 0 {
return errs
}
return nil
}
type errors []error
func (es errors) Error() string {
var parts []string
for _, e := range es {
parts = append(parts, e.Error())
}
return strings.Join(parts, "\n")
}

View File

@@ -5,4 +5,4 @@ package aws
const SDKName = "aws-sdk-go" const SDKName = "aws-sdk-go"
// SDKVersion is the version of this SDK // SDKVersion is the version of this SDK
const SDKVersion = "1.16.26" const SDKVersion = "1.28.2"

View File

@@ -162,7 +162,7 @@ loop:
if len(tokens) == 0 { if len(tokens) == 0 {
break loop break loop
} }
// if should skip is true, we skip the tokens until should skip is set to false.
step = SkipTokenState step = SkipTokenState
} }
@@ -218,7 +218,7 @@ loop:
// S -> equal_expr' expr_stmt' // S -> equal_expr' expr_stmt'
switch k.Kind { switch k.Kind {
case ASTKindEqualExpr: case ASTKindEqualExpr:
// assiging a value to some key // assigning a value to some key
k.AppendChild(newExpression(tok)) k.AppendChild(newExpression(tok))
stack.Push(newExprStatement(k)) stack.Push(newExprStatement(k))
case ASTKindExpr: case ASTKindExpr:
@@ -250,6 +250,13 @@ loop:
if !runeCompare(tok.Raw(), openBrace) { if !runeCompare(tok.Raw(), openBrace) {
return nil, NewParseError("expected '['") return nil, NewParseError("expected '['")
} }
// If OpenScopeState is not at the start, we must mark the previous ast as complete
//
// for example: if previous ast was a skip statement;
// we should mark it as complete before we create a new statement
if k.Kind != ASTKindStart {
stack.MarkComplete(k)
}
stmt := newStatement() stmt := newStatement()
stack.Push(stmt) stack.Push(stmt)
@@ -304,7 +311,9 @@ loop:
stmt := newCommentStatement(tok) stmt := newCommentStatement(tok)
stack.Push(stmt) stack.Push(stmt)
default: default:
return nil, NewParseError(fmt.Sprintf("invalid state with ASTKind %v and TokenType %v", k, tok)) return nil, NewParseError(
fmt.Sprintf("invalid state with ASTKind %v and TokenType %v",
k, tok.Type()))
} }
if len(tokens) > 0 { if len(tokens) > 0 {
@@ -314,7 +323,7 @@ loop:
// this occurs when a statement has not been completed // this occurs when a statement has not been completed
if stack.top > 1 { if stack.top > 1 {
return nil, NewParseError(fmt.Sprintf("incomplete expression: %v", stack.container)) return nil, NewParseError(fmt.Sprintf("incomplete ini expression"))
} }
// returns a sublist which excludes the start symbol // returns a sublist which excludes the start symbol

View File

@@ -22,24 +22,24 @@ func newSkipper() skipper {
} }
func (s *skipper) ShouldSkip(tok Token) bool { func (s *skipper) ShouldSkip(tok Token) bool {
// should skip state will be modified only if previous token was new line (NL);
// and the current token is not WhiteSpace (WS).
if s.shouldSkip && if s.shouldSkip &&
s.prevTok.Type() == TokenNL && s.prevTok.Type() == TokenNL &&
tok.Type() != TokenWS { tok.Type() != TokenWS {
s.Continue() s.Continue()
return false return false
} }
s.prevTok = tok s.prevTok = tok
return s.shouldSkip return s.shouldSkip
} }
func (s *skipper) Skip() { func (s *skipper) Skip() {
s.shouldSkip = true s.shouldSkip = true
s.prevTok = emptyToken
} }
func (s *skipper) Continue() { func (s *skipper) Continue() {
s.shouldSkip = false s.shouldSkip = false
// empty token is assigned as we return to default state, when should skip is false
s.prevTok = emptyToken s.prevTok = emptyToken
} }

View File

@@ -3,6 +3,7 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library")
go_library( go_library(
name = "go_default_library", name = "go_default_library",
srcs = [ srcs = [
"byte.go",
"io_go1.6.go", "io_go1.6.go",
"io_go1.7.go", "io_go1.7.go",
], ],

View File

@@ -0,0 +1,12 @@
package sdkio
const (
// Byte is 8 bits
Byte int64 = 1
// KibiByte (KiB) is 1024 Bytes
KibiByte = Byte * 1024
// MebiByte (MiB) is 1024 KiB
MebiByte = KibiByte * 1024
// GibiByte (GiB) is 1024 MiB
GibiByte = MebiByte * 1024
)

View File

@@ -0,0 +1,26 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")
go_library(
name = "go_default_library",
srcs = [
"floor.go",
"floor_go1.9.go",
],
importmap = "k8s.io/kubernetes/vendor/github.com/aws/aws-sdk-go/internal/sdkmath",
importpath = "github.com/aws/aws-sdk-go/internal/sdkmath",
visibility = ["//vendor/github.com/aws/aws-sdk-go:__subpackages__"],
)
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [":package-srcs"],
tags = ["automanaged"],
visibility = ["//visibility:public"],
)

View File

@@ -0,0 +1,15 @@
// +build go1.10
package sdkmath
import "math"
// Round returns the nearest integer, rounding half away from zero.
//
// Special cases are:
// Round(±0) = ±0
// Round(±Inf) = ±Inf
// Round(NaN) = NaN
func Round(x float64) float64 {
return math.Round(x)
}

View File

@@ -0,0 +1,56 @@
// +build !go1.10
package sdkmath
import "math"
// Copied from the Go standard library's (Go 1.12) math/floor.go for use in
// Go version prior to Go 1.10.
const (
uvone = 0x3FF0000000000000
mask = 0x7FF
shift = 64 - 11 - 1
bias = 1023
signMask = 1 << 63
fracMask = 1<<shift - 1
)
// Round returns the nearest integer, rounding half away from zero.
//
// Special cases are:
// Round(±0) = ±0
// Round(±Inf) = ±Inf
// Round(NaN) = NaN
//
// Copied from the Go standard library's (Go 1.12) math/floor.go for use in
// Go version prior to Go 1.10.
func Round(x float64) float64 {
// Round is a faster implementation of:
//
// func Round(x float64) float64 {
// t := Trunc(x)
// if Abs(x-t) >= 0.5 {
// return t + Copysign(1, x)
// }
// return t
// }
bits := math.Float64bits(x)
e := uint(bits>>shift) & mask
if e < bias {
// Round abs(x) < 1 including denormals.
bits &= signMask // +-0
if e == bias-1 {
bits |= uvone // +-1
}
} else if e < bias+shift {
// Round any abs(x) >= 1 containing a fractional component [0,1).
//
// Numbers with larger exponents are returned unchanged since they
// must be either an integer, infinity, or NaN.
const half = 1 << (shift - 1)
e -= bias
bits += half >> e
bits &^= fracMask >> e
}
return math.Float64frombits(bits)
}

View File

@@ -2,7 +2,11 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library")
go_library( go_library(
name = "go_default_library", name = "go_default_library",
srcs = ["locked_source.go"], srcs = [
"locked_source.go",
"read.go",
"read_1_5.go",
],
importmap = "k8s.io/kubernetes/vendor/github.com/aws/aws-sdk-go/internal/sdkrand", importmap = "k8s.io/kubernetes/vendor/github.com/aws/aws-sdk-go/internal/sdkrand",
importpath = "github.com/aws/aws-sdk-go/internal/sdkrand", importpath = "github.com/aws/aws-sdk-go/internal/sdkrand",
visibility = ["//vendor/github.com/aws/aws-sdk-go:__subpackages__"], visibility = ["//vendor/github.com/aws/aws-sdk-go:__subpackages__"],

View File

@@ -0,0 +1,11 @@
// +build go1.6
package sdkrand
import "math/rand"
// Read provides the stub for math.Rand.Read method support for go version's
// 1.6 and greater.
func Read(r *rand.Rand, p []byte) (int, error) {
return r.Read(p)
}

View File

@@ -0,0 +1,24 @@
// +build !go1.6
package sdkrand
import "math/rand"
// Read backfills Go 1.6's math.Rand.Reader for Go 1.5
func Read(r *rand.Rand, p []byte) (n int, err error) {
// Copy of Go standard libraries math package's read function not added to
// standard library until Go 1.6.
var pos int8
var val int64
for n = 0; n < len(p); n++ {
if pos == 0 {
val = r.Int63()
pos = 7
}
p[n] = byte(val)
val >>= 8
pos--
}
return n, err
}

View File

@@ -0,0 +1,23 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")
go_library(
name = "go_default_library",
srcs = ["strings.go"],
importmap = "k8s.io/kubernetes/vendor/github.com/aws/aws-sdk-go/internal/strings",
importpath = "github.com/aws/aws-sdk-go/internal/strings",
visibility = ["//vendor/github.com/aws/aws-sdk-go:__subpackages__"],
)
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [":package-srcs"],
tags = ["automanaged"],
visibility = ["//visibility:public"],
)

View File

@@ -0,0 +1,11 @@
package strings
import (
"strings"
)
// HasPrefixFold tests whether the string s begins with prefix, interpreted as UTF-8 strings,
// under Unicode case-folding.
func HasPrefixFold(s, prefix string) bool {
return len(s) >= len(prefix) && strings.EqualFold(s[0:len(prefix)], prefix)
}

View File

@@ -8,16 +8,20 @@ go_library(
"idempotency.go", "idempotency.go",
"jsonvalue.go", "jsonvalue.go",
"payload.go", "payload.go",
"protocol.go",
"timestamp.go", "timestamp.go",
"unmarshal.go", "unmarshal.go",
"unmarshal_error.go",
], ],
importmap = "k8s.io/kubernetes/vendor/github.com/aws/aws-sdk-go/private/protocol", importmap = "k8s.io/kubernetes/vendor/github.com/aws/aws-sdk-go/private/protocol",
importpath = "github.com/aws/aws-sdk-go/private/protocol", importpath = "github.com/aws/aws-sdk-go/private/protocol",
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
"//vendor/github.com/aws/aws-sdk-go/aws:go_default_library", "//vendor/github.com/aws/aws-sdk-go/aws:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/aws/awserr:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/aws/client/metadata:go_default_library", "//vendor/github.com/aws/aws-sdk-go/aws/client/metadata:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/aws/request:go_default_library", "//vendor/github.com/aws/aws-sdk-go/aws/request:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/internal/sdkmath:go_default_library",
], ],
) )

View File

@@ -21,7 +21,8 @@ func Build(r *request.Request) {
"Version": {r.ClientInfo.APIVersion}, "Version": {r.ClientInfo.APIVersion},
} }
if err := queryutil.Parse(body, r.Params, true); err != nil { if err := queryutil.Parse(body, r.Params, true); err != nil {
r.Error = awserr.New("SerializationError", "failed encoding EC2 Query request", err) r.Error = awserr.New(request.ErrCodeSerialization,
"failed encoding EC2 Query request", err)
} }
if !r.IsPresigned() { if !r.IsPresigned() {

View File

@@ -4,7 +4,6 @@ package ec2query
import ( import (
"encoding/xml" "encoding/xml"
"io"
"github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/request"
@@ -28,7 +27,8 @@ func Unmarshal(r *request.Request) {
err := xmlutil.UnmarshalXML(r.Data, decoder, "") err := xmlutil.UnmarshalXML(r.Data, decoder, "")
if err != nil { if err != nil {
r.Error = awserr.NewRequestFailure( r.Error = awserr.NewRequestFailure(
awserr.New("SerializationError", "failed decoding EC2 Query response", err), awserr.New(request.ErrCodeSerialization,
"failed decoding EC2 Query response", err),
r.HTTPResponse.StatusCode, r.HTTPResponse.StatusCode,
r.RequestID, r.RequestID,
) )
@@ -39,7 +39,11 @@ func Unmarshal(r *request.Request) {
// UnmarshalMeta unmarshals response headers for the EC2 protocol. // UnmarshalMeta unmarshals response headers for the EC2 protocol.
func UnmarshalMeta(r *request.Request) { func UnmarshalMeta(r *request.Request) {
// TODO implement unmarshaling of request IDs r.RequestID = r.HTTPResponse.Header.Get("X-Amzn-Requestid")
if r.RequestID == "" {
// Alternative version of request id in the header
r.RequestID = r.HTTPResponse.Header.Get("X-Amz-Request-Id")
}
} }
type xmlErrorResponse struct { type xmlErrorResponse struct {
@@ -53,19 +57,21 @@ type xmlErrorResponse struct {
func UnmarshalError(r *request.Request) { func UnmarshalError(r *request.Request) {
defer r.HTTPResponse.Body.Close() defer r.HTTPResponse.Body.Close()
resp := &xmlErrorResponse{} var respErr xmlErrorResponse
err := xml.NewDecoder(r.HTTPResponse.Body).Decode(resp) err := xmlutil.UnmarshalXMLError(&respErr, r.HTTPResponse.Body)
if err != nil && err != io.EOF { if err != nil {
r.Error = awserr.NewRequestFailure( r.Error = awserr.NewRequestFailure(
awserr.New("SerializationError", "failed decoding EC2 Query error response", err), awserr.New(request.ErrCodeSerialization,
"failed to unmarshal error message", err),
r.HTTPResponse.StatusCode, r.HTTPResponse.StatusCode,
r.RequestID, r.RequestID,
) )
} else { return
r.Error = awserr.NewRequestFailure(
awserr.New(resp.Code, resp.Message, nil),
r.HTTPResponse.StatusCode,
resp.RequestID,
)
} }
r.Error = awserr.NewRequestFailure(
awserr.New(respErr.Code, respErr.Message, nil),
r.HTTPResponse.StatusCode,
respErr.RequestID,
)
} }

View File

@@ -11,6 +11,7 @@ go_library(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
"//vendor/github.com/aws/aws-sdk-go/aws:go_default_library", "//vendor/github.com/aws/aws-sdk-go/aws:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/aws/awserr:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/private/protocol:go_default_library", "//vendor/github.com/aws/aws-sdk-go/private/protocol:go_default_library",
], ],
) )

View File

@@ -1,6 +1,7 @@
package jsonutil package jsonutil
import ( import (
"bytes"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
@@ -9,9 +10,30 @@ import (
"time" "time"
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/private/protocol" "github.com/aws/aws-sdk-go/private/protocol"
) )
// UnmarshalJSONError unmarshal's the reader's JSON document into the passed in
// type. The value to unmarshal the json document into must be a pointer to the
// type.
func UnmarshalJSONError(v interface{}, stream io.Reader) error {
var errBuf bytes.Buffer
body := io.TeeReader(stream, &errBuf)
err := json.NewDecoder(body).Decode(v)
if err != nil {
msg := "failed decoding error message"
if err == io.EOF {
msg = "error message missing"
err = nil
}
return awserr.NewUnmarshalError(err, msg, errBuf.Bytes())
}
return nil
}
// UnmarshalJSON reads a stream and unmarshals the results in object v. // UnmarshalJSON reads a stream and unmarshals the results in object v.
func UnmarshalJSON(v interface{}, stream io.Reader) error { func UnmarshalJSON(v interface{}, stream io.Reader) error {
var out interface{} var out interface{}

View File

@@ -2,13 +2,17 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library")
go_library( go_library(
name = "go_default_library", name = "go_default_library",
srcs = ["jsonrpc.go"], srcs = [
"jsonrpc.go",
"unmarshal_error.go",
],
importmap = "k8s.io/kubernetes/vendor/github.com/aws/aws-sdk-go/private/protocol/jsonrpc", importmap = "k8s.io/kubernetes/vendor/github.com/aws/aws-sdk-go/private/protocol/jsonrpc",
importpath = "github.com/aws/aws-sdk-go/private/protocol/jsonrpc", importpath = "github.com/aws/aws-sdk-go/private/protocol/jsonrpc",
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
"//vendor/github.com/aws/aws-sdk-go/aws/awserr:go_default_library", "//vendor/github.com/aws/aws-sdk-go/aws/awserr:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/aws/request:go_default_library", "//vendor/github.com/aws/aws-sdk-go/aws/request:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/private/protocol:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/private/protocol/json/jsonutil:go_default_library", "//vendor/github.com/aws/aws-sdk-go/private/protocol/json/jsonutil:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/private/protocol/rest:go_default_library", "//vendor/github.com/aws/aws-sdk-go/private/protocol/rest:go_default_library",
], ],

View File

@@ -6,10 +6,6 @@ package jsonrpc
//go:generate go run -tags codegen ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/output/json.json unmarshal_test.go //go:generate go run -tags codegen ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/output/json.json unmarshal_test.go
import ( import (
"encoding/json"
"io"
"strings"
"github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/private/protocol/json/jsonutil" "github.com/aws/aws-sdk-go/private/protocol/json/jsonutil"
@@ -18,17 +14,26 @@ import (
var emptyJSON = []byte("{}") var emptyJSON = []byte("{}")
// BuildHandler is a named request handler for building jsonrpc protocol requests // BuildHandler is a named request handler for building jsonrpc protocol
var BuildHandler = request.NamedHandler{Name: "awssdk.jsonrpc.Build", Fn: Build} // requests
var BuildHandler = request.NamedHandler{
Name: "awssdk.jsonrpc.Build",
Fn: Build,
}
// UnmarshalHandler is a named request handler for unmarshaling jsonrpc protocol requests // UnmarshalHandler is a named request handler for unmarshaling jsonrpc
var UnmarshalHandler = request.NamedHandler{Name: "awssdk.jsonrpc.Unmarshal", Fn: Unmarshal} // protocol requests
var UnmarshalHandler = request.NamedHandler{
Name: "awssdk.jsonrpc.Unmarshal",
Fn: Unmarshal,
}
// UnmarshalMetaHandler is a named request handler for unmarshaling jsonrpc protocol request metadata // UnmarshalMetaHandler is a named request handler for unmarshaling jsonrpc
var UnmarshalMetaHandler = request.NamedHandler{Name: "awssdk.jsonrpc.UnmarshalMeta", Fn: UnmarshalMeta} // protocol request metadata
var UnmarshalMetaHandler = request.NamedHandler{
// UnmarshalErrorHandler is a named request handler for unmarshaling jsonrpc protocol request errors Name: "awssdk.jsonrpc.UnmarshalMeta",
var UnmarshalErrorHandler = request.NamedHandler{Name: "awssdk.jsonrpc.UnmarshalError", Fn: UnmarshalError} Fn: UnmarshalMeta,
}
// Build builds a JSON payload for a JSON RPC request. // Build builds a JSON payload for a JSON RPC request.
func Build(req *request.Request) { func Build(req *request.Request) {
@@ -37,7 +42,7 @@ func Build(req *request.Request) {
if req.ParamsFilled() { if req.ParamsFilled() {
buf, err = jsonutil.BuildJSON(req.Params) buf, err = jsonutil.BuildJSON(req.Params)
if err != nil { if err != nil {
req.Error = awserr.New("SerializationError", "failed encoding JSON RPC request", err) req.Error = awserr.New(request.ErrCodeSerialization, "failed encoding JSON RPC request", err)
return return
} }
} else { } else {
@@ -52,9 +57,12 @@ func Build(req *request.Request) {
target := req.ClientInfo.TargetPrefix + "." + req.Operation.Name target := req.ClientInfo.TargetPrefix + "." + req.Operation.Name
req.HTTPRequest.Header.Add("X-Amz-Target", target) req.HTTPRequest.Header.Add("X-Amz-Target", target)
} }
if req.ClientInfo.JSONVersion != "" {
// Only set the content type if one is not already specified and an
// JSONVersion is specified.
if ct, v := req.HTTPRequest.Header.Get("Content-Type"), req.ClientInfo.JSONVersion; len(ct) == 0 && len(v) != 0 {
jsonVersion := req.ClientInfo.JSONVersion jsonVersion := req.ClientInfo.JSONVersion
req.HTTPRequest.Header.Add("Content-Type", "application/x-amz-json-"+jsonVersion) req.HTTPRequest.Header.Set("Content-Type", "application/x-amz-json-"+jsonVersion)
} }
} }
@@ -65,7 +73,7 @@ func Unmarshal(req *request.Request) {
err := jsonutil.UnmarshalJSON(req.Data, req.HTTPResponse.Body) err := jsonutil.UnmarshalJSON(req.Data, req.HTTPResponse.Body)
if err != nil { if err != nil {
req.Error = awserr.NewRequestFailure( req.Error = awserr.NewRequestFailure(
awserr.New("SerializationError", "failed decoding JSON RPC response", err), awserr.New(request.ErrCodeSerialization, "failed decoding JSON RPC response", err),
req.HTTPResponse.StatusCode, req.HTTPResponse.StatusCode,
req.RequestID, req.RequestID,
) )
@@ -78,38 +86,3 @@ func Unmarshal(req *request.Request) {
func UnmarshalMeta(req *request.Request) { func UnmarshalMeta(req *request.Request) {
rest.UnmarshalMeta(req) rest.UnmarshalMeta(req)
} }
// UnmarshalError unmarshals an error response for a JSON RPC service.
func UnmarshalError(req *request.Request) {
defer req.HTTPResponse.Body.Close()
var jsonErr jsonErrorResponse
err := json.NewDecoder(req.HTTPResponse.Body).Decode(&jsonErr)
if err == io.EOF {
req.Error = awserr.NewRequestFailure(
awserr.New("SerializationError", req.HTTPResponse.Status, nil),
req.HTTPResponse.StatusCode,
req.RequestID,
)
return
} else if err != nil {
req.Error = awserr.NewRequestFailure(
awserr.New("SerializationError", "failed decoding JSON RPC error response", err),
req.HTTPResponse.StatusCode,
req.RequestID,
)
return
}
codes := strings.SplitN(jsonErr.Code, "#", 2)
req.Error = awserr.NewRequestFailure(
awserr.New(codes[len(codes)-1], jsonErr.Message, nil),
req.HTTPResponse.StatusCode,
req.RequestID,
)
}
type jsonErrorResponse struct {
Code string `json:"__type"`
Message string `json:"message"`
}

View File

@@ -0,0 +1,107 @@
package jsonrpc
import (
"bytes"
"io"
"io/ioutil"
"net/http"
"strings"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/private/protocol"
"github.com/aws/aws-sdk-go/private/protocol/json/jsonutil"
)
// UnmarshalTypedError provides unmarshaling errors API response errors
// for both typed and untyped errors.
type UnmarshalTypedError struct {
exceptions map[string]func(protocol.ResponseMetadata) error
}
// NewUnmarshalTypedError returns an UnmarshalTypedError initialized for the
// set of exception names to the error unmarshalers
func NewUnmarshalTypedError(exceptions map[string]func(protocol.ResponseMetadata) error) *UnmarshalTypedError {
return &UnmarshalTypedError{
exceptions: exceptions,
}
}
// UnmarshalError attempts to unmarshal the HTTP response error as a known
// error type. If unable to unmarshal the error type, the generic SDK error
// type will be used.
func (u *UnmarshalTypedError) UnmarshalError(
resp *http.Response,
respMeta protocol.ResponseMetadata,
) (error, error) {
var buf bytes.Buffer
var jsonErr jsonErrorResponse
teeReader := io.TeeReader(resp.Body, &buf)
err := jsonutil.UnmarshalJSONError(&jsonErr, teeReader)
if err != nil {
return nil, err
}
body := ioutil.NopCloser(&buf)
// Code may be separated by hash(#), with the last element being the code
// used by the SDK.
codeParts := strings.SplitN(jsonErr.Code, "#", 2)
code := codeParts[len(codeParts)-1]
msg := jsonErr.Message
if fn, ok := u.exceptions[code]; ok {
// If exception code is know, use associated constructor to get a value
// for the exception that the JSON body can be unmarshaled into.
v := fn(respMeta)
err := jsonutil.UnmarshalJSON(v, body)
if err != nil {
return nil, err
}
return v, nil
}
// fallback to unmodeled generic exceptions
return awserr.NewRequestFailure(
awserr.New(code, msg, nil),
respMeta.StatusCode,
respMeta.RequestID,
), nil
}
// UnmarshalErrorHandler is a named request handler for unmarshaling jsonrpc
// protocol request errors
var UnmarshalErrorHandler = request.NamedHandler{
Name: "awssdk.jsonrpc.UnmarshalError",
Fn: UnmarshalError,
}
// UnmarshalError unmarshals an error response for a JSON RPC service.
func UnmarshalError(req *request.Request) {
defer req.HTTPResponse.Body.Close()
var jsonErr jsonErrorResponse
err := jsonutil.UnmarshalJSONError(&jsonErr, req.HTTPResponse.Body)
if err != nil {
req.Error = awserr.NewRequestFailure(
awserr.New(request.ErrCodeSerialization,
"failed to unmarshal error message", err),
req.HTTPResponse.StatusCode,
req.RequestID,
)
return
}
codes := strings.SplitN(jsonErr.Code, "#", 2)
req.Error = awserr.NewRequestFailure(
awserr.New(codes[len(codes)-1], jsonErr.Message, nil),
req.HTTPResponse.StatusCode,
req.RequestID,
)
}
type jsonErrorResponse struct {
Code string `json:"__type"`
Message string `json:"message"`
}

View File

@@ -64,7 +64,7 @@ func (h HandlerPayloadMarshal) MarshalPayload(w io.Writer, v interface{}) error
metadata.ClientInfo{}, metadata.ClientInfo{},
request.Handlers{}, request.Handlers{},
nil, nil,
&request.Operation{HTTPMethod: "GET"}, &request.Operation{HTTPMethod: "PUT"},
v, v,
nil, nil,
) )

View File

@@ -0,0 +1,49 @@
package protocol
import (
"fmt"
"strings"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
)
// RequireHTTPMinProtocol request handler is used to enforce that
// the target endpoint supports the given major and minor HTTP protocol version.
type RequireHTTPMinProtocol struct {
Major, Minor int
}
// Handler will mark the request.Request with an error if the
// target endpoint did not connect with the required HTTP protocol
// major and minor version.
func (p RequireHTTPMinProtocol) Handler(r *request.Request) {
if r.Error != nil || r.HTTPResponse == nil {
return
}
if !strings.HasPrefix(r.HTTPResponse.Proto, "HTTP") {
r.Error = newMinHTTPProtoError(p.Major, p.Minor, r)
}
if r.HTTPResponse.ProtoMajor < p.Major || r.HTTPResponse.ProtoMinor < p.Minor {
r.Error = newMinHTTPProtoError(p.Major, p.Minor, r)
}
}
// ErrCodeMinimumHTTPProtocolError error code is returned when the target endpoint
// did not match the required HTTP major and minor protocol version.
const ErrCodeMinimumHTTPProtocolError = "MinimumHTTPProtocolError"
func newMinHTTPProtoError(major, minor int, r *request.Request) error {
return awserr.NewRequestFailure(
awserr.New("MinimumHTTPProtocolError",
fmt.Sprintf(
"operation requires minimum HTTP protocol of HTTP/%d.%d, but was %s",
major, minor, r.HTTPResponse.Proto,
),
nil,
),
r.HTTPResponse.StatusCode, r.RequestID,
)
}

View File

@@ -21,7 +21,7 @@ func Build(r *request.Request) {
"Version": {r.ClientInfo.APIVersion}, "Version": {r.ClientInfo.APIVersion},
} }
if err := queryutil.Parse(body, r.Params, false); err != nil { if err := queryutil.Parse(body, r.Params, false); err != nil {
r.Error = awserr.New("SerializationError", "failed encoding Query request", err) r.Error = awserr.New(request.ErrCodeSerialization, "failed encoding Query request", err)
return return
} }

View File

@@ -24,7 +24,7 @@ func Unmarshal(r *request.Request) {
err := xmlutil.UnmarshalXML(r.Data, decoder, r.Operation.Name+"Result") err := xmlutil.UnmarshalXML(r.Data, decoder, r.Operation.Name+"Result")
if err != nil { if err != nil {
r.Error = awserr.NewRequestFailure( r.Error = awserr.NewRequestFailure(
awserr.New("SerializationError", "failed decoding Query response", err), awserr.New(request.ErrCodeSerialization, "failed decoding Query response", err),
r.HTTPResponse.StatusCode, r.HTTPResponse.StatusCode,
r.RequestID, r.RequestID,
) )

View File

@@ -2,73 +2,68 @@ package query
import ( import (
"encoding/xml" "encoding/xml"
"io/ioutil" "fmt"
"github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/private/protocol/xml/xmlutil"
) )
// UnmarshalErrorHandler is a name request handler to unmarshal request errors
var UnmarshalErrorHandler = request.NamedHandler{Name: "awssdk.query.UnmarshalError", Fn: UnmarshalError}
type xmlErrorResponse struct { type xmlErrorResponse struct {
XMLName xml.Name `xml:"ErrorResponse"`
Code string `xml:"Error>Code"` Code string `xml:"Error>Code"`
Message string `xml:"Error>Message"` Message string `xml:"Error>Message"`
RequestID string `xml:"RequestId"` RequestID string `xml:"RequestId"`
} }
type xmlServiceUnavailableResponse struct { type xmlResponseError struct {
XMLName xml.Name `xml:"ServiceUnavailableException"` xmlErrorResponse
} }
// UnmarshalErrorHandler is a name request handler to unmarshal request errors func (e *xmlResponseError) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
var UnmarshalErrorHandler = request.NamedHandler{Name: "awssdk.query.UnmarshalError", Fn: UnmarshalError} const svcUnavailableTagName = "ServiceUnavailableException"
const errorResponseTagName = "ErrorResponse"
switch start.Name.Local {
case svcUnavailableTagName:
e.Code = svcUnavailableTagName
e.Message = "service is unavailable"
return d.Skip()
case errorResponseTagName:
return d.DecodeElement(&e.xmlErrorResponse, &start)
default:
return fmt.Errorf("unknown error response tag, %v", start)
}
}
// UnmarshalError unmarshals an error response for an AWS Query service. // UnmarshalError unmarshals an error response for an AWS Query service.
func UnmarshalError(r *request.Request) { func UnmarshalError(r *request.Request) {
defer r.HTTPResponse.Body.Close() defer r.HTTPResponse.Body.Close()
bodyBytes, err := ioutil.ReadAll(r.HTTPResponse.Body) var respErr xmlResponseError
err := xmlutil.UnmarshalXMLError(&respErr, r.HTTPResponse.Body)
if err != nil { if err != nil {
r.Error = awserr.NewRequestFailure( r.Error = awserr.NewRequestFailure(
awserr.New("SerializationError", "failed to read from query HTTP response body", err), awserr.New(request.ErrCodeSerialization,
"failed to unmarshal error message", err),
r.HTTPResponse.StatusCode, r.HTTPResponse.StatusCode,
r.RequestID, r.RequestID,
) )
return return
} }
// First check for specific error reqID := respErr.RequestID
resp := xmlErrorResponse{} if len(reqID) == 0 {
decodeErr := xml.Unmarshal(bodyBytes, &resp)
if decodeErr == nil {
reqID := resp.RequestID
if reqID == "" {
reqID = r.RequestID reqID = r.RequestID
} }
r.Error = awserr.NewRequestFailure( r.Error = awserr.NewRequestFailure(
awserr.New(resp.Code, resp.Message, nil), awserr.New(respErr.Code, respErr.Message, nil),
r.HTTPResponse.StatusCode, r.HTTPResponse.StatusCode,
reqID, reqID,
) )
return
}
// Check for unhandled error
servUnavailResp := xmlServiceUnavailableResponse{}
unavailErr := xml.Unmarshal(bodyBytes, &servUnavailResp)
if unavailErr == nil {
r.Error = awserr.NewRequestFailure(
awserr.New("ServiceUnavailableException", "service is unavailable", nil),
r.HTTPResponse.StatusCode,
r.RequestID,
)
return
}
// Failed to retrieve any error message from the response body
r.Error = awserr.NewRequestFailure(
awserr.New("SerializationError",
"failed to decode query XML error response", decodeErr),
r.HTTPResponse.StatusCode,
r.RequestID,
)
} }

View File

@@ -14,6 +14,7 @@ go_library(
"//vendor/github.com/aws/aws-sdk-go/aws:go_default_library", "//vendor/github.com/aws/aws-sdk-go/aws:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/aws/awserr:go_default_library", "//vendor/github.com/aws/aws-sdk-go/aws/awserr:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/aws/request:go_default_library", "//vendor/github.com/aws/aws-sdk-go/aws/request:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/internal/strings:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/private/protocol:go_default_library", "//vendor/github.com/aws/aws-sdk-go/private/protocol:go_default_library",
], ],
) )

View File

@@ -25,6 +25,8 @@ var noEscape [256]bool
var errValueNotSet = fmt.Errorf("value not set") var errValueNotSet = fmt.Errorf("value not set")
var byteSliceType = reflect.TypeOf([]byte{})
func init() { func init() {
for i := 0; i < len(noEscape); i++ { for i := 0; i < len(noEscape); i++ {
// AWS expects every character except these to be escaped // AWS expects every character except these to be escaped
@@ -94,6 +96,14 @@ func buildLocationElements(r *request.Request, v reflect.Value, buildGETQuery bo
continue continue
} }
// Support the ability to customize values to be marshaled as a
// blob even though they were modeled as a string. Required for S3
// API operations like SSECustomerKey is modeled as stirng but
// required to be base64 encoded in request.
if field.Tag.Get("marshal-as") == "blob" {
m = m.Convert(byteSliceType)
}
var err error var err error
switch field.Tag.Get("location") { switch field.Tag.Get("location") {
case "headers": // header maps case "headers": // header maps
@@ -137,7 +147,7 @@ func buildBody(r *request.Request, v reflect.Value) {
case string: case string:
r.SetStringBody(reader) r.SetStringBody(reader)
default: default:
r.Error = awserr.New("SerializationError", r.Error = awserr.New(request.ErrCodeSerialization,
"failed to encode REST request", "failed to encode REST request",
fmt.Errorf("unknown payload type %s", payload.Type())) fmt.Errorf("unknown payload type %s", payload.Type()))
} }
@@ -152,9 +162,12 @@ func buildHeader(header *http.Header, v reflect.Value, name string, tag reflect.
if err == errValueNotSet { if err == errValueNotSet {
return nil return nil
} else if err != nil { } else if err != nil {
return awserr.New("SerializationError", "failed to encode REST request", err) return awserr.New(request.ErrCodeSerialization, "failed to encode REST request", err)
} }
name = strings.TrimSpace(name)
str = strings.TrimSpace(str)
header.Add(name, str) header.Add(name, str)
return nil return nil
@@ -167,11 +180,13 @@ func buildHeaderMap(header *http.Header, v reflect.Value, tag reflect.StructTag)
if err == errValueNotSet { if err == errValueNotSet {
continue continue
} else if err != nil { } else if err != nil {
return awserr.New("SerializationError", "failed to encode REST request", err) return awserr.New(request.ErrCodeSerialization, "failed to encode REST request", err)
} }
keyStr := strings.TrimSpace(key.String())
str = strings.TrimSpace(str)
header.Add(prefix+key.String(), str) header.Add(prefix+keyStr, str)
} }
return nil return nil
} }
@@ -181,7 +196,7 @@ func buildURI(u *url.URL, v reflect.Value, name string, tag reflect.StructTag) e
if err == errValueNotSet { if err == errValueNotSet {
return nil return nil
} else if err != nil { } else if err != nil {
return awserr.New("SerializationError", "failed to encode REST request", err) return awserr.New(request.ErrCodeSerialization, "failed to encode REST request", err)
} }
u.Path = strings.Replace(u.Path, "{"+name+"}", value, -1) u.Path = strings.Replace(u.Path, "{"+name+"}", value, -1)
@@ -214,7 +229,7 @@ func buildQueryString(query url.Values, v reflect.Value, name string, tag reflec
if err == errValueNotSet { if err == errValueNotSet {
return nil return nil
} else if err != nil { } else if err != nil {
return awserr.New("SerializationError", "failed to encode REST request", err) return awserr.New(request.ErrCodeSerialization, "failed to encode REST request", err)
} }
query.Set(name, str) query.Set(name, str)
} }

View File

@@ -15,6 +15,7 @@ import (
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/request"
awsStrings "github.com/aws/aws-sdk-go/internal/strings"
"github.com/aws/aws-sdk-go/private/protocol" "github.com/aws/aws-sdk-go/private/protocol"
) )
@@ -28,7 +29,9 @@ var UnmarshalMetaHandler = request.NamedHandler{Name: "awssdk.rest.UnmarshalMeta
func Unmarshal(r *request.Request) { func Unmarshal(r *request.Request) {
if r.DataFilled() { if r.DataFilled() {
v := reflect.Indirect(reflect.ValueOf(r.Data)) v := reflect.Indirect(reflect.ValueOf(r.Data))
unmarshalBody(r, v) if err := unmarshalBody(r, v); err != nil {
r.Error = err
}
} }
} }
@@ -40,12 +43,21 @@ func UnmarshalMeta(r *request.Request) {
r.RequestID = r.HTTPResponse.Header.Get("X-Amz-Request-Id") r.RequestID = r.HTTPResponse.Header.Get("X-Amz-Request-Id")
} }
if r.DataFilled() { if r.DataFilled() {
v := reflect.Indirect(reflect.ValueOf(r.Data)) if err := UnmarshalResponse(r.HTTPResponse, r.Data, aws.BoolValue(r.Config.LowerCaseHeaderMaps)); err != nil {
unmarshalLocationElements(r, v) r.Error = err
}
} }
} }
func unmarshalBody(r *request.Request, v reflect.Value) { // UnmarshalResponse attempts to unmarshal the REST response headers to
// the data type passed in. The type must be a pointer. An error is returned
// with any error unmarshaling the response into the target datatype.
func UnmarshalResponse(resp *http.Response, data interface{}, lowerCaseHeaderMaps bool) error {
v := reflect.Indirect(reflect.ValueOf(data))
return unmarshalLocationElements(resp, v, lowerCaseHeaderMaps)
}
func unmarshalBody(r *request.Request, v reflect.Value) error {
if field, ok := v.Type().FieldByName("_"); ok { if field, ok := v.Type().FieldByName("_"); ok {
if payloadName := field.Tag.Get("payload"); payloadName != "" { if payloadName := field.Tag.Get("payload"); payloadName != "" {
pfield, _ := v.Type().FieldByName(payloadName) pfield, _ := v.Type().FieldByName(payloadName)
@@ -57,35 +69,38 @@ func unmarshalBody(r *request.Request, v reflect.Value) {
defer r.HTTPResponse.Body.Close() defer r.HTTPResponse.Body.Close()
b, err := ioutil.ReadAll(r.HTTPResponse.Body) b, err := ioutil.ReadAll(r.HTTPResponse.Body)
if err != nil { if err != nil {
r.Error = awserr.New("SerializationError", "failed to decode REST response", err) return awserr.New(request.ErrCodeSerialization, "failed to decode REST response", err)
} else {
payload.Set(reflect.ValueOf(b))
} }
payload.Set(reflect.ValueOf(b))
case *string: case *string:
defer r.HTTPResponse.Body.Close() defer r.HTTPResponse.Body.Close()
b, err := ioutil.ReadAll(r.HTTPResponse.Body) b, err := ioutil.ReadAll(r.HTTPResponse.Body)
if err != nil { if err != nil {
r.Error = awserr.New("SerializationError", "failed to decode REST response", err) return awserr.New(request.ErrCodeSerialization, "failed to decode REST response", err)
} else { }
str := string(b) str := string(b)
payload.Set(reflect.ValueOf(&str)) payload.Set(reflect.ValueOf(&str))
}
default: default:
switch payload.Type().String() { switch payload.Type().String() {
case "io.ReadCloser": case "io.ReadCloser":
payload.Set(reflect.ValueOf(r.HTTPResponse.Body)) payload.Set(reflect.ValueOf(r.HTTPResponse.Body))
case "io.ReadSeeker": case "io.ReadSeeker":
b, err := ioutil.ReadAll(r.HTTPResponse.Body) b, err := ioutil.ReadAll(r.HTTPResponse.Body)
if err != nil { if err != nil {
r.Error = awserr.New("SerializationError", return awserr.New(request.ErrCodeSerialization,
"failed to read response body", err) "failed to read response body", err)
return
} }
payload.Set(reflect.ValueOf(ioutil.NopCloser(bytes.NewReader(b)))) payload.Set(reflect.ValueOf(ioutil.NopCloser(bytes.NewReader(b))))
default: default:
io.Copy(ioutil.Discard, r.HTTPResponse.Body) io.Copy(ioutil.Discard, r.HTTPResponse.Body)
defer r.HTTPResponse.Body.Close() r.HTTPResponse.Body.Close()
r.Error = awserr.New("SerializationError", return awserr.New(request.ErrCodeSerialization,
"failed to decode REST response", "failed to decode REST response",
fmt.Errorf("unknown payload type %s", payload.Type())) fmt.Errorf("unknown payload type %s", payload.Type()))
} }
@@ -94,9 +109,11 @@ func unmarshalBody(r *request.Request, v reflect.Value) {
} }
} }
} }
return nil
} }
func unmarshalLocationElements(r *request.Request, v reflect.Value) { func unmarshalLocationElements(resp *http.Response, v reflect.Value, lowerCaseHeaderMaps bool) error {
for i := 0; i < v.NumField(); i++ { for i := 0; i < v.NumField(); i++ {
m, field := v.Field(i), v.Type().Field(i) m, field := v.Field(i), v.Type().Field(i)
if n := field.Name; n[0:1] == strings.ToLower(n[0:1]) { if n := field.Name; n[0:1] == strings.ToLower(n[0:1]) {
@@ -111,26 +128,25 @@ func unmarshalLocationElements(r *request.Request, v reflect.Value) {
switch field.Tag.Get("location") { switch field.Tag.Get("location") {
case "statusCode": case "statusCode":
unmarshalStatusCode(m, r.HTTPResponse.StatusCode) unmarshalStatusCode(m, resp.StatusCode)
case "header": case "header":
err := unmarshalHeader(m, r.HTTPResponse.Header.Get(name), field.Tag) err := unmarshalHeader(m, resp.Header.Get(name), field.Tag)
if err != nil { if err != nil {
r.Error = awserr.New("SerializationError", "failed to decode REST response", err) return awserr.New(request.ErrCodeSerialization, "failed to decode REST response", err)
break
} }
case "headers": case "headers":
prefix := field.Tag.Get("locationName") prefix := field.Tag.Get("locationName")
err := unmarshalHeaderMap(m, r.HTTPResponse.Header, prefix) err := unmarshalHeaderMap(m, resp.Header, prefix, lowerCaseHeaderMaps)
if err != nil { if err != nil {
r.Error = awserr.New("SerializationError", "failed to decode REST response", err) awserr.New(request.ErrCodeSerialization, "failed to decode REST response", err)
break
} }
} }
} }
if r.Error != nil {
return
}
} }
return nil
} }
func unmarshalStatusCode(v reflect.Value, statusCode int) { func unmarshalStatusCode(v reflect.Value, statusCode int) {
@@ -145,30 +161,46 @@ func unmarshalStatusCode(v reflect.Value, statusCode int) {
} }
} }
func unmarshalHeaderMap(r reflect.Value, headers http.Header, prefix string) error { func unmarshalHeaderMap(r reflect.Value, headers http.Header, prefix string, normalize bool) error {
if len(headers) == 0 {
return nil
}
switch r.Interface().(type) { switch r.Interface().(type) {
case map[string]*string: // we only support string map value types case map[string]*string: // we only support string map value types
out := map[string]*string{} out := map[string]*string{}
for k, v := range headers { for k, v := range headers {
if awsStrings.HasPrefixFold(k, prefix) {
if normalize == true {
k = strings.ToLower(k)
} else {
k = http.CanonicalHeaderKey(k) k = http.CanonicalHeaderKey(k)
if strings.HasPrefix(strings.ToLower(k), strings.ToLower(prefix)) { }
out[k[len(prefix):]] = &v[0] out[k[len(prefix):]] = &v[0]
} }
} }
if len(out) != 0 {
r.Set(reflect.ValueOf(out)) r.Set(reflect.ValueOf(out))
} }
}
return nil return nil
} }
func unmarshalHeader(v reflect.Value, header string, tag reflect.StructTag) error { func unmarshalHeader(v reflect.Value, header string, tag reflect.StructTag) error {
isJSONValue := tag.Get("type") == "jsonvalue" switch tag.Get("type") {
if isJSONValue { case "jsonvalue":
if len(header) == 0 { if len(header) == 0 {
return nil return nil
} }
} else if !v.IsValid() || (header == "" && v.Elem().Kind() != reflect.String) { case "blob":
if len(header) == 0 {
return nil return nil
} }
default:
if !v.IsValid() || (header == "" && v.Elem().Kind() != reflect.String) {
return nil
}
}
switch v.Interface().(type) { switch v.Interface().(type) {
case *string: case *string:
@@ -178,7 +210,7 @@ func unmarshalHeader(v reflect.Value, header string, tag reflect.StructTag) erro
if err != nil { if err != nil {
return err return err
} }
v.Set(reflect.ValueOf(&b)) v.Set(reflect.ValueOf(b))
case *bool: case *bool:
b, err := strconv.ParseBool(header) b, err := strconv.ParseBool(header)
if err != nil { if err != nil {

View File

@@ -1,8 +1,11 @@
package protocol package protocol
import ( import (
"math"
"strconv" "strconv"
"time" "time"
"github.com/aws/aws-sdk-go/internal/sdkmath"
) )
// Names of time formats supported by the SDK // Names of time formats supported by the SDK
@@ -13,12 +16,19 @@ const (
) )
// Time formats supported by the SDK // Time formats supported by the SDK
// Output time is intended to not contain decimals
const ( const (
// RFC 7231#section-7.1.1.1 timetamp format. e.g Tue, 29 Apr 2014 18:30:38 GMT // RFC 7231#section-7.1.1.1 timetamp format. e.g Tue, 29 Apr 2014 18:30:38 GMT
RFC822TimeFormat = "Mon, 2 Jan 2006 15:04:05 GMT" RFC822TimeFormat = "Mon, 2 Jan 2006 15:04:05 GMT"
// This format is used for output time without seconds precision
RFC822OutputTimeFormat = "Mon, 02 Jan 2006 15:04:05 GMT"
// RFC3339 a subset of the ISO8601 timestamp format. e.g 2014-04-29T18:30:38Z // RFC3339 a subset of the ISO8601 timestamp format. e.g 2014-04-29T18:30:38Z
ISO8601TimeFormat = "2006-01-02T15:04:05Z" ISO8601TimeFormat = "2006-01-02T15:04:05.999999999Z"
// This format is used for output time without seconds precision
ISO8601OutputTimeFormat = "2006-01-02T15:04:05Z"
) )
// IsKnownTimestampFormat returns if the timestamp format name // IsKnownTimestampFormat returns if the timestamp format name
@@ -42,9 +52,9 @@ func FormatTime(name string, t time.Time) string {
switch name { switch name {
case RFC822TimeFormatName: case RFC822TimeFormatName:
return t.Format(RFC822TimeFormat) return t.Format(RFC822OutputTimeFormat)
case ISO8601TimeFormatName: case ISO8601TimeFormatName:
return t.Format(ISO8601TimeFormat) return t.Format(ISO8601OutputTimeFormat)
case UnixTimeFormatName: case UnixTimeFormatName:
return strconv.FormatInt(t.Unix(), 10) return strconv.FormatInt(t.Unix(), 10)
default: default:
@@ -62,10 +72,12 @@ func ParseTime(formatName, value string) (time.Time, error) {
return time.Parse(ISO8601TimeFormat, value) return time.Parse(ISO8601TimeFormat, value)
case UnixTimeFormatName: case UnixTimeFormatName:
v, err := strconv.ParseFloat(value, 64) v, err := strconv.ParseFloat(value, 64)
_, dec := math.Modf(v)
dec = sdkmath.Round(dec*1e3) / 1e3 //Rounds 0.1229999 to 0.123
if err != nil { if err != nil {
return time.Time{}, err return time.Time{}, err
} }
return time.Unix(int64(v), 0), nil return time.Unix(int64(v), int64(dec*(1e9))), nil
default: default:
panic("unknown timestamp format name, " + formatName) panic("unknown timestamp format name, " + formatName)
} }

View File

@@ -19,3 +19,9 @@ func UnmarshalDiscardBody(r *request.Request) {
io.Copy(ioutil.Discard, r.HTTPResponse.Body) io.Copy(ioutil.Discard, r.HTTPResponse.Body)
r.HTTPResponse.Body.Close() r.HTTPResponse.Body.Close()
} }
// ResponseMetadata provides the SDK response metadata attributes.
type ResponseMetadata struct {
StatusCode int
RequestID string
}

View File

@@ -0,0 +1,65 @@
package protocol
import (
"net/http"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
)
// UnmarshalErrorHandler provides unmarshaling errors API response errors for
// both typed and untyped errors.
type UnmarshalErrorHandler struct {
unmarshaler ErrorUnmarshaler
}
// ErrorUnmarshaler is an abstract interface for concrete implementations to
// unmarshal protocol specific response errors.
type ErrorUnmarshaler interface {
UnmarshalError(*http.Response, ResponseMetadata) (error, error)
}
// NewUnmarshalErrorHandler returns an UnmarshalErrorHandler
// initialized for the set of exception names to the error unmarshalers
func NewUnmarshalErrorHandler(unmarshaler ErrorUnmarshaler) *UnmarshalErrorHandler {
return &UnmarshalErrorHandler{
unmarshaler: unmarshaler,
}
}
// UnmarshalErrorHandlerName is the name of the named handler.
const UnmarshalErrorHandlerName = "awssdk.protocol.UnmarshalError"
// NamedHandler returns a NamedHandler for the unmarshaler using the set of
// errors the unmarshaler was initialized for.
func (u *UnmarshalErrorHandler) NamedHandler() request.NamedHandler {
return request.NamedHandler{
Name: UnmarshalErrorHandlerName,
Fn: u.UnmarshalError,
}
}
// UnmarshalError will attempt to unmarshal the API response's error message
// into either a generic SDK error type, or a typed error corresponding to the
// errors exception name.
func (u *UnmarshalErrorHandler) UnmarshalError(r *request.Request) {
defer r.HTTPResponse.Body.Close()
respMeta := ResponseMetadata{
StatusCode: r.HTTPResponse.StatusCode,
RequestID: r.RequestID,
}
v, err := u.unmarshaler.UnmarshalError(r.HTTPResponse, respMeta)
if err != nil {
r.Error = awserr.NewRequestFailure(
awserr.New(request.ErrCodeSerialization,
"failed to unmarshal response error", err),
respMeta.StatusCode,
respMeta.RequestID,
)
return
}
r.Error = v
}

View File

@@ -4,13 +4,17 @@ go_library(
name = "go_default_library", name = "go_default_library",
srcs = [ srcs = [
"build.go", "build.go",
"sort.go",
"unmarshal.go", "unmarshal.go",
"xml_to_struct.go", "xml_to_struct.go",
], ],
importmap = "k8s.io/kubernetes/vendor/github.com/aws/aws-sdk-go/private/protocol/xml/xmlutil", importmap = "k8s.io/kubernetes/vendor/github.com/aws/aws-sdk-go/private/protocol/xml/xmlutil",
importpath = "github.com/aws/aws-sdk-go/private/protocol/xml/xmlutil", importpath = "github.com/aws/aws-sdk-go/private/protocol/xml/xmlutil",
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = ["//vendor/github.com/aws/aws-sdk-go/private/protocol:go_default_library"], deps = [
"//vendor/github.com/aws/aws-sdk-go/aws/awserr:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/private/protocol:go_default_library",
],
) )
filegroup( filegroup(

View File

@@ -0,0 +1,32 @@
package xmlutil
import (
"encoding/xml"
"strings"
)
type xmlAttrSlice []xml.Attr
func (x xmlAttrSlice) Len() int {
return len(x)
}
func (x xmlAttrSlice) Less(i, j int) bool {
spaceI, spaceJ := x[i].Name.Space, x[j].Name.Space
localI, localJ := x[i].Name.Local, x[j].Name.Local
valueI, valueJ := x[i].Value, x[j].Value
spaceCmp := strings.Compare(spaceI, spaceJ)
localCmp := strings.Compare(localI, localJ)
valueCmp := strings.Compare(valueI, valueJ)
if spaceCmp == -1 || (spaceCmp == 0 && (localCmp == -1 || (localCmp == 0 && valueCmp == -1))) {
return true
}
return false
}
func (x xmlAttrSlice) Swap(i, j int) {
x[i], x[j] = x[j], x[i]
}

Some files were not shown because too many files have changed in this diff Show More