/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package net.shibboleth.profile.relyingparty.impl;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import org.opensaml.core.criterion.EntityIdCriterion;
import org.opensaml.core.metrics.MetricsSupport;
import org.opensaml.profile.context.ProfileRequestContext;
import org.opensaml.profile.criterion.ProfileRequestContextCriterion;
import org.opensaml.saml.common.messaging.context.SAMLMetadataContext;
import org.opensaml.saml.common.messaging.context.SAMLPeerEntityContext;
import org.opensaml.saml.criterion.RoleDescriptorCriterion;
import org.opensaml.saml.saml2.metadata.EntityDescriptor;
import org.opensaml.saml.saml2.metadata.RoleDescriptor;
import org.opensaml.security.credential.Credential;
import org.slf4j.Logger;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;

import com.codahale.metrics.Counter;
import com.codahale.metrics.Gauge;

import net.shibboleth.profile.context.RelyingPartyContext;
import net.shibboleth.profile.relyingparty.RelyingPartyConfiguration;
import net.shibboleth.profile.relyingparty.RelyingPartyConfigurationResolver;
import net.shibboleth.profile.relyingparty.VerifiedProfileCriterion;
import net.shibboleth.shared.annotation.constraint.NotEmpty;
import net.shibboleth.shared.annotation.constraint.NotLive;
import net.shibboleth.shared.annotation.constraint.Unmodifiable;
import net.shibboleth.shared.collection.CollectionSupport;
import net.shibboleth.shared.component.AbstractIdentifiableInitializableComponent;
import net.shibboleth.shared.component.ComponentInitializationException;
import net.shibboleth.shared.primitive.LoggerFactory;
import net.shibboleth.shared.primitive.StringSupport;
import net.shibboleth.shared.resolver.CriteriaSet;
import net.shibboleth.shared.resolver.Criterion;
import net.shibboleth.shared.resolver.ResolverException;
import net.shibboleth.spring.security.CredentialHolder;

/**
 * Retrieves a per-relying party configuration for a given profile request based on the
 * supplied {@link CriteriaSet}.
 *
 * <p>Supported {@link Criterion}:</p>
 * <ul>
 *  <li>{@link ProfileRequestContextCriterion}</li>
 *  <li>{@link EntityIdCriterion}</li>
 *  <li>{@link RoleDescriptorCriterion}</li>
 *  <li>{@link VerifiedProfileCriterion}</li>
 * </ul>
 * 
 * <p>
 * Note that this resolver does not permit more than one {@link RelyingPartyConfiguration} with the same ID.
 * </p>
 * 
 * @since 5.0.0
 */
public class DefaultRelyingPartyConfigurationResolver extends AbstractIdentifiableInitializableComponent
        implements RelyingPartyConfigurationResolver {

    /** Counter ID for default RP. */
    @Nonnull @NotEmpty private static final String DEFAULT_RELYING_PARTY_COUNTER = "shibboleth.DefaultRelyingParty";

    /** Counter ID for unverified RP. */
    @Nonnull @NotEmpty
    private static final String UNVERIFIED_RELYING_PARTY_COUNTER = "shibboleth.UnverifiedRelyingParty";

    /** Class logger. */
    @Nonnull private final Logger log = LoggerFactory.getLogger(DefaultRelyingPartyConfigurationResolver.class);

    /** Registered relying party configurations. */
    @Nonnull private List<RelyingPartyConfiguration> rpConfigurations;

    /** Default relying party, used if no other verified instance matches. */
    @Nullable private RelyingPartyConfiguration defaultRelyingPartyConfiguration;

    /** Unverified relying party configuration, used if the request is unverified. */
    @Nullable private RelyingPartyConfiguration unverifiedConfiguration;
    
    /** The global list of all configured signing credentials. */
    @Nonnull private List<Credential> signingCredentials;
    
    /** The global list of all configured encryption credentials. */
    @Nonnull private List<Credential> encryptionCredentials;
    
    /** Name for counter map metric. */
    @Nullable private String metricName;

    /** Stored copy of metric for teardown. */
    @Nullable private Gauge<Map<String,Counter>> counterGauge;
    
    /** Map of override counters. */
    @Nonnull private Map<String,Counter> counterMap;

    /** Constructor. */
    public DefaultRelyingPartyConfigurationResolver() {
        rpConfigurations = CollectionSupport.emptyList();
        signingCredentials = CollectionSupport.emptyList();
        encryptionCredentials = CollectionSupport.emptyList();
        counterMap = CollectionSupport.emptyMap();
    }
    
    /**
     * Get an unmodifiable list of verified relying party configurations.
     * 
     * @return unmodifiable list of verified relying party configurations
     */
    @Nonnull @Unmodifiable @NotLive
    public Collection<? extends RelyingPartyConfiguration> getRelyingPartyConfigurations() {
        return rpConfigurations;
    }

    /**
     * Set the verified relying party configurations.
     * 
     * @param configs list of verified relying party configurations
     */
    public void setRelyingPartyConfigurations(@Nullable final Collection<? extends RelyingPartyConfiguration> configs) {
        checkSetterPreconditions();
        
        if (configs != null) {
            rpConfigurations = CollectionSupport.copyToList(configs);
        } else {
            rpConfigurations = CollectionSupport.emptyList();
        }
    }

    /**
     * Get the {@link RelyingPartyConfiguration} to use if no other configuration is applicable.
     * 
     * @return default configuration
     */
    @Nullable public RelyingPartyConfiguration getDefaultConfiguration() {
        return defaultRelyingPartyConfiguration;
    }

    /**
     * Set the {@link RelyingPartyConfiguration} to use if no other configuration is applicable.
     * 
     * @param configuration default configuration
     */
    public void setDefaultConfiguration(@Nullable final RelyingPartyConfiguration configuration) {
        checkSetterPreconditions();
        
        defaultRelyingPartyConfiguration = configuration;
    }

    /**
     * Get the {@link RelyingPartyConfiguration} to use if the configuration is found to be "unverified"
     * (via use of {@link VerifiedProfileCriterion}).
     * 
     * @return unverified configuration
     */
    @Nullable public RelyingPartyConfiguration getUnverifiedConfiguration() {
        return unverifiedConfiguration;
    }

    /**
     * Set the {@link RelyingPartyConfiguration} to use if the configuration is found to be "unverified"
     * (via use of {@link VerifiedProfileCriterion}).
     * 
     * @param configuration unverified configuration
     */
    public void setUnverifiedConfiguration(@Nullable final RelyingPartyConfiguration configuration) {
        checkSetterPreconditions();
        
        unverifiedConfiguration = configuration;
    }
    
    /**
     * Gets the metric name to use to use for counters to track use of configurations.
     * 
     * @return name for counter metrics
     * 
     * @since 5.1.3
     */
    @Nullable public String getMetricName() {
        return metricName;
    }
    
    /**
     * Set name of metric to use for counters to track use of configurations.
     * 
     * @param name name for counter metrics
     * 
     * @since 5.0.0
     */
    public void setMetricName(@Nullable final String name) {
        checkSetterPreconditions();
        
        metricName = StringSupport.trimOrNull(name);
    }

    /** {@inheritDoc} */
    @Override protected void doInitialize() throws ComponentInitializationException {
        super.doInitialize();

        final HashSet<String> configIds = new HashSet<>(rpConfigurations.size());
        for (final RelyingPartyConfiguration config : rpConfigurations) {
            if (configIds.contains(config.getId())) {
                throw new ComponentInitializationException("Multiple RelyingPartyConfiguration configurations with ID "
                        + config.getId() + " detected, IDs must be unique.");
            }
            configIds.add(config.getId());
        }

        // Set up counters for metrics, but add the two "other" slots in.
        configIds.add(DEFAULT_RELYING_PARTY_COUNTER);
        configIds.add(UNVERIFIED_RELYING_PARTY_COUNTER);
        
        final String localMetricName = getMetricName();
        if (localMetricName != null && !configIds.isEmpty()) {
            // Build map of config IDs to counters.
            counterMap = configIds.stream().collect(
                    CollectionSupport.nonnullCollector(
                            Collectors.toUnmodifiableMap(id -> id, id -> new Counter()))).get();
            // Note that this gauge must use the support method to register in a synchronized fashion,
            // and also must store off the instance for later use in destroy.
            counterGauge = MetricsSupport.register(localMetricName,
                    new Gauge<Map<String, Counter>>() {
                        public Map<String, Counter> getValue() {
                            return counterMap;
                        }
                    },
                    true);
        }
    }
    
    /** {@inheritDoc} */
    @Override protected void doDestroy() {
        final String localMetricName = getMetricName();
        if (localMetricName != null && counterGauge != null) {
            MetricsSupport.remove(localMetricName, counterGauge);
        }
        
        super.doDestroy();
    }

    /** {@inheritDoc} */
    @Nonnull public Iterable<RelyingPartyConfiguration> resolve(@Nullable final CriteriaSet criteria)
            throws ResolverException {
        checkComponentActive();

        log.debug("Resolving relying party configuration");

        if (criteria == null) {
            return CollectionSupport.emptyList();
        }

        final VerifiedProfileCriterion vpc = criteria.get(VerifiedProfileCriterion.class);
        if (vpc == null || !vpc.isVerified()) {
            final RelyingPartyConfiguration uvc = getUnverifiedConfiguration();
            if (uvc == null) {
                log.warn("Profile request was unverified, but no such configuration is available");
                return CollectionSupport.emptyList();
            }
            log.debug("Profile request is unverified, returning configuration {}", uvc.getId());
            increment(UNVERIFIED_RELYING_PARTY_COUNTER);
            return CollectionSupport.singleton(uvc);
        }
        
        final ArrayList<RelyingPartyConfiguration> matches = new ArrayList<>();

        final ProfileRequestContext context = getProfileRequestContext(criteria);

        for (final RelyingPartyConfiguration configuration : rpConfigurations) {
            log.debug("Checking if relying party configuration {} is applicable", configuration.getId());
            if (configuration.test(context)) {
                log.debug("Relying party configuration {} is applicable", configuration.getId());
                increment(configuration.getId());
                matches.add(configuration);
            } else {
                log.debug("Relying party configuration {} is not applicable", configuration.getId());
            }
        }

        if (matches.isEmpty()) {
            if (defaultRelyingPartyConfiguration != null) {
                log.debug("No matching relying party configuration applicable, returning default: {}",
                        defaultRelyingPartyConfiguration.getId());
                increment(DEFAULT_RELYING_PARTY_COUNTER);
                assert defaultRelyingPartyConfiguration != null;
                return CollectionSupport.singleton(defaultRelyingPartyConfiguration);
            } else {
                log.warn("No matching relying party configuration applicable, returning nothing");
                return CollectionSupport.emptyList();
            }
        }
        return matches;
    }

    /** {@inheritDoc} */
    @Nullable public RelyingPartyConfiguration resolveSingle(@Nullable final CriteriaSet criteria)
            throws ResolverException {
        checkComponentActive();

        log.debug("Resolving relying party configuration");

        if (criteria == null) {
            return null;
        }

        final VerifiedProfileCriterion vpc = criteria.get(VerifiedProfileCriterion.class);
        if (vpc == null || !vpc.isVerified()) {
            final RelyingPartyConfiguration uvc = getUnverifiedConfiguration();
            if (uvc == null) {
                log.warn("Profile request was unverified, but no such configuration is available");
                return null;
            }
            log.debug("Profile request is unverified, returning configuration {}", uvc.getId());
            increment(UNVERIFIED_RELYING_PARTY_COUNTER);
            return uvc;
        }

        final ProfileRequestContext context = getProfileRequestContext(criteria);

        for (final RelyingPartyConfiguration configuration : rpConfigurations) {
            log.debug("Checking if relying party configuration {} is applicable", configuration.getId());
            if (configuration.test(context)) {
                log.debug("Relying party configuration {} is applicable", configuration.getId());
                increment(configuration.getId());
                return configuration;
            }
            log.debug("Relying party configuration {} is not applicable", configuration.getId());
        }

        if (defaultRelyingPartyConfiguration != null) {
            log.debug("No matching relying party configuration applicable, returning default: {}",
                    defaultRelyingPartyConfiguration.getId());
            increment(DEFAULT_RELYING_PARTY_COUNTER);
            assert defaultRelyingPartyConfiguration != null;
            return defaultRelyingPartyConfiguration;
        }

        log.warn("No matching relying party configuration applicable, returning nothing");
        return null;
    }
    
    /** {@inheritDoc} */
    @Nonnull @Unmodifiable @NotLive public Collection<Credential> getSigningCredentials() {
        return signingCredentials;
    }
    
    /**
     * Set the list of all configured signing credentials.
     * 
     * @param credentials the list of signing credentials, may be null
     */
    @Autowired
    @Qualifier("signing")
    public void setSigningCredentials(@Nullable final List<CredentialHolder> credentials) {
        checkSetterPreconditions();
        
        if (credentials != null) {
            signingCredentials = credentials.stream()
                    .flatMap(h -> h.getCredentials().stream())
                    .collect(CollectionSupport.nonnullCollector(Collectors.toUnmodifiableList())).get();
        } else {
            signingCredentials = CollectionSupport.emptyList();
        }
    }

    /** {@inheritDoc} */
    @Nonnull @Unmodifiable @NotLive public Collection<Credential> getEncryptionCredentials() {
        return encryptionCredentials;
    }
    
    /**
     * Set the list of all configured encryption credentials.
     * 
     * @param credentials the list of encryption credentials, may be null
     */
    @Autowired
    @Qualifier("encryption")
    public void setEncryptionCredentials(@Nullable final List<CredentialHolder> credentials) {
        checkSetterPreconditions();
        
        if (credentials != null) {
            encryptionCredentials = credentials.stream()
                    .flatMap(h -> h.getCredentials().stream())
                    .collect(CollectionSupport.nonnullCollector(Collectors.toUnmodifiableList())).get();
        } else {
            encryptionCredentials = CollectionSupport.emptyList();
        }
    }
    
    /**
     * Increment relying party counter.
     * 
     * @param name name of counter
     */
    private void increment(@Nullable final String name) {
        final Counter counter = counterMap.get(name);
        if (counter != null) {
            counter.inc();
        }
    }
    
    /**
     * Get the {@link ProfileRequestContext} included in the input criteria, if any.
     * 
     * @param criteria input criteria
     * 
     * @return embedded profile request context or null
     */
    @Nullable private ProfileRequestContext getProfileRequestContext(@Nonnull final CriteriaSet criteria) {
        
        final ProfileRequestContextCriterion prcCriterion = criteria.get(ProfileRequestContextCriterion.class);
        if (prcCriterion != null) {
            return prcCriterion.getProfileRequestContext();
        }
        
        final String entityID = resolveEntityID(criteria);
        log.debug("Resolved effective entityID from criteria: {}", entityID);

        final EntityDescriptor entityDescriptor = resolveEntityDescriptor(criteria);
        log.debug("Resolved effective entity descriptor from criteria: {}", entityDescriptor);

        final RoleDescriptor roleDescriptor = resolveRoleDescriptor(criteria);
        log.debug("Resolved effective role descriptor from criteria: {}", roleDescriptor);

        if (entityID != null || entityDescriptor != null || roleDescriptor != null) {
            final ProfileRequestContext prc = new ProfileRequestContext();
            final RelyingPartyContext rpc = prc.ensureSubcontext(RelyingPartyContext.class);
            rpc.setVerified(true);

            rpc.setRelyingPartyId(entityID);

            if (entityDescriptor != null || roleDescriptor != null) {
                final SAMLPeerEntityContext peerContext = prc.ensureSubcontext(SAMLPeerEntityContext.class);
                rpc.setRelyingPartyIdContextTree(peerContext);

                peerContext.setEntityId(entityID);

                if (roleDescriptor != null) {
                    peerContext.setRole(roleDescriptor.getSchemaType() != null
                            ? roleDescriptor.getSchemaType() : roleDescriptor.getElementQName());
                }

                final SAMLMetadataContext metadataContext = peerContext.ensureSubcontext(SAMLMetadataContext.class);
                metadataContext.setEntityDescriptor(entityDescriptor);
                metadataContext.setRoleDescriptor(roleDescriptor);
            }
            return prc;
        }
        return null;
    }
    
    /**
     * Resolve the entityID from the criteria.
     * 
     * @param criteria the input criteria
     * @return the input entityID criterion or null if could not be resolved
     */
    @Nullable private String resolveEntityID(@Nonnull final CriteriaSet criteria) {
        final EntityIdCriterion eic = criteria.get(EntityIdCriterion.class);
        if (eic != null) {
            return eic.getEntityId();
        }

        final EntityDescriptor ed = resolveEntityDescriptor(criteria);
        if (ed != null) {
            return ed.getEntityID();
        }

        return null;
    }

    /**
     * Resolve the EntityDescriptor from the criteria.
     *
     * @param criteria the input criteria
     * @return the input entity descriptor criterion, or null if could not be resolved
     */
    @Nullable private EntityDescriptor resolveEntityDescriptor(@Nonnull final CriteriaSet criteria) {
        final RoleDescriptor rd = resolveRoleDescriptor(criteria);
        if (rd != null && rd.getParent() != null && rd.getParent() instanceof EntityDescriptor) {
            return (EntityDescriptor)rd.getParent();
        }

        return null;
    }

    /**
     * Resolve the RoleDescriptor from the criteria.
     *
     * @param criteria the input criteria
     * @return the input role descriptor criterion or null if could not be resolved
     */
    @Nullable private RoleDescriptor resolveRoleDescriptor(@Nonnull final CriteriaSet criteria) {
        final RoleDescriptorCriterion rdc = criteria.get(RoleDescriptorCriterion.class);
        if (rdc != null) {
            return rdc.getRole();
        }

        return null;
    }

}