/*
 * Decompiled with CFR 0.152.
 */
package com.hivemq.mqtt.handler.connect;

import com.google.common.annotations.VisibleForTesting;
import com.hivemq.bootstrap.ClientConnectionContext;
import com.hivemq.extension.sdk.api.annotations.NotNull;
import com.hivemq.mqtt.handler.disconnect.MqttServerDisconnector;
import com.hivemq.mqtt.message.Message;
import com.hivemq.mqtt.message.auth.AUTH;
import com.hivemq.mqtt.message.connack.CONNACK;
import com.hivemq.mqtt.message.connect.CONNECT;
import com.hivemq.mqtt.message.reason.Mqtt5ConnAckReasonCode;
import io.netty.channel.Channel;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.util.concurrent.GenericFutureListener;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedList;
import java.util.Optional;
import java.util.Queue;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MessageBarrier
extends ChannelDuplexHandler {
    private static final Logger log = LoggerFactory.getLogger(MessageBarrier.class);
    private static final ChannelFutureListener ENABLE_AUTO_READ_LISTENER = future -> {
        if (future.isSuccess()) {
            MessageBarrier.resumeRead(future.channel());
        }
    };
    @NotNull
    private final MqttServerDisconnector serverDisconnector;
    @NotNull
    private final Queue<Message> messageQueue = new LinkedList<Message>();
    private boolean connectReceived;
    private boolean connackSent;

    public MessageBarrier(@NotNull MqttServerDisconnector serverDisconnector) {
        this.serverDisconnector = serverDisconnector;
    }

    public void channelRead(@NotNull ChannelHandlerContext ctx, @NotNull Object msg) {
        if (msg instanceof Message) {
            if (msg instanceof CONNECT) {
                this.connectReceived = true;
                MessageBarrier.suspendRead(ctx.channel());
            } else {
                if (!this.connectReceived) {
                    this.serverDisconnector.logAndClose(ctx.channel(), "A client (IP: {}) sent other message before CONNECT. Disconnecting client.", "Sent other message before CONNECT");
                    return;
                }
                if (msg instanceof AUTH) {
                    MessageBarrier.suspendRead(ctx.channel());
                } else if (!this.connackSent) {
                    this.messageQueue.add((Message)msg);
                    return;
                }
            }
        }
        ctx.fireChannelRead(msg);
    }

    public void write(@NotNull ChannelHandlerContext ctx, @NotNull Object msg, @NotNull ChannelPromise promise) {
        if (msg instanceof CONNACK && ((CONNACK)msg).getReasonCode() == Mqtt5ConnAckReasonCode.SUCCESS) {
            promise.addListener((GenericFutureListener)((ChannelFutureListener)future -> {
                if (future.isSuccess()) {
                    future.channel().pipeline().remove((ChannelHandler)this);
                    this.connackSent = true;
                    this.releaseQueuedMessages(ctx);
                }
            }));
            promise.addListener((GenericFutureListener)ENABLE_AUTO_READ_LISTENER);
        } else if (msg instanceof AUTH) {
            promise.addListener((GenericFutureListener)ENABLE_AUTO_READ_LISTENER);
        }
        ctx.write(msg, promise);
    }

    private void releaseQueuedMessages(@NotNull ChannelHandlerContext ctx) {
        for (Message message : this.messageQueue) {
            ctx.fireChannelRead((Object)message);
        }
    }

    private static void suspendRead(@NotNull Channel channel) {
        if (log.isTraceEnabled()) {
            ClientConnectionContext clientConnectionContext = ClientConnectionContext.of(channel);
            Optional<String> channelIP = clientConnectionContext.getChannelIP();
            log.trace("Suspending read operations for MQTT client with id {} and IP {}", (Object)clientConnectionContext.getClientId(), (Object)channelIP.orElse("UNKNOWN"));
        }
        channel.config().setAutoRead(false);
    }

    private static void resumeRead(@NotNull Channel channel) {
        if (log.isTraceEnabled()) {
            ClientConnectionContext clientConnectionContext = ClientConnectionContext.of(channel);
            Optional<String> channelIP = clientConnectionContext.getChannelIP();
            log.trace("Restarting read operations for MQTT client with id {} and IP {}", (Object)clientConnectionContext.getClientId(), (Object)channelIP.orElse("UNKNOWN"));
        }
        channel.config().setAutoRead(true);
    }

    @VisibleForTesting
    boolean getConnectReceived() {
        return this.connectReceived;
    }

    @VisibleForTesting
    @NotNull
    Collection<Message> getQueue() {
        return Collections.unmodifiableCollection(this.messageQueue);
    }
}

