Update Godeps

This commit is contained in:
Adam Sunderland
2015-05-14 14:33:55 -07:00
parent ff9996c100
commit 6355aa2e51
70 changed files with 28851 additions and 8762 deletions

38
Godeps/Godeps.json generated
View File

@@ -1,6 +1,6 @@
{
"ImportPath": "github.com/GoogleCloudPlatform/kubernetes",
"GoVersion": "go1.4.2",
"GoVersion": "go1.3.3",
"Packages": [
"./..."
],
@@ -48,6 +48,34 @@
"Comment": "v0.5.1-55-g87808a3",
"Rev": "87808a37061a4a2e6204ccea5fd2fc930576db94"
},
{
"ImportPath": "github.com/awslabs/aws-sdk-go/aws",
"Rev": "c0a38f106248742920a2b786dcae81457af003d3"
},
{
"ImportPath": "github.com/awslabs/aws-sdk-go/internal/endpoints",
"Rev": "c0a38f106248742920a2b786dcae81457af003d3"
},
{
"ImportPath": "github.com/awslabs/aws-sdk-go/internal/protocol/ec2query",
"Rev": "c0a38f106248742920a2b786dcae81457af003d3"
},
{
"ImportPath": "github.com/awslabs/aws-sdk-go/internal/protocol/query/queryutil",
"Rev": "c0a38f106248742920a2b786dcae81457af003d3"
},
{
"ImportPath": "github.com/awslabs/aws-sdk-go/internal/protocol/xml/xmlutil",
"Rev": "c0a38f106248742920a2b786dcae81457af003d3"
},
{
"ImportPath": "github.com/awslabs/aws-sdk-go/internal/signer/v4",
"Rev": "c0a38f106248742920a2b786dcae81457af003d3"
},
{
"ImportPath": "github.com/awslabs/aws-sdk-go/service/ec2",
"Rev": "c0a38f106248742920a2b786dcae81457af003d3"
},
{
"ImportPath": "github.com/beorn7/perks/quantile",
"Rev": "b965b613227fddccbfffe13eae360ed3fa822f8d"
@@ -336,14 +364,6 @@
"ImportPath": "github.com/miekg/dns",
"Rev": "3f504e8dabd5d562e997d19ce0200aa41973e1b2"
},
{
"ImportPath": "github.com/mitchellh/goamz/aws",
"Rev": "703cfb45985762869e465f37ed030ff01615ff1e"
},
{
"ImportPath": "github.com/mitchellh/goamz/ec2",
"Rev": "703cfb45985762869e465f37ed030ff01615ff1e"
},
{
"ImportPath": "github.com/mitchellh/mapstructure",
"Rev": "740c764bc6149d3f1806231418adb9f52c11bcbf"

View File

@@ -0,0 +1,69 @@
package awsutil
import (
"io"
"reflect"
)
// Copy deeply copies a src structure to dst. Useful for copying request and
// response structures.
func Copy(dst, src interface{}) {
rcopy(reflect.ValueOf(dst), reflect.ValueOf(src))
}
// CopyOf returns a copy of src while also allocating the memory for dst.
// src must be a pointer type or this operation will fail.
func CopyOf(src interface{}) (dst interface{}) {
dsti := reflect.New(reflect.TypeOf(src).Elem())
dst = dsti.Interface()
rcopy(dsti, reflect.ValueOf(src))
return
}
// rcopy performs a recursive copy of values from the source to destination.
func rcopy(dst, src reflect.Value) {
if !src.IsValid() {
return
}
switch src.Kind() {
case reflect.Ptr:
if _, ok := src.Interface().(io.Reader); ok {
if dst.Kind() == reflect.Ptr && dst.Elem().CanSet() {
dst.Elem().Set(src)
} else if dst.CanSet() {
dst.Set(src)
}
} else {
e := src.Type().Elem()
if dst.CanSet() {
dst.Set(reflect.New(e))
}
if src.Elem().IsValid() {
rcopy(dst.Elem(), src.Elem())
}
}
case reflect.Struct:
dst.Set(reflect.New(src.Type()).Elem())
for i := 0; i < src.NumField(); i++ {
rcopy(dst.Field(i), src.Field(i))
}
case reflect.Slice:
s := reflect.MakeSlice(src.Type(), src.Len(), src.Cap())
dst.Set(s)
for i := 0; i < src.Len(); i++ {
rcopy(dst.Index(i), src.Index(i))
}
case reflect.Map:
s := reflect.MakeMap(src.Type())
dst.Set(s)
for _, k := range src.MapKeys() {
v := src.MapIndex(k)
v2 := reflect.New(v.Type()).Elem()
rcopy(v2, v)
dst.SetMapIndex(k, v2)
}
default:
dst.Set(src)
}
}

View File

@@ -0,0 +1,130 @@
package awsutil_test
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"testing"
"github.com/awslabs/aws-sdk-go/aws/awsutil"
"github.com/stretchr/testify/assert"
)
func ExampleCopy() {
type Foo struct {
A int
B []*string
}
// Create the initial value
str1 := "hello"
str2 := "bye bye"
f1 := &Foo{A: 1, B: []*string{&str1, &str2}}
// Do the copy
var f2 Foo
awsutil.Copy(&f2, f1)
// Print the result
fmt.Println(awsutil.StringValue(f2))
// Output:
// {
// A: 1,
// B: ["hello","bye bye"]
// }
}
func TestCopy(t *testing.T) {
type Foo struct {
A int
B []*string
C map[string]*int
}
// Create the initial value
str1 := "hello"
str2 := "bye bye"
int1 := 1
int2 := 2
f1 := &Foo{
A: 1,
B: []*string{&str1, &str2},
C: map[string]*int{
"A": &int1,
"B": &int2,
},
}
// Do the copy
var f2 Foo
awsutil.Copy(&f2, f1)
// Values are equal
assert.Equal(t, f2.A, f1.A)
assert.Equal(t, f2.B, f1.B)
assert.Equal(t, f2.C, f1.C)
// But pointers are not!
str3 := "nothello"
int3 := 57
f2.A = 100
f2.B[0] = &str3
f2.C["B"] = &int3
assert.NotEqual(t, f2.A, f1.A)
assert.NotEqual(t, f2.B, f1.B)
assert.NotEqual(t, f2.C, f1.C)
}
func TestCopyPrimitive(t *testing.T) {
str := "hello"
var s string
awsutil.Copy(&s, &str)
assert.Equal(t, "hello", s)
}
func TestCopyNil(t *testing.T) {
var s string
awsutil.Copy(&s, nil)
assert.Equal(t, "", s)
}
func TestCopyReader(t *testing.T) {
var buf io.Reader = bytes.NewReader([]byte("hello world"))
var r io.Reader
awsutil.Copy(&r, buf)
b, err := ioutil.ReadAll(r)
assert.NoError(t, err)
assert.Equal(t, []byte("hello world"), b)
// empty bytes because this is not a deep copy
b, err = ioutil.ReadAll(buf)
assert.NoError(t, err)
assert.Equal(t, []byte(""), b)
}
func ExampleCopyOf() {
type Foo struct {
A int
B []*string
}
// Create the initial value
str1 := "hello"
str2 := "bye bye"
f1 := &Foo{A: 1, B: []*string{&str1, &str2}}
// Do the copy
v := awsutil.CopyOf(f1)
var f2 *Foo = v.(*Foo)
// Print the result
fmt.Println(awsutil.StringValue(f2))
// Output:
// {
// A: 1,
// B: ["hello","bye bye"]
// }
}

View File

@@ -0,0 +1,144 @@
package awsutil
import (
"reflect"
"regexp"
"strconv"
"strings"
)
var indexRe = regexp.MustCompile(`(.+)\[(-?\d+)?\]$`)
// rValuesAtPath returns a slice of values found in value v. The values
// in v are explored recursively so all nested values are collected.
func rValuesAtPath(v interface{}, path string, create bool) []reflect.Value {
pathparts := strings.Split(path, "||")
if len(pathparts) > 1 {
for _, pathpart := range pathparts {
vals := rValuesAtPath(v, pathpart, create)
if vals != nil && len(vals) > 0 {
return vals
}
}
return nil
}
values := []reflect.Value{reflect.Indirect(reflect.ValueOf(v))}
components := strings.Split(path, ".")
for len(values) > 0 && len(components) > 0 {
var index *int64
var indexStar bool
c := strings.TrimSpace(components[0])
if c == "" { // no actual component, illegal syntax
return nil
} else if c != "*" && strings.ToLower(c[0:1]) == c[0:1] {
// TODO normalize case for user
return nil // don't support unexported fields
}
// parse this component
if m := indexRe.FindStringSubmatch(c); m != nil {
c = m[1]
if m[2] == "" {
index = nil
indexStar = true
} else {
i, _ := strconv.ParseInt(m[2], 10, 32)
index = &i
indexStar = false
}
}
nextvals := []reflect.Value{}
for _, value := range values {
// pull component name out of struct member
if value.Kind() != reflect.Struct {
continue
}
if c == "*" { // pull all members
for i := 0; i < value.NumField(); i++ {
if f := reflect.Indirect(value.Field(i)); f.IsValid() {
nextvals = append(nextvals, f)
}
}
continue
}
value = value.FieldByName(c)
if create && value.Kind() == reflect.Ptr && value.IsNil() {
value.Set(reflect.New(value.Type().Elem()))
value = value.Elem()
} else {
value = reflect.Indirect(value)
}
if value.IsValid() {
nextvals = append(nextvals, value)
}
}
values = nextvals
if indexStar || index != nil {
nextvals = []reflect.Value{}
for _, value := range values {
value := reflect.Indirect(value)
if value.Kind() != reflect.Slice {
continue
}
if indexStar { // grab all indices
for i := 0; i < value.Len(); i++ {
idx := reflect.Indirect(value.Index(i))
if idx.IsValid() {
nextvals = append(nextvals, idx)
}
}
continue
}
// pull out index
i := int(*index)
if i >= value.Len() { // check out of bounds
if create {
// TODO resize slice
} else {
continue
}
} else if i < 0 { // support negative indexing
i = value.Len() + i
}
value = reflect.Indirect(value.Index(i))
if value.IsValid() {
nextvals = append(nextvals, value)
}
}
values = nextvals
}
components = components[1:]
}
return values
}
// ValuesAtPath returns a list of objects at the lexical path inside of a structure
func ValuesAtPath(i interface{}, path string) []interface{} {
if rvals := rValuesAtPath(i, path, false); rvals != nil {
vals := make([]interface{}, len(rvals))
for i, rval := range rvals {
vals[i] = rval.Interface()
}
return vals
}
return nil
}
// SetValueAtPath sets an object at the lexical path inside of a structure
func SetValueAtPath(i interface{}, path string, v interface{}) {
if rvals := rValuesAtPath(i, path, true); rvals != nil {
for _, rval := range rvals {
rval.Set(reflect.ValueOf(v))
}
}
}

View File

@@ -0,0 +1,60 @@
package awsutil_test
import (
"testing"
"github.com/awslabs/aws-sdk-go/aws/awsutil"
"github.com/stretchr/testify/assert"
)
type Struct struct {
A []Struct
a []Struct
B *Struct
D *Struct
C string
}
var data = Struct{
A: []Struct{Struct{C: "value1"}, Struct{C: "value2"}, Struct{C: "value3"}},
a: []Struct{Struct{C: "value1"}, Struct{C: "value2"}, Struct{C: "value3"}},
B: &Struct{B: &Struct{C: "terminal"}, D: &Struct{C: "terminal2"}},
C: "initial",
}
func TestValueAtPathSuccess(t *testing.T) {
assert.Equal(t, []interface{}{"initial"}, awsutil.ValuesAtPath(data, "C"))
assert.Equal(t, []interface{}{"value1"}, awsutil.ValuesAtPath(data, "A[0].C"))
assert.Equal(t, []interface{}{"value2"}, awsutil.ValuesAtPath(data, "A[1].C"))
assert.Equal(t, []interface{}{"value3"}, awsutil.ValuesAtPath(data, "A[2].C"))
assert.Equal(t, []interface{}{"value3"}, awsutil.ValuesAtPath(data, "A[-1].C"))
assert.Equal(t, []interface{}{"value1", "value2", "value3"}, awsutil.ValuesAtPath(data, "A[].C"))
assert.Equal(t, []interface{}{"terminal"}, awsutil.ValuesAtPath(data, "B . B . C"))
assert.Equal(t, []interface{}{"terminal", "terminal2"}, awsutil.ValuesAtPath(data, "B.*.C"))
assert.Equal(t, []interface{}{"initial"}, awsutil.ValuesAtPath(data, "A.D.X || C"))
}
func TestValueAtPathFailure(t *testing.T) {
assert.Equal(t, []interface{}(nil), awsutil.ValuesAtPath(data, "C.x"))
assert.Equal(t, []interface{}(nil), awsutil.ValuesAtPath(data, ".x"))
assert.Equal(t, []interface{}{}, awsutil.ValuesAtPath(data, "X.Y.Z"))
assert.Equal(t, []interface{}{}, awsutil.ValuesAtPath(data, "A[100].C"))
assert.Equal(t, []interface{}{}, awsutil.ValuesAtPath(data, "A[3].C"))
assert.Equal(t, []interface{}{}, awsutil.ValuesAtPath(data, "B.B.C.Z"))
assert.Equal(t, []interface{}(nil), awsutil.ValuesAtPath(data, "a[-1].C"))
assert.Equal(t, []interface{}{}, awsutil.ValuesAtPath(nil, "A.B.C"))
}
func TestSetValueAtPathSuccess(t *testing.T) {
var s Struct
awsutil.SetValueAtPath(&s, "C", "test1")
awsutil.SetValueAtPath(&s, "B.B.C", "test2")
awsutil.SetValueAtPath(&s, "B.D.C", "test3")
assert.Equal(t, "test1", s.C)
assert.Equal(t, "test2", s.B.B.C)
assert.Equal(t, "test3", s.B.D.C)
awsutil.SetValueAtPath(&s, "B.*.C", "test0")
assert.Equal(t, "test0", s.B.B.C)
assert.Equal(t, "test0", s.B.D.C)
}

View File

@@ -0,0 +1,103 @@
package awsutil
import (
"bytes"
"fmt"
"io"
"reflect"
"strings"
)
// StringValue returns the string representation of a value.
func StringValue(i interface{}) string {
var buf bytes.Buffer
stringValue(reflect.ValueOf(i), 0, &buf)
return buf.String()
}
// stringValue will recursively walk value v to build a textual
// representation of the value.
func stringValue(v reflect.Value, indent int, buf *bytes.Buffer) {
for v.Kind() == reflect.Ptr {
v = v.Elem()
}
switch v.Kind() {
case reflect.Struct:
strtype := v.Type().String()
if strtype == "time.Time" {
fmt.Fprintf(buf, "%s", v.Interface())
break
} else if strings.HasPrefix(strtype, "io.") {
buf.WriteString("<buffer>")
break
}
buf.WriteString("{\n")
names := []string{}
for i := 0; i < v.Type().NumField(); i++ {
name := v.Type().Field(i).Name
f := v.Field(i)
if name[0:1] == strings.ToLower(name[0:1]) {
continue // ignore unexported fields
}
if (f.Kind() == reflect.Ptr || f.Kind() == reflect.Slice) && f.IsNil() {
continue // ignore unset fields
}
names = append(names, name)
}
for i, n := range names {
val := v.FieldByName(n)
buf.WriteString(strings.Repeat(" ", indent+2))
buf.WriteString(n + ": ")
stringValue(val, indent+2, buf)
if i < len(names)-1 {
buf.WriteString(",\n")
}
}
buf.WriteString("\n" + strings.Repeat(" ", indent) + "}")
case reflect.Slice:
nl, id, id2 := "", "", ""
if v.Len() > 3 {
nl, id, id2 = "\n", strings.Repeat(" ", indent), strings.Repeat(" ", indent+2)
}
buf.WriteString("[" + nl)
for i := 0; i < v.Len(); i++ {
buf.WriteString(id2)
stringValue(v.Index(i), indent+2, buf)
if i < v.Len()-1 {
buf.WriteString("," + nl)
}
}
buf.WriteString(nl + id + "]")
case reflect.Map:
buf.WriteString("{\n")
for i, k := range v.MapKeys() {
buf.WriteString(strings.Repeat(" ", indent+2))
buf.WriteString(k.String() + ": ")
stringValue(v.MapIndex(k), indent+2, buf)
if i < v.Len()-1 {
buf.WriteString(",\n")
}
}
buf.WriteString("\n" + strings.Repeat(" ", indent) + "}")
default:
format := "%v"
switch v.Interface().(type) {
case string:
format = "%q"
case io.ReadSeeker, io.Reader:
format = "buffer(%p)"
}
fmt.Fprintf(buf, format, v.Interface())
}
}

View File

@@ -0,0 +1,167 @@
package aws
import (
"io"
"net/http"
"os"
"github.com/awslabs/aws-sdk-go/aws/credentials"
)
// DefaultChainCredentials is a Credentials which will find the first available
// credentials Value from the list of Providers.
//
// This should be used in the default case. Once the type of credentials are
// known switching to the specific Credentials will be more efficient.
var DefaultChainCredentials = credentials.NewChainCredentials(
[]credentials.Provider{
&credentials.EnvProvider{},
&credentials.SharedCredentialsProvider{Filename: "", Profile: ""},
&credentials.EC2RoleProvider{},
})
// The default number of retries for a service. The value of -1 indicates that
// the service specific retry default will be used.
const DefaultRetries = -1
// DefaultConfig is the default all service configuration will be based off of.
var DefaultConfig = &Config{
Credentials: DefaultChainCredentials,
Endpoint: "",
Region: os.Getenv("AWS_REGION"),
DisableSSL: false,
ManualSend: false,
HTTPClient: http.DefaultClient,
LogHTTPBody: false,
LogLevel: 0,
Logger: os.Stdout,
MaxRetries: DefaultRetries,
DisableParamValidation: false,
DisableComputeChecksums: false,
S3ForcePathStyle: false,
}
// A Config provides service configuration
type Config struct {
Credentials *credentials.Credentials
Endpoint string
Region string
DisableSSL bool
ManualSend bool
HTTPClient *http.Client
LogHTTPBody bool
LogLevel uint
Logger io.Writer
MaxRetries int
DisableParamValidation bool
DisableComputeChecksums bool
S3ForcePathStyle bool
}
// Copy will return a shallow copy of the Config object.
func (c Config) Copy() Config {
dst := Config{}
dst.Credentials = c.Credentials
dst.Endpoint = c.Endpoint
dst.Region = c.Region
dst.DisableSSL = c.DisableSSL
dst.ManualSend = c.ManualSend
dst.HTTPClient = c.HTTPClient
dst.LogLevel = c.LogLevel
dst.Logger = c.Logger
dst.MaxRetries = c.MaxRetries
dst.DisableParamValidation = c.DisableParamValidation
dst.DisableComputeChecksums = c.DisableComputeChecksums
dst.S3ForcePathStyle = c.S3ForcePathStyle
return dst
}
// Merge merges the newcfg attribute values into this Config. Each attribute
// will be merged into this config if the newcfg attribute's value is non-zero.
// Due to this, newcfg attributes with zero values cannot be merged in. For
// example bool attributes cannot be cleared using Merge, and must be explicitly
// set on the Config structure.
func (c Config) Merge(newcfg *Config) *Config {
cfg := Config{}
if newcfg != nil && newcfg.Credentials != nil {
cfg.Credentials = newcfg.Credentials
} else {
cfg.Credentials = c.Credentials
}
if newcfg != nil && newcfg.Endpoint != "" {
cfg.Endpoint = newcfg.Endpoint
} else {
cfg.Endpoint = c.Endpoint
}
if newcfg != nil && newcfg.Region != "" {
cfg.Region = newcfg.Region
} else {
cfg.Region = c.Region
}
if newcfg != nil && newcfg.DisableSSL {
cfg.DisableSSL = newcfg.DisableSSL
} else {
cfg.DisableSSL = c.DisableSSL
}
if newcfg != nil && newcfg.ManualSend {
cfg.ManualSend = newcfg.ManualSend
} else {
cfg.ManualSend = c.ManualSend
}
if newcfg != nil && newcfg.HTTPClient != nil {
cfg.HTTPClient = newcfg.HTTPClient
} else {
cfg.HTTPClient = c.HTTPClient
}
if newcfg != nil && newcfg.LogHTTPBody {
cfg.LogHTTPBody = newcfg.LogHTTPBody
} else {
cfg.LogHTTPBody = c.LogHTTPBody
}
if newcfg != nil && newcfg.LogLevel != 0 {
cfg.LogLevel = newcfg.LogLevel
} else {
cfg.LogLevel = c.LogLevel
}
if newcfg != nil && newcfg.Logger != nil {
cfg.Logger = newcfg.Logger
} else {
cfg.Logger = c.Logger
}
if newcfg != nil && newcfg.MaxRetries != DefaultRetries {
cfg.MaxRetries = newcfg.MaxRetries
} else {
cfg.MaxRetries = c.MaxRetries
}
if newcfg != nil && newcfg.DisableParamValidation {
cfg.DisableParamValidation = newcfg.DisableParamValidation
} else {
cfg.DisableParamValidation = c.DisableParamValidation
}
if newcfg != nil && newcfg.DisableComputeChecksums {
cfg.DisableComputeChecksums = newcfg.DisableComputeChecksums
} else {
cfg.DisableComputeChecksums = c.DisableComputeChecksums
}
if newcfg != nil && newcfg.S3ForcePathStyle {
cfg.S3ForcePathStyle = newcfg.S3ForcePathStyle
} else {
cfg.S3ForcePathStyle = c.S3ForcePathStyle
}
return &cfg
}

View File

@@ -0,0 +1,81 @@
package credentials
import (
"fmt"
)
var (
// ErrNoValidProvidersFoundInChain Is returned when there are no valid
// providers in the ChainProvider.
ErrNoValidProvidersFoundInChain = fmt.Errorf("no valid providers in chain")
)
// A ChainProvider will search for a provider which returns credentials
// and cache that provider until Retrieve is called again.
//
// The ChainProvider provides a way of chaining multiple providers together
// which will pick the first available using priority order of the Providers
// in the list.
//
// If none of the Providers retrieve valid credentials Value, ChainProvider's
// Retrieve() will return the error ErrNoValidProvidersFoundInChain.
//
// If a Provider is found which returns valid credentials Value ChainProvider
// will cache that Provider for all calls to IsExpired(), until Retrieve is
// called again.
//
// Example of ChainProvider to be used with an EnvProvider and EC2RoleProvider.
// In this example EnvProvider will first check if any credentials are available
// vai the environment variables. If there are none ChainProvider will check
// the next Provider in the list, EC2RoleProvider in this case. If EC2RoleProvider
// does not return any credentials ChainProvider will return the error
// ErrNoValidProvidersFoundInChain
//
// creds := NewChainCredentials(
// []Provider{
// &EnvProvider{},
// &EC2RoleProvider{},
// })
// creds.Retrieve()
//
type ChainProvider struct {
Providers []Provider
curr Provider
}
// NewChainCredentials returns a pointer to a new Credentials object
// wrapping a chain of providers.
func NewChainCredentials(providers []Provider) *Credentials {
return NewCredentials(&ChainProvider{
Providers: append([]Provider{}, providers...),
})
}
// Retrieve returns the credentials value or error if no provider returned
// without error.
//
// If a provider is found it will be cached and any calls to IsExpired()
// will return the expired state of the cached provider.
func (c *ChainProvider) Retrieve() (Value, error) {
for _, p := range c.Providers {
if creds, err := p.Retrieve(); err == nil {
c.curr = p
return creds, nil
}
}
c.curr = nil
// TODO better error reporting. maybe report error for each failed retrieve?
return Value{}, ErrNoValidProvidersFoundInChain
}
// IsExpired will returned the expired state of the currently cached provider
// if there is one. If there is no current provider, true will be returned.
func (c *ChainProvider) IsExpired() bool {
if c.curr != nil {
return c.curr.IsExpired()
}
return true
}

View File

@@ -0,0 +1,72 @@
package credentials
import (
"errors"
"github.com/stretchr/testify/assert"
"testing"
)
func TestChainProviderGet(t *testing.T) {
p := &ChainProvider{
Providers: []Provider{
&stubProvider{err: errors.New("first provider error")},
&stubProvider{err: errors.New("second provider error")},
&stubProvider{
creds: Value{
AccessKeyID: "AKID",
SecretAccessKey: "SECRET",
SessionToken: "",
},
},
},
}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "AKID", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "SECRET", creds.SecretAccessKey, "Expect secret access key to match")
assert.Empty(t, creds.SessionToken, "Expect session token to be empty")
}
func TestChainProviderIsExpired(t *testing.T) {
stubProvider := &stubProvider{expired: true}
p := &ChainProvider{
Providers: []Provider{
stubProvider,
},
}
assert.True(t, p.IsExpired(), "Expect expired to be true before any Retrieve")
_, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.False(t, p.IsExpired(), "Expect not expired after retrieve")
stubProvider.expired = true
assert.True(t, p.IsExpired(), "Expect return of expired provider")
_, err = p.Retrieve()
assert.False(t, p.IsExpired(), "Expect not expired after retrieve")
}
func TestChainProviderWithNoProvider(t *testing.T) {
p := &ChainProvider{
Providers: []Provider{},
}
assert.True(t, p.IsExpired(), "Expect expired with no providers")
_, err := p.Retrieve()
assert.Equal(t, ErrNoValidProvidersFoundInChain, err, "Expect no providers error returned")
}
func TestChainProviderWithNoValidProvider(t *testing.T) {
p := &ChainProvider{
Providers: []Provider{
&stubProvider{err: errors.New("first provider error")},
&stubProvider{err: errors.New("second provider error")},
},
}
assert.True(t, p.IsExpired(), "Expect expired with no providers")
_, err := p.Retrieve()
assert.Contains(t, "no valid providers in chain", err.Error(), "Expect no providers error returned")
}

View File

@@ -0,0 +1,178 @@
// Package credentials provides credential retrieval and management
//
// The Credentials is the primary method of getting access to and managing
// credentials Values. Using dependency injection retrieval of the credential
// values is handled by a object which satisfies the Provider interface.
//
// By default the Credentials.Get() will cache the successful result of a
// Provider's Retrieve() until Provider.IsExpired() returns true. At which
// point Credentials will call Provider's Retrieve() to get new credential Value.
//
// The Provider is responsible for determining when credentials Value have expired.
// It is also important to note that Credentials will always call Retrieve the
// first time Credentials.Get() is called.
//
// Example of using the environment variable credentials.
//
// creds := NewEnvCredentials()
//
// // Retrieve the credentials value
// credValue, err := creds.Get()
// if err != nil {
// // handle error
// }
//
// Example of forcing credentials to expire and be refreshed on the next Get().
// This may be helpful to proactively expire credentials and refresh them sooner
// than they would naturally expire on their own.
//
// creds := NewCredentials(&EC2RoleProvider{})
// creds.Expire()
// credsValue, err := creds.Get()
// // New credentials will be retrieved instead of from cache.
//
//
// Custom Provider
//
// Each Provider built into this package also provides a helper method to generate
// a Credentials pointer setup with the provider. To use a custom Provider just
// create a type which satisfies the Provider interface and pass it to the
// NewCredentials method.
//
// type MyProvider struct{}
// func (m *MyProvider) Retrieve() (Value, error) {...}
// func (m *MyProvider) IsExpired() bool {...}
//
// creds := NewCredentials(&MyProvider{})
// credValue, err := creds.Get()
//
package credentials
import (
"sync"
"time"
)
// Create an empty Credential object that can be used as dummy placeholder
// credentials for requests that do not need signed.
//
// This Credentials can be used to configure a service to not sign requests
// when making service API calls. For example, when accessing public
// s3 buckets.
//
// svc := s3.New(&aws.Config{Credentials: AnonymousCredentials})
// // Access public S3 buckets.
//
var AnonymousCredentials = NewStaticCredentials("", "", "")
// A Value is the AWS credentials value for individual credential fields.
type Value struct {
// AWS Access key ID
AccessKeyID string
// AWS Secret Access Key
SecretAccessKey string
// AWS Session Token
SessionToken string
}
// A Provider is the interface for any component which will provide credentials
// Value. A provider is required to manage its own Expired state, and what to
// be expired means.
//
// The Provider should not need to implement its own mutexes, because
// that will be managed by Credentials.
type Provider interface {
// Refresh returns nil if it successfully retrieved the value.
// Error is returned if the value were not obtainable, or empty.
Retrieve() (Value, error)
// IsExpired returns if the credentials are no longer valid, and need
// to be retrieved.
IsExpired() bool
}
// A Credentials provides synchronous safe retrieval of AWS credentials Value.
// Credentials will cache the credentials value until they expire. Once the value
// expires the next Get will attempt to retrieve valid credentials.
//
// Credentials is safe to use across multiple goroutines and will manage the
// synchronous state so the Providers do not need to implement their own
// synchronization.
//
// The first Credentials.Get() will always call Provider.Retrieve() to get the
// first instance of the credentials Value. All calls to Get() after that
// will return the cached credentials Value until IsExpired() returns true.
type Credentials struct {
creds Value
forceRefresh bool
m sync.Mutex
provider Provider
}
// NewCredentials returns a pointer to a new Credentials with the provider set.
func NewCredentials(provider Provider) *Credentials {
return &Credentials{
provider: provider,
forceRefresh: true,
}
}
// Get returns the credentials value, or error if the credentials Value failed
// to be retrieved.
//
// Will return the cached credentials Value if it has not expired. If the
// credentials Value has expired the Provider's Retrieve() will be called
// to refresh the credentials.
//
// If Credentials.Expire() was called the credentials Value will be force
// expired, and the next call to Get() will cause them to be refreshed.
func (c *Credentials) Get() (Value, error) {
c.m.Lock()
defer c.m.Unlock()
if c.isExpired() {
creds, err := c.provider.Retrieve()
if err != nil {
return Value{}, err
}
c.creds = creds
c.forceRefresh = false
}
return c.creds, nil
}
// Expire expires the credentials and forces them to be retrieved on the
// next call to Get().
//
// This will override the Provider's expired state, and force Credentials
// to call the Provider's Retrieve().
func (c *Credentials) Expire() {
c.m.Lock()
defer c.m.Unlock()
c.forceRefresh = true
}
// IsExpired returns if the credentials are no longer valid, and need
// to be retrieved.
//
// If the Credentials were forced to be expired with Expire() this will
// reflect that override.
func (c *Credentials) IsExpired() bool {
c.m.Lock()
defer c.m.Unlock()
return c.isExpired()
}
// isExpired helper method wrapping the definition of expired credentials.
func (c *Credentials) isExpired() bool {
return c.forceRefresh || c.provider.IsExpired()
}
// Provide a stub-able time.Now for unit tests so expiry can be tested.
var currentTime = time.Now

View File

@@ -0,0 +1,61 @@
package credentials
import (
"errors"
"github.com/stretchr/testify/assert"
"testing"
)
type stubProvider struct {
creds Value
expired bool
err error
}
func (s *stubProvider) Retrieve() (Value, error) {
s.expired = false
return s.creds, s.err
}
func (s *stubProvider) IsExpired() bool {
return s.expired
}
func TestCredentialsGet(t *testing.T) {
c := NewCredentials(&stubProvider{
creds: Value{
AccessKeyID: "AKID",
SecretAccessKey: "SECRET",
SessionToken: "",
},
expired: true,
})
creds, err := c.Get()
assert.Nil(t, err, "Expected no error")
assert.Equal(t, "AKID", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "SECRET", creds.SecretAccessKey, "Expect secret access key to match")
assert.Empty(t, creds.SessionToken, "Expect session token to be empty")
}
func TestCredentialsGetWithError(t *testing.T) {
c := NewCredentials(&stubProvider{err: errors.New("provider error"), expired: true})
_, err := c.Get()
assert.Equal(t, "provider error", err.Error(), "Expected provider error")
}
func TestCredentialsExpire(t *testing.T) {
stub := &stubProvider{}
c := NewCredentials(stub)
stub.expired = false
assert.True(t, c.IsExpired(), "Expected to start out expired")
c.Expire()
assert.True(t, c.IsExpired(), "Expected to be expired")
c.forceRefresh = false
assert.False(t, c.IsExpired(), "Expected not to be expired")
stub.expired = true
assert.True(t, c.IsExpired(), "Expected to be expired")
}

View File

@@ -0,0 +1,167 @@
package credentials
import (
"bufio"
"encoding/json"
"fmt"
"net/http"
"time"
)
const metadataCredentialsEndpoint = "http://169.254.169.254/latest/meta-data/iam/security-credentials/"
// A EC2RoleProvider retrieves credentials from the EC2 service, and keeps track if
// those credentials are expired.
//
// Example how to configure the EC2RoleProvider with custom http Client, Endpoint
// or ExpiryWindow
//
// p := &credentials.EC2RoleProvider{
// // Pass in a custom timeout to be used when requesting
// // IAM EC2 Role credentials.
// Client: &http.Client{
// Timeout: 10 * time.Second,
// },
// // Use default EC2 Role metadata endpoint, Alternate endpoints can be
// // specified setting Endpoint to something else.
// Endpoint: "",
// // Do not use early expiry of credentials. If a non zero value is
// // specified the credentials will be expired early
// ExpiryWindow: 0,
// }
//
type EC2RoleProvider struct {
// Endpoint must be fully quantified URL
Endpoint string
// HTTP client to use when connecting to EC2 service
Client *http.Client
// ExpiryWindow will allow the credentials to trigger refreshing prior to
// the credentials actually expiring. This is beneficial so race conditions
// with expiring credentials do not cause request to fail unexpectedly
// due to ExpiredTokenException exceptions.
//
// So a ExpiryWindow of 10s would cause calls to IsExpired() to return true
// 10 seconds before the credentials are actually expired.
//
// If ExpiryWindow is 0 or less it will be ignored.
ExpiryWindow time.Duration
// The date/time at which the credentials expire.
expiresOn time.Time
}
// NewEC2RoleCredentials returns a pointer to a new Credentials object
// wrapping the EC2RoleProvider.
//
// Takes a custom http.Client which can be configured for custom handling of
// things such as timeout.
//
// Endpoint is the URL that the EC2RoleProvider will connect to when retrieving
// role and credentials.
//
// Window is the expiry window that will be subtracted from the expiry returned
// by the role credential request. This is done so that the credentials will
// expire sooner than their actual lifespan.
func NewEC2RoleCredentials(client *http.Client, endpoint string, window time.Duration) *Credentials {
return NewCredentials(&EC2RoleProvider{
Endpoint: endpoint,
Client: client,
ExpiryWindow: window,
})
}
// Retrieve retrieves credentials from the EC2 service.
// Error will be returned if the request fails, or unable to extract
// the desired credentials.
func (m *EC2RoleProvider) Retrieve() (Value, error) {
if m.Client == nil {
m.Client = http.DefaultClient
}
if m.Endpoint == "" {
m.Endpoint = metadataCredentialsEndpoint
}
credsList, err := requestCredList(m.Client, m.Endpoint)
if err != nil {
return Value{}, err
}
if len(credsList) == 0 {
return Value{}, fmt.Errorf("empty MetadataService credentials list")
}
credsName := credsList[0]
roleCreds, err := requestCred(m.Client, m.Endpoint, credsName)
if err != nil {
return Value{}, err
}
m.expiresOn = roleCreds.Expiration
if m.ExpiryWindow > 0 {
// Offset based on expiry window if set.
m.expiresOn = m.expiresOn.Add(-m.ExpiryWindow)
}
return Value{
AccessKeyID: roleCreds.AccessKeyID,
SecretAccessKey: roleCreds.SecretAccessKey,
SessionToken: roleCreds.Token,
}, nil
}
// IsExpired returns if the credentials are expired.
func (m *EC2RoleProvider) IsExpired() bool {
return m.expiresOn.Before(currentTime())
}
// A ec2RoleCredRespBody provides the shape for deserializing credential
// request responses.
type ec2RoleCredRespBody struct {
Expiration time.Time
AccessKeyID string
SecretAccessKey string
Token string
}
// requestCredList requests a list of credentials from the EC2 service.
// If there are no credentials, or there is an error making or receiving the request
func requestCredList(client *http.Client, endpoint string) ([]string, error) {
resp, err := client.Get(endpoint)
if err != nil {
return nil, fmt.Errorf("%s listing MetadataService credentials", err)
}
defer resp.Body.Close()
credsList := []string{}
s := bufio.NewScanner(resp.Body)
for s.Scan() {
credsList = append(credsList, s.Text())
}
if err := s.Err(); err != nil {
return nil, fmt.Errorf("%s reading list of MetadataService credentials", err)
}
return credsList, nil
}
// requestCred requests the credentials for a specific credentials from the EC2 service.
//
// If the credentials cannot be found, or there is an error reading the response
// and error will be returned.
func requestCred(client *http.Client, endpoint, credsName string) (*ec2RoleCredRespBody, error) {
resp, err := client.Get(endpoint + credsName)
if err != nil {
return nil, fmt.Errorf("getting %s MetadataService credentials", credsName)
}
defer resp.Body.Close()
respCreds := &ec2RoleCredRespBody{}
if err := json.NewDecoder(resp.Body).Decode(respCreds); err != nil {
return nil, fmt.Errorf("decoding %s MetadataService credentials", credsName)
}
return respCreds, nil
}

View File

@@ -0,0 +1,114 @@
package credentials
import (
"fmt"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"testing"
"time"
)
func initTestServer(expireOn string) *httptest.Server {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.RequestURI == "/" {
fmt.Fprintln(w, "/creds")
} else {
fmt.Fprintf(w, `{
"AccessKeyId" : "accessKey",
"SecretAccessKey" : "secret",
"Token" : "token",
"Expiration" : "%s"
}`, expireOn)
}
}))
return server
}
func TestEC2RoleProvider(t *testing.T) {
server := initTestServer("2014-12-16T01:51:37Z")
defer server.Close()
p := &EC2RoleProvider{Client: http.DefaultClient, Endpoint: server.URL}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "token", creds.SessionToken, "Expect session token to match")
}
func TestEC2RoleProviderIsExpired(t *testing.T) {
server := initTestServer("2014-12-16T01:51:37Z")
defer server.Close()
p := &EC2RoleProvider{Client: http.DefaultClient, Endpoint: server.URL}
defer func() {
currentTime = time.Now
}()
currentTime = func() time.Time {
return time.Date(2014, 12, 15, 21, 26, 0, 0, time.UTC)
}
assert.True(t, p.IsExpired(), "Expect creds to be expired before retrieve.")
_, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.False(t, p.IsExpired(), "Expect creds to not be expired after retrieve.")
currentTime = func() time.Time {
return time.Date(3014, 12, 15, 21, 26, 0, 0, time.UTC)
}
assert.True(t, p.IsExpired(), "Expect creds to be expired.")
}
func TestEC2RoleProviderExpiryWindowIsExpired(t *testing.T) {
server := initTestServer("2014-12-16T01:51:37Z")
defer server.Close()
p := &EC2RoleProvider{Client: http.DefaultClient, Endpoint: server.URL, ExpiryWindow: time.Hour * 1}
defer func() {
currentTime = time.Now
}()
currentTime = func() time.Time {
return time.Date(2014, 12, 15, 0, 51, 37, 0, time.UTC)
}
assert.True(t, p.IsExpired(), "Expect creds to be expired before retrieve.")
_, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.False(t, p.IsExpired(), "Expect creds to not be expired after retrieve.")
currentTime = func() time.Time {
return time.Date(2014, 12, 16, 0, 55, 37, 0, time.UTC)
}
assert.True(t, p.IsExpired(), "Expect creds to be expired.")
}
func BenchmarkEC2RoleProvider(b *testing.B) {
server := initTestServer("2014-12-16T01:51:37Z")
defer server.Close()
p := &EC2RoleProvider{Client: http.DefaultClient, Endpoint: server.URL}
_, err := p.Retrieve()
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_, err := p.Retrieve()
if err != nil {
b.Fatal(err)
}
}
})
}

View File

@@ -0,0 +1,66 @@
package credentials
import (
"fmt"
"os"
)
var (
// ErrAccessKeyIDNotFound is returned when the AWS Access Key ID can't be
// found in the process's environment.
ErrAccessKeyIDNotFound = fmt.Errorf("AWS_ACCESS_KEY_ID or AWS_ACCESS_KEY not found in environment")
// ErrSecretAccessKeyNotFound is returned when the AWS Secret Access Key
// can't be found in the process's environment.
ErrSecretAccessKeyNotFound = fmt.Errorf("AWS_SECRET_ACCESS_KEY or AWS_SECRET_KEY not found in environment")
)
// A EnvProvider retrieves credentials from the environment variables of the
// running process. Environment credentials never expire.
//
// Environment variables used:
// - Access Key ID: AWS_ACCESS_KEY_ID or AWS_ACCESS_KEY
// - Secret Access Key: AWS_SECRET_ACCESS_KEY or AWS_SECRET_KEY
type EnvProvider struct {
retrieved bool
}
// NewEnvCredentials returns a pointer to a new Credentials object
// wrapping the environment variable provider.
func NewEnvCredentials() *Credentials {
return NewCredentials(&EnvProvider{})
}
// Retrieve retrieves the keys from the environment.
func (e *EnvProvider) Retrieve() (Value, error) {
e.retrieved = false
id := os.Getenv("AWS_ACCESS_KEY_ID")
if id == "" {
id = os.Getenv("AWS_ACCESS_KEY")
}
secret := os.Getenv("AWS_SECRET_ACCESS_KEY")
if secret == "" {
secret = os.Getenv("AWS_SECRET_KEY")
}
if id == "" {
return Value{}, ErrAccessKeyIDNotFound
}
if secret == "" {
return Value{}, ErrSecretAccessKeyNotFound
}
e.retrieved = true
return Value{
AccessKeyID: id,
SecretAccessKey: secret,
SessionToken: os.Getenv("AWS_SESSION_TOKEN"),
}, nil
}
// IsExpired returns if the credentials have been retrieved.
func (e *EnvProvider) IsExpired() bool {
return !e.retrieved
}

View File

@@ -0,0 +1,70 @@
package credentials
import (
"github.com/stretchr/testify/assert"
"os"
"testing"
)
func TestEnvProviderRetrieve(t *testing.T) {
os.Clearenv()
os.Setenv("AWS_ACCESS_KEY_ID", "access")
os.Setenv("AWS_SECRET_ACCESS_KEY", "secret")
os.Setenv("AWS_SESSION_TOKEN", "token")
e := EnvProvider{}
creds, err := e.Retrieve()
assert.Nil(t, err, "Expect no error", err)
assert.Equal(t, "access", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "token", creds.SessionToken, "Expect session token to match")
}
func TestEnvProviderIsExpired(t *testing.T) {
os.Clearenv()
os.Setenv("AWS_ACCESS_KEY_ID", "access")
os.Setenv("AWS_SECRET_ACCESS_KEY", "secret")
os.Setenv("AWS_SESSION_TOKEN", "token")
e := EnvProvider{}
assert.True(t, e.IsExpired(), "Expect creds to be expired before retrieve.")
_, err := e.Retrieve()
assert.Nil(t, err, "Expect no error", err)
assert.False(t, e.IsExpired(), "Expect creds to not be expired after retrieve.")
}
func TestEnvProviderNoAccessKeyID(t *testing.T) {
os.Clearenv()
os.Setenv("AWS_SECRET_ACCESS_KEY", "secret")
e := EnvProvider{}
creds, err := e.Retrieve()
assert.Equal(t, ErrAccessKeyIDNotFound, err, "ErrAccessKeyIDNotFound expected, but was %#v error: %#v", creds, err)
}
func TestEnvProviderNoSecretAccessKey(t *testing.T) {
os.Clearenv()
os.Setenv("AWS_ACCESS_KEY_ID", "access")
e := EnvProvider{}
creds, err := e.Retrieve()
assert.Equal(t, ErrSecretAccessKeyNotFound, err, "ErrSecretAccessKeyNotFound expected, but was %#v error: %#v", creds, err)
}
func TestEnvProviderAlternateNames(t *testing.T) {
os.Clearenv()
os.Setenv("AWS_ACCESS_KEY", "access")
os.Setenv("AWS_SECRET_KEY", "secret")
e := EnvProvider{}
creds, err := e.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "access", creds.AccessKeyID, "Expected access key ID")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expected secret access key")
assert.Empty(t, creds.SessionToken, "Expected no token")
}

View File

@@ -0,0 +1,8 @@
[default]
aws_access_key_id = accessKey
aws_secret_access_key = secret
aws_session_token = token
[no_token]
aws_access_key_id = accessKey
aws_secret_access_key = secret

View File

@@ -0,0 +1,127 @@
package credentials
import (
"fmt"
"os"
"path/filepath"
"github.com/vaughan0/go-ini"
)
var (
// ErrSharedCredentialsHomeNotFound is emitted when the user directory cannot be found.
ErrSharedCredentialsHomeNotFound = fmt.Errorf("User home directory not found.")
)
// A SharedCredentialsProvider retrieves credentials from the current user's home
// directory, and keeps track if those credentials are expired.
//
// Profile ini file example: $HOME/.aws/credentials
type SharedCredentialsProvider struct {
// Path to the shared credentials file. If empty will default to current user's
// home directory.
Filename string
// AWS Profile to extract credentials from the shared credentials file. If empty
// will default to environment variable "AWS_PROFILE" or "default" if
// environment variable is also not set.
Profile string
// retrieved states if the credentials have been successfully retrieved.
retrieved bool
}
// NewSharedCredentials returns a pointer to a new Credentials object
// wrapping the Profile file provider.
func NewSharedCredentials(filename, profile string) *Credentials {
return NewCredentials(&SharedCredentialsProvider{
Filename: filename,
Profile: profile,
})
}
// Retrieve reads and extracts the shared credentials from the current
// users home directory.
func (p *SharedCredentialsProvider) Retrieve() (Value, error) {
p.retrieved = false
filename, err := p.filename()
if err != nil {
return Value{}, err
}
creds, err := loadProfile(filename, p.profile())
if err != nil {
return Value{}, err
}
p.retrieved = true
return creds, nil
}
// IsExpired returns if the shared credentials have expired.
func (p *SharedCredentialsProvider) IsExpired() bool {
return !p.retrieved
}
// loadProfiles loads from the file pointed to by shared credentials filename for profile.
// The credentials retrieved from the profile will be returned or error. Error will be
// returned if it fails to read from the file, or the data is invalid.
func loadProfile(filename, profile string) (Value, error) {
config, err := ini.LoadFile(filename)
if err != nil {
return Value{}, err
}
iniProfile := config.Section(profile)
id, ok := iniProfile["aws_access_key_id"]
if !ok {
return Value{}, fmt.Errorf("shared credentials %s in %s did not contain aws_access_key_id", profile, filename)
}
secret, ok := iniProfile["aws_secret_access_key"]
if !ok {
return Value{}, fmt.Errorf("shared credentials %s in %s did not contain aws_secret_access_key", profile, filename)
}
token := iniProfile["aws_session_token"]
return Value{
AccessKeyID: id,
SecretAccessKey: secret,
SessionToken: token,
}, nil
}
// filename returns the filename to use to read AWS shared credentials.
//
// Will return an error if the user's home directory path cannot be found.
func (p *SharedCredentialsProvider) filename() (string, error) {
if p.Filename == "" {
homeDir := os.Getenv("HOME") // *nix
if homeDir == "" { // Windows
homeDir = os.Getenv("USERPROFILE")
}
if homeDir == "" {
return "", ErrSharedCredentialsHomeNotFound
}
p.Filename = filepath.Join(homeDir, ".aws", "credentials")
}
return p.Filename, nil
}
// profile returns the AWS shared credentials profile. If empty will read
// environment variable "AWS_PROFILE". If that is not set profile will
// return "default".
func (p *SharedCredentialsProvider) profile() string {
if p.Profile == "" {
p.Profile = os.Getenv("AWS_PROFILE")
}
if p.Profile == "" {
p.Profile = "default"
}
return p.Profile
}

View File

@@ -0,0 +1,77 @@
package credentials
import (
"github.com/stretchr/testify/assert"
"os"
"testing"
)
func TestSharedCredentialsProvider(t *testing.T) {
os.Clearenv()
p := SharedCredentialsProvider{Filename: "example.ini", Profile: ""}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "token", creds.SessionToken, "Expect session token to match")
}
func TestSharedCredentialsProviderIsExpired(t *testing.T) {
os.Clearenv()
p := SharedCredentialsProvider{Filename: "example.ini", Profile: ""}
assert.True(t, p.IsExpired(), "Expect creds to be expired before retrieve")
_, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.False(t, p.IsExpired(), "Expect creds to not be expired after retrieve")
}
func TestSharedCredentialsProviderWithAWS_PROFILE(t *testing.T) {
os.Clearenv()
os.Setenv("AWS_PROFILE", "no_token")
p := SharedCredentialsProvider{Filename: "example.ini", Profile: ""}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Empty(t, creds.SessionToken, "Expect no token")
}
func TestSharedCredentialsProviderWithoutTokenFromProfile(t *testing.T) {
os.Clearenv()
p := SharedCredentialsProvider{Filename: "example.ini", Profile: "no_token"}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Empty(t, creds.SessionToken, "Expect no token")
}
func BenchmarkSharedCredentialsProvider(b *testing.B) {
os.Clearenv()
p := SharedCredentialsProvider{Filename: "example.ini", Profile: ""}
_, err := p.Retrieve()
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_, err := p.Retrieve()
if err != nil {
b.Fatal(err)
}
}
})
}

View File

@@ -0,0 +1,42 @@
package credentials
import (
"fmt"
)
var (
// ErrStaticCredentialsEmpty is emitted when static credentials are empty.
ErrStaticCredentialsEmpty = fmt.Errorf("static credentials are empty")
)
// A StaticProvider is a set of credentials which are set pragmatically,
// and will never expire.
type StaticProvider struct {
Value
}
// NewStaticCredentials returns a pointer to a new Credentials object
// wrapping a static credentials value provider.
func NewStaticCredentials(id, secret, token string) *Credentials {
return NewCredentials(&StaticProvider{Value: Value{
AccessKeyID: id,
SecretAccessKey: secret,
SessionToken: token,
}})
}
// Retrieve returns the credentials or error if the credentials are invalid.
func (s *StaticProvider) Retrieve() (Value, error) {
if s.AccessKeyID == "" || s.SecretAccessKey == "" {
return Value{}, ErrStaticCredentialsEmpty
}
return s.Value, nil
}
// IsExpired returns if the credentials are expired.
//
// For StaticProvider, the credentials never expired.
func (s *StaticProvider) IsExpired() bool {
return false
}

View File

@@ -0,0 +1,34 @@
package credentials
import (
"github.com/stretchr/testify/assert"
"testing"
)
func TestStaticProviderGet(t *testing.T) {
s := StaticProvider{
Value: Value{
AccessKeyID: "AKID",
SecretAccessKey: "SECRET",
SessionToken: "",
},
}
creds, err := s.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "AKID", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "SECRET", creds.SecretAccessKey, "Expect secret access key to match")
assert.Empty(t, creds.SessionToken, "Expect no session token")
}
func TestStaticProviderIsExpired(t *testing.T) {
s := StaticProvider{
Value: Value{
AccessKeyID: "AKID",
SecretAccessKey: "SECRET",
SessionToken: "",
},
}
assert.False(t, s.IsExpired(), "Expect static credentials to never expire")
}

View File

@@ -0,0 +1,26 @@
package aws
// An APIError is an error returned by an AWS API.
type APIError struct {
StatusCode int // HTTP status code e.g. 200
Code string
Message string
RequestID string
}
// Error returns the error as a string. Satisfies error interface.
func (e APIError) Error() string {
return e.Code + ": " + e.Message
}
// Error returns an APIError pointer if the error e is an APIError type.
// If the error is not an APIError nil will be returned.
func Error(e error) *APIError {
if err, ok := e.(*APIError); ok {
return err
} else if err, ok := e.(APIError); ok {
return &err
} else {
return nil
}
}

View File

@@ -0,0 +1,136 @@
package aws
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"regexp"
"strconv"
"time"
)
var sleepDelay = func(delay time.Duration) {
time.Sleep(delay)
}
// Interface for matching types which also have a Len method.
type lener interface {
Len() int
}
// BuildContentLength builds the content length of a request based on the body,
// or will use the HTTPRequest.Header's "Content-Length" if defined. If unable
// to determine request body length and no "Content-Length" was specified it will panic.
func BuildContentLength(r *Request) {
if slength := r.HTTPRequest.Header.Get("Content-Length"); slength != "" {
length, _ := strconv.ParseInt(slength, 10, 64)
r.HTTPRequest.ContentLength = length
return
}
var length int64
switch body := r.Body.(type) {
case nil:
length = 0
case lener:
length = int64(body.Len())
case io.Seeker:
r.bodyStart, _ = body.Seek(0, 1)
end, _ := body.Seek(0, 2)
body.Seek(r.bodyStart, 0) // make sure to seek back to original location
length = end - r.bodyStart
default:
panic("Cannot get length of body, must provide `ContentLength`")
}
r.HTTPRequest.ContentLength = length
r.HTTPRequest.Header.Set("Content-Length", fmt.Sprintf("%d", length))
}
// UserAgentHandler is a request handler for injecting User agent into requests.
func UserAgentHandler(r *Request) {
r.HTTPRequest.Header.Set("User-Agent", SDKName+"/"+SDKVersion)
}
var reStatusCode = regexp.MustCompile(`^(\d+)`)
// SendHandler is a request handler to send service request using HTTP client.
func SendHandler(r *Request) {
r.HTTPResponse, r.Error = r.Service.Config.HTTPClient.Do(r.HTTPRequest)
if r.Error != nil {
if e, ok := r.Error.(*url.Error); ok {
if s := reStatusCode.FindStringSubmatch(e.Err.Error()); s != nil {
code, _ := strconv.ParseInt(s[1], 10, 64)
r.Error = nil
r.HTTPResponse = &http.Response{
StatusCode: int(code),
Status: http.StatusText(int(code)),
Body: ioutil.NopCloser(bytes.NewReader([]byte{})),
}
}
}
}
}
// ValidateResponseHandler is a request handler to validate service response.
func ValidateResponseHandler(r *Request) {
if r.HTTPResponse.StatusCode == 0 || r.HTTPResponse.StatusCode >= 300 {
// this may be replaced by an UnmarshalError handler
r.Error = &APIError{
StatusCode: r.HTTPResponse.StatusCode,
Code: "UnknownError",
Message: "unknown error",
}
}
}
// AfterRetryHandler performs final checks to determine if the request should
// be retried and how long to delay.
func AfterRetryHandler(r *Request) {
// If one of the other handlers already set the retry state
// we don't want to override it based on the service's state
if !r.Retryable.IsSet() {
r.Retryable.Set(r.Service.ShouldRetry(r))
}
if r.WillRetry() {
r.RetryDelay = r.Service.RetryRules(r)
sleepDelay(r.RetryDelay)
// when the expired token exception occurs the credentials
// need to be expired locally so that the next request to
// get credentials will trigger a credentials refresh.
if err := Error(r.Error); err != nil && err.Code == "ExpiredTokenException" {
r.Config.Credentials.Expire()
// The credentials will need to be resigned with new credentials
r.signed = false
}
r.RetryCount++
r.Error = nil
}
}
var (
// ErrMissingRegion is an error that is returned if region configuration is
// not found.
ErrMissingRegion = fmt.Errorf("could not find region configuration")
// ErrMissingEndpoint is an error that is returned if an endpoint cannot be
// resolved for a service.
ErrMissingEndpoint = fmt.Errorf("`Endpoint' configuration is required for this service")
)
// ValidateEndpointHandler is a request handler to validate a request had the
// appropriate Region and Endpoint set. Will set r.Error if the endpoint or
// region is not valid.
func ValidateEndpointHandler(r *Request) {
if r.Service.SigningRegion == "" && r.Service.Config.Region == "" {
r.Error = ErrMissingRegion
} else if r.Service.Endpoint == "" {
r.Error = ErrMissingEndpoint
}
}

View File

@@ -0,0 +1,80 @@
package aws
import (
"net/http"
"os"
"testing"
"github.com/awslabs/aws-sdk-go/aws/credentials"
"github.com/stretchr/testify/assert"
)
func TestValidateEndpointHandler(t *testing.T) {
os.Clearenv()
svc := NewService(&Config{Region: "us-west-2"})
svc.Handlers.Clear()
svc.Handlers.Validate.PushBack(ValidateEndpointHandler)
req := NewRequest(svc, &Operation{Name: "Operation"}, nil, nil)
err := req.Build()
assert.NoError(t, err)
}
func TestValidateEndpointHandlerErrorRegion(t *testing.T) {
os.Clearenv()
svc := NewService(nil)
svc.Handlers.Clear()
svc.Handlers.Validate.PushBack(ValidateEndpointHandler)
req := NewRequest(svc, &Operation{Name: "Operation"}, nil, nil)
err := req.Build()
assert.Error(t, err)
assert.Equal(t, ErrMissingRegion, err)
}
type mockCredsProvider struct {
expired bool
retreiveCalled bool
}
func (m *mockCredsProvider) Retrieve() (credentials.Value, error) {
m.retreiveCalled = true
return credentials.Value{}, nil
}
func (m *mockCredsProvider) IsExpired() bool {
return m.expired
}
func TestAfterRetryRefreshCreds(t *testing.T) {
os.Clearenv()
credProvider := &mockCredsProvider{}
svc := NewService(&Config{Credentials: credentials.NewCredentials(credProvider), MaxRetries: 1})
svc.Handlers.Clear()
svc.Handlers.ValidateResponse.PushBack(func(r *Request) {
r.Error = &APIError{Code: "UnknownError"}
r.HTTPResponse = &http.Response{StatusCode: 400}
})
svc.Handlers.UnmarshalError.PushBack(func(r *Request) {
r.Error = &APIError{Code: "ExpiredTokenException"}
})
svc.Handlers.AfterRetry.PushBack(func(r *Request) {
AfterRetryHandler(r)
})
assert.True(t, svc.Config.Credentials.IsExpired(), "Expect to start out expired")
assert.False(t, credProvider.retreiveCalled)
req := NewRequest(svc, &Operation{Name: "Operation"}, nil, nil)
req.Send()
assert.True(t, svc.Config.Credentials.IsExpired())
assert.False(t, credProvider.retreiveCalled)
_, err := svc.Config.Credentials.Get()
assert.NoError(t, err)
assert.True(t, credProvider.retreiveCalled)
}

View File

@@ -0,0 +1,85 @@
package aws
// A Handlers provides a collection of request handlers for various
// stages of handling requests.
type Handlers struct {
Validate HandlerList
Build HandlerList
Sign HandlerList
Send HandlerList
ValidateResponse HandlerList
Unmarshal HandlerList
UnmarshalMeta HandlerList
UnmarshalError HandlerList
Retry HandlerList
AfterRetry HandlerList
}
// copy returns of this handler's lists.
func (h *Handlers) copy() Handlers {
return Handlers{
Validate: h.Validate.copy(),
Build: h.Build.copy(),
Sign: h.Sign.copy(),
Send: h.Send.copy(),
ValidateResponse: h.ValidateResponse.copy(),
Unmarshal: h.Unmarshal.copy(),
UnmarshalError: h.UnmarshalError.copy(),
UnmarshalMeta: h.UnmarshalMeta.copy(),
Retry: h.Retry.copy(),
AfterRetry: h.AfterRetry.copy(),
}
}
// Clear removes callback functions for all handlers
func (h *Handlers) Clear() {
h.Validate.Clear()
h.Build.Clear()
h.Send.Clear()
h.Sign.Clear()
h.Unmarshal.Clear()
h.UnmarshalMeta.Clear()
h.UnmarshalError.Clear()
h.ValidateResponse.Clear()
h.Retry.Clear()
h.AfterRetry.Clear()
}
// A HandlerList manages zero or more handlers in a list.
type HandlerList struct {
list []func(*Request)
}
// copy creates a copy of the handler list.
func (l *HandlerList) copy() HandlerList {
var n HandlerList
n.list = append([]func(*Request){}, l.list...)
return n
}
// Clear clears the handler list.
func (l *HandlerList) Clear() {
l.list = []func(*Request){}
}
// Len returns the number of handlers in the list.
func (l *HandlerList) Len() int {
return len(l.list)
}
// PushBack pushes handlers f to the back of the handler list.
func (l *HandlerList) PushBack(f ...func(*Request)) {
l.list = append(l.list, f...)
}
// PushFront pushes handlers f to the front of the handler list.
func (l *HandlerList) PushFront(f ...func(*Request)) {
l.list = append(f, l.list...)
}
// Run executes all handlers in the list with a given request object.
func (l *HandlerList) Run(r *Request) {
for _, f := range l.list {
f(r)
}
}

View File

@@ -0,0 +1,31 @@
package aws
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestHandlerList(t *testing.T) {
s := ""
r := &Request{}
l := HandlerList{}
l.PushBack(func(r *Request) {
s += "a"
r.Data = s
})
l.Run(r)
assert.Equal(t, "a", s)
assert.Equal(t, "a", r.Data)
}
func TestMultipleHandlers(t *testing.T) {
r := &Request{}
l := HandlerList{}
l.PushBack(func(r *Request) { r.Data = nil })
l.PushFront(func(r *Request) { r.Data = Boolean(true) })
l.Run(r)
if r.Data != nil {
t.Error("Expected handler to execute")
}
}

View File

@@ -0,0 +1,87 @@
package aws
import (
"fmt"
"reflect"
"strings"
)
// ValidateParameters is a request handler to validate the input parameters.
// Validating parameters only has meaning if done prior to the request being sent.
func ValidateParameters(r *Request) {
if r.ParamsFilled() {
v := validator{errors: []string{}}
v.validateAny(reflect.ValueOf(r.Params), "")
if count := len(v.errors); count > 0 {
format := "%d validation errors:\n- %s"
msg := fmt.Sprintf(format, count, strings.Join(v.errors, "\n- "))
r.Error = APIError{Code: "InvalidParameter", Message: msg}
}
}
}
// A validator validates values. Collects validations errors which occurs.
type validator struct {
errors []string
}
// validateAny will validate any struct, slice or map type. All validations
// are also performed recursively for nested types.
func (v *validator) validateAny(value reflect.Value, path string) {
value = reflect.Indirect(value)
if !value.IsValid() {
return
}
switch value.Kind() {
case reflect.Struct:
v.validateStruct(value, path)
case reflect.Slice:
for i := 0; i < value.Len(); i++ {
v.validateAny(value.Index(i), path+fmt.Sprintf("[%d]", i))
}
case reflect.Map:
for _, n := range value.MapKeys() {
v.validateAny(value.MapIndex(n), path+fmt.Sprintf("[%q]", n.String()))
}
}
}
// validateStruct will validate the struct value's fields. If the structure has
// nested types those types will be validated also.
func (v *validator) validateStruct(value reflect.Value, path string) {
prefix := "."
if path == "" {
prefix = ""
}
for i := 0; i < value.Type().NumField(); i++ {
f := value.Type().Field(i)
if strings.ToLower(f.Name[0:1]) == f.Name[0:1] {
continue
}
fvalue := value.FieldByName(f.Name)
notset := false
if f.Tag.Get("required") != "" {
switch fvalue.Kind() {
case reflect.Ptr, reflect.Slice:
if fvalue.IsNil() {
notset = true
}
default:
if !fvalue.IsValid() {
notset = true
}
}
}
if notset {
msg := "missing required parameter: " + path + prefix + f.Name
v.errors = append(v.errors, msg)
} else {
v.validateAny(fvalue, path+prefix+f.Name)
}
}
}

View File

@@ -0,0 +1,85 @@
package aws_test
import (
"testing"
"github.com/awslabs/aws-sdk-go/aws"
"github.com/stretchr/testify/assert"
)
var service = func() *aws.Service {
s := &aws.Service{
Config: &aws.Config{},
ServiceName: "mock-service",
APIVersion: "2015-01-01",
}
return s
}()
type StructShape struct {
RequiredList []*ConditionalStructShape `required:"true"`
RequiredMap *map[string]*ConditionalStructShape `required:"true"`
RequiredBool *bool `required:"true"`
OptionalStruct *ConditionalStructShape
hiddenParameter *string
metadataStructureShape
}
type metadataStructureShape struct {
SDKShapeTraits bool
}
type ConditionalStructShape struct {
Name *string `required:"true"`
SDKShapeTraits bool
}
func TestNoErrors(t *testing.T) {
input := &StructShape{
RequiredList: []*ConditionalStructShape{},
RequiredMap: &map[string]*ConditionalStructShape{
"key1": &ConditionalStructShape{Name: aws.String("Name")},
"key2": &ConditionalStructShape{Name: aws.String("Name")},
},
RequiredBool: aws.Boolean(true),
OptionalStruct: &ConditionalStructShape{Name: aws.String("Name")},
}
req := aws.NewRequest(service, &aws.Operation{}, input, nil)
aws.ValidateParameters(req)
assert.NoError(t, req.Error)
}
func TestMissingRequiredParameters(t *testing.T) {
input := &StructShape{}
req := aws.NewRequest(service, &aws.Operation{}, input, nil)
aws.ValidateParameters(req)
err := aws.Error(req.Error)
assert.Error(t, err)
assert.Equal(t, "InvalidParameter", err.Code)
assert.Equal(t, "3 validation errors:\n- missing required parameter: RequiredList\n- missing required parameter: RequiredMap\n- missing required parameter: RequiredBool", err.Message)
}
func TestNestedMissingRequiredParameters(t *testing.T) {
input := &StructShape{
RequiredList: []*ConditionalStructShape{&ConditionalStructShape{}},
RequiredMap: &map[string]*ConditionalStructShape{
"key1": &ConditionalStructShape{Name: aws.String("Name")},
"key2": &ConditionalStructShape{},
},
RequiredBool: aws.Boolean(true),
OptionalStruct: &ConditionalStructShape{},
}
req := aws.NewRequest(service, &aws.Operation{}, input, nil)
aws.ValidateParameters(req)
err := aws.Error(req.Error)
assert.Error(t, err)
assert.Equal(t, "InvalidParameter", err.Code)
assert.Equal(t, "3 validation errors:\n- missing required parameter: RequiredList[0].Name\n- missing required parameter: RequiredMap[\"key2\"].Name\n- missing required parameter: OptionalStruct.Name", err.Message)
}

View File

@@ -0,0 +1,219 @@
package aws
import (
"bytes"
"io"
"io/ioutil"
"net/http"
"net/url"
"reflect"
"strings"
"time"
)
// A Request is the service request to be made.
type Request struct {
*Service
Handlers Handlers
Time time.Time
ExpireTime time.Duration
Operation *Operation
HTTPRequest *http.Request
HTTPResponse *http.Response
Body io.ReadSeeker
bodyStart int64 // offset from beginning of Body that the request body starts
Params interface{}
Error error
Data interface{}
RequestID string
RetryCount uint
Retryable SettableBool
RetryDelay time.Duration
built bool
signed bool
}
// An Operation is the service API operation to be made.
type Operation struct {
Name string
HTTPMethod string
HTTPPath string
}
// NewRequest returns a new Request pointer for the service API
// operation and parameters.
//
// Params is any value of input parameters to be the request payload.
// Data is pointer value to an object which the request's response
// payload will be deserialized to.
func NewRequest(service *Service, operation *Operation, params interface{}, data interface{}) *Request {
method := operation.HTTPMethod
if method == "" {
method = "POST"
}
p := operation.HTTPPath
if p == "" {
p = "/"
}
httpReq, _ := http.NewRequest(method, "", nil)
httpReq.URL, _ = url.Parse(service.Endpoint + p)
r := &Request{
Service: service,
Handlers: service.Handlers.copy(),
Time: time.Now(),
ExpireTime: 0,
Operation: operation,
HTTPRequest: httpReq,
Body: nil,
Params: params,
Error: nil,
Data: data,
}
r.SetBufferBody([]byte{})
return r
}
// WillRetry returns if the request's can be retried.
func (r *Request) WillRetry() bool {
return r.Error != nil && r.Retryable.Get() && r.RetryCount < r.Service.MaxRetries()
}
// ParamsFilled returns if the request's parameters have been populated
// and the parameters are valid. False is returned if no parameters are
// provided or invalid.
func (r *Request) ParamsFilled() bool {
return r.Params != nil && reflect.ValueOf(r.Params).Elem().IsValid()
}
// DataFilled returns true if the request's data for response deserialization
// target has been set and is a valid. False is returned if data is not
// set, or is invalid.
func (r *Request) DataFilled() bool {
return r.Data != nil && reflect.ValueOf(r.Data).Elem().IsValid()
}
// SetBufferBody will set the request's body bytes that will be sent to
// the service API.
func (r *Request) SetBufferBody(buf []byte) {
r.SetReaderBody(bytes.NewReader(buf))
}
// SetStringBody sets the body of the request to be backed by a string.
func (r *Request) SetStringBody(s string) {
r.SetReaderBody(strings.NewReader(s))
}
// SetReaderBody will set the request's body reader.
func (r *Request) SetReaderBody(reader io.ReadSeeker) {
r.HTTPRequest.Body = ioutil.NopCloser(reader)
r.Body = reader
}
// Presign returns the request's signed URL. Error will be returned
// if the signing fails.
func (r *Request) Presign(expireTime time.Duration) (string, error) {
r.ExpireTime = expireTime
r.Sign()
if r.Error != nil {
return "", r.Error
}
return r.HTTPRequest.URL.String(), nil
}
// Build will build the request's object so it can be signed and sent
// to the service. Build will also validate all the request's parameters.
// Anny additional build Handlers set on this request will be run
// in the order they were set.
//
// The request will only be built once. Multiple calls to build will have
// no effect.
//
// If any Validate or Build errors occur the build will stop and the error
// which occurred will be returned.
func (r *Request) Build() error {
if !r.built {
r.Error = nil
r.Handlers.Validate.Run(r)
if r.Error != nil {
return r.Error
}
r.Handlers.Build.Run(r)
r.built = true
}
return r.Error
}
// Sign will sign the request retuning error if errors are encountered.
//
// Send will build the request prior to signing. All Sign Handlers will
// be executed in the order they were set.
func (r *Request) Sign() error {
if r.signed {
return r.Error
}
r.Build()
if r.Error != nil {
return r.Error
}
r.Handlers.Sign.Run(r)
r.signed = r.Error != nil
return r.Error
}
// Send will send the request returning error if errors are encountered.
//
// Send will sign the request prior to sending. All Send Handlers will
// be executed in the order they were set.
func (r *Request) Send() error {
for {
r.Sign()
if r.Error != nil {
return r.Error
}
if r.Retryable.Get() {
// Re-seek the body back to the original point in for a retry so that
// send will send the body's contents again in the upcoming request.
r.Body.Seek(r.bodyStart, 0)
}
r.Retryable.Reset()
r.Handlers.Send.Run(r)
if r.Error != nil {
return r.Error
}
r.Handlers.UnmarshalMeta.Run(r)
r.Handlers.ValidateResponse.Run(r)
if r.Error != nil {
r.Handlers.UnmarshalError.Run(r)
r.Handlers.Retry.Run(r)
r.Handlers.AfterRetry.Run(r)
if r.Error != nil {
return r.Error
}
continue
}
r.Handlers.Unmarshal.Run(r)
if r.Error != nil {
r.Handlers.Retry.Run(r)
r.Handlers.AfterRetry.Run(r)
if r.Error != nil {
return r.Error
}
continue
}
break
}
return nil
}

View File

@@ -0,0 +1,218 @@
package aws
import (
"bytes"
"encoding/json"
"io"
"io/ioutil"
"net/http"
"reflect"
"testing"
"time"
"github.com/awslabs/aws-sdk-go/aws/credentials"
"github.com/stretchr/testify/assert"
)
type testData struct {
Data string
}
func body(str string) io.ReadCloser {
return ioutil.NopCloser(bytes.NewReader([]byte(str)))
}
func unmarshal(req *Request) {
defer req.HTTPResponse.Body.Close()
if req.Data != nil {
json.NewDecoder(req.HTTPResponse.Body).Decode(req.Data)
}
return
}
func unmarshalError(req *Request) {
bodyBytes, err := ioutil.ReadAll(req.HTTPResponse.Body)
if err != nil {
req.Error = err
return
}
if len(bodyBytes) == 0 {
req.Error = APIError{
StatusCode: req.HTTPResponse.StatusCode,
Message: req.HTTPResponse.Status,
}
return
}
var jsonErr jsonErrorResponse
if err := json.Unmarshal(bodyBytes, &jsonErr); err != nil {
req.Error = err
return
}
req.Error = APIError{
StatusCode: req.HTTPResponse.StatusCode,
Code: jsonErr.Code,
Message: jsonErr.Message,
}
}
type jsonErrorResponse struct {
Code string `json:"__type"`
Message string `json:"message"`
}
// test that retries occur for 5xx status codes
func TestRequestRecoverRetry5xx(t *testing.T) {
reqNum := 0
reqs := []http.Response{
http.Response{StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
http.Response{StatusCode: 501, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
http.Response{StatusCode: 200, Body: body(`{"data":"valid"}`)},
}
s := NewService(&Config{MaxRetries: 10})
s.Handlers.Validate.Clear()
s.Handlers.Unmarshal.PushBack(unmarshal)
s.Handlers.UnmarshalError.PushBack(unmarshalError)
s.Handlers.Send.Clear() // mock sending
s.Handlers.Send.PushBack(func(r *Request) {
r.HTTPResponse = &reqs[reqNum]
reqNum++
})
out := &testData{}
r := NewRequest(s, &Operation{Name: "Operation"}, nil, out)
err := r.Send()
assert.Nil(t, err)
assert.Equal(t, 2, int(r.RetryCount))
assert.Equal(t, "valid", out.Data)
}
// test that retries occur for 4xx status codes with a response type that can be retried - see `shouldRetry`
func TestRequestRecoverRetry4xxRetryable(t *testing.T) {
reqNum := 0
reqs := []http.Response{
http.Response{StatusCode: 400, Body: body(`{"__type":"Throttling","message":"Rate exceeded."}`)},
http.Response{StatusCode: 429, Body: body(`{"__type":"ProvisionedThroughputExceededException","message":"Rate exceeded."}`)},
http.Response{StatusCode: 200, Body: body(`{"data":"valid"}`)},
}
s := NewService(&Config{MaxRetries: 10})
s.Handlers.Validate.Clear()
s.Handlers.Unmarshal.PushBack(unmarshal)
s.Handlers.UnmarshalError.PushBack(unmarshalError)
s.Handlers.Send.Clear() // mock sending
s.Handlers.Send.PushBack(func(r *Request) {
r.HTTPResponse = &reqs[reqNum]
reqNum++
})
out := &testData{}
r := NewRequest(s, &Operation{Name: "Operation"}, nil, out)
err := r.Send()
assert.Nil(t, err)
assert.Equal(t, 2, int(r.RetryCount))
assert.Equal(t, "valid", out.Data)
}
// test that retries don't occur for 4xx status codes with a response type that can't be retried
func TestRequest4xxUnretryable(t *testing.T) {
s := NewService(&Config{MaxRetries: 10})
s.Handlers.Validate.Clear()
s.Handlers.Unmarshal.PushBack(unmarshal)
s.Handlers.UnmarshalError.PushBack(unmarshalError)
s.Handlers.Send.Clear() // mock sending
s.Handlers.Send.PushBack(func(r *Request) {
r.HTTPResponse = &http.Response{StatusCode: 401, Body: body(`{"__type":"SignatureDoesNotMatch","message":"Signature does not match."}`)}
})
out := &testData{}
r := NewRequest(s, &Operation{Name: "Operation"}, nil, out)
err := r.Send()
apiErr := Error(err)
assert.NotNil(t, err)
assert.NotNil(t, apiErr)
assert.Equal(t, 401, apiErr.StatusCode)
assert.Equal(t, "SignatureDoesNotMatch", apiErr.Code)
assert.Equal(t, "Signature does not match.", apiErr.Message)
assert.Equal(t, 0, int(r.RetryCount))
}
func TestRequestExhaustRetries(t *testing.T) {
delays := []time.Duration{}
sleepDelay = func(delay time.Duration) {
delays = append(delays, delay)
}
reqNum := 0
reqs := []http.Response{
http.Response{StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
http.Response{StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
http.Response{StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
http.Response{StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
}
s := NewService(&Config{MaxRetries: -1})
s.Handlers.Validate.Clear()
s.Handlers.Unmarshal.PushBack(unmarshal)
s.Handlers.UnmarshalError.PushBack(unmarshalError)
s.Handlers.Send.Clear() // mock sending
s.Handlers.Send.PushBack(func(r *Request) {
r.HTTPResponse = &reqs[reqNum]
reqNum++
})
r := NewRequest(s, &Operation{Name: "Operation"}, nil, nil)
err := r.Send()
apiErr := Error(err)
assert.NotNil(t, err)
assert.NotNil(t, apiErr)
assert.Equal(t, 500, apiErr.StatusCode)
assert.Equal(t, "UnknownError", apiErr.Code)
assert.Equal(t, "An error occurred.", apiErr.Message)
assert.Equal(t, 3, int(r.RetryCount))
assert.True(t, reflect.DeepEqual([]time.Duration{30 * time.Millisecond, 60 * time.Millisecond, 120 * time.Millisecond}, delays))
}
// test that the request is retried after the credentials are expired.
func TestRequestRecoverExpiredCreds(t *testing.T) {
reqNum := 0
reqs := []http.Response{
http.Response{StatusCode: 400, Body: body(`{"__type":"ExpiredTokenException","message":"expired token"}`)},
http.Response{StatusCode: 200, Body: body(`{"data":"valid"}`)},
}
s := NewService(&Config{MaxRetries: 10, Credentials: credentials.NewStaticCredentials("AKID", "SECRET", "")})
s.Handlers.Validate.Clear()
s.Handlers.Unmarshal.PushBack(unmarshal)
s.Handlers.UnmarshalError.PushBack(unmarshalError)
credExpiredBeforeRetry := false
credExpiredAfterRetry := false
s.Handlers.Retry.PushBack(func(r *Request) {
if err := Error(r.Error); err != nil && err.Code == "ExpiredTokenException" {
credExpiredBeforeRetry = r.Config.Credentials.IsExpired()
}
})
s.Handlers.AfterRetry.PushBack(func(r *Request) {
credExpiredAfterRetry = r.Config.Credentials.IsExpired()
})
s.Handlers.Sign.Clear()
s.Handlers.Sign.PushBack(func(r *Request) {
r.Config.Credentials.Get()
})
s.Handlers.Send.Clear() // mock sending
s.Handlers.Send.PushBack(func(r *Request) {
r.HTTPResponse = &reqs[reqNum]
reqNum++
})
out := &testData{}
r := NewRequest(s, &Operation{Name: "Operation"}, nil, out)
err := r.Send()
assert.Nil(t, err)
assert.False(t, credExpiredBeforeRetry, "Expect valid creds before retry check")
assert.True(t, credExpiredAfterRetry, "Expect expired creds after retry check")
assert.False(t, s.Config.Credentials.IsExpired(), "Expect valid creds after cred expired recovery")
assert.Equal(t, 1, int(r.RetryCount))
assert.Equal(t, "valid", out.Data)
}

View File

@@ -0,0 +1,154 @@
package aws
import (
"fmt"
"math"
"net/http"
"net/http/httputil"
"regexp"
"time"
"github.com/awslabs/aws-sdk-go/internal/endpoints"
)
// A Service implements the base service request and response handling
// used by all services.
type Service struct {
Config *Config
Handlers Handlers
ManualSend bool
ServiceName string
APIVersion string
Endpoint string
SigningName string
SigningRegion string
JSONVersion string
TargetPrefix string
RetryRules func(*Request) time.Duration
ShouldRetry func(*Request) bool
DefaultMaxRetries uint
}
var schemeRE = regexp.MustCompile("^([^:]+)://")
// NewService will return a pointer to a new Server object initialized.
func NewService(config *Config) *Service {
svc := &Service{Config: config}
svc.Initialize()
return svc
}
// Initialize initializes the service.
func (s *Service) Initialize() {
if s.Config == nil {
s.Config = &Config{}
}
if s.Config.HTTPClient == nil {
s.Config.HTTPClient = http.DefaultClient
}
if s.RetryRules == nil {
s.RetryRules = retryRules
}
if s.ShouldRetry == nil {
s.ShouldRetry = shouldRetry
}
s.DefaultMaxRetries = 3
s.Handlers.Validate.PushBack(ValidateEndpointHandler)
s.Handlers.Build.PushBack(UserAgentHandler)
s.Handlers.Sign.PushBack(BuildContentLength)
s.Handlers.Send.PushBack(SendHandler)
s.Handlers.AfterRetry.PushBack(AfterRetryHandler)
s.Handlers.ValidateResponse.PushBack(ValidateResponseHandler)
s.AddDebugHandlers()
s.buildEndpoint()
if !s.Config.DisableParamValidation {
s.Handlers.Validate.PushBack(ValidateParameters)
}
}
// buildEndpoint builds the endpoint values the service will use to make requests with.
func (s *Service) buildEndpoint() {
if s.Config.Endpoint != "" {
s.Endpoint = s.Config.Endpoint
} else {
s.Endpoint, s.SigningRegion =
endpoints.EndpointForRegion(s.ServiceName, s.Config.Region)
}
if s.Endpoint != "" && !schemeRE.MatchString(s.Endpoint) {
scheme := "https"
if s.Config.DisableSSL {
scheme = "http"
}
s.Endpoint = scheme + "://" + s.Endpoint
}
}
// AddDebugHandlers injects debug logging handlers into the service to log request
// debug information.
func (s *Service) AddDebugHandlers() {
out := s.Config.Logger
if s.Config.LogLevel == 0 {
return
}
s.Handlers.Send.PushFront(func(r *Request) {
logBody := r.Config.LogHTTPBody
dumpedBody, _ := httputil.DumpRequestOut(r.HTTPRequest, logBody)
fmt.Fprintf(out, "---[ REQUEST POST-SIGN ]-----------------------------\n")
fmt.Fprintf(out, "%s\n", string(dumpedBody))
fmt.Fprintf(out, "-----------------------------------------------------\n")
})
s.Handlers.Send.PushBack(func(r *Request) {
fmt.Fprintf(out, "---[ RESPONSE ]--------------------------------------\n")
if r.HTTPResponse != nil {
logBody := r.Config.LogHTTPBody
dumpedBody, _ := httputil.DumpResponse(r.HTTPResponse, logBody)
fmt.Fprintf(out, "%s\n", string(dumpedBody))
} else if r.Error != nil {
fmt.Fprintf(out, "%s\n", r.Error)
}
fmt.Fprintf(out, "-----------------------------------------------------\n")
})
}
// MaxRetries returns the number of maximum returns the service will use to make
// an individual API request.
func (s *Service) MaxRetries() uint {
if s.Config.MaxRetries < 0 {
return s.DefaultMaxRetries
}
return uint(s.Config.MaxRetries)
}
// retryRules returns the delay duration before retrying this request again
func retryRules(r *Request) time.Duration {
delay := time.Duration(math.Pow(2, float64(r.RetryCount))) * 30
return delay * time.Millisecond
}
// Collection of service response codes which are generically
// retryable for all services.
var retryableCodes = map[string]struct{}{
"ExpiredTokenException": struct{}{},
"ProvisionedThroughputExceededException": struct{}{},
"Throttling": struct{}{},
}
// shouldRetry returns if the request should be retried.
func shouldRetry(r *Request) bool {
if r.HTTPResponse.StatusCode >= 500 {
return true
}
if err := Error(r.Error); err != nil {
if _, ok := retryableCodes[err.Code]; ok {
return true
}
}
return false
}

View File

@@ -0,0 +1,131 @@
package aws
import (
"fmt"
"io"
"time"
)
// String converts a Go string into a string pointer.
func String(v string) *string {
return &v
}
// Boolean converts a Go bool into a boolean pointer.
func Boolean(v bool) *bool {
return &v
}
// Long converts a Go int64 into a long pointer.
func Long(v int64) *int64 {
return &v
}
// Double converts a Go float64 into a double pointer.
func Double(v float64) *float64 {
return &v
}
// Time converts a Go Time into a Time pointer
func Time(t time.Time) *time.Time {
return &t
}
// ReadSeekCloser wraps a io.Reader returning a ReaderSeakerCloser
func ReadSeekCloser(r io.Reader) ReaderSeekerCloser {
return ReaderSeekerCloser{r}
}
// ReaderSeekerCloser represents a reader that can also delegate io.Seeker and
// io.Closer interfaces to the underlying object if they are available.
type ReaderSeekerCloser struct {
r io.Reader
}
// Read reads from the reader up to size of p. The number of bytes read, and
// error if it occurred will be returned.
//
// If the reader is not an io.Reader zero bytes read, and nil error will be returned.
//
// Performs the same functionality as io.Reader Read
func (r ReaderSeekerCloser) Read(p []byte) (int, error) {
switch t := r.r.(type) {
case io.Reader:
return t.Read(p)
}
return 0, nil
}
// Seek sets the offset for the next Read to offset, interpreted according to
// whence: 0 means relative to the origin of the file, 1 means relative to the
// current offset, and 2 means relative to the end. Seek returns the new offset
// and an error, if any.
//
// If the ReaderSeekerCloser is not an io.Seeker nothing will be done.
func (r ReaderSeekerCloser) Seek(offset int64, whence int) (int64, error) {
switch t := r.r.(type) {
case io.Seeker:
return t.Seek(offset, whence)
}
return int64(0), nil
}
// Close closes the ReaderSeekerCloser.
//
// If the ReaderSeekerCloser is not an io.Closer nothing will be done.
func (r ReaderSeekerCloser) Close() error {
switch t := r.r.(type) {
case io.Closer:
return t.Close()
}
return nil
}
// A SettableBool provides a boolean value which includes the state if
// the value was set or unset. The set state is in addition to the value's
// value(true|false)
type SettableBool struct {
value bool
set bool
}
// SetBool returns a SettableBool with a value set
func SetBool(value bool) SettableBool {
return SettableBool{value: value, set: true}
}
// Get returns the value. Will always be false if the SettableBool was not set.
func (b *SettableBool) Get() bool {
if !b.set {
return false
}
return b.value
}
// Set sets the value and updates the state that the value has been set.
func (b *SettableBool) Set(value bool) {
b.value = value
b.set = true
}
// IsSet returns if the value has been set
func (b *SettableBool) IsSet() bool {
return b.set
}
// Reset resets the state and value of the SettableBool to its initial default
// state of not set and zero value.
func (b *SettableBool) Reset() {
b.value = false
b.set = false
}
// String returns the string representation of the value if set. Zero if not set.
func (b *SettableBool) String() string {
return fmt.Sprintf("%t", b.Get())
}
// GoString returns the string representation of the SettableBool value and state
func (b *SettableBool) GoString() string {
return fmt.Sprintf("Bool{value:%t, set:%t}", b.value, b.set)
}

View File

@@ -0,0 +1,8 @@
// Package aws provides core functionality for making requests to AWS services.
package aws
// SDKName is the name of this AWS SDK
const SDKName = "aws-sdk-go"
// SDKVersion is the version of this SDK
const SDKVersion = "0.5.0"

View File

@@ -0,0 +1,29 @@
package endpoints
//go:generate go run ../model/cli/gen-endpoints/main.go endpoints.json endpoints_map.go
import "strings"
// EndpointForRegion returns an endpoint and its signing region for a service and region.
// if the service and region pair are not found endpoint and signingRegion will be empty.
func EndpointForRegion(svcName, region string) (endpoint, signingRegion string) {
derivedKeys := []string{
region + "/" + svcName,
region + "/*",
"*/" + svcName,
"*/*",
}
for _, key := range derivedKeys {
if val, ok := endpointsMap.Endpoints[key]; ok {
ep := val.Endpoint
ep = strings.Replace(ep, "{region}", region, -1)
ep = strings.Replace(ep, "{service}", svcName, -1)
endpoint = ep
signingRegion = val.SigningRegion
return
}
}
return
}

View File

@@ -0,0 +1,77 @@
{
"version": 2,
"endpoints": {
"*/*": {
"endpoint": "{service}.{region}.amazonaws.com"
},
"cn-north-1/*": {
"endpoint": "{service}.{region}.amazonaws.com.cn",
"signatureVersion": "v4"
},
"us-gov-west-1/iam": {
"endpoint": "iam.us-gov.amazonaws.com"
},
"us-gov-west-1/sts": {
"endpoint": "sts.us-gov-west-1.amazonaws.com"
},
"us-gov-west-1/s3": {
"endpoint": "s3-{region}.amazonaws.com"
},
"*/cloudfront": {
"endpoint": "cloudfront.amazonaws.com",
"signingRegion": "us-east-1"
},
"*/cloudsearchdomain": {
"endpoint": "",
"signingRegion": "us-east-1"
},
"*/iam": {
"endpoint": "iam.amazonaws.com",
"signingRegion": "us-east-1"
},
"*/importexport": {
"endpoint": "importexport.amazonaws.com",
"signingRegion": "us-east-1"
},
"*/route53": {
"endpoint": "route53.amazonaws.com",
"signingRegion": "us-east-1"
},
"*/sts": {
"endpoint": "sts.amazonaws.com",
"signingRegion": "us-east-1"
},
"us-east-1/sdb": {
"endpoint": "sdb.amazonaws.com",
"signingRegion": "us-east-1"
},
"us-east-1/s3": {
"endpoint": "s3.amazonaws.com"
},
"us-west-1/s3": {
"endpoint": "s3-{region}.amazonaws.com"
},
"us-west-2/s3": {
"endpoint": "s3-{region}.amazonaws.com"
},
"eu-west-1/s3": {
"endpoint": "s3-{region}.amazonaws.com"
},
"ap-southeast-1/s3": {
"endpoint": "s3-{region}.amazonaws.com"
},
"ap-southeast-2/s3": {
"endpoint": "s3-{region}.amazonaws.com"
},
"ap-northeast-1/s3": {
"endpoint": "s3-{region}.amazonaws.com"
},
"sa-east-1/s3": {
"endpoint": "s3-{region}.amazonaws.com"
},
"eu-central-1/s3": {
"endpoint": "{service}.{region}.amazonaws.com",
"signatureVersion": "v4"
}
}
}

View File

@@ -0,0 +1,89 @@
package endpoints
// THIS FILE IS AUTOMATICALLY GENERATED. DO NOT EDIT.
type endpointStruct struct {
Version int
Endpoints map[string]endpointEntry
}
type endpointEntry struct {
Endpoint string
SigningRegion string
}
var endpointsMap = endpointStruct{
Version: 2,
Endpoints: map[string]endpointEntry{
"*/*": endpointEntry{
Endpoint: "{service}.{region}.amazonaws.com",
},
"*/cloudfront": endpointEntry{
Endpoint: "cloudfront.amazonaws.com",
SigningRegion: "us-east-1",
},
"*/cloudsearchdomain": endpointEntry{
Endpoint: "",
SigningRegion: "us-east-1",
},
"*/iam": endpointEntry{
Endpoint: "iam.amazonaws.com",
SigningRegion: "us-east-1",
},
"*/importexport": endpointEntry{
Endpoint: "importexport.amazonaws.com",
SigningRegion: "us-east-1",
},
"*/route53": endpointEntry{
Endpoint: "route53.amazonaws.com",
SigningRegion: "us-east-1",
},
"*/sts": endpointEntry{
Endpoint: "sts.amazonaws.com",
SigningRegion: "us-east-1",
},
"ap-northeast-1/s3": endpointEntry{
Endpoint: "s3-{region}.amazonaws.com",
},
"ap-southeast-1/s3": endpointEntry{
Endpoint: "s3-{region}.amazonaws.com",
},
"ap-southeast-2/s3": endpointEntry{
Endpoint: "s3-{region}.amazonaws.com",
},
"cn-north-1/*": endpointEntry{
Endpoint: "{service}.{region}.amazonaws.com.cn",
},
"eu-central-1/s3": endpointEntry{
Endpoint: "{service}.{region}.amazonaws.com",
},
"eu-west-1/s3": endpointEntry{
Endpoint: "s3-{region}.amazonaws.com",
},
"sa-east-1/s3": endpointEntry{
Endpoint: "s3-{region}.amazonaws.com",
},
"us-east-1/s3": endpointEntry{
Endpoint: "s3.amazonaws.com",
},
"us-east-1/sdb": endpointEntry{
Endpoint: "sdb.amazonaws.com",
SigningRegion: "us-east-1",
},
"us-gov-west-1/iam": endpointEntry{
Endpoint: "iam.us-gov.amazonaws.com",
},
"us-gov-west-1/s3": endpointEntry{
Endpoint: "s3-{region}.amazonaws.com",
},
"us-gov-west-1/sts": endpointEntry{
Endpoint: "sts.us-gov-west-1.amazonaws.com",
},
"us-west-1/s3": endpointEntry{
Endpoint: "s3-{region}.amazonaws.com",
},
"us-west-2/s3": endpointEntry{
Endpoint: "s3-{region}.amazonaws.com",
},
},
}

View File

@@ -0,0 +1,28 @@
package endpoints
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestGlobalEndpoints(t *testing.T) {
region := "mock-region-1"
svcs := []string{"cloudfront", "iam", "importexport", "route53", "sts"}
for _, name := range svcs {
ep, sr := EndpointForRegion(name, region)
assert.Equal(t, name+".amazonaws.com", ep)
assert.Equal(t, "us-east-1", sr)
}
}
func TestServicesInCN(t *testing.T) {
region := "cn-north-1"
svcs := []string{"cloudfront", "iam", "importexport", "route53", "sts", "s3"}
for _, name := range svcs {
ep, _ := EndpointForRegion(name, region)
assert.Equal(t, name+"."+region+".amazonaws.com.cn", ep)
}
}

View File

@@ -0,0 +1,31 @@
package ec2query
//go:generate go run ../../fixtures/protocol/generate.go ../../fixtures/protocol/input/ec2.json build_test.go
import (
"net/url"
"github.com/awslabs/aws-sdk-go/aws"
"github.com/awslabs/aws-sdk-go/internal/protocol/query/queryutil"
)
// Build builds a request for the EC2 protocol.
func Build(r *aws.Request) {
body := url.Values{
"Action": {r.Operation.Name},
"Version": {r.Service.APIVersion},
}
if err := queryutil.Parse(body, r.Params, true); err != nil {
r.Error = err
return
}
if r.ExpireTime == 0 {
r.HTTPRequest.Method = "POST"
r.HTTPRequest.Header.Set("Content-Type", "application/x-www-form-urlencoded; charset=utf-8")
r.SetBufferBody([]byte(body.Encode()))
} else { // This is a pre-signed request
r.HTTPRequest.Method = "GET"
r.HTTPRequest.URL.RawQuery = body.Encode()
}
}

View File

@@ -0,0 +1,933 @@
package ec2query_test
import (
"github.com/awslabs/aws-sdk-go/aws"
"github.com/awslabs/aws-sdk-go/internal/protocol/ec2query"
"github.com/awslabs/aws-sdk-go/internal/signer/v4"
"bytes"
"encoding/json"
"encoding/xml"
"github.com/awslabs/aws-sdk-go/internal/protocol/xml/xmlutil"
"github.com/awslabs/aws-sdk-go/internal/util"
"github.com/stretchr/testify/assert"
"io"
"io/ioutil"
"net/http"
"net/url"
"testing"
"time"
)
var _ bytes.Buffer // always import bytes
var _ http.Request
var _ json.Marshaler
var _ time.Time
var _ xmlutil.XMLNode
var _ xml.Attr
var _ = ioutil.Discard
var _ = util.Trim("")
var _ = url.Values{}
var _ = io.EOF
// InputService1ProtocolTest is a client for InputService1ProtocolTest.
type InputService1ProtocolTest struct {
*aws.Service
}
// New returns a new InputService1ProtocolTest client.
func NewInputService1ProtocolTest(config *aws.Config) *InputService1ProtocolTest {
if config == nil {
config = &aws.Config{}
}
service := &aws.Service{
Config: aws.DefaultConfig.Merge(config),
ServiceName: "inputservice1protocoltest",
APIVersion: "2014-01-01",
}
service.Initialize()
// Handlers
service.Handlers.Sign.PushBack(v4.Sign)
service.Handlers.Build.PushBack(ec2query.Build)
service.Handlers.Unmarshal.PushBack(ec2query.Unmarshal)
service.Handlers.UnmarshalMeta.PushBack(ec2query.UnmarshalMeta)
service.Handlers.UnmarshalError.PushBack(ec2query.UnmarshalError)
return &InputService1ProtocolTest{service}
}
// newRequest creates a new request for a InputService1ProtocolTest operation and runs any
// custom request initialization.
func (c *InputService1ProtocolTest) newRequest(op *aws.Operation, params, data interface{}) *aws.Request {
req := aws.NewRequest(c.Service, op, params, data)
return req
}
// InputService1TestCaseOperation1Request generates a request for the InputService1TestCaseOperation1 operation.
func (c *InputService1ProtocolTest) InputService1TestCaseOperation1Request(input *InputService1TestShapeInputShape) (req *aws.Request, output *InputService1TestShapeInputService1TestCaseOperation1Output) {
if opInputService1TestCaseOperation1 == nil {
opInputService1TestCaseOperation1 = &aws.Operation{
Name: "OperationName",
}
}
if input == nil {
input = &InputService1TestShapeInputShape{}
}
req = c.newRequest(opInputService1TestCaseOperation1, input, output)
output = &InputService1TestShapeInputService1TestCaseOperation1Output{}
req.Data = output
return
}
func (c *InputService1ProtocolTest) InputService1TestCaseOperation1(input *InputService1TestShapeInputShape) (output *InputService1TestShapeInputService1TestCaseOperation1Output, err error) {
req, out := c.InputService1TestCaseOperation1Request(input)
output = out
err = req.Send()
return
}
var opInputService1TestCaseOperation1 *aws.Operation
type InputService1TestShapeInputService1TestCaseOperation1Output struct {
metadataInputService1TestShapeInputService1TestCaseOperation1Output `json:"-" xml:"-"`
}
type metadataInputService1TestShapeInputService1TestCaseOperation1Output struct {
SDKShapeTraits bool `type:"structure"`
}
type InputService1TestShapeInputShape struct {
Bar *string `type:"string"`
Foo *string `type:"string"`
metadataInputService1TestShapeInputShape `json:"-" xml:"-"`
}
type metadataInputService1TestShapeInputShape struct {
SDKShapeTraits bool `type:"structure"`
}
// InputService2ProtocolTest is a client for InputService2ProtocolTest.
type InputService2ProtocolTest struct {
*aws.Service
}
// New returns a new InputService2ProtocolTest client.
func NewInputService2ProtocolTest(config *aws.Config) *InputService2ProtocolTest {
if config == nil {
config = &aws.Config{}
}
service := &aws.Service{
Config: aws.DefaultConfig.Merge(config),
ServiceName: "inputservice2protocoltest",
APIVersion: "2014-01-01",
}
service.Initialize()
// Handlers
service.Handlers.Sign.PushBack(v4.Sign)
service.Handlers.Build.PushBack(ec2query.Build)
service.Handlers.Unmarshal.PushBack(ec2query.Unmarshal)
service.Handlers.UnmarshalMeta.PushBack(ec2query.UnmarshalMeta)
service.Handlers.UnmarshalError.PushBack(ec2query.UnmarshalError)
return &InputService2ProtocolTest{service}
}
// newRequest creates a new request for a InputService2ProtocolTest operation and runs any
// custom request initialization.
func (c *InputService2ProtocolTest) newRequest(op *aws.Operation, params, data interface{}) *aws.Request {
req := aws.NewRequest(c.Service, op, params, data)
return req
}
// InputService2TestCaseOperation1Request generates a request for the InputService2TestCaseOperation1 operation.
func (c *InputService2ProtocolTest) InputService2TestCaseOperation1Request(input *InputService2TestShapeInputShape) (req *aws.Request, output *InputService2TestShapeInputService2TestCaseOperation1Output) {
if opInputService2TestCaseOperation1 == nil {
opInputService2TestCaseOperation1 = &aws.Operation{
Name: "OperationName",
}
}
if input == nil {
input = &InputService2TestShapeInputShape{}
}
req = c.newRequest(opInputService2TestCaseOperation1, input, output)
output = &InputService2TestShapeInputService2TestCaseOperation1Output{}
req.Data = output
return
}
func (c *InputService2ProtocolTest) InputService2TestCaseOperation1(input *InputService2TestShapeInputShape) (output *InputService2TestShapeInputService2TestCaseOperation1Output, err error) {
req, out := c.InputService2TestCaseOperation1Request(input)
output = out
err = req.Send()
return
}
var opInputService2TestCaseOperation1 *aws.Operation
type InputService2TestShapeInputService2TestCaseOperation1Output struct {
metadataInputService2TestShapeInputService2TestCaseOperation1Output `json:"-" xml:"-"`
}
type metadataInputService2TestShapeInputService2TestCaseOperation1Output struct {
SDKShapeTraits bool `type:"structure"`
}
type InputService2TestShapeInputShape struct {
Bar *string `locationName:"barLocationName" type:"string"`
Foo *string `type:"string"`
Yuck *string `locationName:"yuckLocationName" queryName:"yuckQueryName" type:"string"`
metadataInputService2TestShapeInputShape `json:"-" xml:"-"`
}
type metadataInputService2TestShapeInputShape struct {
SDKShapeTraits bool `type:"structure"`
}
// InputService3ProtocolTest is a client for InputService3ProtocolTest.
type InputService3ProtocolTest struct {
*aws.Service
}
// New returns a new InputService3ProtocolTest client.
func NewInputService3ProtocolTest(config *aws.Config) *InputService3ProtocolTest {
if config == nil {
config = &aws.Config{}
}
service := &aws.Service{
Config: aws.DefaultConfig.Merge(config),
ServiceName: "inputservice3protocoltest",
APIVersion: "2014-01-01",
}
service.Initialize()
// Handlers
service.Handlers.Sign.PushBack(v4.Sign)
service.Handlers.Build.PushBack(ec2query.Build)
service.Handlers.Unmarshal.PushBack(ec2query.Unmarshal)
service.Handlers.UnmarshalMeta.PushBack(ec2query.UnmarshalMeta)
service.Handlers.UnmarshalError.PushBack(ec2query.UnmarshalError)
return &InputService3ProtocolTest{service}
}
// newRequest creates a new request for a InputService3ProtocolTest operation and runs any
// custom request initialization.
func (c *InputService3ProtocolTest) newRequest(op *aws.Operation, params, data interface{}) *aws.Request {
req := aws.NewRequest(c.Service, op, params, data)
return req
}
// InputService3TestCaseOperation1Request generates a request for the InputService3TestCaseOperation1 operation.
func (c *InputService3ProtocolTest) InputService3TestCaseOperation1Request(input *InputService3TestShapeInputShape) (req *aws.Request, output *InputService3TestShapeInputService3TestCaseOperation1Output) {
if opInputService3TestCaseOperation1 == nil {
opInputService3TestCaseOperation1 = &aws.Operation{
Name: "OperationName",
}
}
if input == nil {
input = &InputService3TestShapeInputShape{}
}
req = c.newRequest(opInputService3TestCaseOperation1, input, output)
output = &InputService3TestShapeInputService3TestCaseOperation1Output{}
req.Data = output
return
}
func (c *InputService3ProtocolTest) InputService3TestCaseOperation1(input *InputService3TestShapeInputShape) (output *InputService3TestShapeInputService3TestCaseOperation1Output, err error) {
req, out := c.InputService3TestCaseOperation1Request(input)
output = out
err = req.Send()
return
}
var opInputService3TestCaseOperation1 *aws.Operation
type InputService3TestShapeInputService3TestCaseOperation1Output struct {
metadataInputService3TestShapeInputService3TestCaseOperation1Output `json:"-" xml:"-"`
}
type metadataInputService3TestShapeInputService3TestCaseOperation1Output struct {
SDKShapeTraits bool `type:"structure"`
}
type InputService3TestShapeInputShape struct {
StructArg *InputService3TestShapeStructType `locationName:"Struct" type:"structure"`
metadataInputService3TestShapeInputShape `json:"-" xml:"-"`
}
type metadataInputService3TestShapeInputShape struct {
SDKShapeTraits bool `type:"structure"`
}
type InputService3TestShapeStructType struct {
ScalarArg *string `locationName:"Scalar" type:"string"`
metadataInputService3TestShapeStructType `json:"-" xml:"-"`
}
type metadataInputService3TestShapeStructType struct {
SDKShapeTraits bool `type:"structure"`
}
// InputService4ProtocolTest is a client for InputService4ProtocolTest.
type InputService4ProtocolTest struct {
*aws.Service
}
// New returns a new InputService4ProtocolTest client.
func NewInputService4ProtocolTest(config *aws.Config) *InputService4ProtocolTest {
if config == nil {
config = &aws.Config{}
}
service := &aws.Service{
Config: aws.DefaultConfig.Merge(config),
ServiceName: "inputservice4protocoltest",
APIVersion: "2014-01-01",
}
service.Initialize()
// Handlers
service.Handlers.Sign.PushBack(v4.Sign)
service.Handlers.Build.PushBack(ec2query.Build)
service.Handlers.Unmarshal.PushBack(ec2query.Unmarshal)
service.Handlers.UnmarshalMeta.PushBack(ec2query.UnmarshalMeta)
service.Handlers.UnmarshalError.PushBack(ec2query.UnmarshalError)
return &InputService4ProtocolTest{service}
}
// newRequest creates a new request for a InputService4ProtocolTest operation and runs any
// custom request initialization.
func (c *InputService4ProtocolTest) newRequest(op *aws.Operation, params, data interface{}) *aws.Request {
req := aws.NewRequest(c.Service, op, params, data)
return req
}
// InputService4TestCaseOperation1Request generates a request for the InputService4TestCaseOperation1 operation.
func (c *InputService4ProtocolTest) InputService4TestCaseOperation1Request(input *InputService4TestShapeInputShape) (req *aws.Request, output *InputService4TestShapeInputService4TestCaseOperation1Output) {
if opInputService4TestCaseOperation1 == nil {
opInputService4TestCaseOperation1 = &aws.Operation{
Name: "OperationName",
}
}
if input == nil {
input = &InputService4TestShapeInputShape{}
}
req = c.newRequest(opInputService4TestCaseOperation1, input, output)
output = &InputService4TestShapeInputService4TestCaseOperation1Output{}
req.Data = output
return
}
func (c *InputService4ProtocolTest) InputService4TestCaseOperation1(input *InputService4TestShapeInputShape) (output *InputService4TestShapeInputService4TestCaseOperation1Output, err error) {
req, out := c.InputService4TestCaseOperation1Request(input)
output = out
err = req.Send()
return
}
var opInputService4TestCaseOperation1 *aws.Operation
type InputService4TestShapeInputService4TestCaseOperation1Output struct {
metadataInputService4TestShapeInputService4TestCaseOperation1Output `json:"-" xml:"-"`
}
type metadataInputService4TestShapeInputService4TestCaseOperation1Output struct {
SDKShapeTraits bool `type:"structure"`
}
type InputService4TestShapeInputShape struct {
ListArg []*string `type:"list"`
metadataInputService4TestShapeInputShape `json:"-" xml:"-"`
}
type metadataInputService4TestShapeInputShape struct {
SDKShapeTraits bool `type:"structure"`
}
// InputService5ProtocolTest is a client for InputService5ProtocolTest.
type InputService5ProtocolTest struct {
*aws.Service
}
// New returns a new InputService5ProtocolTest client.
func NewInputService5ProtocolTest(config *aws.Config) *InputService5ProtocolTest {
if config == nil {
config = &aws.Config{}
}
service := &aws.Service{
Config: aws.DefaultConfig.Merge(config),
ServiceName: "inputservice5protocoltest",
APIVersion: "2014-01-01",
}
service.Initialize()
// Handlers
service.Handlers.Sign.PushBack(v4.Sign)
service.Handlers.Build.PushBack(ec2query.Build)
service.Handlers.Unmarshal.PushBack(ec2query.Unmarshal)
service.Handlers.UnmarshalMeta.PushBack(ec2query.UnmarshalMeta)
service.Handlers.UnmarshalError.PushBack(ec2query.UnmarshalError)
return &InputService5ProtocolTest{service}
}
// newRequest creates a new request for a InputService5ProtocolTest operation and runs any
// custom request initialization.
func (c *InputService5ProtocolTest) newRequest(op *aws.Operation, params, data interface{}) *aws.Request {
req := aws.NewRequest(c.Service, op, params, data)
return req
}
// InputService5TestCaseOperation1Request generates a request for the InputService5TestCaseOperation1 operation.
func (c *InputService5ProtocolTest) InputService5TestCaseOperation1Request(input *InputService5TestShapeInputShape) (req *aws.Request, output *InputService5TestShapeInputService5TestCaseOperation1Output) {
if opInputService5TestCaseOperation1 == nil {
opInputService5TestCaseOperation1 = &aws.Operation{
Name: "OperationName",
}
}
if input == nil {
input = &InputService5TestShapeInputShape{}
}
req = c.newRequest(opInputService5TestCaseOperation1, input, output)
output = &InputService5TestShapeInputService5TestCaseOperation1Output{}
req.Data = output
return
}
func (c *InputService5ProtocolTest) InputService5TestCaseOperation1(input *InputService5TestShapeInputShape) (output *InputService5TestShapeInputService5TestCaseOperation1Output, err error) {
req, out := c.InputService5TestCaseOperation1Request(input)
output = out
err = req.Send()
return
}
var opInputService5TestCaseOperation1 *aws.Operation
type InputService5TestShapeInputService5TestCaseOperation1Output struct {
metadataInputService5TestShapeInputService5TestCaseOperation1Output `json:"-" xml:"-"`
}
type metadataInputService5TestShapeInputService5TestCaseOperation1Output struct {
SDKShapeTraits bool `type:"structure"`
}
type InputService5TestShapeInputShape struct {
ListArg []*string `locationName:"ListMemberName" locationNameList:"item" type:"list"`
metadataInputService5TestShapeInputShape `json:"-" xml:"-"`
}
type metadataInputService5TestShapeInputShape struct {
SDKShapeTraits bool `type:"structure"`
}
// InputService6ProtocolTest is a client for InputService6ProtocolTest.
type InputService6ProtocolTest struct {
*aws.Service
}
// New returns a new InputService6ProtocolTest client.
func NewInputService6ProtocolTest(config *aws.Config) *InputService6ProtocolTest {
if config == nil {
config = &aws.Config{}
}
service := &aws.Service{
Config: aws.DefaultConfig.Merge(config),
ServiceName: "inputservice6protocoltest",
APIVersion: "2014-01-01",
}
service.Initialize()
// Handlers
service.Handlers.Sign.PushBack(v4.Sign)
service.Handlers.Build.PushBack(ec2query.Build)
service.Handlers.Unmarshal.PushBack(ec2query.Unmarshal)
service.Handlers.UnmarshalMeta.PushBack(ec2query.UnmarshalMeta)
service.Handlers.UnmarshalError.PushBack(ec2query.UnmarshalError)
return &InputService6ProtocolTest{service}
}
// newRequest creates a new request for a InputService6ProtocolTest operation and runs any
// custom request initialization.
func (c *InputService6ProtocolTest) newRequest(op *aws.Operation, params, data interface{}) *aws.Request {
req := aws.NewRequest(c.Service, op, params, data)
return req
}
// InputService6TestCaseOperation1Request generates a request for the InputService6TestCaseOperation1 operation.
func (c *InputService6ProtocolTest) InputService6TestCaseOperation1Request(input *InputService6TestShapeInputShape) (req *aws.Request, output *InputService6TestShapeInputService6TestCaseOperation1Output) {
if opInputService6TestCaseOperation1 == nil {
opInputService6TestCaseOperation1 = &aws.Operation{
Name: "OperationName",
}
}
if input == nil {
input = &InputService6TestShapeInputShape{}
}
req = c.newRequest(opInputService6TestCaseOperation1, input, output)
output = &InputService6TestShapeInputService6TestCaseOperation1Output{}
req.Data = output
return
}
func (c *InputService6ProtocolTest) InputService6TestCaseOperation1(input *InputService6TestShapeInputShape) (output *InputService6TestShapeInputService6TestCaseOperation1Output, err error) {
req, out := c.InputService6TestCaseOperation1Request(input)
output = out
err = req.Send()
return
}
var opInputService6TestCaseOperation1 *aws.Operation
type InputService6TestShapeInputService6TestCaseOperation1Output struct {
metadataInputService6TestShapeInputService6TestCaseOperation1Output `json:"-" xml:"-"`
}
type metadataInputService6TestShapeInputService6TestCaseOperation1Output struct {
SDKShapeTraits bool `type:"structure"`
}
type InputService6TestShapeInputShape struct {
ListArg []*string `locationName:"ListMemberName" queryName:"ListQueryName" locationNameList:"item" type:"list"`
metadataInputService6TestShapeInputShape `json:"-" xml:"-"`
}
type metadataInputService6TestShapeInputShape struct {
SDKShapeTraits bool `type:"structure"`
}
// InputService7ProtocolTest is a client for InputService7ProtocolTest.
type InputService7ProtocolTest struct {
*aws.Service
}
// New returns a new InputService7ProtocolTest client.
func NewInputService7ProtocolTest(config *aws.Config) *InputService7ProtocolTest {
if config == nil {
config = &aws.Config{}
}
service := &aws.Service{
Config: aws.DefaultConfig.Merge(config),
ServiceName: "inputservice7protocoltest",
APIVersion: "2014-01-01",
}
service.Initialize()
// Handlers
service.Handlers.Sign.PushBack(v4.Sign)
service.Handlers.Build.PushBack(ec2query.Build)
service.Handlers.Unmarshal.PushBack(ec2query.Unmarshal)
service.Handlers.UnmarshalMeta.PushBack(ec2query.UnmarshalMeta)
service.Handlers.UnmarshalError.PushBack(ec2query.UnmarshalError)
return &InputService7ProtocolTest{service}
}
// newRequest creates a new request for a InputService7ProtocolTest operation and runs any
// custom request initialization.
func (c *InputService7ProtocolTest) newRequest(op *aws.Operation, params, data interface{}) *aws.Request {
req := aws.NewRequest(c.Service, op, params, data)
return req
}
// InputService7TestCaseOperation1Request generates a request for the InputService7TestCaseOperation1 operation.
func (c *InputService7ProtocolTest) InputService7TestCaseOperation1Request(input *InputService7TestShapeInputShape) (req *aws.Request, output *InputService7TestShapeInputService7TestCaseOperation1Output) {
if opInputService7TestCaseOperation1 == nil {
opInputService7TestCaseOperation1 = &aws.Operation{
Name: "OperationName",
}
}
if input == nil {
input = &InputService7TestShapeInputShape{}
}
req = c.newRequest(opInputService7TestCaseOperation1, input, output)
output = &InputService7TestShapeInputService7TestCaseOperation1Output{}
req.Data = output
return
}
func (c *InputService7ProtocolTest) InputService7TestCaseOperation1(input *InputService7TestShapeInputShape) (output *InputService7TestShapeInputService7TestCaseOperation1Output, err error) {
req, out := c.InputService7TestCaseOperation1Request(input)
output = out
err = req.Send()
return
}
var opInputService7TestCaseOperation1 *aws.Operation
type InputService7TestShapeInputService7TestCaseOperation1Output struct {
metadataInputService7TestShapeInputService7TestCaseOperation1Output `json:"-" xml:"-"`
}
type metadataInputService7TestShapeInputService7TestCaseOperation1Output struct {
SDKShapeTraits bool `type:"structure"`
}
type InputService7TestShapeInputShape struct {
BlobArg []byte `type:"blob"`
metadataInputService7TestShapeInputShape `json:"-" xml:"-"`
}
type metadataInputService7TestShapeInputShape struct {
SDKShapeTraits bool `type:"structure"`
}
// InputService8ProtocolTest is a client for InputService8ProtocolTest.
type InputService8ProtocolTest struct {
*aws.Service
}
// New returns a new InputService8ProtocolTest client.
func NewInputService8ProtocolTest(config *aws.Config) *InputService8ProtocolTest {
if config == nil {
config = &aws.Config{}
}
service := &aws.Service{
Config: aws.DefaultConfig.Merge(config),
ServiceName: "inputservice8protocoltest",
APIVersion: "2014-01-01",
}
service.Initialize()
// Handlers
service.Handlers.Sign.PushBack(v4.Sign)
service.Handlers.Build.PushBack(ec2query.Build)
service.Handlers.Unmarshal.PushBack(ec2query.Unmarshal)
service.Handlers.UnmarshalMeta.PushBack(ec2query.UnmarshalMeta)
service.Handlers.UnmarshalError.PushBack(ec2query.UnmarshalError)
return &InputService8ProtocolTest{service}
}
// newRequest creates a new request for a InputService8ProtocolTest operation and runs any
// custom request initialization.
func (c *InputService8ProtocolTest) newRequest(op *aws.Operation, params, data interface{}) *aws.Request {
req := aws.NewRequest(c.Service, op, params, data)
return req
}
// InputService8TestCaseOperation1Request generates a request for the InputService8TestCaseOperation1 operation.
func (c *InputService8ProtocolTest) InputService8TestCaseOperation1Request(input *InputService8TestShapeInputShape) (req *aws.Request, output *InputService8TestShapeInputService8TestCaseOperation1Output) {
if opInputService8TestCaseOperation1 == nil {
opInputService8TestCaseOperation1 = &aws.Operation{
Name: "OperationName",
}
}
if input == nil {
input = &InputService8TestShapeInputShape{}
}
req = c.newRequest(opInputService8TestCaseOperation1, input, output)
output = &InputService8TestShapeInputService8TestCaseOperation1Output{}
req.Data = output
return
}
func (c *InputService8ProtocolTest) InputService8TestCaseOperation1(input *InputService8TestShapeInputShape) (output *InputService8TestShapeInputService8TestCaseOperation1Output, err error) {
req, out := c.InputService8TestCaseOperation1Request(input)
output = out
err = req.Send()
return
}
var opInputService8TestCaseOperation1 *aws.Operation
type InputService8TestShapeInputService8TestCaseOperation1Output struct {
metadataInputService8TestShapeInputService8TestCaseOperation1Output `json:"-" xml:"-"`
}
type metadataInputService8TestShapeInputService8TestCaseOperation1Output struct {
SDKShapeTraits bool `type:"structure"`
}
type InputService8TestShapeInputShape struct {
TimeArg *time.Time `type:"timestamp" timestampFormat:"iso8601"`
metadataInputService8TestShapeInputShape `json:"-" xml:"-"`
}
type metadataInputService8TestShapeInputShape struct {
SDKShapeTraits bool `type:"structure"`
}
//
// Tests begin here
//
func TestInputService1ProtocolTestScalarMembersCase1(t *testing.T) {
svc := NewInputService1ProtocolTest(nil)
svc.Endpoint = "https://test"
input := &InputService1TestShapeInputShape{
Bar: aws.String("val2"),
Foo: aws.String("val1"),
}
req, _ := svc.InputService1TestCaseOperation1Request(input)
r := req.HTTPRequest
// build request
ec2query.Build(req)
assert.NoError(t, req.Error)
// assert body
assert.NotNil(t, r.Body)
body, _ := ioutil.ReadAll(r.Body)
assert.Equal(t, util.Trim(`Action=OperationName&Bar=val2&Foo=val1&Version=2014-01-01`), util.Trim(string(body)))
// assert URL
assert.Equal(t, "https://test/", r.URL.String())
// assert headers
}
func TestInputService2ProtocolTestStructureWithLocationNameAndQueryNameAppliedToMembersCase1(t *testing.T) {
svc := NewInputService2ProtocolTest(nil)
svc.Endpoint = "https://test"
input := &InputService2TestShapeInputShape{
Bar: aws.String("val2"),
Foo: aws.String("val1"),
Yuck: aws.String("val3"),
}
req, _ := svc.InputService2TestCaseOperation1Request(input)
r := req.HTTPRequest
// build request
ec2query.Build(req)
assert.NoError(t, req.Error)
// assert body
assert.NotNil(t, r.Body)
body, _ := ioutil.ReadAll(r.Body)
assert.Equal(t, util.Trim(`Action=OperationName&BarLocationName=val2&Foo=val1&Version=2014-01-01&yuckQueryName=val3`), util.Trim(string(body)))
// assert URL
assert.Equal(t, "https://test/", r.URL.String())
// assert headers
}
func TestInputService3ProtocolTestNestedStructureMembersCase1(t *testing.T) {
svc := NewInputService3ProtocolTest(nil)
svc.Endpoint = "https://test"
input := &InputService3TestShapeInputShape{
StructArg: &InputService3TestShapeStructType{
ScalarArg: aws.String("foo"),
},
}
req, _ := svc.InputService3TestCaseOperation1Request(input)
r := req.HTTPRequest
// build request
ec2query.Build(req)
assert.NoError(t, req.Error)
// assert body
assert.NotNil(t, r.Body)
body, _ := ioutil.ReadAll(r.Body)
assert.Equal(t, util.Trim(`Action=OperationName&Struct.Scalar=foo&Version=2014-01-01`), util.Trim(string(body)))
// assert URL
assert.Equal(t, "https://test/", r.URL.String())
// assert headers
}
func TestInputService4ProtocolTestListTypesCase1(t *testing.T) {
svc := NewInputService4ProtocolTest(nil)
svc.Endpoint = "https://test"
input := &InputService4TestShapeInputShape{
ListArg: []*string{
aws.String("foo"),
aws.String("bar"),
aws.String("baz"),
},
}
req, _ := svc.InputService4TestCaseOperation1Request(input)
r := req.HTTPRequest
// build request
ec2query.Build(req)
assert.NoError(t, req.Error)
// assert body
assert.NotNil(t, r.Body)
body, _ := ioutil.ReadAll(r.Body)
assert.Equal(t, util.Trim(`Action=OperationName&ListArg.1=foo&ListArg.2=bar&ListArg.3=baz&Version=2014-01-01`), util.Trim(string(body)))
// assert URL
assert.Equal(t, "https://test/", r.URL.String())
// assert headers
}
func TestInputService5ProtocolTestListWithLocationNameAppliedToMemberCase1(t *testing.T) {
svc := NewInputService5ProtocolTest(nil)
svc.Endpoint = "https://test"
input := &InputService5TestShapeInputShape{
ListArg: []*string{
aws.String("a"),
aws.String("b"),
aws.String("c"),
},
}
req, _ := svc.InputService5TestCaseOperation1Request(input)
r := req.HTTPRequest
// build request
ec2query.Build(req)
assert.NoError(t, req.Error)
// assert body
assert.NotNil(t, r.Body)
body, _ := ioutil.ReadAll(r.Body)
assert.Equal(t, util.Trim(`Action=OperationName&ListMemberName.1=a&ListMemberName.2=b&ListMemberName.3=c&Version=2014-01-01`), util.Trim(string(body)))
// assert URL
assert.Equal(t, "https://test/", r.URL.String())
// assert headers
}
func TestInputService6ProtocolTestListWithLocationNameAndQueryNameCase1(t *testing.T) {
svc := NewInputService6ProtocolTest(nil)
svc.Endpoint = "https://test"
input := &InputService6TestShapeInputShape{
ListArg: []*string{
aws.String("a"),
aws.String("b"),
aws.String("c"),
},
}
req, _ := svc.InputService6TestCaseOperation1Request(input)
r := req.HTTPRequest
// build request
ec2query.Build(req)
assert.NoError(t, req.Error)
// assert body
assert.NotNil(t, r.Body)
body, _ := ioutil.ReadAll(r.Body)
assert.Equal(t, util.Trim(`Action=OperationName&ListQueryName.1=a&ListQueryName.2=b&ListQueryName.3=c&Version=2014-01-01`), util.Trim(string(body)))
// assert URL
assert.Equal(t, "https://test/", r.URL.String())
// assert headers
}
func TestInputService7ProtocolTestBase64EncodedBlobsCase1(t *testing.T) {
svc := NewInputService7ProtocolTest(nil)
svc.Endpoint = "https://test"
input := &InputService7TestShapeInputShape{
BlobArg: []byte("foo"),
}
req, _ := svc.InputService7TestCaseOperation1Request(input)
r := req.HTTPRequest
// build request
ec2query.Build(req)
assert.NoError(t, req.Error)
// assert body
assert.NotNil(t, r.Body)
body, _ := ioutil.ReadAll(r.Body)
assert.Equal(t, util.Trim(`Action=OperationName&BlobArg=Zm9v&Version=2014-01-01`), util.Trim(string(body)))
// assert URL
assert.Equal(t, "https://test/", r.URL.String())
// assert headers
}
func TestInputService8ProtocolTestTimestampValuesCase1(t *testing.T) {
svc := NewInputService8ProtocolTest(nil)
svc.Endpoint = "https://test"
input := &InputService8TestShapeInputShape{
TimeArg: aws.Time(time.Unix(1422172800, 0)),
}
req, _ := svc.InputService8TestCaseOperation1Request(input)
r := req.HTTPRequest
// build request
ec2query.Build(req)
assert.NoError(t, req.Error)
// assert body
assert.NotNil(t, r.Body)
body, _ := ioutil.ReadAll(r.Body)
assert.Equal(t, util.Trim(`Action=OperationName&TimeArg=2015-01-25T08%3A00%3A00Z&Version=2014-01-01`), util.Trim(string(body)))
// assert URL
assert.Equal(t, "https://test/", r.URL.String())
// assert headers
}

View File

@@ -0,0 +1,53 @@
package ec2query
//go:generate go run ../../fixtures/protocol/generate.go ../../fixtures/protocol/output/ec2.json unmarshal_test.go
import (
"encoding/xml"
"io"
"github.com/awslabs/aws-sdk-go/aws"
"github.com/awslabs/aws-sdk-go/internal/protocol/xml/xmlutil"
)
// Unmarshal unmarshals a response body for the EC2 protocol.
func Unmarshal(r *aws.Request) {
defer r.HTTPResponse.Body.Close()
if r.DataFilled() {
decoder := xml.NewDecoder(r.HTTPResponse.Body)
err := xmlutil.UnmarshalXML(r.Data, decoder, "")
if err != nil {
r.Error = err
return
}
}
}
// UnmarshalMeta unmarshals response headers for the EC2 protocol.
func UnmarshalMeta(r *aws.Request) {
// TODO implement unmarshaling of request IDs
}
type xmlErrorResponse struct {
XMLName xml.Name `xml:"Response"`
Code string `xml:"Errors>Error>Code"`
Message string `xml:"Errors>Error>Message"`
RequestID string `xml:"RequestId"`
}
// UnmarshalError unmarshals a response error for the EC2 protocol.
func UnmarshalError(r *aws.Request) {
defer r.HTTPResponse.Body.Close()
resp := &xmlErrorResponse{}
err := xml.NewDecoder(r.HTTPResponse.Body).Decode(resp)
if err != nil && err != io.EOF {
r.Error = err
} else {
r.Error = aws.APIError{
StatusCode: r.HTTPResponse.StatusCode,
Code: resp.Code,
Message: resp.Message,
}
}
}

View File

@@ -0,0 +1,889 @@
package ec2query_test
import (
"github.com/awslabs/aws-sdk-go/aws"
"github.com/awslabs/aws-sdk-go/internal/protocol/ec2query"
"github.com/awslabs/aws-sdk-go/internal/signer/v4"
"bytes"
"encoding/json"
"encoding/xml"
"github.com/awslabs/aws-sdk-go/internal/protocol/xml/xmlutil"
"github.com/awslabs/aws-sdk-go/internal/util"
"github.com/stretchr/testify/assert"
"io"
"io/ioutil"
"net/http"
"net/url"
"testing"
"time"
)
var _ bytes.Buffer // always import bytes
var _ http.Request
var _ json.Marshaler
var _ time.Time
var _ xmlutil.XMLNode
var _ xml.Attr
var _ = ioutil.Discard
var _ = util.Trim("")
var _ = url.Values{}
var _ = io.EOF
// OutputService1ProtocolTest is a client for OutputService1ProtocolTest.
type OutputService1ProtocolTest struct {
*aws.Service
}
// New returns a new OutputService1ProtocolTest client.
func NewOutputService1ProtocolTest(config *aws.Config) *OutputService1ProtocolTest {
if config == nil {
config = &aws.Config{}
}
service := &aws.Service{
Config: aws.DefaultConfig.Merge(config),
ServiceName: "outputservice1protocoltest",
APIVersion: "",
}
service.Initialize()
// Handlers
service.Handlers.Sign.PushBack(v4.Sign)
service.Handlers.Build.PushBack(ec2query.Build)
service.Handlers.Unmarshal.PushBack(ec2query.Unmarshal)
service.Handlers.UnmarshalMeta.PushBack(ec2query.UnmarshalMeta)
service.Handlers.UnmarshalError.PushBack(ec2query.UnmarshalError)
return &OutputService1ProtocolTest{service}
}
// newRequest creates a new request for a OutputService1ProtocolTest operation and runs any
// custom request initialization.
func (c *OutputService1ProtocolTest) newRequest(op *aws.Operation, params, data interface{}) *aws.Request {
req := aws.NewRequest(c.Service, op, params, data)
return req
}
// OutputService1TestCaseOperation1Request generates a request for the OutputService1TestCaseOperation1 operation.
func (c *OutputService1ProtocolTest) OutputService1TestCaseOperation1Request(input *OutputService1TestShapeOutputService1TestCaseOperation1Input) (req *aws.Request, output *OutputService1TestShapeOutputShape) {
if opOutputService1TestCaseOperation1 == nil {
opOutputService1TestCaseOperation1 = &aws.Operation{
Name: "OperationName",
}
}
if input == nil {
input = &OutputService1TestShapeOutputService1TestCaseOperation1Input{}
}
req = c.newRequest(opOutputService1TestCaseOperation1, input, output)
output = &OutputService1TestShapeOutputShape{}
req.Data = output
return
}
func (c *OutputService1ProtocolTest) OutputService1TestCaseOperation1(input *OutputService1TestShapeOutputService1TestCaseOperation1Input) (output *OutputService1TestShapeOutputShape, err error) {
req, out := c.OutputService1TestCaseOperation1Request(input)
output = out
err = req.Send()
return
}
var opOutputService1TestCaseOperation1 *aws.Operation
type OutputService1TestShapeOutputService1TestCaseOperation1Input struct {
metadataOutputService1TestShapeOutputService1TestCaseOperation1Input `json:"-" xml:"-"`
}
type metadataOutputService1TestShapeOutputService1TestCaseOperation1Input struct {
SDKShapeTraits bool `type:"structure"`
}
type OutputService1TestShapeOutputShape struct {
Char *string `type:"character"`
Double *float64 `type:"double"`
FalseBool *bool `type:"boolean"`
Float *float64 `type:"float"`
Long *int64 `type:"long"`
Num *int64 `locationName:"FooNum" type:"integer"`
Str *string `type:"string"`
TrueBool *bool `type:"boolean"`
metadataOutputService1TestShapeOutputShape `json:"-" xml:"-"`
}
type metadataOutputService1TestShapeOutputShape struct {
SDKShapeTraits bool `type:"structure"`
}
// OutputService2ProtocolTest is a client for OutputService2ProtocolTest.
type OutputService2ProtocolTest struct {
*aws.Service
}
// New returns a new OutputService2ProtocolTest client.
func NewOutputService2ProtocolTest(config *aws.Config) *OutputService2ProtocolTest {
if config == nil {
config = &aws.Config{}
}
service := &aws.Service{
Config: aws.DefaultConfig.Merge(config),
ServiceName: "outputservice2protocoltest",
APIVersion: "",
}
service.Initialize()
// Handlers
service.Handlers.Sign.PushBack(v4.Sign)
service.Handlers.Build.PushBack(ec2query.Build)
service.Handlers.Unmarshal.PushBack(ec2query.Unmarshal)
service.Handlers.UnmarshalMeta.PushBack(ec2query.UnmarshalMeta)
service.Handlers.UnmarshalError.PushBack(ec2query.UnmarshalError)
return &OutputService2ProtocolTest{service}
}
// newRequest creates a new request for a OutputService2ProtocolTest operation and runs any
// custom request initialization.
func (c *OutputService2ProtocolTest) newRequest(op *aws.Operation, params, data interface{}) *aws.Request {
req := aws.NewRequest(c.Service, op, params, data)
return req
}
// OutputService2TestCaseOperation1Request generates a request for the OutputService2TestCaseOperation1 operation.
func (c *OutputService2ProtocolTest) OutputService2TestCaseOperation1Request(input *OutputService2TestShapeOutputService2TestCaseOperation1Input) (req *aws.Request, output *OutputService2TestShapeOutputShape) {
if opOutputService2TestCaseOperation1 == nil {
opOutputService2TestCaseOperation1 = &aws.Operation{
Name: "OperationName",
}
}
if input == nil {
input = &OutputService2TestShapeOutputService2TestCaseOperation1Input{}
}
req = c.newRequest(opOutputService2TestCaseOperation1, input, output)
output = &OutputService2TestShapeOutputShape{}
req.Data = output
return
}
func (c *OutputService2ProtocolTest) OutputService2TestCaseOperation1(input *OutputService2TestShapeOutputService2TestCaseOperation1Input) (output *OutputService2TestShapeOutputShape, err error) {
req, out := c.OutputService2TestCaseOperation1Request(input)
output = out
err = req.Send()
return
}
var opOutputService2TestCaseOperation1 *aws.Operation
type OutputService2TestShapeOutputService2TestCaseOperation1Input struct {
metadataOutputService2TestShapeOutputService2TestCaseOperation1Input `json:"-" xml:"-"`
}
type metadataOutputService2TestShapeOutputService2TestCaseOperation1Input struct {
SDKShapeTraits bool `type:"structure"`
}
type OutputService2TestShapeOutputShape struct {
Blob []byte `type:"blob"`
metadataOutputService2TestShapeOutputShape `json:"-" xml:"-"`
}
type metadataOutputService2TestShapeOutputShape struct {
SDKShapeTraits bool `type:"structure"`
}
// OutputService3ProtocolTest is a client for OutputService3ProtocolTest.
type OutputService3ProtocolTest struct {
*aws.Service
}
// New returns a new OutputService3ProtocolTest client.
func NewOutputService3ProtocolTest(config *aws.Config) *OutputService3ProtocolTest {
if config == nil {
config = &aws.Config{}
}
service := &aws.Service{
Config: aws.DefaultConfig.Merge(config),
ServiceName: "outputservice3protocoltest",
APIVersion: "",
}
service.Initialize()
// Handlers
service.Handlers.Sign.PushBack(v4.Sign)
service.Handlers.Build.PushBack(ec2query.Build)
service.Handlers.Unmarshal.PushBack(ec2query.Unmarshal)
service.Handlers.UnmarshalMeta.PushBack(ec2query.UnmarshalMeta)
service.Handlers.UnmarshalError.PushBack(ec2query.UnmarshalError)
return &OutputService3ProtocolTest{service}
}
// newRequest creates a new request for a OutputService3ProtocolTest operation and runs any
// custom request initialization.
func (c *OutputService3ProtocolTest) newRequest(op *aws.Operation, params, data interface{}) *aws.Request {
req := aws.NewRequest(c.Service, op, params, data)
return req
}
// OutputService3TestCaseOperation1Request generates a request for the OutputService3TestCaseOperation1 operation.
func (c *OutputService3ProtocolTest) OutputService3TestCaseOperation1Request(input *OutputService3TestShapeOutputService3TestCaseOperation1Input) (req *aws.Request, output *OutputService3TestShapeOutputShape) {
if opOutputService3TestCaseOperation1 == nil {
opOutputService3TestCaseOperation1 = &aws.Operation{
Name: "OperationName",
}
}
if input == nil {
input = &OutputService3TestShapeOutputService3TestCaseOperation1Input{}
}
req = c.newRequest(opOutputService3TestCaseOperation1, input, output)
output = &OutputService3TestShapeOutputShape{}
req.Data = output
return
}
func (c *OutputService3ProtocolTest) OutputService3TestCaseOperation1(input *OutputService3TestShapeOutputService3TestCaseOperation1Input) (output *OutputService3TestShapeOutputShape, err error) {
req, out := c.OutputService3TestCaseOperation1Request(input)
output = out
err = req.Send()
return
}
var opOutputService3TestCaseOperation1 *aws.Operation
type OutputService3TestShapeOutputService3TestCaseOperation1Input struct {
metadataOutputService3TestShapeOutputService3TestCaseOperation1Input `json:"-" xml:"-"`
}
type metadataOutputService3TestShapeOutputService3TestCaseOperation1Input struct {
SDKShapeTraits bool `type:"structure"`
}
type OutputService3TestShapeOutputShape struct {
ListMember []*string `type:"list"`
metadataOutputService3TestShapeOutputShape `json:"-" xml:"-"`
}
type metadataOutputService3TestShapeOutputShape struct {
SDKShapeTraits bool `type:"structure"`
}
// OutputService4ProtocolTest is a client for OutputService4ProtocolTest.
type OutputService4ProtocolTest struct {
*aws.Service
}
// New returns a new OutputService4ProtocolTest client.
func NewOutputService4ProtocolTest(config *aws.Config) *OutputService4ProtocolTest {
if config == nil {
config = &aws.Config{}
}
service := &aws.Service{
Config: aws.DefaultConfig.Merge(config),
ServiceName: "outputservice4protocoltest",
APIVersion: "",
}
service.Initialize()
// Handlers
service.Handlers.Sign.PushBack(v4.Sign)
service.Handlers.Build.PushBack(ec2query.Build)
service.Handlers.Unmarshal.PushBack(ec2query.Unmarshal)
service.Handlers.UnmarshalMeta.PushBack(ec2query.UnmarshalMeta)
service.Handlers.UnmarshalError.PushBack(ec2query.UnmarshalError)
return &OutputService4ProtocolTest{service}
}
// newRequest creates a new request for a OutputService4ProtocolTest operation and runs any
// custom request initialization.
func (c *OutputService4ProtocolTest) newRequest(op *aws.Operation, params, data interface{}) *aws.Request {
req := aws.NewRequest(c.Service, op, params, data)
return req
}
// OutputService4TestCaseOperation1Request generates a request for the OutputService4TestCaseOperation1 operation.
func (c *OutputService4ProtocolTest) OutputService4TestCaseOperation1Request(input *OutputService4TestShapeOutputService4TestCaseOperation1Input) (req *aws.Request, output *OutputService4TestShapeOutputShape) {
if opOutputService4TestCaseOperation1 == nil {
opOutputService4TestCaseOperation1 = &aws.Operation{
Name: "OperationName",
}
}
if input == nil {
input = &OutputService4TestShapeOutputService4TestCaseOperation1Input{}
}
req = c.newRequest(opOutputService4TestCaseOperation1, input, output)
output = &OutputService4TestShapeOutputShape{}
req.Data = output
return
}
func (c *OutputService4ProtocolTest) OutputService4TestCaseOperation1(input *OutputService4TestShapeOutputService4TestCaseOperation1Input) (output *OutputService4TestShapeOutputShape, err error) {
req, out := c.OutputService4TestCaseOperation1Request(input)
output = out
err = req.Send()
return
}
var opOutputService4TestCaseOperation1 *aws.Operation
type OutputService4TestShapeOutputService4TestCaseOperation1Input struct {
metadataOutputService4TestShapeOutputService4TestCaseOperation1Input `json:"-" xml:"-"`
}
type metadataOutputService4TestShapeOutputService4TestCaseOperation1Input struct {
SDKShapeTraits bool `type:"structure"`
}
type OutputService4TestShapeOutputShape struct {
ListMember []*string `locationNameList:"item" type:"list"`
metadataOutputService4TestShapeOutputShape `json:"-" xml:"-"`
}
type metadataOutputService4TestShapeOutputShape struct {
SDKShapeTraits bool `type:"structure"`
}
// OutputService5ProtocolTest is a client for OutputService5ProtocolTest.
type OutputService5ProtocolTest struct {
*aws.Service
}
// New returns a new OutputService5ProtocolTest client.
func NewOutputService5ProtocolTest(config *aws.Config) *OutputService5ProtocolTest {
if config == nil {
config = &aws.Config{}
}
service := &aws.Service{
Config: aws.DefaultConfig.Merge(config),
ServiceName: "outputservice5protocoltest",
APIVersion: "",
}
service.Initialize()
// Handlers
service.Handlers.Sign.PushBack(v4.Sign)
service.Handlers.Build.PushBack(ec2query.Build)
service.Handlers.Unmarshal.PushBack(ec2query.Unmarshal)
service.Handlers.UnmarshalMeta.PushBack(ec2query.UnmarshalMeta)
service.Handlers.UnmarshalError.PushBack(ec2query.UnmarshalError)
return &OutputService5ProtocolTest{service}
}
// newRequest creates a new request for a OutputService5ProtocolTest operation and runs any
// custom request initialization.
func (c *OutputService5ProtocolTest) newRequest(op *aws.Operation, params, data interface{}) *aws.Request {
req := aws.NewRequest(c.Service, op, params, data)
return req
}
// OutputService5TestCaseOperation1Request generates a request for the OutputService5TestCaseOperation1 operation.
func (c *OutputService5ProtocolTest) OutputService5TestCaseOperation1Request(input *OutputService5TestShapeOutputService5TestCaseOperation1Input) (req *aws.Request, output *OutputService5TestShapeOutputShape) {
if opOutputService5TestCaseOperation1 == nil {
opOutputService5TestCaseOperation1 = &aws.Operation{
Name: "OperationName",
}
}
if input == nil {
input = &OutputService5TestShapeOutputService5TestCaseOperation1Input{}
}
req = c.newRequest(opOutputService5TestCaseOperation1, input, output)
output = &OutputService5TestShapeOutputShape{}
req.Data = output
return
}
func (c *OutputService5ProtocolTest) OutputService5TestCaseOperation1(input *OutputService5TestShapeOutputService5TestCaseOperation1Input) (output *OutputService5TestShapeOutputShape, err error) {
req, out := c.OutputService5TestCaseOperation1Request(input)
output = out
err = req.Send()
return
}
var opOutputService5TestCaseOperation1 *aws.Operation
type OutputService5TestShapeOutputService5TestCaseOperation1Input struct {
metadataOutputService5TestShapeOutputService5TestCaseOperation1Input `json:"-" xml:"-"`
}
type metadataOutputService5TestShapeOutputService5TestCaseOperation1Input struct {
SDKShapeTraits bool `type:"structure"`
}
type OutputService5TestShapeOutputShape struct {
ListMember []*string `type:"list" flattened:"true"`
metadataOutputService5TestShapeOutputShape `json:"-" xml:"-"`
}
type metadataOutputService5TestShapeOutputShape struct {
SDKShapeTraits bool `type:"structure"`
}
// OutputService6ProtocolTest is a client for OutputService6ProtocolTest.
type OutputService6ProtocolTest struct {
*aws.Service
}
// New returns a new OutputService6ProtocolTest client.
func NewOutputService6ProtocolTest(config *aws.Config) *OutputService6ProtocolTest {
if config == nil {
config = &aws.Config{}
}
service := &aws.Service{
Config: aws.DefaultConfig.Merge(config),
ServiceName: "outputservice6protocoltest",
APIVersion: "",
}
service.Initialize()
// Handlers
service.Handlers.Sign.PushBack(v4.Sign)
service.Handlers.Build.PushBack(ec2query.Build)
service.Handlers.Unmarshal.PushBack(ec2query.Unmarshal)
service.Handlers.UnmarshalMeta.PushBack(ec2query.UnmarshalMeta)
service.Handlers.UnmarshalError.PushBack(ec2query.UnmarshalError)
return &OutputService6ProtocolTest{service}
}
// newRequest creates a new request for a OutputService6ProtocolTest operation and runs any
// custom request initialization.
func (c *OutputService6ProtocolTest) newRequest(op *aws.Operation, params, data interface{}) *aws.Request {
req := aws.NewRequest(c.Service, op, params, data)
return req
}
// OutputService6TestCaseOperation1Request generates a request for the OutputService6TestCaseOperation1 operation.
func (c *OutputService6ProtocolTest) OutputService6TestCaseOperation1Request(input *OutputService6TestShapeOutputService6TestCaseOperation1Input) (req *aws.Request, output *OutputService6TestShapeOutputShape) {
if opOutputService6TestCaseOperation1 == nil {
opOutputService6TestCaseOperation1 = &aws.Operation{
Name: "OperationName",
}
}
if input == nil {
input = &OutputService6TestShapeOutputService6TestCaseOperation1Input{}
}
req = c.newRequest(opOutputService6TestCaseOperation1, input, output)
output = &OutputService6TestShapeOutputShape{}
req.Data = output
return
}
func (c *OutputService6ProtocolTest) OutputService6TestCaseOperation1(input *OutputService6TestShapeOutputService6TestCaseOperation1Input) (output *OutputService6TestShapeOutputShape, err error) {
req, out := c.OutputService6TestCaseOperation1Request(input)
output = out
err = req.Send()
return
}
var opOutputService6TestCaseOperation1 *aws.Operation
type OutputService6TestShapeOutputService6TestCaseOperation1Input struct {
metadataOutputService6TestShapeOutputService6TestCaseOperation1Input `json:"-" xml:"-"`
}
type metadataOutputService6TestShapeOutputService6TestCaseOperation1Input struct {
SDKShapeTraits bool `type:"structure"`
}
type OutputService6TestShapeOutputShape struct {
Map *map[string]*OutputService6TestShapeStructureType `type:"map"`
metadataOutputService6TestShapeOutputShape `json:"-" xml:"-"`
}
type metadataOutputService6TestShapeOutputShape struct {
SDKShapeTraits bool `type:"structure"`
}
type OutputService6TestShapeStructureType struct {
Foo *string `locationName:"foo" type:"string"`
metadataOutputService6TestShapeStructureType `json:"-" xml:"-"`
}
type metadataOutputService6TestShapeStructureType struct {
SDKShapeTraits bool `type:"structure"`
}
// OutputService7ProtocolTest is a client for OutputService7ProtocolTest.
type OutputService7ProtocolTest struct {
*aws.Service
}
// New returns a new OutputService7ProtocolTest client.
func NewOutputService7ProtocolTest(config *aws.Config) *OutputService7ProtocolTest {
if config == nil {
config = &aws.Config{}
}
service := &aws.Service{
Config: aws.DefaultConfig.Merge(config),
ServiceName: "outputservice7protocoltest",
APIVersion: "",
}
service.Initialize()
// Handlers
service.Handlers.Sign.PushBack(v4.Sign)
service.Handlers.Build.PushBack(ec2query.Build)
service.Handlers.Unmarshal.PushBack(ec2query.Unmarshal)
service.Handlers.UnmarshalMeta.PushBack(ec2query.UnmarshalMeta)
service.Handlers.UnmarshalError.PushBack(ec2query.UnmarshalError)
return &OutputService7ProtocolTest{service}
}
// newRequest creates a new request for a OutputService7ProtocolTest operation and runs any
// custom request initialization.
func (c *OutputService7ProtocolTest) newRequest(op *aws.Operation, params, data interface{}) *aws.Request {
req := aws.NewRequest(c.Service, op, params, data)
return req
}
// OutputService7TestCaseOperation1Request generates a request for the OutputService7TestCaseOperation1 operation.
func (c *OutputService7ProtocolTest) OutputService7TestCaseOperation1Request(input *OutputService7TestShapeOutputService7TestCaseOperation1Input) (req *aws.Request, output *OutputService7TestShapeOutputShape) {
if opOutputService7TestCaseOperation1 == nil {
opOutputService7TestCaseOperation1 = &aws.Operation{
Name: "OperationName",
}
}
if input == nil {
input = &OutputService7TestShapeOutputService7TestCaseOperation1Input{}
}
req = c.newRequest(opOutputService7TestCaseOperation1, input, output)
output = &OutputService7TestShapeOutputShape{}
req.Data = output
return
}
func (c *OutputService7ProtocolTest) OutputService7TestCaseOperation1(input *OutputService7TestShapeOutputService7TestCaseOperation1Input) (output *OutputService7TestShapeOutputShape, err error) {
req, out := c.OutputService7TestCaseOperation1Request(input)
output = out
err = req.Send()
return
}
var opOutputService7TestCaseOperation1 *aws.Operation
type OutputService7TestShapeOutputService7TestCaseOperation1Input struct {
metadataOutputService7TestShapeOutputService7TestCaseOperation1Input `json:"-" xml:"-"`
}
type metadataOutputService7TestShapeOutputService7TestCaseOperation1Input struct {
SDKShapeTraits bool `type:"structure"`
}
type OutputService7TestShapeOutputShape struct {
Map *map[string]*string `type:"map" flattened:"true"`
metadataOutputService7TestShapeOutputShape `json:"-" xml:"-"`
}
type metadataOutputService7TestShapeOutputShape struct {
SDKShapeTraits bool `type:"structure"`
}
// OutputService8ProtocolTest is a client for OutputService8ProtocolTest.
type OutputService8ProtocolTest struct {
*aws.Service
}
// New returns a new OutputService8ProtocolTest client.
func NewOutputService8ProtocolTest(config *aws.Config) *OutputService8ProtocolTest {
if config == nil {
config = &aws.Config{}
}
service := &aws.Service{
Config: aws.DefaultConfig.Merge(config),
ServiceName: "outputservice8protocoltest",
APIVersion: "",
}
service.Initialize()
// Handlers
service.Handlers.Sign.PushBack(v4.Sign)
service.Handlers.Build.PushBack(ec2query.Build)
service.Handlers.Unmarshal.PushBack(ec2query.Unmarshal)
service.Handlers.UnmarshalMeta.PushBack(ec2query.UnmarshalMeta)
service.Handlers.UnmarshalError.PushBack(ec2query.UnmarshalError)
return &OutputService8ProtocolTest{service}
}
// newRequest creates a new request for a OutputService8ProtocolTest operation and runs any
// custom request initialization.
func (c *OutputService8ProtocolTest) newRequest(op *aws.Operation, params, data interface{}) *aws.Request {
req := aws.NewRequest(c.Service, op, params, data)
return req
}
// OutputService8TestCaseOperation1Request generates a request for the OutputService8TestCaseOperation1 operation.
func (c *OutputService8ProtocolTest) OutputService8TestCaseOperation1Request(input *OutputService8TestShapeOutputService8TestCaseOperation1Input) (req *aws.Request, output *OutputService8TestShapeOutputShape) {
if opOutputService8TestCaseOperation1 == nil {
opOutputService8TestCaseOperation1 = &aws.Operation{
Name: "OperationName",
}
}
if input == nil {
input = &OutputService8TestShapeOutputService8TestCaseOperation1Input{}
}
req = c.newRequest(opOutputService8TestCaseOperation1, input, output)
output = &OutputService8TestShapeOutputShape{}
req.Data = output
return
}
func (c *OutputService8ProtocolTest) OutputService8TestCaseOperation1(input *OutputService8TestShapeOutputService8TestCaseOperation1Input) (output *OutputService8TestShapeOutputShape, err error) {
req, out := c.OutputService8TestCaseOperation1Request(input)
output = out
err = req.Send()
return
}
var opOutputService8TestCaseOperation1 *aws.Operation
type OutputService8TestShapeOutputService8TestCaseOperation1Input struct {
metadataOutputService8TestShapeOutputService8TestCaseOperation1Input `json:"-" xml:"-"`
}
type metadataOutputService8TestShapeOutputService8TestCaseOperation1Input struct {
SDKShapeTraits bool `type:"structure"`
}
type OutputService8TestShapeOutputShape struct {
Map *map[string]*string `locationNameKey:"foo" locationNameValue:"bar" type:"map" flattened:"true"`
metadataOutputService8TestShapeOutputShape `json:"-" xml:"-"`
}
type metadataOutputService8TestShapeOutputShape struct {
SDKShapeTraits bool `type:"structure"`
}
//
// Tests begin here
//
func TestOutputService1ProtocolTestScalarMembersCase1(t *testing.T) {
svc := NewOutputService1ProtocolTest(nil)
buf := bytes.NewReader([]byte("<OperationNameResponse><Str>myname</Str><FooNum>123</FooNum><FalseBool>false</FalseBool><TrueBool>true</TrueBool><Float>1.2</Float><Double>1.3</Double><Long>200</Long><Char>a</Char><RequestId>request-id</RequestId></OperationNameResponse>"))
req, out := svc.OutputService1TestCaseOperation1Request(nil)
req.HTTPResponse = &http.Response{StatusCode: 200, Body: ioutil.NopCloser(buf), Header: http.Header{}}
// set headers
// unmarshal response
ec2query.UnmarshalMeta(req)
ec2query.Unmarshal(req)
assert.NoError(t, req.Error)
// assert response
assert.NotNil(t, out) // ensure out variable is used
assert.Equal(t, "a", *out.Char)
assert.Equal(t, 1.3, *out.Double)
assert.Equal(t, false, *out.FalseBool)
assert.Equal(t, 1.2, *out.Float)
assert.Equal(t, int64(200), *out.Long)
assert.Equal(t, int64(123), *out.Num)
assert.Equal(t, "myname", *out.Str)
assert.Equal(t, true, *out.TrueBool)
}
func TestOutputService2ProtocolTestBlobCase1(t *testing.T) {
svc := NewOutputService2ProtocolTest(nil)
buf := bytes.NewReader([]byte("<OperationNameResponse><Blob>dmFsdWU=</Blob><RequestId>requestid</RequestId></OperationNameResponse>"))
req, out := svc.OutputService2TestCaseOperation1Request(nil)
req.HTTPResponse = &http.Response{StatusCode: 200, Body: ioutil.NopCloser(buf), Header: http.Header{}}
// set headers
// unmarshal response
ec2query.UnmarshalMeta(req)
ec2query.Unmarshal(req)
assert.NoError(t, req.Error)
// assert response
assert.NotNil(t, out) // ensure out variable is used
assert.Equal(t, "value", string(out.Blob))
}
func TestOutputService3ProtocolTestListsCase1(t *testing.T) {
svc := NewOutputService3ProtocolTest(nil)
buf := bytes.NewReader([]byte("<OperationNameResponse><ListMember><member>abc</member><member>123</member></ListMember><RequestId>requestid</RequestId></OperationNameResponse>"))
req, out := svc.OutputService3TestCaseOperation1Request(nil)
req.HTTPResponse = &http.Response{StatusCode: 200, Body: ioutil.NopCloser(buf), Header: http.Header{}}
// set headers
// unmarshal response
ec2query.UnmarshalMeta(req)
ec2query.Unmarshal(req)
assert.NoError(t, req.Error)
// assert response
assert.NotNil(t, out) // ensure out variable is used
assert.Equal(t, "abc", *out.ListMember[0])
assert.Equal(t, "123", *out.ListMember[1])
}
func TestOutputService4ProtocolTestListWithCustomMemberNameCase1(t *testing.T) {
svc := NewOutputService4ProtocolTest(nil)
buf := bytes.NewReader([]byte("<OperationNameResponse><ListMember><item>abc</item><item>123</item></ListMember><RequestId>requestid</RequestId></OperationNameResponse>"))
req, out := svc.OutputService4TestCaseOperation1Request(nil)
req.HTTPResponse = &http.Response{StatusCode: 200, Body: ioutil.NopCloser(buf), Header: http.Header{}}
// set headers
// unmarshal response
ec2query.UnmarshalMeta(req)
ec2query.Unmarshal(req)
assert.NoError(t, req.Error)
// assert response
assert.NotNil(t, out) // ensure out variable is used
assert.Equal(t, "abc", *out.ListMember[0])
assert.Equal(t, "123", *out.ListMember[1])
}
func TestOutputService5ProtocolTestFlattenedListCase1(t *testing.T) {
svc := NewOutputService5ProtocolTest(nil)
buf := bytes.NewReader([]byte("<OperationNameResponse><ListMember>abc</ListMember><ListMember>123</ListMember><RequestId>requestid</RequestId></OperationNameResponse>"))
req, out := svc.OutputService5TestCaseOperation1Request(nil)
req.HTTPResponse = &http.Response{StatusCode: 200, Body: ioutil.NopCloser(buf), Header: http.Header{}}
// set headers
// unmarshal response
ec2query.UnmarshalMeta(req)
ec2query.Unmarshal(req)
assert.NoError(t, req.Error)
// assert response
assert.NotNil(t, out) // ensure out variable is used
assert.Equal(t, "abc", *out.ListMember[0])
assert.Equal(t, "123", *out.ListMember[1])
}
func TestOutputService6ProtocolTestNormalMapCase1(t *testing.T) {
svc := NewOutputService6ProtocolTest(nil)
buf := bytes.NewReader([]byte("<OperationNameResponse><Map><entry><key>qux</key><value><foo>bar</foo></value></entry><entry><key>baz</key><value><foo>bam</foo></value></entry></Map><RequestId>requestid</RequestId></OperationNameResponse>"))
req, out := svc.OutputService6TestCaseOperation1Request(nil)
req.HTTPResponse = &http.Response{StatusCode: 200, Body: ioutil.NopCloser(buf), Header: http.Header{}}
// set headers
// unmarshal response
ec2query.UnmarshalMeta(req)
ec2query.Unmarshal(req)
assert.NoError(t, req.Error)
// assert response
assert.NotNil(t, out) // ensure out variable is used
assert.Equal(t, "bam", *(*out.Map)["baz"].Foo)
assert.Equal(t, "bar", *(*out.Map)["qux"].Foo)
}
func TestOutputService7ProtocolTestFlattenedMapCase1(t *testing.T) {
svc := NewOutputService7ProtocolTest(nil)
buf := bytes.NewReader([]byte("<OperationNameResponse><Map><key>qux</key><value>bar</value></Map><Map><key>baz</key><value>bam</value></Map><RequestId>requestid</RequestId></OperationNameResponse>"))
req, out := svc.OutputService7TestCaseOperation1Request(nil)
req.HTTPResponse = &http.Response{StatusCode: 200, Body: ioutil.NopCloser(buf), Header: http.Header{}}
// set headers
// unmarshal response
ec2query.UnmarshalMeta(req)
ec2query.Unmarshal(req)
assert.NoError(t, req.Error)
// assert response
assert.NotNil(t, out) // ensure out variable is used
assert.Equal(t, "bam", *(*out.Map)["baz"])
assert.Equal(t, "bar", *(*out.Map)["qux"])
}
func TestOutputService8ProtocolTestNamedMapCase1(t *testing.T) {
svc := NewOutputService8ProtocolTest(nil)
buf := bytes.NewReader([]byte("<OperationNameResponse><Map><foo>qux</foo><bar>bar</bar></Map><Map><foo>baz</foo><bar>bam</bar></Map><RequestId>requestid</RequestId></OperationNameResponse>"))
req, out := svc.OutputService8TestCaseOperation1Request(nil)
req.HTTPResponse = &http.Response{StatusCode: 200, Body: ioutil.NopCloser(buf), Header: http.Header{}}
// set headers
// unmarshal response
ec2query.UnmarshalMeta(req)
ec2query.Unmarshal(req)
assert.NoError(t, req.Error)
// assert response
assert.NotNil(t, out) // ensure out variable is used
assert.Equal(t, "bam", *(*out.Map)["baz"])
assert.Equal(t, "bar", *(*out.Map)["qux"])
}

View File

@@ -0,0 +1,223 @@
package queryutil
import (
"encoding/base64"
"fmt"
"net/url"
"reflect"
"sort"
"strconv"
"strings"
"time"
)
// Parse parses an object i and fills a url.Values object. The isEC2 flag
// indicates if this is the EC2 Query sub-protocol.
func Parse(body url.Values, i interface{}, isEC2 bool) error {
q := queryParser{isEC2: isEC2}
return q.parseValue(body, reflect.ValueOf(i), "", "")
}
func elemOf(value reflect.Value) reflect.Value {
for value.Kind() == reflect.Ptr {
value = value.Elem()
}
return value
}
type queryParser struct {
isEC2 bool
}
func (q *queryParser) parseValue(v url.Values, value reflect.Value, prefix string, tag reflect.StructTag) error {
value = elemOf(value)
// no need to handle zero values
if !value.IsValid() {
return nil
}
t := tag.Get("type")
if t == "" {
switch value.Kind() {
case reflect.Struct:
t = "structure"
case reflect.Slice:
t = "list"
case reflect.Map:
t = "map"
}
}
switch t {
case "structure":
return q.parseStruct(v, value, prefix)
case "list":
return q.parseList(v, value, prefix, tag)
case "map":
return q.parseMap(v, value, prefix, tag)
default:
return q.parseScalar(v, value, prefix, tag)
}
}
func (q *queryParser) parseStruct(v url.Values, value reflect.Value, prefix string) error {
if !value.IsValid() {
return nil
}
t := value.Type()
for i := 0; i < value.NumField(); i++ {
if c := t.Field(i).Name[0:1]; strings.ToLower(c) == c {
continue // ignore unexported fields
}
value := elemOf(value.Field(i))
field := t.Field(i)
var name string
if q.isEC2 {
name = field.Tag.Get("queryName")
}
if name == "" {
if field.Tag.Get("flattened") != "" && field.Tag.Get("locationNameList") != "" {
name = field.Tag.Get("locationNameList")
} else if locName := field.Tag.Get("locationName"); locName != "" {
name = locName
}
if name != "" && q.isEC2 {
name = strings.ToUpper(name[0:1]) + name[1:]
}
}
if name == "" {
name = field.Name
}
if prefix != "" {
name = prefix + "." + name
}
if err := q.parseValue(v, value, name, field.Tag); err != nil {
return err
}
}
return nil
}
func (q *queryParser) parseList(v url.Values, value reflect.Value, prefix string, tag reflect.StructTag) error {
// If it's empty, generate an empty value
if !value.IsNil() && value.Len() == 0 {
v.Set(prefix, "")
return nil
}
// check for unflattened list member
if !q.isEC2 && tag.Get("flattened") == "" {
prefix += ".member"
}
for i := 0; i < value.Len(); i++ {
slicePrefix := prefix
if slicePrefix == "" {
slicePrefix = strconv.Itoa(i + 1)
} else {
slicePrefix = slicePrefix + "." + strconv.Itoa(i+1)
}
if err := q.parseValue(v, value.Index(i), slicePrefix, ""); err != nil {
return err
}
}
return nil
}
func (q *queryParser) parseMap(v url.Values, value reflect.Value, prefix string, tag reflect.StructTag) error {
// If it's empty, generate an empty value
if !value.IsNil() && value.Len() == 0 {
v.Set(prefix, "")
return nil
}
// check for unflattened list member
if !q.isEC2 && tag.Get("flattened") == "" {
prefix += ".entry"
}
// sort keys for improved serialization consistency.
// this is not strictly necessary for protocol support.
mapKeyValues := value.MapKeys()
mapKeys := map[string]reflect.Value{}
mapKeyNames := make([]string, len(mapKeyValues))
for i, mapKey := range mapKeyValues {
name := mapKey.String()
mapKeys[name] = mapKey
mapKeyNames[i] = name
}
sort.Strings(mapKeyNames)
for i, mapKeyName := range mapKeyNames {
mapKey := mapKeys[mapKeyName]
mapValue := value.MapIndex(mapKey)
kname := tag.Get("locationNameKey")
if kname == "" {
kname = "key"
}
vname := tag.Get("locationNameValue")
if vname == "" {
vname = "value"
}
// serialize key
var keyName string
if prefix == "" {
keyName = strconv.Itoa(i+1) + "." + kname
} else {
keyName = prefix + "." + strconv.Itoa(i+1) + "." + kname
}
if err := q.parseValue(v, mapKey, keyName, ""); err != nil {
return err
}
// serialize value
var valueName string
if prefix == "" {
valueName = strconv.Itoa(i+1) + "." + vname
} else {
valueName = prefix + "." + strconv.Itoa(i+1) + "." + vname
}
if err := q.parseValue(v, mapValue, valueName, ""); err != nil {
return err
}
}
return nil
}
func (q *queryParser) parseScalar(v url.Values, r reflect.Value, name string, tag reflect.StructTag) error {
switch value := r.Interface().(type) {
case string:
v.Set(name, value)
case []byte:
if !r.IsNil() {
v.Set(name, base64.StdEncoding.EncodeToString(value))
}
case bool:
v.Set(name, strconv.FormatBool(value))
case int64:
v.Set(name, strconv.FormatInt(value, 10))
case int:
v.Set(name, strconv.Itoa(value))
case float64:
v.Set(name, strconv.FormatFloat(value, 'f', -1, 64))
case float32:
v.Set(name, strconv.FormatFloat(float64(value), 'f', -1, 32))
case time.Time:
const ISO8601UTC = "2006-01-02T15:04:05Z"
v.Set(name, value.UTC().Format(ISO8601UTC))
default:
return fmt.Errorf("unsupported value for param %s: %v (%s)", name, r.Interface(), r.Type().Name())
}
return nil
}

View File

@@ -0,0 +1,286 @@
package xmlutil
import (
"encoding/base64"
"encoding/xml"
"fmt"
"reflect"
"sort"
"strconv"
"strings"
"time"
)
// BuildXML will serialize params into an xml.Encoder.
// Error will be returned if the serialization of any of the params or nested values fails.
func BuildXML(params interface{}, e *xml.Encoder) error {
b := xmlBuilder{encoder: e, namespaces: map[string]string{}}
root := NewXMLElement(xml.Name{})
if err := b.buildValue(reflect.ValueOf(params), root, ""); err != nil {
return err
}
for _, c := range root.Children {
for _, v := range c {
return StructToXML(e, v, false)
}
}
return nil
}
// Returns the reflection element of a value, if it is a pointer.
func elemOf(value reflect.Value) reflect.Value {
for value.Kind() == reflect.Ptr {
value = value.Elem()
}
return value
}
// A xmlBuilder serializes values from Go code to XML
type xmlBuilder struct {
encoder *xml.Encoder
namespaces map[string]string
}
// buildValue generic XMLNode builder for any type. Will build value for their specific type
// struct, list, map, scalar.
//
// Also takes a "type" tag value to set what type a value should be converted to XMLNode as. If
// type is not provided reflect will be used to determine the value's type.
func (b *xmlBuilder) buildValue(value reflect.Value, current *XMLNode, tag reflect.StructTag) error {
value = elemOf(value)
if !value.IsValid() { // no need to handle zero values
return nil
} else if tag.Get("location") != "" { // don't handle non-body location values
return nil
}
t := tag.Get("type")
if t == "" {
switch value.Kind() {
case reflect.Struct:
t = "structure"
case reflect.Slice:
t = "list"
case reflect.Map:
t = "map"
}
}
switch t {
case "structure":
if field, ok := value.Type().FieldByName("SDKShapeTraits"); ok {
tag = tag + reflect.StructTag(" ") + field.Tag
}
return b.buildStruct(value, current, tag)
case "list":
return b.buildList(value, current, tag)
case "map":
return b.buildMap(value, current, tag)
default:
return b.buildScalar(value, current, tag)
}
}
// buildStruct adds a struct and its fields to the current XMLNode. All fields any any nested
// types are converted to XMLNodes also.
func (b *xmlBuilder) buildStruct(value reflect.Value, current *XMLNode, tag reflect.StructTag) error {
if !value.IsValid() {
return nil
}
fieldAdded := false
// unwrap payloads
if payload := tag.Get("payload"); payload != "" {
field, _ := value.Type().FieldByName(payload)
tag = field.Tag
value = elemOf(value.FieldByName(payload))
if !value.IsValid() {
return nil
}
}
child := NewXMLElement(xml.Name{Local: tag.Get("locationName")})
// there is an xmlNamespace associated with this struct
if prefix, uri := tag.Get("xmlPrefix"), tag.Get("xmlURI"); uri != "" {
ns := xml.Attr{
Name: xml.Name{Local: "xmlns"},
Value: uri,
}
if prefix != "" {
b.namespaces[prefix] = uri // register the namespace
ns.Name.Local = "xmlns:" + prefix
}
child.Attr = append(child.Attr, ns)
}
t := value.Type()
for i := 0; i < value.NumField(); i++ {
if c := t.Field(i).Name[0:1]; strings.ToLower(c) == c {
continue // ignore unexported fields
}
member := elemOf(value.Field(i))
field := t.Field(i)
mTag := field.Tag
if mTag.Get("location") != "" { // skip non-body members
continue
}
memberName := mTag.Get("locationName")
if memberName == "" {
memberName = field.Name
mTag = reflect.StructTag(string(mTag) + ` locationName:"` + memberName + `"`)
}
if err := b.buildValue(member, child, mTag); err != nil {
return err
}
fieldAdded = true
}
if fieldAdded { // only append this child if we have one ore more valid members
current.AddChild(child)
}
return nil
}
// buildList adds the value's list items to the current XMLNode as children nodes. All
// nested values in the list are converted to XMLNodes also.
func (b *xmlBuilder) buildList(value reflect.Value, current *XMLNode, tag reflect.StructTag) error {
if value.IsNil() { // don't build omitted lists
return nil
}
// check for unflattened list member
flattened := tag.Get("flattened") != ""
xname := xml.Name{Local: tag.Get("locationName")}
if flattened {
for i := 0; i < value.Len(); i++ {
child := NewXMLElement(xname)
current.AddChild(child)
if err := b.buildValue(value.Index(i), child, ""); err != nil {
return err
}
}
} else {
list := NewXMLElement(xname)
current.AddChild(list)
for i := 0; i < value.Len(); i++ {
iname := tag.Get("locationNameList")
if iname == "" {
iname = "member"
}
child := NewXMLElement(xml.Name{Local: iname})
list.AddChild(child)
if err := b.buildValue(value.Index(i), child, ""); err != nil {
return err
}
}
}
return nil
}
// buildMap adds the value's key/value pairs to the current XMLNode as children nodes. All
// nested values in the map are converted to XMLNodes also.
//
// Error will be returned if it is unable to build the map's values into XMLNodes
func (b *xmlBuilder) buildMap(value reflect.Value, current *XMLNode, tag reflect.StructTag) error {
if value.IsNil() { // don't build omitted maps
return nil
}
maproot := NewXMLElement(xml.Name{Local: tag.Get("locationName")})
current.AddChild(maproot)
current = maproot
kname, vname := "key", "value"
if n := tag.Get("locationNameKey"); n != "" {
kname = n
}
if n := tag.Get("locationNameValue"); n != "" {
vname = n
}
// sorting is not required for compliance, but it makes testing easier
keys := make([]string, value.Len())
for i, k := range value.MapKeys() {
keys[i] = k.String()
}
sort.Strings(keys)
for _, k := range keys {
v := value.MapIndex(reflect.ValueOf(k))
mapcur := current
if tag.Get("flattened") == "" { // add "entry" tag to non-flat maps
child := NewXMLElement(xml.Name{Local: "entry"})
mapcur.AddChild(child)
mapcur = child
}
kchild := NewXMLElement(xml.Name{Local: kname})
kchild.Text = k
vchild := NewXMLElement(xml.Name{Local: vname})
mapcur.AddChild(kchild)
mapcur.AddChild(vchild)
if err := b.buildValue(v, vchild, ""); err != nil {
return err
}
}
return nil
}
// buildScalar will convert the value into a string and append it as a attribute or child
// of the current XMLNode.
//
// The value will be added as an attribute if tag contains a "xmlAttribute" attribute value.
//
// Error will be returned if the value type is unsupported.
func (b *xmlBuilder) buildScalar(value reflect.Value, current *XMLNode, tag reflect.StructTag) error {
var str string
switch converted := value.Interface().(type) {
case string:
str = converted
case []byte:
if !value.IsNil() {
str = base64.StdEncoding.EncodeToString(converted)
}
case bool:
str = strconv.FormatBool(converted)
case int64:
str = strconv.FormatInt(converted, 10)
case int:
str = strconv.Itoa(converted)
case float64:
str = strconv.FormatFloat(converted, 'f', -1, 64)
case float32:
str = strconv.FormatFloat(float64(converted), 'f', -1, 32)
case time.Time:
const ISO8601UTC = "2006-01-02T15:04:05Z"
str = converted.UTC().Format(ISO8601UTC)
default:
return fmt.Errorf("unsupported value for param %s: %v (%s)",
tag.Get("locationName"), value.Interface(), value.Type().Name())
}
xname := xml.Name{Local: tag.Get("locationName")}
if tag.Get("xmlAttribute") != "" { // put into current node's attribute list
attr := xml.Attr{Name: xname, Value: str}
current.Attr = append(current.Attr, attr)
} else { // regular text node
current.AddChild(&XMLNode{Name: xname, Text: str})
}
return nil
}

View File

@@ -0,0 +1,267 @@
package xmlutil
import (
"encoding/base64"
"encoding/xml"
"fmt"
"io"
"reflect"
"strconv"
"strings"
"time"
)
// UnmarshalXML deserializes an xml.Decoder into the container v. V
// needs to match the shape of the XML expected to be decoded.
// If the shape doesn't match unmarshaling will fail.
func UnmarshalXML(v interface{}, d *xml.Decoder, wrapper string) error {
n, _ := XMLToStruct(d, nil)
if n.Children != nil {
for _, root := range n.Children {
for _, c := range root {
if wrappedChild, ok := c.Children[wrapper]; ok {
c = wrappedChild[0] // pull out wrapped element
}
err := parse(reflect.ValueOf(v), c, "")
if err != nil {
if err == io.EOF {
return nil
}
return err
}
}
}
return nil
}
return nil
}
// parse deserializes any value from the XMLNode. The type tag is used to infer the type, or reflect
// will be used to determine the type from r.
func parse(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
rtype := r.Type()
if rtype.Kind() == reflect.Ptr {
rtype = rtype.Elem() // check kind of actual element type
}
t := tag.Get("type")
if t == "" {
switch rtype.Kind() {
case reflect.Struct:
t = "structure"
case reflect.Slice:
t = "list"
case reflect.Map:
t = "map"
}
}
switch t {
case "structure":
if field, ok := rtype.FieldByName("SDKShapeTraits"); ok {
tag = field.Tag
}
return parseStruct(r, node, tag)
case "list":
return parseList(r, node, tag)
case "map":
return parseMap(r, node, tag)
default:
return parseScalar(r, node, tag)
}
}
// parseStruct deserializes a structure and its fields from an XMLNode. Any nested
// types in the structure will also be deserialized.
func parseStruct(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
t := r.Type()
if r.Kind() == reflect.Ptr {
if r.IsNil() { // create the structure if it's nil
s := reflect.New(r.Type().Elem())
r.Set(s)
r = s
}
r = r.Elem()
t = t.Elem()
}
// unwrap any payloads
if payload := tag.Get("payload"); payload != "" {
field, _ := t.FieldByName(payload)
return parseStruct(r.FieldByName(payload), node, field.Tag)
}
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
if c := field.Name[0:1]; strings.ToLower(c) == c {
continue // ignore unexported fields
}
// figure out what this field is called
name := field.Name
if field.Tag.Get("flattened") != "" && field.Tag.Get("locationNameList") != "" {
name = field.Tag.Get("locationNameList")
} else if locName := field.Tag.Get("locationName"); locName != "" {
name = locName
}
// try to find the field by name in elements
elems := node.Children[name]
if elems == nil { // try to find the field in attributes
for _, a := range node.Attr {
if name == a.Name.Local {
// turn this into a text node for de-serializing
elems = []*XMLNode{&XMLNode{Text: a.Value}}
}
}
}
member := r.FieldByName(field.Name)
for _, elem := range elems {
err := parse(member, elem, field.Tag)
if err != nil {
return err
}
}
}
return nil
}
// parseList deserializes a list of values from an XML node. Each list entry
// will also be deserialized.
func parseList(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
t := r.Type()
if tag.Get("flattened") == "" { // look at all item entries
mname := "member"
if name := tag.Get("locationNameList"); name != "" {
mname = name
}
if Children, ok := node.Children[mname]; ok {
if r.IsNil() {
r.Set(reflect.MakeSlice(t, len(Children), len(Children)))
}
for i, c := range Children {
err := parse(r.Index(i), c, "")
if err != nil {
return err
}
}
}
} else { // flattened list means this is a single element
if r.IsNil() {
r.Set(reflect.MakeSlice(t, 0, 0))
}
childR := reflect.Zero(t.Elem())
r.Set(reflect.Append(r, childR))
err := parse(r.Index(r.Len()-1), node, "")
if err != nil {
return err
}
}
return nil
}
// parseMap deserializes a map from an XMLNode. The direct children of the XMLNode
// will also be deserialized as map entries.
func parseMap(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
t := r.Type()
if r.Kind() == reflect.Ptr {
t = t.Elem()
if r.IsNil() {
r.Set(reflect.New(t))
r.Elem().Set(reflect.MakeMap(t))
}
r = r.Elem()
}
if tag.Get("flattened") == "" { // look at all child entries
for _, entry := range node.Children["entry"] {
parseMapEntry(r, entry, tag)
}
} else { // this element is itself an entry
parseMapEntry(r, node, tag)
}
return nil
}
// parseMapEntry deserializes a map entry from a XML node.
func parseMapEntry(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
kname, vname := "key", "value"
if n := tag.Get("locationNameKey"); n != "" {
kname = n
}
if n := tag.Get("locationNameValue"); n != "" {
vname = n
}
keys, ok := node.Children[kname]
values := node.Children[vname]
if ok {
for i, key := range keys {
keyR := reflect.ValueOf(key.Text)
value := values[i]
valueR := reflect.New(r.Type().Elem()).Elem()
parse(valueR, value, "")
r.SetMapIndex(keyR, valueR)
}
}
return nil
}
// parseScaller deserializes an XMLNode value into a concrete type based on the
// interface type of r.
//
// Error is returned if the deserialization fails due to invalid type conversion,
// or unsupported interface type.
func parseScalar(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
switch r.Interface().(type) {
case *string:
r.Set(reflect.ValueOf(&node.Text))
return nil
case []byte:
b, err := base64.StdEncoding.DecodeString(node.Text)
if err != nil {
return err
}
r.Set(reflect.ValueOf(b))
case *bool:
v, err := strconv.ParseBool(node.Text)
if err != nil {
return err
}
r.Set(reflect.ValueOf(&v))
case *int64:
v, err := strconv.ParseInt(node.Text, 10, 64)
if err != nil {
return err
}
r.Set(reflect.ValueOf(&v))
case *float64:
v, err := strconv.ParseFloat(node.Text, 64)
if err != nil {
return err
}
r.Set(reflect.ValueOf(&v))
case *time.Time:
const ISO8601UTC = "2006-01-02T15:04:05Z"
t, err := time.Parse(ISO8601UTC, node.Text)
if err != nil {
return err
}
r.Set(reflect.ValueOf(&t))
default:
return fmt.Errorf("unsupported value: %v (%s)", r.Interface(), r.Type())
}
return nil
}

View File

@@ -0,0 +1,105 @@
package xmlutil
import (
"encoding/xml"
"io"
"sort"
)
// A XMLNode contains the values to be encoded or decoded.
type XMLNode struct {
Name xml.Name `json:",omitempty"`
Children map[string][]*XMLNode `json:",omitempty"`
Text string `json:",omitempty"`
Attr []xml.Attr `json:",omitempty"`
}
// NewXMLElement returns a pointer to a new XMLNode initialized to default values.
func NewXMLElement(name xml.Name) *XMLNode {
return &XMLNode{
Name: name,
Children: map[string][]*XMLNode{},
Attr: []xml.Attr{},
}
}
// AddChild adds child to the XMLNode.
func (n *XMLNode) AddChild(child *XMLNode) {
if _, ok := n.Children[child.Name.Local]; !ok {
n.Children[child.Name.Local] = []*XMLNode{}
}
n.Children[child.Name.Local] = append(n.Children[child.Name.Local], child)
}
// XMLToStruct converts a xml.Decoder stream to XMLNode with nested values.
func XMLToStruct(d *xml.Decoder, s *xml.StartElement) (*XMLNode, error) {
out := &XMLNode{}
for {
tok, err := d.Token()
if tok == nil || err == io.EOF {
break
}
if err != nil {
return out, err
}
switch typed := tok.(type) {
case xml.CharData:
out.Text = string(typed.Copy())
case xml.StartElement:
el := typed.Copy()
out.Attr = el.Attr
if out.Children == nil {
out.Children = map[string][]*XMLNode{}
}
name := typed.Name.Local
slice := out.Children[name]
if slice == nil {
slice = []*XMLNode{}
}
node, e := XMLToStruct(d, &el)
if e != nil {
return out, e
}
node.Name = typed.Name
slice = append(slice, node)
out.Children[name] = slice
case xml.EndElement:
if s != nil && s.Name.Local == typed.Name.Local { // matching end token
return out, nil
}
}
}
return out, nil
}
// StructToXML writes an XMLNode to a xml.Encoder as tokens.
func StructToXML(e *xml.Encoder, node *XMLNode, sorted bool) error {
e.EncodeToken(xml.StartElement{Name: node.Name, Attr: node.Attr})
if node.Text != "" {
e.EncodeToken(xml.CharData([]byte(node.Text)))
} else if sorted {
sortedNames := []string{}
for k := range node.Children {
sortedNames = append(sortedNames, k)
}
sort.Strings(sortedNames)
for _, k := range sortedNames {
for _, v := range node.Children[k] {
StructToXML(e, v, sorted)
}
}
} else {
for _, c := range node.Children {
for _, v := range c {
StructToXML(e, v, sorted)
}
}
}
e.EncodeToken(xml.EndElement{Name: node.Name})
return e.Flush()
}

View File

@@ -0,0 +1,45 @@
// +build !integration
package v4_test
import (
"net/url"
"testing"
"time"
"github.com/awslabs/aws-sdk-go/aws"
"github.com/awslabs/aws-sdk-go/internal/test/unit"
"github.com/awslabs/aws-sdk-go/service/s3"
"github.com/stretchr/testify/assert"
)
var _ = unit.Imported
func TestPresignHandler(t *testing.T) {
svc := s3.New(nil)
req, _ := svc.PutObjectRequest(&s3.PutObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
ContentDisposition: aws.String("a+b c$d"),
ACL: aws.String("public-read"),
})
req.Time = time.Unix(0, 0)
urlstr, err := req.Presign(5 * time.Minute)
assert.NoError(t, err)
expectedDate := "19700101T000000Z"
expectedHeaders := "host;x-amz-acl"
expectedSig := "7edcb4e3a1bf12f4989018d75acbe3a7f03df24bd6f3112602d59fc551f0e4e2"
expectedCred := "AKID/19700101/mock-region/s3/aws4_request"
u, _ := url.Parse(urlstr)
urlQ := u.Query()
assert.Equal(t, expectedSig, urlQ.Get("X-Amz-Signature"))
assert.Equal(t, expectedCred, urlQ.Get("X-Amz-Credential"))
assert.Equal(t, expectedHeaders, urlQ.Get("X-Amz-SignedHeaders"))
assert.Equal(t, expectedDate, urlQ.Get("X-Amz-Date"))
assert.Equal(t, "300", urlQ.Get("X-Amz-Expires"))
assert.NotContains(t, urlstr, "+") // + encoded as %20
}

View File

@@ -0,0 +1,320 @@
package v4
import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"net/http"
"net/url"
"sort"
"strconv"
"strings"
"time"
"github.com/awslabs/aws-sdk-go/aws/credentials"
"github.com/awslabs/aws-sdk-go/aws"
)
const (
authHeaderPrefix = "AWS4-HMAC-SHA256"
timeFormat = "20060102T150405Z"
shortTimeFormat = "20060102"
)
var ignoredHeaders = map[string]bool{
"Authorization": true,
"Content-Type": true,
"Content-Length": true,
"User-Agent": true,
}
type signer struct {
Request *http.Request
Time time.Time
ExpireTime time.Duration
ServiceName string
Region string
AccessKeyID string
SecretAccessKey string
SessionToken string
Query url.Values
Body io.ReadSeeker
Debug uint
Logger io.Writer
isPresign bool
formattedTime string
formattedShortTime string
signedHeaders string
canonicalHeaders string
canonicalString string
credentialString string
stringToSign string
signature string
authorization string
}
// Sign requests with signature version 4.
//
// Will sign the requests with the service config's Credentials object
// Signing is skipped if the credentials is the credentials.AnonymousCredentials
// object.
func Sign(req *aws.Request) {
// If the request does not need to be signed ignore the signing of the
// request if the AnonymousCredentials object is used.
if req.Service.Config.Credentials == credentials.AnonymousCredentials {
return
}
creds, err := req.Service.Config.Credentials.Get()
if err != nil {
req.Error = err
return
}
region := req.Service.SigningRegion
if region == "" {
region = req.Service.Config.Region
}
name := req.Service.SigningName
if name == "" {
name = req.Service.ServiceName
}
s := signer{
Request: req.HTTPRequest,
Time: req.Time,
ExpireTime: req.ExpireTime,
Query: req.HTTPRequest.URL.Query(),
Body: req.Body,
ServiceName: name,
Region: region,
AccessKeyID: creds.AccessKeyID,
SecretAccessKey: creds.SecretAccessKey,
SessionToken: creds.SessionToken,
Debug: req.Service.Config.LogLevel,
Logger: req.Service.Config.Logger,
}
s.sign()
return
}
func (v4 *signer) sign() {
if v4.ExpireTime != 0 {
v4.isPresign = true
}
if v4.isPresign {
v4.Query.Set("X-Amz-Algorithm", authHeaderPrefix)
if v4.SessionToken != "" {
v4.Query.Set("X-Amz-Security-Token", v4.SessionToken)
} else {
v4.Query.Del("X-Amz-Security-Token")
}
} else if v4.SessionToken != "" {
v4.Request.Header.Set("X-Amz-Security-Token", v4.SessionToken)
}
v4.build()
if v4.Debug > 0 {
out := v4.Logger
fmt.Fprintf(out, "---[ CANONICAL STRING ]-----------------------------\n")
fmt.Fprintln(out, v4.canonicalString)
fmt.Fprintf(out, "---[ STRING TO SIGN ]--------------------------------\n")
fmt.Fprintln(out, v4.stringToSign)
if v4.isPresign {
fmt.Fprintf(out, "---[ SIGNED URL ]--------------------------------\n")
fmt.Fprintln(out, v4.Request.URL)
}
fmt.Fprintf(out, "-----------------------------------------------------\n")
}
}
func (v4 *signer) build() {
v4.buildTime() // no depends
v4.buildCredentialString() // no depends
if v4.isPresign {
v4.buildQuery() // no depends
}
v4.buildCanonicalHeaders() // depends on cred string
v4.buildCanonicalString() // depends on canon headers / signed headers
v4.buildStringToSign() // depends on canon string
v4.buildSignature() // depends on string to sign
if v4.isPresign {
v4.Request.URL.RawQuery += "&X-Amz-Signature=" + v4.signature
} else {
parts := []string{
authHeaderPrefix + " Credential=" + v4.AccessKeyID + "/" + v4.credentialString,
"SignedHeaders=" + v4.signedHeaders,
"Signature=" + v4.signature,
}
v4.Request.Header.Set("Authorization", strings.Join(parts, ", "))
}
}
func (v4 *signer) buildTime() {
v4.formattedTime = v4.Time.UTC().Format(timeFormat)
v4.formattedShortTime = v4.Time.UTC().Format(shortTimeFormat)
if v4.isPresign {
duration := int64(v4.ExpireTime / time.Second)
v4.Query.Set("X-Amz-Date", v4.formattedTime)
v4.Query.Set("X-Amz-Expires", strconv.FormatInt(duration, 10))
} else {
v4.Request.Header.Set("X-Amz-Date", v4.formattedTime)
}
}
func (v4 *signer) buildCredentialString() {
v4.credentialString = strings.Join([]string{
v4.formattedShortTime,
v4.Region,
v4.ServiceName,
"aws4_request",
}, "/")
if v4.isPresign {
v4.Query.Set("X-Amz-Credential", v4.AccessKeyID+"/"+v4.credentialString)
}
}
func (v4 *signer) buildQuery() {
for k, h := range v4.Request.Header {
if strings.HasPrefix(http.CanonicalHeaderKey(k), "X-Amz-") {
continue // never hoist x-amz-* headers, they must be signed
}
if _, ok := ignoredHeaders[http.CanonicalHeaderKey(k)]; ok {
continue // never hoist ignored headers
}
v4.Request.Header.Del(k)
v4.Query.Del(k)
for _, v := range h {
v4.Query.Add(k, v)
}
}
}
func (v4 *signer) buildCanonicalHeaders() {
var headers []string
headers = append(headers, "host")
for k := range v4.Request.Header {
if _, ok := ignoredHeaders[http.CanonicalHeaderKey(k)]; ok {
continue // ignored header
}
headers = append(headers, strings.ToLower(k))
}
sort.Strings(headers)
v4.signedHeaders = strings.Join(headers, ";")
if v4.isPresign {
v4.Query.Set("X-Amz-SignedHeaders", v4.signedHeaders)
}
headerValues := make([]string, len(headers))
for i, k := range headers {
if k == "host" {
headerValues[i] = "host:" + v4.Request.URL.Host
} else {
headerValues[i] = k + ":" +
strings.Join(v4.Request.Header[http.CanonicalHeaderKey(k)], ",")
}
}
v4.canonicalHeaders = strings.Join(headerValues, "\n")
}
func (v4 *signer) buildCanonicalString() {
v4.Request.URL.RawQuery = strings.Replace(v4.Query.Encode(), "+", "%20", -1)
uri := v4.Request.URL.Opaque
if uri != "" {
uri = "/" + strings.Join(strings.Split(uri, "/")[3:], "/")
} else {
uri = v4.Request.URL.Path
}
if uri == "" {
uri = "/"
}
v4.canonicalString = strings.Join([]string{
v4.Request.Method,
uri,
v4.Request.URL.RawQuery,
v4.canonicalHeaders + "\n",
v4.signedHeaders,
v4.bodyDigest(),
}, "\n")
}
func (v4 *signer) buildStringToSign() {
v4.stringToSign = strings.Join([]string{
authHeaderPrefix,
v4.formattedTime,
v4.credentialString,
hex.EncodeToString(makeSha256([]byte(v4.canonicalString))),
}, "\n")
}
func (v4 *signer) buildSignature() {
secret := v4.SecretAccessKey
date := makeHmac([]byte("AWS4"+secret), []byte(v4.formattedShortTime))
region := makeHmac(date, []byte(v4.Region))
service := makeHmac(region, []byte(v4.ServiceName))
credentials := makeHmac(service, []byte("aws4_request"))
signature := makeHmac(credentials, []byte(v4.stringToSign))
v4.signature = hex.EncodeToString(signature)
}
func (v4 *signer) bodyDigest() string {
hash := v4.Request.Header.Get("X-Amz-Content-Sha256")
if hash == "" {
if v4.isPresign && v4.ServiceName == "s3" {
hash = "UNSIGNED-PAYLOAD"
} else if v4.Body == nil {
hash = hex.EncodeToString(makeSha256([]byte{}))
} else {
hash = hex.EncodeToString(makeSha256Reader(v4.Body))
}
v4.Request.Header.Add("X-Amz-Content-Sha256", hash)
}
return hash
}
func makeHmac(key []byte, data []byte) []byte {
hash := hmac.New(sha256.New, key)
hash.Write(data)
return hash.Sum(nil)
}
func makeSha256(data []byte) []byte {
hash := sha256.New()
hash.Write(data)
return hash.Sum(nil)
}
func makeSha256Reader(reader io.ReadSeeker) []byte {
packet := make([]byte, 4096)
hash := sha256.New()
start, _ := reader.Seek(0, 1)
defer reader.Seek(start, 0)
for {
n, err := reader.Read(packet)
if n > 0 {
hash.Write(packet[0:n])
}
if err == io.EOF || n == 0 {
break
}
}
return hash.Sum(nil)
}

View File

@@ -0,0 +1,156 @@
package v4
import (
"net/http"
"strings"
"testing"
"time"
"github.com/awslabs/aws-sdk-go/aws"
"github.com/awslabs/aws-sdk-go/aws/credentials"
"github.com/stretchr/testify/assert"
)
func buildSigner(serviceName string, region string, signTime time.Time, expireTime time.Duration, body string) signer {
endpoint := "https://" + serviceName + "." + region + ".amazonaws.com"
reader := strings.NewReader(body)
req, _ := http.NewRequest("POST", endpoint, reader)
req.URL.Opaque = "//example.org/bucket/key-._~,!@#$%^&*()"
req.Header.Add("X-Amz-Target", "prefix.Operation")
req.Header.Add("Content-Type", "application/x-amz-json-1.0")
req.Header.Add("Content-Length", string(len(body)))
req.Header.Add("X-Amz-Meta-Other-Header", "some-value=!@#$%^&* (+)")
return signer{
Request: req,
Time: signTime,
ExpireTime: expireTime,
Query: req.URL.Query(),
Body: reader,
ServiceName: serviceName,
Region: region,
AccessKeyID: "AKID",
SecretAccessKey: "SECRET",
SessionToken: "SESSION",
}
}
func removeWS(text string) string {
text = strings.Replace(text, " ", "", -1)
text = strings.Replace(text, "\n", "", -1)
text = strings.Replace(text, "\t", "", -1)
return text
}
func assertEqual(t *testing.T, expected, given string) {
if removeWS(expected) != removeWS(given) {
t.Errorf("\nExpected: %s\nGiven: %s", expected, given)
}
}
func TestPresignRequest(t *testing.T) {
signer := buildSigner("dynamodb", "us-east-1", time.Unix(0, 0), 300*time.Second, "{}")
signer.sign()
expectedDate := "19700101T000000Z"
expectedHeaders := "host;x-amz-meta-other-header;x-amz-target"
expectedSig := "4c86bacebb78e458e8e898e93547513079d67ab95d3d5c02236889e21ea0df2b"
expectedCred := "AKID/19700101/us-east-1/dynamodb/aws4_request"
q := signer.Request.URL.Query()
assert.Equal(t, expectedSig, q.Get("X-Amz-Signature"))
assert.Equal(t, expectedCred, q.Get("X-Amz-Credential"))
assert.Equal(t, expectedHeaders, q.Get("X-Amz-SignedHeaders"))
assert.Equal(t, expectedDate, q.Get("X-Amz-Date"))
}
func TestSignRequest(t *testing.T) {
signer := buildSigner("dynamodb", "us-east-1", time.Unix(0, 0), 0, "{}")
signer.sign()
expectedDate := "19700101T000000Z"
expectedSig := "AWS4-HMAC-SHA256 Credential=AKID/19700101/us-east-1/dynamodb/aws4_request, SignedHeaders=host;x-amz-date;x-amz-meta-other-header;x-amz-security-token;x-amz-target, Signature=d4df864b291252d2c9206be75963db54db09ca70265622cc787cbe50ca0efcc8"
q := signer.Request.Header
assert.Equal(t, expectedSig, q.Get("Authorization"))
assert.Equal(t, expectedDate, q.Get("X-Amz-Date"))
}
func TestSignEmptyBody(t *testing.T) {
signer := buildSigner("dynamodb", "us-east-1", time.Now(), 0, "")
signer.Body = nil
signer.sign()
hash := signer.Request.Header.Get("X-Amz-Content-Sha256")
assert.Equal(t, "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", hash)
}
func TestSignBody(t *testing.T) {
signer := buildSigner("dynamodb", "us-east-1", time.Now(), 0, "hello")
signer.sign()
hash := signer.Request.Header.Get("X-Amz-Content-Sha256")
assert.Equal(t, "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", hash)
}
func TestSignSeekedBody(t *testing.T) {
signer := buildSigner("dynamodb", "us-east-1", time.Now(), 0, " hello")
signer.Body.Read(make([]byte, 3)) // consume first 3 bytes so body is now "hello"
signer.sign()
hash := signer.Request.Header.Get("X-Amz-Content-Sha256")
assert.Equal(t, "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", hash)
start, _ := signer.Body.Seek(0, 1)
assert.Equal(t, int64(3), start)
}
func TestPresignEmptyBodyS3(t *testing.T) {
signer := buildSigner("s3", "us-east-1", time.Now(), 5*time.Minute, "hello")
signer.sign()
hash := signer.Request.Header.Get("X-Amz-Content-Sha256")
assert.Equal(t, "UNSIGNED-PAYLOAD", hash)
}
func TestSignPrecomputedBodyChecksum(t *testing.T) {
signer := buildSigner("dynamodb", "us-east-1", time.Now(), 0, "hello")
signer.Request.Header.Set("X-Amz-Content-Sha256", "PRECOMPUTED")
signer.sign()
hash := signer.Request.Header.Get("X-Amz-Content-Sha256")
assert.Equal(t, "PRECOMPUTED", hash)
}
func TestAnonymousCredentials(t *testing.T) {
r := aws.NewRequest(
aws.NewService(&aws.Config{Credentials: credentials.AnonymousCredentials}),
&aws.Operation{
Name: "BatchGetItem",
HTTPMethod: "POST",
HTTPPath: "/",
},
nil,
nil,
)
Sign(r)
urlQ := r.HTTPRequest.URL.Query()
assert.Empty(t, urlQ.Get("X-Amz-Signature"))
assert.Empty(t, urlQ.Get("X-Amz-Credential"))
assert.Empty(t, urlQ.Get("X-Amz-SignedHeaders"))
assert.Empty(t, urlQ.Get("X-Amz-Date"))
hQ := r.HTTPRequest.Header
assert.Empty(t, hQ.Get("Authorization"))
assert.Empty(t, hQ.Get("X-Amz-Date"))
}
func BenchmarkPresignRequest(b *testing.B) {
signer := buildSigner("dynamodb", "us-east-1", time.Now(), 300*time.Second, "{}")
for i := 0; i < b.N; i++ {
signer.sign()
}
}
func BenchmarkSignRequest(b *testing.B) {
signer := buildSigner("dynamodb", "us-east-1", time.Now(), 0, "{}")
for i := 0; i < b.N; i++ {
signer.sign()
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,57 @@
package ec2
import (
"time"
"github.com/awslabs/aws-sdk-go/aws"
"github.com/awslabs/aws-sdk-go/aws/awsutil"
)
func init() {
initRequest = func(r *aws.Request) {
if r.Operation == opCopySnapshot { // fill the PresignedURL parameter
r.Handlers.Build.PushFront(fillPresignedURL)
}
}
}
func fillPresignedURL(r *aws.Request) {
if !r.ParamsFilled() {
return
}
params := r.Params.(*CopySnapshotInput)
// Stop if PresignedURL/DestinationRegion is set
if params.PresignedURL != nil || params.DestinationRegion != nil {
return
}
// First generate a copy of parameters
r.Params = awsutil.CopyOf(r.Params)
params = r.Params.(*CopySnapshotInput)
// Set destination region. Avoids infinite handler loop.
// Also needed to sign sub-request.
params.DestinationRegion = &r.Service.Config.Region
// Create a new client pointing at source region.
// We will use this to presign the CopySnapshot request against
// the source region
config := r.Service.Config.Copy()
config.Endpoint = ""
config.Region = *params.SourceRegion
client := New(&config)
// Presign a CopySnapshot request with modified params
req, _ := client.CopySnapshotRequest(params)
url, err := req.Presign(300 * time.Second) // 5 minutes should be enough.
if err != nil { // bubble error back up to original request
r.Error = err
}
// We have our URL, set it on params
params.PresignedURL = &url
}

View File

@@ -0,0 +1,38 @@
// +build !integration
package ec2_test
import (
"io/ioutil"
"net/url"
"testing"
"github.com/awslabs/aws-sdk-go/aws"
"github.com/awslabs/aws-sdk-go/internal/test/unit"
"github.com/awslabs/aws-sdk-go/service/ec2"
"github.com/stretchr/testify/assert"
)
var _ = unit.Imported
func TestCopySnapshotPresignedURL(t *testing.T) {
svc := ec2.New(&aws.Config{Region: "us-west-2"})
assert.NotPanics(t, func() {
// Doesn't panic on nil input
req, _ := svc.CopySnapshotRequest(nil)
req.Sign()
})
req, _ := svc.CopySnapshotRequest(&ec2.CopySnapshotInput{
SourceRegion: aws.String("us-west-1"),
SourceSnapshotID: aws.String("snap-id"),
})
req.Sign()
b, _ := ioutil.ReadAll(req.HTTPRequest.Body)
q, _ := url.ParseQuery(string(b))
url, _ := url.QueryUnescape(q.Get("PresignedUrl"))
assert.Equal(t, "us-west-2", q.Get("DestinationRegion"))
assert.Regexp(t, `^https://ec2\.us-west-1\.amazon.+&DestinationRegion=us-west-2`, url)
}

View File

@@ -0,0 +1,337 @@
package ec2iface
import (
"github.com/awslabs/aws-sdk-go/service/ec2"
)
type EC2API interface {
AcceptVPCPeeringConnection(*ec2.AcceptVPCPeeringConnectionInput) (*ec2.AcceptVPCPeeringConnectionOutput, error)
AllocateAddress(*ec2.AllocateAddressInput) (*ec2.AllocateAddressOutput, error)
AssignPrivateIPAddresses(*ec2.AssignPrivateIPAddressesInput) (*ec2.AssignPrivateIPAddressesOutput, error)
AssociateAddress(*ec2.AssociateAddressInput) (*ec2.AssociateAddressOutput, error)
AssociateDHCPOptions(*ec2.AssociateDHCPOptionsInput) (*ec2.AssociateDHCPOptionsOutput, error)
AssociateRouteTable(*ec2.AssociateRouteTableInput) (*ec2.AssociateRouteTableOutput, error)
AttachClassicLinkVPC(*ec2.AttachClassicLinkVPCInput) (*ec2.AttachClassicLinkVPCOutput, error)
AttachInternetGateway(*ec2.AttachInternetGatewayInput) (*ec2.AttachInternetGatewayOutput, error)
AttachNetworkInterface(*ec2.AttachNetworkInterfaceInput) (*ec2.AttachNetworkInterfaceOutput, error)
AttachVPNGateway(*ec2.AttachVPNGatewayInput) (*ec2.AttachVPNGatewayOutput, error)
AttachVolume(*ec2.AttachVolumeInput) (*ec2.VolumeAttachment, error)
AuthorizeSecurityGroupEgress(*ec2.AuthorizeSecurityGroupEgressInput) (*ec2.AuthorizeSecurityGroupEgressOutput, error)
AuthorizeSecurityGroupIngress(*ec2.AuthorizeSecurityGroupIngressInput) (*ec2.AuthorizeSecurityGroupIngressOutput, error)
BundleInstance(*ec2.BundleInstanceInput) (*ec2.BundleInstanceOutput, error)
CancelBundleTask(*ec2.CancelBundleTaskInput) (*ec2.CancelBundleTaskOutput, error)
CancelConversionTask(*ec2.CancelConversionTaskInput) (*ec2.CancelConversionTaskOutput, error)
CancelExportTask(*ec2.CancelExportTaskInput) (*ec2.CancelExportTaskOutput, error)
CancelImportTask(*ec2.CancelImportTaskInput) (*ec2.CancelImportTaskOutput, error)
CancelReservedInstancesListing(*ec2.CancelReservedInstancesListingInput) (*ec2.CancelReservedInstancesListingOutput, error)
CancelSpotInstanceRequests(*ec2.CancelSpotInstanceRequestsInput) (*ec2.CancelSpotInstanceRequestsOutput, error)
ConfirmProductInstance(*ec2.ConfirmProductInstanceInput) (*ec2.ConfirmProductInstanceOutput, error)
CopyImage(*ec2.CopyImageInput) (*ec2.CopyImageOutput, error)
CopySnapshot(*ec2.CopySnapshotInput) (*ec2.CopySnapshotOutput, error)
CreateCustomerGateway(*ec2.CreateCustomerGatewayInput) (*ec2.CreateCustomerGatewayOutput, error)
CreateDHCPOptions(*ec2.CreateDHCPOptionsInput) (*ec2.CreateDHCPOptionsOutput, error)
CreateImage(*ec2.CreateImageInput) (*ec2.CreateImageOutput, error)
CreateInstanceExportTask(*ec2.CreateInstanceExportTaskInput) (*ec2.CreateInstanceExportTaskOutput, error)
CreateInternetGateway(*ec2.CreateInternetGatewayInput) (*ec2.CreateInternetGatewayOutput, error)
CreateKeyPair(*ec2.CreateKeyPairInput) (*ec2.CreateKeyPairOutput, error)
CreateNetworkACL(*ec2.CreateNetworkACLInput) (*ec2.CreateNetworkACLOutput, error)
CreateNetworkACLEntry(*ec2.CreateNetworkACLEntryInput) (*ec2.CreateNetworkACLEntryOutput, error)
CreateNetworkInterface(*ec2.CreateNetworkInterfaceInput) (*ec2.CreateNetworkInterfaceOutput, error)
CreatePlacementGroup(*ec2.CreatePlacementGroupInput) (*ec2.CreatePlacementGroupOutput, error)
CreateReservedInstancesListing(*ec2.CreateReservedInstancesListingInput) (*ec2.CreateReservedInstancesListingOutput, error)
CreateRoute(*ec2.CreateRouteInput) (*ec2.CreateRouteOutput, error)
CreateRouteTable(*ec2.CreateRouteTableInput) (*ec2.CreateRouteTableOutput, error)
CreateSecurityGroup(*ec2.CreateSecurityGroupInput) (*ec2.CreateSecurityGroupOutput, error)
CreateSnapshot(*ec2.CreateSnapshotInput) (*ec2.Snapshot, error)
CreateSpotDatafeedSubscription(*ec2.CreateSpotDatafeedSubscriptionInput) (*ec2.CreateSpotDatafeedSubscriptionOutput, error)
CreateSubnet(*ec2.CreateSubnetInput) (*ec2.CreateSubnetOutput, error)
CreateTags(*ec2.CreateTagsInput) (*ec2.CreateTagsOutput, error)
CreateVPC(*ec2.CreateVPCInput) (*ec2.CreateVPCOutput, error)
CreateVPCPeeringConnection(*ec2.CreateVPCPeeringConnectionInput) (*ec2.CreateVPCPeeringConnectionOutput, error)
CreateVPNConnection(*ec2.CreateVPNConnectionInput) (*ec2.CreateVPNConnectionOutput, error)
CreateVPNConnectionRoute(*ec2.CreateVPNConnectionRouteInput) (*ec2.CreateVPNConnectionRouteOutput, error)
CreateVPNGateway(*ec2.CreateVPNGatewayInput) (*ec2.CreateVPNGatewayOutput, error)
CreateVolume(*ec2.CreateVolumeInput) (*ec2.Volume, error)
DeleteCustomerGateway(*ec2.DeleteCustomerGatewayInput) (*ec2.DeleteCustomerGatewayOutput, error)
DeleteDHCPOptions(*ec2.DeleteDHCPOptionsInput) (*ec2.DeleteDHCPOptionsOutput, error)
DeleteInternetGateway(*ec2.DeleteInternetGatewayInput) (*ec2.DeleteInternetGatewayOutput, error)
DeleteKeyPair(*ec2.DeleteKeyPairInput) (*ec2.DeleteKeyPairOutput, error)
DeleteNetworkACL(*ec2.DeleteNetworkACLInput) (*ec2.DeleteNetworkACLOutput, error)
DeleteNetworkACLEntry(*ec2.DeleteNetworkACLEntryInput) (*ec2.DeleteNetworkACLEntryOutput, error)
DeleteNetworkInterface(*ec2.DeleteNetworkInterfaceInput) (*ec2.DeleteNetworkInterfaceOutput, error)
DeletePlacementGroup(*ec2.DeletePlacementGroupInput) (*ec2.DeletePlacementGroupOutput, error)
DeleteRoute(*ec2.DeleteRouteInput) (*ec2.DeleteRouteOutput, error)
DeleteRouteTable(*ec2.DeleteRouteTableInput) (*ec2.DeleteRouteTableOutput, error)
DeleteSecurityGroup(*ec2.DeleteSecurityGroupInput) (*ec2.DeleteSecurityGroupOutput, error)
DeleteSnapshot(*ec2.DeleteSnapshotInput) (*ec2.DeleteSnapshotOutput, error)
DeleteSpotDatafeedSubscription(*ec2.DeleteSpotDatafeedSubscriptionInput) (*ec2.DeleteSpotDatafeedSubscriptionOutput, error)
DeleteSubnet(*ec2.DeleteSubnetInput) (*ec2.DeleteSubnetOutput, error)
DeleteTags(*ec2.DeleteTagsInput) (*ec2.DeleteTagsOutput, error)
DeleteVPC(*ec2.DeleteVPCInput) (*ec2.DeleteVPCOutput, error)
DeleteVPCPeeringConnection(*ec2.DeleteVPCPeeringConnectionInput) (*ec2.DeleteVPCPeeringConnectionOutput, error)
DeleteVPNConnection(*ec2.DeleteVPNConnectionInput) (*ec2.DeleteVPNConnectionOutput, error)
DeleteVPNConnectionRoute(*ec2.DeleteVPNConnectionRouteInput) (*ec2.DeleteVPNConnectionRouteOutput, error)
DeleteVPNGateway(*ec2.DeleteVPNGatewayInput) (*ec2.DeleteVPNGatewayOutput, error)
DeleteVolume(*ec2.DeleteVolumeInput) (*ec2.DeleteVolumeOutput, error)
DeregisterImage(*ec2.DeregisterImageInput) (*ec2.DeregisterImageOutput, error)
DescribeAccountAttributes(*ec2.DescribeAccountAttributesInput) (*ec2.DescribeAccountAttributesOutput, error)
DescribeAddresses(*ec2.DescribeAddressesInput) (*ec2.DescribeAddressesOutput, error)
DescribeAvailabilityZones(*ec2.DescribeAvailabilityZonesInput) (*ec2.DescribeAvailabilityZonesOutput, error)
DescribeBundleTasks(*ec2.DescribeBundleTasksInput) (*ec2.DescribeBundleTasksOutput, error)
DescribeClassicLinkInstances(*ec2.DescribeClassicLinkInstancesInput) (*ec2.DescribeClassicLinkInstancesOutput, error)
DescribeConversionTasks(*ec2.DescribeConversionTasksInput) (*ec2.DescribeConversionTasksOutput, error)
DescribeCustomerGateways(*ec2.DescribeCustomerGatewaysInput) (*ec2.DescribeCustomerGatewaysOutput, error)
DescribeDHCPOptions(*ec2.DescribeDHCPOptionsInput) (*ec2.DescribeDHCPOptionsOutput, error)
DescribeExportTasks(*ec2.DescribeExportTasksInput) (*ec2.DescribeExportTasksOutput, error)
DescribeImageAttribute(*ec2.DescribeImageAttributeInput) (*ec2.DescribeImageAttributeOutput, error)
DescribeImages(*ec2.DescribeImagesInput) (*ec2.DescribeImagesOutput, error)
DescribeImportImageTasks(*ec2.DescribeImportImageTasksInput) (*ec2.DescribeImportImageTasksOutput, error)
DescribeImportSnapshotTasks(*ec2.DescribeImportSnapshotTasksInput) (*ec2.DescribeImportSnapshotTasksOutput, error)
DescribeInstanceAttribute(*ec2.DescribeInstanceAttributeInput) (*ec2.DescribeInstanceAttributeOutput, error)
DescribeInstanceStatus(*ec2.DescribeInstanceStatusInput) (*ec2.DescribeInstanceStatusOutput, error)
DescribeInstances(*ec2.DescribeInstancesInput) (*ec2.DescribeInstancesOutput, error)
DescribeInternetGateways(*ec2.DescribeInternetGatewaysInput) (*ec2.DescribeInternetGatewaysOutput, error)
DescribeKeyPairs(*ec2.DescribeKeyPairsInput) (*ec2.DescribeKeyPairsOutput, error)
DescribeNetworkACLs(*ec2.DescribeNetworkACLsInput) (*ec2.DescribeNetworkACLsOutput, error)
DescribeNetworkInterfaceAttribute(*ec2.DescribeNetworkInterfaceAttributeInput) (*ec2.DescribeNetworkInterfaceAttributeOutput, error)
DescribeNetworkInterfaces(*ec2.DescribeNetworkInterfacesInput) (*ec2.DescribeNetworkInterfacesOutput, error)
DescribePlacementGroups(*ec2.DescribePlacementGroupsInput) (*ec2.DescribePlacementGroupsOutput, error)
DescribeRegions(*ec2.DescribeRegionsInput) (*ec2.DescribeRegionsOutput, error)
DescribeReservedInstances(*ec2.DescribeReservedInstancesInput) (*ec2.DescribeReservedInstancesOutput, error)
DescribeReservedInstancesListings(*ec2.DescribeReservedInstancesListingsInput) (*ec2.DescribeReservedInstancesListingsOutput, error)
DescribeReservedInstancesModifications(*ec2.DescribeReservedInstancesModificationsInput) (*ec2.DescribeReservedInstancesModificationsOutput, error)
DescribeReservedInstancesOfferings(*ec2.DescribeReservedInstancesOfferingsInput) (*ec2.DescribeReservedInstancesOfferingsOutput, error)
DescribeRouteTables(*ec2.DescribeRouteTablesInput) (*ec2.DescribeRouteTablesOutput, error)
DescribeSecurityGroups(*ec2.DescribeSecurityGroupsInput) (*ec2.DescribeSecurityGroupsOutput, error)
DescribeSnapshotAttribute(*ec2.DescribeSnapshotAttributeInput) (*ec2.DescribeSnapshotAttributeOutput, error)
DescribeSnapshots(*ec2.DescribeSnapshotsInput) (*ec2.DescribeSnapshotsOutput, error)
DescribeSpotDatafeedSubscription(*ec2.DescribeSpotDatafeedSubscriptionInput) (*ec2.DescribeSpotDatafeedSubscriptionOutput, error)
DescribeSpotInstanceRequests(*ec2.DescribeSpotInstanceRequestsInput) (*ec2.DescribeSpotInstanceRequestsOutput, error)
DescribeSpotPriceHistory(*ec2.DescribeSpotPriceHistoryInput) (*ec2.DescribeSpotPriceHistoryOutput, error)
DescribeSubnets(*ec2.DescribeSubnetsInput) (*ec2.DescribeSubnetsOutput, error)
DescribeTags(*ec2.DescribeTagsInput) (*ec2.DescribeTagsOutput, error)
DescribeVPCAttribute(*ec2.DescribeVPCAttributeInput) (*ec2.DescribeVPCAttributeOutput, error)
DescribeVPCClassicLink(*ec2.DescribeVPCClassicLinkInput) (*ec2.DescribeVPCClassicLinkOutput, error)
DescribeVPCPeeringConnections(*ec2.DescribeVPCPeeringConnectionsInput) (*ec2.DescribeVPCPeeringConnectionsOutput, error)
DescribeVPCs(*ec2.DescribeVPCsInput) (*ec2.DescribeVPCsOutput, error)
DescribeVPNConnections(*ec2.DescribeVPNConnectionsInput) (*ec2.DescribeVPNConnectionsOutput, error)
DescribeVPNGateways(*ec2.DescribeVPNGatewaysInput) (*ec2.DescribeVPNGatewaysOutput, error)
DescribeVolumeAttribute(*ec2.DescribeVolumeAttributeInput) (*ec2.DescribeVolumeAttributeOutput, error)
DescribeVolumeStatus(*ec2.DescribeVolumeStatusInput) (*ec2.DescribeVolumeStatusOutput, error)
DescribeVolumes(*ec2.DescribeVolumesInput) (*ec2.DescribeVolumesOutput, error)
DetachClassicLinkVPC(*ec2.DetachClassicLinkVPCInput) (*ec2.DetachClassicLinkVPCOutput, error)
DetachInternetGateway(*ec2.DetachInternetGatewayInput) (*ec2.DetachInternetGatewayOutput, error)
DetachNetworkInterface(*ec2.DetachNetworkInterfaceInput) (*ec2.DetachNetworkInterfaceOutput, error)
DetachVPNGateway(*ec2.DetachVPNGatewayInput) (*ec2.DetachVPNGatewayOutput, error)
DetachVolume(*ec2.DetachVolumeInput) (*ec2.VolumeAttachment, error)
DisableVGWRoutePropagation(*ec2.DisableVGWRoutePropagationInput) (*ec2.DisableVGWRoutePropagationOutput, error)
DisableVPCClassicLink(*ec2.DisableVPCClassicLinkInput) (*ec2.DisableVPCClassicLinkOutput, error)
DisassociateAddress(*ec2.DisassociateAddressInput) (*ec2.DisassociateAddressOutput, error)
DisassociateRouteTable(*ec2.DisassociateRouteTableInput) (*ec2.DisassociateRouteTableOutput, error)
EnableVGWRoutePropagation(*ec2.EnableVGWRoutePropagationInput) (*ec2.EnableVGWRoutePropagationOutput, error)
EnableVPCClassicLink(*ec2.EnableVPCClassicLinkInput) (*ec2.EnableVPCClassicLinkOutput, error)
EnableVolumeIO(*ec2.EnableVolumeIOInput) (*ec2.EnableVolumeIOOutput, error)
GetConsoleOutput(*ec2.GetConsoleOutputInput) (*ec2.GetConsoleOutputOutput, error)
GetPasswordData(*ec2.GetPasswordDataInput) (*ec2.GetPasswordDataOutput, error)
ImportImage(*ec2.ImportImageInput) (*ec2.ImportImageOutput, error)
ImportInstance(*ec2.ImportInstanceInput) (*ec2.ImportInstanceOutput, error)
ImportKeyPair(*ec2.ImportKeyPairInput) (*ec2.ImportKeyPairOutput, error)
ImportSnapshot(*ec2.ImportSnapshotInput) (*ec2.ImportSnapshotOutput, error)
ImportVolume(*ec2.ImportVolumeInput) (*ec2.ImportVolumeOutput, error)
ModifyImageAttribute(*ec2.ModifyImageAttributeInput) (*ec2.ModifyImageAttributeOutput, error)
ModifyInstanceAttribute(*ec2.ModifyInstanceAttributeInput) (*ec2.ModifyInstanceAttributeOutput, error)
ModifyNetworkInterfaceAttribute(*ec2.ModifyNetworkInterfaceAttributeInput) (*ec2.ModifyNetworkInterfaceAttributeOutput, error)
ModifyReservedInstances(*ec2.ModifyReservedInstancesInput) (*ec2.ModifyReservedInstancesOutput, error)
ModifySnapshotAttribute(*ec2.ModifySnapshotAttributeInput) (*ec2.ModifySnapshotAttributeOutput, error)
ModifySubnetAttribute(*ec2.ModifySubnetAttributeInput) (*ec2.ModifySubnetAttributeOutput, error)
ModifyVPCAttribute(*ec2.ModifyVPCAttributeInput) (*ec2.ModifyVPCAttributeOutput, error)
ModifyVolumeAttribute(*ec2.ModifyVolumeAttributeInput) (*ec2.ModifyVolumeAttributeOutput, error)
MonitorInstances(*ec2.MonitorInstancesInput) (*ec2.MonitorInstancesOutput, error)
PurchaseReservedInstancesOffering(*ec2.PurchaseReservedInstancesOfferingInput) (*ec2.PurchaseReservedInstancesOfferingOutput, error)
RebootInstances(*ec2.RebootInstancesInput) (*ec2.RebootInstancesOutput, error)
RegisterImage(*ec2.RegisterImageInput) (*ec2.RegisterImageOutput, error)
RejectVPCPeeringConnection(*ec2.RejectVPCPeeringConnectionInput) (*ec2.RejectVPCPeeringConnectionOutput, error)
ReleaseAddress(*ec2.ReleaseAddressInput) (*ec2.ReleaseAddressOutput, error)
ReplaceNetworkACLAssociation(*ec2.ReplaceNetworkACLAssociationInput) (*ec2.ReplaceNetworkACLAssociationOutput, error)
ReplaceNetworkACLEntry(*ec2.ReplaceNetworkACLEntryInput) (*ec2.ReplaceNetworkACLEntryOutput, error)
ReplaceRoute(*ec2.ReplaceRouteInput) (*ec2.ReplaceRouteOutput, error)
ReplaceRouteTableAssociation(*ec2.ReplaceRouteTableAssociationInput) (*ec2.ReplaceRouteTableAssociationOutput, error)
ReportInstanceStatus(*ec2.ReportInstanceStatusInput) (*ec2.ReportInstanceStatusOutput, error)
RequestSpotInstances(*ec2.RequestSpotInstancesInput) (*ec2.RequestSpotInstancesOutput, error)
ResetImageAttribute(*ec2.ResetImageAttributeInput) (*ec2.ResetImageAttributeOutput, error)
ResetInstanceAttribute(*ec2.ResetInstanceAttributeInput) (*ec2.ResetInstanceAttributeOutput, error)
ResetNetworkInterfaceAttribute(*ec2.ResetNetworkInterfaceAttributeInput) (*ec2.ResetNetworkInterfaceAttributeOutput, error)
ResetSnapshotAttribute(*ec2.ResetSnapshotAttributeInput) (*ec2.ResetSnapshotAttributeOutput, error)
RevokeSecurityGroupEgress(*ec2.RevokeSecurityGroupEgressInput) (*ec2.RevokeSecurityGroupEgressOutput, error)
RevokeSecurityGroupIngress(*ec2.RevokeSecurityGroupIngressInput) (*ec2.RevokeSecurityGroupIngressOutput, error)
RunInstances(*ec2.RunInstancesInput) (*ec2.Reservation, error)
StartInstances(*ec2.StartInstancesInput) (*ec2.StartInstancesOutput, error)
StopInstances(*ec2.StopInstancesInput) (*ec2.StopInstancesOutput, error)
TerminateInstances(*ec2.TerminateInstancesInput) (*ec2.TerminateInstancesOutput, error)
UnassignPrivateIPAddresses(*ec2.UnassignPrivateIPAddressesInput) (*ec2.UnassignPrivateIPAddressesOutput, error)
UnmonitorInstances(*ec2.UnmonitorInstancesInput) (*ec2.UnmonitorInstancesOutput, error)
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,45 @@
// +build integration
package ec2_test
import (
"testing"
"github.com/awslabs/aws-sdk-go/aws"
"github.com/awslabs/aws-sdk-go/internal/test/integration"
"github.com/awslabs/aws-sdk-go/internal/util/utilassert"
"github.com/awslabs/aws-sdk-go/service/ec2"
"github.com/stretchr/testify/assert"
)
var (
_ = assert.Equal
_ = utilassert.Match
_ = integration.Imported
)
func TestMakingABasicRequest(t *testing.T) {
client := ec2.New(nil)
resp, e := client.DescribeRegions(&ec2.DescribeRegionsInput{})
err := aws.Error(e)
_, _, _ = resp, e, err // avoid unused warnings
assert.NoError(t, nil, err)
}
func TestErrorHandling(t *testing.T) {
client := ec2.New(nil)
resp, e := client.DescribeInstances(&ec2.DescribeInstancesInput{
InstanceIDs: []*string{
aws.String("i-12345678"),
},
})
err := aws.Error(e)
_, _, _ = resp, e, err // avoid unused warnings
assert.NotEqual(t, nil, err)
assert.Equal(t, "InvalidInstanceID.NotFound", err.Code)
utilassert.Match(t, "The instance ID 'i-12345678' does not exist", err.Message)
}

View File

@@ -0,0 +1,59 @@
package ec2
import (
"github.com/awslabs/aws-sdk-go/aws"
"github.com/awslabs/aws-sdk-go/internal/protocol/ec2query"
"github.com/awslabs/aws-sdk-go/internal/signer/v4"
)
// EC2 is a client for Amazon EC2.
type EC2 struct {
*aws.Service
}
// Used for custom service initialization logic
var initService func(*aws.Service)
// Used for custom request initialization logic
var initRequest func(*aws.Request)
// New returns a new EC2 client.
func New(config *aws.Config) *EC2 {
if config == nil {
config = &aws.Config{}
}
service := &aws.Service{
Config: aws.DefaultConfig.Merge(config),
ServiceName: "ec2",
APIVersion: "2015-03-01",
}
service.Initialize()
// Handlers
service.Handlers.Sign.PushBack(v4.Sign)
service.Handlers.Build.PushBack(ec2query.Build)
service.Handlers.Unmarshal.PushBack(ec2query.Unmarshal)
service.Handlers.UnmarshalMeta.PushBack(ec2query.UnmarshalMeta)
service.Handlers.UnmarshalError.PushBack(ec2query.UnmarshalError)
// Run custom service initialization if present
if initService != nil {
initService(service)
}
return &EC2{service}
}
// newRequest creates a new request for a EC2 operation and runs any
// custom request initialization.
func (c *EC2) newRequest(op *aws.Operation, params, data interface{}) *aws.Request {
req := aws.NewRequest(c.Service, op, params, data)
// Run custom request initialization if present
if initRequest != nil {
initRequest(req)
}
return req
}

View File

@@ -1,74 +0,0 @@
package aws
import (
"time"
)
// AttemptStrategy represents a strategy for waiting for an action
// to complete successfully. This is an internal type used by the
// implementation of other goamz packages.
type AttemptStrategy struct {
Total time.Duration // total duration of attempt.
Delay time.Duration // interval between each try in the burst.
Min int // minimum number of retries; overrides Total
}
type Attempt struct {
strategy AttemptStrategy
last time.Time
end time.Time
force bool
count int
}
// Start begins a new sequence of attempts for the given strategy.
func (s AttemptStrategy) Start() *Attempt {
now := time.Now()
return &Attempt{
strategy: s,
last: now,
end: now.Add(s.Total),
force: true,
}
}
// Next waits until it is time to perform the next attempt or returns
// false if it is time to stop trying.
func (a *Attempt) Next() bool {
now := time.Now()
sleep := a.nextSleep(now)
if !a.force && !now.Add(sleep).Before(a.end) && a.strategy.Min <= a.count {
return false
}
a.force = false
if sleep > 0 && a.count > 0 {
time.Sleep(sleep)
now = time.Now()
}
a.count++
a.last = now
return true
}
func (a *Attempt) nextSleep(now time.Time) time.Duration {
sleep := a.strategy.Delay - now.Sub(a.last)
if sleep < 0 {
return 0
}
return sleep
}
// HasNext returns whether another attempt will be made if the current
// one fails. If it returns true, the following call to Next is
// guaranteed to return true.
func (a *Attempt) HasNext() bool {
if a.force || a.strategy.Min > a.count {
return true
}
now := time.Now()
if now.Add(a.nextSleep(now)).Before(a.end) {
a.force = true
return true
}
return false
}

View File

@@ -1,57 +0,0 @@
package aws_test
import (
"github.com/mitchellh/goamz/aws"
. "github.com/motain/gocheck"
"time"
)
func (S) TestAttemptTiming(c *C) {
testAttempt := aws.AttemptStrategy{
Total: 0.25e9,
Delay: 0.1e9,
}
want := []time.Duration{0, 0.1e9, 0.2e9, 0.2e9}
got := make([]time.Duration, 0, len(want)) // avoid allocation when testing timing
t0 := time.Now()
for a := testAttempt.Start(); a.Next(); {
got = append(got, time.Now().Sub(t0))
}
got = append(got, time.Now().Sub(t0))
c.Assert(got, HasLen, len(want))
const margin = 0.01e9
for i, got := range want {
lo := want[i] - margin
hi := want[i] + margin
if got < lo || got > hi {
c.Errorf("attempt %d want %g got %g", i, want[i].Seconds(), got.Seconds())
}
}
}
func (S) TestAttemptNextHasNext(c *C) {
a := aws.AttemptStrategy{}.Start()
c.Assert(a.Next(), Equals, true)
c.Assert(a.Next(), Equals, false)
a = aws.AttemptStrategy{}.Start()
c.Assert(a.Next(), Equals, true)
c.Assert(a.HasNext(), Equals, false)
c.Assert(a.Next(), Equals, false)
a = aws.AttemptStrategy{Total: 2e8}.Start()
c.Assert(a.Next(), Equals, true)
c.Assert(a.HasNext(), Equals, true)
time.Sleep(2e8)
c.Assert(a.HasNext(), Equals, true)
c.Assert(a.Next(), Equals, true)
c.Assert(a.Next(), Equals, false)
a = aws.AttemptStrategy{Total: 1e8, Min: 2}.Start()
time.Sleep(1e8)
c.Assert(a.Next(), Equals, true)
c.Assert(a.HasNext(), Equals, true)
c.Assert(a.Next(), Equals, true)
c.Assert(a.HasNext(), Equals, false)
c.Assert(a.Next(), Equals, false)
}

View File

@@ -1,445 +0,0 @@
//
// goamz - Go packages to interact with the Amazon Web Services.
//
// https://wiki.ubuntu.com/goamz
//
// Copyright (c) 2011 Canonical Ltd.
//
// Written by Gustavo Niemeyer <gustavo.niemeyer@canonical.com>
//
package aws
import (
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"os"
"github.com/vaughan0/go-ini"
)
// Region defines the URLs where AWS services may be accessed.
//
// See http://goo.gl/d8BP1 for more details.
type Region struct {
Name string // the canonical name of this region.
EC2Endpoint string
S3Endpoint string
S3BucketEndpoint string // Not needed by AWS S3. Use ${bucket} for bucket name.
S3LocationConstraint bool // true if this region requires a LocationConstraint declaration.
S3LowercaseBucket bool // true if the region requires bucket names to be lower case.
SDBEndpoint string
SNSEndpoint string
SQSEndpoint string
IAMEndpoint string
ELBEndpoint string
AutoScalingEndpoint string
RdsEndpoint string
Route53Endpoint string
}
var USGovWest = Region{
"us-gov-west-1",
"https://ec2.us-gov-west-1.amazonaws.com",
"https://s3-fips-us-gov-west-1.amazonaws.com",
"",
true,
true,
"",
"https://sns.us-gov-west-1.amazonaws.com",
"https://sqs.us-gov-west-1.amazonaws.com",
"https://iam.us-gov.amazonaws.com",
"https://elasticloadbalancing.us-gov-west-1.amazonaws.com",
"https://autoscaling.us-gov-west-1.amazonaws.com",
"https://rds.us-gov-west-1.amazonaws.com",
"https://route53.amazonaws.com",
}
var USEast = Region{
"us-east-1",
"https://ec2.us-east-1.amazonaws.com",
"https://s3.amazonaws.com",
"",
false,
false,
"https://sdb.amazonaws.com",
"https://sns.us-east-1.amazonaws.com",
"https://sqs.us-east-1.amazonaws.com",
"https://iam.amazonaws.com",
"https://elasticloadbalancing.us-east-1.amazonaws.com",
"https://autoscaling.us-east-1.amazonaws.com",
"https://rds.us-east-1.amazonaws.com",
"https://route53.amazonaws.com",
}
var USWest = Region{
"us-west-1",
"https://ec2.us-west-1.amazonaws.com",
"https://s3-us-west-1.amazonaws.com",
"",
true,
true,
"https://sdb.us-west-1.amazonaws.com",
"https://sns.us-west-1.amazonaws.com",
"https://sqs.us-west-1.amazonaws.com",
"https://iam.amazonaws.com",
"https://elasticloadbalancing.us-west-1.amazonaws.com",
"https://autoscaling.us-west-1.amazonaws.com",
"https://rds.us-west-1.amazonaws.com",
"https://route53.amazonaws.com",
}
var USWest2 = Region{
"us-west-2",
"https://ec2.us-west-2.amazonaws.com",
"https://s3-us-west-2.amazonaws.com",
"",
true,
true,
"https://sdb.us-west-2.amazonaws.com",
"https://sns.us-west-2.amazonaws.com",
"https://sqs.us-west-2.amazonaws.com",
"https://iam.amazonaws.com",
"https://elasticloadbalancing.us-west-2.amazonaws.com",
"https://autoscaling.us-west-2.amazonaws.com",
"https://rds.us-west-2.amazonaws.com",
"https://route53.amazonaws.com",
}
var EUWest = Region{
"eu-west-1",
"https://ec2.eu-west-1.amazonaws.com",
"https://s3-eu-west-1.amazonaws.com",
"",
true,
true,
"https://sdb.eu-west-1.amazonaws.com",
"https://sns.eu-west-1.amazonaws.com",
"https://sqs.eu-west-1.amazonaws.com",
"https://iam.amazonaws.com",
"https://elasticloadbalancing.eu-west-1.amazonaws.com",
"https://autoscaling.eu-west-1.amazonaws.com",
"https://rds.eu-west-1.amazonaws.com",
"https://route53.amazonaws.com",
}
var EUCentral = Region{
"eu-central-1",
"https://ec2.eu-central-1.amazonaws.com",
"https://s3-eu-central-1.amazonaws.com",
"",
true,
true,
"",
"https://sns.eu-central-1.amazonaws.com",
"https://sqs.eu-central-1.amazonaws.com",
"https://iam.amazonaws.com",
"https://elasticloadbalancing.eu-central-1.amazonaws.com",
"https://autoscaling.eu-central-1.amazonaws.com",
"https://rds.eu-central-1.amazonaws.com",
"https://route53.amazonaws.com",
}
var APSoutheast = Region{
"ap-southeast-1",
"https://ec2.ap-southeast-1.amazonaws.com",
"https://s3-ap-southeast-1.amazonaws.com",
"",
true,
true,
"https://sdb.ap-southeast-1.amazonaws.com",
"https://sns.ap-southeast-1.amazonaws.com",
"https://sqs.ap-southeast-1.amazonaws.com",
"https://iam.amazonaws.com",
"https://elasticloadbalancing.ap-southeast-1.amazonaws.com",
"https://autoscaling.ap-southeast-1.amazonaws.com",
"https://rds.ap-southeast-1.amazonaws.com",
"https://route53.amazonaws.com",
}
var APSoutheast2 = Region{
"ap-southeast-2",
"https://ec2.ap-southeast-2.amazonaws.com",
"https://s3-ap-southeast-2.amazonaws.com",
"",
true,
true,
"https://sdb.ap-southeast-2.amazonaws.com",
"https://sns.ap-southeast-2.amazonaws.com",
"https://sqs.ap-southeast-2.amazonaws.com",
"https://iam.amazonaws.com",
"https://elasticloadbalancing.ap-southeast-2.amazonaws.com",
"https://autoscaling.ap-southeast-2.amazonaws.com",
"https://rds.ap-southeast-2.amazonaws.com",
"https://route53.amazonaws.com",
}
var APNortheast = Region{
"ap-northeast-1",
"https://ec2.ap-northeast-1.amazonaws.com",
"https://s3-ap-northeast-1.amazonaws.com",
"",
true,
true,
"https://sdb.ap-northeast-1.amazonaws.com",
"https://sns.ap-northeast-1.amazonaws.com",
"https://sqs.ap-northeast-1.amazonaws.com",
"https://iam.amazonaws.com",
"https://elasticloadbalancing.ap-northeast-1.amazonaws.com",
"https://autoscaling.ap-northeast-1.amazonaws.com",
"https://rds.ap-northeast-1.amazonaws.com",
"https://route53.amazonaws.com",
}
var SAEast = Region{
"sa-east-1",
"https://ec2.sa-east-1.amazonaws.com",
"https://s3-sa-east-1.amazonaws.com",
"",
true,
true,
"https://sdb.sa-east-1.amazonaws.com",
"https://sns.sa-east-1.amazonaws.com",
"https://sqs.sa-east-1.amazonaws.com",
"https://iam.amazonaws.com",
"https://elasticloadbalancing.sa-east-1.amazonaws.com",
"https://autoscaling.sa-east-1.amazonaws.com",
"https://rds.sa-east-1.amazonaws.com",
"https://route53.amazonaws.com",
}
var CNNorth = Region{
"cn-north-1",
"https://ec2.cn-north-1.amazonaws.com.cn",
"https://s3.cn-north-1.amazonaws.com.cn",
"",
true,
true,
"",
"https://sns.cn-north-1.amazonaws.com.cn",
"https://sqs.cn-north-1.amazonaws.com.cn",
"https://iam.cn-north-1.amazonaws.com.cn",
"https://elasticloadbalancing.cn-north-1.amazonaws.com.cn",
"https://autoscaling.cn-north-1.amazonaws.com.cn",
"https://rds.cn-north-1.amazonaws.com.cn",
"https://route53.amazonaws.com",
}
var Regions = map[string]Region{
APNortheast.Name: APNortheast,
APSoutheast.Name: APSoutheast,
APSoutheast2.Name: APSoutheast2,
EUWest.Name: EUWest,
EUCentral.Name: EUCentral,
USEast.Name: USEast,
USWest.Name: USWest,
USWest2.Name: USWest2,
SAEast.Name: SAEast,
USGovWest.Name: USGovWest,
CNNorth.Name: CNNorth,
}
type Auth struct {
AccessKey, SecretKey, Token string
}
var unreserved = make([]bool, 128)
var hex = "0123456789ABCDEF"
func init() {
// RFC3986
u := "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz01234567890-_.~"
for _, c := range u {
unreserved[c] = true
}
}
type credentials struct {
Code string
LastUpdated string
Type string
AccessKeyId string
SecretAccessKey string
Token string
Expiration string
}
// GetMetaData retrieves instance metadata about the current machine.
//
// See http://docs.aws.amazon.com/AWSEC2/latest/UserGuide/AESDG-chapter-instancedata.html for more details.
func GetMetaData(path string) (contents []byte, err error) {
url := "http://169.254.169.254/latest/meta-data/" + path
resp, err := RetryingClient.Get(url)
if err != nil {
return
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
err = fmt.Errorf("Code %d returned for url %s", resp.StatusCode, url)
return
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return
}
return []byte(body), err
}
func getInstanceCredentials() (cred credentials, err error) {
credentialPath := "iam/security-credentials/"
// Get the instance role
role, err := GetMetaData(credentialPath)
if err != nil {
return
}
// Get the instance role credentials
credentialJSON, err := GetMetaData(credentialPath + string(role))
if err != nil {
return
}
err = json.Unmarshal([]byte(credentialJSON), &cred)
return
}
// GetAuth creates an Auth based on either passed in credentials,
// environment information or instance based role credentials.
func GetAuth(accessKey string, secretKey string) (auth Auth, err error) {
// First try passed in credentials
if accessKey != "" && secretKey != "" {
return Auth{accessKey, secretKey, ""}, nil
}
// Next try to get auth from the environment
auth, err = SharedAuth()
if err == nil {
// Found auth, return
return
}
// Next try to get auth from the environment
auth, err = EnvAuth()
if err == nil {
// Found auth, return
return
}
// Next try getting auth from the instance role
cred, err := getInstanceCredentials()
if err == nil {
// Found auth, return
auth.AccessKey = cred.AccessKeyId
auth.SecretKey = cred.SecretAccessKey
auth.Token = cred.Token
return
}
err = errors.New("No valid AWS authentication found")
return
}
// SharedAuth creates an Auth based on shared credentials stored in
// $HOME/.aws/credentials. The AWS_PROFILE environment variables is used to
// select the profile.
func SharedAuth() (auth Auth, err error) {
var profileName = os.Getenv("AWS_PROFILE")
if profileName == "" {
profileName = "default"
}
var credentialsFile = os.Getenv("AWS_CREDENTIAL_FILE")
if credentialsFile == "" {
var homeDir = os.Getenv("HOME")
if homeDir == "" {
err = errors.New("Could not get HOME")
return
}
credentialsFile = homeDir + "/.aws/credentials"
}
file, err := ini.LoadFile(credentialsFile)
if err != nil {
err = errors.New("Couldn't parse AWS credentials file")
return
}
var profile = file[profileName]
if profile == nil {
err = errors.New("Couldn't find profile in AWS credentials file")
return
}
auth.AccessKey = profile["aws_access_key_id"]
auth.SecretKey = profile["aws_secret_access_key"]
if auth.AccessKey == "" {
err = errors.New("AWS_ACCESS_KEY_ID not found in environment in credentials file")
}
if auth.SecretKey == "" {
err = errors.New("AWS_SECRET_ACCESS_KEY not found in credentials file")
}
return
}
// EnvAuth creates an Auth based on environment information.
// The AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment
// For accounts that require a security token, it is read from AWS_SECURITY_TOKEN
// variables are used.
func EnvAuth() (auth Auth, err error) {
auth.AccessKey = os.Getenv("AWS_ACCESS_KEY_ID")
if auth.AccessKey == "" {
auth.AccessKey = os.Getenv("AWS_ACCESS_KEY")
}
auth.SecretKey = os.Getenv("AWS_SECRET_ACCESS_KEY")
if auth.SecretKey == "" {
auth.SecretKey = os.Getenv("AWS_SECRET_KEY")
}
auth.Token = os.Getenv("AWS_SECURITY_TOKEN")
if auth.AccessKey == "" {
err = errors.New("AWS_ACCESS_KEY_ID or AWS_ACCESS_KEY not found in environment")
}
if auth.SecretKey == "" {
err = errors.New("AWS_SECRET_ACCESS_KEY or AWS_SECRET_KEY not found in environment")
}
return
}
// Encode takes a string and URI-encodes it in a way suitable
// to be used in AWS signatures.
func Encode(s string) string {
encode := false
for i := 0; i != len(s); i++ {
c := s[i]
if c > 127 || !unreserved[c] {
encode = true
break
}
}
if !encode {
return s
}
e := make([]byte, len(s)*3)
ei := 0
for i := 0; i != len(s); i++ {
c := s[i]
if c > 127 || !unreserved[c] {
e[ei] = '%'
e[ei+1] = hex[c>>4]
e[ei+2] = hex[c&0xF]
ei += 3
} else {
e[ei] = c
ei += 1
}
}
return string(e[:ei])
}

View File

@@ -1,203 +0,0 @@
package aws_test
import (
"github.com/mitchellh/goamz/aws"
. "github.com/motain/gocheck"
"io/ioutil"
"os"
"strings"
"testing"
)
func Test(t *testing.T) {
TestingT(t)
}
var _ = Suite(&S{})
type S struct {
environ []string
}
func (s *S) SetUpSuite(c *C) {
s.environ = os.Environ()
}
func (s *S) TearDownTest(c *C) {
os.Clearenv()
for _, kv := range s.environ {
l := strings.SplitN(kv, "=", 2)
os.Setenv(l[0], l[1])
}
}
func (s *S) TestSharedAuthNoHome(c *C) {
os.Clearenv()
os.Setenv("AWS_PROFILE", "foo")
_, err := aws.SharedAuth()
c.Assert(err, ErrorMatches, "Could not get HOME")
}
func (s *S) TestSharedAuthNoCredentialsFile(c *C) {
os.Clearenv()
os.Setenv("AWS_PROFILE", "foo")
os.Setenv("HOME", "/tmp")
_, err := aws.SharedAuth()
c.Assert(err, ErrorMatches, "Couldn't parse AWS credentials file")
}
func (s *S) TestSharedAuthNoProfileInFile(c *C) {
os.Clearenv()
os.Setenv("AWS_PROFILE", "foo")
d, err := ioutil.TempDir("", "")
if err != nil {
panic(err)
}
defer os.RemoveAll(d)
err = os.Mkdir(d+"/.aws", 0755)
if err != nil {
panic(err)
}
ioutil.WriteFile(d+"/.aws/credentials", []byte("[bar]\n"), 0644)
os.Setenv("HOME", d)
_, err = aws.SharedAuth()
c.Assert(err, ErrorMatches, "Couldn't find profile in AWS credentials file")
}
func (s *S) TestSharedAuthNoKeysInProfile(c *C) {
os.Clearenv()
os.Setenv("AWS_PROFILE", "bar")
d, err := ioutil.TempDir("", "")
if err != nil {
panic(err)
}
defer os.RemoveAll(d)
err = os.Mkdir(d+"/.aws", 0755)
if err != nil {
panic(err)
}
ioutil.WriteFile(d+"/.aws/credentials", []byte("[bar]\nawsaccesskeyid = AK.."), 0644)
os.Setenv("HOME", d)
_, err = aws.SharedAuth()
c.Assert(err, ErrorMatches, "AWS_SECRET_ACCESS_KEY not found in credentials file")
}
func (s *S) TestSharedAuthDefaultCredentials(c *C) {
os.Clearenv()
d, err := ioutil.TempDir("", "")
if err != nil {
panic(err)
}
defer os.RemoveAll(d)
err = os.Mkdir(d+"/.aws", 0755)
if err != nil {
panic(err)
}
ioutil.WriteFile(d+"/.aws/credentials", []byte("[default]\naws_access_key_id = access\naws_secret_access_key = secret\n"), 0644)
os.Setenv("HOME", d)
auth, err := aws.SharedAuth()
c.Assert(err, IsNil)
c.Assert(auth, Equals, aws.Auth{SecretKey: "secret", AccessKey: "access"})
}
func (s *S) TestSharedAuth(c *C) {
os.Clearenv()
os.Setenv("AWS_PROFILE", "bar")
d, err := ioutil.TempDir("", "")
if err != nil {
panic(err)
}
defer os.RemoveAll(d)
err = os.Mkdir(d+"/.aws", 0755)
if err != nil {
panic(err)
}
ioutil.WriteFile(d+"/.aws/credentials", []byte("[bar]\naws_access_key_id = access\naws_secret_access_key = secret\n"), 0644)
os.Setenv("HOME", d)
auth, err := aws.SharedAuth()
c.Assert(err, IsNil)
c.Assert(auth, Equals, aws.Auth{SecretKey: "secret", AccessKey: "access"})
}
func (s *S) TestEnvAuthNoSecret(c *C) {
os.Clearenv()
_, err := aws.EnvAuth()
c.Assert(err, ErrorMatches, "AWS_SECRET_ACCESS_KEY or AWS_SECRET_KEY not found in environment")
}
func (s *S) TestEnvAuthNoAccess(c *C) {
os.Clearenv()
os.Setenv("AWS_SECRET_ACCESS_KEY", "foo")
_, err := aws.EnvAuth()
c.Assert(err, ErrorMatches, "AWS_ACCESS_KEY_ID or AWS_ACCESS_KEY not found in environment")
}
func (s *S) TestEnvAuth(c *C) {
os.Clearenv()
os.Setenv("AWS_SECRET_ACCESS_KEY", "secret")
os.Setenv("AWS_ACCESS_KEY_ID", "access")
auth, err := aws.EnvAuth()
c.Assert(err, IsNil)
c.Assert(auth, Equals, aws.Auth{SecretKey: "secret", AccessKey: "access"})
}
func (s *S) TestEnvAuthWithToken(c *C) {
os.Clearenv()
os.Setenv("AWS_SECRET_ACCESS_KEY", "secret")
os.Setenv("AWS_ACCESS_KEY_ID", "access")
os.Setenv("AWS_SECURITY_TOKEN", "token")
auth, err := aws.EnvAuth()
c.Assert(err, IsNil)
c.Assert(auth, Equals, aws.Auth{SecretKey: "secret", AccessKey: "access", Token: "token"})
}
func (s *S) TestEnvAuthAlt(c *C) {
os.Clearenv()
os.Setenv("AWS_SECRET_KEY", "secret")
os.Setenv("AWS_ACCESS_KEY", "access")
auth, err := aws.EnvAuth()
c.Assert(err, IsNil)
c.Assert(auth, Equals, aws.Auth{SecretKey: "secret", AccessKey: "access"})
}
func (s *S) TestGetAuthStatic(c *C) {
auth, err := aws.GetAuth("access", "secret")
c.Assert(err, IsNil)
c.Assert(auth, Equals, aws.Auth{SecretKey: "secret", AccessKey: "access"})
}
func (s *S) TestGetAuthEnv(c *C) {
os.Clearenv()
os.Setenv("AWS_SECRET_ACCESS_KEY", "secret")
os.Setenv("AWS_ACCESS_KEY_ID", "access")
auth, err := aws.GetAuth("", "")
c.Assert(err, IsNil)
c.Assert(auth, Equals, aws.Auth{SecretKey: "secret", AccessKey: "access"})
}
func (s *S) TestEncode(c *C) {
c.Assert(aws.Encode("foo"), Equals, "foo")
c.Assert(aws.Encode("/"), Equals, "%2F")
}
func (s *S) TestRegionsAreNamed(c *C) {
for n, r := range aws.Regions {
c.Assert(n, Equals, r.Name)
}
}

View File

@@ -1,125 +0,0 @@
package aws
import (
"math"
"net"
"net/http"
"time"
)
type RetryableFunc func(*http.Request, *http.Response, error) bool
type WaitFunc func(try int)
type DeadlineFunc func() time.Time
type ResilientTransport struct {
// Timeout is the maximum amount of time a dial will wait for
// a connect to complete.
//
// The default is no timeout.
//
// With or without a timeout, the operating system may impose
// its own earlier timeout. For instance, TCP timeouts are
// often around 3 minutes.
DialTimeout time.Duration
// MaxTries, if non-zero, specifies the number of times we will retry on
// failure. Retries are only attempted for temporary network errors or known
// safe failures.
MaxTries int
Deadline DeadlineFunc
ShouldRetry RetryableFunc
Wait WaitFunc
transport *http.Transport
}
// Convenience method for creating an http client
func NewClient(rt *ResilientTransport) *http.Client {
rt.transport = &http.Transport{
Dial: func(netw, addr string) (net.Conn, error) {
c, err := net.DialTimeout(netw, addr, rt.DialTimeout)
if err != nil {
return nil, err
}
c.SetDeadline(rt.Deadline())
return c, nil
},
DisableKeepAlives: true,
Proxy: http.ProxyFromEnvironment,
}
// TODO: Would be nice is ResilientTransport allowed clients to initialize
// with http.Transport attributes.
return &http.Client{
Transport: rt,
}
}
var retryingTransport = &ResilientTransport{
Deadline: func() time.Time {
return time.Now().Add(5 * time.Second)
},
DialTimeout: 10 * time.Second,
MaxTries: 3,
ShouldRetry: awsRetry,
Wait: ExpBackoff,
}
// Exported default client
var RetryingClient = NewClient(retryingTransport)
func (t *ResilientTransport) RoundTrip(req *http.Request) (*http.Response, error) {
return t.tries(req)
}
// Retry a request a maximum of t.MaxTries times.
// We'll only retry if the proper criteria are met.
// If a wait function is specified, wait that amount of time
// In between requests.
func (t *ResilientTransport) tries(req *http.Request) (res *http.Response, err error) {
for try := 0; try < t.MaxTries; try += 1 {
res, err = t.transport.RoundTrip(req)
if !t.ShouldRetry(req, res, err) {
break
}
if res != nil {
res.Body.Close()
}
if t.Wait != nil {
t.Wait(try)
}
}
return
}
func ExpBackoff(try int) {
time.Sleep(100 * time.Millisecond *
time.Duration(math.Exp2(float64(try))))
}
func LinearBackoff(try int) {
time.Sleep(time.Duration(try*100) * time.Millisecond)
}
// Decide if we should retry a request.
// In general, the criteria for retrying a request is described here
// http://docs.aws.amazon.com/general/latest/gr/api-retries.html
func awsRetry(req *http.Request, res *http.Response, err error) bool {
retry := false
// Retry if there's a temporary network error.
if neterr, ok := err.(net.Error); ok {
if neterr.Temporary() {
retry = true
}
}
// Retry if we get a 5xx series error.
if res != nil {
if res.StatusCode >= 500 && res.StatusCode < 600 {
retry = true
}
}
return retry
}

View File

@@ -1,121 +0,0 @@
package aws_test
import (
"fmt"
"github.com/mitchellh/goamz/aws"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
)
// Retrieve the response from handler using aws.RetryingClient
func serveAndGet(handler http.HandlerFunc) (body string, err error) {
ts := httptest.NewServer(handler)
defer ts.Close()
resp, err := aws.RetryingClient.Get(ts.URL)
if err != nil {
return
}
if resp.StatusCode != 200 {
return "", fmt.Errorf("Bad status code: %d", resp.StatusCode)
}
greeting, err := ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return
}
return strings.TrimSpace(string(greeting)), nil
}
func TestClient_expected(t *testing.T) {
body := "foo bar"
resp, err := serveAndGet(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, body)
})
if err != nil {
t.Fatal(err)
}
if resp != body {
t.Fatal("Body not as expected.")
}
}
func TestClient_delay(t *testing.T) {
body := "baz"
wait := 4
resp, err := serveAndGet(func(w http.ResponseWriter, r *http.Request) {
if wait < 0 {
// If we dipped to zero delay and still failed.
t.Fatal("Never succeeded.")
}
wait -= 1
time.Sleep(time.Second * time.Duration(wait))
fmt.Fprintln(w, body)
})
if err != nil {
t.Fatal(err)
}
if resp != body {
t.Fatal("Body not as expected.", resp)
}
}
func TestClient_no4xxRetry(t *testing.T) {
tries := 0
// Fail once before succeeding.
_, err := serveAndGet(func(w http.ResponseWriter, r *http.Request) {
tries += 1
http.Error(w, "error", 404)
})
if err == nil {
t.Fatal("should have error")
}
if tries != 1 {
t.Fatalf("should only try once: %d", tries)
}
}
func TestClient_retries(t *testing.T) {
body := "biz"
failed := false
// Fail once before succeeding.
resp, err := serveAndGet(func(w http.ResponseWriter, r *http.Request) {
if !failed {
http.Error(w, "error", 500)
failed = true
} else {
fmt.Fprintln(w, body)
}
})
if failed != true {
t.Error("We didn't retry!")
}
if err != nil {
t.Fatal(err)
}
if resp != body {
t.Fatal("Body not as expected.")
}
}
func TestClient_fails(t *testing.T) {
tries := 0
// Fail 3 times and return the last error.
_, err := serveAndGet(func(w http.ResponseWriter, r *http.Request) {
tries += 1
http.Error(w, "error", 500)
})
if err == nil {
t.Fatal(err)
}
if tries != 3 {
t.Fatal("Didn't retry enough")
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,204 +0,0 @@
package ec2_test
import (
"crypto/rand"
"fmt"
"github.com/mitchellh/goamz/aws"
"github.com/mitchellh/goamz/ec2"
"github.com/mitchellh/goamz/testutil"
. "github.com/motain/gocheck"
)
// AmazonServer represents an Amazon EC2 server.
type AmazonServer struct {
auth aws.Auth
}
func (s *AmazonServer) SetUp(c *C) {
auth, err := aws.EnvAuth()
if err != nil {
c.Fatal(err.Error())
}
s.auth = auth
}
// Suite cost per run: 0.02 USD
var _ = Suite(&AmazonClientSuite{})
// AmazonClientSuite tests the client against a live EC2 server.
type AmazonClientSuite struct {
srv AmazonServer
ClientTests
}
func (s *AmazonClientSuite) SetUpSuite(c *C) {
if !testutil.Amazon {
c.Skip("AmazonClientSuite tests not enabled")
}
s.srv.SetUp(c)
s.ec2 = ec2.NewWithClient(s.srv.auth, aws.USEast, testutil.DefaultClient)
}
// ClientTests defines integration tests designed to test the client.
// It is not used as a test suite in itself, but embedded within
// another type.
type ClientTests struct {
ec2 *ec2.EC2
}
var imageId = "ami-ccf405a5" // Ubuntu Maverick, i386, EBS store
// Cost: 0.00 USD
func (s *ClientTests) TestRunInstancesError(c *C) {
options := ec2.RunInstances{
ImageId: "ami-a6f504cf", // Ubuntu Maverick, i386, instance store
InstanceType: "t1.micro", // Doesn't work with micro, results in 400.
}
resp, err := s.ec2.RunInstances(&options)
c.Assert(resp, IsNil)
c.Assert(err, ErrorMatches, "AMI.*root device.*not supported.*")
ec2err, ok := err.(*ec2.Error)
c.Assert(ok, Equals, true)
c.Assert(ec2err.StatusCode, Equals, 400)
c.Assert(ec2err.Code, Equals, "UnsupportedOperation")
c.Assert(ec2err.Message, Matches, "AMI.*root device.*not supported.*")
c.Assert(ec2err.RequestId, Matches, ".+")
}
// Cost: 0.02 USD
func (s *ClientTests) TestRunAndTerminate(c *C) {
options := ec2.RunInstances{
ImageId: imageId,
InstanceType: "t1.micro",
}
resp1, err := s.ec2.RunInstances(&options)
c.Assert(err, IsNil)
c.Check(resp1.ReservationId, Matches, "r-[0-9a-f]*")
c.Check(resp1.OwnerId, Matches, "[0-9]+")
c.Check(resp1.Instances, HasLen, 1)
c.Check(resp1.Instances[0].InstanceType, Equals, "t1.micro")
instId := resp1.Instances[0].InstanceId
resp2, err := s.ec2.Instances([]string{instId}, nil)
c.Assert(err, IsNil)
if c.Check(resp2.Reservations, HasLen, 1) && c.Check(len(resp2.Reservations[0].Instances), Equals, 1) {
inst := resp2.Reservations[0].Instances[0]
c.Check(inst.InstanceId, Equals, instId)
}
resp3, err := s.ec2.TerminateInstances([]string{instId})
c.Assert(err, IsNil)
c.Check(resp3.StateChanges, HasLen, 1)
c.Check(resp3.StateChanges[0].InstanceId, Equals, instId)
c.Check(resp3.StateChanges[0].CurrentState.Name, Equals, "shutting-down")
c.Check(resp3.StateChanges[0].CurrentState.Code, Equals, 32)
}
// Cost: 0.00 USD
func (s *ClientTests) TestSecurityGroups(c *C) {
name := "goamz-test"
descr := "goamz security group for tests"
// Clean it up, if a previous test left it around and avoid leaving it around.
s.ec2.DeleteSecurityGroup(ec2.SecurityGroup{Name: name})
defer s.ec2.DeleteSecurityGroup(ec2.SecurityGroup{Name: name})
resp1, err := s.ec2.CreateSecurityGroup(ec2.SecurityGroup{Name: name, Description: descr})
c.Assert(err, IsNil)
c.Assert(resp1.RequestId, Matches, ".+")
c.Assert(resp1.Name, Equals, name)
c.Assert(resp1.Id, Matches, ".+")
resp1, err = s.ec2.CreateSecurityGroup(ec2.SecurityGroup{Name: name, Description: descr})
ec2err, _ := err.(*ec2.Error)
c.Assert(resp1, IsNil)
c.Assert(ec2err, NotNil)
c.Assert(ec2err.Code, Equals, "InvalidGroup.Duplicate")
perms := []ec2.IPPerm{{
Protocol: "tcp",
FromPort: 0,
ToPort: 1024,
SourceIPs: []string{"127.0.0.1/24"},
}}
resp2, err := s.ec2.AuthorizeSecurityGroup(ec2.SecurityGroup{Name: name}, perms)
c.Assert(err, IsNil)
c.Assert(resp2.RequestId, Matches, ".+")
resp3, err := s.ec2.SecurityGroups(ec2.SecurityGroupNames(name), nil)
c.Assert(err, IsNil)
c.Assert(resp3.RequestId, Matches, ".+")
c.Assert(resp3.Groups, HasLen, 1)
g0 := resp3.Groups[0]
c.Assert(g0.Name, Equals, name)
c.Assert(g0.Description, Equals, descr)
c.Assert(g0.IPPerms, HasLen, 1)
c.Assert(g0.IPPerms[0].Protocol, Equals, "tcp")
c.Assert(g0.IPPerms[0].FromPort, Equals, 0)
c.Assert(g0.IPPerms[0].ToPort, Equals, 1024)
c.Assert(g0.IPPerms[0].SourceIPs, DeepEquals, []string{"127.0.0.1/24"})
resp2, err = s.ec2.DeleteSecurityGroup(ec2.SecurityGroup{Name: name})
c.Assert(err, IsNil)
c.Assert(resp2.RequestId, Matches, ".+")
}
var sessionId = func() string {
buf := make([]byte, 8)
// if we have no randomness, we'll just make do, so ignore the error.
rand.Read(buf)
return fmt.Sprintf("%x", buf)
}()
// sessionName reutrns a name that is probably
// unique to this test session.
func sessionName(prefix string) string {
return prefix + "-" + sessionId
}
var allRegions = []aws.Region{
aws.USEast,
aws.USWest,
aws.EUWest,
aws.EUCentral,
aws.APSoutheast,
aws.APNortheast,
}
// Communicate with all EC2 endpoints to see if they are alive.
func (s *ClientTests) TestRegions(c *C) {
name := sessionName("goamz-region-test")
perms := []ec2.IPPerm{{
Protocol: "tcp",
FromPort: 80,
ToPort: 80,
SourceIPs: []string{"127.0.0.1/32"},
}}
errs := make(chan error, len(allRegions))
for _, region := range allRegions {
go func(r aws.Region) {
e := ec2.NewWithClient(s.ec2.Auth, r, testutil.DefaultClient)
_, err := e.AuthorizeSecurityGroup(ec2.SecurityGroup{Name: name}, perms)
errs <- err
}(region)
}
for _ = range allRegions {
err := <-errs
if err != nil {
ec2_err, ok := err.(*ec2.Error)
if ok {
c.Check(ec2_err.Code, Matches, "InvalidGroup.NotFound")
} else {
c.Errorf("Non-EC2 error: %s", err)
}
} else {
c.Errorf("Test should have errored but it seems to have succeeded")
}
}
}

View File

@@ -1,580 +0,0 @@
package ec2_test
import (
"fmt"
"github.com/mitchellh/goamz/aws"
"github.com/mitchellh/goamz/ec2"
"github.com/mitchellh/goamz/ec2/ec2test"
"github.com/mitchellh/goamz/testutil"
. "github.com/motain/gocheck"
"regexp"
"sort"
)
// LocalServer represents a local ec2test fake server.
type LocalServer struct {
auth aws.Auth
region aws.Region
srv *ec2test.Server
}
func (s *LocalServer) SetUp(c *C) {
srv, err := ec2test.NewServer()
c.Assert(err, IsNil)
c.Assert(srv, NotNil)
s.srv = srv
s.region = aws.Region{EC2Endpoint: srv.URL()}
}
// LocalServerSuite defines tests that will run
// against the local ec2test server. It includes
// selected tests from ClientTests;
// when the ec2test functionality is sufficient, it should
// include all of them, and ClientTests can be simply embedded.
type LocalServerSuite struct {
srv LocalServer
ServerTests
clientTests ClientTests
}
var _ = Suite(&LocalServerSuite{})
func (s *LocalServerSuite) SetUpSuite(c *C) {
s.srv.SetUp(c)
s.ServerTests.ec2 = ec2.NewWithClient(s.srv.auth, s.srv.region, testutil.DefaultClient)
s.clientTests.ec2 = ec2.NewWithClient(s.srv.auth, s.srv.region, testutil.DefaultClient)
}
func (s *LocalServerSuite) TestRunAndTerminate(c *C) {
s.clientTests.TestRunAndTerminate(c)
}
func (s *LocalServerSuite) TestSecurityGroups(c *C) {
s.clientTests.TestSecurityGroups(c)
}
// TestUserData is not defined on ServerTests because it
// requires the ec2test server to function.
func (s *LocalServerSuite) TestUserData(c *C) {
data := make([]byte, 256)
for i := range data {
data[i] = byte(i)
}
inst, err := s.ec2.RunInstances(&ec2.RunInstances{
ImageId: imageId,
InstanceType: "t1.micro",
UserData: data,
})
c.Assert(err, IsNil)
c.Assert(inst, NotNil)
c.Assert(inst.Instances[0].DNSName, Equals, inst.Instances[0].InstanceId+".example.com")
id := inst.Instances[0].InstanceId
defer s.ec2.TerminateInstances([]string{id})
tinst := s.srv.srv.Instance(id)
c.Assert(tinst, NotNil)
c.Assert(tinst.UserData, DeepEquals, data)
}
// AmazonServerSuite runs the ec2test server tests against a live EC2 server.
// It will only be activated if the -all flag is specified.
type AmazonServerSuite struct {
srv AmazonServer
ServerTests
}
var _ = Suite(&AmazonServerSuite{})
func (s *AmazonServerSuite) SetUpSuite(c *C) {
if !testutil.Amazon {
c.Skip("AmazonServerSuite tests not enabled")
}
s.srv.SetUp(c)
s.ServerTests.ec2 = ec2.NewWithClient(s.srv.auth, aws.USEast, testutil.DefaultClient)
}
// ServerTests defines a set of tests designed to test
// the ec2test local fake ec2 server.
// It is not used as a test suite in itself, but embedded within
// another type.
type ServerTests struct {
ec2 *ec2.EC2
}
func terminateInstances(c *C, e *ec2.EC2, insts []*ec2.Instance) {
var ids []string
for _, inst := range insts {
if inst != nil {
ids = append(ids, inst.InstanceId)
}
}
_, err := e.TerminateInstances(ids)
c.Check(err, IsNil, Commentf("%d INSTANCES LEFT RUNNING!!!", len(ids)))
}
func (s *ServerTests) makeTestGroup(c *C, name, descr string) ec2.SecurityGroup {
// Clean it up if a previous test left it around.
_, err := s.ec2.DeleteSecurityGroup(ec2.SecurityGroup{Name: name})
if err != nil && err.(*ec2.Error).Code != "InvalidGroup.NotFound" {
c.Fatalf("delete security group: %v", err)
}
resp, err := s.ec2.CreateSecurityGroup(ec2.SecurityGroup{Name: name, Description: descr})
c.Assert(err, IsNil)
c.Assert(resp.Name, Equals, name)
return resp.SecurityGroup
}
func (s *ServerTests) TestIPPerms(c *C) {
g0 := s.makeTestGroup(c, "goamz-test0", "ec2test group 0")
defer s.ec2.DeleteSecurityGroup(g0)
g1 := s.makeTestGroup(c, "goamz-test1", "ec2test group 1")
defer s.ec2.DeleteSecurityGroup(g1)
resp, err := s.ec2.SecurityGroups([]ec2.SecurityGroup{g0, g1}, nil)
c.Assert(err, IsNil)
c.Assert(resp.Groups, HasLen, 2)
c.Assert(resp.Groups[0].IPPerms, HasLen, 0)
c.Assert(resp.Groups[1].IPPerms, HasLen, 0)
ownerId := resp.Groups[0].OwnerId
// test some invalid parameters
// TODO more
_, err = s.ec2.AuthorizeSecurityGroup(g0, []ec2.IPPerm{{
Protocol: "tcp",
FromPort: 0,
ToPort: 1024,
SourceIPs: []string{"z127.0.0.1/24"},
}})
c.Assert(err, NotNil)
c.Check(err.(*ec2.Error).Code, Equals, "InvalidPermission.Malformed")
// Check that AuthorizeSecurityGroup adds the correct authorizations.
_, err = s.ec2.AuthorizeSecurityGroup(g0, []ec2.IPPerm{{
Protocol: "tcp",
FromPort: 2000,
ToPort: 2001,
SourceIPs: []string{"127.0.0.0/24"},
SourceGroups: []ec2.UserSecurityGroup{{
Name: g1.Name,
}, {
Id: g0.Id,
}},
}, {
Protocol: "tcp",
FromPort: 2000,
ToPort: 2001,
SourceIPs: []string{"200.1.1.34/32"},
}})
c.Assert(err, IsNil)
resp, err = s.ec2.SecurityGroups([]ec2.SecurityGroup{g0}, nil)
c.Assert(err, IsNil)
c.Assert(resp.Groups, HasLen, 1)
c.Assert(resp.Groups[0].IPPerms, HasLen, 1)
perm := resp.Groups[0].IPPerms[0]
srcg := perm.SourceGroups
c.Assert(srcg, HasLen, 2)
// Normalize so we don't care about returned order.
if srcg[0].Name == g1.Name {
srcg[0], srcg[1] = srcg[1], srcg[0]
}
c.Check(srcg[0].Name, Equals, g0.Name)
c.Check(srcg[0].Id, Equals, g0.Id)
c.Check(srcg[0].OwnerId, Equals, ownerId)
c.Check(srcg[1].Name, Equals, g1.Name)
c.Check(srcg[1].Id, Equals, g1.Id)
c.Check(srcg[1].OwnerId, Equals, ownerId)
sort.Strings(perm.SourceIPs)
c.Check(perm.SourceIPs, DeepEquals, []string{"127.0.0.0/24", "200.1.1.34/32"})
// Check that we can't delete g1 (because g0 is using it)
_, err = s.ec2.DeleteSecurityGroup(g1)
c.Assert(err, NotNil)
c.Check(err.(*ec2.Error).Code, Equals, "InvalidGroup.InUse")
_, err = s.ec2.RevokeSecurityGroup(g0, []ec2.IPPerm{{
Protocol: "tcp",
FromPort: 2000,
ToPort: 2001,
SourceGroups: []ec2.UserSecurityGroup{{Id: g1.Id}},
}, {
Protocol: "tcp",
FromPort: 2000,
ToPort: 2001,
SourceIPs: []string{"200.1.1.34/32"},
}})
c.Assert(err, IsNil)
resp, err = s.ec2.SecurityGroups([]ec2.SecurityGroup{g0}, nil)
c.Assert(err, IsNil)
c.Assert(resp.Groups, HasLen, 1)
c.Assert(resp.Groups[0].IPPerms, HasLen, 1)
perm = resp.Groups[0].IPPerms[0]
srcg = perm.SourceGroups
c.Assert(srcg, HasLen, 1)
c.Check(srcg[0].Name, Equals, g0.Name)
c.Check(srcg[0].Id, Equals, g0.Id)
c.Check(srcg[0].OwnerId, Equals, ownerId)
c.Check(perm.SourceIPs, DeepEquals, []string{"127.0.0.0/24"})
// We should be able to delete g1 now because we've removed its only use.
_, err = s.ec2.DeleteSecurityGroup(g1)
c.Assert(err, IsNil)
_, err = s.ec2.DeleteSecurityGroup(g0)
c.Assert(err, IsNil)
f := ec2.NewFilter()
f.Add("group-id", g0.Id, g1.Id)
resp, err = s.ec2.SecurityGroups(nil, f)
c.Assert(err, IsNil)
c.Assert(resp.Groups, HasLen, 0)
}
func (s *ServerTests) TestDuplicateIPPerm(c *C) {
name := "goamz-test"
descr := "goamz security group for tests"
// Clean it up, if a previous test left it around and avoid leaving it around.
s.ec2.DeleteSecurityGroup(ec2.SecurityGroup{Name: name})
defer s.ec2.DeleteSecurityGroup(ec2.SecurityGroup{Name: name})
resp1, err := s.ec2.CreateSecurityGroup(ec2.SecurityGroup{Name: name, Description: descr})
c.Assert(err, IsNil)
c.Assert(resp1.Name, Equals, name)
perms := []ec2.IPPerm{{
Protocol: "tcp",
FromPort: 200,
ToPort: 1024,
SourceIPs: []string{"127.0.0.1/24"},
}, {
Protocol: "tcp",
FromPort: 0,
ToPort: 100,
SourceIPs: []string{"127.0.0.1/24"},
}}
_, err = s.ec2.AuthorizeSecurityGroup(ec2.SecurityGroup{Name: name}, perms[0:1])
c.Assert(err, IsNil)
_, err = s.ec2.AuthorizeSecurityGroup(ec2.SecurityGroup{Name: name}, perms[0:2])
c.Assert(err, ErrorMatches, `.*\(InvalidPermission.Duplicate\)`)
}
type filterSpec struct {
name string
values []string
}
func (s *ServerTests) TestInstanceFiltering(c *C) {
groupResp, err := s.ec2.CreateSecurityGroup(ec2.SecurityGroup{Name: sessionName("testgroup1"), Description: "testgroup one description"})
c.Assert(err, IsNil)
group1 := groupResp.SecurityGroup
defer s.ec2.DeleteSecurityGroup(group1)
groupResp, err = s.ec2.CreateSecurityGroup(ec2.SecurityGroup{Name: sessionName("testgroup2"), Description: "testgroup two description"})
c.Assert(err, IsNil)
group2 := groupResp.SecurityGroup
defer s.ec2.DeleteSecurityGroup(group2)
insts := make([]*ec2.Instance, 3)
inst, err := s.ec2.RunInstances(&ec2.RunInstances{
MinCount: 2,
ImageId: imageId,
InstanceType: "t1.micro",
SecurityGroups: []ec2.SecurityGroup{group1},
})
c.Assert(err, IsNil)
insts[0] = &inst.Instances[0]
insts[1] = &inst.Instances[1]
defer terminateInstances(c, s.ec2, insts)
imageId2 := "ami-e358958a" // Natty server, i386, EBS store
inst, err = s.ec2.RunInstances(&ec2.RunInstances{
ImageId: imageId2,
InstanceType: "t1.micro",
SecurityGroups: []ec2.SecurityGroup{group2},
})
c.Assert(err, IsNil)
insts[2] = &inst.Instances[0]
ids := func(indices ...int) (instIds []string) {
for _, index := range indices {
instIds = append(instIds, insts[index].InstanceId)
}
return
}
tests := []struct {
about string
instanceIds []string // instanceIds argument to Instances method.
filters []filterSpec // filters argument to Instances method.
resultIds []string // set of instance ids of expected results.
allowExtra bool // resultIds may be incomplete.
err string // expected error.
}{
{
about: "check that Instances returns all instances",
resultIds: ids(0, 1, 2),
allowExtra: true,
}, {
about: "check that specifying two instance ids returns them",
instanceIds: ids(0, 2),
resultIds: ids(0, 2),
}, {
about: "check that specifying a non-existent instance id gives an error",
instanceIds: append(ids(0), "i-deadbeef"),
err: `.*\(InvalidInstanceID\.NotFound\)`,
}, {
about: "check that a filter allowed both instances returns both of them",
filters: []filterSpec{
{"instance-id", ids(0, 2)},
},
resultIds: ids(0, 2),
}, {
about: "check that a filter allowing only one instance returns it",
filters: []filterSpec{
{"instance-id", ids(1)},
},
resultIds: ids(1),
}, {
about: "check that a filter allowing no instances returns none",
filters: []filterSpec{
{"instance-id", []string{"i-deadbeef12345"}},
},
}, {
about: "check that filtering on group id works",
filters: []filterSpec{
{"group-id", []string{group1.Id}},
},
resultIds: ids(0, 1),
}, {
about: "check that filtering on group name works",
filters: []filterSpec{
{"group-name", []string{group1.Name}},
},
resultIds: ids(0, 1),
}, {
about: "check that filtering on image id works",
filters: []filterSpec{
{"image-id", []string{imageId}},
},
resultIds: ids(0, 1),
allowExtra: true,
}, {
about: "combination filters 1",
filters: []filterSpec{
{"image-id", []string{imageId, imageId2}},
{"group-name", []string{group1.Name}},
},
resultIds: ids(0, 1),
}, {
about: "combination filters 2",
filters: []filterSpec{
{"image-id", []string{imageId2}},
{"group-name", []string{group1.Name}},
},
},
}
for i, t := range tests {
c.Logf("%d. %s", i, t.about)
var f *ec2.Filter
if t.filters != nil {
f = ec2.NewFilter()
for _, spec := range t.filters {
f.Add(spec.name, spec.values...)
}
}
resp, err := s.ec2.Instances(t.instanceIds, f)
if t.err != "" {
c.Check(err, ErrorMatches, t.err)
continue
}
c.Assert(err, IsNil)
insts := make(map[string]*ec2.Instance)
for _, r := range resp.Reservations {
for j := range r.Instances {
inst := &r.Instances[j]
c.Check(insts[inst.InstanceId], IsNil, Commentf("duplicate instance id: %q", inst.InstanceId))
insts[inst.InstanceId] = inst
}
}
if !t.allowExtra {
c.Check(insts, HasLen, len(t.resultIds), Commentf("expected %d instances got %#v", len(t.resultIds), insts))
}
for j, id := range t.resultIds {
c.Check(insts[id], NotNil, Commentf("instance id %d (%q) not found; got %#v", j, id, insts))
}
}
}
func idsOnly(gs []ec2.SecurityGroup) []ec2.SecurityGroup {
for i := range gs {
gs[i].Name = ""
}
return gs
}
func namesOnly(gs []ec2.SecurityGroup) []ec2.SecurityGroup {
for i := range gs {
gs[i].Id = ""
}
return gs
}
func (s *ServerTests) TestGroupFiltering(c *C) {
g := make([]ec2.SecurityGroup, 4)
for i := range g {
resp, err := s.ec2.CreateSecurityGroup(ec2.SecurityGroup{Name: sessionName(fmt.Sprintf("testgroup%d", i)), Description: fmt.Sprintf("testdescription%d", i)})
c.Assert(err, IsNil)
g[i] = resp.SecurityGroup
c.Logf("group %d: %v", i, g[i])
defer s.ec2.DeleteSecurityGroup(g[i])
}
perms := [][]ec2.IPPerm{
{{
Protocol: "tcp",
FromPort: 100,
ToPort: 200,
SourceIPs: []string{"1.2.3.4/32"},
}},
{{
Protocol: "tcp",
FromPort: 200,
ToPort: 300,
SourceGroups: []ec2.UserSecurityGroup{{Id: g[1].Id}},
}},
{{
Protocol: "udp",
FromPort: 200,
ToPort: 400,
SourceGroups: []ec2.UserSecurityGroup{{Id: g[1].Id}},
}},
}
for i, ps := range perms {
_, err := s.ec2.AuthorizeSecurityGroup(g[i], ps)
c.Assert(err, IsNil)
}
groups := func(indices ...int) (gs []ec2.SecurityGroup) {
for _, index := range indices {
gs = append(gs, g[index])
}
return
}
type groupTest struct {
about string
groups []ec2.SecurityGroup // groupIds argument to SecurityGroups method.
filters []filterSpec // filters argument to SecurityGroups method.
results []ec2.SecurityGroup // set of expected result groups.
allowExtra bool // specified results may be incomplete.
err string // expected error.
}
filterCheck := func(name, val string, gs []ec2.SecurityGroup) groupTest {
return groupTest{
about: "filter check " + name,
filters: []filterSpec{{name, []string{val}}},
results: gs,
allowExtra: true,
}
}
tests := []groupTest{
{
about: "check that SecurityGroups returns all groups",
results: groups(0, 1, 2, 3),
allowExtra: true,
}, {
about: "check that specifying two group ids returns them",
groups: idsOnly(groups(0, 2)),
results: groups(0, 2),
}, {
about: "check that specifying names only works",
groups: namesOnly(groups(0, 2)),
results: groups(0, 2),
}, {
about: "check that specifying a non-existent group id gives an error",
groups: append(groups(0), ec2.SecurityGroup{Id: "sg-eeeeeeeee"}),
err: `.*\(InvalidGroup\.NotFound\)`,
}, {
about: "check that a filter allowed two groups returns both of them",
filters: []filterSpec{
{"group-id", []string{g[0].Id, g[2].Id}},
},
results: groups(0, 2),
},
{
about: "check that the previous filter works when specifying a list of ids",
groups: groups(1, 2),
filters: []filterSpec{
{"group-id", []string{g[0].Id, g[2].Id}},
},
results: groups(2),
}, {
about: "check that a filter allowing no groups returns none",
filters: []filterSpec{
{"group-id", []string{"sg-eeeeeeeee"}},
},
},
filterCheck("description", "testdescription1", groups(1)),
filterCheck("group-name", g[2].Name, groups(2)),
filterCheck("ip-permission.cidr", "1.2.3.4/32", groups(0)),
filterCheck("ip-permission.group-name", g[1].Name, groups(1, 2)),
filterCheck("ip-permission.protocol", "udp", groups(2)),
filterCheck("ip-permission.from-port", "200", groups(1, 2)),
filterCheck("ip-permission.to-port", "200", groups(0)),
// TODO owner-id
}
for i, t := range tests {
c.Logf("%d. %s", i, t.about)
var f *ec2.Filter
if t.filters != nil {
f = ec2.NewFilter()
for _, spec := range t.filters {
f.Add(spec.name, spec.values...)
}
}
resp, err := s.ec2.SecurityGroups(t.groups, f)
if t.err != "" {
c.Check(err, ErrorMatches, t.err)
continue
}
c.Assert(err, IsNil)
groups := make(map[string]*ec2.SecurityGroup)
for j := range resp.Groups {
group := &resp.Groups[j].SecurityGroup
c.Check(groups[group.Id], IsNil, Commentf("duplicate group id: %q", group.Id))
groups[group.Id] = group
}
// If extra groups may be returned, eliminate all groups that
// we did not create in this session apart from the default group.
if t.allowExtra {
namePat := regexp.MustCompile(sessionName("testgroup[0-9]"))
for id, g := range groups {
if !namePat.MatchString(g.Name) {
delete(groups, id)
}
}
}
c.Check(groups, HasLen, len(t.results))
for j, g := range t.results {
rg := groups[g.Id]
c.Assert(rg, NotNil, Commentf("group %d (%v) not found; got %#v", j, g, groups))
c.Check(rg.Name, Equals, g.Name, Commentf("group %d (%v)", j, g))
}
}
}

View File

@@ -1,84 +0,0 @@
package ec2test
import (
"fmt"
"net/url"
"strings"
)
// filter holds an ec2 filter. A filter maps an attribute to a set of
// possible values for that attribute. For an item to pass through the
// filter, every attribute of the item mentioned in the filter must match
// at least one of its given values.
type filter map[string][]string
// newFilter creates a new filter from the Filter fields in the url form.
//
// The filtering is specified through a map of name=>values, where the
// name is a well-defined key identifying the data to be matched,
// and the list of values holds the possible values the filtered
// item can take for the key to be included in the
// result set. For example:
//
// Filter.1.Name=instance-type
// Filter.1.Value.1=m1.small
// Filter.1.Value.2=m1.large
//
func newFilter(form url.Values) filter {
// TODO return an error if the fields are not well formed?
names := make(map[int]string)
values := make(map[int][]string)
maxId := 0
for name, fvalues := range form {
var rest string
var id int
if x, _ := fmt.Sscanf(name, "Filter.%d.%s", &id, &rest); x != 2 {
continue
}
if id > maxId {
maxId = id
}
if rest == "Name" {
names[id] = fvalues[0]
continue
}
if !strings.HasPrefix(rest, "Value.") {
continue
}
values[id] = append(values[id], fvalues[0])
}
f := make(filter)
for id, name := range names {
f[name] = values[id]
}
return f
}
func notDigit(r rune) bool {
return r < '0' || r > '9'
}
// filterable represents an object that can be passed through a filter.
type filterable interface {
// matchAttr returns true if given attribute of the
// object matches value. It returns an error if the
// attribute is not recognised or the value is malformed.
matchAttr(attr, value string) (bool, error)
}
// ok returns true if x passes through the filter.
func (f filter) ok(x filterable) (bool, error) {
next:
for a, vs := range f {
for _, v := range vs {
if ok, err := x.matchAttr(a, v); ok {
continue next
} else if err != nil {
return false, fmt.Errorf("bad attribute or value %q=%q for type %T: %v", a, v, x, err)
}
}
return false, nil
}
return true, nil
}

View File

@@ -1,993 +0,0 @@
// The ec2test package implements a fake EC2 provider with
// the capability of inducing errors on any given operation,
// and retrospectively determining what operations have been
// carried out.
package ec2test
import (
"encoding/base64"
"encoding/xml"
"fmt"
"github.com/mitchellh/goamz/ec2"
"io"
"net"
"net/http"
"net/url"
"regexp"
"strconv"
"strings"
"sync"
)
var b64 = base64.StdEncoding
// Action represents a request that changes the ec2 state.
type Action struct {
RequestId string
// Request holds the requested action as a url.Values instance
Request url.Values
// If the action succeeded, Response holds the value that
// was marshalled to build the XML response for the request.
Response interface{}
// If the action failed, Err holds an error giving details of the failure.
Err *ec2.Error
}
// TODO possible other things:
// - some virtual time stamp interface, so a client
// can ask for all actions after a certain virtual time.
// Server implements an EC2 simulator for use in testing.
type Server struct {
url string
listener net.Listener
mu sync.Mutex
reqs []*Action
instances map[string]*Instance // id -> instance
reservations map[string]*reservation // id -> reservation
groups map[string]*securityGroup // id -> group
maxId counter
reqId counter
reservationId counter
groupId counter
initialInstanceState ec2.InstanceState
}
// reservation holds a simulated ec2 reservation.
type reservation struct {
id string
instances map[string]*Instance
groups []*securityGroup
}
// instance holds a simulated ec2 instance
type Instance struct {
// UserData holds the data that was passed to the RunInstances request
// when the instance was started.
UserData []byte
id string
imageId string
reservation *reservation
instType string
state ec2.InstanceState
}
// permKey represents permission for a given security
// group or IP address (but not both) to access a given range of
// ports. Equality of permKeys is used in the implementation of
// permission sets, relying on the uniqueness of securityGroup
// instances.
type permKey struct {
protocol string
fromPort int
toPort int
group *securityGroup
ipAddr string
}
// securityGroup holds a simulated ec2 security group.
// Instances of securityGroup should only be created through
// Server.createSecurityGroup to ensure that groups can be
// compared by pointer value.
type securityGroup struct {
id string
name string
description string
perms map[permKey]bool
}
func (g *securityGroup) ec2SecurityGroup() ec2.SecurityGroup {
return ec2.SecurityGroup{
Name: g.name,
Id: g.id,
}
}
func (g *securityGroup) matchAttr(attr, value string) (ok bool, err error) {
switch attr {
case "description":
return g.description == value, nil
case "group-id":
return g.id == value, nil
case "group-name":
return g.name == value, nil
case "ip-permission.cidr":
return g.hasPerm(func(k permKey) bool { return k.ipAddr == value }), nil
case "ip-permission.group-name":
return g.hasPerm(func(k permKey) bool {
return k.group != nil && k.group.name == value
}), nil
case "ip-permission.from-port":
port, err := strconv.Atoi(value)
if err != nil {
return false, err
}
return g.hasPerm(func(k permKey) bool { return k.fromPort == port }), nil
case "ip-permission.to-port":
port, err := strconv.Atoi(value)
if err != nil {
return false, err
}
return g.hasPerm(func(k permKey) bool { return k.toPort == port }), nil
case "ip-permission.protocol":
return g.hasPerm(func(k permKey) bool { return k.protocol == value }), nil
case "owner-id":
return value == ownerId, nil
}
return false, fmt.Errorf("unknown attribute %q", attr)
}
func (g *securityGroup) hasPerm(test func(k permKey) bool) bool {
for k := range g.perms {
if test(k) {
return true
}
}
return false
}
// ec2Perms returns the list of EC2 permissions granted
// to g. It groups permissions by port range and protocol.
func (g *securityGroup) ec2Perms() (perms []ec2.IPPerm) {
// The grouping is held in result. We use permKey for convenience,
// (ensuring that the group and ipAddr of each key is zero). For
// each protocol/port range combination, we build up the permission
// set in the associated value.
result := make(map[permKey]*ec2.IPPerm)
for k := range g.perms {
groupKey := k
groupKey.group = nil
groupKey.ipAddr = ""
ec2p := result[groupKey]
if ec2p == nil {
ec2p = &ec2.IPPerm{
Protocol: k.protocol,
FromPort: k.fromPort,
ToPort: k.toPort,
}
result[groupKey] = ec2p
}
if k.group != nil {
ec2p.SourceGroups = append(ec2p.SourceGroups,
ec2.UserSecurityGroup{
Id: k.group.id,
Name: k.group.name,
OwnerId: ownerId,
})
} else {
ec2p.SourceIPs = append(ec2p.SourceIPs, k.ipAddr)
}
}
for _, ec2p := range result {
perms = append(perms, *ec2p)
}
return
}
var actions = map[string]func(*Server, http.ResponseWriter, *http.Request, string) interface{}{
"RunInstances": (*Server).runInstances,
"TerminateInstances": (*Server).terminateInstances,
"DescribeInstances": (*Server).describeInstances,
"CreateSecurityGroup": (*Server).createSecurityGroup,
"DescribeSecurityGroups": (*Server).describeSecurityGroups,
"DeleteSecurityGroup": (*Server).deleteSecurityGroup,
"AuthorizeSecurityGroupIngress": (*Server).authorizeSecurityGroupIngress,
"RevokeSecurityGroupIngress": (*Server).revokeSecurityGroupIngress,
}
const ownerId = "9876"
// newAction allocates a new action and adds it to the
// recorded list of server actions.
func (srv *Server) newAction() *Action {
srv.mu.Lock()
defer srv.mu.Unlock()
a := new(Action)
srv.reqs = append(srv.reqs, a)
return a
}
// NewServer returns a new server.
func NewServer() (*Server, error) {
srv := &Server{
instances: make(map[string]*Instance),
groups: make(map[string]*securityGroup),
reservations: make(map[string]*reservation),
initialInstanceState: Pending,
}
// Add default security group.
g := &securityGroup{
name: "default",
description: "default group",
id: fmt.Sprintf("sg-%d", srv.groupId.next()),
}
g.perms = map[permKey]bool{
permKey{
protocol: "icmp",
fromPort: -1,
toPort: -1,
group: g,
}: true,
permKey{
protocol: "tcp",
fromPort: 0,
toPort: 65535,
group: g,
}: true,
permKey{
protocol: "udp",
fromPort: 0,
toPort: 65535,
group: g,
}: true,
}
srv.groups[g.id] = g
l, err := net.Listen("tcp", "localhost:0")
if err != nil {
return nil, fmt.Errorf("cannot listen on localhost: %v", err)
}
srv.listener = l
srv.url = "http://" + l.Addr().String()
// we use HandlerFunc rather than *Server directly so that we
// can avoid exporting HandlerFunc from *Server.
go http.Serve(l, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
srv.serveHTTP(w, req)
}))
return srv, nil
}
// Quit closes down the server.
func (srv *Server) Quit() {
srv.listener.Close()
}
// SetInitialInstanceState sets the state that any new instances will be started in.
func (srv *Server) SetInitialInstanceState(state ec2.InstanceState) {
srv.mu.Lock()
srv.initialInstanceState = state
srv.mu.Unlock()
}
// URL returns the URL of the server.
func (srv *Server) URL() string {
return srv.url
}
// serveHTTP serves the EC2 protocol.
func (srv *Server) serveHTTP(w http.ResponseWriter, req *http.Request) {
req.ParseForm()
a := srv.newAction()
a.RequestId = fmt.Sprintf("req%d", srv.reqId.next())
a.Request = req.Form
// Methods on Server that deal with parsing user data
// may fail. To save on error handling code, we allow these
// methods to call fatalf, which will panic with an *ec2.Error
// which will be caught here and returned
// to the client as a properly formed EC2 error.
defer func() {
switch err := recover().(type) {
case *ec2.Error:
a.Err = err
err.RequestId = a.RequestId
writeError(w, err)
case nil:
default:
panic(err)
}
}()
f := actions[req.Form.Get("Action")]
if f == nil {
fatalf(400, "InvalidParameterValue", "Unrecognized Action")
}
response := f(srv, w, req, a.RequestId)
a.Response = response
w.Header().Set("Content-Type", `xml version="1.0" encoding="UTF-8"`)
xmlMarshal(w, response)
}
// Instance returns the instance for the given instance id.
// It returns nil if there is no such instance.
func (srv *Server) Instance(id string) *Instance {
srv.mu.Lock()
defer srv.mu.Unlock()
return srv.instances[id]
}
// writeError writes an appropriate error response.
// TODO how should we deal with errors when the
// error itself is potentially generated by backend-agnostic
// code?
func writeError(w http.ResponseWriter, err *ec2.Error) {
// Error encapsulates an error returned by EC2.
// TODO merge with ec2.Error when xml supports ignoring a field.
type ec2error struct {
Code string // EC2 error code ("UnsupportedOperation", ...)
Message string // The human-oriented error message
RequestId string
}
type Response struct {
RequestId string
Errors []ec2error `xml:"Errors>Error"`
}
w.Header().Set("Content-Type", `xml version="1.0" encoding="UTF-8"`)
w.WriteHeader(err.StatusCode)
xmlMarshal(w, Response{
RequestId: err.RequestId,
Errors: []ec2error{{
Code: err.Code,
Message: err.Message,
}},
})
}
// xmlMarshal is the same as xml.Marshal except that
// it panics on error. The marshalling should not fail,
// but we want to know if it does.
func xmlMarshal(w io.Writer, x interface{}) {
if err := xml.NewEncoder(w).Encode(x); err != nil {
panic(fmt.Errorf("error marshalling %#v: %v", x, err))
}
}
// formToGroups parses a set of SecurityGroup form values
// as found in a RunInstances request, and returns the resulting
// slice of security groups.
// It calls fatalf if a group is not found.
func (srv *Server) formToGroups(form url.Values) []*securityGroup {
var groups []*securityGroup
for name, values := range form {
switch {
case strings.HasPrefix(name, "SecurityGroupId."):
if g := srv.groups[values[0]]; g != nil {
groups = append(groups, g)
} else {
fatalf(400, "InvalidGroup.NotFound", "unknown group id %q", values[0])
}
case strings.HasPrefix(name, "SecurityGroup."):
var found *securityGroup
for _, g := range srv.groups {
if g.name == values[0] {
found = g
}
}
if found == nil {
fatalf(400, "InvalidGroup.NotFound", "unknown group name %q", values[0])
}
groups = append(groups, found)
}
}
return groups
}
// runInstances implements the EC2 RunInstances entry point.
func (srv *Server) runInstances(w http.ResponseWriter, req *http.Request, reqId string) interface{} {
min := atoi(req.Form.Get("MinCount"))
max := atoi(req.Form.Get("MaxCount"))
if min < 0 || max < 1 {
fatalf(400, "InvalidParameterValue", "bad values for MinCount or MaxCount")
}
if min > max {
fatalf(400, "InvalidParameterCombination", "MinCount is greater than MaxCount")
}
var userData []byte
if data := req.Form.Get("UserData"); data != "" {
var err error
userData, err = b64.DecodeString(data)
if err != nil {
fatalf(400, "InvalidParameterValue", "bad UserData value: %v", err)
}
}
// TODO attributes still to consider:
// ImageId: accept anything, we can verify later
// KeyName ?
// InstanceType ?
// KernelId ?
// RamdiskId ?
// AvailZone ?
// GroupName tag
// Monitoring ignore?
// SubnetId ?
// DisableAPITermination bool
// ShutdownBehavior string
// PrivateIPAddress string
srv.mu.Lock()
defer srv.mu.Unlock()
// make sure that form fields are correct before creating the reservation.
instType := req.Form.Get("InstanceType")
imageId := req.Form.Get("ImageId")
r := srv.newReservation(srv.formToGroups(req.Form))
var resp ec2.RunInstancesResp
resp.RequestId = reqId
resp.ReservationId = r.id
resp.OwnerId = ownerId
for i := 0; i < max; i++ {
inst := srv.newInstance(r, instType, imageId, srv.initialInstanceState)
inst.UserData = userData
resp.Instances = append(resp.Instances, inst.ec2instance())
}
return &resp
}
func (srv *Server) group(group ec2.SecurityGroup) *securityGroup {
if group.Id != "" {
return srv.groups[group.Id]
}
for _, g := range srv.groups {
if g.name == group.Name {
return g
}
}
return nil
}
// NewInstances creates n new instances in srv with the given instance type,
// image ID, initial state and security groups. If any group does not already
// exist, it will be created. NewInstances returns the ids of the new instances.
func (srv *Server) NewInstances(n int, instType string, imageId string, state ec2.InstanceState, groups []ec2.SecurityGroup) []string {
srv.mu.Lock()
defer srv.mu.Unlock()
rgroups := make([]*securityGroup, len(groups))
for i, group := range groups {
g := srv.group(group)
if g == nil {
fatalf(400, "InvalidGroup.NotFound", "no such group %v", g)
}
rgroups[i] = g
}
r := srv.newReservation(rgroups)
ids := make([]string, n)
for i := 0; i < n; i++ {
inst := srv.newInstance(r, instType, imageId, state)
ids[i] = inst.id
}
return ids
}
func (srv *Server) newInstance(r *reservation, instType string, imageId string, state ec2.InstanceState) *Instance {
inst := &Instance{
id: fmt.Sprintf("i-%d", srv.maxId.next()),
instType: instType,
imageId: imageId,
state: state,
reservation: r,
}
srv.instances[inst.id] = inst
r.instances[inst.id] = inst
return inst
}
func (srv *Server) newReservation(groups []*securityGroup) *reservation {
r := &reservation{
id: fmt.Sprintf("r-%d", srv.reservationId.next()),
instances: make(map[string]*Instance),
groups: groups,
}
srv.reservations[r.id] = r
return r
}
func (srv *Server) terminateInstances(w http.ResponseWriter, req *http.Request, reqId string) interface{} {
srv.mu.Lock()
defer srv.mu.Unlock()
var resp ec2.TerminateInstancesResp
resp.RequestId = reqId
var insts []*Instance
for attr, vals := range req.Form {
if strings.HasPrefix(attr, "InstanceId.") {
id := vals[0]
inst := srv.instances[id]
if inst == nil {
fatalf(400, "InvalidInstanceID.NotFound", "no such instance id %q", id)
}
insts = append(insts, inst)
}
}
for _, inst := range insts {
resp.StateChanges = append(resp.StateChanges, inst.terminate())
}
return &resp
}
func (inst *Instance) terminate() (d ec2.InstanceStateChange) {
d.PreviousState = inst.state
inst.state = ShuttingDown
d.CurrentState = inst.state
d.InstanceId = inst.id
return d
}
func (inst *Instance) ec2instance() ec2.Instance {
return ec2.Instance{
InstanceId: inst.id,
InstanceType: inst.instType,
ImageId: inst.imageId,
DNSName: fmt.Sprintf("%s.example.com", inst.id),
// TODO the rest
}
}
func (inst *Instance) matchAttr(attr, value string) (ok bool, err error) {
switch attr {
case "architecture":
return value == "i386", nil
case "instance-id":
return inst.id == value, nil
case "group-id":
for _, g := range inst.reservation.groups {
if g.id == value {
return true, nil
}
}
return false, nil
case "group-name":
for _, g := range inst.reservation.groups {
if g.name == value {
return true, nil
}
}
return false, nil
case "image-id":
return value == inst.imageId, nil
case "instance-state-code":
code, err := strconv.Atoi(value)
if err != nil {
return false, err
}
return code&0xff == inst.state.Code, nil
case "instance-state-name":
return value == inst.state.Name, nil
}
return false, fmt.Errorf("unknown attribute %q", attr)
}
var (
Pending = ec2.InstanceState{0, "pending"}
Running = ec2.InstanceState{16, "running"}
ShuttingDown = ec2.InstanceState{32, "shutting-down"}
Terminated = ec2.InstanceState{16, "terminated"}
Stopped = ec2.InstanceState{16, "stopped"}
)
func (srv *Server) createSecurityGroup(w http.ResponseWriter, req *http.Request, reqId string) interface{} {
name := req.Form.Get("GroupName")
if name == "" {
fatalf(400, "InvalidParameterValue", "empty security group name")
}
srv.mu.Lock()
defer srv.mu.Unlock()
if srv.group(ec2.SecurityGroup{Name: name}) != nil {
fatalf(400, "InvalidGroup.Duplicate", "group %q already exists", name)
}
g := &securityGroup{
name: name,
description: req.Form.Get("GroupDescription"),
id: fmt.Sprintf("sg-%d", srv.groupId.next()),
perms: make(map[permKey]bool),
}
srv.groups[g.id] = g
// we define a local type for this because ec2.CreateSecurityGroupResp
// contains SecurityGroup, but the response to this request
// should not contain the security group name.
type CreateSecurityGroupResponse struct {
RequestId string `xml:"requestId"`
Return bool `xml:"return"`
GroupId string `xml:"groupId"`
}
r := &CreateSecurityGroupResponse{
RequestId: reqId,
Return: true,
GroupId: g.id,
}
return r
}
func (srv *Server) notImplemented(w http.ResponseWriter, req *http.Request, reqId string) interface{} {
fatalf(500, "InternalError", "not implemented")
panic("not reached")
}
func (srv *Server) describeInstances(w http.ResponseWriter, req *http.Request, reqId string) interface{} {
srv.mu.Lock()
defer srv.mu.Unlock()
insts := make(map[*Instance]bool)
for name, vals := range req.Form {
if !strings.HasPrefix(name, "InstanceId.") {
continue
}
inst := srv.instances[vals[0]]
if inst == nil {
fatalf(400, "InvalidInstanceID.NotFound", "instance %q not found", vals[0])
}
insts[inst] = true
}
f := newFilter(req.Form)
var resp ec2.InstancesResp
resp.RequestId = reqId
for _, r := range srv.reservations {
var instances []ec2.Instance
for _, inst := range r.instances {
if len(insts) > 0 && !insts[inst] {
continue
}
ok, err := f.ok(inst)
if ok {
instances = append(instances, inst.ec2instance())
} else if err != nil {
fatalf(400, "InvalidParameterValue", "describe instances: %v", err)
}
}
if len(instances) > 0 {
var groups []ec2.SecurityGroup
for _, g := range r.groups {
groups = append(groups, g.ec2SecurityGroup())
}
resp.Reservations = append(resp.Reservations, ec2.Reservation{
ReservationId: r.id,
OwnerId: ownerId,
Instances: instances,
SecurityGroups: groups,
})
}
}
return &resp
}
func (srv *Server) describeSecurityGroups(w http.ResponseWriter, req *http.Request, reqId string) interface{} {
// BUG similar bug to describeInstances, but for GroupName and GroupId
srv.mu.Lock()
defer srv.mu.Unlock()
var groups []*securityGroup
for name, vals := range req.Form {
var g ec2.SecurityGroup
switch {
case strings.HasPrefix(name, "GroupName."):
g.Name = vals[0]
case strings.HasPrefix(name, "GroupId."):
g.Id = vals[0]
default:
continue
}
sg := srv.group(g)
if sg == nil {
fatalf(400, "InvalidGroup.NotFound", "no such group %v", g)
}
groups = append(groups, sg)
}
if len(groups) == 0 {
for _, g := range srv.groups {
groups = append(groups, g)
}
}
f := newFilter(req.Form)
var resp ec2.SecurityGroupsResp
resp.RequestId = reqId
for _, group := range groups {
ok, err := f.ok(group)
if ok {
resp.Groups = append(resp.Groups, ec2.SecurityGroupInfo{
OwnerId: ownerId,
SecurityGroup: group.ec2SecurityGroup(),
Description: group.description,
IPPerms: group.ec2Perms(),
})
} else if err != nil {
fatalf(400, "InvalidParameterValue", "describe security groups: %v", err)
}
}
return &resp
}
func (srv *Server) authorizeSecurityGroupIngress(w http.ResponseWriter, req *http.Request, reqId string) interface{} {
srv.mu.Lock()
defer srv.mu.Unlock()
g := srv.group(ec2.SecurityGroup{
Name: req.Form.Get("GroupName"),
Id: req.Form.Get("GroupId"),
})
if g == nil {
fatalf(400, "InvalidGroup.NotFound", "group not found")
}
perms := srv.parsePerms(req)
for _, p := range perms {
if g.perms[p] {
fatalf(400, "InvalidPermission.Duplicate", "Permission has already been authorized on the specified group")
}
}
for _, p := range perms {
g.perms[p] = true
}
return &ec2.SimpleResp{
XMLName: xml.Name{"", "AuthorizeSecurityGroupIngressResponse"},
RequestId: reqId,
}
}
func (srv *Server) revokeSecurityGroupIngress(w http.ResponseWriter, req *http.Request, reqId string) interface{} {
srv.mu.Lock()
defer srv.mu.Unlock()
g := srv.group(ec2.SecurityGroup{
Name: req.Form.Get("GroupName"),
Id: req.Form.Get("GroupId"),
})
if g == nil {
fatalf(400, "InvalidGroup.NotFound", "group not found")
}
perms := srv.parsePerms(req)
// Note EC2 does not give an error if asked to revoke an authorization
// that does not exist.
for _, p := range perms {
delete(g.perms, p)
}
return &ec2.SimpleResp{
XMLName: xml.Name{"", "RevokeSecurityGroupIngressResponse"},
RequestId: reqId,
}
}
var secGroupPat = regexp.MustCompile(`^sg-[a-z0-9]+$`)
var ipPat = regexp.MustCompile(`^[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+/[0-9]+$`)
var ownerIdPat = regexp.MustCompile(`^[0-9]+$`)
// parsePerms returns a slice of permKey values extracted
// from the permission fields in req.
func (srv *Server) parsePerms(req *http.Request) []permKey {
// perms maps an index found in the form to its associated
// IPPerm. For instance, the form value with key
// "IpPermissions.3.FromPort" will be stored in perms[3].FromPort
perms := make(map[int]ec2.IPPerm)
type subgroupKey struct {
id1, id2 int
}
// Each IPPerm can have many source security groups. The form key
// for a source security group contains two indices: the index
// of the IPPerm and the sub-index of the security group. The
// sourceGroups map maps from a subgroupKey containing these
// two indices to the associated security group. For instance,
// the form value with key "IPPermissions.3.Groups.2.GroupName"
// will be stored in sourceGroups[subgroupKey{3, 2}].Name.
sourceGroups := make(map[subgroupKey]ec2.UserSecurityGroup)
// For each value in the form we store its associated information in the
// above maps. The maps are necessary because the form keys may
// arrive in any order, and the indices are not
// necessarily sequential or even small.
for name, vals := range req.Form {
val := vals[0]
var id1 int
var rest string
if x, _ := fmt.Sscanf(name, "IpPermissions.%d.%s", &id1, &rest); x != 2 {
continue
}
ec2p := perms[id1]
switch {
case rest == "FromPort":
ec2p.FromPort = atoi(val)
case rest == "ToPort":
ec2p.ToPort = atoi(val)
case rest == "IpProtocol":
switch val {
case "tcp", "udp", "icmp":
ec2p.Protocol = val
default:
// check it's a well formed number
atoi(val)
ec2p.Protocol = val
}
case strings.HasPrefix(rest, "Groups."):
k := subgroupKey{id1: id1}
if x, _ := fmt.Sscanf(rest[len("Groups."):], "%d.%s", &k.id2, &rest); x != 2 {
continue
}
g := sourceGroups[k]
switch rest {
case "UserId":
// BUG if the user id is blank, this does not conform to the
// way that EC2 handles it - a specified but blank owner id
// can cause RevokeSecurityGroupIngress to fail with
// "group not found" even if the security group id has been
// correctly specified.
// By failing here, we ensure that we fail early in this case.
if !ownerIdPat.MatchString(val) {
fatalf(400, "InvalidUserID.Malformed", "Invalid user ID: %q", val)
}
g.OwnerId = val
case "GroupName":
g.Name = val
case "GroupId":
if !secGroupPat.MatchString(val) {
fatalf(400, "InvalidGroupId.Malformed", "Invalid group ID: %q", val)
}
g.Id = val
default:
fatalf(400, "UnknownParameter", "unknown parameter %q", name)
}
sourceGroups[k] = g
case strings.HasPrefix(rest, "IpRanges."):
var id2 int
if x, _ := fmt.Sscanf(rest[len("IpRanges."):], "%d.%s", &id2, &rest); x != 2 {
continue
}
switch rest {
case "CidrIp":
if !ipPat.MatchString(val) {
fatalf(400, "InvalidPermission.Malformed", "Invalid IP range: %q", val)
}
ec2p.SourceIPs = append(ec2p.SourceIPs, val)
default:
fatalf(400, "UnknownParameter", "unknown parameter %q", name)
}
default:
fatalf(400, "UnknownParameter", "unknown parameter %q", name)
}
perms[id1] = ec2p
}
// Associate each set of source groups with its IPPerm.
for k, g := range sourceGroups {
p := perms[k.id1]
p.SourceGroups = append(p.SourceGroups, g)
perms[k.id1] = p
}
// Now that we have built up the IPPerms we need, we check for
// parameter errors and build up a permKey for each permission,
// looking up security groups from srv as we do so.
var result []permKey
for _, p := range perms {
if p.FromPort > p.ToPort {
fatalf(400, "InvalidParameterValue", "invalid port range")
}
k := permKey{
protocol: p.Protocol,
fromPort: p.FromPort,
toPort: p.ToPort,
}
for _, g := range p.SourceGroups {
if g.OwnerId != "" && g.OwnerId != ownerId {
fatalf(400, "InvalidGroup.NotFound", "group %q not found", g.Name)
}
var ec2g ec2.SecurityGroup
switch {
case g.Id != "":
ec2g.Id = g.Id
case g.Name != "":
ec2g.Name = g.Name
}
k.group = srv.group(ec2g)
if k.group == nil {
fatalf(400, "InvalidGroup.NotFound", "group %v not found", g)
}
result = append(result, k)
}
k.group = nil
for _, ip := range p.SourceIPs {
k.ipAddr = ip
result = append(result, k)
}
}
return result
}
func (srv *Server) deleteSecurityGroup(w http.ResponseWriter, req *http.Request, reqId string) interface{} {
srv.mu.Lock()
defer srv.mu.Unlock()
g := srv.group(ec2.SecurityGroup{
Name: req.Form.Get("GroupName"),
Id: req.Form.Get("GroupId"),
})
if g == nil {
fatalf(400, "InvalidGroup.NotFound", "group not found")
}
for _, r := range srv.reservations {
for _, h := range r.groups {
if h == g && r.hasRunningMachine() {
fatalf(500, "InvalidGroup.InUse", "group is currently in use by a running instance")
}
}
}
for _, sg := range srv.groups {
// If a group refers to itself, it's ok to delete it.
if sg == g {
continue
}
for k := range sg.perms {
if k.group == g {
fatalf(500, "InvalidGroup.InUse", "group is currently in use by group %q", sg.id)
}
}
}
delete(srv.groups, g.id)
return &ec2.SimpleResp{
XMLName: xml.Name{"", "DeleteSecurityGroupResponse"},
RequestId: reqId,
}
}
func (r *reservation) hasRunningMachine() bool {
for _, inst := range r.instances {
if inst.state.Code != ShuttingDown.Code && inst.state.Code != Terminated.Code {
return true
}
}
return false
}
type counter int
func (c *counter) next() (i int) {
i = int(*c)
(*c)++
return
}
// atoi is like strconv.Atoi but is fatal if the
// string is not well formed.
func atoi(s string) int {
i, err := strconv.Atoi(s)
if err != nil {
fatalf(400, "InvalidParameterValue", "bad number: %v", err)
}
return i
}
func fatalf(statusCode int, code string, f string, a ...interface{}) {
panic(&ec2.Error{
StatusCode: statusCode,
Code: code,
Message: fmt.Sprintf(f, a...),
})
}

View File

@@ -1,22 +0,0 @@
package ec2
import (
"github.com/mitchellh/goamz/aws"
"time"
)
func Sign(auth aws.Auth, method, path string, params map[string]string, host string) {
sign(auth, method, path, params, host)
}
func fixedTime() time.Time {
return time.Date(2012, 1, 1, 0, 0, 0, 0, time.UTC)
}
func FakeTime(fakeIt bool) {
if fakeIt {
timeNow = fixedTime
} else {
timeNow = time.Now
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,45 +0,0 @@
package ec2
import (
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"github.com/mitchellh/goamz/aws"
"sort"
"strings"
)
// ----------------------------------------------------------------------------
// EC2 signing (http://goo.gl/fQmAN)
var b64 = base64.StdEncoding
func sign(auth aws.Auth, method, path string, params map[string]string, host string) {
params["AWSAccessKeyId"] = auth.AccessKey
params["SignatureVersion"] = "2"
params["SignatureMethod"] = "HmacSHA256"
if auth.Token != "" {
params["SecurityToken"] = auth.Token
}
// AWS specifies that the parameters in a signed request must
// be provided in the natural order of the keys. This is distinct
// from the natural order of the encoded value of key=value.
// Percent and equals affect the sorting order.
var keys, sarray []string
for k, _ := range params {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
sarray = append(sarray, aws.Encode(k)+"="+aws.Encode(params[k]))
}
joined := strings.Join(sarray, "&")
payload := method + "\n" + host + "\n" + path + "\n" + joined
hash := hmac.New(sha256.New, []byte(auth.SecretKey))
hash.Write([]byte(payload))
signature := make([]byte, b64.EncodedLen(hash.Size()))
b64.Encode(signature, hash.Sum(nil))
params["Signature"] = string(signature)
}

View File

@@ -1,68 +0,0 @@
package ec2_test
import (
"github.com/mitchellh/goamz/aws"
"github.com/mitchellh/goamz/ec2"
. "github.com/motain/gocheck"
)
// EC2 ReST authentication docs: http://goo.gl/fQmAN
var testAuth = aws.Auth{"user", "secret", ""}
func (s *S) TestBasicSignature(c *C) {
params := map[string]string{}
ec2.Sign(testAuth, "GET", "/path", params, "localhost")
c.Assert(params["SignatureVersion"], Equals, "2")
c.Assert(params["SignatureMethod"], Equals, "HmacSHA256")
expected := "6lSe5QyXum0jMVc7cOUz32/52ZnL7N5RyKRk/09yiK4="
c.Assert(params["Signature"], Equals, expected)
}
func (s *S) TestParamSignature(c *C) {
params := map[string]string{
"param1": "value1",
"param2": "value2",
"param3": "value3",
}
ec2.Sign(testAuth, "GET", "/path", params, "localhost")
expected := "XWOR4+0lmK8bD8CGDGZ4kfuSPbb2JibLJiCl/OPu1oU="
c.Assert(params["Signature"], Equals, expected)
}
func (s *S) TestManyParams(c *C) {
params := map[string]string{
"param1": "value10",
"param2": "value2",
"param3": "value3",
"param4": "value4",
"param5": "value5",
"param6": "value6",
"param7": "value7",
"param8": "value8",
"param9": "value9",
"param10": "value1",
}
ec2.Sign(testAuth, "GET", "/path", params, "localhost")
expected := "di0sjxIvezUgQ1SIL6i+C/H8lL+U0CQ9frLIak8jkVg="
c.Assert(params["Signature"], Equals, expected)
}
func (s *S) TestEscaping(c *C) {
params := map[string]string{"Nonce": "+ +"}
ec2.Sign(testAuth, "GET", "/path", params, "localhost")
c.Assert(params["Nonce"], Equals, "+ +")
expected := "bqffDELReIqwjg/W0DnsnVUmfLK4wXVLO4/LuG+1VFA="
c.Assert(params["Signature"], Equals, expected)
}
func (s *S) TestSignatureExample1(c *C) {
params := map[string]string{
"Timestamp": "2009-02-01T12:53:20+00:00",
"Version": "2007-11-07",
"Action": "ListDomains",
}
ec2.Sign(aws.Auth{"access", "secret", ""}, "GET", "/", params, "sdb.amazonaws.com")
expected := "okj96/5ucWBSc1uR2zXVfm6mDHtgfNv657rRtt/aunQ="
c.Assert(params["Signature"], Equals, expected)
}