rework oidc client auth provider

* Cache OpenID Connect clients to prevent reinitialization
* Don't retry requests in the http.RoundTripper.
  * Don't rely on the server not reading POST bodies.
  * Don't leak response body FDs.
  * Formerly ignored any throttling requests by the server.
* Determine if the id token's expired by inspecting it.
  * Similar to logic in golang.org/x/oauth2
* Synchronize around refreshing tokens and persisting the new config.
This commit is contained in:
Eric Chiang
2016-12-05 18:44:22 -08:00
parent 89a506a9b5
commit 46518e937c
2 changed files with 386 additions and 582 deletions

View File

@@ -22,6 +22,7 @@ import (
"fmt"
"net/http"
"strings"
"sync"
"time"
"github.com/coreos/go-oidc/jose"
@@ -30,7 +31,6 @@ import (
"github.com/golang/glog"
"k8s.io/kubernetes/pkg/client/restclient"
"k8s.io/kubernetes/pkg/util/wait"
)
const (
@@ -44,21 +44,68 @@ const (
cfgRefreshToken = "refresh-token"
)
var (
backoff = wait.Backoff{
Duration: 1 * time.Second,
Factor: 2,
Jitter: .1,
Steps: 5,
}
)
func init() {
if err := restclient.RegisterAuthProviderPlugin("oidc", newOIDCAuthProvider); err != nil {
glog.Fatalf("Failed to register oidc auth plugin: %v", err)
}
}
// expiryDelta determines how earlier a token should be considered
// expired than its actual expiration time. It is used to avoid late
// expirations due to client-server time mismatches.
//
// NOTE(ericchiang): this is take from golang.org/x/oauth2
const expiryDelta = 10 * time.Second
var cache = newClientCache()
// Like TLS transports, keep a cache of OIDC clients indexed by issuer URL.
type clientCache struct {
mu sync.RWMutex
cache map[cacheKey]*oidcAuthProvider
}
func newClientCache() *clientCache {
return &clientCache{cache: make(map[cacheKey]*oidcAuthProvider)}
}
type cacheKey struct {
// Canonical issuer URL string of the provider.
issuerURL string
clientID string
clientSecret string
// Don't use CA as cache key because we only add a cache entry if we can connect
// to the issuer in the first place. A valid CA is a prerequisite.
}
func (c *clientCache) getClient(issuer, clientID, clientSecret string) (*oidcAuthProvider, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
client, ok := c.cache[cacheKey{issuer, clientID, clientSecret}]
return client, ok
}
// setClient attempts to put the client in the cache but may return any clients
// with the same keys set before. This is so there's only ever one client for a provider.
func (c *clientCache) setClient(issuer, clientID, clientSecret string, client *oidcAuthProvider) *oidcAuthProvider {
c.mu.Lock()
defer c.mu.Unlock()
key := cacheKey{issuer, clientID, clientSecret}
// If another client has already initialized a client for the given provider we want
// to use that client instead of the one we're trying to set. This is so all transports
// share a client and can coordinate around the same mutex when refreshing and writing
// to the kubeconfig.
if oldClient, ok := c.cache[key]; ok {
return oldClient
}
c.cache[key] = client
return client
}
func newOIDCAuthProvider(_ string, cfg map[string]string, persister restclient.AuthProviderConfigPersister) (restclient.AuthProvider, error) {
issuer := cfg[cfgIssuerUrl]
if issuer == "" {
@@ -75,6 +122,11 @@ func newOIDCAuthProvider(_ string, cfg map[string]string, persister restclient.A
return nil, fmt.Errorf("Must provide %s", cfgClientSecret)
}
// Check cache for existing provider.
if provider, ok := cache.getClient(issuer, clientID, clientSecret); ok {
return provider, nil
}
var certAuthData []byte
var err error
if cfg[cfgCertificateAuthorityData] != "" {
@@ -112,146 +164,134 @@ func newOIDCAuthProvider(_ string, cfg map[string]string, persister restclient.A
ProviderConfig: providerCfg,
Scope: append(scopes, oidc.DefaultScope...),
}
client, err := oidc.NewClient(oidcCfg)
if err != nil {
return nil, fmt.Errorf("error creating OIDC Client: %v", err)
}
oClient := &oidcClient{client}
var initialIDToken jose.JWT
if cfg[cfgIDToken] != "" {
initialIDToken, err = jose.ParseJWT(cfg[cfgIDToken])
if err != nil {
return nil, err
}
provider := &oidcAuthProvider{
client: &oidcClient{client},
cfg: cfg,
persister: persister,
now: time.Now,
}
return &oidcAuthProvider{
initialIDToken: initialIDToken,
refresher: &idTokenRefresher{
client: oClient,
cfg: cfg,
persister: persister,
},
}, nil
return cache.setClient(issuer, clientID, clientSecret, provider), nil
}
type oidcAuthProvider struct {
refresher *idTokenRefresher
initialIDToken jose.JWT
// Interface rather than a raw *oidc.Client for testing.
client OIDCClient
// Stubbed out for testing.
now func() time.Time
// Mutex guards persisting to the kubeconfig file and allows synchronized
// updates to the in-memory config. It also ensures concurrent calls to
// the RoundTripper only trigger a single refresh request.
mu sync.Mutex
cfg map[string]string
persister restclient.AuthProviderConfigPersister
}
func (g *oidcAuthProvider) WrapTransport(rt http.RoundTripper) http.RoundTripper {
at := &oidc.AuthenticatedTransport{
TokenRefresher: g.refresher,
RoundTripper: rt,
}
at.SetJWT(g.initialIDToken)
func (p *oidcAuthProvider) WrapTransport(rt http.RoundTripper) http.RoundTripper {
return &roundTripper{
wrapped: at,
refresher: g.refresher,
wrapped: rt,
provider: p,
}
}
func (g *oidcAuthProvider) Login() error {
func (p *oidcAuthProvider) Login() error {
return errors.New("not yet implemented")
}
type OIDCClient interface {
refreshToken(rt string) (oauth2.TokenResponse, error)
verifyJWT(jwt jose.JWT) error
verifyJWT(jwt *jose.JWT) error
}
type roundTripper struct {
refresher *idTokenRefresher
wrapped *oidc.AuthenticatedTransport
provider *oidcAuthProvider
wrapped http.RoundTripper
}
func (r *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
var res *http.Response
var err error
firstTime := true
wait.ExponentialBackoff(backoff, func() (bool, error) {
if !firstTime {
var jwt jose.JWT
jwt, err = r.refresher.Refresh()
if err != nil {
return true, nil
}
r.wrapped.SetJWT(jwt)
} else {
firstTime = false
}
token, err := r.provider.idToken()
if err != nil {
return nil, err
}
res, err = r.wrapped.RoundTrip(req)
// shallow copy of the struct
r2 := new(http.Request)
*r2 = *req
// deep copy of the Header so we don't modify the original
// request's Header (as per RoundTripper contract).
r2.Header = make(http.Header)
for k, s := range req.Header {
r2.Header[k] = s
}
r2.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
return r.wrapped.RoundTrip(r2)
}
func (p *oidcAuthProvider) idToken() (string, error) {
p.mu.Lock()
defer p.mu.Unlock()
if idToken, ok := p.cfg[cfgIDToken]; ok && len(idToken) > 0 {
valid, err := verifyJWTExpiry(p.now(), idToken)
if err != nil {
return true, nil
return "", err
}
if res.StatusCode == http.StatusUnauthorized {
return false, nil
if valid {
// If the cached id token is still valid use it.
return idToken, nil
}
return true, nil
})
return res, err
}
}
type idTokenRefresher struct {
cfg map[string]string
client OIDCClient
persister restclient.AuthProviderConfigPersister
intialIDToken jose.JWT
}
// Try to request a new token using the refresh token.
rt, ok := p.cfg[cfgRefreshToken]
if !ok || len(rt) == 0 {
return "", errors.New("No valid id-token, and cannot refresh without refresh-token")
}
func (r *idTokenRefresher) Verify(jwt jose.JWT) error {
claims, err := jwt.Claims()
tokens, err := p.client.refreshToken(rt)
if err != nil {
return err
}
now := time.Now()
exp, ok, err := claims.TimeClaim("exp")
switch {
case err != nil:
return fmt.Errorf("failed to parse 'exp' claim: %v", err)
case !ok:
return errors.New("missing required 'exp' claim")
case exp.Before(now):
return fmt.Errorf("token already expired at: %v", exp)
}
return nil
}
func (r *idTokenRefresher) Refresh() (jose.JWT, error) {
rt, ok := r.cfg[cfgRefreshToken]
if !ok {
return jose.JWT{}, errors.New("No valid id-token, and cannot refresh without refresh-token")
}
tokens, err := r.client.refreshToken(rt)
if err != nil {
return jose.JWT{}, fmt.Errorf("could not refresh token: %v", err)
return "", fmt.Errorf("could not refresh token: %v", err)
}
jwt, err := jose.ParseJWT(tokens.IDToken)
if err != nil {
return jose.JWT{}, err
return "", err
}
if err := p.client.verifyJWT(&jwt); err != nil {
return "", err
}
// Create a new config to persist.
newCfg := make(map[string]string)
for key, val := range p.cfg {
newCfg[key] = val
}
if tokens.RefreshToken != "" && tokens.RefreshToken != rt {
r.cfg[cfgRefreshToken] = tokens.RefreshToken
}
r.cfg[cfgIDToken] = jwt.Encode()
err = r.persister.Persist(r.cfg)
if err != nil {
return jose.JWT{}, fmt.Errorf("could not perist new tokens: %v", err)
newCfg[cfgRefreshToken] = tokens.RefreshToken
}
return jwt, r.client.verifyJWT(jwt)
newCfg[cfgIDToken] = tokens.IDToken
if err = p.persister.Persist(newCfg); err != nil {
return "", fmt.Errorf("could not perist new tokens: %v", err)
}
// Update the in memory config to reflect the on disk one.
p.cfg = newCfg
return tokens.IDToken, nil
}
// oidcClient is the real implementation of the OIDCClient interface, which is
// used for testing.
type oidcClient struct {
client *oidc.Client
}
@@ -265,6 +305,29 @@ func (o *oidcClient) refreshToken(rt string) (oauth2.TokenResponse, error) {
return oac.RequestToken(oauth2.GrantTypeRefreshToken, rt)
}
func (o *oidcClient) verifyJWT(jwt jose.JWT) error {
return o.client.VerifyJWT(jwt)
func (o *oidcClient) verifyJWT(jwt *jose.JWT) error {
return o.client.VerifyJWT(*jwt)
}
func verifyJWTExpiry(now time.Time, s string) (valid bool, err error) {
jwt, err := jose.ParseJWT(s)
if err != nil {
return false, fmt.Errorf("invalid %q", cfgIDToken)
}
claims, err := jwt.Claims()
if err != nil {
return false, err
}
exp, ok, err := claims.TimeClaim("exp")
switch {
case err != nil:
return false, fmt.Errorf("failed to parse 'exp' claim: %v", err)
case !ok:
return false, errors.New("missing required 'exp' claim")
case exp.After(now.Add(expiryDelta)):
return true, nil
}
return false, nil
}