/*
 * JBoss, Home of Professional Open Source. Copyright 2008, Red Hat Middleware LLC, and individual contributors as
 * indicated by the @author tags. See the copyright.txt file in the distribution for a full listing of individual
 * contributors.
 * 
 * This is free software; you can redistribute it and/or modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation; either version 2.1 of the License, or (at your option) any
 * later version.
 * 
 * This software is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied
 * warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
 * details.
 * 
 * You should have received a copy of the GNU Lesser General Public License along with this software; if not, write to
 * the Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA, or see the FSF site:
 * http://www.fsf.org.
 */
package org.picketlink.identity.federation.bindings.jboss.auth;

import java.security.Principal;
import java.security.acl.Group;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import javax.security.auth.Subject;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.NameCallback;
import javax.security.auth.callback.PasswordCallback;
import javax.security.auth.login.LoginException;
import javax.xml.bind.JAXBElement;

import org.jboss.security.auth.spi.AbstractServerLoginModule;
import org.picketlink.identity.federation.bindings.jboss.subject.PicketLinkGroup;
import org.picketlink.identity.federation.bindings.jboss.subject.PicketLinkPrincipal;
import org.picketlink.identity.federation.core.wstrust.SAMLPrincipal;
import org.picketlink.identity.federation.core.wstrust.STSClient;
import org.picketlink.identity.federation.core.wstrust.SamlCredential;
import org.picketlink.identity.federation.core.wstrust.WSTrustException;
import org.picketlink.identity.federation.core.wstrust.STSClientConfig.Builder;
import org.picketlink.identity.federation.core.wstrust.plugins.saml.SAMLUtil;
import org.picketlink.identity.federation.saml.v2.assertion.AssertionType;
import org.picketlink.identity.federation.saml.v2.assertion.AttributeStatementType;
import org.picketlink.identity.federation.saml.v2.assertion.AttributeType;
import org.picketlink.identity.federation.saml.v2.assertion.NameIDType;
import org.picketlink.identity.federation.saml.v2.assertion.StatementAbstractType;
import org.picketlink.identity.federation.saml.v2.assertion.SubjectType;
import org.w3c.dom.Element;

/**
 * <p>
 * This login module delegates authentication of clients to a security token service (STS). If the STS succeeds at
 * authenticating the client, it issues a SAML assertion that may contain other attributes such as security roles. The
 * assertion is then included in a {@code Principal} instance that is in turn added to the "CallerPrincipal" group. This
 * makes the SAML assertion available to JEE applications via {@code getUserPrincipal} or {@code getCallerPrincipal}
 * methods.
 * </p>
 * <p>
 * This login module defines three module options that are used to specify the location of the STS:
 * <ul>
 * <li>endpointAddress - this property specifies the URL of the STS (required).</li>
 * <li>serviceName - this property is used to specify the STS service name (optional, default value=PicketLinkSTS).</li>
 * <li>portName - this property is used to specify the STS port name (optional, default value=PicketLinkSTSPort).</li>
 * </ul>
 * </p>
 * <p>
 * Note: applications may use the resulting {@code Principal} to propagate the client's identity using a SAML assertion.
 * For example, a Web container may use this module to exchange the client's username/password for a SAML assertion and
 * then use this SAML assertion when calling EJBs or other services. The EJB container can then use the {@code
 * SAML2STSLoginModule} to validate the incoming assertion and establish the client's identity.
 * </p>
 * 
 * @author <a href="mailto:sguilhen@redhat.com">Stefan Guilhen</a>
 */
@SuppressWarnings("unchecked")
public class SAML2STSIssuingLoginModule extends AbstractServerLoginModule
{

   private String endpointURL = null;

   private String portName = "PicketLinkSTSPort";

   private String serviceName = "PicketLinkSTS";

   private SAMLPrincipal principal;

   /*
    * (non-Javadoc)
    * @see org.jboss.security.auth.spi.AbstractServerLoginModule#initialize(javax.security.auth.Subject, javax.security.auth.callback.CallbackHandler, java.util.Map, java.util.Map)
    */
   @Override
   public void initialize(Subject subject, CallbackHandler callbackHandler, Map<String, ?> sharedState,
         Map<String, ?> options)
   {
      super.initialize(subject, callbackHandler, sharedState, options);
      // check if the required endpointURL property has been specified.
      if (options.get("endpointAddress") == null)
         throw new IllegalArgumentException("The endpointURL property is required and must specify the STS URL");
      this.endpointURL = (String) options.get("endpointAddress");
      // check if the optional properties have been specified.
      String option = (String) options.get("portName");
      if (option != null)
         this.portName = option;
      option = (String) options.get("serviceName");
      if (option != null)
         this.serviceName = option;
   }

   @Override
   public boolean login() throws LoginException
   {
      // if client has already been authenticated, just save the principal.
      if (super.login() == true)
      {
         Object sharedPrincipal = super.sharedState.get("javax.security.auth.login.name");
         if (sharedPrincipal instanceof SAMLPrincipal)
            this.principal = (SAMLPrincipal) sharedPrincipal;
         else
         {
            super.log.warn("Shared principal is not a SAMLPrincipal.");
            return false;
         }
         return true;
      }

      // client hasn't been authenticated: get username/password pair from callback handler.
      if (callbackHandler == null)
      {
         throw new LoginException("Error: no CallbackHandler available " + "to collect authentication information");
      }

      NameCallback nc = new NameCallback("User name: ", "guest");
      PasswordCallback pc = new PasswordCallback("Password: ", false);
      Callback[] callbacks =
      {nc, pc};

      String username = null;
      String password = null;
      try
      {
         callbackHandler.handle(callbacks);
         username = nc.getName();
         password = new String(pc.getPassword());
      }
      catch (Exception e)
      {
         LoginException exception = new LoginException("Error handling callback" + e.getMessage());
         exception.initCause(e);
         throw exception;
      }

      // create a WS-Trust request with the username/password and send it to the STS.
      Builder builder = new Builder();
      builder.endpointAddress(this.endpointURL).portName(this.portName).serviceName(this.serviceName);
      builder.username(username).password(password);
      STSClient client = new STSClient(builder.build());

      try
      {
         if (log.isTraceEnabled())
            log.trace("Calling STS at " + this.endpointURL);
         Element assertionElement = client.issueToken(SAMLUtil.SAML2_TOKEN_TYPE);
         SamlCredential credential = new SamlCredential(assertionElement);
         this.principal = new SAMLPrincipal(this.getAssertionSubjectName(assertionElement), credential);
      }
      catch (WSTrustException we)
      {
         LoginException exception = new LoginException("Failed to authenticate client via STS: " + we.getMessage());
         exception.initCause(we);
         throw exception;
      }
      
      if (super.getUseFirstPass())
      {
         super.sharedState.put("javax.security.auth.login.name", this.principal);
         super.sharedState.put("javax.security.auth.login.password", this.principal.getSAMLCredential());
      }
      return (super.loginOk = true);
   }

   /*
    * (non-Javadoc)
    * @see org.jboss.security.auth.spi.AbstractServerLoginModule#getIdentity()
    */
   @Override
   protected Principal getIdentity()
   {
      return this.principal;
   }

   /*
    * (non-Javadoc)
    * @see org.jboss.security.auth.spi.AbstractServerLoginModule#getRoleSets()
    */
   @Override
   protected Group[] getRoleSets() throws LoginException
   {
      // add the SAMLPrincipal to the CallerPrincipal group.
      Group callerPrincipal = new PicketLinkGroup("CallerPrincipal");
      callerPrincipal.addMember(this.principal);

      // try to extract roles from the SAML assertion.
      try
      {
         AssertionType assertion = SAMLUtil.fromElement(this.principal.getSAMLCredential().getAssertionAsElement());
         // check the assertion statements and look for role attributes.
         AttributeStatementType attributeStatement = this.getAttributeStatement(assertion);
         if (attributeStatement != null)
         {
            Set<Principal> roles = new HashSet<Principal>();
            List<Object> attributeList = attributeStatement.getAttributeOrEncryptedAttribute();
            for (Object obj : attributeList)
            {
               if (obj instanceof AttributeType)
               {
                  AttributeType attribute = (AttributeType) obj;
                  // if this is a role attribute, get its values and add them to the role set.
                  if (attribute.getName().equals("role"))
                  {
                     for (Object value : attribute.getAttributeValue())
                        roles.add(new PicketLinkPrincipal((String) value));
                  }
               }
            }
            Group rolesGroup = new PicketLinkGroup("Roles");
            for (Principal role : roles)
               rolesGroup.addMember(role);

            return new Group[] {callerPrincipal, rolesGroup};
         }
      }
      catch (Exception e)
      {
         LoginException le = new LoginException("Failed to parse assertion element: " + e.getMessage());
         le.initCause(e);
         throw le;
      }

      return new Group[] {callerPrincipal};
   }

   /**
    * <p>
    * Obtains the subject name of the specified SAML assertion.
    * </p>
    * 
    * @param assertionElement the assertion {@code Element}.
    * @return the name of the assertion subject.
    */
   private String getAssertionSubjectName(Element assertionElement)
   {
      try
      {
         AssertionType assertion = SAMLUtil.fromElement(assertionElement);
         SubjectType subject = assertion.getSubject();
         if (subject != null)
         {
            for (JAXBElement<?> element : subject.getContent())
            {
               if (element.getDeclaredType().equals(NameIDType.class))
               {
                  NameIDType nameID = (NameIDType) element.getValue();
                  return nameID.getValue();
               }
            }
         }
      }
      catch (Exception e)
      {
         throw new RuntimeException("Failed to parse assertion element" + e.getMessage(), e);
      }
      return null;
   }

   /**
    * <p>
    * Checks if the specified SAML assertion contains a {@code AttributeStatementType} and returns this type when it is
    * available.
    * </p>
    * 
    * @param assertion a reference to the {@code AssertionType} that may contain an {@code AttributeStatementType}.
    * @return the assertion's {@code AttributeStatementType}, or {@code null} if no such type can be found in the SAML
    *         assertion.
    */
   private AttributeStatementType getAttributeStatement(AssertionType assertion)
   {
      List<StatementAbstractType> statementList = assertion.getStatementOrAuthnStatementOrAuthzDecisionStatement();
      if (statementList.size() != 0)
      {
         for (StatementAbstractType statement : statementList)
         {
            if (statement instanceof AttributeStatementType)
               return (AttributeStatementType) statement;
         }
      }
      return null;
   }
}
