/*
 * Decompiled with CFR 0.152.
 */
package org.apache.tomcat.websocket.server;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.EnumSet;
import java.util.Map;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import javax.servlet.DispatcherType;
import javax.servlet.Filter;
import javax.servlet.FilterRegistration;
import javax.servlet.ServletContext;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.websocket.CloseReason;
import javax.websocket.DeploymentException;
import javax.websocket.Encoder;
import javax.websocket.Endpoint;
import javax.websocket.server.ServerContainer;
import javax.websocket.server.ServerEndpoint;
import javax.websocket.server.ServerEndpointConfig;
import org.apache.catalina.ThreadBindingListener;
import org.apache.tomcat.websocket.ClassIntrospecter;
import org.apache.tomcat.websocket.Constants;
import org.apache.tomcat.websocket.InstanceFactory;
import org.apache.tomcat.websocket.WsSession;
import org.apache.tomcat.websocket.WsWebSocketContainer;
import org.apache.tomcat.websocket.pojo.PojoEndpointServer;
import org.apache.tomcat.websocket.pojo.PojoMethodMapping;
import org.apache.tomcat.websocket.server.UpgradeUtil;
import org.apache.tomcat.websocket.server.UriTemplate;
import org.apache.tomcat.websocket.server.WsFilter;
import org.apache.tomcat.websocket.server.WsMappingResult;
import org.apache.tomcat.websocket.server.WsWriteTimeout;
import org.jboss.web.WebsocketsLogger;
import org.jboss.web.WebsocketsMessages;

public class WsServerContainer
extends WsWebSocketContainer
implements ServerContainer {
    private static final CloseReason AUTHENTICATED_HTTP_SESSION_CLOSED = new CloseReason(CloseReason.CloseCodes.VIOLATED_POLICY, WebsocketsMessages.MESSAGES.expiredHttpSession());
    private final WsWriteTimeout wsWriteTimeout = new WsWriteTimeout();
    private final ServletContext servletContext;
    private final Map<String, ServerEndpointConfig> configExactMatchMap = new ConcurrentHashMap<String, ServerEndpointConfig>();
    private final ConcurrentHashMap<Integer, SortedSet<TemplatePathMatch>> configTemplateMatchMap = new ConcurrentHashMap();
    private volatile boolean enforceNoAddAfterHandshake = Constants.STRICT_SPEC_COMPLIANCE;
    private volatile boolean addAllowed = true;
    private final ConcurrentHashMap<String, Set<WsSession>> authenticatedSessions = new ConcurrentHashMap();
    private final ExecutorService executorService;
    private final ThreadGroup threadGroup;
    private volatile boolean endpointsRegistered = false;
    private final ClassIntrospecter classIntrospecter;
    private final ThreadBindingListener threadBindingListener;

    WsServerContainer(ServletContext servletContext) {
        this.servletContext = servletContext;
        ClassIntrospecter classIntrospecter = (ClassIntrospecter)servletContext.getAttribute(ClassIntrospecter.class.getName());
        servletContext.removeAttribute(ClassIntrospecter.class.getName());
        if (classIntrospecter == null) {
            classIntrospecter = null;
        }
        this.classIntrospecter = classIntrospecter;
        ThreadBindingListener threadBindingListener = (ThreadBindingListener)servletContext.getAttribute(ThreadBindingListener.class.getName());
        servletContext.removeAttribute(ThreadBindingListener.class.getName());
        this.threadBindingListener = threadBindingListener != null ? threadBindingListener : DEFAULT_THREAD_BINDING_LISTENER;
        String value = servletContext.getInitParameter("org.apache.tomcat.websocket.binaryBufferSize");
        if (value != null) {
            this.setDefaultMaxBinaryMessageBufferSize(Integer.parseInt(value));
        }
        if ((value = servletContext.getInitParameter("org.apache.tomcat.websocket.textBufferSize")) != null) {
            this.setDefaultMaxTextMessageBufferSize(Integer.parseInt(value));
        }
        if ((value = servletContext.getInitParameter("org.apache.tomcat.websocket.noAddAfterHandshake")) != null) {
            this.setEnforceNoAddAfterHandshake(Boolean.parseBoolean(value));
        }
        int executorCoreSize = 0;
        int executorMaxSize = 200;
        long executorKeepAliveTimeSeconds = 60L;
        value = servletContext.getInitParameter("org.apache.tomcat.websocket.executorCoreSize");
        if (value != null) {
            executorCoreSize = Integer.parseInt(value);
        }
        if ((value = servletContext.getInitParameter("org.apache.tomcat.websocket.executorMaxSize")) != null) {
            executorMaxSize = Integer.parseInt(value);
        }
        if ((value = servletContext.getInitParameter("org.apache.tomcat.websocket.executorKeepAliveTimeSeconds")) != null) {
            executorKeepAliveTimeSeconds = Long.parseLong(value);
        }
        FilterRegistration.Dynamic fr = servletContext.addFilter("Tomcat WebSocket (JSR356) Filter", (Filter)new WsFilter());
        fr.setAsyncSupported(true);
        EnumSet<DispatcherType> types = EnumSet.of(DispatcherType.REQUEST, DispatcherType.FORWARD);
        fr.addMappingForUrlPatterns(types, true, new String[]{"/*"});
        StringBuffer threadGroupName = new StringBuffer("WebSocketServer-");
        if ("".equals(servletContext.getContextPath())) {
            threadGroupName.append("ROOT");
        } else {
            threadGroupName.append(servletContext.getContextPath());
        }
        this.threadGroup = new ThreadGroup(threadGroupName.toString());
        WsThreadFactory wsThreadFactory = new WsThreadFactory(this.threadGroup);
        this.executorService = new ThreadPoolExecutor(executorCoreSize, executorMaxSize, executorKeepAliveTimeSeconds, TimeUnit.SECONDS, new LinkedBlockingQueue<Runnable>(), wsThreadFactory);
    }

    @Override
    public void addEndpoint(ServerEndpointConfig sec) throws DeploymentException {
        if (this.enforceNoAddAfterHandshake && !this.addAllowed) {
            throw new DeploymentException(WebsocketsMessages.MESSAGES.addNotAllowed());
        }
        if (this.servletContext == null) {
            throw new DeploymentException(WebsocketsMessages.MESSAGES.missingServletContext());
        }
        String path = sec.getPath();
        UriTemplate uriTemplate = new UriTemplate(path);
        if (uriTemplate.hasParameters()) {
            Integer key = uriTemplate.getSegmentCount();
            SortedSet<TemplatePathMatch> templateMatches = this.configTemplateMatchMap.get(key);
            if (templateMatches == null) {
                templateMatches = new TreeSet<TemplatePathMatch>(TemplatePathMatchComparator.getInstance());
                this.configTemplateMatchMap.putIfAbsent(key, templateMatches);
                templateMatches = this.configTemplateMatchMap.get(key);
            }
            if (!templateMatches.add(new TemplatePathMatch(sec, uriTemplate))) {
                throw new DeploymentException(WebsocketsMessages.MESSAGES.duplicatePaths(path));
            }
        } else {
            ServerEndpointConfig old = this.configExactMatchMap.put(path, sec);
            if (old != null) {
                throw new DeploymentException(WebsocketsMessages.MESSAGES.duplicatePaths(path));
            }
        }
        this.endpointsRegistered = true;
    }

    @Override
    public void addEndpoint(Class<?> pojo) throws DeploymentException {
        ServerEndpoint annotation = pojo.getAnnotation(ServerEndpoint.class);
        if (annotation == null) {
            throw new DeploymentException(WebsocketsMessages.MESSAGES.cannotDeployPojo(pojo.getName()));
        }
        String path = annotation.value();
        WsServerContainer.validateEncoders(annotation.encoders());
        PojoMethodMapping methodMapping = new PojoMethodMapping(pojo, annotation.decoders(), path);
        Class<? extends ServerEndpointConfig.Configurator> configuratorClazz = annotation.configurator();
        ServerEndpointConfig.Configurator configurator = null;
        if (!configuratorClazz.equals(ServerEndpointConfig.Configurator.class)) {
            try {
                configurator = annotation.configurator().newInstance();
            }
            catch (InstantiationException e) {
                throw new DeploymentException(WebsocketsMessages.MESSAGES.configuratorFailed(annotation.configurator().getName(), pojo.getClass().getName()), e);
            }
            catch (IllegalAccessException e) {
                throw new DeploymentException(WebsocketsMessages.MESSAGES.configuratorFailed(annotation.configurator().getName(), pojo.getClass().getName()), e);
            }
        }
        if (this.classIntrospecter != null) {
            try {
                configurator = new ServerInstanceFactoryConfigurator(this.classIntrospecter.createInstanceFactory(pojo));
            }
            catch (NoSuchMethodException e) {
                throw new DeploymentException(WebsocketsMessages.MESSAGES.configuratorFailed(ServerInstanceFactoryConfigurator.class.getName(), pojo.getClass().getName()), e);
            }
        }
        ServerEndpointConfig sec = ServerEndpointConfig.Builder.create(pojo, path).decoders(Arrays.asList(annotation.decoders())).encoders(Arrays.asList(annotation.encoders())).subprotocols(Arrays.asList(annotation.subprotocols())).configurator(configurator).build();
        sec.getUserProperties().put("org.apache.tomcat.websocket.pojo.PojoEndpoint.methodMapping", methodMapping);
        this.addEndpoint(sec);
    }

    @Override
    public void destroy() {
        this.shutdownExecutor();
        super.destroy();
        try {
            this.threadGroup.destroy();
        }
        catch (IllegalThreadStateException itse) {
            WebsocketsLogger.ROOT_LOGGER.threadGroupNotDestryed(this.threadGroup.getName());
        }
    }

    boolean areEndpointsRegistered() {
        return this.endpointsRegistered;
    }

    public void doUpgrade(HttpServletRequest request, HttpServletResponse response, ServerEndpointConfig sec, Map<String, String> pathParams) throws ServletException, IOException {
        UpgradeUtil.doUpgrade(this, request, response, sec, pathParams);
    }

    public WsMappingResult findMapping(String path) {
        ServerEndpointConfig sec;
        if (this.addAllowed) {
            this.addAllowed = false;
        }
        if ((sec = this.configExactMatchMap.get(path)) != null) {
            return new WsMappingResult(sec, Collections.<String, String>emptyMap());
        }
        UriTemplate pathUriTemplate = null;
        try {
            pathUriTemplate = new UriTemplate(path);
        }
        catch (DeploymentException e) {
            return null;
        }
        Integer key = pathUriTemplate.getSegmentCount();
        SortedSet<TemplatePathMatch> templateMatches = this.configTemplateMatchMap.get(key);
        if (templateMatches == null) {
            return null;
        }
        Map<String, String> pathParams = null;
        for (TemplatePathMatch templateMatch : templateMatches) {
            pathParams = templateMatch.getUriTemplate().match(pathUriTemplate);
            if (pathParams == null) continue;
            sec = templateMatch.getConfig();
            break;
        }
        if (sec == null) {
            return null;
        }
        if (!PojoEndpointServer.class.isAssignableFrom(sec.getEndpointClass())) {
            sec.getUserProperties().put("org.apache.tomcat.websocket.pojo.PojoEndpoint.pathParams", pathParams);
        }
        return new WsMappingResult(sec, pathParams);
    }

    public boolean isEnforceNoAddAfterHandshake() {
        return this.enforceNoAddAfterHandshake;
    }

    public void setEnforceNoAddAfterHandshake(boolean enforceNoAddAfterHandshake) {
        this.enforceNoAddAfterHandshake = enforceNoAddAfterHandshake;
    }

    protected WsWriteTimeout getTimeout() {
        return this.wsWriteTimeout;
    }

    @Override
    protected void registerSession(Endpoint endpoint, WsSession wsSession) {
        super.registerSession(endpoint, wsSession);
        if (wsSession.isOpen() && wsSession.getUserPrincipal() != null && wsSession.getHttpSessionId() != null) {
            this.registerAuthenticatedSession(wsSession, wsSession.getHttpSessionId());
        }
    }

    @Override
    protected void unregisterSession(Endpoint endpoint, WsSession wsSession) {
        if (wsSession.getUserPrincipal() != null && wsSession.getHttpSessionId() != null) {
            this.unregisterAuthenticatedSession(wsSession, wsSession.getHttpSessionId());
        }
        super.unregisterSession(endpoint, wsSession);
    }

    private void registerAuthenticatedSession(WsSession wsSession, String httpSessionId) {
        Set<WsSession> wsSessions = this.authenticatedSessions.get(httpSessionId);
        if (wsSessions == null) {
            wsSessions = Collections.newSetFromMap(new ConcurrentHashMap());
            this.authenticatedSessions.putIfAbsent(httpSessionId, wsSessions);
            wsSessions = this.authenticatedSessions.get(httpSessionId);
        }
        wsSessions.add(wsSession);
    }

    private void unregisterAuthenticatedSession(WsSession wsSession, String httpSessionId) {
        Set<WsSession> wsSessions = this.authenticatedSessions.get(httpSessionId);
        if (wsSessions != null) {
            wsSessions.remove(wsSession);
        }
    }

    public void closeAuthenticatedSession(String httpSessionId) {
        Set<WsSession> wsSessions = this.authenticatedSessions.remove(httpSessionId);
        if (wsSessions != null && !wsSessions.isEmpty()) {
            for (WsSession wsSession : wsSessions) {
                try {
                    wsSession.close(AUTHENTICATED_HTTP_SESSION_CLOSED);
                }
                catch (IOException e) {}
            }
        }
    }

    ExecutorService getExecutorService() {
        return this.executorService;
    }

    @Override
    public ThreadBindingListener getThreadBindingListener() {
        return this.threadBindingListener;
    }

    @Override
    public ClassLoader getClassLoader() {
        return this.servletContext.getClassLoader();
    }

    private void shutdownExecutor() {
        if (this.executorService == null) {
            return;
        }
        this.executorService.shutdown();
        try {
            this.executorService.awaitTermination(10L, TimeUnit.SECONDS);
        }
        catch (InterruptedException interruptedException) {
            // empty catch block
        }
    }

    private static void validateEncoders(Class<? extends Encoder>[] encoders) throws DeploymentException {
        for (Class<? extends Encoder> encoder : encoders) {
            try {
                encoder.newInstance();
            }
            catch (InstantiationException e) {
                throw new DeploymentException(WebsocketsMessages.MESSAGES.cannotInstatiateEncoder(encoder.getName()), e);
            }
            catch (IllegalAccessException e) {
                throw new DeploymentException(WebsocketsMessages.MESSAGES.cannotInstatiateEncoder(encoder.getName()), e);
            }
        }
    }

    private static final class ServerInstanceFactoryConfigurator
    extends ServerEndpointConfig.Configurator {
        private final InstanceFactory factory;

        private ServerInstanceFactoryConfigurator(InstanceFactory factory) {
            this.factory = factory;
        }

        @Override
        public <T> T getEndpointInstance(Class<T> endpointClass) throws InstantiationException {
            return (T)this.factory.createInstance();
        }
    }

    private static class WsThreadFactory
    implements ThreadFactory {
        private final ThreadGroup tg;
        private final AtomicLong count = new AtomicLong(0L);

        private WsThreadFactory(ThreadGroup tg) {
            this.tg = tg;
        }

        @Override
        public Thread newThread(Runnable r) {
            Thread t = new Thread(this.tg, r);
            t.setName(this.tg.getName() + "-" + this.count.incrementAndGet());
            return t;
        }
    }

    private static class TemplatePathMatchComparator
    implements Comparator<TemplatePathMatch> {
        private static final TemplatePathMatchComparator INSTANCE = new TemplatePathMatchComparator();

        public static TemplatePathMatchComparator getInstance() {
            return INSTANCE;
        }

        private TemplatePathMatchComparator() {
        }

        @Override
        public int compare(TemplatePathMatch tpm1, TemplatePathMatch tpm2) {
            return tpm1.getUriTemplate().getNormalizedPath().compareTo(tpm2.getUriTemplate().getNormalizedPath());
        }
    }

    private static class TemplatePathMatch {
        private final ServerEndpointConfig config;
        private final UriTemplate uriTemplate;

        public TemplatePathMatch(ServerEndpointConfig config, UriTemplate uriTemplate) {
            this.config = config;
            this.uriTemplate = uriTemplate;
        }

        public ServerEndpointConfig getConfig() {
            return this.config;
        }

        public UriTemplate getUriTemplate() {
            return this.uriTemplate;
        }
    }
}

