/*
 * Decompiled with CFR 0.152.
 */
package org.wildfly.security.mechanism.scram;

import java.nio.charset.StandardCharsets;
import java.security.InvalidKeyException;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.Provider;
import java.security.SecureRandom;
import java.util.Arrays;
import java.util.NoSuchElementException;
import java.util.Random;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.Supplier;
import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.NameCallback;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.security.sasl.AuthorizeCallback;
import org.wildfly.common.Assert;
import org.wildfly.security._private.ElytronMessages;
import org.wildfly.security.auth.callback.ChannelBindingCallback;
import org.wildfly.security.mechanism.AuthenticationMechanismException;
import org.wildfly.security.mechanism.MechanismUtil;
import org.wildfly.security.mechanism.scram.ScramFinalClientMessage;
import org.wildfly.security.mechanism.scram.ScramFinalServerMessage;
import org.wildfly.security.mechanism.scram.ScramInitialClientMessage;
import org.wildfly.security.mechanism.scram.ScramInitialServerMessage;
import org.wildfly.security.mechanism.scram.ScramInitialServerResult;
import org.wildfly.security.mechanism.scram.ScramMechanism;
import org.wildfly.security.mechanism.scram.ScramUtil;
import org.wildfly.security.password.interfaces.ScramDigestPassword;
import org.wildfly.security.password.spec.IteratedPasswordAlgorithmSpec;
import org.wildfly.security.sasl.util.StringPrep;
import org.wildfly.security.util.ByteIterator;
import org.wildfly.security.util.ByteStringBuilder;

public final class ScramServer {
    private final Supplier<Provider[]> providers;
    private final ScramMechanism mechanism;
    private final CallbackHandler callbackHandler;
    private final SecureRandom random;
    private final byte[] bindingData;
    private final String bindingType;
    private final int minimumIterationCount;
    private final int maximumIterationCount;

    ScramServer(ScramMechanism mechanism, CallbackHandler callbackHandler, SecureRandom random, byte[] bindingData, String bindingType, int minimumIterationCount, int maximumIterationCount, Supplier<Provider[]> providers) {
        this.mechanism = mechanism;
        this.callbackHandler = callbackHandler;
        this.random = random;
        this.bindingData = bindingData;
        this.bindingType = bindingType;
        this.minimumIterationCount = minimumIterationCount;
        this.maximumIterationCount = maximumIterationCount;
        this.providers = providers;
    }

    public ScramInitialClientMessage parseInitialClientMessage(ChannelBindingCallback bindingCallback, byte[] bytes) throws AuthenticationMechanismException {
        byte[] response = (byte[])bytes.clone();
        ByteIterator bi = ByteIterator.ofBytes(response);
        try {
            String authorizationID;
            boolean binding;
            byte[] bindingData;
            String bindingType;
            char cbindFlag = (char)bi.next();
            if (bindingCallback != null) {
                bindingType = bindingCallback.getBindingType();
                bindingData = bindingCallback.getBindingData();
            } else {
                bindingType = null;
                bindingData = null;
            }
            if (cbindFlag == 'p') {
                if (!this.mechanism.isPlus()) {
                    throw ElytronMessages.log.mechChannelBindingNotSupported(this.mechanism.toString());
                }
                if (bindingType == null || bindingData == null) {
                    throw ElytronMessages.log.mechChannelBindingNotProvided(this.mechanism.toString());
                }
                if (bi.next() != 61) {
                    throw ElytronMessages.log.mechInvalidMessageReceived(this.mechanism.toString());
                }
                if (!bindingType.equals(bi.delimitedBy(44).asUtf8String().drainToString())) {
                    throw ElytronMessages.log.mechChannelBindingTypeMismatch(this.mechanism.toString());
                }
                binding = true;
            } else if (cbindFlag == 'y') {
                if (this.mechanism.isPlus()) {
                    throw ElytronMessages.log.mechChannelBindingNotProvided(this.mechanism.toString());
                }
                binding = true;
            } else if (cbindFlag == 'n') {
                if (this.mechanism.isPlus()) {
                    throw ElytronMessages.log.mechChannelBindingNotProvided(this.mechanism.toString());
                }
                if (bindingType != null || bindingData != null) {
                    throw ElytronMessages.log.mechChannelBindingNotSupported(this.mechanism.toString());
                }
                binding = false;
            } else {
                throw ElytronMessages.log.mechInvalidMessageReceived(this.mechanism.toString());
            }
            if (bi.next() != 44) {
                throw ElytronMessages.log.mechInvalidMessageReceived(this.mechanism.toString());
            }
            int c = bi.next();
            if (c == 97) {
                if (bi.next() != 61) {
                    throw ElytronMessages.log.mechInvalidClientMessage(this.mechanism.toString());
                }
                authorizationID = bi.delimitedBy(44).asUtf8String().drainToString();
                bi.next();
            } else if (c == 44) {
                authorizationID = null;
            } else {
                throw ElytronMessages.log.mechInvalidClientMessage(this.mechanism.toString());
            }
            int initialPartIndex = bi.offset();
            if (bi.next() == 110) {
                if (bi.next() != 61) {
                    throw ElytronMessages.log.mechInvalidClientMessage(this.mechanism.toString());
                }
            } else {
                throw ElytronMessages.log.mechInvalidClientMessage(this.mechanism.toString());
            }
            ByteStringBuilder bsb = new ByteStringBuilder();
            StringPrep.encode(bi.delimitedBy(44).asUtf8String().drainToString(), bsb, 0x10001FFFL);
            String authenticationName = new String(bsb.toArray(), StandardCharsets.UTF_8);
            bi.next();
            if (bi.next() != 114 || bi.next() != 61) {
                throw ElytronMessages.log.mechInvalidClientMessage(this.mechanism.toString());
            }
            byte[] nonce = bi.delimitedBy(44).drain();
            if (bi.hasNext()) {
                throw ElytronMessages.log.mechInvalidClientMessage(this.mechanism.toString());
            }
            return new ScramInitialClientMessage(this.mechanism, authorizationID, authenticationName, binding, bindingType, bindingData, nonce, initialPartIndex, response);
        }
        catch (NoSuchElementException ignored) {
            throw ElytronMessages.log.mechInvalidMessageReceived(this.mechanism.toString());
        }
    }

    public ScramInitialServerResult evaluateInitialResponse(ScramInitialClientMessage clientMessage) throws AuthenticationMechanismException {
        boolean trace = ElytronMessages.log.isTraceEnabled();
        if (clientMessage.getMechanism() != this.mechanism) {
            throw ElytronMessages.log.mechUnmatchedMechanism(this.mechanism.toString(), clientMessage.getMechanism().toString());
        }
        NameCallback nameCallback = new NameCallback("Remote authentication name", clientMessage.getAuthenticationName());
        try {
            MechanismUtil.handleCallbacks(this.mechanism.toString(), this.callbackHandler, nameCallback);
        }
        catch (UnsupportedCallbackException e) {
            throw ElytronMessages.log.mechCallbackHandlerDoesNotSupportUserName(this.mechanism.toString(), e);
        }
        IteratedPasswordAlgorithmSpec generateParameters = new IteratedPasswordAlgorithmSpec(Math.max(this.minimumIterationCount, Math.min(this.maximumIterationCount, 20000)));
        ScramDigestPassword password = MechanismUtil.getPasswordCredential(clientMessage.getAuthenticationName(), this.callbackHandler, ScramDigestPassword.class, this.mechanism.getPasswordAlgorithm(), null, generateParameters, this.providers);
        byte[] saltedPasswordBytes = password.getDigest();
        int iterationCount = password.getIterationCount();
        if (iterationCount < this.minimumIterationCount) {
            throw ElytronMessages.log.mechIterationCountIsTooLow(this.mechanism.toString(), iterationCount, this.minimumIterationCount);
        }
        if (iterationCount > this.maximumIterationCount) {
            throw ElytronMessages.log.mechIterationCountIsTooHigh(this.mechanism.toString(), iterationCount, this.maximumIterationCount);
        }
        byte[] salt = password.getSalt();
        if (trace) {
            ElytronMessages.log.tracef("[S] Salt: %s%n", (Object)ByteIterator.ofBytes(salt).hexEncode().drainToString());
        }
        if (trace) {
            ElytronMessages.log.tracef("[S] Salted password: %s%n", (Object)ByteIterator.ofBytes(saltedPasswordBytes).hexEncode().drainToString());
        }
        ByteStringBuilder b = new ByteStringBuilder();
        b.append('r').append('=');
        b.append(clientMessage.getRawNonce());
        byte[] serverNonce = ScramUtil.generateNonce(28, this.getRandom());
        b.append(serverNonce);
        b.append(',');
        b.append('s').append('=');
        b.appendLatin1(ByteIterator.ofBytes(salt).base64Encode());
        b.append(',');
        b.append('i').append('=');
        b.append(Integer.toString(iterationCount));
        byte[] messageBytes = b.toArray();
        return new ScramInitialServerResult(new ScramInitialServerMessage(clientMessage, serverNonce, salt, iterationCount, messageBytes), password);
    }

    public ScramFinalClientMessage parseFinalClientMessage(ScramInitialClientMessage initialResponse, ScramInitialServerResult initialResult, byte[] bytes) throws AuthenticationMechanismException {
        ScramInitialServerMessage initialChallenge = initialResult.getScramInitialChallenge();
        Assert.checkNotNullParam("initialResponse", initialResponse);
        Assert.checkNotNullParam("initialChallenge", initialChallenge);
        ScramMechanism mechanism = initialResponse.getMechanism();
        if (mechanism != initialChallenge.getMechanism()) {
            throw ElytronMessages.log.mechUnmatchedMechanism(mechanism.toString(), initialChallenge.getMechanism().toString());
        }
        byte[] response = (byte[])bytes.clone();
        ByteIterator bi = ByteIterator.ofBytes(response);
        try {
            if (bi.next() != 99 || bi.next() != 61) {
                throw ElytronMessages.log.mechInvalidMessageReceived(mechanism.toString());
            }
            ByteIterator ibi = bi.delimitedBy(44).base64Decode();
            char cbindFlag = (char)ibi.next();
            String bindingType = initialResponse.getBindingType();
            byte[] bindingData = initialResponse.getRawBindingData();
            boolean binding = initialResponse.isBinding();
            if (cbindFlag == 'p') {
                if (!binding) {
                    throw ElytronMessages.log.mechChannelBindingNotSupported(mechanism.toString());
                }
                if (bindingType == null || bindingData == null) {
                    throw ElytronMessages.log.mechChannelBindingNotProvided(mechanism.toString());
                }
                if (ibi.next() != 61) {
                    throw ElytronMessages.log.mechInvalidMessageReceived(mechanism.toString());
                }
                if (!bindingType.equals(ibi.delimitedBy(44).asUtf8String().drainToString())) {
                    throw ElytronMessages.log.mechChannelBindingTypeMismatch(mechanism.toString());
                }
            } else if (cbindFlag == 'y') {
                if (!binding) {
                    throw ElytronMessages.log.mechChannelBindingNotSupported(mechanism.toString());
                }
                if (mechanism.isPlus()) {
                    throw ElytronMessages.log.mechChannelBindingNotProvided(mechanism.toString());
                }
                if (bindingType != null || bindingData != null) {
                    throw ElytronMessages.log.mechChannelBindingNotSupported(mechanism.toString());
                }
            } else if (cbindFlag == 'n') {
                if (binding) {
                    throw ElytronMessages.log.mechChannelBindingNotSupported(mechanism.toString());
                }
                if (mechanism.isPlus()) {
                    throw ElytronMessages.log.mechChannelBindingNotProvided(mechanism.toString());
                }
            } else {
                throw ElytronMessages.log.mechInvalidMessageReceived(mechanism.toString());
            }
            if (ibi.next() != 44) {
                throw ElytronMessages.log.mechInvalidMessageReceived(mechanism.toString());
            }
            int c = ibi.next();
            if (c == 97) {
                if (ibi.next() != 61) {
                    throw ElytronMessages.log.mechInvalidClientMessage(mechanism.toString());
                }
                String authorizationID = ibi.delimitedBy(44).asUtf8String().drainToString();
                ibi.next();
                if (!authorizationID.equals(initialResponse.getAuthorizationId())) {
                    throw ElytronMessages.log.mechAuthorizationIdChanged(mechanism.toString());
                }
            } else if (c == 44) {
                if (initialResponse.getAuthorizationId() != null) {
                    throw ElytronMessages.log.mechAuthorizationIdChanged(mechanism.toString());
                }
            } else {
                throw ElytronMessages.log.mechInvalidClientMessage(mechanism.toString());
            }
            if (bindingData != null && !ibi.contentEquals(ByteIterator.ofBytes(bindingData))) {
                throw ElytronMessages.log.mechChannelBindingChanged(mechanism.toString());
            }
            bi.next();
            if (bi.next() != 114 || bi.next() != 61) {
                throw ElytronMessages.log.mechInvalidClientMessage(mechanism.toString());
            }
            byte[] clientNonce = initialResponse.getRawNonce();
            byte[] serverNonce = initialChallenge.getRawServerNonce();
            if (!bi.delimitedBy(44).limitedTo(clientNonce.length).contentEquals(ByteIterator.ofBytes(clientNonce)) || !bi.delimitedBy(44).limitedTo(serverNonce.length).contentEquals(ByteIterator.ofBytes(serverNonce))) {
                throw ElytronMessages.log.mechNoncesDoNotMatch(mechanism.toString());
            }
            int proofOffset = bi.offset();
            bi.next();
            if (bi.next() != 112 || bi.next() != 61) {
                throw ElytronMessages.log.mechInvalidClientMessage(mechanism.toString());
            }
            byte[] proof = bi.delimitedBy(44).base64Decode().drain();
            if (bi.hasNext()) {
                throw ElytronMessages.log.mechInvalidClientMessage(mechanism.toString());
            }
            return new ScramFinalClientMessage(initialResponse, initialChallenge, initialResult.getScramDigestPassword(), proof, response, proofOffset);
        }
        catch (NoSuchElementException ignored) {
            throw ElytronMessages.log.mechInvalidMessageReceived(mechanism.toString());
        }
    }

    public ScramFinalServerMessage evaluateFinalClientMessage(ScramInitialServerResult initialResult, ScramFinalClientMessage clientMessage) throws AuthenticationMechanismException {
        boolean trace = ElytronMessages.log.isTraceEnabled();
        if (clientMessage.getMechanism() != this.mechanism) {
            throw ElytronMessages.log.mechUnmatchedMechanism(this.mechanism.toString(), clientMessage.getMechanism().toString());
        }
        ByteStringBuilder b = new ByteStringBuilder();
        try {
            Mac mac = Mac.getInstance(this.getMechanism().getHmacName());
            MessageDigest messageDigest = MessageDigest.getInstance(this.getMechanism().getMessageDigestName());
            mac.reset();
            byte[] saltedPassword = initialResult.getScramDigestPassword().getDigest();
            mac.init(new SecretKeySpec(saltedPassword, mac.getAlgorithm()));
            mac.update(ScramUtil.CLIENT_KEY_BYTES);
            byte[] clientKey = mac.doFinal();
            if (trace) {
                ElytronMessages.log.tracef("[S] Client key: %s%n", (Object)ByteIterator.ofBytes(clientKey).hexEncode().drainToString());
            }
            messageDigest.reset();
            messageDigest.update(clientKey);
            byte[] storedKey = messageDigest.digest();
            if (trace) {
                ElytronMessages.log.tracef("[S] Stored key: %s%n", (Object)ByteIterator.ofBytes(storedKey).hexEncode().drainToString());
            }
            mac.reset();
            mac.init(new SecretKeySpec(storedKey, mac.getAlgorithm()));
            byte[] clientFirstMessage = clientMessage.getInitialResponse().getRawMessageBytes();
            int clientFirstMessageBareStart = clientMessage.getInitialResponse().getInitialPartIndex();
            mac.update(clientFirstMessage, clientFirstMessageBareStart, clientFirstMessage.length - clientFirstMessageBareStart);
            if (trace) {
                ElytronMessages.log.tracef("[S] Using client first message: %s%n", (Object)ByteIterator.ofBytes(Arrays.copyOfRange(clientFirstMessage, clientFirstMessageBareStart, clientFirstMessage.length)).hexEncode().drainToString());
            }
            mac.update((byte)44);
            byte[] serverFirstMessage = initialResult.getScramInitialChallenge().getRawMessageBytes();
            mac.update(serverFirstMessage);
            if (trace) {
                ElytronMessages.log.tracef("[S] Using server first message: %s%n", (Object)ByteIterator.ofBytes(serverFirstMessage).hexEncode().drainToString());
            }
            mac.update((byte)44);
            byte[] response = clientMessage.getRawMessageBytes();
            int proofOffset = clientMessage.getProofOffset();
            mac.update(response, 0, proofOffset);
            if (trace) {
                ElytronMessages.log.tracef("[S] Using client final message without proof: %s%n", (Object)ByteIterator.ofBytes(Arrays.copyOfRange(response, 0, proofOffset)).hexEncode().drainToString());
            }
            byte[] clientSignature = mac.doFinal();
            if (trace) {
                ElytronMessages.log.tracef("[S] Client signature: %s%n", (Object)ByteIterator.ofBytes(clientSignature).hexEncode().drainToString());
            }
            mac.reset();
            mac.init(new SecretKeySpec(saltedPassword, mac.getAlgorithm()));
            mac.update(ScramUtil.SERVER_KEY_BYTES);
            byte[] serverKey = mac.doFinal();
            if (trace) {
                ElytronMessages.log.tracef("[S] Server key: %s%n", (Object)ByteIterator.ofBytes(serverKey).hexEncode().drainToString());
            }
            mac.reset();
            mac.init(new SecretKeySpec(serverKey, mac.getAlgorithm()));
            mac.update(clientFirstMessage, clientFirstMessageBareStart, clientFirstMessage.length - clientFirstMessageBareStart);
            mac.update((byte)44);
            mac.update(serverFirstMessage);
            mac.update((byte)44);
            mac.update(response, 0, proofOffset);
            byte[] serverSignature = mac.doFinal();
            if (trace) {
                ElytronMessages.log.tracef("[S] Server signature: %s%n", (Object)ByteIterator.ofBytes(serverSignature).hexEncode().drainToString());
            }
            byte[] recoveredClientProof = clientMessage.getRawClientProof();
            if (trace) {
                ElytronMessages.log.tracef("[S] Client proof: %s%n", (Object)ByteIterator.ofBytes(recoveredClientProof).hexEncode().drainToString());
            }
            byte[] recoveredClientKey = (byte[])clientSignature.clone();
            ScramUtil.xor(recoveredClientKey, recoveredClientProof);
            if (trace) {
                ElytronMessages.log.tracef("[S] Recovered client key: %s%n", (Object)ByteIterator.ofBytes(recoveredClientKey).hexEncode().drainToString());
            }
            if (!Arrays.equals(recoveredClientKey, clientKey)) {
                throw ElytronMessages.log.mechAuthenticationRejectedInvalidProof(this.mechanism.toString());
            }
            String userName = clientMessage.getInitialResponse().getAuthenticationName();
            String authorizationID = clientMessage.getInitialResponse().getAuthorizationId();
            if (authorizationID == null) {
                authorizationID = userName;
            } else {
                ByteStringBuilder bsb = new ByteStringBuilder();
                StringPrep.encode(authorizationID, bsb, 0x10001FFFL);
                authorizationID = new String(bsb.toArray(), StandardCharsets.UTF_8);
            }
            AuthorizeCallback authorizeCallback = new AuthorizeCallback(userName, authorizationID);
            try {
                MechanismUtil.handleCallbacks(this.mechanism.toString(), this.callbackHandler, authorizeCallback);
            }
            catch (UnsupportedCallbackException e) {
                throw ElytronMessages.log.mechAuthorizationUnsupported(this.mechanism.toString(), e);
            }
            if (!authorizeCallback.isAuthorized()) {
                throw ElytronMessages.log.mechAuthorizationFailed(this.mechanism.toString(), userName, authorizationID);
            }
            b.setLength(0);
            b.append('v').append('=');
            b.appendUtf8(ByteIterator.ofBytes(serverSignature).base64Encode());
            return new ScramFinalServerMessage(serverSignature, b.toArray());
        }
        catch (InvalidKeyException | NoSuchAlgorithmException e) {
            throw ElytronMessages.log.mechMacAlgorithmNotSupported(this.mechanism.toString(), e);
        }
    }

    public ScramMechanism getMechanism() {
        return this.mechanism;
    }

    public CallbackHandler getCallbackHandler() {
        return this.callbackHandler;
    }

    Random getRandom() {
        return this.random != null ? this.random : ThreadLocalRandom.current();
    }

    public byte[] getBindingData() {
        return this.bindingData == null ? null : (byte[])this.bindingData.clone();
    }

    byte[] getRawBindingData() {
        return this.bindingData;
    }

    public String getBindingType() {
        return this.bindingType;
    }
}

