remotes/docker/authorizer.go: invalidate auth tokens when they expire.

Signed-off-by: Iain Macdonald <xiainx@gmail.com>
This commit is contained in:
Iain Macdonald 2023-06-23 14:16:52 -07:00
parent be9336fed1
commit af6a90bf5c
2 changed files with 32 additions and 15 deletions

View File

@ -88,7 +88,7 @@ type TokenOptions struct {
type OAuthTokenResponse struct { type OAuthTokenResponse struct {
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"` RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"` ExpiresInSeconds int `json:"expires_in"`
IssuedAt time.Time `json:"issued_at"` IssuedAt time.Time `json:"issued_at"`
Scope string `json:"scope"` Scope string `json:"scope"`
} }
@ -154,7 +154,7 @@ func FetchTokenWithOAuth(ctx context.Context, client *http.Client, headers http.
type FetchTokenResponse struct { type FetchTokenResponse struct {
Token string `json:"token"` Token string `json:"token"`
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"` ExpiresInSeconds int `json:"expires_in"`
IssuedAt time.Time `json:"issued_at"` IssuedAt time.Time `json:"issued_at"`
RefreshToken string `json:"refresh_token"` RefreshToken string `json:"refresh_token"`
} }

View File

@ -24,6 +24,7 @@ import (
"net/http" "net/http"
"strings" "strings"
"sync" "sync"
"time"
"github.com/containerd/containerd/v2/core/remotes/docker/auth" "github.com/containerd/containerd/v2/core/remotes/docker/auth"
remoteerrors "github.com/containerd/containerd/v2/core/remotes/errors" remoteerrors "github.com/containerd/containerd/v2/core/remotes/errors"
@ -207,6 +208,7 @@ type authResult struct {
sync.WaitGroup sync.WaitGroup
token string token string
refreshToken string refreshToken string
expirationTime *time.Time
err error err error
} }
@ -270,8 +272,12 @@ func (ah *authHandler) doBearerAuth(ctx context.Context) (token, refreshToken st
// Docs: https://docs.docker.com/registry/spec/auth/scope // Docs: https://docs.docker.com/registry/spec/auth/scope
scoped := strings.Join(to.Scopes, " ") scoped := strings.Join(to.Scopes, " ")
// Keep track of the expiration time of cached bearer tokens so they can be
// refreshed when they expire without a server roundtrip.
var expirationTime *time.Time
ah.Lock() ah.Lock()
if r, exist := ah.scopedTokens[scoped]; exist { if r, exist := ah.scopedTokens[scoped]; exist && (r.expirationTime == nil || r.expirationTime.After(time.Now())) {
ah.Unlock() ah.Unlock()
r.Wait() r.Wait()
return r.token, r.refreshToken, r.err return r.token, r.refreshToken, r.err
@ -285,7 +291,7 @@ func (ah *authHandler) doBearerAuth(ctx context.Context) (token, refreshToken st
defer func() { defer func() {
token = fmt.Sprintf("Bearer %s", token) token = fmt.Sprintf("Bearer %s", token)
r.token, r.refreshToken, r.err = token, refreshToken, err r.token, r.refreshToken, r.err, r.expirationTime = token, refreshToken, err, expirationTime
r.Done() r.Done()
}() }()
@ -311,6 +317,7 @@ func (ah *authHandler) doBearerAuth(ctx context.Context) (token, refreshToken st
if err != nil { if err != nil {
return "", "", err return "", "", err
} }
expirationTime = getExpirationTime(resp.ExpiresInSeconds)
return resp.Token, resp.RefreshToken, nil return resp.Token, resp.RefreshToken, nil
} }
log.G(ctx).WithFields(log.Fields{ log.G(ctx).WithFields(log.Fields{
@ -320,6 +327,7 @@ func (ah *authHandler) doBearerAuth(ctx context.Context) (token, refreshToken st
} }
return "", "", err return "", "", err
} }
expirationTime = getExpirationTime(resp.ExpiresInSeconds)
return resp.AccessToken, resp.RefreshToken, nil return resp.AccessToken, resp.RefreshToken, nil
} }
// do request anonymously // do request anonymously
@ -327,9 +335,18 @@ func (ah *authHandler) doBearerAuth(ctx context.Context) (token, refreshToken st
if err != nil { if err != nil {
return "", "", fmt.Errorf("failed to fetch anonymous token: %w", err) return "", "", fmt.Errorf("failed to fetch anonymous token: %w", err)
} }
expirationTime = getExpirationTime(resp.ExpiresInSeconds)
return resp.Token, resp.RefreshToken, nil return resp.Token, resp.RefreshToken, nil
} }
func getExpirationTime(expiresInSeconds int) *time.Time {
if expiresInSeconds <= 0 {
return nil
}
expirationTime := time.Now().Add(time.Duration(expiresInSeconds) * time.Second)
return &expirationTime
}
func invalidAuthorization(ctx context.Context, c auth.Challenge, responses []*http.Response) (retry bool, _ error) { func invalidAuthorization(ctx context.Context, c auth.Challenge, responses []*http.Response) (retry bool, _ error) {
errStr := c.Parameters["error"] errStr := c.Parameters["error"]
if errStr == "" { if errStr == "" {