/*
 * Decompiled with CFR 0.152.
 */
package org.keycloak.services.clientpolicy.executor;

import java.security.MessageDigest;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import javax.ws.rs.core.MultivaluedMap;
import org.jboss.logging.Logger;
import org.keycloak.common.util.Base64Url;
import org.keycloak.component.ComponentModel;
import org.keycloak.models.ClientModel;
import org.keycloak.models.KeycloakSession;
import org.keycloak.protocol.oidc.OIDCAdvancedConfigWrapper;
import org.keycloak.protocol.oidc.endpoints.request.AuthorizationEndpointRequest;
import org.keycloak.protocol.oidc.utils.OAuth2Code;
import org.keycloak.protocol.oidc.utils.OAuth2CodeParser;
import org.keycloak.protocol.oidc.utils.OIDCResponseType;
import org.keycloak.representations.idm.ClientRepresentation;
import org.keycloak.services.clientpolicy.AuthorizationRequestContext;
import org.keycloak.services.clientpolicy.ClientPolicyContext;
import org.keycloak.services.clientpolicy.ClientPolicyException;
import org.keycloak.services.clientpolicy.TokenRequestContext;
import org.keycloak.services.clientpolicy.executor.AbstractAugumentingClientRegistrationPolicyExecutor;

public class PKCEEnforceExecutor
extends AbstractAugumentingClientRegistrationPolicyExecutor {
    private static final Logger logger = Logger.getLogger(PKCEEnforceExecutor.class);
    private static final Pattern VALID_CODE_CHALLENGE_PATTERN = Pattern.compile("^[0-9a-zA-Z\\-\\.~_]+$");
    private static final Pattern VALID_CODE_VERIFIER_PATTERN = Pattern.compile("^[0-9a-zA-Z\\-\\.~_]+$");

    public PKCEEnforceExecutor(KeycloakSession session, ComponentModel componentModel) {
        super(session, componentModel);
    }

    @Override
    protected void augment(ClientRepresentation rep) {
        if (Boolean.valueOf((String)this.componentModel.getConfig().getFirst((Object)"is-augment")).booleanValue()) {
            OIDCAdvancedConfigWrapper.fromClientRepresentation(rep).setPkceCodeChallengeMethod("S256");
        }
    }

    @Override
    protected void validate(ClientRepresentation rep) throws ClientPolicyException {
        String pkceMethod = OIDCAdvancedConfigWrapper.fromClientRepresentation(rep).getPkceCodeChallengeMethod();
        if (pkceMethod != null && pkceMethod.equals("S256")) {
            return;
        }
        throw new ClientPolicyException("invalid_client_metadata", "Invalid client metadata: code_challenge_method");
    }

    @Override
    public void executeOnEvent(ClientPolicyContext context) throws ClientPolicyException {
        super.executeOnEvent(context);
        switch (context.getEvent()) {
            case AUTHORIZATION_REQUEST: {
                AuthorizationRequestContext authorizationRequestContext = (AuthorizationRequestContext)context;
                this.executeOnAuthorizationRequest(authorizationRequestContext.getparsedResponseType(), authorizationRequestContext.getAuthorizationEndpointRequest(), authorizationRequestContext.getRedirectUri());
                return;
            }
            case TOKEN_REQUEST: {
                TokenRequestContext tokenRequestContext = (TokenRequestContext)context;
                this.executeOnTokenRequest(tokenRequestContext.getParams(), tokenRequestContext.getParseResult());
                return;
            }
        }
    }

    private void executeOnAuthorizationRequest(OIDCResponseType parsedResponseType, AuthorizationEndpointRequest request, String redirectUri) throws ClientPolicyException {
        ClientModel client = this.session.getContext().getClient();
        String codeChallenge = request.getCodeChallenge();
        String codeChallengeMethod = request.getCodeChallengeMethod();
        String pkceCodeChallengeMethod = OIDCAdvancedConfigWrapper.fromClientModel(client).getPkceCodeChallengeMethod();
        if (codeChallengeMethod == null) {
            throw new ClientPolicyException("invalid_request", "Missing parameter: code_challenge_method");
        }
        if (!this.isAcceptableCodeChallengeMethod(codeChallengeMethod)) {
            throw new ClientPolicyException("invalid_request", "Invalid parameter: invalid code_challenge_method");
        }
        if (pkceCodeChallengeMethod != null && !codeChallengeMethod.equals(pkceCodeChallengeMethod)) {
            throw new ClientPolicyException("invalid_request", "Invalid parameter: code challenge method is not configured one");
        }
        if (codeChallenge == null) {
            throw new ClientPolicyException("invalid_request", "Missing parameter: code_challenge");
        }
        if (!this.isValidPkceCodeChallenge(codeChallenge)) {
            throw new ClientPolicyException("invalid_request", "Invalid parameter: code_challenge");
        }
    }

    private boolean isAcceptableCodeChallengeMethod(String method) {
        return "S256".equals(method);
    }

    private boolean isValidPkceCodeChallenge(String codeChallenge) {
        if (codeChallenge.length() < 43) {
            return false;
        }
        if (codeChallenge.length() > 128) {
            return false;
        }
        Matcher m = VALID_CODE_CHALLENGE_PATTERN.matcher(codeChallenge);
        return m.matches();
    }

    private void executeOnTokenRequest(MultivaluedMap<String, String> params, OAuth2CodeParser.ParseResult parseResult) throws ClientPolicyException {
        String codeVerifier = (String)params.getFirst((Object)"code_verifier");
        OAuth2Code codeData = parseResult.getCodeData();
        String codeChallenge = codeData.getCodeChallenge();
        String codeChallengeMethod = codeData.getCodeChallengeMethod();
        this.checkParamsForPkceEnforcedClient(codeVerifier, codeChallenge, codeChallengeMethod);
    }

    private void checkParamsForPkceEnforcedClient(String codeVerifier, String codeChallenge, String codeChallengeMethod) throws ClientPolicyException {
        if (codeVerifier == null) {
            throw new ClientPolicyException("code_verifier_missing", "PKCE code verifier not specified");
        }
        this.verifyCodeVerifier(codeVerifier, codeChallenge, codeChallengeMethod);
    }

    private void verifyCodeVerifier(String codeVerifier, String codeChallenge, String codeChallengeMethod) throws ClientPolicyException {
        if (!this.isValidFormattedCodeVerifier(codeVerifier)) {
            throw new ClientPolicyException("invalid_code_verifier", "PKCE invalid code verifier");
        }
        String codeVerifierEncoded = codeVerifier;
        try {
            codeVerifierEncoded = codeChallengeMethod != null && codeChallengeMethod.equals("S256") ? this.generateS256CodeChallenge(codeVerifier) : codeVerifier;
        }
        catch (Exception nae) {
            throw new ClientPolicyException("pkce_verification_failed", "PKCE code verification failed, not supported algorithm specified");
        }
        if (!codeChallenge.equals(codeVerifierEncoded)) {
            throw new ClientPolicyException("pkce_verification_failed", "PKCE verification failed");
        }
    }

    private boolean isValidFormattedCodeVerifier(String codeVerifier) {
        if (codeVerifier.length() < 43) {
            return false;
        }
        if (codeVerifier.length() > 128) {
            return false;
        }
        Matcher m = VALID_CODE_VERIFIER_PATTERN.matcher(codeVerifier);
        return m.matches();
    }

    private String generateS256CodeChallenge(String codeVerifier) throws Exception {
        MessageDigest md = MessageDigest.getInstance("SHA-256");
        md.update(codeVerifier.getBytes("ISO_8859_1"));
        byte[] digestBytes = md.digest();
        String codeVerifierEncoded = Base64Url.encode((byte[])digestBytes);
        return codeVerifierEncoded;
    }
}

