/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package net.shibboleth.idp.ui.csrf.impl;

import java.util.function.BiPredicate;
import java.util.function.Predicate;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import org.slf4j.Logger;
import org.springframework.webflow.core.collection.MutableAttributeMap;
import org.springframework.webflow.definition.StateDefinition;
import org.springframework.webflow.execution.Event;
import org.springframework.webflow.execution.FlowExecutionListener;
import org.springframework.webflow.execution.FlowSession;
import org.springframework.webflow.execution.RequestContext;
import org.springframework.webflow.execution.View;

import net.shibboleth.idp.ui.csrf.CSRFToken;
import net.shibboleth.idp.ui.csrf.CSRFTokenManager;
import net.shibboleth.idp.ui.csrf.InvalidCSRFTokenException;
import net.shibboleth.shared.annotation.constraint.NonnullAfterInit;
import net.shibboleth.shared.component.AbstractInitializableComponent;
import net.shibboleth.shared.component.ComponentInitializationException;
import net.shibboleth.shared.logic.Constraint;
import net.shibboleth.shared.primitive.LoggerFactory;

/**
 * A flow execution lifecycle listener that, if enabled:
 * <ol>
 *      <li>Sets an anti-CSRF token into the flow-scope map when a flow session starts and a token per-flow is 
 *      enabled.</li>
 *      <li>Sets an anti-CSRF token into the view-scope map when rendering a suitable view-state. This token is 
 *      either retrieved from the flow-scope, if available from step 1, or generated anew.</li>
 *      <li>Checks the CSRF token in a HTTP request matches that stored in the view-scope map when a suitable 
 *       view-state event occurs.</li>
 * </ol>
 */
public class CSRFTokenFlowExecutionListener extends AbstractInitializableComponent implements FlowExecutionListener {
    
    /** The name of the view scope parameter that holds the CSRF token. */
    @Nonnull public static final String CSRF_TOKEN_VIEWSCOPE_NAME = "csrfToken";

    /** Class logger. */
    @Nonnull private final Logger log = LoggerFactory.getLogger(CSRFTokenFlowExecutionListener.class);    
 
    /** Should the request context and event be checked for a valid (matching) CSRF token? */
    @NonnullAfterInit private BiPredicate<RequestContext,Event> eventRequiresCSRFTokenValidationPredicate;

    /** Does the view being rendered require a CSRF token to be set.*/
    @NonnullAfterInit private Predicate<RequestContext> viewRequiresCSRFTokenPredicate;

    /** Is this listener enabled? */
    private boolean enabled;
    
    /** Should a new token should be created for each flow session and not for each view? */
    private boolean tokenPerFlow;

    /** The CSRF token manager for getting and validating tokens. */    
    @NonnullAfterInit private CSRFTokenManager csrfTokenManager;
    
    
    /** Constructor. */
    public CSRFTokenFlowExecutionListener() {
        enabled = false;
        tokenPerFlow = false;
    }
    
    /**
     * Set whether CSRF protection is globally enabled or disabled. 
     * 
     * @param enable enabled/disable CSRF protection (default is {@literal false}).
     */
    public void setEnabled(final boolean enable) {
        checkSetterPreconditions();
        enabled = enable;
    }
    
    /**
     * Sets whether a new token should be created for each flow session and not for each view.
     * 
     * @param flag enable or disable the token per flow pattern
     */
    public void setTokenPerFlow(final boolean flag) {
        checkSetterPreconditions();
        tokenPerFlow = flag;
    }
    
    /**
     *  Sets the request context condition to determine if a CSRF token should be added to the view-scope.
     *  
     * @param condition the condition to apply.
     */
    public void setViewRequiresCSRFTokenPredicate(@Nonnull final Predicate<RequestContext> condition) {
        checkSetterPreconditions();
        viewRequiresCSRFTokenPredicate = Constraint.isNotNull(condition, 
                        "Does view require CSRF token predicate can not be null");
    }
    
    /**
     * Set the request context and event condition to determine if a CSRF token should be validated.
     *  
     * @param condition the condition to apply
     */
    public void setEventRequiresCSRFTokenValidationPredicate(
            @Nonnull final BiPredicate<RequestContext,Event> condition) {
        checkSetterPreconditions();
        eventRequiresCSRFTokenValidationPredicate = Constraint.isNotNull(condition, 
                "Validate CSRF token condition cannot be null");
    }

    /**
     * Sets the CSRF token manager.
     * 
     * @param tokenManager the CSRF token manager.
     */
    public void setCsrfTokenManager(@Nonnull final CSRFTokenManager tokenManager) {    
        checkSetterPreconditions();
        csrfTokenManager = Constraint.isNotNull(tokenManager, "CSRF Token manager can not be null");
    }
    
    /**
     * {@inheritDoc}
     * 
     * <p>If per flow-session tokens are enabled, creates a CSRF token and adds it to the request context flow scope 
     * for extraction into the view scope later on.</p>
     */
    @Override
    public void sessionStarting(final RequestContext context, final FlowSession session, 
            final MutableAttributeMap<?> input) {

        if (enabled && tokenPerFlow) {
            context.getFlowScope().put(CSRF_TOKEN_VIEWSCOPE_NAME, csrfTokenManager.generateCSRFToken()); 
        }
    }
    
    /**
     * Generates a CSRF token and adds it to the request context view scope, overwriting any existing token. 
     * 
     * {@inheritDoc}
     */
    @Override
    public void viewRendering(final RequestContext context, final View view,
            final StateDefinition viewState) {

        //state here should always be a view-state, but guard anyway.
        if (enabled && viewState.isViewState() && viewRequiresCSRFTokenPredicate.test(context)) {
            final Object flowScopedCsrfTokenObject  = context.getFlowScope().get(CSRF_TOKEN_VIEWSCOPE_NAME);
            if (flowScopedCsrfTokenObject instanceof final CSRFToken token) {
                context.getViewScope().put(CSRF_TOKEN_VIEWSCOPE_NAME, token);
            } else {
                context.getViewScope().put(CSRF_TOKEN_VIEWSCOPE_NAME, csrfTokenManager.generateCSRFToken());
            }
        }
    }

    /**
     * Checks the CSRF token in the HTTP request matches that stored in the request context viewScope.
     * 
     * <p>Only applies if the listener is enabled, the current state is a view-state, and the request context and 
     * event match the <code>eventRequiresCSRFTokenValidationPredicate</code> condition.</p>
     * 
     * <p>Invalid tokens - those not found or not matching - are signalled by throwing a 
     * {@link InvalidCSRFTokenException}.</p>
     * 
     * {@inheritDoc}
     */
    @Override
    public void eventSignaled(final @Nullable RequestContext context, final @Nullable Event event) {

        assert context != null && event != null;
        //always make sure listener is enabled and the current state is an active view-state.
        if (enabled && context.inViewState() && 
                eventRequiresCSRFTokenValidationPredicate.test(context,event)){            
           
            final String stateId = context.getCurrentState().getId();
            
            log.trace("Event '{}' signaled from view '{}' requires a CSRF token", event.getId(),stateId);

            final Object storedCsrfTokenObject = context.getViewScope().get(CSRF_TOKEN_VIEWSCOPE_NAME);
            final String activeFlowId = context.getActiveFlow().getId();
            assert activeFlowId != null;
            if (storedCsrfTokenObject == null || !(storedCsrfTokenObject instanceof CSRFToken)) {
                log.warn("CSRF token is required but was not found in the view-scope; for "
                        + "view-state '{}' and event '{}'.",stateId,event.getId());
                throw new InvalidCSRFTokenException(activeFlowId, stateId,
                        "Invalid CSRF token");               
            }

            final CSRFToken storedCsrfToken = (CSRFToken) storedCsrfTokenObject;

            //external context and request parameter map should never be null.
            final Object csrfTokenFromRequest =
                    context.getExternalContext().getRequestParameterMap().get(storedCsrfToken.getParameterName());

            if (csrfTokenFromRequest == null || !(csrfTokenFromRequest instanceof String)) {    
                log.warn("CSRF token is required but was not found in the request; for "
                        + "view-state '{}' and event '{}'.",stateId,event.getId());
                throw new InvalidCSRFTokenException(activeFlowId, stateId,
                        "Invalid CSRF token");
            }

            log.trace("Stored (viewScoped) CSRF Token '{}', CSRF Token in HTTP request '{}'", 
                    storedCsrfToken.getToken(), csrfTokenFromRequest);

            if (!csrfTokenManager.isValidCSRFToken(storedCsrfToken, (String) csrfTokenFromRequest)) {
                log.warn("CSRF token in the request did not match that stored in the view-scope; for "
                        + "view-state '{}' and event '{}'.",stateId,event.getId());
                throw new InvalidCSRFTokenException(activeFlowId, stateId,
                        "Invalid CSRF token");
            }

        }
    } 


    /** {@inheritDoc} */
    @Override
    public void doInitialize() throws ComponentInitializationException {
        super.doInitialize();
        
        if (csrfTokenManager == null ) {
            throw new ComponentInitializationException("CSRF token manager can not be null");
        }
        if (viewRequiresCSRFTokenPredicate==null) {
            throw new ComponentInitializationException("View requires CSRF token predicate can not be null");
        }
        if (eventRequiresCSRFTokenValidationPredicate==null) {
            throw new ComponentInitializationException("Event requires CSRF token validation predicate can "
                    + "not be null");
        }
        
    }

}
