package org.kie.kogito.explainability.local.lime.optim;

import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
import org.kie.kogito.explainability.local.lime.LimeConfig;
import org.kie.kogito.explainability.local.lime.LimeExplainer;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.utils.ExplainabilityMetrics;
import org.kie.kogito.explainability.utils.LocalSaliencyStability;
import org.optaplanner.core.api.score.buildin.simplebigdecimal.SimpleBigDecimalScore;
import org.optaplanner.core.api.score.calculator.EasyScoreCalculator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/kie/kogito/explainability/local/lime/optim/LimeStabilityScoreCalculator.class */
public class LimeStabilityScoreCalculator implements EasyScoreCalculator<LimeStabilitySolution, SimpleBigDecimalScore> {
    private static final Logger LOGGER = LoggerFactory.getLogger(LimeStabilityScoreCalculator.class);
    private static final BigDecimal TWO = BigDecimal.valueOf(2.0d);
    private static final BigDecimal ZERO = BigDecimal.valueOf(0L);

    public SimpleBigDecimalScore calculateScore(LimeStabilitySolution limeStabilitySolution) {
        LimeConfig limeConfig = LimeConfigEntityFactory.toLimeConfig(limeStabilitySolution);
        BigDecimal bigDecimal = BigDecimal.ZERO;
        List<Prediction> predictions = limeStabilitySolution.getPredictions();
        if (!predictions.isEmpty()) {
            bigDecimal = getStabilityScore(limeStabilitySolution, limeConfig, predictions);
        }
        return SimpleBigDecimalScore.of(bigDecimal);
    }

    private BigDecimal getStabilityScore(LimeStabilitySolution limeStabilitySolution, LimeConfig limeConfig, List<Prediction> list) {
        double d = 0.0d;
        BigDecimal bigDecimal = BigDecimal.ZERO;
        LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
        Iterator<Prediction> it = list.iterator();
        while (it.hasNext()) {
            try {
                LocalSaliencyStability localSaliencyStability = ExplainabilityMetrics.getLocalSaliencyStability(limeStabilitySolution.getModel(), it.next(), limeExplainer, TWO.intValue(), 5);
                Iterator<String> it2 = localSaliencyStability.getDecisions().iterator();
                while (it2.hasNext()) {
                    bigDecimal = bigDecimal.add(getDecisionMarginalScore(TWO, localSaliencyStability, it2.next()));
                    d += 1.0d;
                }
            } catch (InterruptedException e) {
                LOGGER.error("Interrupted while waiting for saliency stability calculation {}", e.getMessage());
                Thread.currentThread().interrupt();
            } catch (ExecutionException e2) {
                LOGGER.error("Saliency stability calculation returned an error {}", e2.getMessage());
            } catch (TimeoutException e3) {
                LOGGER.error("Timed out while waiting for saliency stability calculation", e3);
            }
        }
        if (d > 0.0d) {
            bigDecimal = bigDecimal.divide(BigDecimal.valueOf(d), RoundingMode.CEILING);
        }
        return bigDecimal;
    }

    private BigDecimal getDecisionMarginalScore(BigDecimal bigDecimal, LocalSaliencyStability localSaliencyStability, String str) {
        BigDecimal bigDecimal2 = ZERO;
        BigDecimal bigDecimal3 = ZERO;
        for (int i = 1; i <= bigDecimal.intValue(); i++) {
            bigDecimal2 = bigDecimal2.add(BigDecimal.valueOf(localSaliencyStability.getPositiveStabilityScore(str, i)));
            bigDecimal3 = bigDecimal3.add(BigDecimal.valueOf(localSaliencyStability.getNegativeStabilityScore(str, i)));
        }
        return bigDecimal2.divide(bigDecimal, RoundingMode.CEILING).add(bigDecimal3.divide(bigDecimal, RoundingMode.CEILING)).divide(TWO.multiply(BigDecimal.valueOf(localSaliencyStability.getDecisions().size())), RoundingMode.CEILING);
    }
}
