/*
 * Decompiled with CFR 0.152.
 */
package com.hivemq.extensions.handler;

import com.google.common.collect.ImmutableMap;
import com.hivemq.bootstrap.ClientConnectionContext;
import com.hivemq.configuration.HivemqId;
import com.hivemq.configuration.service.FullConfigurationService;
import com.hivemq.extension.sdk.api.annotations.NotNull;
import com.hivemq.extension.sdk.api.async.TimeoutFallback;
import com.hivemq.extension.sdk.api.client.parameter.ClientInformation;
import com.hivemq.extension.sdk.api.client.parameter.ConnectionInformation;
import com.hivemq.extension.sdk.api.client.parameter.ServerInformation;
import com.hivemq.extension.sdk.api.interceptor.connect.ConnectInboundInterceptor;
import com.hivemq.extension.sdk.api.interceptor.connect.ConnectInboundInterceptorProvider;
import com.hivemq.extension.sdk.api.interceptor.connect.parameter.ConnectInboundInput;
import com.hivemq.extension.sdk.api.interceptor.connect.parameter.ConnectInboundOutput;
import com.hivemq.extension.sdk.api.interceptor.connect.parameter.ConnectInboundProviderInput;
import com.hivemq.extensions.ExtensionInformationUtil;
import com.hivemq.extensions.HiveMQExtension;
import com.hivemq.extensions.HiveMQExtensions;
import com.hivemq.extensions.client.parameter.ClientInformationImpl;
import com.hivemq.extensions.executor.PluginOutPutAsyncer;
import com.hivemq.extensions.executor.PluginTaskExecutorService;
import com.hivemq.extensions.executor.task.PluginInOutTask;
import com.hivemq.extensions.executor.task.PluginInOutTaskContext;
import com.hivemq.extensions.handler.ExtensionParameterHolder;
import com.hivemq.extensions.interceptor.connect.parameter.ConnectInboundInputImpl;
import com.hivemq.extensions.interceptor.connect.parameter.ConnectInboundOutputImpl;
import com.hivemq.extensions.interceptor.connect.parameter.ConnectInboundProviderInputImpl;
import com.hivemq.extensions.packets.connect.ConnectPacketImpl;
import com.hivemq.extensions.packets.connect.ModifiableConnectPacketImpl;
import com.hivemq.extensions.services.interceptor.Interceptors;
import com.hivemq.mqtt.handler.connack.MqttConnacker;
import com.hivemq.mqtt.message.connect.CONNECT;
import com.hivemq.mqtt.message.reason.Mqtt5ConnAckReasonCode;
import com.hivemq.util.Exceptions;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicInteger;
import javax.inject.Inject;
import javax.inject.Singleton;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Singleton
public class ConnectInboundInterceptorHandler {
    private static final Logger log = LoggerFactory.getLogger(ConnectInboundInterceptorHandler.class);
    @NotNull
    private final FullConfigurationService configurationService;
    @NotNull
    private final PluginOutPutAsyncer asyncer;
    @NotNull
    private final HiveMQExtensions hiveMQExtensions;
    @NotNull
    private final PluginTaskExecutorService executorService;
    @NotNull
    private final HivemqId hivemqId;
    @NotNull
    private final Interceptors interceptors;
    @NotNull
    private final ServerInformation serverInformation;
    @NotNull
    private final MqttConnacker connacker;

    @Inject
    public ConnectInboundInterceptorHandler(@NotNull FullConfigurationService configurationService, @NotNull PluginOutPutAsyncer asyncer, @NotNull HiveMQExtensions hiveMQExtensions, @NotNull PluginTaskExecutorService executorService, @NotNull HivemqId hivemqId, @NotNull Interceptors interceptors, @NotNull ServerInformation serverInformation, @NotNull MqttConnacker connacker) {
        this.configurationService = configurationService;
        this.asyncer = asyncer;
        this.hiveMQExtensions = hiveMQExtensions;
        this.executorService = executorService;
        this.hivemqId = hivemqId;
        this.interceptors = interceptors;
        this.serverInformation = serverInformation;
        this.connacker = connacker;
    }

    public void handleInboundConnect(@NotNull ChannelHandlerContext ctx, @NotNull CONNECT connect) {
        Channel channel = ctx.channel();
        ClientConnectionContext clientConnectionContext = ClientConnectionContext.of(channel);
        String clientId = clientConnectionContext.getClientId();
        if (clientId == null) {
            return;
        }
        ImmutableMap<String, ConnectInboundInterceptorProvider> providers = this.interceptors.connectInboundInterceptorProviders();
        if (providers.isEmpty()) {
            ctx.fireChannelRead((Object)connect);
            return;
        }
        ClientInformation clientInfo = ExtensionInformationUtil.getAndSetClientInformation(channel, clientId);
        ConnectionInformation connectionInfo = ExtensionInformationUtil.getAndSetConnectionInformation(channel);
        ConnectInboundProviderInputImpl providerInput = new ConnectInboundProviderInputImpl(this.serverInformation, clientInfo, connectionInfo);
        long timestamp = Objects.requireNonNullElse(clientConnectionContext.getConnectReceivedTimestamp(), System.currentTimeMillis());
        ConnectPacketImpl packet = new ConnectPacketImpl(connect, timestamp);
        ConnectInboundInputImpl input = new ConnectInboundInputImpl(clientInfo, connectionInfo, packet);
        ExtensionParameterHolder<ConnectInboundInputImpl> inputHolder = new ExtensionParameterHolder<ConnectInboundInputImpl>(input);
        ModifiableConnectPacketImpl modifiablePacket = new ModifiableConnectPacketImpl(packet, this.configurationService);
        ConnectInboundOutputImpl output = new ConnectInboundOutputImpl(this.asyncer, modifiablePacket);
        ExtensionParameterHolder<ConnectInboundOutputImpl> outputHolder = new ExtensionParameterHolder<ConnectInboundOutputImpl>(output);
        ConnectInterceptorContext context = new ConnectInterceptorContext(clientId, providers.size(), ctx, inputHolder, outputHolder);
        for (Map.Entry entry : providers.entrySet()) {
            ConnectInboundInterceptorProvider provider = (ConnectInboundInterceptorProvider)entry.getValue();
            HiveMQExtension extension = this.hiveMQExtensions.getExtension((String)entry.getKey());
            if (extension == null) {
                context.finishInterceptor();
                continue;
            }
            ConnectInterceptorTask task = new ConnectInterceptorTask(provider, providerInput, extension.getId(), clientId);
            this.executorService.handlePluginInOutTaskExecution(context, inputHolder, outputHolder, task);
        }
    }

    private static class ConnectInterceptorTask
    implements PluginInOutTask<ConnectInboundInputImpl, ConnectInboundOutputImpl> {
        @NotNull
        private final ConnectInboundInterceptorProvider provider;
        @NotNull
        private final ConnectInboundProviderInputImpl providerInput;
        @NotNull
        private final String extensionId;
        @NotNull
        private final String clientId;

        private ConnectInterceptorTask(@NotNull ConnectInboundInterceptorProvider provider, @NotNull ConnectInboundProviderInputImpl providerInput, @NotNull String extensionId, @NotNull String clientId) {
            this.provider = provider;
            this.providerInput = providerInput;
            this.extensionId = extensionId;
            this.clientId = clientId;
        }

        @Override
        @NotNull
        public ConnectInboundOutputImpl apply(@NotNull ConnectInboundInputImpl input, @NotNull ConnectInboundOutputImpl output) {
            if (output.isPrevent()) {
                return output;
            }
            try {
                ConnectInboundInterceptor interceptor = this.provider.getConnectInboundInterceptor((ConnectInboundProviderInput)this.providerInput);
                if (interceptor != null) {
                    interceptor.onConnect((ConnectInboundInput)input, (ConnectInboundOutput)output);
                }
            }
            catch (Throwable e) {
                log.warn("Uncaught exception was thrown from extension with id \"{}\" on inbound CONNECT interception. Extensions are responsible for their own exception handling.", (Object)this.extensionId, (Object)e);
                output.prevent(String.format("Connect with client ID %s failed because of an exception was thrown by an CONNECT inbound interceptor.", this.clientId), "Exception in CONNECT inbound interceptor");
                Exceptions.rethrowError(e);
            }
            return output;
        }

        @Override
        @NotNull
        public ClassLoader getPluginClassLoader() {
            return this.provider.getClass().getClassLoader();
        }
    }

    private class ConnectInterceptorContext
    extends PluginInOutTaskContext<ConnectInboundOutputImpl>
    implements Runnable {
        private final int interceptorCount;
        @NotNull
        private final AtomicInteger counter;
        @NotNull
        private final ChannelHandlerContext ctx;
        @NotNull
        private final ExtensionParameterHolder<ConnectInboundInputImpl> inputHolder;
        @NotNull
        private final ExtensionParameterHolder<ConnectInboundOutputImpl> outputHolder;

        ConnectInterceptorContext(String clientId, @NotNull int interceptorCount, @NotNull ChannelHandlerContext ctx, @NotNull ExtensionParameterHolder<ConnectInboundInputImpl> inputHolder, ExtensionParameterHolder<ConnectInboundOutputImpl> outputHolder) {
            super(clientId);
            this.interceptorCount = interceptorCount;
            this.counter = new AtomicInteger(0);
            this.ctx = ctx;
            this.inputHolder = inputHolder;
            this.outputHolder = outputHolder;
        }

        @Override
        public void pluginPost(@NotNull ConnectInboundOutputImpl output) {
            if (output.isPrevent()) {
                this.finishInterceptor();
            } else if (output.isTimedOut() && output.getTimeoutFallback() == TimeoutFallback.FAILURE) {
                output.prevent("Connect with client ID " + this.getIdentifier() + " failed because of an interceptor timeout", "Extension interceptor timeout");
                this.finishInterceptor();
            } else {
                if (output.getConnectPacket().isModified()) {
                    this.inputHolder.set(this.inputHolder.get().update(output));
                }
                if (!this.finishInterceptor()) {
                    this.outputHolder.set(output.update(this.inputHolder.get()));
                }
            }
        }

        public boolean finishInterceptor() {
            if (this.counter.incrementAndGet() == this.interceptorCount) {
                this.ctx.executor().execute((Runnable)this);
                return true;
            }
            return false;
        }

        @Override
        public void run() {
            ConnectInboundOutputImpl output = this.outputHolder.get();
            if (output.isPrevent()) {
                String logMessage = output.getLogMessage();
                String reasonString = output.getReasonString();
                ConnectInboundInterceptorHandler.this.connacker.connackError(this.ctx.channel(), logMessage, logMessage, Mqtt5ConnAckReasonCode.UNSPECIFIED_ERROR, reasonString);
            } else {
                CONNECT connect = CONNECT.from(this.inputHolder.get().getConnectPacket(), ConnectInboundInterceptorHandler.this.hivemqId.get());
                ClientConnectionContext clientConnectionContext = ClientConnectionContext.of(this.ctx.channel());
                clientConnectionContext.setClientId(connect.getClientIdentifier());
                clientConnectionContext.setExtensionClientInformation(new ClientInformationImpl(connect.getClientIdentifier()));
                clientConnectionContext.setCleanStart(connect.isCleanStart());
                clientConnectionContext.setConnectKeepAlive(connect.getKeepAlive());
                clientConnectionContext.setAuthUsername(connect.getUsername());
                clientConnectionContext.setAuthPassword(connect.getPassword());
                this.ctx.fireChannelRead((Object)connect);
            }
        }
    }
}

