Move public key getter to interface
This commit is contained in:
		| @@ -96,7 +96,7 @@ type Extra struct { | ||||
| 	// ServiceAccountIssuerDiscovery | ||||
| 	ServiceAccountIssuerURL        string | ||||
| 	ServiceAccountJWKSURI          string | ||||
| 	ServiceAccountPublicKeys []interface{} | ||||
| 	ServiceAccountPublicKeysGetter serviceaccount.PublicKeysGetter | ||||
|  | ||||
| 	SystemNamespaces []string | ||||
|  | ||||
| @@ -363,6 +363,7 @@ func CreateConfig( | ||||
| 		return nil, nil, fmt.Errorf("failed to apply admission: %w", err) | ||||
| 	} | ||||
|  | ||||
| 	if len(opts.Authentication.ServiceAccounts.KeyFiles) > 0 { | ||||
| 		// Load and set the public keys. | ||||
| 		var pubKeys []interface{} | ||||
| 		for _, f := range opts.Authentication.ServiceAccounts.KeyFiles { | ||||
| @@ -372,9 +373,14 @@ func CreateConfig( | ||||
| 			} | ||||
| 			pubKeys = append(pubKeys, keys...) | ||||
| 		} | ||||
| 		keysGetter, err := serviceaccount.StaticPublicKeysGetter(pubKeys) | ||||
| 		if err != nil { | ||||
| 			return nil, nil, fmt.Errorf("failed to set up public service account keys: %w", err) | ||||
| 		} | ||||
| 		config.ServiceAccountPublicKeysGetter = keysGetter | ||||
| 	} | ||||
| 	config.ServiceAccountIssuerURL = opts.Authentication.ServiceAccounts.Issuers[0] | ||||
| 	config.ServiceAccountJWKSURI = opts.Authentication.ServiceAccounts.JWKSURI | ||||
| 	config.ServiceAccountPublicKeys = pubKeys | ||||
|  | ||||
| 	return config, genericInitializers, nil | ||||
| } | ||||
|   | ||||
| @@ -93,13 +93,11 @@ func (c completedConfig) New(name string, delegationTarget genericapiserver.Dele | ||||
| 		routes.Logs{}.Install(generic.Handler.GoRestfulContainer) | ||||
| 	} | ||||
|  | ||||
| 	// Metadata and keys are expected to only change across restarts at present, | ||||
| 	// so we just marshal immediately and serve the cached JSON bytes. | ||||
| 	md, err := serviceaccount.NewOpenIDMetadata( | ||||
| 	md, err := serviceaccount.NewOpenIDMetadataProvider( | ||||
| 		c.ServiceAccountIssuerURL, | ||||
| 		c.ServiceAccountJWKSURI, | ||||
| 		c.Generic.ExternalAddress, | ||||
| 		c.ServiceAccountPublicKeys, | ||||
| 		c.ServiceAccountPublicKeysGetter, | ||||
| 	) | ||||
| 	if err != nil { | ||||
| 		// If there was an error, skip installing the endpoints and log the | ||||
| @@ -120,8 +118,7 @@ func (c completedConfig) New(name string, delegationTarget genericapiserver.Dele | ||||
| 			klog.Info(msg) | ||||
| 		} | ||||
| 	} else { | ||||
| 		routes.NewOpenIDMetadataServer(md.ConfigJSON, md.PublicKeysetJSON). | ||||
| 			Install(generic.Handler.GoRestfulContainer) | ||||
| 		routes.NewOpenIDMetadataServer(md).Install(generic.Handler.GoRestfulContainer) | ||||
| 	} | ||||
|  | ||||
| 	s := &Server{ | ||||
|   | ||||
| @@ -62,7 +62,6 @@ type Config struct { | ||||
| 	AuthenticationConfig        *apiserver.AuthenticationConfiguration | ||||
| 	AuthenticationConfigData    string | ||||
| 	OIDCSigningAlgs             []string | ||||
| 	ServiceAccountKeyFiles      []string | ||||
| 	ServiceAccountLookup        bool | ||||
| 	ServiceAccountIssuers       []string | ||||
| 	APIAudiences                authenticator.Audiences | ||||
| @@ -79,7 +78,9 @@ type Config struct { | ||||
|  | ||||
| 	RequestHeaderConfig *authenticatorfactory.RequestHeaderConfig | ||||
|  | ||||
| 	// TODO, this is the only non-serializable part of the entire config.  Factor it out into a clientconfig | ||||
| 	// ServiceAccountPublicKeysGetter returns public keys for verifying service account tokens. | ||||
| 	ServiceAccountPublicKeysGetter serviceaccount.PublicKeysGetter | ||||
| 	// ServiceAccountTokenGetter fetches API objects used to verify bound objects in service account token claims. | ||||
| 	ServiceAccountTokenGetter   serviceaccount.ServiceAccountTokenGetter | ||||
| 	SecretsWriter               typedv1core.SecretsGetter | ||||
| 	BootstrapTokenAuthenticator authenticator.Token | ||||
| @@ -127,15 +128,15 @@ func (config Config) New(serverLifecycle context.Context) (authenticator.Request | ||||
| 		} | ||||
| 		tokenAuthenticators = append(tokenAuthenticators, authenticator.WrapAudienceAgnosticToken(config.APIAudiences, tokenAuth)) | ||||
| 	} | ||||
| 	if len(config.ServiceAccountKeyFiles) > 0 { | ||||
| 		serviceAccountAuth, err := newLegacyServiceAccountAuthenticator(config.ServiceAccountKeyFiles, config.ServiceAccountLookup, config.APIAudiences, config.ServiceAccountTokenGetter, config.SecretsWriter) | ||||
| 	if config.ServiceAccountPublicKeysGetter != nil { | ||||
| 		serviceAccountAuth, err := newLegacyServiceAccountAuthenticator(config.ServiceAccountPublicKeysGetter, config.ServiceAccountLookup, config.APIAudiences, config.ServiceAccountTokenGetter, config.SecretsWriter) | ||||
| 		if err != nil { | ||||
| 			return nil, nil, nil, nil, err | ||||
| 		} | ||||
| 		tokenAuthenticators = append(tokenAuthenticators, serviceAccountAuth) | ||||
| 	} | ||||
| 	if len(config.ServiceAccountIssuers) > 0 { | ||||
| 		serviceAccountAuth, err := newServiceAccountAuthenticator(config.ServiceAccountIssuers, config.ServiceAccountKeyFiles, config.APIAudiences, config.ServiceAccountTokenGetter) | ||||
| 		serviceAccountAuth, err := newServiceAccountAuthenticator(config.ServiceAccountIssuers, config.ServiceAccountPublicKeysGetter, config.APIAudiences, config.ServiceAccountTokenGetter) | ||||
| 		if err != nil { | ||||
| 			return nil, nil, nil, nil, err | ||||
| 		} | ||||
| @@ -338,36 +339,25 @@ func newAuthenticatorFromTokenFile(tokenAuthFile string) (authenticator.Token, e | ||||
| } | ||||
|  | ||||
| // newLegacyServiceAccountAuthenticator returns an authenticator.Token or an error | ||||
| func newLegacyServiceAccountAuthenticator(keyfiles []string, lookup bool, apiAudiences authenticator.Audiences, serviceAccountGetter serviceaccount.ServiceAccountTokenGetter, secretsWriter typedv1core.SecretsGetter) (authenticator.Token, error) { | ||||
| 	allPublicKeys := []interface{}{} | ||||
| 	for _, keyfile := range keyfiles { | ||||
| 		publicKeys, err := keyutil.PublicKeysFromFile(keyfile) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		allPublicKeys = append(allPublicKeys, publicKeys...) | ||||
| func newLegacyServiceAccountAuthenticator(publicKeysGetter serviceaccount.PublicKeysGetter, lookup bool, apiAudiences authenticator.Audiences, serviceAccountGetter serviceaccount.ServiceAccountTokenGetter, secretsWriter typedv1core.SecretsGetter) (authenticator.Token, error) { | ||||
| 	if publicKeysGetter == nil { | ||||
| 		return nil, fmt.Errorf("no public key getter provided") | ||||
| 	} | ||||
| 	validator, err := serviceaccount.NewLegacyValidator(lookup, serviceAccountGetter, secretsWriter) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("while creating legacy validator, err: %w", err) | ||||
| 	} | ||||
|  | ||||
| 	tokenAuthenticator := serviceaccount.JWTTokenAuthenticator([]string{serviceaccount.LegacyIssuer}, allPublicKeys, apiAudiences, validator) | ||||
| 	tokenAuthenticator := serviceaccount.JWTTokenAuthenticator([]string{serviceaccount.LegacyIssuer}, publicKeysGetter, apiAudiences, validator) | ||||
| 	return tokenAuthenticator, nil | ||||
| } | ||||
|  | ||||
| // newServiceAccountAuthenticator returns an authenticator.Token or an error | ||||
| func newServiceAccountAuthenticator(issuers []string, keyfiles []string, apiAudiences authenticator.Audiences, serviceAccountGetter serviceaccount.ServiceAccountTokenGetter) (authenticator.Token, error) { | ||||
| 	allPublicKeys := []interface{}{} | ||||
| 	for _, keyfile := range keyfiles { | ||||
| 		publicKeys, err := keyutil.PublicKeysFromFile(keyfile) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| func newServiceAccountAuthenticator(issuers []string, publicKeysGetter serviceaccount.PublicKeysGetter, apiAudiences authenticator.Audiences, serviceAccountGetter serviceaccount.ServiceAccountTokenGetter) (authenticator.Token, error) { | ||||
| 	if publicKeysGetter == nil { | ||||
| 		return nil, fmt.Errorf("no public key getter provided") | ||||
| 	} | ||||
| 		allPublicKeys = append(allPublicKeys, publicKeys...) | ||||
| 	} | ||||
|  | ||||
| 	tokenAuthenticator := serviceaccount.JWTTokenAuthenticator(issuers, allPublicKeys, apiAudiences, serviceaccount.NewValidator(serviceAccountGetter)) | ||||
| 	tokenAuthenticator := serviceaccount.JWTTokenAuthenticator(issuers, publicKeysGetter, apiAudiences, serviceaccount.NewValidator(serviceAccountGetter)) | ||||
| 	return tokenAuthenticator, nil | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -47,6 +47,7 @@ import ( | ||||
| 	"k8s.io/client-go/informers" | ||||
| 	"k8s.io/client-go/kubernetes" | ||||
| 	v1listers "k8s.io/client-go/listers/core/v1" | ||||
| 	"k8s.io/client-go/util/keyutil" | ||||
| 	cliflag "k8s.io/component-base/cli/flag" | ||||
| 	"k8s.io/klog/v2" | ||||
| 	openapicommon "k8s.io/kube-openapi/pkg/common" | ||||
| @@ -54,6 +55,7 @@ import ( | ||||
| 	"k8s.io/kubernetes/pkg/features" | ||||
| 	kubeauthenticator "k8s.io/kubernetes/pkg/kubeapiserver/authenticator" | ||||
| 	authzmodes "k8s.io/kubernetes/pkg/kubeapiserver/authorizer/modes" | ||||
| 	"k8s.io/kubernetes/pkg/serviceaccount" | ||||
| 	"k8s.io/kubernetes/pkg/util/filesystem" | ||||
| 	"k8s.io/kubernetes/plugin/pkg/auth/authenticator/token/bootstrap" | ||||
| 	"k8s.io/utils/pointer" | ||||
| @@ -559,7 +561,21 @@ func (o *BuiltInAuthenticationOptions) ToAuthenticationConfig() (kubeauthenticat | ||||
| 		if len(o.ServiceAccounts.Issuers) != 0 && len(o.APIAudiences) == 0 { | ||||
| 			ret.APIAudiences = authenticator.Audiences(o.ServiceAccounts.Issuers) | ||||
| 		} | ||||
| 		ret.ServiceAccountKeyFiles = o.ServiceAccounts.KeyFiles | ||||
| 		if len(o.ServiceAccounts.KeyFiles) > 0 { | ||||
| 			allPublicKeys := []interface{}{} | ||||
| 			for _, keyfile := range o.ServiceAccounts.KeyFiles { | ||||
| 				publicKeys, err := keyutil.PublicKeysFromFile(keyfile) | ||||
| 				if err != nil { | ||||
| 					return kubeauthenticator.Config{}, err | ||||
| 				} | ||||
| 				allPublicKeys = append(allPublicKeys, publicKeys...) | ||||
| 			} | ||||
| 			keysGetter, err := serviceaccount.StaticPublicKeysGetter(allPublicKeys) | ||||
| 			if err != nil { | ||||
| 				return kubeauthenticator.Config{}, fmt.Errorf("failed to set up public service account keys: %w", err) | ||||
| 			} | ||||
| 			ret.ServiceAccountPublicKeysGetter = keysGetter | ||||
| 		} | ||||
| 		ret.ServiceAccountIssuers = o.ServiceAccounts.Issuers | ||||
| 		ret.ServiceAccountLookup = o.ServiceAccounts.Lookup | ||||
| 	} | ||||
|   | ||||
| @@ -17,6 +17,7 @@ limitations under the License. | ||||
| package routes | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
|  | ||||
| 	restful "github.com/emicklei/go-restful/v3" | ||||
| @@ -34,7 +35,8 @@ const ( | ||||
| 	// cacheControl is the value of the Cache-Control header. Overrides the | ||||
| 	// global `private, no-cache` setting. | ||||
| 	headerCacheControl = "Cache-Control" | ||||
| 	cacheControl       = "public, max-age=3600" // 1 hour | ||||
|  | ||||
| 	cacheControlTemplate = "public, max-age=%d" | ||||
|  | ||||
| 	// mimeJWKS is the content type of the keyset response | ||||
| 	mimeJWKS = "application/jwk-set+json" | ||||
| @@ -42,18 +44,14 @@ const ( | ||||
|  | ||||
| // OpenIDMetadataServer is an HTTP server for metadata of the KSA token issuer. | ||||
| type OpenIDMetadataServer struct { | ||||
| 	configJSON []byte | ||||
| 	keysetJSON []byte | ||||
| 	provider serviceaccount.OpenIDMetadataProvider | ||||
| } | ||||
|  | ||||
| // NewOpenIDMetadataServer creates a new OpenIDMetadataServer. | ||||
| // The issuer is the OIDC issuer; keys are the keys that may be used to sign | ||||
| // KSA tokens. | ||||
| func NewOpenIDMetadataServer(configJSON, keysetJSON []byte) *OpenIDMetadataServer { | ||||
| 	return &OpenIDMetadataServer{ | ||||
| 		configJSON: configJSON, | ||||
| 		keysetJSON: keysetJSON, | ||||
| 	} | ||||
| func NewOpenIDMetadataServer(provider serviceaccount.OpenIDMetadataProvider) *OpenIDMetadataServer { | ||||
| 	return &OpenIDMetadataServer{provider: provider} | ||||
| } | ||||
|  | ||||
| // Install adds this server to the request router c. | ||||
| @@ -95,19 +93,21 @@ func fromStandard(h http.HandlerFunc) restful.RouteFunction { | ||||
| } | ||||
|  | ||||
| func (s *OpenIDMetadataServer) serveConfiguration(w http.ResponseWriter, req *http.Request) { | ||||
| 	configJSON, maxAge := s.provider.GetConfigJSON() | ||||
| 	w.Header().Set(restful.HEADER_ContentType, restful.MIME_JSON) | ||||
| 	w.Header().Set(headerCacheControl, cacheControl) | ||||
| 	if _, err := w.Write(s.configJSON); err != nil { | ||||
| 	w.Header().Set(headerCacheControl, fmt.Sprintf(cacheControlTemplate, maxAge)) | ||||
| 	if _, err := w.Write(configJSON); err != nil { | ||||
| 		klog.Errorf("failed to write service account issuer metadata response: %v", err) | ||||
| 		return | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (s *OpenIDMetadataServer) serveKeys(w http.ResponseWriter, req *http.Request) { | ||||
| 	keysetJSON, maxAge := s.provider.GetKeysetJSON() | ||||
| 	// Per RFC7517 : https://tools.ietf.org/html/rfc7517#section-8.5.1 | ||||
| 	w.Header().Set(restful.HEADER_ContentType, mimeJWKS) | ||||
| 	w.Header().Set(headerCacheControl, cacheControl) | ||||
| 	if _, err := w.Write(s.keysetJSON); err != nil { | ||||
| 	w.Header().Set(headerCacheControl, fmt.Sprintf(cacheControlTemplate, maxAge)) | ||||
| 	if _, err := w.Write(keysetJSON); err != nil { | ||||
| 		klog.Errorf("failed to write service account issuer JWKS response: %v", err) | ||||
| 		return | ||||
| 	} | ||||
|   | ||||
| @@ -225,22 +225,97 @@ func (j *jwtTokenGenerator) GenerateToken(claims *jwt.Claims, privateClaims inte | ||||
| // JWTTokenAuthenticator authenticates tokens as JWT tokens produced by JWTTokenGenerator | ||||
| // Token signatures are verified using each of the given public keys until one works (allowing key rotation) | ||||
| // If lookup is true, the service account and secret referenced as claims inside the token are retrieved and verified with the provided ServiceAccountTokenGetter | ||||
| func JWTTokenAuthenticator[PrivateClaims any](issuers []string, keys []interface{}, implicitAuds authenticator.Audiences, validator Validator[PrivateClaims]) authenticator.Token { | ||||
| func JWTTokenAuthenticator[PrivateClaims any](issuers []string, publicKeysGetter PublicKeysGetter, implicitAuds authenticator.Audiences, validator Validator[PrivateClaims]) authenticator.Token { | ||||
| 	issuersMap := make(map[string]bool) | ||||
| 	for _, issuer := range issuers { | ||||
| 		issuersMap[issuer] = true | ||||
| 	} | ||||
| 	return &jwtTokenAuthenticator[PrivateClaims]{ | ||||
| 		issuers:      issuersMap, | ||||
| 		keys:         keys, | ||||
| 		keysGetter:   publicKeysGetter, | ||||
| 		implicitAuds: implicitAuds, | ||||
| 		validator:    validator, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Listener is an interface to use to notify interested parties of a change. | ||||
| type Listener interface { | ||||
| 	// Enqueue should be called when an input may have changed | ||||
| 	Enqueue() | ||||
| } | ||||
|  | ||||
| // PublicKeysGetter returns public keys for a given key id. | ||||
| type PublicKeysGetter interface { | ||||
| 	// AddListener is adds a listener to be notified of potential input changes. | ||||
| 	// This is a noop on static providers. | ||||
| 	AddListener(listener Listener) | ||||
|  | ||||
| 	// GetCacheAgeMaxSeconds returns the seconds a call to GetPublicKeys() can be cached for. | ||||
| 	// If the results of GetPublicKeys() can be dynamic, this means a new key must be included in the results | ||||
| 	// for at least this long before it is used to sign new tokens. | ||||
| 	GetCacheAgeMaxSeconds() int | ||||
|  | ||||
| 	// GetPublicKeys returns public keys to use for verifying a token with the given key id. | ||||
| 	// keyIDHint may be empty if the token did not have a kid header, or if all public keys are desired. | ||||
| 	GetPublicKeys(keyIDHint string) []PublicKey | ||||
| } | ||||
|  | ||||
| type PublicKey struct { | ||||
| 	KeyID     string | ||||
| 	PublicKey interface{} | ||||
| } | ||||
|  | ||||
| type staticPublicKeysGetter struct { | ||||
| 	allPublicKeys  []PublicKey | ||||
| 	publicKeysByID map[string][]PublicKey | ||||
| } | ||||
|  | ||||
| // StaticPublicKeysGetter constructs an implementation of PublicKeysGetter | ||||
| // which returns all public keys when key id is unspecified, and returns | ||||
| // the public keys matching the keyIDFromPublicKey-derived key id when | ||||
| // a key id is specified. | ||||
| func StaticPublicKeysGetter(keys []interface{}) (PublicKeysGetter, error) { | ||||
| 	allPublicKeys := []PublicKey{} | ||||
| 	publicKeysByID := map[string][]PublicKey{} | ||||
| 	for _, key := range keys { | ||||
| 		if privateKey, isPrivateKey := key.(publicKeyGetter); isPrivateKey { | ||||
| 			// This is a private key. Extract its public key. | ||||
| 			key = privateKey.Public() | ||||
| 		} | ||||
|  | ||||
| 		keyID, err := keyIDFromPublicKey(key) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		pk := PublicKey{PublicKey: key, KeyID: keyID} | ||||
| 		publicKeysByID[keyID] = append(publicKeysByID[keyID], pk) | ||||
| 		allPublicKeys = append(allPublicKeys, pk) | ||||
| 	} | ||||
| 	return &staticPublicKeysGetter{ | ||||
| 		allPublicKeys:  allPublicKeys, | ||||
| 		publicKeysByID: publicKeysByID, | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
| func (s staticPublicKeysGetter) AddListener(listener Listener) { | ||||
| 	// no-op, static key content never changes | ||||
| } | ||||
|  | ||||
| func (s staticPublicKeysGetter) GetCacheAgeMaxSeconds() int { | ||||
| 	// hard-coded to match cache max-age set in OIDC discovery | ||||
| 	return 3600 | ||||
| } | ||||
|  | ||||
| func (s staticPublicKeysGetter) GetPublicKeys(keyID string) []PublicKey { | ||||
| 	if len(keyID) == 0 { | ||||
| 		return s.allPublicKeys | ||||
| 	} | ||||
| 	return s.publicKeysByID[keyID] | ||||
| } | ||||
|  | ||||
| type jwtTokenAuthenticator[PrivateClaims any] struct { | ||||
| 	issuers      map[string]bool | ||||
| 	keys         []interface{} | ||||
| 	keysGetter   PublicKeysGetter | ||||
| 	validator    Validator[PrivateClaims] | ||||
| 	implicitAuds authenticator.Audiences | ||||
| } | ||||
| @@ -269,13 +344,25 @@ func (j *jwtTokenAuthenticator[PrivateClaims]) AuthenticateToken(ctx context.Con | ||||
| 	public := &jwt.Claims{} | ||||
| 	private := new(PrivateClaims) | ||||
|  | ||||
| 	// TODO: Pick the key that has the same key ID as `tok`, if one exists. | ||||
| 	// Pick the key that has the same key ID as `tok`, if one exists. | ||||
| 	var kid string | ||||
| 	for _, header := range tok.Headers { | ||||
| 		if header.KeyID != "" { | ||||
| 			kid = header.KeyID | ||||
| 			break | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	var ( | ||||
| 		found   bool | ||||
| 		errlist []error | ||||
| 	) | ||||
| 	for _, key := range j.keys { | ||||
| 		if err := tok.Claims(key, public, private); err != nil { | ||||
| 	keys := j.keysGetter.GetPublicKeys(kid) | ||||
| 	if len(keys) == 0 { | ||||
| 		return nil, false, fmt.Errorf("invalid signature, no keys found") | ||||
| 	} | ||||
| 	for _, key := range keys { | ||||
| 		if err := tok.Claims(key.PublicKey, public, private); err != nil { | ||||
| 			errlist = append(errlist, err) | ||||
| 			continue | ||||
| 		} | ||||
|   | ||||
| @@ -247,7 +247,7 @@ func TestTokenGenerateAndValidate(t *testing.T) { | ||||
| 			Token:       rsaToken, | ||||
| 			Client:      nil, | ||||
| 			Keys:        []interface{}{}, | ||||
| 			ExpectedErr: false, | ||||
| 			ExpectedErr: true, | ||||
| 			ExpectedOK:  false, | ||||
| 		}, | ||||
| 		"invalid keys (rsa)": { | ||||
| @@ -385,7 +385,13 @@ func TestTokenGenerateAndValidate(t *testing.T) { | ||||
| 		if err != nil { | ||||
| 			t.Fatalf("While creating legacy validator, err: %v", err) | ||||
| 		} | ||||
| 		authn := serviceaccount.JWTTokenAuthenticator([]string{serviceaccount.LegacyIssuer, "bar"}, tc.Keys, auds, validator) | ||||
| 		staticKeysGetter, err := serviceaccount.StaticPublicKeysGetter(tc.Keys) | ||||
| 		if err != nil { | ||||
| 			t.Fatal(err) | ||||
| 		} | ||||
| 		keysGetter := &keyIDPrefixer{PublicKeysGetter: staticKeysGetter} | ||||
|  | ||||
| 		authn := serviceaccount.JWTTokenAuthenticator([]string{serviceaccount.LegacyIssuer, "bar"}, keysGetter, auds, validator) | ||||
|  | ||||
| 		// An invalid, non-JWT token should always fail | ||||
| 		ctx := authenticator.WithAudiences(context.Background(), auds) | ||||
| @@ -394,6 +400,16 @@ func TestTokenGenerateAndValidate(t *testing.T) { | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		if tc.ExpectedOK { | ||||
| 			// if authentication is otherwise expected to succeed, demonstrate changing key ids makes it fail | ||||
| 			keysGetter.keyIDPrefix = "bogus" | ||||
| 			if _, ok, err := authn.AuthenticateToken(ctx, tc.Token); err == nil || !strings.Contains(err.Error(), "no keys found") || ok { | ||||
| 				t.Errorf("%s: Expected err containing 'no keys found', ok=false when key lookup by ID fails", k) | ||||
| 				continue | ||||
| 			} | ||||
| 			keysGetter.keyIDPrefix = "" | ||||
| 		} | ||||
|  | ||||
| 		resp, ok, err := authn.AuthenticateToken(ctx, tc.Token) | ||||
| 		if (err != nil) != tc.ExpectedErr { | ||||
| 			t.Errorf("%s: Expected error=%v, got %v", k, tc.ExpectedErr, err) | ||||
| @@ -424,6 +440,26 @@ func TestTokenGenerateAndValidate(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| type keyIDPrefixer struct { | ||||
| 	serviceaccount.PublicKeysGetter | ||||
| 	keyIDPrefix string | ||||
| } | ||||
|  | ||||
| func (k *keyIDPrefixer) GetPublicKeys(keyIDHint string) []serviceaccount.PublicKey { | ||||
| 	if k.keyIDPrefix == "" { | ||||
| 		return k.PublicKeysGetter.GetPublicKeys(keyIDHint) | ||||
| 	} | ||||
| 	if keyIDHint != "" { | ||||
| 		keyIDHint = k.keyIDPrefix + keyIDHint | ||||
| 	} | ||||
| 	var retval []serviceaccount.PublicKey | ||||
| 	for _, key := range k.PublicKeysGetter.GetPublicKeys(keyIDHint) { | ||||
| 		key.KeyID = k.keyIDPrefix + key.KeyID | ||||
| 		retval = append(retval, key) | ||||
| 	} | ||||
| 	return retval | ||||
| } | ||||
|  | ||||
| func checkJSONWebSignatureHasKeyID(t *testing.T, jwsString string, expectedKeyID string) { | ||||
| 	jws, err := jose.ParseSigned(jwsString) | ||||
| 	if err != nil { | ||||
| @@ -502,3 +538,76 @@ func generateECDSATokenWithMalformedIss(t *testing.T, serviceAccount *v1.Service | ||||
|  | ||||
| 	return string(out) | ||||
| } | ||||
|  | ||||
| func TestStaticPublicKeysGetter(t *testing.T) { | ||||
| 	ecPrivate := getPrivateKey(ecdsaPrivateKey) | ||||
| 	ecPublic := getPublicKey(ecdsaPublicKey) | ||||
| 	rsaPublic := getPublicKey(rsaPublicKey) | ||||
|  | ||||
| 	testcases := []struct { | ||||
| 		Name       string | ||||
| 		Keys       []interface{} | ||||
| 		ExpectErr  bool | ||||
| 		ExpectKeys []serviceaccount.PublicKey | ||||
| 	}{ | ||||
| 		{ | ||||
| 			Name:       "empty", | ||||
| 			Keys:       nil, | ||||
| 			ExpectKeys: []serviceaccount.PublicKey{}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Name: "simple", | ||||
| 			Keys: []interface{}{ecPublic, rsaPublic}, | ||||
| 			ExpectKeys: []serviceaccount.PublicKey{ | ||||
| 				{KeyID: "SoABiieYuNx4UdqYvZRVeuC6SihxgLrhLy9peHMHpTc", PublicKey: ecPublic}, | ||||
| 				{KeyID: "JHJehTTTZlsspKHT-GaJxK7Kd1NQgZJu3fyK6K_QDYU", PublicKey: rsaPublic}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Name: "private --> public", | ||||
| 			Keys: []interface{}{ecPrivate}, | ||||
| 			ExpectKeys: []serviceaccount.PublicKey{ | ||||
| 				{KeyID: "SoABiieYuNx4UdqYvZRVeuC6SihxgLrhLy9peHMHpTc", PublicKey: ecPublic}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Name:      "invalid", | ||||
| 			Keys:      []interface{}{"bogus"}, | ||||
| 			ExpectErr: true, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for _, tc := range testcases { | ||||
| 		t.Run(tc.Name, func(t *testing.T) { | ||||
| 			getter, err := serviceaccount.StaticPublicKeysGetter(tc.Keys) | ||||
| 			if tc.ExpectErr { | ||||
| 				if err == nil { | ||||
| 					t.Fatal("expected construction error, got none") | ||||
| 				} | ||||
| 				return | ||||
| 			} | ||||
| 			if err != nil { | ||||
| 				t.Fatalf("unexpected construction error: %v", err) | ||||
| 			} | ||||
|  | ||||
| 			bogusKeys := getter.GetPublicKeys("bogus") | ||||
| 			if len(bogusKeys) != 0 { | ||||
| 				t.Fatalf("unexpected bogus keys: %#v", bogusKeys) | ||||
| 			} | ||||
|  | ||||
| 			allKeys := getter.GetPublicKeys("") | ||||
| 			if !reflect.DeepEqual(tc.ExpectKeys, allKeys) { | ||||
| 				t.Fatalf("unexpected keys: %#v", allKeys) | ||||
| 			} | ||||
| 			for _, key := range allKeys { | ||||
| 				keysByID := getter.GetPublicKeys(key.KeyID) | ||||
| 				if len(keysByID) != 1 { | ||||
| 					t.Fatalf("expected 1 key for id %s, got %d", key.KeyID, len(keysByID)) | ||||
| 				} | ||||
| 				if !reflect.DeepEqual(key, keysByID[0]) { | ||||
| 					t.Fatalf("unexpected key for id %s", key.KeyID) | ||||
| 				} | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|   | ||||
							
								
								
									
										49
									
								
								pkg/serviceaccount/keyid_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										49
									
								
								pkg/serviceaccount/keyid_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,49 @@ | ||||
| /* | ||||
| Copyright 2024 The Kubernetes Authors. | ||||
|  | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
|  | ||||
|     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| */ | ||||
|  | ||||
| package serviceaccount | ||||
|  | ||||
| import ( | ||||
| 	"testing" | ||||
|  | ||||
| 	"k8s.io/client-go/util/keyutil" | ||||
| ) | ||||
|  | ||||
| const rsaPublicKey = `-----BEGIN PUBLIC KEY----- | ||||
| MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA249XwEo9k4tM8fMxV7zx | ||||
| OhcrP+WvXn917koM5Qr2ZXs4vo26e4ytdlrV0bQ9SlcLpQVSYjIxNfhTZdDt+ecI | ||||
| zshKuv1gKIxbbLQMOuK1eA/4HALyEkFgmS/tleLJrhc65tKPMGD+pKQ/xhmzRuCG | ||||
| 51RoiMgbQxaCyYxGfNLpLAZK9L0Tctv9a0mJmGIYnIOQM4kC1A1I1n3EsXMWmeJU | ||||
| j7OTh/AjjCnMnkgvKT2tpKxYQ59PgDgU8Ssc7RDSmSkLxnrv+OrN80j6xrw0OjEi | ||||
| B4Ycr0PqfzZcvy8efTtFQ/Jnc4Bp1zUtFXt7+QeevePtQ2EcyELXE0i63T1CujRM | ||||
| WwIDAQAB | ||||
| -----END PUBLIC KEY----- | ||||
| ` | ||||
|  | ||||
| func TestKeyIDStability(t *testing.T) { | ||||
| 	keys, err := keyutil.ParsePublicKeysPEM([]byte(rsaPublicKey)) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	keyID, err := keyIDFromPublicKey(keys[0]) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	// The derived key id for a given public key must not change or validation of previously issued tokens will fail to find associated keys | ||||
| 	if expected, actual := "JHJehTTTZlsspKHT-GaJxK7Kd1NQgZJu3fyK6K_QDYU", keyID; expected != actual { | ||||
| 		t.Fatalf("expected stable key id %q, got %q", expected, actual) | ||||
| 	} | ||||
| } | ||||
| @@ -24,11 +24,13 @@ import ( | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"net/url" | ||||
| 	"sync/atomic" | ||||
|  | ||||
| 	jose "gopkg.in/square/go-jose.v2" | ||||
|  | ||||
| 	"k8s.io/apimachinery/pkg/util/errors" | ||||
| 	"k8s.io/apimachinery/pkg/util/sets" | ||||
| 	"k8s.io/klog/v2" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| @@ -44,26 +46,68 @@ const ( | ||||
| 	JWKSPath = "/openid/v1/jwks" | ||||
| ) | ||||
|  | ||||
| // OpenIDMetadata contains the pre-rendered responses for OIDC discovery endpoints. | ||||
| type OpenIDMetadata struct { | ||||
| 	ConfigJSON       []byte | ||||
| 	PublicKeysetJSON []byte | ||||
| // OpenIDMetadataProvider returns pre-rendered responses for OIDC discovery endpoints. | ||||
| type OpenIDMetadataProvider interface { | ||||
| 	GetConfigJSON() (json []byte, maxAge int) | ||||
| 	GetKeysetJSON() (json []byte, maxAge int) | ||||
| } | ||||
|  | ||||
| // NewOpenIDMetadata returns the pre-rendered JSON responses for the OIDC discovery | ||||
| type openidConfigProvider struct { | ||||
| 	issuerURL, jwksURI string | ||||
| 	pubKeyGetter       PublicKeysGetter | ||||
| 	config             atomic.Pointer[openidConfig] | ||||
| } | ||||
| type openidConfig struct { | ||||
| 	configJSON []byte | ||||
| 	keysetJSON []byte | ||||
| } | ||||
|  | ||||
| func (p *openidConfigProvider) GetConfigJSON() ([]byte, int) { | ||||
| 	return p.config.Load().configJSON, p.pubKeyGetter.GetCacheAgeMaxSeconds() | ||||
| } | ||||
| func (p *openidConfigProvider) GetKeysetJSON() ([]byte, int) { | ||||
| 	return p.config.Load().keysetJSON, p.pubKeyGetter.GetCacheAgeMaxSeconds() | ||||
| } | ||||
| func (p *openidConfigProvider) Enqueue() { | ||||
| 	err := p.Update() | ||||
| 	if err != nil { | ||||
| 		klog.ErrorS(err, "failed to update openid config metadata") | ||||
| 	} | ||||
| } | ||||
| func (p *openidConfigProvider) Update() error { | ||||
| 	pubKeys := p.pubKeyGetter.GetPublicKeys("") | ||||
| 	if len(pubKeys) == 0 { | ||||
| 		return fmt.Errorf("no keys provided for validating keyset") | ||||
| 	} | ||||
| 	configJSON, err := openIDConfigJSON(p.issuerURL, p.jwksURI, pubKeys) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("could not marshal issuer discovery JSON, error: %w", err) | ||||
| 	} | ||||
| 	keysetJSON, err := openIDKeysetJSON(pubKeys) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("could not marshal issuer keys JSON, error: %w", err) | ||||
| 	} | ||||
| 	p.config.Store(&openidConfig{ | ||||
| 		configJSON: configJSON, | ||||
| 		keysetJSON: keysetJSON, | ||||
| 	}) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // NewOpenIDMetadataProvider returns a provider for the OIDC discovery | ||||
| // endpoints, or an error if they could not be constructed. Callers should note | ||||
| // that this function may perform additional validation on inputs that is not | ||||
| // backwards-compatible with all command-line validation. The recommendation is | ||||
| // to log the error and skip installing the OIDC discovery endpoints. | ||||
| func NewOpenIDMetadata(issuerURL, jwksURI, defaultExternalAddress string, pubKeys []interface{}) (*OpenIDMetadata, error) { | ||||
| func NewOpenIDMetadataProvider(issuerURL, jwksURI, defaultExternalAddress string, pubKeyGetter PublicKeysGetter) (OpenIDMetadataProvider, error) { | ||||
| 	if issuerURL == "" { | ||||
| 		return nil, fmt.Errorf("empty issuer URL") | ||||
| 	} | ||||
| 	if jwksURI == "" && defaultExternalAddress == "" { | ||||
| 		return nil, fmt.Errorf("either the JWKS URI or the default external address, or both, must be set") | ||||
| 	} | ||||
| 	if len(pubKeys) == 0 { | ||||
| 		return nil, fmt.Errorf("no keys provided for validating keyset") | ||||
| 	if pubKeyGetter == nil { | ||||
| 		return nil, fmt.Errorf("no public key getter provided") | ||||
| 	} | ||||
|  | ||||
| 	// Ensure the issuer URL meets the OIDC spec (this is the additional | ||||
| @@ -126,20 +170,18 @@ func NewOpenIDMetadata(issuerURL, jwksURI, defaultExternalAddress string, pubKey | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	configJSON, err := openIDConfigJSON(issuerURL, jwksURI, pubKeys) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("could not marshal issuer discovery JSON, error: %v", err) | ||||
| 	provider := &openidConfigProvider{ | ||||
| 		issuerURL:    issuerURL, | ||||
| 		jwksURI:      jwksURI, | ||||
| 		pubKeyGetter: pubKeyGetter, | ||||
| 	} | ||||
|  | ||||
| 	keysetJSON, err := openIDKeysetJSON(pubKeys) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("could not marshal issuer keys JSON, error: %v", err) | ||||
| 	// Register to be notified if public keys change | ||||
| 	pubKeyGetter.AddListener(provider) | ||||
| 	// Synchronously construct the config / keyset json once at startup to ensure a successful starting point | ||||
| 	if err := provider.Update(); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	return &OpenIDMetadata{ | ||||
| 		ConfigJSON:       configJSON, | ||||
| 		PublicKeysetJSON: keysetJSON, | ||||
| 	}, nil | ||||
| 	return provider, nil | ||||
| } | ||||
|  | ||||
| // openIDMetadata provides a minimal subset of OIDC provider metadata: | ||||
| @@ -159,7 +201,7 @@ type openIDMetadata struct { | ||||
|  | ||||
| // openIDConfigJSON returns the JSON OIDC Discovery Doc for the service | ||||
| // account issuer. | ||||
| func openIDConfigJSON(iss, jwksURI string, keys []interface{}) ([]byte, error) { | ||||
| func openIDConfigJSON(iss, jwksURI string, keys []PublicKey) ([]byte, error) { | ||||
| 	keyset, errs := publicJWKSFromKeys(keys) | ||||
| 	if errs != nil { | ||||
| 		return nil, errs | ||||
| @@ -183,7 +225,7 @@ func openIDConfigJSON(iss, jwksURI string, keys []interface{}) ([]byte, error) { | ||||
|  | ||||
| // openIDKeysetJSON returns the JSON Web Key Set for the service account | ||||
| // issuer's keys. | ||||
| func openIDKeysetJSON(keys []interface{}) ([]byte, error) { | ||||
| func openIDKeysetJSON(keys []PublicKey) ([]byte, error) { | ||||
| 	keyset, errs := publicJWKSFromKeys(keys) | ||||
| 	if errs != nil { | ||||
| 		return nil, errs | ||||
| @@ -212,21 +254,12 @@ type publicKeyGetter interface { | ||||
|  | ||||
| // publicJWKSFromKeys constructs a JSONWebKeySet from a list of keys. The key | ||||
| // set will only contain the public keys associated with the input keys. | ||||
| func publicJWKSFromKeys(in []interface{}) (*jose.JSONWebKeySet, errors.Aggregate) { | ||||
| func publicJWKSFromKeys(in []PublicKey) (*jose.JSONWebKeySet, errors.Aggregate) { | ||||
| 	// Decode keys into a JWKS. | ||||
| 	var keys jose.JSONWebKeySet | ||||
| 	var errs []error | ||||
| 	for i, key := range in { | ||||
| 		var pubkey *jose.JSONWebKey | ||||
| 		var err error | ||||
|  | ||||
| 		switch k := key.(type) { | ||||
| 		case publicKeyGetter: | ||||
| 			// This is a private key. Get its public key | ||||
| 			pubkey, err = jwkFromPublicKey(k.Public()) | ||||
| 		default: | ||||
| 			pubkey, err = jwkFromPublicKey(k) | ||||
| 		} | ||||
| 		pubkey, err := jwkFromPublicKey(key) | ||||
| 		if err != nil { | ||||
| 			errs = append(errs, fmt.Errorf("error constructing JWK for key #%d: %v", i, err)) | ||||
| 			continue | ||||
| @@ -244,21 +277,16 @@ func publicJWKSFromKeys(in []interface{}) (*jose.JSONWebKeySet, errors.Aggregate | ||||
| 	return &keys, nil | ||||
| } | ||||
|  | ||||
| func jwkFromPublicKey(publicKey crypto.PublicKey) (*jose.JSONWebKey, error) { | ||||
| 	alg, err := algorithmFromPublicKey(publicKey) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	keyID, err := keyIDFromPublicKey(publicKey) | ||||
| func jwkFromPublicKey(publicKey PublicKey) (*jose.JSONWebKey, error) { | ||||
| 	alg, err := algorithmFromPublicKey(publicKey.PublicKey) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	jwk := &jose.JSONWebKey{ | ||||
| 		Algorithm: string(alg), | ||||
| 		Key:       publicKey, | ||||
| 		KeyID:     keyID, | ||||
| 		Key:       publicKey.PublicKey, | ||||
| 		KeyID:     publicKey.KeyID, | ||||
| 		Use:       "sig", | ||||
| 	} | ||||
|  | ||||
|   | ||||
| @@ -39,7 +39,7 @@ const ( | ||||
| 	exampleIssuer = "https://issuer.example.com" | ||||
| ) | ||||
|  | ||||
| func setupServer(t *testing.T, iss string, keys []interface{}) (*httptest.Server, string) { | ||||
| func setupServer(t *testing.T, iss string, keys serviceaccount.PublicKeysGetter) (*httptest.Server, string) { | ||||
| 	t.Helper() | ||||
|  | ||||
| 	c := restful.NewContainer() | ||||
| @@ -53,13 +53,13 @@ func setupServer(t *testing.T, iss string, keys []interface{}) (*httptest.Server | ||||
| 	jwksURI.Scheme = "https" | ||||
| 	jwksURI.Path = serviceaccount.JWKSPath | ||||
|  | ||||
| 	md, err := serviceaccount.NewOpenIDMetadata( | ||||
| 	md, err := serviceaccount.NewOpenIDMetadataProvider( | ||||
| 		iss, jwksURI.String(), "", keys) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	srv := routes.NewOpenIDMetadataServer(md.ConfigJSON, md.PublicKeysetJSON) | ||||
| 	srv := routes.NewOpenIDMetadataServer(md) | ||||
| 	srv.Install(c) | ||||
|  | ||||
| 	return s, jwksURI.String() | ||||
| @@ -77,20 +77,59 @@ type Configuration struct { | ||||
| 	SubjectTypes  []string `json:"subject_types_supported"` | ||||
| } | ||||
|  | ||||
| type proxyKeyGetter struct { | ||||
| 	serviceaccount.PublicKeysGetter | ||||
| 	listeners []serviceaccount.Listener | ||||
| } | ||||
|  | ||||
| func (p *proxyKeyGetter) AddListener(listener serviceaccount.Listener) { | ||||
| 	p.listeners = append(p.listeners, listener) | ||||
| 	p.PublicKeysGetter.AddListener(listener) | ||||
| } | ||||
|  | ||||
| func TestServeConfiguration(t *testing.T) { | ||||
| 	s, jwksURI := setupServer(t, exampleIssuer, defaultKeys) | ||||
| 	ecKeysGetter, err := serviceaccount.StaticPublicKeysGetter([]interface{}{getPublicKey(ecdsaPublicKey)}) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	rsaKeysGetter, err := serviceaccount.StaticPublicKeysGetter([]interface{}{getPublicKey(rsaPublicKey)}) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	keysGetter := &proxyKeyGetter{PublicKeysGetter: ecKeysGetter} | ||||
| 	s, jwksURI := setupServer(t, exampleIssuer, keysGetter) | ||||
| 	defer s.Close() | ||||
|  | ||||
| 	want := Configuration{ | ||||
| 	wantEC := Configuration{ | ||||
| 		Issuer:        exampleIssuer, | ||||
| 		JWKSURI:       jwksURI, | ||||
| 		ResponseTypes: []string{"id_token"}, | ||||
| 		SubjectTypes:  []string{"public"}, | ||||
| 		SigningAlgs:   []string{"ES256", "RS256"}, | ||||
| 		SigningAlgs:   []string{"ES256"}, | ||||
| 	} | ||||
| 	wantRSA := Configuration{ | ||||
| 		Issuer:        exampleIssuer, | ||||
| 		JWKSURI:       jwksURI, | ||||
| 		ResponseTypes: []string{"id_token"}, | ||||
| 		SubjectTypes:  []string{"public"}, | ||||
| 		SigningAlgs:   []string{"RS256"}, | ||||
| 	} | ||||
|  | ||||
| 	reqURL := s.URL + "/.well-known/openid-configuration" | ||||
|  | ||||
| 	expectConfiguration(t, reqURL, wantEC) | ||||
|  | ||||
| 	// modify the underlying keys, expect the same response | ||||
| 	keysGetter.PublicKeysGetter = rsaKeysGetter | ||||
| 	expectConfiguration(t, reqURL, wantEC) | ||||
|  | ||||
| 	// notify the metadata the keys changed, expected a modified response | ||||
| 	for _, listener := range keysGetter.listeners { | ||||
| 		listener.Enqueue() | ||||
| 	} | ||||
| 	expectConfiguration(t, reqURL, wantRSA) | ||||
| } | ||||
|  | ||||
| func expectConfiguration(t *testing.T, reqURL string, want Configuration) { | ||||
| 	resp, err := http.Get(reqURL) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Get(%s) = %v, %v want: <response>, <nil>", reqURL, resp, err) | ||||
| @@ -185,16 +224,49 @@ func TestServeKeys(t *testing.T) { | ||||
|  | ||||
| 	for _, tt := range serveKeysTests { | ||||
| 		t.Run(tt.Name, func(t *testing.T) { | ||||
| 			s, _ := setupServer(t, exampleIssuer, tt.Keys) | ||||
| 			initialKeysGetter, err := serviceaccount.StaticPublicKeysGetter(tt.Keys) | ||||
| 			if err != nil { | ||||
| 				t.Fatal(err) | ||||
| 			} | ||||
| 			updatedKeysGetter, err := serviceaccount.StaticPublicKeysGetter([]interface{}{wantPubRSA}) | ||||
| 			if err != nil { | ||||
| 				t.Fatal(err) | ||||
| 			} | ||||
| 			keysGetter := &proxyKeyGetter{PublicKeysGetter: initialKeysGetter} | ||||
| 			s, _ := setupServer(t, exampleIssuer, keysGetter) | ||||
| 			defer s.Close() | ||||
|  | ||||
| 			reqURL := s.URL + "/openid/v1/jwks" | ||||
| 			expectKeys(t, reqURL, tt.WantKeys) | ||||
|  | ||||
| 			// modify the underlying keys, expect the same response | ||||
| 			keysGetter.PublicKeysGetter = updatedKeysGetter | ||||
| 			expectKeys(t, reqURL, tt.WantKeys) | ||||
|  | ||||
| 			// notify the metadata the keys changed, expected a modified response | ||||
| 			for _, listener := range keysGetter.listeners { | ||||
| 				listener.Enqueue() | ||||
| 			} | ||||
| 			expectKeys(t, reqURL, []jose.JSONWebKey{{ | ||||
| 				Algorithm:                   "RS256", | ||||
| 				Key:                         wantPubRSA, | ||||
| 				KeyID:                       rsaKeyID, | ||||
| 				Use:                         "sig", | ||||
| 				Certificates:                []*x509.Certificate{}, | ||||
| 				CertificateThumbprintSHA1:   []uint8{}, | ||||
| 				CertificateThumbprintSHA256: []uint8{}, | ||||
| 			}}) | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
| func expectKeys(t *testing.T, reqURL string, wantKeys []jose.JSONWebKey) { | ||||
| 	resp, err := http.Get(reqURL) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Get(%s) = %v, %v want: <response>, <nil>", reqURL, resp, err) | ||||
| 	} | ||||
| 			defer resp.Body.Close() | ||||
| 	defer func() { | ||||
| 		_ = resp.Body.Close() | ||||
| 	}() | ||||
|  | ||||
| 	if resp.StatusCode != http.StatusOK { | ||||
| 		t.Errorf("Get(%s) = %v, _ want: %v, _", reqURL, resp.StatusCode, http.StatusOK) | ||||
| @@ -216,16 +288,18 @@ func TestServeKeys(t *testing.T) { | ||||
| 		func(x, y *big.Int) bool { | ||||
| 			return x.Cmp(y) == 0 | ||||
| 		}) | ||||
| 			if !cmp.Equal(tt.WantKeys, ks.Keys, bigIntComparer) { | ||||
| 	if !cmp.Equal(wantKeys, ks.Keys, bigIntComparer) { | ||||
| 		t.Errorf("unexpected diff in JWKS keys (-want, +got): %v", | ||||
| 					cmp.Diff(tt.WantKeys, ks.Keys, bigIntComparer)) | ||||
| 			} | ||||
| 		}) | ||||
| 			cmp.Diff(wantKeys, ks.Keys, bigIntComparer)) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestURLBoundaries(t *testing.T) { | ||||
| 	s, _ := setupServer(t, exampleIssuer, defaultKeys) | ||||
| 	keysGetter, err := serviceaccount.StaticPublicKeysGetter(defaultKeys) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	s, _ := setupServer(t, exampleIssuer, keysGetter) | ||||
| 	defer s.Close() | ||||
|  | ||||
| 	for _, tt := range []struct { | ||||
| @@ -380,7 +454,11 @@ func TestNewOpenIDMetadata(t *testing.T) { | ||||
| 	} | ||||
| 	for _, tc := range cases { | ||||
| 		t.Run(tc.name, func(t *testing.T) { | ||||
| 			md, err := serviceaccount.NewOpenIDMetadata(tc.issuerURL, tc.jwksURI, tc.externalAddress, tc.keys) | ||||
| 			keysGetter, err := serviceaccount.StaticPublicKeysGetter(tc.keys) | ||||
| 			if err != nil { | ||||
| 				t.Fatal(err) | ||||
| 			} | ||||
| 			md, err := serviceaccount.NewOpenIDMetadataProvider(tc.issuerURL, tc.jwksURI, tc.externalAddress, keysGetter) | ||||
| 			if tc.err { | ||||
| 				if err == nil { | ||||
| 					t.Fatalf("got <nil>, want error") | ||||
| @@ -390,13 +468,13 @@ func TestNewOpenIDMetadata(t *testing.T) { | ||||
| 				t.Fatalf("got error %v, want <nil>", err) | ||||
| 			} | ||||
|  | ||||
| 			config := string(md.ConfigJSON) | ||||
| 			keyset := string(md.PublicKeysetJSON) | ||||
| 			if config != tc.wantConfig { | ||||
| 				t.Errorf("got metadata %s, want %s", config, tc.wantConfig) | ||||
| 			config, _ := md.GetConfigJSON() | ||||
| 			keyset, _ := md.GetKeysetJSON() | ||||
| 			if string(config) != tc.wantConfig { | ||||
| 				t.Errorf("got metadata %s, want %s", string(config), tc.wantConfig) | ||||
| 			} | ||||
| 			if keyset != tc.wantKeyset { | ||||
| 				t.Errorf("got keyset %s, want %s", keyset, tc.wantKeyset) | ||||
| 			if string(keyset) != tc.wantKeyset { | ||||
| 				t.Errorf("got keyset %s, want %s", string(keyset), tc.wantKeyset) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Jordan Liggitt
					Jordan Liggitt