package org.kie.kogito.explainability.local.lime;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.function.Function;
import org.assertj.core.api.AssertionsForClassTypes;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.kie.kogito.explainability.Config;
import org.kie.kogito.explainability.TestUtils;
import org.kie.kogito.explainability.model.FeatureFactory;
import org.kie.kogito.explainability.model.FeatureImportance;
import org.kie.kogito.explainability.model.PerturbationContext;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionInputsDataDistribution;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.Saliency;
import org.kie.kogito.explainability.model.SimplePrediction;
import org.kie.kogito.explainability.utils.ExplainabilityMetrics;

/* loaded from: input_file:org/kie/kogito/explainability/local/lime/DummyModelsLimeExplainerTest.class */
class DummyModelsLimeExplainerTest {
    DummyModelsLimeExplainerTest() {
    }

    @ValueSource(longs = {0})
    @ParameterizedTest
    void testMapOneFeatureToOutputRegression(long j) throws Exception {
        Random random = new Random();
        LinkedList linkedList = new LinkedList();
        linkedList.add(TestUtils.getMockedNumericFeature(100.0d));
        linkedList.add(TestUtils.getMockedNumericFeature(20.0d));
        linkedList.add(TestUtils.getMockedNumericFeature(0.1d));
        PredictionInput predictionInput = new PredictionInput(linkedList);
        PredictionProvider featurePassModel = TestUtils.getFeaturePassModel(1);
        SimplePrediction simplePrediction = new SimplePrediction(predictionInput, (PredictionOutput) ((List) featurePassModel.predictAsync(List.of(predictionInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0));
        LimeExplainer limeExplainer = new LimeExplainer(new LimeConfig().withSamples(100).withPerturbationContext(new PerturbationContext(Long.valueOf(j), random, 1)));
        for (Saliency saliency : ((Map) limeExplainer.explainAsync(simplePrediction, featurePassModel).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).values()) {
            Assertions.assertNotNull(saliency);
            List topFeatures = saliency.getTopFeatures(3);
            Assertions.assertEquals(3, topFeatures.size());
            Assertions.assertEquals(1.0d, ExplainabilityMetrics.impactScore(featurePassModel, simplePrediction, topFeatures));
        }
        TestUtils.assertLimeStability(featurePassModel, simplePrediction, limeExplainer, 1, 0.5d, 0.5d);
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < 100; i++) {
            LinkedList linkedList2 = new LinkedList();
            linkedList2.add(TestUtils.getMockedNumericFeature());
            linkedList2.add(TestUtils.getMockedNumericFeature());
            linkedList2.add(TestUtils.getMockedNumericFeature());
            arrayList.add(new PredictionInput(linkedList2));
        }
        PredictionInputsDataDistribution predictionInputsDataDistribution = new PredictionInputsDataDistribution(arrayList);
        String str = "feature-" + 1;
        AssertionsForClassTypes.assertThat(ExplainabilityMetrics.getLocalSaliencyPrecision(str, featurePassModel, limeExplainer, predictionInputsDataDistribution, 2, 10)).isZero();
        AssertionsForClassTypes.assertThat(ExplainabilityMetrics.getLocalSaliencyRecall(str, featurePassModel, limeExplainer, predictionInputsDataDistribution, 2, 10)).isEqualTo(1.0d);
        AssertionsForClassTypes.assertThat(ExplainabilityMetrics.getLocalSaliencyF1(str, featurePassModel, limeExplainer, predictionInputsDataDistribution, 2, 10)).isZero();
    }

    @ValueSource(longs = {0})
    @ParameterizedTest
    void testUnusedFeatureRegression(long j) throws Exception {
        Random random = new Random();
        LinkedList linkedList = new LinkedList();
        linkedList.add(TestUtils.getMockedNumericFeature(100.0d));
        linkedList.add(TestUtils.getMockedNumericFeature(20.0d));
        linkedList.add(TestUtils.getMockedNumericFeature(10.0d));
        PredictionProvider sumSkipModel = TestUtils.getSumSkipModel(2);
        PredictionInput predictionInput = new PredictionInput(linkedList);
        SimplePrediction simplePrediction = new SimplePrediction(predictionInput, (PredictionOutput) ((List) sumSkipModel.predictAsync(List.of(predictionInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0));
        LimeExplainer limeExplainer = new LimeExplainer(new LimeConfig().withSamples(10).withPerturbationContext(new PerturbationContext(Long.valueOf(j), random, 1)));
        for (Saliency saliency : ((Map) limeExplainer.explainAsync(simplePrediction, sumSkipModel).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).values()) {
            Assertions.assertNotNull(saliency);
            List topFeatures = saliency.getTopFeatures(3);
            Assertions.assertEquals(3, topFeatures.size());
            Assertions.assertEquals(1.0d, ExplainabilityMetrics.impactScore(sumSkipModel, simplePrediction, topFeatures));
        }
        TestUtils.assertLimeStability(sumSkipModel, simplePrediction, limeExplainer, 1, 0.5d, 0.5d);
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < 100; i++) {
            LinkedList linkedList2 = new LinkedList();
            linkedList2.add(TestUtils.getMockedNumericFeature());
            linkedList2.add(TestUtils.getMockedNumericFeature());
            linkedList2.add(TestUtils.getMockedNumericFeature());
            arrayList.add(new PredictionInput(linkedList2));
        }
        PredictionInputsDataDistribution predictionInputsDataDistribution = new PredictionInputsDataDistribution(arrayList);
        String str = "sum-but" + 2;
        AssertionsForClassTypes.assertThat(ExplainabilityMetrics.getLocalSaliencyPrecision(str, sumSkipModel, limeExplainer, predictionInputsDataDistribution, 2, 10)).isEqualTo(1.0d);
        AssertionsForClassTypes.assertThat(ExplainabilityMetrics.getLocalSaliencyRecall(str, sumSkipModel, limeExplainer, predictionInputsDataDistribution, 2, 10)).isEqualTo(1.0d);
        AssertionsForClassTypes.assertThat(ExplainabilityMetrics.getLocalSaliencyF1(str, sumSkipModel, limeExplainer, predictionInputsDataDistribution, 2, 10)).isEqualTo(1.0d);
    }

    @ValueSource(longs = {0})
    @ParameterizedTest
    void testMapOneFeatureToOutputClassification(long j) throws Exception {
        Random random = new Random();
        LinkedList linkedList = new LinkedList();
        linkedList.add(FeatureFactory.newNumericalFeature("f1", 1));
        linkedList.add(FeatureFactory.newNumericalFeature("f2", 1));
        linkedList.add(FeatureFactory.newNumericalFeature("f3", 3));
        PredictionInput predictionInput = new PredictionInput(linkedList);
        PredictionProvider evenFeatureModel = TestUtils.getEvenFeatureModel(1);
        SimplePrediction simplePrediction = new SimplePrediction(predictionInput, (PredictionOutput) ((List) evenFeatureModel.predictAsync(List.of(predictionInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0));
        LimeExplainer limeExplainer = new LimeExplainer(new LimeConfig().withSamples(100).withPerturbationContext(new PerturbationContext(Long.valueOf(j), random, 2)));
        for (Saliency saliency : ((Map) limeExplainer.explainAsync(simplePrediction, evenFeatureModel).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).values()) {
            Assertions.assertNotNull(saliency);
            List topFeatures = saliency.getTopFeatures(3);
            Assertions.assertEquals(3, topFeatures.size());
            Assertions.assertEquals(1.0d, ExplainabilityMetrics.impactScore(evenFeatureModel, simplePrediction, topFeatures));
        }
        TestUtils.assertLimeStability(evenFeatureModel, simplePrediction, limeExplainer, 1, 0.5d, 0.5d);
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < 100; i++) {
            LinkedList linkedList2 = new LinkedList();
            linkedList2.add(TestUtils.getMockedNumericFeature());
            linkedList2.add(TestUtils.getMockedNumericFeature());
            linkedList2.add(TestUtils.getMockedNumericFeature());
            arrayList.add(new PredictionInput(linkedList2));
        }
        PredictionInputsDataDistribution predictionInputsDataDistribution = new PredictionInputsDataDistribution(arrayList);
        String str = "feature-" + 1;
        AssertionsForClassTypes.assertThat(ExplainabilityMetrics.getLocalSaliencyPrecision(str, evenFeatureModel, limeExplainer, predictionInputsDataDistribution, 2, 10)).isEqualTo(1.0d);
        AssertionsForClassTypes.assertThat(ExplainabilityMetrics.getLocalSaliencyRecall(str, evenFeatureModel, limeExplainer, predictionInputsDataDistribution, 2, 10)).isEqualTo(1.0d);
        AssertionsForClassTypes.assertThat(ExplainabilityMetrics.getLocalSaliencyF1(str, evenFeatureModel, limeExplainer, predictionInputsDataDistribution, 2, 10)).isEqualTo(1.0d);
    }

    @ValueSource(longs = {0})
    @ParameterizedTest
    void testTextSpamClassification(long j) throws Exception {
        Random random = new Random();
        LinkedList linkedList = new LinkedList();
        Function function = str -> {
            return Arrays.asList((String[]) str.split(" ").clone());
        };
        linkedList.add(FeatureFactory.newFulltextFeature("f1", "we go here and there", function));
        linkedList.add(FeatureFactory.newFulltextFeature("f2", "please give me some money", function));
        linkedList.add(FeatureFactory.newFulltextFeature("f3", "dear friend, please reply", function));
        PredictionInput predictionInput = new PredictionInput(linkedList);
        PredictionProvider dummyTextClassifier = TestUtils.getDummyTextClassifier();
        SimplePrediction simplePrediction = new SimplePrediction(predictionInput, (PredictionOutput) ((List) dummyTextClassifier.predictAsync(List.of(predictionInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0));
        LimeExplainer limeExplainer = new LimeExplainer(new LimeConfig().withSamples(100).withPerturbationContext(new PerturbationContext(Long.valueOf(j), random, 1)));
        for (Saliency saliency : ((Map) limeExplainer.explainAsync(simplePrediction, dummyTextClassifier).toCompletableFuture().get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).values()) {
            Assertions.assertNotNull(saliency);
            List positiveFeatures = saliency.getPositiveFeatures(1);
            Assertions.assertEquals(1, positiveFeatures.size());
            Assertions.assertEquals(1.0d, ExplainabilityMetrics.impactScore(dummyTextClassifier, simplePrediction, positiveFeatures));
        }
        TestUtils.assertLimeStability(dummyTextClassifier, simplePrediction, limeExplainer, 1, 0.5d, 0.2d);
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < 100; i++) {
            LinkedList linkedList2 = new LinkedList();
            linkedList2.add(TestUtils.getMockedNumericFeature());
            linkedList2.add(TestUtils.getMockedNumericFeature());
            linkedList2.add(TestUtils.getMockedNumericFeature());
            arrayList.add(new PredictionInput(linkedList2));
        }
        PredictionInputsDataDistribution predictionInputsDataDistribution = new PredictionInputsDataDistribution(arrayList);
        AssertionsForClassTypes.assertThat(ExplainabilityMetrics.getLocalSaliencyPrecision("spam", dummyTextClassifier, limeExplainer, predictionInputsDataDistribution, 2, 10)).isEqualTo(1.0d);
        AssertionsForClassTypes.assertThat(ExplainabilityMetrics.getLocalSaliencyRecall("spam", dummyTextClassifier, limeExplainer, predictionInputsDataDistribution, 2, 10)).isEqualTo(1.0d);
        AssertionsForClassTypes.assertThat(ExplainabilityMetrics.getLocalSaliencyF1("spam", dummyTextClassifier, limeExplainer, predictionInputsDataDistribution, 2, 10)).isEqualTo(1.0d);
    }

    @ValueSource(longs = {0})
    @ParameterizedTest
    void testUnusedFeatureClassification(long j) throws Exception {
        Random random = new Random();
        LinkedList linkedList = new LinkedList();
        linkedList.add(FeatureFactory.newNumericalFeature("f1", 6));
        linkedList.add(FeatureFactory.newNumericalFeature("f2", 3));
        linkedList.add(FeatureFactory.newNumericalFeature("f3", 5));
        PredictionProvider evenSumModel = TestUtils.getEvenSumModel(2);
        PredictionInput predictionInput = new PredictionInput(linkedList);
        SimplePrediction simplePrediction = new SimplePrediction(predictionInput, (PredictionOutput) ((List) evenSumModel.predictAsync(List.of(predictionInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0));
        LimeExplainer limeExplainer = new LimeExplainer(new LimeConfig().withSamples(100).withPerturbationContext(new PerturbationContext(Long.valueOf(j), random, 1)));
        for (Saliency saliency : ((Map) limeExplainer.explainAsync(simplePrediction, evenSumModel).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).values()) {
            Assertions.assertNotNull(saliency);
            List topFeatures = saliency.getTopFeatures(3);
            Assertions.assertEquals(3, topFeatures.size());
            Assertions.assertEquals(1.0d, ExplainabilityMetrics.impactScore(evenSumModel, simplePrediction, topFeatures));
        }
        TestUtils.assertLimeStability(evenSumModel, simplePrediction, limeExplainer, 1, 0.5d, 0.5d);
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < 100; i++) {
            LinkedList linkedList2 = new LinkedList();
            linkedList2.add(TestUtils.getMockedNumericFeature());
            linkedList2.add(TestUtils.getMockedNumericFeature());
            linkedList2.add(TestUtils.getMockedNumericFeature());
            arrayList.add(new PredictionInput(linkedList2));
        }
        PredictionInputsDataDistribution predictionInputsDataDistribution = new PredictionInputsDataDistribution(arrayList);
        String str = "sum-even-but" + 2;
        AssertionsForClassTypes.assertThat(ExplainabilityMetrics.getLocalSaliencyPrecision(str, evenSumModel, limeExplainer, predictionInputsDataDistribution, 2, 10)).isEqualTo(1.0d);
        AssertionsForClassTypes.assertThat(ExplainabilityMetrics.getLocalSaliencyRecall(str, evenSumModel, limeExplainer, predictionInputsDataDistribution, 2, 10)).isEqualTo(1.0d);
        AssertionsForClassTypes.assertThat(ExplainabilityMetrics.getLocalSaliencyF1(str, evenSumModel, limeExplainer, predictionInputsDataDistribution, 2, 10)).isEqualTo(1.0d);
    }

    @ValueSource(longs = {0})
    @ParameterizedTest
    void testFixedOutput(long j) throws Exception {
        Random random = new Random();
        LinkedList linkedList = new LinkedList();
        linkedList.add(FeatureFactory.newNumericalFeature("f1", 6));
        linkedList.add(FeatureFactory.newNumericalFeature("f2", 3));
        linkedList.add(FeatureFactory.newNumericalFeature("f3", 5));
        PredictionProvider fixedOutputClassifier = TestUtils.getFixedOutputClassifier();
        PredictionInput predictionInput = new PredictionInput(linkedList);
        SimplePrediction simplePrediction = new SimplePrediction(predictionInput, (PredictionOutput) ((List) fixedOutputClassifier.predictAsync(List.of(predictionInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0));
        LimeExplainer limeExplainer = new LimeExplainer(new LimeConfig().withSamples(10).withPerturbationContext(new PerturbationContext(Long.valueOf(j), random, 1)));
        for (Saliency saliency : ((Map) limeExplainer.explainAsync(simplePrediction, fixedOutputClassifier).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).values()) {
            Assertions.assertNotNull(saliency);
            List topFeatures = saliency.getTopFeatures(3);
            Assertions.assertEquals(3, topFeatures.size());
            Iterator it = topFeatures.iterator();
            while (it.hasNext()) {
                Assertions.assertEquals(0.0d, ((FeatureImportance) it.next()).getScore());
            }
            Assertions.assertEquals(0.0d, ExplainabilityMetrics.impactScore(fixedOutputClassifier, simplePrediction, topFeatures));
        }
        TestUtils.assertLimeStability(fixedOutputClassifier, simplePrediction, limeExplainer, 1, 0.5d, 0.5d);
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < 100; i++) {
            LinkedList linkedList2 = new LinkedList();
            linkedList2.add(TestUtils.getMockedNumericFeature());
            linkedList2.add(TestUtils.getMockedNumericFeature());
            linkedList2.add(TestUtils.getMockedNumericFeature());
            arrayList.add(new PredictionInput(linkedList2));
        }
        PredictionInputsDataDistribution predictionInputsDataDistribution = new PredictionInputsDataDistribution(arrayList);
        AssertionsForClassTypes.assertThat(ExplainabilityMetrics.getLocalSaliencyPrecision("class", fixedOutputClassifier, limeExplainer, predictionInputsDataDistribution, 2, 10)).isEqualTo(1.0d);
        AssertionsForClassTypes.assertThat(ExplainabilityMetrics.getLocalSaliencyRecall("class", fixedOutputClassifier, limeExplainer, predictionInputsDataDistribution, 2, 10)).isEqualTo(1.0d);
        AssertionsForClassTypes.assertThat(ExplainabilityMetrics.getLocalSaliencyF1("class", fixedOutputClassifier, limeExplainer, predictionInputsDataDistribution, 2, 10)).isEqualTo(1.0d);
    }
}
