package org.kie.kogito.explainability.local.lime.optim;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.ExecutionException;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.kie.kogito.explainability.local.lime.LimeConfig;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.optaplanner.core.api.score.buildin.simplebigdecimal.SimpleBigDecimalScore;
import org.optaplanner.core.api.score.calculator.EasyScoreCalculator;
import org.optaplanner.core.api.solver.SolverManager;
import org.optaplanner.core.config.localsearch.LocalSearchPhaseConfig;
import org.optaplanner.core.config.localsearch.LocalSearchType;
import org.optaplanner.core.config.score.director.ScoreDirectorFactoryConfig;
import org.optaplanner.core.config.solver.SolverConfig;
import org.optaplanner.core.config.solver.SolverManagerConfig;
import org.optaplanner.core.config.solver.termination.TerminationConfig;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/kie/kogito/explainability/local/lime/optim/LimeConfigOptimizer.class */
public class LimeConfigOptimizer {
    private static final Logger logger = LoggerFactory.getLogger((Class<?>) LimeConfigOptimizer.class);
    private static final long DEFAULT_TIME_LIMIT = 30;
    private static final boolean DEFAULT_PROXIMITY_ENTITIES = true;
    private static final boolean DEFAULT_SAMPLING_ENTITIES = true;
    private static final boolean DEFAULT_ENCODING_ENTITIES = true;
    private static final boolean DEFAULT_WEIGHTING_ENTITIES = true;
    private long timeLimit = DEFAULT_TIME_LIMIT;
    private EasyScoreCalculator<LimeConfigSolution, SimpleBigDecimalScore> scoreCalculator = new LimeStabilityScoreCalculator();
    private boolean proximityEntities = true;
    private boolean samplingEntities = true;
    private boolean encodingEntities = true;
    private boolean weightingEntities = true;

    public LimeConfigOptimizer withTimeLimit(long j) {
        this.timeLimit = j;
        return this;
    }

    public LimeConfigOptimizer withProximity(boolean z) {
        this.proximityEntities = z;
        return this;
    }

    public LimeConfigOptimizer withSampling(boolean z) {
        this.samplingEntities = z;
        return this;
    }

    public LimeConfigOptimizer withEncoding(boolean z) {
        this.encodingEntities = z;
        return this;
    }

    public LimeConfigOptimizer withWeighting(boolean z) {
        this.weightingEntities = z;
        return this;
    }

    public LimeConfigOptimizer withScoreCalculator(EasyScoreCalculator<LimeConfigSolution, SimpleBigDecimalScore> easyScoreCalculator) {
        this.scoreCalculator = easyScoreCalculator;
        return this;
    }

    public LimeConfigOptimizer withWeightedStability(double d, double d2) {
        if (d < CMAESOptimizer.DEFAULT_STOPFITNESS || d > 1.0d) {
            throw new IllegalArgumentException("negative weight must be between 0 and 1");
        }
        if (d2 < CMAESOptimizer.DEFAULT_STOPFITNESS || d2 > 1.0d) {
            throw new IllegalArgumentException("positive weight must be between 0 and 1");
        }
        if (Math.abs((1.0d - d) - d2) > 0.001d) {
            throw new IllegalArgumentException("negative and positive weights must sum up to 1");
        }
        this.scoreCalculator = new LimeStabilityScoreCalculator(BigDecimal.valueOf(d), BigDecimal.valueOf(d2));
        return this;
    }

    public LimeConfig optimize(LimeConfig limeConfig, List<Prediction> list, PredictionProvider predictionProvider) {
        ArrayList arrayList = new ArrayList();
        if (this.samplingEntities) {
            arrayList.addAll(LimeConfigEntityFactory.createSamplingEntities(limeConfig));
        }
        if (this.proximityEntities) {
            arrayList.addAll(LimeConfigEntityFactory.createProximityEntities(limeConfig));
        }
        if (this.encodingEntities) {
            arrayList.addAll(LimeConfigEntityFactory.createEncodingEntities(limeConfig));
        }
        if (this.weightingEntities) {
            arrayList.addAll(LimeConfigEntityFactory.createWeightingEntities(limeConfig));
        }
        if (arrayList.isEmpty()) {
            return limeConfig;
        }
        LimeConfigSolution limeConfigSolution = new LimeConfigSolution(limeConfig, list, arrayList, predictionProvider);
        SolverConfig solverConfig = new SolverConfig();
        solverConfig.withEntityClasses(new Class[]{NumericLimeConfigEntity.class, BooleanLimeConfigEntity.class});
        solverConfig.withSolutionClass(LimeConfigSolution.class);
        ScoreDirectorFactoryConfig scoreDirectorFactoryConfig = new ScoreDirectorFactoryConfig();
        scoreDirectorFactoryConfig.setEasyScoreCalculatorClass(this.scoreCalculator.getClass());
        solverConfig.setScoreDirectorFactoryConfig(scoreDirectorFactoryConfig);
        TerminationConfig terminationConfig = new TerminationConfig();
        terminationConfig.setSecondsSpentLimit(Long.valueOf(this.timeLimit));
        solverConfig.setTerminationConfig(terminationConfig);
        LocalSearchPhaseConfig localSearchPhaseConfig = new LocalSearchPhaseConfig();
        localSearchPhaseConfig.setLocalSearchType(LocalSearchType.LATE_ACCEPTANCE);
        ArrayList arrayList2 = new ArrayList();
        arrayList2.add(localSearchPhaseConfig);
        solverConfig.setPhaseConfigList(arrayList2);
        SolverManager create = SolverManager.create(solverConfig, new SolverManagerConfig());
        try {
            try {
                LimeConfigSolution limeConfigSolution2 = (LimeConfigSolution) create.solve(UUID.randomUUID(), limeConfigSolution).getFinalBestSolution();
                LimeConfig limeConfig2 = LimeConfigEntityFactory.toLimeConfig(limeConfigSolution2);
                logger.info("final best solution score {} with config {}", limeConfigSolution2.getScore().getScore(), limeConfig2);
                if (create != null) {
                    create.close();
                }
                return limeConfig2;
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                throw new IllegalStateException("Solving failed (Thread interrupted)", e);
            } catch (ExecutionException e2) {
                logger.error("Solving failed: {}", e2.getMessage());
                throw new IllegalStateException("Prediction returned an error", e2);
            }
        } catch (Throwable th) {
            if (create != null) {
                try {
                    create.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    public LimeConfigOptimizer forImpactScore() {
        this.scoreCalculator = new LimeImpactScoreCalculator();
        return this;
    }

    public LimeConfigOptimizer forStabilityScore() {
        this.scoreCalculator = new LimeStabilityScoreCalculator();
        return this;
    }
}
