package net.shibboleth.oidc.security.impl;

import com.nimbusds.jose.EncryptionMethod;
import com.nimbusds.jose.JWEAlgorithm;
import com.nimbusds.jose.JWEHeader;
import com.nimbusds.jose.JWEObject;
import com.nimbusds.jose.Payload;
import com.nimbusds.jose.crypto.AESEncrypter;
import com.nimbusds.jose.crypto.DirectEncrypter;
import com.nimbusds.jose.crypto.ECDHEncrypter;
import com.nimbusds.jose.crypto.RSAEncrypter;
import com.nimbusds.jwt.EncryptedJWT;
import com.nimbusds.jwt.JWT;
import java.security.interfaces.ECPublicKey;
import java.security.interfaces.RSAPublicKey;
import java.util.function.BiConsumer;
import java.util.function.Function;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import net.shibboleth.oidc.security.CredentialConversionUtil;
import net.shibboleth.oidc.security.jose.EncryptionParameters;
import net.shibboleth.oidc.security.jose.context.SecurityParametersContext;
import net.shibboleth.utilities.java.support.annotation.constraint.NonnullAfterInit;
import net.shibboleth.utilities.java.support.annotation.constraint.NotEmpty;
import net.shibboleth.utilities.java.support.component.ComponentInitializationException;
import net.shibboleth.utilities.java.support.component.ComponentSupport;
import net.shibboleth.utilities.java.support.logic.Constraint;
import net.shibboleth.utilities.java.support.primitive.StringSupport;
import org.opensaml.messaging.context.MessageContext;
import org.opensaml.messaging.context.navigate.ChildContextLookup;
import org.opensaml.messaging.handler.AbstractMessageHandler;
import org.opensaml.messaging.handler.MessageHandlerException;
import org.opensaml.security.credential.Credential;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:net/shibboleth/oidc/security/impl/EncryptJWTHandler.class */
public class EncryptJWTHandler extends AbstractMessageHandler {

    @NonnullAfterInit
    private Function<MessageContext, Payload> payloadToEncryptLookupStrategy;

    @NonnullAfterInit
    private BiConsumer<JWT, MessageContext> jwtUpdateConsumer;

    @Nullable
    private EncryptionParameters encryptionParameters;

    @Nonnull
    private final Logger log = LoggerFactory.getLogger(EncryptJWTHandler.class);

    @Nonnull
    private String logName = "not-specified";

    @Nonnull
    private Function<MessageContext, SecurityParametersContext> securityParametersLookupStrategy = new ChildContextLookup(SecurityParametersContext.class);

    public void setLogName(@NotEmpty @Nonnull String str) {
        ComponentSupport.ifInitializedThrowUnmodifiabledComponentException(this);
        ComponentSupport.ifDestroyedThrowDestroyedComponentException(this);
        this.logName = Constraint.isNotEmpty(str, "ForFriendlyName can not be null or empty");
    }

    public void setJwtUpdateConsumer(BiConsumer<JWT, MessageContext> biConsumer) {
        ComponentSupport.ifInitializedThrowUnmodifiabledComponentException(this);
        ComponentSupport.ifDestroyedThrowDestroyedComponentException(this);
        this.jwtUpdateConsumer = (BiConsumer) Constraint.isNotNull(biConsumer, "JWT Update Consumer can not be null");
    }

    public void setPayloadToEncryptLookupStrategy(@Nonnull Function<MessageContext, Payload> function) {
        ComponentSupport.ifInitializedThrowUnmodifiabledComponentException(this);
        ComponentSupport.ifDestroyedThrowDestroyedComponentException(this);
        this.payloadToEncryptLookupStrategy = (Function) Constraint.isNotNull(function, "Payload To Encrypt Lookup Strategy can not be null");
    }

    public void setSecurityParametersLookupStrategy(@Nonnull Function<MessageContext, SecurityParametersContext> function) {
        ComponentSupport.ifInitializedThrowUnmodifiabledComponentException(this);
        this.securityParametersLookupStrategy = (Function) Constraint.isNotNull(function, "SecurityParameterContext lookup strategy cannot be null");
    }

    protected void doInitialize() throws ComponentInitializationException {
        if (this.payloadToEncryptLookupStrategy == null) {
            throw new ComponentInitializationException("Payload To Encrypt Lookup Strategy can not be null");
        }
        if (this.jwtUpdateConsumer == null) {
            throw new ComponentInitializationException("JWT Update Consumer can not be null");
        }
        super.doInitialize();
    }

    protected boolean doPreInvoke(@Nonnull MessageContext messageContext) throws MessageHandlerException {
        if (!super.doPreInvoke(messageContext)) {
            return false;
        }
        SecurityParametersContext apply = this.securityParametersLookupStrategy.apply(messageContext);
        if (apply == null) {
            this.log.trace("{} Message context did not contain an encryption parameters context, encryption skipped", getLogPrefix());
            return false;
        }
        this.encryptionParameters = apply.getEncryptionParameters();
        if (this.encryptionParameters == null) {
            this.log.debug("{} Message context did not contain encryption parameters, '{}' will not be encrypted", getLogPrefix(), this.logName);
            return false;
        }
        if (StringSupport.trimOrNull(this.encryptionParameters.getKeyTransportEncryptionAlgorithm()) == null || StringSupport.trimOrNull(this.encryptionParameters.getDataEncryptionAlgorithm()) == null || (this.encryptionParameters.getKeyTransportEncryptionCredential() == null && this.encryptionParameters.getDataEncryptionCredential() == null)) {
            throw new MessageHandlerException("Message context did not contain all required encryption parameters");
        }
        if (this.encryptionParameters.getKeyTransportEncryptionCredential() == null || this.encryptionParameters.getDataEncryptionCredential() == null) {
            return true;
        }
        throw new MessageHandlerException("Message context contained both a content encryption and key transport credential. Only one required.");
    }

    protected void doInvoke(@Nonnull MessageContext messageContext) throws MessageHandlerException {
        JWEObject jWEObject;
        Payload apply = this.payloadToEncryptLookupStrategy.apply(messageContext);
        if (apply == null) {
            this.log.trace("{} No plain text source payload provided to encrypt, encryption skipped", getLogPrefix());
            return;
        }
        JWEAlgorithm parse = JWEAlgorithm.parse(this.encryptionParameters.getKeyTransportEncryptionAlgorithm());
        EncryptionMethod parse2 = EncryptionMethod.parse(this.encryptionParameters.getDataEncryptionAlgorithm());
        Credential keyTransportEncryptionCredential = this.encryptionParameters.getKeyTransportEncryptionCredential();
        Credential dataEncryptionCredential = this.encryptionParameters.getDataEncryptionCredential();
        String resolveKid = keyTransportEncryptionCredential == null ? null : CredentialConversionUtil.resolveKid(keyTransportEncryptionCredential);
        String resolveKid2 = dataEncryptionCredential == null ? null : CredentialConversionUtil.resolveKid(dataEncryptionCredential);
        try {
            if (JWEAlgorithm.Family.RSA.contains(parse) && keyTransportEncryptionCredential != null && keyTransportEncryptionCredential.getPublicKey() != null) {
                jWEObject = new JWEObject(new JWEHeader.Builder(parse, parse2).contentType("JWT").keyID(resolveKid).build(), apply);
                logEncryption(resolveKid, parse.getName(), parse2.getName());
                jWEObject.encrypt(new RSAEncrypter((RSAPublicKey) keyTransportEncryptionCredential.getPublicKey()));
            } else if (JWEAlgorithm.Family.ECDH_ES.contains(parse) && keyTransportEncryptionCredential != null && keyTransportEncryptionCredential.getPublicKey() != null) {
                jWEObject = new JWEObject(new JWEHeader.Builder(parse, parse2).contentType("JWT").keyID(resolveKid).build(), apply);
                logEncryption(resolveKid, parse.getName(), parse2.getName());
                jWEObject.encrypt(new ECDHEncrypter((ECPublicKey) keyTransportEncryptionCredential.getPublicKey()));
            } else if ((JWEAlgorithm.Family.AES_KW.contains(parse) || JWEAlgorithm.Family.AES_GCM_KW.contains(parse)) && keyTransportEncryptionCredential != null && keyTransportEncryptionCredential.getSecretKey() != null) {
                jWEObject = new JWEObject(new JWEHeader.Builder(parse, parse2).contentType("JWT").keyID(resolveKid).build(), apply);
                logEncryption(resolveKid, parse.getName(), parse2.getName());
                jWEObject.encrypt(new AESEncrypter(keyTransportEncryptionCredential.getSecretKey()));
            } else {
                if (!JWEAlgorithm.DIR.equals(parse) || dataEncryptionCredential == null || dataEncryptionCredential.getSecretKey() == null) {
                    this.log.error("{} Unsupported algorithm '{}' or key '{}'", new Object[]{getLogPrefix(), parse.getName(), resolveKid});
                    throw new MessageHandlerException("Unsupported algorithm " + parse.getName());
                }
                jWEObject = new JWEObject(new JWEHeader.Builder(parse, parse2).contentType("JWT").keyID(resolveKid2).build(), apply);
                logEncryption(resolveKid2, parse.getName(), parse2.getName());
                jWEObject.encrypt(new DirectEncrypter(dataEncryptionCredential.getSecretKey()));
            }
            JWT parse3 = EncryptedJWT.parse(jWEObject.serialize());
            this.jwtUpdateConsumer.accept(parse3, messageContext);
            if (this.log.isDebugEnabled() && !this.log.isTraceEnabled()) {
                this.log.debug("{} Encrypted '{}' JWT", getLogPrefix(), this.logName);
            } else if (this.log.isTraceEnabled()) {
                this.log.trace("{} Encrypted '{}' JWT: {}", new Object[]{getLogPrefix(), this.logName, parse3.serialize()});
            }
        } catch (Exception e) {
            this.log.error("{} Encryption failed", getLogPrefix(), e);
            throw new MessageHandlerException("Encryption failed", e);
        }
    }

    private void logEncryption(@Nullable String str, @Nullable String str2, @Nullable String str3) {
        this.log.debug("{} Encrypting '{}' with kid '{}' and params alg: {} enc: {}", new Object[]{getLogPrefix(), this.logName, str, str2, str3});
    }
}
