package org.kie.kogito.explainability.utils;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.concurrent.ExecutionException;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import org.assertj.core.api.AssertionsForClassTypes;
import org.junit.jupiter.api.Test;
import org.kie.kogito.explainability.TestUtils;
import org.kie.kogito.explainability.model.Dataset;
import org.kie.kogito.explainability.model.FeatureFactory;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.SimplePrediction;
import org.kie.kogito.explainability.model.Type;
import org.kie.kogito.explainability.model.Value;

/* loaded from: input_file:org/kie/kogito/explainability/utils/FairnessMetricsTest.class */
class FairnessMetricsTest {
    FairnessMetricsTest() {
    }

    @Test
    void testIndividualConsistencyTextClassifier() throws ExecutionException, InterruptedException {
        AssertionsForClassTypes.assertThat(FairnessMetrics.individualConsistency((predictionInput, list) -> {
            String textify = DataUtils.textify(predictionInput);
            return ((List) list.stream().sorted((predictionInput, predictionInput2) -> {
                return StringUtils.getFuzzyDistance(DataUtils.textify(predictionInput2), textify, Locale.getDefault()) - StringUtils.getFuzzyDistance(DataUtils.textify(predictionInput), textify, Locale.getDefault());
            }).collect(Collectors.toList())).subList(1, 3);
        }, getTestInputs(), TestUtils.getDummyTextClassifier())).isBetween(Double.valueOf(0.0d), Double.valueOf(1.0d));
    }

    @Test
    void testGroupSPDTextClassifier() throws ExecutionException, InterruptedException {
        AssertionsForClassTypes.assertThat(FairnessMetrics.groupStatisticalParityDifference(predictionInput -> {
            return DataUtils.textify(predictionInput).contains("please");
        }, getTestInputs(), TestUtils.getDummyTextClassifier(), new Output("spam", Type.BOOLEAN, new Value(false), 1.0d))).isBetween(Double.valueOf(-1.0d), Double.valueOf(1.0d));
    }

    @Test
    void testGroupDIRTextClassifier() throws ExecutionException, InterruptedException {
        AssertionsForClassTypes.assertThat(FairnessMetrics.groupDisparateImpactRatio(predictionInput -> {
            return DataUtils.textify(predictionInput).contains("please");
        }, getTestInputs(), TestUtils.getDummyTextClassifier(), new Output("spam", Type.BOOLEAN, new Value(false), 1.0d))).isPositive();
    }

    @Test
    void testGroupAODTextClassifier() throws ExecutionException, InterruptedException {
        AssertionsForClassTypes.assertThat(FairnessMetrics.groupAverageOddsDifference(predictionInput -> {
            return DataUtils.textify(predictionInput).contains("please");
        }, predictionOutput -> {
            return ((Output) predictionOutput.getByName("spam").get()).getValue().asNumber() == 0.0d;
        }, new Dataset(getTestData()), TestUtils.getDummyTextClassifier())).isBetween(Double.valueOf(-1.0d), Double.valueOf(1.0d));
    }

    @Test
    void testGroupAPVDTextClassifier() throws ExecutionException, InterruptedException {
        AssertionsForClassTypes.assertThat(FairnessMetrics.groupAveragePredictiveValueDifference(predictionInput -> {
            return DataUtils.textify(predictionInput).contains("please");
        }, predictionOutput -> {
            return ((Output) predictionOutput.getByName("spam").get()).getValue().asNumber() == 0.0d;
        }, new Dataset(getTestData()), TestUtils.getDummyTextClassifier())).isBetween(Double.valueOf(-1.0d), Double.valueOf(1.0d));
    }

    private List<PredictionInput> getTestInputs() {
        ArrayList arrayList = new ArrayList();
        Function function = str -> {
            return Arrays.asList((String[]) str.split(" ").clone());
        };
        ArrayList arrayList2 = new ArrayList();
        arrayList2.add(FeatureFactory.newFulltextFeature("subject", "urgent inquiry", function));
        arrayList2.add(FeatureFactory.newFulltextFeature("text", "please give me some money", function));
        arrayList.add(new PredictionInput(arrayList2));
        ArrayList arrayList3 = new ArrayList();
        arrayList3.add(FeatureFactory.newFulltextFeature("subject", "please reply", function));
        arrayList3.add(FeatureFactory.newFulltextFeature("text", "we got urgent matter! please reply", function));
        arrayList.add(new PredictionInput(arrayList3));
        ArrayList arrayList4 = new ArrayList();
        arrayList4.add(FeatureFactory.newFulltextFeature("subject", "please reply", function));
        arrayList4.add(FeatureFactory.newFulltextFeature("text", "we got money matter! please reply", function));
        arrayList.add(new PredictionInput(arrayList4));
        ArrayList arrayList5 = new ArrayList();
        arrayList5.add(FeatureFactory.newFulltextFeature("subject", "inquiry", function));
        arrayList5.add(FeatureFactory.newFulltextFeature("text", "would you like to get a 100% secure way to invest your money?", function));
        arrayList.add(new PredictionInput(arrayList5));
        ArrayList arrayList6 = new ArrayList();
        arrayList6.add(FeatureFactory.newFulltextFeature("subject", "you win", function));
        arrayList6.add(FeatureFactory.newFulltextFeature("text", "you just won an incredible 1M $ prize !", function));
        arrayList.add(new PredictionInput(arrayList6));
        ArrayList arrayList7 = new ArrayList();
        arrayList7.add(FeatureFactory.newFulltextFeature("subject", "prize waiting", function));
        arrayList7.add(FeatureFactory.newFulltextFeature("text", "you are the lucky winner of a 100k $ prize", function));
        arrayList.add(new PredictionInput(arrayList7));
        ArrayList arrayList8 = new ArrayList();
        arrayList8.add(FeatureFactory.newFulltextFeature("subject", "urgent matter", function));
        arrayList8.add(FeatureFactory.newFulltextFeature("text", "we got an urgent inquiry for you to answer.", function));
        arrayList.add(new PredictionInput(arrayList8));
        ArrayList arrayList9 = new ArrayList();
        arrayList9.add(FeatureFactory.newFulltextFeature("subject", "password change", function));
        arrayList9.add(FeatureFactory.newFulltextFeature("text", "you just requested to change your password", function));
        arrayList.add(new PredictionInput(arrayList9));
        ArrayList arrayList10 = new ArrayList();
        arrayList10.add(FeatureFactory.newFulltextFeature("subject", "password stolen", function));
        arrayList10.add(FeatureFactory.newFulltextFeature("text", "we stole your password, if you want it back, send some money .", function));
        arrayList.add(new PredictionInput(arrayList10));
        return arrayList;
    }

    private List<Prediction> getTestData() {
        ArrayList arrayList = new ArrayList();
        Function function = str -> {
            return Arrays.asList((String[]) str.split(" ").clone());
        };
        ArrayList arrayList2 = new ArrayList();
        arrayList2.add(FeatureFactory.newFulltextFeature("subject", "urgent inquiry", function));
        arrayList2.add(FeatureFactory.newFulltextFeature("text", "please give me some money", function));
        arrayList.add(new SimplePrediction(new PredictionInput(arrayList2), new PredictionOutput(List.of(new Output("spam", Type.BOOLEAN, new Value(true), 1.0d)))));
        ArrayList arrayList3 = new ArrayList();
        arrayList3.add(FeatureFactory.newFulltextFeature("subject", "do not reply", function));
        arrayList3.add(FeatureFactory.newFulltextFeature("text", "if you asked to reset your password, ignore this", function));
        arrayList.add(new SimplePrediction(new PredictionInput(arrayList3), new PredictionOutput(List.of(new Output("spam", Type.BOOLEAN, new Value(false), 1.0d)))));
        ArrayList arrayList4 = new ArrayList();
        arrayList4.add(FeatureFactory.newFulltextFeature("subject", "please reply", function));
        arrayList4.add(FeatureFactory.newFulltextFeature("text", "we got money matter! please reply", function));
        arrayList.add(new SimplePrediction(new PredictionInput(arrayList4), new PredictionOutput(List.of(new Output("spam", Type.BOOLEAN, new Value(true), 1.0d)))));
        ArrayList arrayList5 = new ArrayList();
        arrayList5.add(FeatureFactory.newFulltextFeature("subject", "inquiry", function));
        arrayList5.add(FeatureFactory.newFulltextFeature("text", "would you like to get a 100% secure way to invest your money?", function));
        arrayList.add(new SimplePrediction(new PredictionInput(arrayList5), new PredictionOutput(List.of(new Output("spam", Type.BOOLEAN, new Value(true), 1.0d)))));
        ArrayList arrayList6 = new ArrayList();
        arrayList6.add(FeatureFactory.newFulltextFeature("subject", "clear some space", function));
        arrayList6.add(FeatureFactory.newFulltextFeature("text", "you just finished your space, upgrade today for 1 $ a week", function));
        arrayList.add(new SimplePrediction(new PredictionInput(arrayList6), new PredictionOutput(List.of(new Output("spam", Type.BOOLEAN, new Value(false), 1.0d)))));
        ArrayList arrayList7 = new ArrayList();
        arrayList7.add(FeatureFactory.newFulltextFeature("subject", "prize waiting", function));
        arrayList7.add(FeatureFactory.newFulltextFeature("text", "you are the lucky winner of a 100k $ prize", function));
        arrayList.add(new SimplePrediction(new PredictionInput(arrayList7), new PredictionOutput(List.of(new Output("spam", Type.BOOLEAN, new Value(true), 1.0d)))));
        ArrayList arrayList8 = new ArrayList();
        arrayList8.add(FeatureFactory.newFulltextFeature("subject", "urgent matter", function));
        arrayList8.add(FeatureFactory.newFulltextFeature("text", "we got an urgent inquiry for you to answer.", function));
        arrayList.add(new SimplePrediction(new PredictionInput(arrayList8), new PredictionOutput(List.of(new Output("spam", Type.BOOLEAN, new Value(true), 1.0d)))));
        ArrayList arrayList9 = new ArrayList();
        arrayList9.add(FeatureFactory.newFulltextFeature("subject", "password change", function));
        arrayList9.add(FeatureFactory.newFulltextFeature("text", "you just requested to change your password", function));
        arrayList.add(new SimplePrediction(new PredictionInput(arrayList9), new PredictionOutput(List.of(new Output("spam", Type.BOOLEAN, new Value(false), 1.0d)))));
        ArrayList arrayList10 = new ArrayList();
        arrayList10.add(FeatureFactory.newFulltextFeature("subject", "password stolen", function));
        arrayList10.add(FeatureFactory.newFulltextFeature("text", "we stole your password, if you want it back, send some money .", function));
        arrayList.add(new SimplePrediction(new PredictionInput(arrayList10), new PredictionOutput(List.of(new Output("spam", Type.BOOLEAN, new Value(true), 1.0d)))));
        return arrayList;
    }
}
