remotes/docker/authorizer.go: invalidate auth tokens when they expire.

Signed-off-by: Iain Macdonald <xiainx@gmail.com>
This commit is contained in:
Iain Macdonald 2023-06-23 14:16:52 -07:00
parent be9336fed1
commit af6a90bf5c
2 changed files with 32 additions and 15 deletions

View File

@ -86,11 +86,11 @@ type TokenOptions struct {
// OAuthTokenResponse is response from fetching token with a OAuth POST request // OAuthTokenResponse is response from fetching token with a OAuth POST request
type OAuthTokenResponse struct { type OAuthTokenResponse struct {
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"` RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"` ExpiresInSeconds int `json:"expires_in"`
IssuedAt time.Time `json:"issued_at"` IssuedAt time.Time `json:"issued_at"`
Scope string `json:"scope"` Scope string `json:"scope"`
} }
// FetchTokenWithOAuth fetches a token using a POST request // FetchTokenWithOAuth fetches a token using a POST request
@ -152,11 +152,11 @@ func FetchTokenWithOAuth(ctx context.Context, client *http.Client, headers http.
// FetchTokenResponse is response from fetching token with GET request // FetchTokenResponse is response from fetching token with GET request
type FetchTokenResponse struct { type FetchTokenResponse struct {
Token string `json:"token"` Token string `json:"token"`
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"` ExpiresInSeconds int `json:"expires_in"`
IssuedAt time.Time `json:"issued_at"` IssuedAt time.Time `json:"issued_at"`
RefreshToken string `json:"refresh_token"` RefreshToken string `json:"refresh_token"`
} }
// FetchToken fetches a token using a GET request // FetchToken fetches a token using a GET request

View File

@ -24,6 +24,7 @@ import (
"net/http" "net/http"
"strings" "strings"
"sync" "sync"
"time"
"github.com/containerd/containerd/v2/core/remotes/docker/auth" "github.com/containerd/containerd/v2/core/remotes/docker/auth"
remoteerrors "github.com/containerd/containerd/v2/core/remotes/errors" remoteerrors "github.com/containerd/containerd/v2/core/remotes/errors"
@ -205,9 +206,10 @@ func (a *dockerAuthorizer) AddResponses(ctx context.Context, responses []*http.R
// authResult is used to control limit rate. // authResult is used to control limit rate.
type authResult struct { type authResult struct {
sync.WaitGroup sync.WaitGroup
token string token string
refreshToken string refreshToken string
err error expirationTime *time.Time
err error
} }
// authHandler is used to handle auth request per registry server. // authHandler is used to handle auth request per registry server.
@ -270,8 +272,12 @@ func (ah *authHandler) doBearerAuth(ctx context.Context) (token, refreshToken st
// Docs: https://docs.docker.com/registry/spec/auth/scope // Docs: https://docs.docker.com/registry/spec/auth/scope
scoped := strings.Join(to.Scopes, " ") scoped := strings.Join(to.Scopes, " ")
// Keep track of the expiration time of cached bearer tokens so they can be
// refreshed when they expire without a server roundtrip.
var expirationTime *time.Time
ah.Lock() ah.Lock()
if r, exist := ah.scopedTokens[scoped]; exist { if r, exist := ah.scopedTokens[scoped]; exist && (r.expirationTime == nil || r.expirationTime.After(time.Now())) {
ah.Unlock() ah.Unlock()
r.Wait() r.Wait()
return r.token, r.refreshToken, r.err return r.token, r.refreshToken, r.err
@ -285,7 +291,7 @@ func (ah *authHandler) doBearerAuth(ctx context.Context) (token, refreshToken st
defer func() { defer func() {
token = fmt.Sprintf("Bearer %s", token) token = fmt.Sprintf("Bearer %s", token)
r.token, r.refreshToken, r.err = token, refreshToken, err r.token, r.refreshToken, r.err, r.expirationTime = token, refreshToken, err, expirationTime
r.Done() r.Done()
}() }()
@ -311,6 +317,7 @@ func (ah *authHandler) doBearerAuth(ctx context.Context) (token, refreshToken st
if err != nil { if err != nil {
return "", "", err return "", "", err
} }
expirationTime = getExpirationTime(resp.ExpiresInSeconds)
return resp.Token, resp.RefreshToken, nil return resp.Token, resp.RefreshToken, nil
} }
log.G(ctx).WithFields(log.Fields{ log.G(ctx).WithFields(log.Fields{
@ -320,6 +327,7 @@ func (ah *authHandler) doBearerAuth(ctx context.Context) (token, refreshToken st
} }
return "", "", err return "", "", err
} }
expirationTime = getExpirationTime(resp.ExpiresInSeconds)
return resp.AccessToken, resp.RefreshToken, nil return resp.AccessToken, resp.RefreshToken, nil
} }
// do request anonymously // do request anonymously
@ -327,9 +335,18 @@ func (ah *authHandler) doBearerAuth(ctx context.Context) (token, refreshToken st
if err != nil { if err != nil {
return "", "", fmt.Errorf("failed to fetch anonymous token: %w", err) return "", "", fmt.Errorf("failed to fetch anonymous token: %w", err)
} }
expirationTime = getExpirationTime(resp.ExpiresInSeconds)
return resp.Token, resp.RefreshToken, nil return resp.Token, resp.RefreshToken, nil
} }
func getExpirationTime(expiresInSeconds int) *time.Time {
if expiresInSeconds <= 0 {
return nil
}
expirationTime := time.Now().Add(time.Duration(expiresInSeconds) * time.Second)
return &expirationTime
}
func invalidAuthorization(ctx context.Context, c auth.Challenge, responses []*http.Response) (retry bool, _ error) { func invalidAuthorization(ctx context.Context, c auth.Challenge, responses []*http.Response) (retry bool, _ error) {
errStr := c.Parameters["error"] errStr := c.Parameters["error"]
if errStr == "" { if errStr == "" {