/*
 * Decompiled with CFR 0.152.
 */
package org.wildfly.security.auth.realm.token.validator;

import java.io.IOException;
import java.io.InputStream;
import java.math.BigInteger;
import java.net.URL;
import java.net.URLConnection;
import java.security.KeyFactory;
import java.security.NoSuchAlgorithmException;
import java.security.PublicKey;
import java.security.interfaces.RSAPublicKey;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.RSAPublicKeySpec;
import java.util.Base64;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import javax.json.Json;
import javax.json.JsonArray;
import javax.json.JsonObject;
import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.HttpsURLConnection;
import javax.net.ssl.SSLContext;
import org.wildfly.common.Assert;
import org.wildfly.security.auth.realm.token._private.ElytronMessages;

class JwkManager {
    private final Map<URL, Map<String, RSAPublicKey>> keys = new LinkedHashMap<URL, Map<String, RSAPublicKey>>();
    private final Map<URL, Long> timeouts = new ConcurrentHashMap<URL, Long>();
    private final SSLContext sslContext;
    private final HostnameVerifier hostnameVerifier;
    private final long updateTimeout;
    private static final int CONNECTION_TIMEOUT = 2000;

    JwkManager(SSLContext sslContext, HostnameVerifier hostnameVerifier, long updateTimeout) {
        this.sslContext = sslContext;
        this.hostnameVerifier = hostnameVerifier;
        this.updateTimeout = updateTimeout;
    }

    public PublicKey getPublicKey(String kid, URL url) {
        Map<String, RSAPublicKey> urlKeys = this.checkRemote(url);
        if (urlKeys == null) {
            return null;
        }
        PublicKey pk = urlKeys.get(kid);
        if (pk == null) {
            ElytronMessages.log.warn("Unknown kid: " + kid);
            return null;
        }
        return pk;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private Map<String, RSAPublicKey> checkRemote(URL url) {
        Map<String, RSAPublicKey> urlKeys;
        Assert.checkNotNullParam("url", url);
        long lastUpdate = 0L;
        Map<Object, Object> map = this.keys;
        synchronized (map) {
            urlKeys = this.keys.get(url);
            if (urlKeys == null) {
                urlKeys = new ConcurrentHashMap<String, RSAPublicKey>();
                this.keys.put(url, urlKeys);
            }
        }
        map = urlKeys;
        synchronized (map) {
            if (this.timeouts.containsKey(url)) {
                lastUpdate = this.timeouts.get(url);
            }
            if (lastUpdate + this.updateTimeout <= System.currentTimeMillis()) {
                Map<String, RSAPublicKey> newJwks = JwkManager.getJwksFromUrl(url, this.sslContext, this.hostnameVerifier);
                if (newJwks == null) {
                    ElytronMessages.log.unableToFetchJwks(url.toString());
                    return null;
                }
                urlKeys.clear();
                urlKeys.putAll(newJwks);
                this.timeouts.put(url, System.currentTimeMillis());
            }
            return urlKeys;
        }
    }

    private static Map<String, RSAPublicKey> getJwksFromUrl(URL url, SSLContext sslContext, HostnameVerifier hostnameVerifier) {
        JsonObject response = null;
        try {
            URLConnection connection = url.openConnection();
            if (connection instanceof HttpsURLConnection) {
                HttpsURLConnection conn = (HttpsURLConnection)connection;
                conn.setRequestMethod("GET");
                conn.setSSLSocketFactory(sslContext.getSocketFactory());
                conn.setHostnameVerifier(hostnameVerifier);
                conn.setConnectTimeout(2000);
                conn.setReadTimeout(2000);
                conn.connect();
                InputStream inputStream = conn.getInputStream();
                response = Json.createReader(inputStream).readObject();
            }
        }
        catch (IOException e) {
            ElytronMessages.log.warn("Unable to connect to " + url.toString());
            return null;
        }
        if (response == null) {
            ElytronMessages.log.warn("No response when fetching jwk set from " + url.toString());
            return null;
        }
        JsonArray jwks = response.getJsonArray("keys");
        if (jwks == null) {
            ElytronMessages.log.warn("Unable to parse jwks");
            return null;
        }
        LinkedHashMap<String, RSAPublicKey> res = new LinkedHashMap<String, RSAPublicKey>();
        for (int i = 0; i < jwks.size(); ++i) {
            JsonObject jwk = jwks.getJsonObject(i);
            String kid = jwk.getString("kid", null);
            String kty = jwk.getString("kty", null);
            String e1 = jwk.getString("e", null);
            String n1 = jwk.getString("n", null);
            if (kid == null) {
                ElytronMessages.log.tokenRealmJwkMissingClaim("kid");
                continue;
            }
            if (!"RSA".equals(kty)) {
                ElytronMessages.log.tokenRealmJwkMissingClaim("kty");
                continue;
            }
            if (e1 == null) {
                ElytronMessages.log.tokenRealmJwkMissingClaim("e");
                continue;
            }
            if (n1 == null) {
                ElytronMessages.log.tokenRealmJwkMissingClaim("n");
                continue;
            }
            BigInteger e = new BigInteger(Base64.getDecoder().decode(e1));
            BigInteger n = new BigInteger(Base64.getDecoder().decode(n1));
            RSAPublicKeySpec keySpec = new RSAPublicKeySpec(n, e);
            try {
                RSAPublicKey publicKey = (RSAPublicKey)KeyFactory.getInstance("RSA").generatePublic(keySpec);
                res.put(kid, publicKey);
                continue;
            }
            catch (NoSuchAlgorithmException | InvalidKeySpecException ex) {
                ElytronMessages.log.info("Fetched jwk could not be parsed, ignoring...");
                ex.printStackTrace();
            }
        }
        return res;
    }
}

