/*
 * Copyright 2012 The Netty Project
 *
 * The Netty Project licenses this file to you under the Apache License,
 * version 2.0 (the "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at:
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 */
package org.jboss.netty.channel.socket.http;

import java.io.EOFException;
import java.io.IOException;
import java.io.PushbackInputStream;
import java.net.SocketAddress;

import javax.servlet.ServletConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletOutputStream;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import org.jboss.netty.buffer.ChannelBuffer;
import org.jboss.netty.buffer.ChannelBuffers;
import org.jboss.netty.channel.Channel;
import org.jboss.netty.channel.ChannelFactory;
import org.jboss.netty.channel.ChannelFuture;
import org.jboss.netty.channel.ChannelFutureListener;
import org.jboss.netty.channel.ChannelHandlerContext;
import org.jboss.netty.channel.ChannelPipeline;
import org.jboss.netty.channel.Channels;
import org.jboss.netty.channel.ExceptionEvent;
import org.jboss.netty.channel.MessageEvent;
import org.jboss.netty.channel.SimpleChannelUpstreamHandler;
import org.jboss.netty.channel.local.DefaultLocalClientChannelFactory;
import org.jboss.netty.channel.local.LocalAddress;
import org.jboss.netty.handler.codec.http.HttpHeaders;
import org.jboss.netty.logging.InternalLogger;
import org.jboss.netty.logging.InternalLoggerFactory;

/**
 * An {@link HttpServlet} that proxies an incoming data to the actual server
 * and vice versa.  Please refer to the
 * <a href="package-summary.html#package_description">package summary</a> for
 * the detailed usage.
 * @apiviz.landmark
 */
public class HttpTunnelingServlet extends HttpServlet {

    private static final long serialVersionUID = 4259910275899756070L;

    private static final String ENDPOINT = "endpoint";
    private static final String CONNECT_ATTEMPTS = "connectAttempts";
    private static final String RETRY_DELAY = "retryDelay";

    static final InternalLogger logger = InternalLoggerFactory.getInstance(HttpTunnelingServlet.class);

    private volatile SocketAddress remoteAddress;
    private volatile ChannelFactory channelFactory;
    private volatile long connectAttempts = 1;
    private volatile long retryDelay;

    @Override
    public void init() throws ServletException {
        ServletConfig config = getServletConfig();
        String endpoint = config.getInitParameter(ENDPOINT);
        if (endpoint == null) {
            throw new ServletException("init-param '" + ENDPOINT + "' must be specified.");
        }

        try {
            remoteAddress = parseEndpoint(endpoint.trim());
        } catch (ServletException e) {
            throw e;
        } catch (Exception e) {
            throw new ServletException("Failed to parse an endpoint.", e);
        }

        try {
            channelFactory = createChannelFactory(remoteAddress);
        } catch (ServletException e) {
            throw e;
        } catch (Exception e) {
            throw new ServletException("Failed to create a channel factory.", e);
        }

        String temp = config.getInitParameter(CONNECT_ATTEMPTS);
        if (temp != null) {
            try {
                connectAttempts = Long.parseLong(temp);
            } catch (NumberFormatException e) {
                throw new ServletException(
                   "init-param '" + CONNECT_ATTEMPTS + "' is not a valid number. Actual value: " + temp);
            }
            if (connectAttempts < 1) {
                throw new ServletException(
                   "init-param '" + CONNECT_ATTEMPTS + "' must be >= 1. Actual value: " + connectAttempts);
            }
        }

        temp = config.getInitParameter(RETRY_DELAY);
        if (temp != null) {
            try {
                retryDelay = Long.parseLong(temp);
            } catch (NumberFormatException e) {
                throw new ServletException(
                   "init-param '" + RETRY_DELAY + "' is not a valid number. Actual value: " + temp);
            }
            if (retryDelay < 0) {
                throw new ServletException(
                   "init-param '" + RETRY_DELAY + "' must be >= 0. Actual value: " + retryDelay);
            }
        }

        // Stuff for testing purpose
        //ServerBootstrap b = new ServerBootstrap(new DefaultLocalServerChannelFactory());
        //b.getPipeline().addLast("logger", new LoggingHandler(getClass(), InternalLogLevel.INFO, true));
        //b.getPipeline().addLast("handler", new EchoHandler());
        //b.bind(remoteAddress);
    }

    protected SocketAddress parseEndpoint(String endpoint) throws Exception {
        if (endpoint.startsWith("local:")) {
            return new LocalAddress(endpoint.substring(6).trim());
        } else {
            throw new ServletException(
                    "Invalid or unknown endpoint: " + endpoint);
        }
    }

    protected ChannelFactory createChannelFactory(SocketAddress remoteAddress) throws Exception {
        if (remoteAddress instanceof LocalAddress) {
            return new DefaultLocalClientChannelFactory();
        } else {
            throw new ServletException(
                    "Unsupported remote address type: " +
                    remoteAddress.getClass().getName());
        }
    }

    @Override
    public void destroy() {
        try {
            destroyChannelFactory(channelFactory);
        } catch (Exception e) {
            if (logger.isWarnEnabled()) {
                logger.warn("Failed to destroy a channel factory.", e);
            }
        }
    }

    protected void destroyChannelFactory(ChannelFactory factory) throws Exception {
        factory.releaseExternalResources();
    }

    @Override
    protected void service(HttpServletRequest req, HttpServletResponse res)
            throws ServletException, IOException {
        if (!"POST".equalsIgnoreCase(req.getMethod())) {
            if (logger.isWarnEnabled()) {
                logger.warn("Unallowed method: " + req.getMethod());
            }
            res.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED);
            return;
        }

        final ChannelPipeline pipeline = Channels.pipeline();
        final ServletOutputStream out = res.getOutputStream();
        final OutboundConnectionHandler handler = new OutboundConnectionHandler(out);
        pipeline.addLast("handler", handler);

        Channel channel = channelFactory.newChannel(pipeline);
        int tries = 0;
        ChannelFuture future = null;

        while (tries < connectAttempts) {
            future = channel.connect(remoteAddress).awaitUninterruptibly();
            if (!future.isSuccess()) {
                tries++;
                try {
                    Thread.sleep(retryDelay);
                } catch (InterruptedException e) {
                    // ignore
                }
            } else {
                break;
            }
        }

        if (!future.isSuccess()) {
            if (logger.isWarnEnabled()) {
                Throwable cause = future.getCause();
                logger.warn("Endpoint unavailable: " + cause.getMessage(), cause);
            }
            res.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
            return;
        }

        ChannelFuture lastWriteFuture = null;
        try {
            res.setStatus(HttpServletResponse.SC_OK);
            res.setHeader(HttpHeaders.Names.CONTENT_TYPE, "application/octet-stream");
            res.setHeader(HttpHeaders.Names.CONTENT_TRANSFER_ENCODING, HttpHeaders.Values.BINARY);

            // Initiate chunked encoding by flushing the headers.
            out.flush();

            PushbackInputStream in =
                    new PushbackInputStream(req.getInputStream());
            while (channel.isConnected()) {
                ChannelBuffer buffer;
                try {
                    buffer = read(in);
                } catch (EOFException e) {
                    break;
                }
                if (buffer == null) {
                    break;
                }
                lastWriteFuture = channel.write(buffer);
            }
        } finally {
            if (lastWriteFuture == null) {
                channel.close();
            } else {
                lastWriteFuture.addListener(ChannelFutureListener.CLOSE);
            }
        }
    }

    private static ChannelBuffer read(PushbackInputStream in) throws IOException {
        byte[] buf;
        int readBytes;

        int bytesToRead = in.available();
        if (bytesToRead > 0) {
            buf = new byte[bytesToRead];
            readBytes = in.read(buf);
        } else if (bytesToRead == 0) {
            int b = in.read();
            if (b < 0 || in.available() < 0) {
                return null;
            }
            in.unread(b);
            bytesToRead = in.available();
            buf = new byte[bytesToRead];
            readBytes = in.read(buf);
        } else {
            return null;
        }

        assert readBytes > 0;

        ChannelBuffer buffer;
        if (readBytes == buf.length) {
            buffer = ChannelBuffers.wrappedBuffer(buf);
        } else {
            // A rare case, but it sometimes happen.
            buffer = ChannelBuffers.wrappedBuffer(buf, 0, readBytes);
        }
        return buffer;
    }

    private static final class OutboundConnectionHandler extends SimpleChannelUpstreamHandler {

        private final ServletOutputStream out;

        public OutboundConnectionHandler(ServletOutputStream out) {
            this.out = out;
        }

        @Override
        public void messageReceived(ChannelHandlerContext ctx, MessageEvent e) throws Exception {
            ChannelBuffer buffer = (ChannelBuffer) e.getMessage();
            synchronized (this) {
                buffer.readBytes(out, buffer.readableBytes());
                out.flush();
            }
        }

        @Override
        public void exceptionCaught(ChannelHandlerContext ctx, ExceptionEvent e) throws Exception {
            if (logger.isWarnEnabled()) {
                logger.warn("Unexpected exception while HTTP tunneling", e.getCause());
            }
            e.getChannel().close();
        }
    }
}
