Merge pull request #18478 from yifan-gu/bump_go_oidc

Auto commit by PR queue bot
This commit is contained in:
k8s-merge-robot 2015-12-15 20:46:19 -08:00
commit f20cad179f
11 changed files with 188 additions and 94 deletions

10
Godeps/Godeps.json generated
View File

@ -247,23 +247,23 @@
}, },
{ {
"ImportPath": "github.com/coreos/go-oidc/http", "ImportPath": "github.com/coreos/go-oidc/http",
"Rev": "ee7cb1fb480df22f7d8c4c90199e438e454ca3b6" "Rev": "024cdeee09d02fb439eb55bc422e582ac115615b"
}, },
{ {
"ImportPath": "github.com/coreos/go-oidc/jose", "ImportPath": "github.com/coreos/go-oidc/jose",
"Rev": "ee7cb1fb480df22f7d8c4c90199e438e454ca3b6" "Rev": "024cdeee09d02fb439eb55bc422e582ac115615b"
}, },
{ {
"ImportPath": "github.com/coreos/go-oidc/key", "ImportPath": "github.com/coreos/go-oidc/key",
"Rev": "ee7cb1fb480df22f7d8c4c90199e438e454ca3b6" "Rev": "024cdeee09d02fb439eb55bc422e582ac115615b"
}, },
{ {
"ImportPath": "github.com/coreos/go-oidc/oauth2", "ImportPath": "github.com/coreos/go-oidc/oauth2",
"Rev": "ee7cb1fb480df22f7d8c4c90199e438e454ca3b6" "Rev": "024cdeee09d02fb439eb55bc422e582ac115615b"
}, },
{ {
"ImportPath": "github.com/coreos/go-oidc/oidc", "ImportPath": "github.com/coreos/go-oidc/oidc",
"Rev": "ee7cb1fb480df22f7d8c4c90199e438e454ca3b6" "Rev": "024cdeee09d02fb439eb55bc422e582ac115615b"
}, },
{ {
"ImportPath": "github.com/coreos/go-semver/semver", "ImportPath": "github.com/coreos/go-semver/semver",

View File

@ -26,6 +26,32 @@ func (c Claims) StringClaim(name string) (string, bool, error) {
return v, true, nil return v, true, nil
} }
func (c Claims) StringsClaim(name string) ([]string, bool, error) {
cl, ok := c[name]
if !ok {
return nil, false, nil
}
if v, ok := cl.([]string); ok {
return v, true, nil
}
// When unmarshaled, []string will become []interface{}.
if v, ok := cl.([]interface{}); ok {
var ret []string
for _, vv := range v {
str, ok := vv.(string)
if !ok {
return nil, false, fmt.Errorf("unable to parse claim as string array: %v", name)
}
ret = append(ret, str)
}
return ret, true, nil
}
return nil, false, fmt.Errorf("unable to parse claim as string array: %v", name)
}
func (c Claims) Int64Claim(name string) (int64, bool, error) { func (c Claims) Int64Claim(name string) (int64, bool, error) {
cl, ok := c[name] cl, ok := c[name]
if !ok { if !ok {

View File

@ -1,8 +1,6 @@
package jose package jose
import ( import "strings"
"strings"
)
type JWT JWS type JWT JWS
@ -63,13 +61,13 @@ func (j *JWT) Encode() string {
return strings.Join([]string{d, s}, ".") return strings.Join([]string{d, s}, ".")
} }
func NewSignedJWT(claims map[string]interface{}, s Signer) (*JWT, error) { func NewSignedJWT(claims Claims, s Signer) (*JWT, error) {
header := JOSEHeader{ header := JOSEHeader{
HeaderKeyAlgorithm: s.Alg(), HeaderKeyAlgorithm: s.Alg(),
HeaderKeyID: s.ID(), HeaderKeyID: s.ID(),
} }
jwt, err := NewJWT(header, Claims(claims)) jwt, err := NewJWT(header, claims)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -4,6 +4,7 @@ import (
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
"encoding/base64" "encoding/base64"
"encoding/json"
"math/big" "math/big"
"time" "time"
@ -18,6 +19,19 @@ type PublicKey struct {
jwk jose.JWK jwk jose.JWK
} }
func (k *PublicKey) MarshalJSON() ([]byte, error) {
return json.Marshal(k.jwk)
}
func (k *PublicKey) UnmarshalJSON(data []byte) error {
var jwk jose.JWK
if err := json.Unmarshal(data, &jwk); err != nil {
return err
}
k.jwk = jwk
return nil
}
func (k *PublicKey) ID() string { func (k *PublicKey) ID() string {
return k.jwk.ID return k.jwk.ID
} }

View File

@ -1,10 +1,5 @@
package oauth2 package oauth2
import (
"encoding/json"
"fmt"
)
const ( const (
ErrorAccessDenied = "access_denied" ErrorAccessDenied = "access_denied"
ErrorInvalidClient = "invalid_client" ErrorInvalidClient = "invalid_client"
@ -18,22 +13,17 @@ const (
type Error struct { type Error struct {
Type string `json:"error"` Type string `json:"error"`
Description string `json:"error_description,omitempty"`
State string `json:"state,omitempty"` State string `json:"state,omitempty"`
} }
func (e *Error) Error() string { func (e *Error) Error() string {
if e.Description != "" {
return e.Type + ": " + e.Description
}
return e.Type return e.Type
} }
func NewError(typ string) *Error { func NewError(typ string) *Error {
return &Error{Type: typ} return &Error{Type: typ}
} }
func unmarshalError(b []byte) error {
var oerr Error
err := json.Unmarshal(b, &oerr)
if err != nil {
return fmt.Errorf("unrecognized error: %s", string(b))
}
return &oerr
}

View File

@ -220,11 +220,7 @@ func parseTokenResponse(resp *http.Response) (result TokenResponse, err error) {
if err != nil { if err != nil {
return return
} }
badStatusCode := resp.StatusCode < 200 || resp.StatusCode > 299
if resp.StatusCode < 200 || resp.StatusCode > 299 {
err = unmarshalError(body)
return
}
contentType, _, err := mime.ParseMediaType(resp.Header.Get("Content-Type")) contentType, _, err := mime.ParseMediaType(resp.Header.Get("Content-Type"))
if err != nil { if err != nil {
@ -235,42 +231,69 @@ func parseTokenResponse(resp *http.Response) (result TokenResponse, err error) {
RawBody: body, RawBody: body,
} }
newError := func(typ, desc, state string) error {
if typ == "" {
return fmt.Errorf("unrecognized error %s", body)
}
return &Error{typ, desc, state}
}
if contentType == "application/x-www-form-urlencoded" || contentType == "text/plain" { if contentType == "application/x-www-form-urlencoded" || contentType == "text/plain" {
var vals url.Values var vals url.Values
vals, err = url.ParseQuery(string(body)) vals, err = url.ParseQuery(string(body))
if err != nil { if err != nil {
return return
} }
if error := vals.Get("error"); error != "" || badStatusCode {
err = newError(error, vals.Get("error_description"), vals.Get("state"))
return
}
e := vals.Get("expires_in")
if e == "" {
e = vals.Get("expires")
}
if e != "" {
result.Expires, err = strconv.Atoi(e)
if err != nil {
return
}
}
result.AccessToken = vals.Get("access_token") result.AccessToken = vals.Get("access_token")
result.TokenType = vals.Get("token_type") result.TokenType = vals.Get("token_type")
result.IDToken = vals.Get("id_token") result.IDToken = vals.Get("id_token")
result.RefreshToken = vals.Get("refresh_token") result.RefreshToken = vals.Get("refresh_token")
result.Scope = vals.Get("scope") result.Scope = vals.Get("scope")
e := vals.Get("expires_in")
if e == "" {
e = vals.Get("expires")
}
result.Expires, err = strconv.Atoi(e)
if err != nil {
return
}
} else { } else {
b := make(map[string]interface{}) var r struct {
if err = json.Unmarshal(body, &b); err != nil { AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
IDToken string `json:"id_token"`
RefreshToken string `json:"refresh_token"`
Scope string `json:"scope"`
State string `json:"state"`
ExpiresIn int `json:"expires_in"`
Expires int `json:"expires"`
Error string `json:"error"`
Desc string `json:"error_description"`
}
if err = json.Unmarshal(body, &r); err != nil {
return return
} }
result.AccessToken, _ = b["access_token"].(string) if r.Error != "" || badStatusCode {
result.TokenType, _ = b["token_type"].(string) err = newError(r.Error, r.Desc, r.State)
result.IDToken, _ = b["id_token"].(string) return
result.RefreshToken, _ = b["refresh_token"].(string) }
result.Scope, _ = b["scope"].(string) result.AccessToken = r.AccessToken
e, ok := b["expires_in"].(int) result.TokenType = r.TokenType
if !ok { result.IDToken = r.IDToken
e, _ = b["expires"].(int) result.RefreshToken = r.RefreshToken
result.Scope = r.Scope
if r.ExpiresIn == 0 {
result.Expires = r.Expires
} else {
result.Expires = r.ExpiresIn
} }
result.Expires = e
} }
return return
} }

View File

@ -78,7 +78,7 @@ func NewClient(cfg ClientConfig) (*Client, error) {
httpClient: cfg.HTTPClient, httpClient: cfg.HTTPClient,
scope: cfg.Scope, scope: cfg.Scope,
redirectURL: ru.String(), redirectURL: ru.String(),
providerConfig: cfg.ProviderConfig, providerConfig: newProviderConfigRepo(cfg.ProviderConfig),
keySet: cfg.KeySet, keySet: cfg.KeySet,
} }
@ -96,7 +96,7 @@ func NewClient(cfg ClientConfig) (*Client, error) {
type Client struct { type Client struct {
httpClient phttp.Client httpClient phttp.Client
providerConfig ProviderConfig providerConfig *providerConfigRepo
credentials ClientCredentials credentials ClientCredentials
redirectURL string redirectURL string
scope []string scope []string
@ -106,14 +106,39 @@ type Client struct {
lastKeySetSync time.Time lastKeySetSync time.Time
} }
type providerConfigRepo struct {
mu sync.RWMutex
config ProviderConfig // do not access directly, use Get()
}
func newProviderConfigRepo(pc ProviderConfig) *providerConfigRepo {
return &providerConfigRepo{sync.RWMutex{}, pc}
}
// returns an error to implement ProviderConfigSetter
func (r *providerConfigRepo) Set(cfg ProviderConfig) error {
r.mu.Lock()
defer r.mu.Unlock()
r.config = cfg
return nil
}
func (r *providerConfigRepo) Get() ProviderConfig {
r.mu.RLock()
defer r.mu.RUnlock()
return r.config
}
func (c *Client) Healthy() error { func (c *Client) Healthy() error {
now := time.Now().UTC() now := time.Now().UTC()
if c.providerConfig.Empty() { cfg := c.providerConfig.Get()
if cfg.Empty() {
return errors.New("oidc client provider config empty") return errors.New("oidc client provider config empty")
} }
if !c.providerConfig.ExpiresAt.IsZero() && c.providerConfig.ExpiresAt.Before(now) { if !cfg.ExpiresAt.IsZero() && cfg.ExpiresAt.Before(now) {
return errors.New("oidc client provider config expired") return errors.New("oidc client provider config expired")
} }
@ -121,7 +146,8 @@ func (c *Client) Healthy() error {
} }
func (c *Client) OAuthClient() (*oauth2.Client, error) { func (c *Client) OAuthClient() (*oauth2.Client, error) {
authMethod, err := c.chooseAuthMethod() cfg := c.providerConfig.Get()
authMethod, err := chooseAuthMethod(cfg)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -129,8 +155,8 @@ func (c *Client) OAuthClient() (*oauth2.Client, error) {
ocfg := oauth2.Config{ ocfg := oauth2.Config{
Credentials: oauth2.ClientCredentials(c.credentials), Credentials: oauth2.ClientCredentials(c.credentials),
RedirectURL: c.redirectURL, RedirectURL: c.redirectURL,
AuthURL: c.providerConfig.AuthEndpoint, AuthURL: cfg.AuthEndpoint,
TokenURL: c.providerConfig.TokenEndpoint, TokenURL: cfg.TokenEndpoint,
Scope: c.scope, Scope: c.scope,
AuthMethod: authMethod, AuthMethod: authMethod,
} }
@ -138,12 +164,12 @@ func (c *Client) OAuthClient() (*oauth2.Client, error) {
return oauth2.NewClient(c.httpClient, ocfg) return oauth2.NewClient(c.httpClient, ocfg)
} }
func (c *Client) chooseAuthMethod() (string, error) { func chooseAuthMethod(cfg ProviderConfig) (string, error) {
if len(c.providerConfig.TokenEndpointAuthMethodsSupported) == 0 { if len(cfg.TokenEndpointAuthMethodsSupported) == 0 {
return oauth2.AuthMethodClientSecretBasic, nil return oauth2.AuthMethodClientSecretBasic, nil
} }
for _, authMethod := range c.providerConfig.TokenEndpointAuthMethodsSupported { for _, authMethod := range cfg.TokenEndpointAuthMethodsSupported {
if _, ok := supportedAuthMethods[authMethod]; ok { if _, ok := supportedAuthMethods[authMethod]; ok {
return authMethod, nil return authMethod, nil
} }
@ -153,9 +179,8 @@ func (c *Client) chooseAuthMethod() (string, error) {
} }
func (c *Client) SyncProviderConfig(discoveryURL string) chan struct{} { func (c *Client) SyncProviderConfig(discoveryURL string) chan struct{} {
rp := &providerConfigRepo{c}
r := NewHTTPProviderConfigGetter(c.httpClient, discoveryURL) r := NewHTTPProviderConfigGetter(c.httpClient, discoveryURL)
return NewProviderConfigSyncer(r, rp).Run() return NewProviderConfigSyncer(r, c.providerConfig).Run()
} }
func (c *Client) maybeSyncKeys() error { func (c *Client) maybeSyncKeys() error {
@ -178,7 +203,8 @@ func (c *Client) maybeSyncKeys() error {
return nil return nil
} }
r := NewRemotePublicKeyRepo(c.httpClient, c.providerConfig.KeysEndpoint) cfg := c.providerConfig.Get()
r := NewRemotePublicKeyRepo(c.httpClient, cfg.KeysEndpoint)
w := &clientKeyRepo{client: c} w := &clientKeyRepo{client: c}
_, err := key.Sync(r, w) _, err := key.Sync(r, w)
c.lastKeySetSync = time.Now().UTC() c.lastKeySetSync = time.Now().UTC()
@ -186,15 +212,6 @@ func (c *Client) maybeSyncKeys() error {
return err return err
} }
type providerConfigRepo struct {
client *Client
}
func (r *providerConfigRepo) Set(cfg ProviderConfig) error {
r.client.providerConfig = cfg
return nil
}
type clientKeyRepo struct { type clientKeyRepo struct {
client *Client client *Client
} }
@ -209,7 +226,9 @@ func (r *clientKeyRepo) Set(ks key.KeySet) error {
} }
func (c *Client) ClientCredsToken(scope []string) (jose.JWT, error) { func (c *Client) ClientCredsToken(scope []string) (jose.JWT, error) {
if !c.providerConfig.SupportsGrantType(oauth2.GrantTypeClientCreds) { cfg := c.providerConfig.Get()
if !cfg.SupportsGrantType(oauth2.GrantTypeClientCreds) {
return jose.JWT{}, fmt.Errorf("%v grant type is not supported", oauth2.GrantTypeClientCreds) return jose.JWT{}, fmt.Errorf("%v grant type is not supported", oauth2.GrantTypeClientCreds)
} }
@ -280,7 +299,7 @@ func (c *Client) VerifyJWT(jwt jose.JWT) error {
} }
v := NewJWTVerifier( v := NewJWTVerifier(
c.providerConfig.Issuer, c.providerConfig.Get().Issuer,
c.credentials.ID, c.credentials.ID,
c.maybeSyncKeys, keysFunc) c.maybeSyncKeys, keysFunc)

View File

@ -25,6 +25,9 @@ const (
discoveryConfigPath = "/.well-known/openid-configuration" discoveryConfigPath = "/.well-known/openid-configuration"
) )
// internally configurable for tests
var minimumProviderConfigSyncInterval = MinimumProviderConfigSyncInterval
type ProviderConfig struct { type ProviderConfig struct {
Issuer string `json:"issuer"` Issuer string `json:"issuer"`
AuthEndpoint string `json:"authorization_endpoint"` AuthEndpoint string `json:"authorization_endpoint"`
@ -172,8 +175,8 @@ func nextSyncAfter(exp time.Time, clock clockwork.Clock) time.Duration {
t := exp.Sub(clock.Now()) / 2 t := exp.Sub(clock.Now()) / 2
if t > MaximumProviderConfigSyncInterval { if t > MaximumProviderConfigSyncInterval {
t = MaximumProviderConfigSyncInterval t = MaximumProviderConfigSyncInterval
} else if t < MinimumProviderConfigSyncInterval { } else if t < minimumProviderConfigSyncInterval {
t = MinimumProviderConfigSyncInterval t = minimumProviderConfigSyncInterval
} }
return t return t

View File

@ -53,7 +53,7 @@ func CookieTokenExtractor(cookieName string) RequestTokenExtractor {
} }
} }
func NewClaims(iss, sub, aud string, iat, exp time.Time) jose.Claims { func NewClaims(iss, sub string, aud interface{}, iat, exp time.Time) jose.Claims {
return jose.Claims{ return jose.Claims{
// required // required
"iss": iss, "iss": iss,

View File

@ -25,6 +25,17 @@ func VerifySignature(jwt jose.JWT, keys []key.PublicKey) (bool, error) {
return false, nil return false, nil
} }
// containsString returns true if the given string(needle) is found
// in the string array(haystack).
func containsString(needle string, haystack []string) bool {
for _, v := range haystack {
if v == needle {
return true
}
}
return false
}
// Verify claims in accordance with OIDC spec // Verify claims in accordance with OIDC spec
// http://openid.net/specs/openid-connect-basic-1_0.html#IDTokenValidation // http://openid.net/specs/openid-connect-basic-1_0.html#IDTokenValidation
func VerifyClaims(jwt jose.JWT, issuer, clientID string) error { func VerifyClaims(jwt jose.JWT, issuer, clientID string) error {
@ -45,7 +56,8 @@ func VerifyClaims(jwt jose.JWT, issuer, clientID string) error {
} }
// iss REQUIRED. Issuer Identifier for the Issuer of the response. // iss REQUIRED. Issuer Identifier for the Issuer of the response.
// The iss value is a case sensitive URL using the https scheme that contains scheme, host, and optionally, port number and path components and no query or fragment components. // The iss value is a case sensitive URL using the https scheme that contains scheme,
// host, and optionally, port number and path components and no query or fragment components.
if iss, exists := claims["iss"].(string); exists { if iss, exists := claims["iss"].(string); exists {
if !urlEqual(iss, issuer) { if !urlEqual(iss, issuer) {
return fmt.Errorf("invalid claim value: 'iss'. expected=%s, found=%s.", issuer, iss) return fmt.Errorf("invalid claim value: 'iss'. expected=%s, found=%s.", issuer, iss)
@ -55,19 +67,27 @@ func VerifyClaims(jwt jose.JWT, issuer, clientID string) error {
} }
// iat REQUIRED. Time at which the JWT was issued. // iat REQUIRED. Time at which the JWT was issued.
// Its value is a JSON number representing the number of seconds from 1970-01-01T0:0:0Z as measured in UTC until the date/time. // Its value is a JSON number representing the number of seconds from 1970-01-01T0:0:0Z
// as measured in UTC until the date/time.
if _, exists := claims["iat"].(float64); !exists { if _, exists := claims["iat"].(float64); !exists {
return errors.New("missing claim: 'iat'") return errors.New("missing claim: 'iat'")
} }
// aud REQUIRED. Audience(s) that this ID Token is intended for. // aud REQUIRED. Audience(s) that this ID Token is intended for.
// It MUST contain the OAuth 2.0 client_id of the Relying Party as an audience value. It MAY also contain identifiers for other audiences. In the general case, the aud value is an array of case sensitive strings. In the common special case when there is one audience, the aud value MAY be a single case sensitive string. // It MUST contain the OAuth 2.0 client_id of the Relying Party as an audience value.
if aud, exists := claims["aud"].(string); exists { // It MAY also contain identifiers for other audiences. In the general case, the aud
// value is an array of case sensitive strings. In the common special case when there
// is one audience, the aud value MAY be a single case sensitive string.
if aud, ok, err := claims.StringClaim("aud"); err == nil && ok {
if aud != clientID { if aud != clientID {
return errors.New("invalid claim value: 'aud'") return fmt.Errorf("invalid claims, 'aud' claim and 'client_id' do not match, aud=%s, client_id=%s", aud, clientID)
}
} else if aud, ok, err := claims.StringsClaim("aud"); err == nil && ok {
if !containsString(clientID, aud) {
return fmt.Errorf("invalid claims, cannot find 'client_id' in 'aud' claim, aud=%v, client_id=%s", aud, clientID)
} }
} else { } else {
return errors.New("missing claim: 'aud'") return errors.New("invalid claim value: 'aud' is required, and should be either string or string array")
} }
return nil return nil
@ -97,16 +117,17 @@ func VerifyClientClaims(jwt jose.JWT, issuer string) (string, error) {
return "", errors.New("missing required 'sub' claim") return "", errors.New("missing required 'sub' claim")
} }
aud, ok, err := claims.StringClaim("aud") if aud, ok, err := claims.StringClaim("aud"); err == nil && ok {
if err != nil { if aud != sub {
return "", fmt.Errorf("failed to parse 'aud' claim: %v", err)
} else if !ok {
return "", errors.New("missing required 'aud' claim")
}
if sub != aud {
return "", fmt.Errorf("invalid claims, 'aud' claim and 'sub' claim do not match, aud=%s, sub=%s", aud, sub) return "", fmt.Errorf("invalid claims, 'aud' claim and 'sub' claim do not match, aud=%s, sub=%s", aud, sub)
} }
} else if aud, ok, err := claims.StringsClaim("aud"); err == nil && ok {
if !containsString(sub, aud) {
return "", fmt.Errorf("invalid claims, cannot find 'sud' in 'aud' claim, aud=%v, sub=%s", aud, sub)
}
} else {
return "", errors.New("invalid claim value: 'aud' is required, and should be either string or string array")
}
now := time.Now().UTC() now := time.Now().UTC()
exp, ok, err := claims.TimeClaim("exp") exp, ok, err := claims.TimeClaim("exp")

View File

@ -350,7 +350,7 @@ func TestOIDCAuthentication(t *testing.T) {
op.generateGoodToken(t, srv.URL, "client-foo", "client-bar", "sub", "user-foo"), op.generateGoodToken(t, srv.URL, "client-foo", "client-bar", "sub", "user-foo"),
nil, nil,
false, false,
"oidc: JWT claims invalid: invalid claim value: 'aud'", "oidc: JWT claims invalid: invalid claims, 'aud' claim and 'client_id' do not match",
}, },
{ {
// Invalid issuer. // Invalid issuer.