/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package net.shibboleth.shared.spring.config;

import java.util.HashMap;
import java.util.Map;

import javax.annotation.Nonnull;

import org.springframework.beans.BeansException;
import org.springframework.beans.PropertyValue;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.BeanDefinitionHolder;
import org.springframework.beans.factory.config.BeanFactoryPostProcessor;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.support.ManagedList;
import org.springframework.beans.factory.support.ManagedSet;

import net.shibboleth.shared.annotation.ParameterName;
import net.shibboleth.shared.collection.CollectionSupport;
import net.shibboleth.shared.primitive.DeprecationSupport;
import net.shibboleth.shared.primitive.StringSupport;
import net.shibboleth.shared.primitive.DeprecationSupport.ObjectType;

/**
 * A {@link BeanFactoryPostProcessor} to rewrite and log on relocated classes or parent beans.
 * 
 * @since 9.0.0
 */
public class RelocatedBeanFactoryPostProcessor implements BeanFactoryPostProcessor {
    
    /** Relocated classes. */
    @Nonnull private Map<String,String> movedClasses;

    /** Relocated beans. */
    @Nonnull private Map<String,String> movedBeans;

    /** Constructor. */
    public RelocatedBeanFactoryPostProcessor() {
        movedClasses = CollectionSupport.emptyMap();
        movedBeans = CollectionSupport.emptyMap();
    }
    
    /**
     * Set class names to rewrite.
     * 
     * @param classes classes to detect with replacements identified
     */
    public void setClasses(@Nonnull @ParameterName(name="classes") final Map<String,String> classes) {
        movedClasses = new HashMap<>(classes.size());
        classes.forEach((k,v) -> {
            final String key = StringSupport.trimOrNull(k);
            final String val = StringSupport.trimOrNull(v);
            if (key != null) {
                movedClasses.put(key, val);
            }
        });
    }

    /**
     * Set bean names to rewrite.
     * 
     * @param beans beans to detect with replacements identified
     */
    public void setBeans(@Nonnull @ParameterName(name="classes") final Map<String,String> beans) {
        movedBeans = new HashMap<>(beans.size());
        beans.forEach((k,v) -> {
            final String key = StringSupport.trimOrNull(k);
            final String val = StringSupport.trimOrNull(v);
            if (key != null) {
                movedBeans.put(key, val);
            }
        });
    }

    /**
     * Process a bean definition by relocating its class or parent if appropriate.
     *
     * @param def bean definition to be processed
     * @param name name of the bean definition
     */
    private void processBeanDefinition(@Nonnull final BeanDefinition def, @Nonnull final String name) {
        final String className = def.getBeanClassName();
        if (className != null && movedClasses.containsKey(className)) {
            DeprecationSupport.warn(ObjectType.CLASS, className, "Bean ID: " + name, movedClasses.get(className));
            def.setBeanClassName(movedClasses.get(className));
        }
        
        final String parentName = def.getParentName();
        if (parentName != null && movedBeans.containsKey(parentName)) {
            DeprecationSupport.warn(ObjectType.BEAN, parentName, "Bean ID: " + name, movedBeans.get(parentName));
            def.setParentName(movedBeans.get(parentName));
        }

        // Look recursively inside any property values for nested and potentially un-named bean definitions
        for (final PropertyValue property : def.getPropertyValues()) {
            final var propValue = property.getValue();
            if (propValue instanceof BeanDefinitionHolder defHolder) {
                // Handle potentially unnamed bean definitions
                processBeanDefinition(defHolder.getBeanDefinition(), defHolder.getBeanName());
            } else if (propValue instanceof ManagedList<?> pv) {
                for (final var value : pv) {
                    if (value instanceof BeanDefinitionHolder defHolder) {
                        // Handle potentially unnamed bean definitions within lists
                        processBeanDefinition(defHolder.getBeanDefinition(), defHolder.getBeanName());
                    }
                }
            } else if (propValue instanceof ManagedSet<?> pv) {
                for (final var value : pv) {
                    if (value instanceof BeanDefinitionHolder defHolder) {
                        // Handle potentially unnamed bean definitions within sets
                        processBeanDefinition(defHolder.getBeanDefinition(), defHolder.getBeanName());
                    }
                }
            }
        }
    }

    /** {@inheritDoc} */
    public void postProcessBeanFactory(@Nonnull final ConfigurableListableBeanFactory beanFactory)
            throws BeansException {
        
        for (final String name : beanFactory.getBeanDefinitionNames()) {
            assert name != null;
            final BeanDefinition def = beanFactory.getBeanDefinition(name);

            // Handle the named bean definition
            processBeanDefinition(def, name);
        }
    }

}