package io.strimzi.kafka.oauth.validator;

import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import io.strimzi.kafka.oauth.common.HttpUtil;
import io.strimzi.kafka.oauth.common.TimeUtil;
import io.strimzi.kafka.oauth.common.TokenInfo;
import io.strimzi.kafka.oauth.validator.TokenValidationException;
import java.net.URI;
import java.net.URISyntaxException;
import java.security.PublicKey;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.SSLSocketFactory;
import org.apache.kafka.common.utils.Time;
import org.keycloak.TokenVerifier;
import org.keycloak.exceptions.TokenSignatureInvalidException;
import org.keycloak.jose.jwk.JSONWebKeySet;
import org.keycloak.jose.jwk.JWK;
import org.keycloak.representations.AccessToken;
import org.keycloak.util.JWKSUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:io/strimzi/kafka/oauth/validator/JWTSignatureValidator.class */
public class JWTSignatureValidator implements TokenValidator {
    private static final Logger log = LoggerFactory.getLogger(JWTSignatureValidator.class);
    private final ScheduledExecutorService scheduler;
    private final URI keysUri;
    private final String issuerUri;
    private final int maxStaleSeconds;
    private final boolean defaultChecks;
    private final String audience;
    private final SSLSocketFactory socketFactory;
    private final HostnameVerifier hostnameVerifier;
    private long lastFetchTime;
    private Map<String, PublicKey> cache = new ConcurrentHashMap();

    /* loaded from: input_file:io/strimzi/kafka/oauth/validator/JWTSignatureValidator$DaemonThreadFactory.class */
    static class DaemonThreadFactory implements ThreadFactory {
        DaemonThreadFactory() {
        }

        @Override // java.util.concurrent.ThreadFactory
        public Thread newThread(Runnable runnable) {
            Thread newThread = Executors.defaultThreadFactory().newThread(runnable);
            newThread.setDaemon(true);
            return newThread;
        }
    }

    public JWTSignatureValidator(String str, SSLSocketFactory sSLSocketFactory, HostnameVerifier hostnameVerifier, String str2, int i, int i2, boolean z, String str3) {
        if (str == null) {
            throw new IllegalArgumentException("keysEndpointUri == null");
        }
        try {
            this.keysUri = new URI(str);
            if (sSLSocketFactory != null && !"https".equals(this.keysUri.getScheme())) {
                throw new IllegalArgumentException("SSL socket factory set but keysEndpointUri not 'https'");
            }
            this.socketFactory = sSLSocketFactory;
            if (hostnameVerifier != null && !"https".equals(this.keysUri.getScheme())) {
                throw new IllegalArgumentException("Certificate hostname verifier set but keysEndpointUri not 'https'");
            }
            this.hostnameVerifier = hostnameVerifier;
            if (str2 == null) {
                throw new IllegalArgumentException("validIssuerUri == null");
            }
            this.issuerUri = str2;
            if (i2 < i + 60) {
                throw new IllegalArgumentException("expirySeconds has to be at least 60 seconds longer than refreshSeconds");
            }
            this.maxStaleSeconds = i2;
            this.defaultChecks = z;
            this.audience = str3;
            fetchKeys();
            this.scheduler = Executors.newSingleThreadScheduledExecutor(new DaemonThreadFactory());
            this.scheduler.scheduleAtFixedRate(() -> {
                fetchKeys();
            }, i, i, TimeUnit.SECONDS);
        } catch (URISyntaxException e) {
            throw new IllegalArgumentException("Invalid keysEndpointUri: " + str, e);
        }
    }

    private PublicKey getPublicKey(String str) {
        return getKeyUnlessStale(str);
    }

    private PublicKey getKeyUnlessStale(String str) {
        if (this.lastFetchTime + (this.maxStaleSeconds * 1000) <= System.currentTimeMillis()) {
            log.warn("The cached public key with id '" + str + "' is expired!");
            return null;
        }
        PublicKey publicKey = this.cache.get(str);
        if (publicKey == null) {
            log.warn("No public key for id: " + str);
        }
        return publicKey;
    }

    private void fetchKeys() {
        try {
            this.cache = JWKSUtils.getKeysForUse((JSONWebKeySet) HttpUtil.get(this.keysUri, this.socketFactory, this.hostnameVerifier, null, JSONWebKeySet.class), JWK.Use.SIG);
            this.lastFetchTime = System.currentTimeMillis();
        } catch (Exception e) {
            throw new RuntimeException("Failed to fetch public keys needed to validate JWT signatures: " + this.keysUri, e);
        }
    }

    @Override // io.strimzi.kafka.oauth.validator.TokenValidator
    @SuppressFBWarnings(value = {"BC_UNCONFIRMED_CAST_OF_RETURN_VALUE"}, justification = "We tell TokenVerifier to parse AccessToken. It will return AccessToken or fail.")
    public TokenInfo validate(String str) {
        TokenVerifier create = TokenVerifier.create(str, AccessToken.class);
        if (this.defaultChecks) {
            create.withDefaultChecks().realmUrl(this.issuerUri);
        }
        try {
            create.publicKey(getPublicKey(create.getHeader().getKeyId()));
            if (this.audience != null) {
                create.audience(this.audience);
            }
            try {
                create.verify();
                AccessToken token = create.getToken();
                long expiration = token.getExpiration() * 1000;
                if (Time.SYSTEM.milliseconds() > expiration) {
                    throw new TokenExpiredException("Token expired at: " + expiration + " (" + TimeUtil.formatIsoDateTimeUTC(expiration) + ")");
                }
                return new TokenInfo(token, str);
            } catch (Exception e) {
                throw new TokenValidationException("Token validation failed:", e);
            } catch (TokenSignatureInvalidException e2) {
                throw new TokenSignatureException("Signature check failed:", e2);
            }
        } catch (Exception e3) {
            throw new TokenValidationException("Token signature validation failed: " + str, e3).status(TokenValidationException.Status.INVALID_TOKEN);
        }
    }
}
