/*
 * Decompiled with CFR 0.152.
 */
package org.jgroups.protocols;

import java.util.HashMap;
import java.util.Map;
import javax.security.auth.callback.CallbackHandler;
import javax.security.sasl.SaslException;
import org.jgroups.Address;
import org.jgroups.Event;
import org.jgroups.Message;
import org.jgroups.annotations.MBean;
import org.jgroups.annotations.Property;
import org.jgroups.auth.sasl.SaslClientContext;
import org.jgroups.auth.sasl.SaslContext;
import org.jgroups.auth.sasl.SaslServerContext;
import org.jgroups.conf.ClassConfigurator;
import org.jgroups.conf.PropertyConverters;
import org.jgroups.protocols.SaslHeader;
import org.jgroups.protocols.pbcast.GMS;
import org.jgroups.protocols.pbcast.JoinRsp;
import org.jgroups.stack.Protocol;
import org.jgroups.util.MessageBatch;

@MBean(description="Provides SASL authentication")
public class SASL
extends Protocol {
    public static final short GMS_ID = ClassConfigurator.getProtocolId(GMS.class);
    public static final short SASL_ID = ClassConfigurator.getProtocolId(SASL.class);
    @Property(name="mech", description="The name of the mech to require for authentication. Can be any mech supported by your local SASL provider. The JDK comes standard with CRAM-MD5, DIGEST-MD5, GSSAPI, NTLM")
    protected String mech;
    @Property(name="sasl_props", description="Properties specific to the chosen mech", converter=PropertyConverters.StringProperties.class)
    protected Map<String, String> sasl_props = new HashMap<String, String>();
    @Property(name="timeout", description="How long to wait (in ms) for a response to a challenge")
    protected long timeout = 5000L;
    protected CallbackHandler callback_handler;
    protected Address local_addr;
    protected final Map<Address, SaslContext> sasl_context = new HashMap<Address, SaslContext>();

    public SASL() {
        this.name = this.getClass().getSimpleName();
    }

    @Property(name="callback_handler_class")
    public void setCallbackHandlerClass(String handlerClass) throws Exception {
        this.callback_handler = Class.forName(handlerClass).asSubclass(CallbackHandler.class).newInstance();
    }

    public String getCallbackHandlerClass() {
        return this.callback_handler != null ? this.callback_handler.getClass().getName() : null;
    }

    public CallbackHandler getCallbackHandler() {
        return this.callback_handler;
    }

    public void setCallbackHandler(CallbackHandler callback_handler) {
        this.callback_handler = callback_handler;
    }

    public void setMech(String mech) {
        this.mech = mech;
    }

    public String getMech() {
        return this.mech;
    }

    public void setSaslProps(Map<String, String> sasl_props) {
        this.sasl_props = sasl_props;
    }

    public Map<String, String> getSaslProps() {
        return this.sasl_props;
    }

    public void setTimeout(long timeout) {
        this.timeout = timeout;
    }

    public long getTimeout() {
        return this.timeout;
    }

    public Address getAddress() {
        return this.local_addr;
    }

    @Override
    public void init() throws Exception {
        super.init();
    }

    @Override
    public void stop() {
        super.stop();
        this.cleanup();
    }

    @Override
    public void destroy() {
        super.destroy();
        this.cleanup();
    }

    private void cleanup() {
        for (SaslContext context : this.sasl_context.values()) {
            context.dispose();
        }
        this.sasl_context.clear();
    }

    @Override
    public Object up(Event evt) {
        if (evt.getType() == 1) {
            Message msg = (Message)evt.getArg();
            SaslHeader saslHeader = (SaslHeader)msg.getHeader(SASL_ID);
            GMS.GmsHeader gmsHeader = (GMS.GmsHeader)msg.getHeader(GMS_ID);
            if (SASL.needsAuthentication(gmsHeader)) {
                if (saslHeader == null) {
                    throw new IllegalStateException("Found GMS join or merge request but no SASL header");
                }
                if (!this.serverChallenge(gmsHeader, saslHeader, msg)) {
                    return null;
                }
            } else if (saslHeader != null) {
                Address remoteAddress = msg.getSrc();
                SaslContext saslContext = this.sasl_context.get(remoteAddress);
                if (saslContext == null) {
                    throw new IllegalStateException(String.format("Cannot find server context to challenge SASL request from %s", remoteAddress.toString()));
                }
                switch (saslHeader.getType()) {
                    case CHALLENGE: {
                        try {
                            if (this.log.isTraceEnabled()) {
                                this.log.trace("%s: received CHALLENGE from %s", this.getAddress(), remoteAddress);
                            }
                            Message response = saslContext.nextMessage(remoteAddress, saslHeader);
                            if (this.log.isTraceEnabled()) {
                                this.log.trace("%s: sending RESPONSE to %s", this.getAddress(), remoteAddress);
                            }
                            this.down_prot.down(new Event(1, response));
                        }
                        catch (SaslException e) {
                            this.disposeContext(remoteAddress);
                            if (this.log.isWarnEnabled()) {
                                this.log.warn("failed to validate CHALLENGE from " + remoteAddress + ", token", e);
                            }
                            this.sendRejectionMessage(gmsHeader.getType(), remoteAddress, "authentication failed");
                        }
                        break;
                    }
                    case RESPONSE: {
                        try {
                            Message challenge;
                            if (this.log.isTraceEnabled()) {
                                this.log.trace("%s: received RESPONSE from %s", this.getAddress(), remoteAddress);
                            }
                            if ((challenge = saslContext.nextMessage(remoteAddress, saslHeader)) != null) {
                                if (this.log.isTraceEnabled()) {
                                    this.log.trace("%s: sending CHALLENGE to %s", this.getAddress(), remoteAddress);
                                }
                                this.down_prot.down(new Event(1, challenge));
                                break;
                            }
                            if (!this.log.isTraceEnabled()) break;
                            this.log.trace("%s: authentication complete from %s", this.getAddress(), remoteAddress);
                            break;
                        }
                        catch (SaslException e) {
                            this.disposeContext(remoteAddress);
                            if (!this.log.isWarnEnabled()) break;
                            this.log.warn("failed to validate RESPONSE from " + remoteAddress + ", token", e);
                        }
                    }
                }
                return null;
            }
        }
        return this.up_prot.up(evt);
    }

    private void disposeContext(Address address) {
        SaslContext context = this.sasl_context.remove(address);
        if (context != null) {
            context.dispose();
        }
    }

    @Override
    public void up(MessageBatch batch) {
        for (Message msg : batch) {
            GMS.GmsHeader gmsHeader = (GMS.GmsHeader)msg.getHeader(GMS_ID);
            if (!SASL.needsAuthentication(gmsHeader)) continue;
            SaslHeader saslHeader = (SaslHeader)msg.getHeader(this.id);
            if (saslHeader == null) {
                this.log.warn("Found GMS join or merge request but no SASL header");
                this.sendRejectionMessage(gmsHeader.getType(), batch.sender(), "join or merge without an SASL header");
                batch.remove(msg);
                continue;
            }
            if (this.serverChallenge(gmsHeader, saslHeader, msg)) continue;
            batch.remove(msg);
        }
        if (!batch.isEmpty()) {
            this.up_prot.up(batch);
        }
    }

    @Override
    public Object down(Event evt) {
        switch (evt.getType()) {
            case 8: {
                this.local_addr = (Address)evt.getArg();
                break;
            }
            case 1: {
                Message msg = (Message)evt.getArg();
                GMS.GmsHeader hdr = (GMS.GmsHeader)msg.getHeader(GMS_ID);
                if (!SASL.needsAuthentication(hdr)) break;
                SaslClientContext ctx = null;
                Address remoteAddress = msg.getDest();
                try {
                    ctx = new SaslClientContext(this.mech, remoteAddress, this.callback_handler, this.sasl_props);
                    this.sasl_context.put(remoteAddress, ctx);
                    ctx.addHeader(msg, null);
                    break;
                }
                catch (SaslException e) {
                    if (ctx != null) {
                        this.disposeContext(remoteAddress);
                    }
                    throw new SecurityException(e);
                }
            }
        }
        return this.down_prot.down(evt);
    }

    protected static boolean needsAuthentication(GMS.GmsHeader hdr) {
        return hdr != null && (hdr.getType() == 1 || hdr.getType() == 11 || hdr.getType() == 6);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     * Unable to fully structure code
     */
    protected boolean serverChallenge(GMS.GmsHeader gmsHeader, SaslHeader saslHeader, Message msg) {
        switch (gmsHeader.getType()) {
            case 1: 
            case 6: 
            case 11: {
                remoteAddress = msg.getSrc();
                ctx = null;
                ctx = new SaslServerContext(this.mech, this.local_addr, this.callback_handler, this.sasl_props);
                this.sasl_context.put(remoteAddress, ctx);
                this.getDownProtocol().down(new Event(1, ctx.nextMessage(remoteAddress, saslHeader)));
                ctx.awaitCompletion(this.timeout);
                if (!ctx.isSuccessful()) ** GOTO lbl19
                if (this.log.isDebugEnabled()) {
                    this.log.debug("Authorization successful for %s", new Object[]{ctx.getAuthorizationID()});
                }
                var6_6 = true;
                if (ctx != null && !ctx.needsWrapping()) {
                    this.disposeContext(remoteAddress);
                }
                return var6_6;
lbl19:
                // 1 sources

                try {
                    this.log.warn("failed to validate SaslHeader from %s, header: %s", new Object[]{msg.getSrc(), saslHeader});
                    this.sendRejectionMessage(gmsHeader.getType(), msg.getSrc(), "authentication failed");
                    var6_7 = false;
                    if (ctx != null && !ctx.needsWrapping()) {
                        this.disposeContext(remoteAddress);
                    }
                    return var6_7;
                }
                catch (SaslException e) {
                    this.log.warn("failed to validate SaslHeader from %s, header: %s", new Object[]{msg.getSrc(), saslHeader});
                    this.sendRejectionMessage(gmsHeader.getType(), msg.getSrc(), "authentication failed");
                    if (ctx == null || ctx.needsWrapping()) break;
                    this.disposeContext(remoteAddress);
                    break;
                }
                catch (InterruptedException e) {
                    var7_10 = false;
                    return var7_10;
                    {
                        catch (Throwable var8_11) {
                            throw var8_11;
                        }
                    }
                    finally {
                        if (ctx != null && !ctx.needsWrapping()) {
                            this.disposeContext(remoteAddress);
                        }
                    }
                }
            }
        }
        return true;
    }

    protected void sendRejectionMessage(byte type, Address dest, String error_msg) {
        switch (type) {
            case 1: 
            case 11: {
                this.sendJoinRejectionMessage(dest, error_msg);
                break;
            }
            case 6: {
                this.sendMergeRejectionMessage(dest);
                break;
            }
            default: {
                this.log.error("type " + type + " unknown");
            }
        }
    }

    protected void sendJoinRejectionMessage(Address dest, String error_msg) {
        if (dest == null) {
            return;
        }
        JoinRsp joinRes = new JoinRsp(error_msg);
        Message msg = new Message(dest).putHeader(GMS_ID, new GMS.GmsHeader(2)).setBuffer(GMS.marshal(joinRes));
        this.down_prot.down(new Event(1, msg));
    }

    protected void sendMergeRejectionMessage(Address dest) {
        Message msg = new Message(dest).setFlag(Message.Flag.OOB);
        GMS.GmsHeader hdr = new GMS.GmsHeader(7);
        hdr.setMergeRejected(true);
        msg.putHeader(GMS_ID, hdr);
        if (this.log.isDebugEnabled()) {
            this.log.debug("merge response=" + hdr);
        }
        this.down_prot.down(new Event(1, msg));
    }
}

