docker: split private token helper functions to reusable pkg

Signed-off-by: Tonis Tiigi <tonistiigi@gmail.com>
This commit is contained in:
Tonis Tiigi 2020-08-02 22:30:31 -07:00
parent bd92d567a5
commit 957bcb3dff
4 changed files with 292 additions and 228 deletions

View File

@ -0,0 +1,219 @@
/*
Copyright The containerd Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package auth
import (
"context"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"strings"
"time"
"github.com/containerd/containerd/log"
"github.com/pkg/errors"
"golang.org/x/net/context/ctxhttp"
)
var (
// ErrNoToken is returned if a request is successful but the body does not
// contain an authorization token.
ErrNoToken = errors.New("authorization server did not include a token in the response")
)
// ErrUnexpectedStatus is returned if a token request returned with unexpected HTTP status
type ErrUnexpectedStatus struct {
Status string
StatusCode int
Body []byte
}
func (e ErrUnexpectedStatus) Error() string {
return fmt.Sprintf("unexpected status: %s", e.Status)
}
func newUnexpectedStatusErr(resp *http.Response) error {
var b []byte
if resp.Body != nil {
b, _ = ioutil.ReadAll(io.LimitReader(resp.Body, 64000)) // 64KB
}
return ErrUnexpectedStatus{Status: resp.Status, StatusCode: resp.StatusCode, Body: b}
}
func GenerateTokenOptions(ctx context.Context, host, username, secret string, c Challenge) (TokenOptions, error) {
realm, ok := c.Parameters["realm"]
if !ok {
return TokenOptions{}, errors.New("no realm specified for token auth challenge")
}
realmURL, err := url.Parse(realm)
if err != nil {
return TokenOptions{}, errors.Wrap(err, "invalid token auth challenge realm")
}
to := TokenOptions{
Realm: realmURL.String(),
Service: c.Parameters["service"],
Username: username,
Secret: secret,
}
scope, ok := c.Parameters["scope"]
if ok {
to.Scopes = append(to.Scopes, scope)
} else {
log.G(ctx).WithField("host", host).Debug("no scope specified for token auth challenge")
}
return to, nil
}
// TokenOptions are optios for requesting a token
type TokenOptions struct {
Realm string
Service string
Scopes []string
Username string
Secret string
}
type postTokenResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"`
IssuedAt time.Time `json:"issued_at"`
Scope string `json:"scope"`
}
func FetchTokenWithOAuth(ctx context.Context, client *http.Client, headers http.Header, clientID string, to TokenOptions) (string, error) {
form := url.Values{}
if len(to.Scopes) > 0 {
form.Set("scope", strings.Join(to.Scopes, " "))
}
form.Set("service", to.Service)
form.Set("client_id", clientID)
if to.Username == "" {
form.Set("grant_type", "refresh_token")
form.Set("refresh_token", to.Secret)
} else {
form.Set("grant_type", "password")
form.Set("username", to.Username)
form.Set("password", to.Secret)
}
req, err := http.NewRequest("POST", to.Realm, strings.NewReader(form.Encode()))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded; charset=utf-8")
if headers != nil {
for k, v := range headers {
req.Header[k] = append(req.Header[k], v...)
}
}
resp, err := ctxhttp.Do(ctx, client, req)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 400 {
return "", errors.WithStack(newUnexpectedStatusErr(resp))
}
decoder := json.NewDecoder(resp.Body)
var tr postTokenResponse
if err = decoder.Decode(&tr); err != nil {
return "", errors.Errorf("unable to decode token response: %s", err)
}
return tr.AccessToken, nil
}
type getTokenResponse struct {
Token string `json:"token"`
AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"`
IssuedAt time.Time `json:"issued_at"`
RefreshToken string `json:"refresh_token"`
}
// FetchToken fetches a token using a GET request
func FetchToken(ctx context.Context, client *http.Client, headers http.Header, to TokenOptions) (string, error) {
req, err := http.NewRequest("GET", to.Realm, nil)
if err != nil {
return "", err
}
if headers != nil {
for k, v := range headers {
req.Header[k] = append(req.Header[k], v...)
}
}
reqParams := req.URL.Query()
if to.Service != "" {
reqParams.Add("service", to.Service)
}
for _, scope := range to.Scopes {
reqParams.Add("scope", scope)
}
if to.Secret != "" {
req.SetBasicAuth(to.Username, to.Secret)
}
req.URL.RawQuery = reqParams.Encode()
resp, err := ctxhttp.Do(ctx, client, req)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 400 {
return "", errors.WithStack(newUnexpectedStatusErr(resp))
}
decoder := json.NewDecoder(resp.Body)
var tr getTokenResponse
if err = decoder.Decode(&tr); err != nil {
return "", errors.Errorf("unable to decode token response: %s", err)
}
// `access_token` is equivalent to `token` and if both are specified
// the choice is undefined. Canonicalize `access_token` by sticking
// things in `token`.
if tr.AccessToken != "" {
tr.Token = tr.AccessToken
}
if tr.Token == "" {
return "", ErrNoToken
}
return tr.Token, nil
}

View File

@ -14,7 +14,7 @@
limitations under the License. limitations under the License.
*/ */
package docker package auth
import ( import (
"net/http" "net/http"
@ -22,31 +22,35 @@ import (
"strings" "strings"
) )
type authenticationScheme byte // AuthenticationScheme defines scheme of the authentication method
type AuthenticationScheme byte
const ( const (
basicAuth authenticationScheme = 1 << iota // Defined in RFC 7617 // BasicAuth is scheme for Basic HTTP Authentication RFC 7617
digestAuth // Defined in RFC 7616 BasicAuth AuthenticationScheme = 1 << iota
bearerAuth // Defined in RFC 6750 // DigestAuth is scheme for HTTP Digest Access Authentication RFC 7616
DigestAuth
// BearerAuth is scheme for OAuth 2.0 Bearer Tokens RFC 6750
BearerAuth
) )
// challenge carries information from a WWW-Authenticate response header. // Challenge carries information from a WWW-Authenticate response header.
// See RFC 2617. // See RFC 2617.
type challenge struct { type Challenge struct {
// scheme is the auth-scheme according to RFC 2617 // scheme is the auth-scheme according to RFC 2617
scheme authenticationScheme Scheme AuthenticationScheme
// parameters are the auth-params according to RFC 2617 // parameters are the auth-params according to RFC 2617
parameters map[string]string Parameters map[string]string
} }
type byScheme []challenge type byScheme []Challenge
func (bs byScheme) Len() int { return len(bs) } func (bs byScheme) Len() int { return len(bs) }
func (bs byScheme) Swap(i, j int) { bs[i], bs[j] = bs[j], bs[i] } func (bs byScheme) Swap(i, j int) { bs[i], bs[j] = bs[j], bs[i] }
// Sort in priority order: token > digest > basic // Sort in priority order: token > digest > basic
func (bs byScheme) Less(i, j int) bool { return bs[i].scheme > bs[j].scheme } func (bs byScheme) Less(i, j int) bool { return bs[i].Scheme > bs[j].Scheme }
// Octet types from RFC 2616. // Octet types from RFC 2616.
type octetType byte type octetType byte
@ -90,22 +94,23 @@ func init() {
} }
} }
func parseAuthHeader(header http.Header) []challenge { // ParseAuthHeader parses challenges from WWW-Authenticate header
challenges := []challenge{} func ParseAuthHeader(header http.Header) []Challenge {
challenges := []Challenge{}
for _, h := range header[http.CanonicalHeaderKey("WWW-Authenticate")] { for _, h := range header[http.CanonicalHeaderKey("WWW-Authenticate")] {
v, p := parseValueAndParams(h) v, p := parseValueAndParams(h)
var s authenticationScheme var s AuthenticationScheme
switch v { switch v {
case "basic": case "basic":
s = basicAuth s = BasicAuth
case "digest": case "digest":
s = digestAuth s = DigestAuth
case "bearer": case "bearer":
s = bearerAuth s = BearerAuth
default: default:
continue continue
} }
challenges = append(challenges, challenge{scheme: s, parameters: p}) challenges = append(challenges, Challenge{Scheme: s, Parameters: p})
} }
sort.Stable(byScheme(challenges)) sort.Stable(byScheme(challenges))
return challenges return challenges

View File

@ -19,21 +19,16 @@ package docker
import ( import (
"context" "context"
"encoding/base64" "encoding/base64"
"encoding/json"
"fmt" "fmt"
"io"
"io/ioutil"
"net/http" "net/http"
"net/url"
"strings" "strings"
"sync" "sync"
"time"
"github.com/containerd/containerd/errdefs" "github.com/containerd/containerd/errdefs"
"github.com/containerd/containerd/log" "github.com/containerd/containerd/log"
"github.com/containerd/containerd/remotes/docker/auth"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"golang.org/x/net/context/ctxhttp"
) )
type dockerAuthorizer struct { type dockerAuthorizer struct {
@ -135,8 +130,8 @@ func (a *dockerAuthorizer) AddResponses(ctx context.Context, responses []*http.R
a.mu.Lock() a.mu.Lock()
defer a.mu.Unlock() defer a.mu.Unlock()
for _, c := range parseAuthHeader(last.Header) { for _, c := range auth.ParseAuthHeader(last.Header) {
if c.scheme == bearerAuth { if c.Scheme == auth.BearerAuth {
if err := invalidAuthorization(c, responses); err != nil { if err := invalidAuthorization(c, responses); err != nil {
delete(a.handlers, host) delete(a.handlers, host)
return err return err
@ -152,26 +147,35 @@ func (a *dockerAuthorizer) AddResponses(ctx context.Context, responses []*http.R
return nil return nil
} }
common, err := a.generateTokenOptions(ctx, host, c) var username, secret string
if a.credentials != nil {
var err error
username, secret, err = a.credentials(host)
if err != nil {
return err
}
}
common, err := auth.GenerateTokenOptions(ctx, host, username, secret, c)
if err != nil { if err != nil {
return err return err
} }
a.handlers[host] = newAuthHandler(a.client, a.header, c.scheme, common) a.handlers[host] = newAuthHandler(a.client, a.header, c.Scheme, common)
return nil return nil
} else if c.scheme == basicAuth && a.credentials != nil { } else if c.Scheme == auth.BasicAuth && a.credentials != nil {
username, secret, err := a.credentials(host) username, secret, err := a.credentials(host)
if err != nil { if err != nil {
return err return err
} }
if username != "" && secret != "" { if username != "" && secret != "" {
common := tokenOptions{ common := auth.TokenOptions{
username: username, Username: username,
secret: secret, Secret: secret,
} }
a.handlers[host] = newAuthHandler(a.client, a.header, c.scheme, common) a.handlers[host] = newAuthHandler(a.client, a.header, c.Scheme, common)
return nil return nil
} }
} }
@ -179,38 +183,6 @@ func (a *dockerAuthorizer) AddResponses(ctx context.Context, responses []*http.R
return errors.Wrap(errdefs.ErrNotImplemented, "failed to find supported auth scheme") return errors.Wrap(errdefs.ErrNotImplemented, "failed to find supported auth scheme")
} }
func (a *dockerAuthorizer) generateTokenOptions(ctx context.Context, host string, c challenge) (tokenOptions, error) {
realm, ok := c.parameters["realm"]
if !ok {
return tokenOptions{}, errors.New("no realm specified for token auth challenge")
}
realmURL, err := url.Parse(realm)
if err != nil {
return tokenOptions{}, errors.Wrap(err, "invalid token auth challenge realm")
}
to := tokenOptions{
realm: realmURL.String(),
service: c.parameters["service"],
}
scope, ok := c.parameters["scope"]
if ok {
to.scopes = append(to.scopes, scope)
} else {
log.G(ctx).WithField("host", host).Debug("no scope specified for token auth challenge")
}
if a.credentials != nil {
to.username, to.secret, err = a.credentials(host)
if err != nil {
return tokenOptions{}, err
}
}
return to, nil
}
// authResult is used to control limit rate. // authResult is used to control limit rate.
type authResult struct { type authResult struct {
sync.WaitGroup sync.WaitGroup
@ -227,17 +199,17 @@ type authHandler struct {
client *http.Client client *http.Client
// only support basic and bearer schemes // only support basic and bearer schemes
scheme authenticationScheme scheme auth.AuthenticationScheme
// common contains common challenge answer // common contains common challenge answer
common tokenOptions common auth.TokenOptions
// scopedTokens caches token indexed by scopes, which used in // scopedTokens caches token indexed by scopes, which used in
// bearer auth case // bearer auth case
scopedTokens map[string]*authResult scopedTokens map[string]*authResult
} }
func newAuthHandler(client *http.Client, hdr http.Header, scheme authenticationScheme, opts tokenOptions) *authHandler { func newAuthHandler(client *http.Client, hdr http.Header, scheme auth.AuthenticationScheme, opts auth.TokenOptions) *authHandler {
return &authHandler{ return &authHandler{
header: hdr, header: hdr,
client: client, client: client,
@ -249,17 +221,17 @@ func newAuthHandler(client *http.Client, hdr http.Header, scheme authenticationS
func (ah *authHandler) authorize(ctx context.Context) (string, error) { func (ah *authHandler) authorize(ctx context.Context) (string, error) {
switch ah.scheme { switch ah.scheme {
case basicAuth: case auth.BasicAuth:
return ah.doBasicAuth(ctx) return ah.doBasicAuth(ctx)
case bearerAuth: case auth.BearerAuth:
return ah.doBearerAuth(ctx) return ah.doBearerAuth(ctx)
default: default:
return "", errors.Wrap(errdefs.ErrNotImplemented, "failed to find supported auth scheme") return "", errors.Wrapf(errdefs.ErrNotImplemented, "failed to find supported auth scheme: %s", string(ah.scheme))
} }
} }
func (ah *authHandler) doBasicAuth(ctx context.Context) (string, error) { func (ah *authHandler) doBasicAuth(ctx context.Context) (string, error) {
username, secret := ah.common.username, ah.common.secret username, secret := ah.common.Username, ah.common.Secret
if username == "" || secret == "" { if username == "" || secret == "" {
return "", fmt.Errorf("failed to handle basic auth because missing username or secret") return "", fmt.Errorf("failed to handle basic auth because missing username or secret")
@ -273,10 +245,10 @@ func (ah *authHandler) doBearerAuth(ctx context.Context) (string, error) {
// copy common tokenOptions // copy common tokenOptions
to := ah.common to := ah.common
to.scopes = GetTokenScopes(ctx, to.scopes) to.Scopes = GetTokenScopes(ctx, to.Scopes)
// Docs: https://docs.docker.com/registry/spec/auth/scope // Docs: https://docs.docker.com/registry/spec/auth/scope
scoped := strings.Join(to.scopes, " ") scoped := strings.Join(to.Scopes, " ")
ah.Lock() ah.Lock()
if r, exist := ah.scopedTokens[scoped]; exist { if r, exist := ah.scopedTokens[scoped]; exist {
@ -296,13 +268,30 @@ func (ah *authHandler) doBearerAuth(ctx context.Context) (string, error) {
token string token string
err error err error
) )
if to.secret != "" { if to.Secret != "" {
// credential information is provided, use oauth POST endpoint // credential information is provided, use oauth POST endpoint
token, err = ah.fetchTokenWithOAuth(ctx, to) // TODO: Allow setting client_id
token, err = auth.FetchTokenWithOAuth(ctx, ah.client, ah.header, "containerd-client", to)
if err != nil {
var errStatus auth.ErrUnexpectedStatus
if errors.As(err, &errStatus) {
// Registries without support for POST may return 404 for POST /v2/token.
// As of September 2017, GCR is known to return 404.
// As of February 2018, JFrog Artifactory is known to return 401.
if (errStatus.StatusCode == 405 && to.Username != "") || errStatus.StatusCode == 404 || errStatus.StatusCode == 401 {
token, err = auth.FetchToken(ctx, ah.client, ah.header, to)
} else {
log.G(ctx).WithFields(logrus.Fields{
"status": errStatus.Status,
"body": string(errStatus.Body),
}).Debugf("token request failed")
}
}
}
err = errors.Wrap(err, "failed to fetch oauth token") err = errors.Wrap(err, "failed to fetch oauth token")
} else { } else {
// do request anonymously // do request anonymously
token, err = ah.fetchToken(ctx, to) token, err = auth.FetchToken(ctx, ah.client, ah.header, to)
err = errors.Wrap(err, "failed to fetch anonymous token") err = errors.Wrap(err, "failed to fetch anonymous token")
} }
token = fmt.Sprintf("Bearer %s", token) token = fmt.Sprintf("Bearer %s", token)
@ -312,153 +301,8 @@ func (ah *authHandler) doBearerAuth(ctx context.Context) (string, error) {
return r.token, r.err return r.token, r.err
} }
type tokenOptions struct { func invalidAuthorization(c auth.Challenge, responses []*http.Response) error {
realm string errStr := c.Parameters["error"]
service string
scopes []string
username string
secret string
}
type postTokenResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"`
IssuedAt time.Time `json:"issued_at"`
Scope string `json:"scope"`
}
func (ah *authHandler) fetchTokenWithOAuth(ctx context.Context, to tokenOptions) (string, error) {
form := url.Values{}
if len(to.scopes) > 0 {
form.Set("scope", strings.Join(to.scopes, " "))
}
form.Set("service", to.service)
// TODO: Allow setting client_id
form.Set("client_id", "containerd-client")
if to.username == "" {
form.Set("grant_type", "refresh_token")
form.Set("refresh_token", to.secret)
} else {
form.Set("grant_type", "password")
form.Set("username", to.username)
form.Set("password", to.secret)
}
req, err := http.NewRequest("POST", to.realm, strings.NewReader(form.Encode()))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded; charset=utf-8")
if ah.header != nil {
for k, v := range ah.header {
req.Header[k] = append(req.Header[k], v...)
}
}
resp, err := ctxhttp.Do(ctx, ah.client, req)
if err != nil {
return "", err
}
defer resp.Body.Close()
// Registries without support for POST may return 404 for POST /v2/token.
// As of September 2017, GCR is known to return 404.
// As of February 2018, JFrog Artifactory is known to return 401.
if (resp.StatusCode == 405 && to.username != "") || resp.StatusCode == 404 || resp.StatusCode == 401 {
return ah.fetchToken(ctx, to)
} else if resp.StatusCode < 200 || resp.StatusCode >= 400 {
b, _ := ioutil.ReadAll(io.LimitReader(resp.Body, 64000)) // 64KB
log.G(ctx).WithFields(logrus.Fields{
"status": resp.Status,
"body": string(b),
}).Debugf("token request failed")
// TODO: handle error body and write debug output
return "", errors.Errorf("unexpected status: %s", resp.Status)
}
decoder := json.NewDecoder(resp.Body)
var tr postTokenResponse
if err = decoder.Decode(&tr); err != nil {
return "", fmt.Errorf("unable to decode token response: %s", err)
}
return tr.AccessToken, nil
}
type getTokenResponse struct {
Token string `json:"token"`
AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"`
IssuedAt time.Time `json:"issued_at"`
RefreshToken string `json:"refresh_token"`
}
// fetchToken fetches a token using a GET request
func (ah *authHandler) fetchToken(ctx context.Context, to tokenOptions) (string, error) {
req, err := http.NewRequest("GET", to.realm, nil)
if err != nil {
return "", err
}
if ah.header != nil {
for k, v := range ah.header {
req.Header[k] = append(req.Header[k], v...)
}
}
reqParams := req.URL.Query()
if to.service != "" {
reqParams.Add("service", to.service)
}
for _, scope := range to.scopes {
reqParams.Add("scope", scope)
}
if to.secret != "" {
req.SetBasicAuth(to.username, to.secret)
}
req.URL.RawQuery = reqParams.Encode()
resp, err := ctxhttp.Do(ctx, ah.client, req)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 400 {
// TODO: handle error body and write debug output
return "", errors.Errorf("unexpected status: %s", resp.Status)
}
decoder := json.NewDecoder(resp.Body)
var tr getTokenResponse
if err = decoder.Decode(&tr); err != nil {
return "", fmt.Errorf("unable to decode token response: %s", err)
}
// `access_token` is equivalent to `token` and if both are specified
// the choice is undefined. Canonicalize `access_token` by sticking
// things in `token`.
if tr.AccessToken != "" {
tr.Token = tr.AccessToken
}
if tr.Token == "" {
return "", ErrNoToken
}
return tr.Token, nil
}
func invalidAuthorization(c challenge, responses []*http.Response) error {
errStr := c.parameters["error"]
if errStr == "" { if errStr == "" {
return nil return nil
} }

View File

@ -40,10 +40,6 @@ import (
) )
var ( var (
// ErrNoToken is returned if a request is successful but the body does not
// contain an authorization token.
ErrNoToken = errors.New("authorization server did not include a token in the response")
// ErrInvalidAuthorization is used when credentials are passed to a server but // ErrInvalidAuthorization is used when credentials are passed to a server but
// those credentials are rejected. // those credentials are rejected.
ErrInvalidAuthorization = errors.New("authorization failed") ErrInvalidAuthorization = errors.New("authorization failed")