package org.kie.kogito.explainability.utils;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.commons.lang3.tuple.Pair;
import org.kie.kogito.explainability.Config;
import org.kie.kogito.explainability.local.LocalExplainer;
import org.kie.kogito.explainability.model.DataDistribution;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.FeatureImportance;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.Saliency;
import org.kie.kogito.explainability.model.Type;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/kie/kogito/explainability/utils/ExplainabilityMetrics.class */
public class ExplainabilityMetrics {
    private static final Logger LOGGER = LoggerFactory.getLogger(ExplainabilityMetrics.class);
    private static final double CONFIDENCE_DROP_RATIO = 0.2d;

    private ExplainabilityMetrics() {
    }

    public static double quantifyExplainability(int i, int i2, double d) {
        if (i + i2 > 0) {
            return (0.333d / i) + (0.333d / i2) + (0.333d * (1.0d - d));
        }
        return 0.0d;
    }

    public static double impactScore(PredictionProvider predictionProvider, Prediction prediction, List<FeatureImportance> list) throws InterruptedException, ExecutionException, TimeoutException {
        List<Feature> copyOf = List.copyOf(prediction.getInput().getFeatures());
        Iterator<FeatureImportance> it = list.iterator();
        while (it.hasNext()) {
            copyOf = DataUtils.dropFeature(copyOf, it.next().getFeature());
        }
        try {
            double d = 0.0d;
            for (PredictionOutput predictionOutput : predictionProvider.predictAsync(List.of(new PredictionInput(copyOf))).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())) {
                double size = predictionOutput.getOutputs().size();
                for (int i = 0; i < size; i++) {
                    Output output = prediction.getOutput().getOutputs().get(i);
                    Output output2 = predictionOutput.getOutputs().get(i);
                    d += (!output.getValue().asString().equals(output2.getValue().asString()) || output2.getScore() < output.getScore() * CONFIDENCE_DROP_RATIO) ? 1.0d / size : 0.0d;
                }
            }
            return d;
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            throw new IllegalStateException("Impossible to obtain prediction (Thread interrupted)", e);
        } catch (ExecutionException | TimeoutException e2) {
            LOGGER.error("Impossible to obtain prediction {}", e2.getMessage());
            throw new IllegalStateException("Impossible to obtain prediction", e2);
        }
    }

    public static double classificationFidelity(List<Pair<Saliency, Prediction>> list) {
        double d = 0.0d;
        double d2 = 0.0d;
        for (Pair<Saliency, Prediction> pair : list) {
            Saliency saliency = (Saliency) pair.getLeft();
            for (Output output : ((Prediction) pair.getRight()).getOutput().getOutputs()) {
                if (Type.BOOLEAN.equals(output.getType())) {
                    double sum = saliency.getPerFeatureImportance().stream().map((v0) -> {
                        return v0.getScore();
                    }).mapToDouble(d3 -> {
                        return d3.doubleValue();
                    }).sum();
                    double asNumber = output.getValue().asNumber();
                    if ((asNumber >= 0.0d && sum >= 0.0d) || (asNumber < 0.0d && sum < 0.0d)) {
                        d += 1.0d;
                    }
                    d2 += 1.0d;
                }
            }
        }
        if (d2 == 0.0d) {
            return 0.0d;
        }
        return d / d2;
    }

    public static LocalSaliencyStability getLocalSaliencyStability(PredictionProvider predictionProvider, Prediction prediction, LocalExplainer<Map<String, Saliency>> localExplainer, int i, int i2) throws InterruptedException, ExecutionException, TimeoutException {
        Map<String, List<Saliency>> multipleSaliencies = getMultipleSaliencies(predictionProvider, prediction, localExplainer, i2);
        LocalSaliencyStability localSaliencyStability = new LocalSaliencyStability(multipleSaliencies.keySet());
        for (Map.Entry<String, List<Saliency>> entry : multipleSaliencies.entrySet()) {
            String key = entry.getKey();
            List<Saliency> value = entry.getValue();
            for (int i3 = 1; i3 <= i; i3++) {
                int i4 = i3;
                Pair<List<String>, Long> mostFrequent = getMostFrequent(getTopKFeaturesFrequency(value, saliency -> {
                    return saliency.getPositiveFeatures(i4);
                }));
                int i5 = i3;
                localSaliencyStability.add(key, i5, (List) mostFrequent.getKey(), ((Long) mostFrequent.getValue()).longValue() / value.size(), (List) getMostFrequent(getTopKFeaturesFrequency(value, saliency2 -> {
                    return saliency2.getNegativeFeatures(i4);
                })).getKey(), ((Long) r0.getValue()).longValue() / value.size());
            }
        }
        return localSaliencyStability;
    }

    private static Map<String, List<Saliency>> getMultipleSaliencies(PredictionProvider predictionProvider, Prediction prediction, LocalExplainer<Map<String, Saliency>> localExplainer, int i) throws InterruptedException, ExecutionException, TimeoutException {
        HashMap hashMap = new HashMap();
        int i2 = 0;
        for (int i3 = 0; i3 < i; i3++) {
            for (Map.Entry<String, Saliency> entry : localExplainer.explainAsync(prediction, predictionProvider).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit()).entrySet()) {
                List<FeatureImportance> topFeatures = entry.getValue().getTopFeatures(1);
                if (topFeatures.isEmpty() || topFeatures.get(0).getScore() == 0.0d) {
                    LOGGER.debug("skipping empty / zero saliency for {}", entry.getKey());
                    i2++;
                } else if (hashMap.containsKey(entry.getKey())) {
                    ArrayList arrayList = new ArrayList((List) hashMap.get(entry.getKey()));
                    arrayList.add(entry.getValue());
                    hashMap.put(entry.getKey(), arrayList);
                } else {
                    hashMap.put(entry.getKey(), List.of(entry.getValue()));
                }
            }
        }
        LOGGER.debug("skipped {} useless saliencies", Integer.valueOf(i2));
        return hashMap;
    }

    private static Map<List<String>, Long> getTopKFeaturesFrequency(List<Saliency> list, Function<Saliency, List<FeatureImportance>> function) {
        return (Map) list.stream().map(function).map(list2 -> {
            return (List) list2.stream().map(featureImportance -> {
                return featureImportance.getFeature().getName();
            }).collect(Collectors.toList());
        }).collect(Collectors.groupingBy(Function.identity(), Collectors.counting()));
    }

    private static Pair<List<String>, Long> getMostFrequent(Map<List<String>, Long> map) {
        Map.Entry entry = (Map.Entry) Collections.max(map.entrySet(), Map.Entry.comparingByValue());
        return Pair.of((List) entry.getKey(), (Long) entry.getValue());
    }

    public static double getLocalSaliencyRecall(String str, PredictionProvider predictionProvider, LocalExplainer<Map<String, Saliency>> localExplainer, DataDistribution dataDistribution, int i, int i2) throws InterruptedException, ExecutionException, TimeoutException {
        List<Prediction> scoreSortedPredictions = DataUtils.getScoreSortedPredictions(str, predictionProvider, dataDistribution);
        ArrayList<Prediction> arrayList = new ArrayList(scoreSortedPredictions.subList(0, i2));
        ArrayList arrayList2 = new ArrayList(scoreSortedPredictions.subList(scoreSortedPredictions.size() - i2, scoreSortedPredictions.size()));
        double d = 0.0d;
        double d2 = 0.0d;
        int i3 = 0;
        for (Prediction prediction : arrayList) {
            Optional<Output> byName = prediction.getOutput().getByName(str);
            if (byName.isPresent()) {
                Output output = byName.get();
                Map<String, Saliency> map = localExplainer.explainAsync(prediction, predictionProvider).get(5L, Config.DEFAULT_ASYNC_TIMEUNIT);
                if (map.containsKey(str)) {
                    List<PredictionOutput> list = predictionProvider.predictAsync(List.of(maskInput((List) map.get(str).getPerFeatureImportance().stream().sorted((featureImportance, featureImportance2) -> {
                        return Double.compare(featureImportance2.getScore(), featureImportance.getScore());
                    }).limit(i).collect(Collectors.toList()), ((Prediction) arrayList2.get(i3)).getInput()))).get(5L, Config.DEFAULT_ASYNC_TIMEUNIT);
                    if (!list.isEmpty() && list.get(0).getByName(str).isPresent()) {
                        if (output.getValue().equals(byName.get().getValue())) {
                            d += 1.0d;
                        } else {
                            d2 += 1.0d;
                        }
                    }
                    i3++;
                }
            }
        }
        if (d + d2 > 0.0d) {
            return d / (d + d2);
        }
        return Double.NaN;
    }

    private static PredictionInput maskInput(List<FeatureImportance> list, PredictionInput predictionInput) {
        ArrayList arrayList = new ArrayList();
        Iterator<FeatureImportance> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(it.next().getFeature());
        }
        return replaceAllFeatures(arrayList, predictionInput);
    }

    public static double getLocalSaliencyPrecision(String str, PredictionProvider predictionProvider, LocalExplainer<Map<String, Saliency>> localExplainer, DataDistribution dataDistribution, int i, int i2) throws InterruptedException, ExecutionException, TimeoutException {
        List<Prediction> scoreSortedPredictions = DataUtils.getScoreSortedPredictions(str, predictionProvider, dataDistribution);
        ArrayList arrayList = new ArrayList(scoreSortedPredictions.subList(0, i2));
        double d = 0.0d;
        double d2 = 0.0d;
        int i3 = 0;
        Iterator it = new ArrayList(scoreSortedPredictions.subList(scoreSortedPredictions.size() - i2, scoreSortedPredictions.size())).iterator();
        while (it.hasNext()) {
            Map<String, Saliency> map = localExplainer.explainAsync((Prediction) it.next(), predictionProvider).get(5L, Config.DEFAULT_ASYNC_TIMEUNIT);
            if (map.containsKey(str)) {
                List list = (List) map.get(str).getPerFeatureImportance().stream().sorted(Comparator.comparingDouble((v0) -> {
                    return v0.getScore();
                })).limit(i).collect(Collectors.toList());
                Prediction prediction = (Prediction) arrayList.get(i3);
                List<PredictionOutput> list2 = predictionProvider.predictAsync(List.of(maskInput(list, prediction.getInput()))).get(5L, Config.DEFAULT_ASYNC_TIMEUNIT);
                if (!list2.isEmpty()) {
                    Optional<Output> byName = list2.get(0).getByName(str);
                    if (byName.isPresent()) {
                        Output output = byName.get();
                        Optional<Output> byName2 = prediction.getOutput().getByName(str);
                        if (byName2.isPresent()) {
                            if (byName2.get().getValue().equals(output.getValue())) {
                                d += 1.0d;
                            } else {
                                d2 += 1.0d;
                            }
                        }
                    }
                }
                i3++;
            }
        }
        if (d + d2 > 0.0d) {
            return d / (d + d2);
        }
        return Double.NaN;
    }

    public static double getLocalSaliencyF1(String str, PredictionProvider predictionProvider, LocalExplainer<Map<String, Saliency>> localExplainer, DataDistribution dataDistribution, int i, int i2) throws InterruptedException, ExecutionException, TimeoutException {
        double localSaliencyPrecision = getLocalSaliencyPrecision(str, predictionProvider, localExplainer, dataDistribution, i, i2);
        double localSaliencyRecall = getLocalSaliencyRecall(str, predictionProvider, localExplainer, dataDistribution, i, i2);
        if (!Double.isFinite(localSaliencyPrecision + localSaliencyRecall) || localSaliencyPrecision + localSaliencyRecall <= 0.0d) {
            return Double.NaN;
        }
        return ((2.0d * localSaliencyPrecision) * localSaliencyRecall) / (localSaliencyPrecision + localSaliencyRecall);
    }

    private static PredictionInput replaceAllFeatures(List<Feature> list, PredictionInput predictionInput) {
        List<Feature> copyOf = List.copyOf(predictionInput.getFeatures());
        Iterator<Feature> it = list.iterator();
        while (it.hasNext()) {
            copyOf = DataUtils.replaceFeatures(it.next(), copyOf);
        }
        return new PredictionInput(copyOf);
    }
}
