auth: return token structs from fetcher functions

Signed-off-by: Tonis Tiigi <tonistiigi@gmail.com>
This commit is contained in:
Tonis Tiigi 2020-08-04 12:54:40 -07:00
parent 957bcb3dff
commit b5185eae6c
2 changed files with 53 additions and 39 deletions

View File

@ -57,6 +57,7 @@ func newUnexpectedStatusErr(resp *http.Response) error {
return ErrUnexpectedStatus{Status: resp.Status, StatusCode: resp.StatusCode, Body: b} return ErrUnexpectedStatus{Status: resp.Status, StatusCode: resp.StatusCode, Body: b}
} }
// GenerateTokenOptions generates options for fetching a token based on a challenge
func GenerateTokenOptions(ctx context.Context, host, username, secret string, c Challenge) (TokenOptions, error) { func GenerateTokenOptions(ctx context.Context, host, username, secret string, c Challenge) (TokenOptions, error) {
realm, ok := c.Parameters["realm"] realm, ok := c.Parameters["realm"]
if !ok { if !ok {
@ -94,7 +95,8 @@ type TokenOptions struct {
Secret string Secret string
} }
type postTokenResponse struct { // OAuthTokenResponse is response from fetching token with a OAuth POST request
type OAuthTokenResponse struct {
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"` RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"` ExpiresIn int `json:"expires_in"`
@ -102,7 +104,8 @@ type postTokenResponse struct {
Scope string `json:"scope"` Scope string `json:"scope"`
} }
func FetchTokenWithOAuth(ctx context.Context, client *http.Client, headers http.Header, clientID string, to TokenOptions) (string, error) { // FetchTokenWithOAuth fetches a token using a POST request
func FetchTokenWithOAuth(ctx context.Context, client *http.Client, headers http.Header, clientID string, to TokenOptions) (*OAuthTokenResponse, error) {
form := url.Values{} form := url.Values{}
if len(to.Scopes) > 0 { if len(to.Scopes) > 0 {
form.Set("scope", strings.Join(to.Scopes, " ")) form.Set("scope", strings.Join(to.Scopes, " "))
@ -121,7 +124,7 @@ func FetchTokenWithOAuth(ctx context.Context, client *http.Client, headers http.
req, err := http.NewRequest("POST", to.Realm, strings.NewReader(form.Encode())) req, err := http.NewRequest("POST", to.Realm, strings.NewReader(form.Encode()))
if err != nil { if err != nil {
return "", err return nil, err
} }
req.Header.Set("Content-Type", "application/x-www-form-urlencoded; charset=utf-8") req.Header.Set("Content-Type", "application/x-www-form-urlencoded; charset=utf-8")
if headers != nil { if headers != nil {
@ -132,25 +135,30 @@ func FetchTokenWithOAuth(ctx context.Context, client *http.Client, headers http.
resp, err := ctxhttp.Do(ctx, client, req) resp, err := ctxhttp.Do(ctx, client, req)
if err != nil { if err != nil {
return "", err return nil, err
} }
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 400 { if resp.StatusCode < 200 || resp.StatusCode >= 400 {
return "", errors.WithStack(newUnexpectedStatusErr(resp)) return nil, errors.WithStack(newUnexpectedStatusErr(resp))
} }
decoder := json.NewDecoder(resp.Body) decoder := json.NewDecoder(resp.Body)
var tr postTokenResponse var tr OAuthTokenResponse
if err = decoder.Decode(&tr); err != nil { if err = decoder.Decode(&tr); err != nil {
return "", errors.Errorf("unable to decode token response: %s", err) return nil, errors.Wrap(err, "unable to decode token response")
} }
return tr.AccessToken, nil if tr.AccessToken == "" {
return nil, errors.WithStack(ErrNoToken)
}
return &tr, nil
} }
type getTokenResponse struct { // FetchTokenResponse is response from fetching token with GET request
type FetchTokenResponse struct {
Token string `json:"token"` Token string `json:"token"`
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"` ExpiresIn int `json:"expires_in"`
@ -159,10 +167,10 @@ type getTokenResponse struct {
} }
// FetchToken fetches a token using a GET request // FetchToken fetches a token using a GET request
func FetchToken(ctx context.Context, client *http.Client, headers http.Header, to TokenOptions) (string, error) { func FetchToken(ctx context.Context, client *http.Client, headers http.Header, to TokenOptions) (*FetchTokenResponse, error) {
req, err := http.NewRequest("GET", to.Realm, nil) req, err := http.NewRequest("GET", to.Realm, nil)
if err != nil { if err != nil {
return "", err return nil, err
} }
if headers != nil { if headers != nil {
@ -189,19 +197,19 @@ func FetchToken(ctx context.Context, client *http.Client, headers http.Header, t
resp, err := ctxhttp.Do(ctx, client, req) resp, err := ctxhttp.Do(ctx, client, req)
if err != nil { if err != nil {
return "", err return nil, err
} }
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 400 { if resp.StatusCode < 200 || resp.StatusCode >= 400 {
return "", errors.WithStack(newUnexpectedStatusErr(resp)) return nil, errors.WithStack(newUnexpectedStatusErr(resp))
} }
decoder := json.NewDecoder(resp.Body) decoder := json.NewDecoder(resp.Body)
var tr getTokenResponse var tr FetchTokenResponse
if err = decoder.Decode(&tr); err != nil { if err = decoder.Decode(&tr); err != nil {
return "", errors.Errorf("unable to decode token response: %s", err) return nil, errors.Wrap(err, "unable to decode token response")
} }
// `access_token` is equivalent to `token` and if both are specified // `access_token` is equivalent to `token` and if both are specified
@ -212,8 +220,8 @@ func FetchToken(ctx context.Context, client *http.Client, headers http.Header, t
} }
if tr.Token == "" { if tr.Token == "" {
return "", ErrNoToken return nil, errors.WithStack(ErrNoToken)
} }
return tr.Token, nil return &tr, nil
} }

View File

@ -241,7 +241,7 @@ func (ah *authHandler) doBasicAuth(ctx context.Context) (string, error) {
return fmt.Sprintf("Basic %s", auth), nil return fmt.Sprintf("Basic %s", auth), nil
} }
func (ah *authHandler) doBearerAuth(ctx context.Context) (string, error) { func (ah *authHandler) doBearerAuth(ctx context.Context) (token string, err error) {
// copy common tokenOptions // copy common tokenOptions
to := ah.common to := ah.common
@ -263,15 +263,20 @@ func (ah *authHandler) doBearerAuth(ctx context.Context) (string, error) {
ah.scopedTokens[scoped] = r ah.scopedTokens[scoped] = r
ah.Unlock() ah.Unlock()
defer func() {
token = fmt.Sprintf("Bearer %s", token)
r.token, r.err = token, err
r.Done()
}()
// fetch token for the resource scope // fetch token for the resource scope
var (
token string
err error
)
if to.Secret != "" { if to.Secret != "" {
defer func() {
err = errors.Wrap(err, "failed to fetch oauth token")
}()
// credential information is provided, use oauth POST endpoint // credential information is provided, use oauth POST endpoint
// TODO: Allow setting client_id // TODO: Allow setting client_id
token, err = auth.FetchTokenWithOAuth(ctx, ah.client, ah.header, "containerd-client", to) resp, err := auth.FetchTokenWithOAuth(ctx, ah.client, ah.header, "containerd-client", to)
if err != nil { if err != nil {
var errStatus auth.ErrUnexpectedStatus var errStatus auth.ErrUnexpectedStatus
if errors.As(err, &errStatus) { if errors.As(err, &errStatus) {
@ -279,26 +284,27 @@ func (ah *authHandler) doBearerAuth(ctx context.Context) (string, error) {
// As of September 2017, GCR is known to return 404. // As of September 2017, GCR is known to return 404.
// As of February 2018, JFrog Artifactory is known to return 401. // As of February 2018, JFrog Artifactory is known to return 401.
if (errStatus.StatusCode == 405 && to.Username != "") || errStatus.StatusCode == 404 || errStatus.StatusCode == 401 { if (errStatus.StatusCode == 405 && to.Username != "") || errStatus.StatusCode == 404 || errStatus.StatusCode == 401 {
token, err = auth.FetchToken(ctx, ah.client, ah.header, to) resp, err := auth.FetchToken(ctx, ah.client, ah.header, to)
} else { if err != nil {
log.G(ctx).WithFields(logrus.Fields{ return "", err
"status": errStatus.Status, }
"body": string(errStatus.Body), return resp.Token, nil
}).Debugf("token request failed")
} }
log.G(ctx).WithFields(logrus.Fields{
"status": errStatus.Status,
"body": string(errStatus.Body),
}).Debugf("token request failed")
} }
return "", err
} }
err = errors.Wrap(err, "failed to fetch oauth token") return resp.AccessToken, nil
} else {
// do request anonymously
token, err = auth.FetchToken(ctx, ah.client, ah.header, to)
err = errors.Wrap(err, "failed to fetch anonymous token")
} }
token = fmt.Sprintf("Bearer %s", token) // do request anonymously
resp, err := auth.FetchToken(ctx, ah.client, ah.header, to)
r.token, r.err = token, err if err != nil {
r.Done() return "", errors.Wrap(err, "failed to fetch anonymous token")
return r.token, r.err }
return resp.Token, nil
} }
func invalidAuthorization(c auth.Challenge, responses []*http.Response) error { func invalidAuthorization(c auth.Challenge, responses []*http.Response) error {