package org.kie.kogito.explainability.utils;

import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.function.BiFunction;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.kie.kogito.explainability.model.Dataset;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.Value;

/* loaded from: input_file:org/kie/kogito/explainability/utils/FairnessMetrics.class */
public class FairnessMetrics {
    static final /* synthetic */ boolean $assertionsDisabled;

    private FairnessMetrics() {
    }

    public static double individualConsistency(BiFunction<PredictionInput, List<PredictionInput>, List<PredictionInput>> biFunction, List<PredictionInput> list, PredictionProvider predictionProvider) throws ExecutionException, InterruptedException {
        double d = 1.0d;
        for (PredictionInput predictionInput : list) {
            PredictionOutput predictionOutput = predictionProvider.predictAsync(List.of(predictionInput)).get().get(0);
            List<PredictionOutput> list2 = predictionProvider.predictAsync(biFunction.apply(predictionInput, list)).get();
            for (Output output : predictionOutput.getOutputs()) {
                Value value = output.getValue();
                Iterator<PredictionOutput> it = list2.iterator();
                while (it.hasNext()) {
                    Output orElse = it.next().getByName(output.getName()).orElse(null);
                    if (orElse != null && !value.equals(orElse.getValue())) {
                        d -= 1.0f / ((r0.size() * predictionOutput.getOutputs().size()) * list.size());
                    }
                }
            }
        }
        return d;
    }

    public static double groupStatisticalParityDifference(Predicate<PredictionInput> predicate, List<PredictionInput> list, PredictionProvider predictionProvider, Output output) throws ExecutionException, InterruptedException {
        return getFavorableLabelProbability(predicate.negate(), list, predictionProvider, output) - getFavorableLabelProbability(predicate, list, predictionProvider, output);
    }

    public static double groupDisparateImpactRatio(Predicate<PredictionInput> predicate, List<PredictionInput> list, PredictionProvider predictionProvider, Output output) throws ExecutionException, InterruptedException {
        return getFavorableLabelProbability(predicate.negate(), list, predictionProvider, output) / getFavorableLabelProbability(predicate, list, predictionProvider, output);
    }

    private static double getFavorableLabelProbability(Predicate<PredictionInput> predicate, List<PredictionInput> list, PredictionProvider predictionProvider, Output output) throws ExecutionException, InterruptedException {
        String name = output.getName();
        Value value = output.getValue();
        List<PredictionOutput> selectedPredictionOutputs = getSelectedPredictionOutputs(predicate, list, predictionProvider);
        return selectedPredictionOutputs.stream().map(predictionOutput -> {
            return predictionOutput.getByName(name);
        }).map((v0) -> {
            return v0.get();
        }).filter(output2 -> {
            return output2.getValue().equals(value);
        }).count() / selectedPredictionOutputs.size();
    }

    private static List<PredictionOutput> getSelectedPredictionOutputs(Predicate<PredictionInput> predicate, List<PredictionInput> list, PredictionProvider predictionProvider) throws InterruptedException, ExecutionException {
        return predictionProvider.predictAsync((List) list.stream().filter(predicate).collect(Collectors.toList())).get();
    }

    public static double groupAverageOddsDifference(Predicate<PredictionInput> predicate, Predicate<PredictionOutput> predicate2, Dataset dataset, PredictionProvider predictionProvider) throws ExecutionException, InterruptedException {
        Dataset filterByInput = dataset.filterByInput(predicate);
        Map<String, Integer> countMatchingOutputSelector = countMatchingOutputSelector(filterByInput, predictionProvider.predictAsync(filterByInput.getInputs()).get(), predicate2);
        Dataset filterByInput2 = dataset.filterByInput(predicate.negate());
        Map<String, Integer> countMatchingOutputSelector2 = countMatchingOutputSelector(filterByInput2, predictionProvider.predictAsync(filterByInput2.getInputs()).get(), predicate2);
        double intValue = countMatchingOutputSelector2.get("tp").intValue();
        double intValue2 = countMatchingOutputSelector2.get("tn").intValue();
        double intValue3 = countMatchingOutputSelector2.get("fp").intValue();
        double intValue4 = countMatchingOutputSelector2.get("fn").intValue();
        double intValue5 = countMatchingOutputSelector.get("tp").intValue();
        double intValue6 = countMatchingOutputSelector.get("tn").intValue();
        double intValue7 = countMatchingOutputSelector.get("fp").intValue();
        return (((intValue / (intValue + intValue4)) - (intValue5 / ((intValue5 + countMatchingOutputSelector.get("fn").intValue()) + 1.0E-10d))) / 2.0d) + (((intValue3 / (intValue3 + intValue2)) - (intValue7 / ((intValue7 + intValue6) + 1.0E-10d))) / 2.0d);
    }

    private static Map<String, Integer> countMatchingOutputSelector(Dataset dataset, List<PredictionOutput> list, Predicate<PredictionOutput> predicate) {
        if (!$assertionsDisabled && list.size() != dataset.getData().size()) {
            throw new AssertionError("dataset and predictions must have same size");
        }
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        int i4 = 0;
        int i5 = 0;
        Iterator<Prediction> it = dataset.getData().iterator();
        while (it.hasNext()) {
            if (predicate.test(it.next().getOutput())) {
                if (predicate.test(list.get(i5))) {
                    i++;
                } else {
                    i4++;
                }
            } else if (predicate.test(list.get(i5))) {
                i3++;
            } else {
                i2++;
            }
            i5++;
        }
        HashMap hashMap = new HashMap();
        hashMap.put("tp", Integer.valueOf(i));
        hashMap.put("tn", Integer.valueOf(i2));
        hashMap.put("fp", Integer.valueOf(i3));
        hashMap.put("fn", Integer.valueOf(i4));
        return hashMap;
    }

    public static double groupAveragePredictiveValueDifference(Predicate<PredictionInput> predicate, Predicate<PredictionOutput> predicate2, Dataset dataset, PredictionProvider predictionProvider) throws ExecutionException, InterruptedException {
        Dataset filterByInput = dataset.filterByInput(predicate);
        Map<String, Integer> countMatchingOutputSelector = countMatchingOutputSelector(filterByInput, predictionProvider.predictAsync(filterByInput.getInputs()).get(), predicate2);
        double intValue = countMatchingOutputSelector.get("tp").intValue();
        double intValue2 = countMatchingOutputSelector.get("tn").intValue();
        double intValue3 = countMatchingOutputSelector.get("fp").intValue();
        double intValue4 = countMatchingOutputSelector.get("fn").intValue();
        Dataset filterByInput2 = dataset.filterByInput(predicate.negate());
        Map<String, Integer> countMatchingOutputSelector2 = countMatchingOutputSelector(filterByInput2, predictionProvider.predictAsync(filterByInput2.getInputs()).get(), predicate2);
        double intValue5 = countMatchingOutputSelector2.get("tp").intValue();
        double intValue6 = countMatchingOutputSelector2.get("tn").intValue();
        double intValue7 = countMatchingOutputSelector2.get("fp").intValue();
        double intValue8 = countMatchingOutputSelector2.get("fn").intValue();
        return (((intValue5 / (intValue5 + intValue7)) - (intValue / ((intValue + intValue3) + 1.0E-10d))) / 2.0d) + (((intValue8 / (intValue8 + intValue6)) - (intValue4 / ((intValue4 + intValue2) + 1.0E-10d))) / 2.0d);
    }

    static {
        $assertionsDisabled = !FairnessMetrics.class.desiredAssertionStatus();
    }
}
