/*
 * JBoss, Home of Professional Open Source. Copyright 2009, Red Hat Middleware LLC, and individual contributors as
 * indicated by the @author tags. See the copyright.txt file in the distribution for a full listing of individual
 * contributors.
 * 
 * This is free software; you can redistribute it and/or modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation; either version 2.1 of the License, or (at your option) any
 * later version.
 * 
 * This software is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied
 * warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
 * details.
 * 
 * You should have received a copy of the GNU Lesser General Public License along with this software; if not, write to
 * the Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA, or see the FSF site:
 * http://www.fsf.org.
 */
package org.picketlink.identity.federation.core.wstrust;

import java.util.ArrayList;
import java.util.List;

import javax.xml.bind.Binder;
import javax.xml.bind.JAXBElement;
import javax.xml.bind.JAXBException;
import javax.xml.bind.Marshaller;
import javax.xml.transform.Source;

import org.apache.log4j.Logger;
import org.picketlink.identity.federation.core.saml.v2.common.SAMLDocumentHolder;
import org.picketlink.identity.federation.core.saml.v2.util.DocumentUtil;
import org.picketlink.identity.federation.core.util.JAXBUtil;
import org.picketlink.identity.federation.core.wstrust.wrappers.BaseRequestSecurityToken;
import org.picketlink.identity.federation.core.wstrust.wrappers.BaseRequestSecurityTokenResponse;
import org.picketlink.identity.federation.core.wstrust.wrappers.RequestSecurityToken;
import org.picketlink.identity.federation.core.wstrust.wrappers.RequestSecurityTokenCollection;
import org.picketlink.identity.federation.core.wstrust.wrappers.RequestSecurityTokenResponse;
import org.picketlink.identity.federation.core.wstrust.wrappers.RequestSecurityTokenResponseCollection;
import org.picketlink.identity.federation.ws.trust.ObjectFactory;
import org.picketlink.identity.federation.ws.trust.RequestSecurityTokenCollectionType;
import org.picketlink.identity.federation.ws.trust.RequestSecurityTokenResponseCollectionType;
import org.picketlink.identity.federation.ws.trust.RequestSecurityTokenType;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;

/**
 * <p>
 * This factory implements utility methods for converting between JAXB model objects and XML source.
 * </p>
 * 
 * @author <a href="mailto:sguilhen@redhat.com">Stefan Guilhen</a>
 */
public class WSTrustJAXBFactory
{
   private static Logger log = Logger.getLogger(WSTrustJAXBFactory.class);

   private boolean trace = log.isTraceEnabled();

   private static final WSTrustJAXBFactory instance = new WSTrustJAXBFactory();

   private Marshaller marshaller;

   private Binder<Node> binder;

   private final ObjectFactory objectFactory;

   private ThreadLocal<SAMLDocumentHolder> holders = new ThreadLocal<SAMLDocumentHolder>();

   /**
    * <p>
    * Creates the {@code WSTrustJAXBFactory} singleton instance.
    * </p>
    */
   private WSTrustJAXBFactory()
   {
      try
      {
         this.marshaller = JAXBUtil.getMarshaller(this.getPackages());
         this.binder = JAXBUtil.getJAXBContext(this.getPackages()).createBinder();
         this.objectFactory = new ObjectFactory();
      }
      catch (JAXBException e)
      {
         throw new RuntimeException(e.getMessage(), e);
      }
   }

   /**
    * <p>
    * Gets a reference to the singleton instance.
    * </p>
    * 
    * @return a reference to the {@code WSTrustJAXBFactory} instance.
    */
   public static WSTrustJAXBFactory getInstance()
   {
      return instance;
   }

   private String getPackages()
   {
      StringBuilder packages = new StringBuilder();
      packages.append("org.picketlink.identity.federation.ws.addressing");
      packages.append(":org.picketlink.identity.federation.ws.policy");
      packages.append(":org.picketlink.identity.federation.ws.trust");
      packages.append(":org.picketlink.identity.federation.ws.wss.secext");
      packages.append(":org.picketlink.identity.federation.ws.wss.utility");
      packages.append(":org.picketlink.identity.xmlsec.w3.xmldsig");
      return packages.toString();
   }

   /**
    * <p>
    * Creates a {@code BaseRequestSecurityToken} from the specified XML source.
    * </p>
    * 
    * @param request the XML source containing the security token request message.
    * @return the constructed {@code BaseRequestSecurityToken} instance. It will be an instance of {@code
    *         RequestSecurityToken} the message contains a single token request, and an instance of {@code
    *         RequestSecurityTokenCollection} if multiples requests are being made in the same message.
    */
   public BaseRequestSecurityToken parseRequestSecurityToken(Source request) throws WSTrustException
   {
      try
      {
         Node documentNode = DocumentUtil.getNodeFromSource(request);
         Document document = documentNode instanceof Document ? (Document) documentNode : documentNode
               .getOwnerDocument();
         Node rst = this.findNodeByNameNS(document, "RequestSecurityToken", WSTrustConstants.BASE_NAMESPACE);
         if (rst == null)
            throw new RuntimeException("The request document must contain at least one RequestSecurityToken node");

         JAXBElement<?> jaxbRST = (JAXBElement<?>) binder.unmarshal(document);
         if (jaxbRST.getDeclaredType().equals(RequestSecurityTokenType.class))
         {
            RequestSecurityTokenType rstt = (RequestSecurityTokenType) jaxbRST.getValue();
            RequestSecurityToken requestSecToken = new RequestSecurityToken(rstt);
            requestSecToken.setRSTDocument(document);
            return requestSecToken;
         }
         else if (jaxbRST.getDeclaredType().equals(RequestSecurityTokenCollectionType.class))
         {
            RequestSecurityTokenCollectionType rstct = (RequestSecurityTokenCollectionType) jaxbRST.getValue();
            RequestSecurityTokenCollection requestSecTokenCollection = new RequestSecurityTokenCollection(rstct,
                  document);
            return requestSecTokenCollection;
         }
         else
            throw new WSTrustException("Request message doesn't contain a valid request type");
      }
      catch (Exception e)
      {
         throw new WSTrustException("Error parsing security token request", e);
      }
   }

   /**
    * <p>
    * Creates a {@code BaseRequestSecurityTokenResponse} from the specified XML source.
    * </p>
    * 
    * @param response the XML source containing the security token response message.
    * @return the constructed {@code BaseRequestSecurityTokenResponse} instance. According to the WS-Trust
    *         specification, the returned object will be an instance of {@code RequestSecurityTokenResponseCollection}.
    */ 
   public BaseRequestSecurityTokenResponse parseRequestSecurityTokenResponse(Source response) throws WSTrustException
   {
      Node documentNode = null;
      try
      {
         documentNode = DocumentUtil.getNodeFromSource(response);
      }
      catch (Exception e)
      {
         throw new WSTrustException("Failed to transform request source", e);
      }

      try
      {
         Object object = this.binder.unmarshal(documentNode);
         if (object instanceof JAXBElement<?>)
         {
            JAXBElement<?> element = (JAXBElement<?>) object;
            if (element.getDeclaredType().equals(RequestSecurityTokenResponseCollectionType.class))
            {
               RequestSecurityTokenResponseCollection collection = new RequestSecurityTokenResponseCollection(
                     (RequestSecurityTokenResponseCollectionType) element.getValue());
               return collection;
            }
            else
               throw new RuntimeException("Invalid response type: " + element.getDeclaredType());
         }
         else
            throw new RuntimeException("Invalid response type: " + object.getClass().getName());
      }
      catch (JAXBException e)
      {
         throw new RuntimeException("Failed to unmarshall security token response", e);
      }
   }

   /**
    * <p>
    * Creates a {@code javax.xml.transform.Source} from the specified request object.
    * </p>
    * 
    * @param request a {@code RequestSecurityToken} representing the object model of the security token request.
    * @return the constructed {@code Source} instance.
    */
   public Source marshallRequestSecurityToken(RequestSecurityToken request)
   {
      Element targetElement = null;
      // if the request has a validate, cancel, or renew target, we must preserve it from JAXB marshaling.
      String requestType = request.getRequestType().toString();
      if (requestType.equalsIgnoreCase(WSTrustConstants.VALIDATE_REQUEST) && request.getValidateTarget() != null)
      {
         targetElement = (Element) request.getValidateTarget().getAny();
         request.getValidateTarget().setAny(null);
      }
      else if (requestType.equalsIgnoreCase(WSTrustConstants.RENEW_REQUEST) && request.getRenewTarget() != null)
      {
         targetElement = (Element) request.getRenewTarget().getAny();
         request.getRenewTarget().setAny(null);
      }
      else if (requestType.equalsIgnoreCase(WSTrustConstants.CANCEL_REQUEST) && request.getCancelTarget() != null)
      {
         targetElement = (Element) request.getCancelTarget().getAny();
         request.getCancelTarget().setAny(null);
      }

      Document result = null;
      try
      {
         result = DocumentUtil.createDocument();
         this.binder.marshal(this.objectFactory.createRequestSecurityToken(request.getDelegate()), result);

         // insert the original target in the appropriate element.
         if (targetElement != null)
         {
            Node node = null;
            if (requestType.equalsIgnoreCase(WSTrustConstants.VALIDATE_REQUEST))
               node = this.findNodeByNameNS(result, "ValidateTarget", WSTrustConstants.BASE_NAMESPACE);
            else if (requestType.equalsIgnoreCase(WSTrustConstants.RENEW_REQUEST))
               node = this.findNodeByNameNS(result, "RenewTarget", WSTrustConstants.BASE_NAMESPACE);
            else if (requestType.equalsIgnoreCase(WSTrustConstants.CANCEL_REQUEST))
               node = this.findNodeByNameNS(result, "CancelTarget", WSTrustConstants.BASE_NAMESPACE);
            if (node == null)
               throw new RuntimeException("Unsupported request type:" + requestType);
            node.appendChild(result.importNode(targetElement, true));
         }
      }
      catch (Exception e)
      {
         throw new RuntimeException("Failed to marshall security token request", e);
      }

      return DocumentUtil.getXMLSource(result);
   }

   /**
    * <p>
    * Creates a {@code javax.xml.transform.Source} from the specified request object.
    * </p>
    * 
    * @param request a {@code RequestSecurityTokenCollection} representing the object model of the security token batch
    *           request.
    * @return the constructed {@code Source} instance.
    */
   public Source marshallRequestSecurityTokenCollection(RequestSecurityTokenCollection collection)
   {
      // validation: the collection must contain at least one request.
      if (collection == null || collection.getRequestSecurityTokens().size() == 0)
         throw new IllegalArgumentException("The request collection must contain at least one request");

      // validation: all requests must be the of the same type and must match one of the batch request types.
      String requestType = null;
      for (RequestSecurityToken request : collection.getRequestSecurityTokens())
      {
         // this is the first request: ensure its type is valid.
         if (requestType == null)
         {
            if (request.getRequestType() == null || !isValidBatchRequestType(request.getRequestType().toString()))
               throw new IllegalArgumentException(
                     "The request type cannot be null and must be a valid WS-Trust batch request type");
            requestType = request.getRequestType().toString();
         }
         // for the other requests, ensure their type matches the type of the first request.
         else
         {
            if (request.getRequestType() == null || !requestType.equals(request.getRequestType().toString()))
               throw new IllegalArgumentException("All requests must be of the same type. Invalid type: "
                     + request.getRequestType());
         }
      }

      // save the cancel/renew/validate targets to preserve them from the JAXB marshaling process.
      List<Element> targets = new ArrayList<Element>();
      for (RequestSecurityToken request : collection.getRequestSecurityTokens())
      {
         if (requestType.equals(WSTrustConstants.BATCH_CANCEL_REQUEST))
         {
            targets.add((Element) request.getCancelTarget().getAny());
            request.getCancelTarget().setAny(null);
         }
         else if (requestType.equals(WSTrustConstants.BATCH_RENEW_REQUEST))
         {
            targets.add((Element) request.getRenewTarget().getAny());
            request.getRenewTarget().setAny(null);
         }
         else if (requestType.equals(WSTrustConstants.BATCH_VALIDATE_REQUEST))
         {
            targets.add((Element) request.getValidateTarget().getAny());
            request.getValidateTarget().setAny(null);
         }
      }

      // marshal the document and reinsert the target elements in the generated XML document.
      Document result = null;
      try
      {
         result = DocumentUtil.createDocument();
         this.binder.marshal(this.objectFactory.createRequestSecurityTokenCollection(collection.getDelegate()), result);

         NodeList nodes = null;
         if (requestType.equals(WSTrustConstants.BATCH_CANCEL_REQUEST))
            nodes = result.getElementsByTagNameNS(WSTrustConstants.BASE_NAMESPACE, "CancelTarget");
         else if (requestType.equals(WSTrustConstants.BATCH_RENEW_REQUEST))
            nodes = result.getElementsByTagNameNS(WSTrustConstants.BASE_NAMESPACE, "RenewTarget");
         else if (requestType.equals(WSTrustConstants.BATCH_VALIDATE_REQUEST))
            nodes = result.getElementsByTagNameNS(WSTrustConstants.BASE_NAMESPACE, "ValidateTarget");

         // iterate through the document nodes reinserting the original target elements.
         if (nodes != null)
         {
            for (int i = 0; i < nodes.getLength(); i++)
            {
               Node node = nodes.item(i);
               node.appendChild(result.importNode(targets.get(i), true));
            }
         }
      }
      catch (Exception e)
      {
         throw new RuntimeException("Failed to marshall security token request", e);
      }

      return DocumentUtil.getXMLSource(result);
   }

   /**
    * <p>
    * Creates a {@code javax.xml.transform.Source} from the specified response object.
    * </p>
    * 
    * @param collection a {@code RequestSecurityTokenResponseCollection} representing the object model of the security
    *           token response.
    * @return the constructed {@code Source} instance.
    */
   public Source marshallRequestSecurityTokenResponse(RequestSecurityTokenResponseCollection collection)
   {
      if (collection.getRequestSecurityTokenResponses().size() == 0)
         throw new IllegalArgumentException("The response collection must contain at least one response");

      // if the response contains issued tokens, we must preserve them from the JAXB marshaling.
      List<Element> tokenElements = new ArrayList<Element>();
      for (RequestSecurityTokenResponse response : collection.getRequestSecurityTokenResponses())
      {
         if (response.getRequestedSecurityToken() != null)
         {
            tokenElements.add((Element) response.getRequestedSecurityToken().getAny());
            // we don't want to marshall any token - it will be inserted in the DOM document later.
            response.getRequestedSecurityToken().setAny(null);
         }
      }

      Document result = null;
      try
      {
         // marshall the response to a document and insert the issued tokens directly into the document.
         result = DocumentUtil.createDocument();
         this.marshaller.marshal(this.objectFactory.createRequestSecurityTokenResponseCollection(collection
               .getDelegate()), result);

         // the document is a ws-trust template - we need to insert the tokens in the appropriate elements.
         if (!tokenElements.isEmpty())
         {
            NodeList nodes = result.getElementsByTagNameNS(WSTrustConstants.BASE_NAMESPACE, "RequestedSecurityToken");
            for (int i = 0; i < nodes.getLength(); i++)
            {
               Node node = nodes.item(i);
               node.appendChild(result.importNode(tokenElements.get(i), true));
            }
         }
         if (trace)
         {
            log.trace("Final RSTR doc:" + DocumentUtil.asString(result));
         }

      }
      catch (Exception e)
      {
         throw new RuntimeException("Failed to marshall security token response", e);
      }
      return DocumentUtil.getXMLSource(result);
   }

   /**
    * Return the {@code SAMLDocumentHolder} for the thread
    * 
    * @return
    */
   public SAMLDocumentHolder getSAMLDocumentHolderOnThread()
   {
      return holders.get();
   }

   /**
    * <p>
    * Finds in the specified document a node that matches the specified name and namespace.
    * </p>
    * 
    * @param document the {@code Document} instance upon which the search is made.
    * @param localName a {@code String} containing the local name of the searched node.
    * @param namespace a {@code String} containing the namespace of the searched node.
    * @return a {@code Node} representing the searched node. If more than one node is found in the document, the first
    *         one will be returned. If no nodes were found according to the search parameters, then {@code null} is
    *         returned.
    */
   private Node findNodeByNameNS(Document document, String localName, String namespace)
   {
      NodeList list = document.getElementsByTagNameNS(namespace, localName);
      if (list == null || list.getLength() == 0)
         // log("Unable to locate element " + localName + " with namespace " + namespace);
         return null;
      return list.item(0);
   }

   /**
    * <p>
    * Verifies if the specified {@code String} represents one of the valid WS-Trust batch request types or not.
    * </p>
    * 
    * @param requestType the {@code String} to be verified.
    * @return {@code true} if the request type matches one of the WS-Trust batch requests; {@code false} otherwise.
    */
   private boolean isValidBatchRequestType(String requestType)
   {
      // the request type must match one of the WS-Trust batch request types.
      return (requestType.equals(WSTrustConstants.BATCH_ISSUE_REQUEST)
            || requestType.equals(WSTrustConstants.BATCH_RENEW_REQUEST)
            || requestType.equals(WSTrustConstants.BATCH_CANCEL_REQUEST) || requestType
            .equals(WSTrustConstants.BATCH_VALIDATE_REQUEST));
   }
}