/*
 * Decompiled with CFR 0.152.
 */
package org.apache.activemq.transport.ws.jetty9;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.apache.activemq.transport.Transport;
import org.apache.activemq.transport.TransportAcceptListener;
import org.apache.activemq.transport.TransportSupport;
import org.apache.activemq.transport.util.HttpTransportUtils;
import org.apache.activemq.transport.ws.jetty9.MQTTSocket;
import org.apache.activemq.transport.ws.jetty9.StompSocket;
import org.eclipse.jetty.websocket.servlet.ServletUpgradeRequest;
import org.eclipse.jetty.websocket.servlet.ServletUpgradeResponse;
import org.eclipse.jetty.websocket.servlet.WebSocketCreator;
import org.eclipse.jetty.websocket.servlet.WebSocketServlet;
import org.eclipse.jetty.websocket.servlet.WebSocketServletFactory;

public class WSServlet
extends WebSocketServlet {
    private static final long serialVersionUID = -4716657876092884139L;
    private TransportAcceptListener listener;
    private static final Map<String, Integer> stompProtocols = new ConcurrentHashMap<String, Integer>();
    private static final Map<String, Integer> mqttProtocols = new ConcurrentHashMap<String, Integer>();

    public void init() throws ServletException {
        super.init();
        this.listener = (TransportAcceptListener)this.getServletContext().getAttribute("acceptListener");
        if (this.listener == null) {
            throw new ServletException("No such attribute 'acceptListener' available in the ServletContext");
        }
    }

    protected void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
        this.getServletContext().getNamedDispatcher("default").forward((ServletRequest)request, (ServletResponse)response);
    }

    public void configure(WebSocketServletFactory factory) {
        factory.setCreator(new WebSocketCreator(){

            public Object createWebSocket(ServletUpgradeRequest req, ServletUpgradeResponse resp) {
                TransportSupport socket;
                boolean isMqtt = false;
                for (String subProtocol : req.getSubProtocols()) {
                    if (!subProtocol.startsWith("mqtt")) continue;
                    isMqtt = true;
                }
                if (isMqtt) {
                    socket = new MQTTSocket(HttpTransportUtils.generateWsRemoteAddress(req.getHttpServletRequest()));
                    resp.setAcceptedSubProtocol(WSServlet.this.getAcceptedSubProtocol(mqttProtocols, req.getSubProtocols(), "mqtt"));
                    socket.setPeerCertificates(req.getCertificates());
                } else {
                    socket = new StompSocket(HttpTransportUtils.generateWsRemoteAddress(req.getHttpServletRequest()));
                    ((StompSocket)socket).setCertificates(req.getCertificates());
                    resp.setAcceptedSubProtocol(WSServlet.this.getAcceptedSubProtocol(stompProtocols, req.getSubProtocols(), "stomp"));
                }
                WSServlet.this.listener.onAccept((Transport)socket);
                return socket;
            }
        });
    }

    private String getAcceptedSubProtocol(Map<String, Integer> protocols, List<String> subProtocols, String defaultProtocol) {
        ArrayList<SubProtocol> matchedProtocols = new ArrayList<SubProtocol>();
        if (subProtocols != null && subProtocols.size() > 0) {
            for (String subProtocol : subProtocols) {
                Integer priority = protocols.get(subProtocol);
                if (subProtocol == null || priority == null) continue;
                matchedProtocols.add(new SubProtocol(subProtocol, priority));
            }
            if (matchedProtocols.size() > 0) {
                Collections.sort(matchedProtocols, new Comparator<SubProtocol>(){

                    @Override
                    public int compare(SubProtocol s1, SubProtocol s2) {
                        return s2.priority.compareTo(s1.priority);
                    }
                });
                return ((SubProtocol)matchedProtocols.get(0)).protocol;
            }
        }
        return defaultProtocol;
    }

    static {
        stompProtocols.put("v12.stomp", 3);
        stompProtocols.put("v11.stomp", 2);
        stompProtocols.put("v10.stomp", 1);
        stompProtocols.put("stomp", 0);
        mqttProtocols.put("mqttv3.1", 1);
        mqttProtocols.put("mqtt", 0);
    }

    private class SubProtocol {
        private String protocol;
        private Integer priority;

        public SubProtocol(String protocol, Integer priority) {
            this.protocol = protocol;
            this.priority = priority;
        }
    }
}

