package org.kie.kogito.explainability.utils;

import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
import org.apache.commons.lang3.tuple.Pair;
import org.kie.kogito.explainability.Config;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.FeatureImportance;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.Saliency;
import org.kie.kogito.explainability.model.Type;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/kie/kogito/explainability/utils/ExplainabilityMetrics.class */
public class ExplainabilityMetrics {
    private static final Logger LOGGER = LoggerFactory.getLogger((Class<?>) ExplainabilityMetrics.class);
    private static final double CONFIDENCE_DROP_RATIO = 0.2d;

    private ExplainabilityMetrics() {
    }

    public static double quantifyExplainability(int i, int i2, double d) {
        if (i + i2 > 0) {
            return (0.333d / i) + (0.333d / i2) + (0.333d * (1.0d - d));
        }
        return 0.0d;
    }

    public static double impactScore(PredictionProvider predictionProvider, Prediction prediction, List<FeatureImportance> list) throws InterruptedException, ExecutionException, TimeoutException {
        List<Feature> copyOf = List.copyOf(prediction.getInput().getFeatures());
        Iterator<FeatureImportance> it = list.iterator();
        while (it.hasNext()) {
            copyOf = DataUtils.dropFeature(copyOf, it.next().getFeature());
        }
        try {
            double d = 0.0d;
            for (PredictionOutput predictionOutput : predictionProvider.predictAsync(List.of(new PredictionInput(copyOf))).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())) {
                double size = predictionOutput.getOutputs().size();
                for (int i = 0; i < size; i++) {
                    Output output = prediction.getOutput().getOutputs().get(i);
                    Output output2 = predictionOutput.getOutputs().get(i);
                    d += (!output.getValue().asString().equals(output2.getValue().asString()) || output2.getScore() < output.getScore() * CONFIDENCE_DROP_RATIO) ? 1.0d / size : 0.0d;
                }
            }
            return d;
        } catch (InterruptedException | ExecutionException | TimeoutException e) {
            LOGGER.error("Impossible to obtain prediction {}", e.getMessage());
            throw e;
        }
    }

    public static double classificationFidelity(List<Pair<Saliency, Prediction>> list) {
        double d = 0.0d;
        double d2 = 0.0d;
        for (Pair<Saliency, Prediction> pair : list) {
            Saliency left = pair.getLeft();
            for (Output output : pair.getRight().getOutput().getOutputs()) {
                if (Type.BOOLEAN.equals(output.getType())) {
                    double sum = left.getPerFeatureImportance().stream().map((v0) -> {
                        return v0.getScore();
                    }).mapToDouble(d3 -> {
                        return d3.doubleValue();
                    }).sum();
                    double asNumber = output.getValue().asNumber();
                    if ((asNumber >= 0.0d && sum >= 0.0d) || (asNumber < 0.0d && sum < 0.0d)) {
                        d += 1.0d;
                    }
                    d2 += 1.0d;
                }
            }
        }
        if (d2 == 0.0d) {
            return 0.0d;
        }
        return d / d2;
    }
}
