diff --git a/staging/src/k8s.io/apiserver/plugin/pkg/authenticator/token/oidc/oidc.go b/staging/src/k8s.io/apiserver/plugin/pkg/authenticator/token/oidc/oidc.go index 0309fb581ca..9e13c58350e 100644 --- a/staging/src/k8s.io/apiserver/plugin/pkg/authenticator/token/oidc/oidc.go +++ b/staging/src/k8s.io/apiserver/plugin/pkg/authenticator/token/oidc/oidc.go @@ -464,7 +464,7 @@ func (r *claimResolver) Verifier(iss string) (*oidc.IDTokenVerifier, error) { // }, // }, // } -func (r *claimResolver) expand(c claims) error { +func (r *claimResolver) expand(ctx context.Context, c claims) error { const ( // The claim containing a map of endpoint references per claim. // OIDC Connect Core 1.0, section 5.6.2. @@ -516,14 +516,14 @@ func (r *claimResolver) expand(c claims) error { // This is maybe an aggregated claim (ep.JWT != ""). return nil } - return r.resolve(ep, c) + return r.resolve(ctx, ep, c) } // resolve requests distributed claims from all endpoints passed in, // and inserts the lookup results into allClaims. -func (r *claimResolver) resolve(endpoint endpoint, allClaims claims) error { +func (r *claimResolver) resolve(ctx context.Context, endpoint endpoint, allClaims claims) error { // TODO: cache resolved claims. - jwt, err := getClaimJWT(r.client, endpoint.URL, endpoint.AccessToken) + jwt, err := getClaimJWT(ctx, r.client, endpoint.URL, endpoint.AccessToken) if err != nil { return fmt.Errorf("while getting distributed claim %q: %v", r.claim, err) } @@ -535,7 +535,7 @@ func (r *claimResolver) resolve(endpoint endpoint, allClaims claims) error { if err != nil { return fmt.Errorf("verifying untrusted issuer %v failed: %v", untrustedIss, err) } - t, err := v.Verify(context.Background(), jwt) + t, err := v.Verify(ctx, jwt) if err != nil { return fmt.Errorf("verify distributed claim token: %v", err) } @@ -571,7 +571,7 @@ func (a *Authenticator) AuthenticateToken(ctx context.Context, token string) (*a return nil, false, fmt.Errorf("oidc: parse claims: %v", err) } if a.resolver != nil { - if err := a.resolver.expand(c); err != nil { + if err := a.resolver.expand(ctx, c); err != nil { return nil, false, fmt.Errorf("oidc: could not expand distributed claims: %v", err) } } @@ -645,10 +645,7 @@ func (a *Authenticator) AuthenticateToken(ctx context.Context, token string) (*a // token as bearer token. If the access token is "", the authorization header // will not be set. // TODO: Allow passing in JSON hints to the IDP. -func getClaimJWT(client *http.Client, url, accessToken string) (string, error) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - +func getClaimJWT(ctx context.Context, client *http.Client, url, accessToken string) (string, error) { // TODO: Allow passing request body with configurable information. req, err := http.NewRequest("GET", url, nil) if err != nil { diff --git a/staging/src/k8s.io/apiserver/plugin/pkg/authenticator/token/oidc/oidc_test.go b/staging/src/k8s.io/apiserver/plugin/pkg/authenticator/token/oidc/oidc_test.go index 1ff78a66750..b6081a1b716 100644 --- a/staging/src/k8s.io/apiserver/plugin/pkg/authenticator/token/oidc/oidc_test.go +++ b/staging/src/k8s.io/apiserver/plugin/pkg/authenticator/token/oidc/oidc_test.go @@ -296,7 +296,7 @@ func (c *claimsTest) run(t *testing.T) { t.Fatalf("serialize token: %v", err) } - got, ok, err := a.AuthenticateToken(context.Background(), token) + got, ok, err := a.AuthenticateToken(testContext(t), token) expectErr := len(c.wantErr) > 0 @@ -1581,3 +1581,9 @@ type errTransport string func (e errTransport) RoundTrip(_ *http.Request) (*http.Response, error) { return nil, fmt.Errorf("%s", e) } + +func testContext(t *testing.T) context.Context { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + return ctx +}