Describe the bug

Using NimbusJwtClientAuthenticationParametersConverter with a resolver reading jwk from a rolling key sources leads to Jwt signed with only the first resolved key.

NimbusJwsEncoder is cached in a concurrent map with the first key and is never refreshed when the Jwk kid change. A test should assert that the Kid is the still the same and if not re-compute it.

To Reproduce

Use

public OAuth2AuthorizedClientProvider getJwtBearerClientProvider() {

        var requestEntityConverter = new JwtBearerGrantRequestEntityConverter();
        requestEntityConverter.addParametersConverter(new NimbusJwtClientAuthenticationParametersConverter<>(Foo::getJWK));

        var tokenResponseClient = new DefaultJwtBearerTokenResponseClient();
        tokenResponseClient.setRequestEntityConverter(requestEntityConverter);

        var jwtBearerClientProvider = new JwtBearerOAuth2AuthorizedClientProvider(JwtSecurityConfig::getJWK);
        jwtBearerClientProvider.setAccessTokenResponseClient(tokenResponseClient);

        return jwtBearerClientProvider;
    }

where Foo::getJWK returns different keys each time.

Manually configure a non existent authorization server to force authorize request to fail. This avoid caching and force NimbusJwtClientAuthenticationParametersConverter to rebuild a Jwt assertion for each authorize request.

Debug and observe that the Jwt is always signed with the first provided key. Jwt Bearer kid is always the first one.

Note that it will throw JwtEncodingException if the key algorithm changed because the JWKMatcher will fail to match the key algorithm.

Expected behavior

JwtBearer assertion must be signed with the resolved Jwk.

Old NimbusJwsEncoder should be evicted to avoid memory leak on the long term.

Comment From: jgrandja

Thanks for reporting this bug @scrocquesel. Would you be interested in submitting a PR for the fix?

Comment From: sclorng

I'm not a Java expert. I'm not sure how this could be done. If NimbusJwsEncoder would expose its jwkSource, it may be possible to remap the value if jwks changed.

If we can only use the hashmap key to guess the current jwk used by the encoder, we may need to use a composite key with the given kid to compute a new instance for every different kid. But how, should we evict the old one as we will never see the kid again. Calling removeEntryIf before each call to compute ?

Is there any JWKSource implementation that can take a delegate function

Something like

NimbusJwsEncoder jwsEncoder = this.jwsEncoders.computeIfAbsent(clientRegistration.getRegistrationId(),
                (clientRegistrationId) -> {
                    JWKSource<SecurityContext> jwkSource = new ResolverJWKSet<>(() -> new JWKSet(this.jwkResolver.apply(clientRegistration)));
                    return new NimbusJwsEncoder(jwkSource);
                });

Can clientRegistration captured ? It seems to be tight linked to the clientRegistrationId used for the key.