/*
 * JBoss, Home of Professional Open Source.
 * Copyright 2008, Red Hat Middleware LLC, and individual contributors
 * as indicated by the @author tags. See the copyright.txt file in the
 * distribution for a full listing of individual contributors.
 *
 * This is free software; you can redistribute it and/or modify it
 * under the terms of the GNU Lesser General Public License as
 * published by the Free Software Foundation; either version 2.1 of
 * the License, or (at your option) any later version.
 *
 * This software is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
 * License along with this software; if not, write to the Free
 * Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
 * 02110-1301 USA, or see the FSF site: http://www.fsf.org.
 */

package org.picketlink.identity.federation.bindings.wildfly.sp;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Deque;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

import io.undertow.server.HttpServerExchange;
import io.undertow.server.handlers.form.FormData;
import io.undertow.servlet.handlers.ServletRequestContext;
import io.undertow.servlet.spec.HttpServletRequestImpl;
import io.undertow.servlet.spec.PartImpl;
import io.undertow.servlet.spec.ServletContextImpl;

import javax.servlet.ReadListener;
import javax.servlet.ServletException;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.Part;

/**
 * HttpServletRequest wrapper which changes behavior of {@see ServletRequest.getParameter}.
 * In case origin getParameter() method returns null and there is a not null instance of {@see FormData), wrapper returns data from FormData instead.
 * Wrapper is a part of the fix of JBEAP-10449.
 *
 * @Author Jiri Ondrusek
 */
public class SPFormAuthenticationRequestWrapper extends HttpServletRequestWrapper {

    private final FormData formData;
    private List<Part> parts = null;
    private ServletInputStream servletInputStream;

    public SPFormAuthenticationRequestWrapper(HttpServletRequest request, FormData formData, byte[] bytes) {
        super(request);
        this.formData = formData;
        ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(bytes);
        this.servletInputStream = new ServletInputStream() {
           @Override
           public int read() throws IOException {
               return byteArrayInputStream.read();
           }

           @Override
           public boolean isFinished() {
               return byteArrayInputStream.available() <= 0;
           }

           @Override
           public boolean isReady() {
               return !isFinished();
           }

           @Override
           public void setReadListener(ReadListener readListener) {
              throw new IllegalStateException("Cannot set ReadListener: not an async request.");
           }
       };
    }

    @Override
    public String getParameter(String name) {
        String retVal = super.getParameter(name);

        if(retVal == null && formData != null) {
            FormData.FormValue formValue = formData.getFirst(name);
            if (formValue != null) {
                retVal = formValue.getValue();
            }
        }
        return retVal;
    }

    @Override
    public String[] getParameterValues(String name)
    {
        String paramsValues[] = super.getParameterValues(name);
        List<String> retVal = paramsValues != null ? Arrays.asList(paramsValues) : null;
        if(retVal == null && formData != null) {
            retVal = new ArrayList();
            Deque formValues = formData.get(name);
            if(formValues != null) {
                retVal = getValuesFromForm(formValues);
            }
        }
        return retVal.isEmpty()?null:retVal.toArray(new String[retVal.size()]);
    }

    @Override
    public Map<String, String[]> getParameterMap() {
        Map<String, String[]> retVal = super.getParameterMap();
        if (formData != null) {
            Iterator<String> formKeysIterator = formData.iterator();
            while (formKeysIterator.hasNext()) {
                String key = formKeysIterator.next();
                if (!retVal.containsKey(key)) {
                    Deque<FormData.FormValue> formValues = formData.get(key);
                    if (formValues != null) {
                        ArrayList<String> values = getValuesFromForm(formValues);
                        if (!values.isEmpty()) {
                            retVal.put(key, values.toArray(new String[]{}));
                        }
                    }
                }
            }
        }

        return retVal;
    }

    @Override
    public Enumeration<String> getParameterNames() {
        Enumeration<String> parameterNames = super.getParameterNames();
        Set<String> retVals = new HashSet<>();
        if (formData != null) {

            while (parameterNames.hasMoreElements()) {
                retVals.add(parameterNames.nextElement());
            }

            Iterator<String> formKeysIterator = formData.iterator();
            while (formKeysIterator.hasNext()) {
                String key = formKeysIterator.next();
                if (!retVals.contains(key)) {
                    retVals.add(key);
                }
            }
        }
        Iterator<String> retValueIterator = retVals.iterator();
        return new Enumeration<String>() {

            @Override
            public boolean hasMoreElements() {
                return retValueIterator.hasNext();
            }

            @Override
            public String nextElement() {
                return retValueIterator.next();
            }
        };
    }

    private ArrayList<String> getValuesFromForm(Deque<FormData.FormValue> formValues) {
        ArrayList<String> values = new ArrayList<>();
        Iterator iterator = formValues.iterator();
        while (iterator.hasNext()) {
            FormData.FormValue value = (FormData.FormValue) iterator.next();
            if (!value.isFile()) {
                values.add(value.getValue());
            }
        }
        return values;
    }

    @Override
    public Part getPart(String name) throws IOException, ServletException {
        Part part = super.getPart(name);
        if (part != null) {
            return part;
        }

        if (parts == null) {
            loadParts();
        }
        for (Part p : parts) {
            if (p.getName().equals(name)) {
                return p;
            }
        }
        return null;
    }

    @Override
    public Collection<Part> getParts() throws IOException, ServletException {
        Collection<Part> parts = super.getParts();
        if (parts != null && !parts.isEmpty()) {
            return parts;
        }

        if (this.parts == null) {
            loadParts();
        }
        return this.parts;
    }

    private void loadParts() {
        HttpServletRequestImpl request = (HttpServletRequestImpl) getRequest();
        HttpServerExchange exchange = request.getExchange();
        ServletContextImpl servletContext = request.getServletContext();
        final ServletRequestContext requestContext = exchange.getAttachment(ServletRequestContext.ATTACHMENT_KEY);

        if (parts == null) {
            // no need to check the mime-type - parent methods already handle that
            final List<Part> parts = new ArrayList<>();
            if(formData != null) {
                for (final String namedPart : formData) {
                    for (FormData.FormValue part : formData.get(namedPart)) {
                        parts.add(new PartImpl(namedPart,
                                part,
                                requestContext.getOriginalServletPathMatch().getServletChain().getManagedServlet().getMultipartConfig(),
                                servletContext, request));
                    }
                }
            }
            this.parts = parts;
        }
    }
    
    @Override
    public ServletInputStream getInputStream() throws IOException {
       return this.servletInputStream;
    }
}
