package org.jgroups.protocols;

import java.math.BigInteger;
import java.net.InetAddress;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.KeyStore;
import java.security.KeyStoreException;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.security.cert.X509Certificate;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManagerFactory;
import javax.security.auth.x500.X500Principal;
import org.jgroups.Global;
import org.jgroups.JChannel;
import org.jgroups.View;
import org.jgroups.protocols.pbcast.GMS;
import org.jgroups.protocols.pbcast.NAKACK2;
import org.jgroups.protocols.pbcast.STABLE;
import org.jgroups.stack.IpAddress;
import org.jgroups.util.DefaultSocketFactory;
import org.jgroups.util.MyReceiver;
import org.jgroups.util.Util;
import org.testng.AssertJUnit;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;
import org.wildfly.security.x500.cert.BasicConstraintsExtension;
import org.wildfly.security.x500.cert.SelfSignedX509CertificateAndSigningKey;
import org.wildfly.security.x500.cert.X509CertificateBuilder;

@Test(groups = {Global.FUNCTIONAL, Global.ENCRYPT}, singleThreaded = true)
/* loaded from: input_file:org/jgroups/protocols/TLSTest.class */
public class TLSTest {
    public static final String PROTOCOL = "TLSv1.2";
    public static final String BASE_DN = "CN=%s,OU=JGroups,O=JBoss,L=Red Hat";
    public static final String KEY_PASSWORD = "secret";
    public static final String KEY_ALGORITHM = "RSA";
    public static final String KEY_SIGNATURE_ALGORITHM = "SHA256withRSA";
    public static final String KEYSTORE_TYPE = "pkcs12";
    JChannel a;
    JChannel b;
    JChannel c;
    MyReceiver<String> ra;
    MyReceiver<String> rb;
    MyReceiver<String> rc;
    private final AtomicLong certSerial = new AtomicLong(1);
    Map<String, KeyStore> keyStores = new HashMap();
    Map<String, SSLContext> sslContexts = new HashMap();
    final String cluster_name = getClass().getSimpleName();

    @BeforeClass
    public void init() throws Exception {
        KeyPair generateKeyPair = KeyPairGenerator.getInstance(KEY_ALGORITHM).generateKeyPair();
        PrivateKey privateKey = generateKeyPair.getPrivate();
        PublicKey publicKey = generateKeyPair.getPublic();
        KeyStore keyStore = KeyStore.getInstance(KEYSTORE_TYPE);
        keyStore.load(null);
        X500Principal dn = dn("CA");
        SelfSignedX509CertificateAndSigningKey createSelfSignedCertificate = createSelfSignedCertificate(dn, true, "ca");
        keyStore.setCertificateEntry("ca", createSelfSignedCertificate.getSelfSignedCertificate());
        for (String str : Arrays.asList("A", "B", "C")) {
            this.keyStores.put(str, createSignedCertificate(privateKey, publicKey, createSelfSignedCertificate, dn, str, keyStore));
        }
        SelfSignedX509CertificateAndSigningKey createSelfSignedCertificate2 = createSelfSignedCertificate(dn, true, "other");
        this.keyStores.put("O", createKeyStore(keyStore2 -> {
            try {
                keyStore2.setCertificateEntry("O", createSelfSignedCertificate2.getSelfSignedCertificate());
            } catch (KeyStoreException e) {
                throw new RuntimeException(e);
            }
        }));
        TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
        trustManagerFactory.init(keyStore);
        for (Map.Entry<String, KeyStore> entry : this.keyStores.entrySet()) {
            SSLContext sSLContext = SSLContext.getInstance(PROTOCOL);
            KeyManagerFactory keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
            keyManagerFactory.init(entry.getValue(), KEY_PASSWORD.toCharArray());
            String key = entry.getKey();
            sSLContext.init(keyManagerFactory.getKeyManagers(), key.charAt(0) < 'D' ? trustManagerFactory.getTrustManagers() : null, null);
            this.sslContexts.put(key, sSLContext);
        }
    }

    @Test
    public void testTLS() throws Exception {
        JChannel connect = create("A").connect(this.cluster_name);
        MyReceiver<String> rawMsgs = new MyReceiver().rawMsgs(true);
        this.ra = rawMsgs;
        this.a = connect.setReceiver(rawMsgs);
        JChannel connect2 = create("B").connect(this.cluster_name);
        MyReceiver<String> rawMsgs2 = new MyReceiver().rawMsgs(true);
        this.rb = rawMsgs2;
        this.b = connect2.setReceiver(rawMsgs2);
        JChannel connect3 = create("C").connect(this.cluster_name);
        MyReceiver<String> rawMsgs3 = new MyReceiver().rawMsgs(true);
        this.rc = rawMsgs3;
        this.c = connect3.setReceiver(rawMsgs3);
        Util.waitUntilAllChannelsHaveSameView(10000L, 500L, this.a, this.b, this.c);
        verifyForbiddenJoiner("U");
        verifyForbiddenJoiner("O");
        Util.close(this.c, this.b, this.a);
    }

    private void verifyForbiddenJoiner(String str) throws Exception {
        JChannel create = create(str);
        ((GMS) create.getProtocolStack().findProtocol(GMS.class)).setMaxJoinAttempts(1);
        try {
            create.connect(this.cluster_name);
        } catch (Exception e) {
        }
        for (int i = 0; i < 10 && this.a.getView().size() <= 3; i++) {
            Util.sleep(500L);
        }
        Arrays.asList(this.a, this.b, this.c).forEach(jChannel -> {
            View view = jChannel.getView();
            AssertJUnit.assertEquals(3, view.size());
            AssertJUnit.assertTrue(view.containsMembers(this.a.getAddress(), this.b.getAddress(), this.c.getAddress()));
        });
    }

    private JChannel create(String str) throws Exception {
        TCP tcp = new TCP();
        tcp.setBindAddress(InetAddress.getLoopbackAddress());
        tcp.setBindPort(9600);
        tcp.setSocketFactory(this.sslContexts.containsKey(str) ? new DefaultSocketFactory(this.sslContexts.get(str)) : new DefaultSocketFactory());
        TCPPING tcpping = new TCPPING();
        tcpping.setInitialHosts2(Collections.singletonList(new IpAddress(tcp.getBindAddress(), tcp.getBindPort())));
        return new JChannel(tcp, tcpping, new NAKACK2(), new UNICAST3(), new STABLE(), new GMS()).name(str);
    }

    static X500Principal dn(String str) {
        return new X500Principal(String.format(BASE_DN, str));
    }

    protected SelfSignedX509CertificateAndSigningKey createSelfSignedCertificate(X500Principal x500Principal, boolean z, String str) {
        SelfSignedX509CertificateAndSigningKey.Builder keyAlgorithmName = SelfSignedX509CertificateAndSigningKey.builder().setDn(x500Principal).setSignatureAlgorithmName(KEY_SIGNATURE_ALGORITHM).setKeyAlgorithmName(KEY_ALGORITHM);
        if (z) {
            keyAlgorithmName.addExtension(false, "BasicConstraints", "CA:true,pathlen:2147483647");
        }
        return keyAlgorithmName.build();
    }

    protected KeyStore createSignedCertificate(PrivateKey privateKey, PublicKey publicKey, SelfSignedX509CertificateAndSigningKey selfSignedX509CertificateAndSigningKey, X500Principal x500Principal, String str, KeyStore keyStore) {
        try {
            X509Certificate selfSignedCertificate = selfSignedX509CertificateAndSigningKey.getSelfSignedCertificate();
            X509Certificate build = new X509CertificateBuilder().setIssuerDn(x500Principal).setSubjectDn(dn(str)).setSignatureAlgorithmName(KEY_SIGNATURE_ALGORITHM).setSigningKey(selfSignedX509CertificateAndSigningKey.getSigningKey()).setPublicKey(publicKey).setSerialNumber(BigInteger.valueOf(this.certSerial.getAndIncrement())).addExtension(new BasicConstraintsExtension(false, false, -1)).build();
            keyStore.setCertificateEntry(str, build);
            return createKeyStore(keyStore2 -> {
                try {
                    keyStore2.setCertificateEntry("ca", selfSignedCertificate);
                    keyStore2.setKeyEntry(str, privateKey, KEY_PASSWORD.toCharArray(), new X509Certificate[]{build, selfSignedCertificate});
                } catch (KeyStoreException e) {
                    throw new RuntimeException(e);
                }
            });
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private static KeyStore createKeyStore(Consumer<KeyStore> consumer) {
        try {
            KeyStore keyStore = KeyStore.getInstance(KEYSTORE_TYPE);
            keyStore.load(null);
            consumer.accept(keyStore);
            return keyStore;
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
}
