package io.strimzi.kafka.oauth.validator;

import com.fasterxml.jackson.databind.JsonNode;
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import io.strimzi.kafka.oauth.common.HttpUtil;
import io.strimzi.kafka.oauth.common.JSONUtil;
import io.strimzi.kafka.oauth.common.PrincipalExtractor;
import io.strimzi.kafka.oauth.common.TimeUtil;
import io.strimzi.kafka.oauth.common.TokenInfo;
import io.strimzi.kafka.oauth.jsonpath.JsonPathFilterQuery;
import io.strimzi.kafka.oauth.validator.TokenValidationException;
import java.net.URI;
import java.net.URISyntaxException;
import java.security.Provider;
import java.security.PublicKey;
import java.security.Security;
import java.util.Collections;
import java.util.Map;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.SSLSocketFactory;
import org.apache.kafka.common.utils.Time;
import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.keycloak.TokenVerifier;
import org.keycloak.crypto.AsymmetricSignatureVerifierContext;
import org.keycloak.crypto.KeyWrapper;
import org.keycloak.crypto.SignatureVerifierContext;
import org.keycloak.exceptions.TokenSignatureInvalidException;
import org.keycloak.jose.jwk.JSONWebKeySet;
import org.keycloak.jose.jwk.JWK;
import org.keycloak.representations.AccessToken;
import org.keycloak.util.JWKSUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:io/strimzi/kafka/oauth/validator/JWTSignatureValidator.class */
public class JWTSignatureValidator implements TokenValidator {
    private static final Logger log = LoggerFactory.getLogger(JWTSignatureValidator.class);
    private static AtomicBoolean bouncyInstalled = new AtomicBoolean(false);
    private static final TokenVerifier.TokenTypeCheck TOKEN_TYPE_CHECK = new TokenVerifier.TokenTypeCheck("Bearer");
    private final BackOffTaskScheduler fastScheduler;
    private final URI keysUri;
    private final String issuerUri;
    private final int maxStaleSeconds;
    private final boolean checkAccessTokenType;
    private final String audience;
    private final JsonPathFilterQuery customClaimMatcher;
    private final SSLSocketFactory socketFactory;
    private final HostnameVerifier hostnameVerifier;
    private final PrincipalExtractor principalExtractor;
    private long lastFetchTime;
    private Map<String, PublicKey> cache = Collections.emptyMap();
    private Map<String, PublicKey> oldCache = Collections.emptyMap();

    public JWTSignatureValidator(String str, SSLSocketFactory sSLSocketFactory, HostnameVerifier hostnameVerifier, PrincipalExtractor principalExtractor, String str2, int i, int i2, int i3, boolean z, String str3, String str4, boolean z2, int i4) {
        if (str == null) {
            throw new IllegalArgumentException("keysEndpointUri == null");
        }
        try {
            this.keysUri = new URI(str);
            if (sSLSocketFactory != null && !"https".equals(this.keysUri.getScheme())) {
                throw new IllegalArgumentException("SSL socket factory set but keysEndpointUri not 'https'");
            }
            this.socketFactory = sSLSocketFactory;
            if (hostnameVerifier != null && !"https".equals(this.keysUri.getScheme())) {
                throw new IllegalArgumentException("Certificate hostname verifier set but keysEndpointUri not 'https'");
            }
            this.hostnameVerifier = hostnameVerifier;
            this.principalExtractor = principalExtractor;
            if (str2 != null) {
                try {
                    new URI(str2);
                } catch (URISyntaxException e) {
                    throw new IllegalArgumentException("Value of validIssuerUri not a valid URI: " + str2, e);
                }
            }
            this.issuerUri = str2;
            validateRefreshConfig(i, i3);
            this.maxStaleSeconds = i3;
            this.checkAccessTokenType = z;
            this.audience = str3;
            this.customClaimMatcher = parseCustomClaimCheck(str4);
            if (z2 && !bouncyInstalled.getAndSet(true)) {
                log.info("BouncyCastle security provider installed at position: " + Security.insertProviderAt(new BouncyCastleProvider(), i4));
                if (log.isDebugEnabled()) {
                    StringBuilder sb = new StringBuilder("Installed security providers:\n");
                    for (Provider provider : Security.getProviders()) {
                        sb.append("  - " + provider.toString() + "  [" + provider.getClass().getName() + "]\n");
                        sb.append("   " + provider.getInfo() + "\n");
                    }
                    log.debug(sb.toString());
                }
            }
            fetchKeys();
            ScheduledExecutorService newSingleThreadScheduledExecutor = Executors.newSingleThreadScheduledExecutor(new DaemonThreadFactory());
            this.fastScheduler = new BackOffTaskScheduler(newSingleThreadScheduledExecutor, i2, i, () -> {
                fetchKeys();
            });
            setupRefreshKeysJob(newSingleThreadScheduledExecutor, i);
            if (log.isDebugEnabled()) {
                log.debug("Configured JWTSignatureValidator:\n    keysEndpointUri: " + str + "\n    sslSocketFactory: " + sSLSocketFactory + "\n    hostnameVerifier: " + this.hostnameVerifier + "\n    principalExtractor: " + principalExtractor + "\n    validIssuerUri: " + str2 + "\n    certsRefreshSeconds: " + i + "\n    certsRefreshMinPauseSeconds: " + i2 + "\n    certsExpirySeconds: " + i3 + "\n    checkAccessTokenType: " + z + "\n    audience: " + str3 + "\n    customClaimCheck: " + str4 + "\n    enableBouncyCastleProvider: " + z2 + "\n    bouncyCastleProviderPosition: " + i4);
            }
        } catch (URISyntaxException e2) {
            throw new IllegalArgumentException("Invalid keysEndpointUri: " + str, e2);
        }
    }

    private JsonPathFilterQuery parseCustomClaimCheck(String str) {
        if (str == null) {
            return null;
        }
        String trim = str.trim();
        if (trim.length() == 0) {
            throw new IllegalArgumentException("Value of customClaimCheck is empty");
        }
        return JsonPathFilterQuery.parse(trim);
    }

    private void validateRefreshConfig(int i, int i2) {
        if (i <= 0) {
            throw new IllegalArgumentException("refreshSeconds has to be a positive number - (refreshSeconds=" + i + ")");
        }
        if (i2 < i + 60) {
            throw new IllegalArgumentException("expirySeconds has to be at least 60 seconds longer than refreshSeconds - (expirySeconds=" + i2 + ", refreshSeconds=" + i + ")");
        }
    }

    private void setupRefreshKeysJob(ScheduledExecutorService scheduledExecutorService, int i) {
        scheduledExecutorService.scheduleAtFixedRate(() -> {
            try {
                this.fastScheduler.scheduleTask();
            } catch (Exception e) {
                log.error(e.getMessage(), e);
            }
        }, i, i, TimeUnit.SECONDS);
    }

    private PublicKey getPublicKey(String str) {
        return getKeyUnlessStale(str);
    }

    private PublicKey getKeyUnlessStale(String str) {
        if (this.lastFetchTime + (this.maxStaleSeconds * 1000) <= System.currentTimeMillis()) {
            log.warn("The cached public key with id '" + str + "' is expired!");
            return null;
        }
        PublicKey publicKey = this.cache.get(str);
        if (publicKey == null) {
            log.warn("No public key for id: " + str);
        }
        return publicKey;
    }

    private void fetchKeys() {
        try {
            Map<String, PublicKey> unmodifiableMap = Collections.unmodifiableMap(JWKSUtils.getKeysForUse((JSONWebKeySet) HttpUtil.get(this.keysUri, this.socketFactory, this.hostnameVerifier, null, JSONWebKeySet.class), JWK.Use.SIG));
            if (!this.cache.equals(unmodifiableMap)) {
                log.info("JWKS keys change detected. Keys updated.");
                this.oldCache = this.cache;
                this.cache = unmodifiableMap;
            }
            this.lastFetchTime = System.currentTimeMillis();
        } catch (Exception e) {
            throw new RuntimeException("Failed to fetch public keys needed to validate JWT signatures: " + this.keysUri, e);
        }
    }

    @Override // io.strimzi.kafka.oauth.validator.TokenValidator
    @SuppressFBWarnings(value = {"BC_UNCONFIRMED_CAST_OF_RETURN_VALUE"}, justification = "We tell TokenVerifier to parse AccessToken. It will return AccessToken or fail.")
    public TokenInfo validate(String str) {
        TokenVerifier<AccessToken> initializeTokenVerifier = initializeTokenVerifier(str);
        try {
            String keyId = initializeTokenVerifier.getHeader().getKeyId();
            try {
                KeyWrapper keyWrapper = new KeyWrapper();
                PublicKey publicKey = getPublicKey(keyId);
                if (publicKey == null) {
                    if (this.oldCache.get(keyId) != null) {
                        throw new TokenValidationException("Token validation failed: The signing key is no longer valid (kid:" + keyId + ")");
                    }
                    this.fastScheduler.scheduleTask();
                    throw new TokenValidationException("Token validation failed: Unknown signing key (kid:" + keyId + ")");
                }
                keyWrapper.setPublicKey(publicKey);
                keyWrapper.setAlgorithm(initializeTokenVerifier.getHeader().getAlgorithm().name());
                keyWrapper.setKid(keyId);
                log.debug("Signature algorithm used: [{}]", publicKey.getAlgorithm());
                SignatureVerifierContext eCDSASignatureVerifierContext = isAlgorithmEC(publicKey.getAlgorithm()) ? new ECDSASignatureVerifierContext(keyWrapper) : new AsymmetricSignatureVerifierContext(keyWrapper);
                initializeTokenVerifier.verifierContext(eCDSASignatureVerifierContext);
                log.debug("SignatureVerifierContext set to: {}", eCDSASignatureVerifierContext);
                initializeTokenVerifier.verify();
                AccessToken accessToken = (AccessToken) initializeTokenVerifier.getToken();
                long intValue = accessToken.getExp() != null ? accessToken.getExp().intValue() * 1000 : 0L;
                if (Time.SYSTEM.milliseconds() > intValue) {
                    throw new TokenExpiredException("Token expired at: " + intValue + " (" + TimeUtil.formatIsoDateTimeUTC(intValue) + " UTC)");
                }
                JsonNode jsonNode = null;
                if (this.customClaimMatcher != null) {
                    jsonNode = JSONUtil.asJson(accessToken);
                    if (!this.customClaimMatcher.matches(jsonNode)) {
                        throw new TokenValidationException("Token validation failed: Custom claim check failed");
                    }
                }
                return new TokenInfo(accessToken, str, extractPrincipal(accessToken, jsonNode));
            } catch (TokenSignatureInvalidException e) {
                throw new TokenSignatureException("Signature check failed:", e);
            } catch (TokenValidationException e2) {
                throw e2;
            } catch (Exception e3) {
                throw new TokenValidationException("Token validation failed:", e3);
            }
        } catch (Exception e4) {
            throw new TokenValidationException("Token signature validation failed: " + str, e4).status(TokenValidationException.Status.INVALID_TOKEN);
        }
    }

    private String extractPrincipal(AccessToken accessToken, JsonNode jsonNode) {
        String str = null;
        if (this.principalExtractor.isConfigured()) {
            if (jsonNode == null) {
                jsonNode = JSONUtil.asJson(accessToken);
            }
            str = this.principalExtractor.getPrincipal(jsonNode);
        }
        if (str == null && !this.principalExtractor.isConfigured()) {
            str = this.principalExtractor.getSub(accessToken);
        }
        if (str == null) {
            throw new RuntimeException("Failed to extract principal - check usernameClaim, fallbackUsernameClaim configuration");
        }
        return str;
    }

    private TokenVerifier<AccessToken> initializeTokenVerifier(String str) {
        TokenVerifier<AccessToken> create = TokenVerifier.create(str, AccessToken.class);
        if (this.issuerUri != null) {
            create.realmUrl(this.issuerUri);
        }
        if (this.checkAccessTokenType) {
            create.withChecks(new TokenVerifier.Predicate[]{TOKEN_TYPE_CHECK});
        }
        if (this.audience != null) {
            create.audience(this.audience);
        }
        return create;
    }

    private static boolean isAlgorithmEC(String str) {
        return "EC".equals(str) || "ECDSA".equals(str);
    }
}
