package org.kie.kogito.explainability.local.lime;

import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.function.Function;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.kie.kogito.explainability.Config;
import org.kie.kogito.explainability.TestUtils;
import org.kie.kogito.explainability.model.FeatureFactory;
import org.kie.kogito.explainability.model.FeatureImportance;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.Saliency;
import org.kie.kogito.explainability.utils.ExplainabilityMetrics;

/* loaded from: input_file:org/kie/kogito/explainability/local/lime/DummyModelsLimeExplainerTest.class */
class DummyModelsLimeExplainerTest {
    DummyModelsLimeExplainerTest() {
    }

    @Test
    void testMapOneFeatureToOutputRegression() throws Exception {
        Random random = new Random();
        for (int i = 0; i < 5; i++) {
            random.setSeed(i);
            LinkedList linkedList = new LinkedList();
            linkedList.add(FeatureFactory.newNumericalFeature("f1", 100));
            linkedList.add(FeatureFactory.newNumericalFeature("f2", 20));
            linkedList.add(FeatureFactory.newNumericalFeature("f3", Double.valueOf(0.1d)));
            PredictionInput predictionInput = new PredictionInput(linkedList);
            PredictionProvider featurePassModel = TestUtils.getFeaturePassModel(1);
            Prediction prediction = new Prediction(predictionInput, (PredictionOutput) ((List) featurePassModel.predictAsync(List.of(predictionInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0));
            for (Saliency saliency : ((Map) new LimeExplainer(100, 1, random).explainAsync(prediction, featurePassModel).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).values()) {
                Assertions.assertNotNull(saliency);
                List topFeatures = saliency.getTopFeatures(3);
                Assertions.assertEquals(3, topFeatures.size());
                Assertions.assertEquals(1.0d, ExplainabilityMetrics.impactScore(featurePassModel, prediction, topFeatures));
            }
        }
    }

    @Test
    void testUnusedFeatureRegression() throws Exception {
        Random random = new Random();
        for (int i = 0; i < 5; i++) {
            random.setSeed(i);
            LinkedList linkedList = new LinkedList();
            linkedList.add(FeatureFactory.newNumericalFeature("f1", 100));
            linkedList.add(FeatureFactory.newNumericalFeature("f2", 20));
            linkedList.add(FeatureFactory.newNumericalFeature("f3", 10));
            PredictionProvider sumSkipModel = TestUtils.getSumSkipModel(2);
            PredictionInput predictionInput = new PredictionInput(linkedList);
            Prediction prediction = new Prediction(predictionInput, (PredictionOutput) ((List) sumSkipModel.predictAsync(List.of(predictionInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0));
            for (Saliency saliency : ((Map) new LimeExplainer(1000, 1, random).explainAsync(prediction, sumSkipModel).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).values()) {
                Assertions.assertNotNull(saliency);
                List topFeatures = saliency.getTopFeatures(3);
                Assertions.assertEquals(3, topFeatures.size());
                Assertions.assertEquals(1.0d, ExplainabilityMetrics.impactScore(sumSkipModel, prediction, topFeatures));
            }
        }
    }

    @Test
    void testMapOneFeatureToOutputClassification() throws Exception {
        Random random = new Random();
        for (int i = 0; i < 5; i++) {
            random.setSeed(i);
            LinkedList linkedList = new LinkedList();
            linkedList.add(FeatureFactory.newNumericalFeature("f1", 1));
            linkedList.add(FeatureFactory.newNumericalFeature("f2", 1));
            linkedList.add(FeatureFactory.newNumericalFeature("f3", 3));
            PredictionInput predictionInput = new PredictionInput(linkedList);
            PredictionProvider evenFeatureModel = TestUtils.getEvenFeatureModel(1);
            Prediction prediction = new Prediction(predictionInput, (PredictionOutput) ((List) evenFeatureModel.predictAsync(List.of(predictionInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0));
            for (Saliency saliency : ((Map) new LimeExplainer(1000, 2, random).explainAsync(prediction, evenFeatureModel).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).values()) {
                Assertions.assertNotNull(saliency);
                List topFeatures = saliency.getTopFeatures(3);
                Assertions.assertEquals(3, topFeatures.size());
                Assertions.assertEquals(1.0d, ExplainabilityMetrics.impactScore(evenFeatureModel, prediction, topFeatures));
            }
        }
    }

    @Test
    void testTextSpamClassification() throws Exception {
        Random random = new Random();
        for (int i = 0; i < 5; i++) {
            random.setSeed(i);
            LinkedList linkedList = new LinkedList();
            Function function = str -> {
                return Arrays.asList((String[]) str.split(" ").clone());
            };
            linkedList.add(FeatureFactory.newFulltextFeature("f1", "we go here and there", function));
            linkedList.add(FeatureFactory.newFulltextFeature("f2", "please give me some money", function));
            linkedList.add(FeatureFactory.newFulltextFeature("f3", "dear friend, please reply", function));
            PredictionInput predictionInput = new PredictionInput(linkedList);
            PredictionProvider dummyTextClassifier = TestUtils.getDummyTextClassifier();
            Prediction prediction = new Prediction(predictionInput, (PredictionOutput) ((List) dummyTextClassifier.predictAsync(List.of(predictionInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0));
            for (Saliency saliency : ((Map) new LimeExplainer(1000, 1, random).explainAsync(prediction, dummyTextClassifier).toCompletableFuture().get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).values()) {
                Assertions.assertNotNull(saliency);
                List positiveFeatures = saliency.getPositiveFeatures(1);
                Assertions.assertEquals(1, positiveFeatures.size());
                Assertions.assertEquals(1.0d, ExplainabilityMetrics.impactScore(dummyTextClassifier, prediction, positiveFeatures));
            }
        }
    }

    @Test
    void testUnusedFeatureClassification() throws Exception {
        Random random = new Random();
        for (int i = 0; i < 5; i++) {
            random.setSeed(i);
            LinkedList linkedList = new LinkedList();
            linkedList.add(FeatureFactory.newNumericalFeature("f1", 6));
            linkedList.add(FeatureFactory.newNumericalFeature("f2", 3));
            linkedList.add(FeatureFactory.newNumericalFeature("f3", 5));
            PredictionProvider evenSumModel = TestUtils.getEvenSumModel(2);
            PredictionInput predictionInput = new PredictionInput(linkedList);
            Prediction prediction = new Prediction(predictionInput, (PredictionOutput) ((List) evenSumModel.predictAsync(List.of(predictionInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0));
            for (Saliency saliency : ((Map) new LimeExplainer(1000, 1, random).explainAsync(prediction, evenSumModel).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).values()) {
                Assertions.assertNotNull(saliency);
                List topFeatures = saliency.getTopFeatures(3);
                Assertions.assertEquals(3, topFeatures.size());
                Assertions.assertEquals(1.0d, ExplainabilityMetrics.impactScore(evenSumModel, prediction, topFeatures));
            }
        }
    }

    @Test
    void testFixedOutput() throws Exception {
        Random random = new Random();
        for (int i = 0; i < 5; i++) {
            random.setSeed(i);
            LinkedList linkedList = new LinkedList();
            linkedList.add(FeatureFactory.newNumericalFeature("f1", 6));
            linkedList.add(FeatureFactory.newNumericalFeature("f2", 3));
            linkedList.add(FeatureFactory.newNumericalFeature("f3", 5));
            PredictionProvider fixedOutputClassifier = TestUtils.getFixedOutputClassifier();
            PredictionInput predictionInput = new PredictionInput(linkedList);
            Prediction prediction = new Prediction(predictionInput, (PredictionOutput) ((List) fixedOutputClassifier.predictAsync(List.of(predictionInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0));
            for (Saliency saliency : ((Map) new LimeExplainer(1000, 1, random).explainAsync(prediction, fixedOutputClassifier).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).values()) {
                Assertions.assertNotNull(saliency);
                List topFeatures = saliency.getTopFeatures(3);
                Assertions.assertEquals(3, topFeatures.size());
                Iterator it = topFeatures.iterator();
                while (it.hasNext()) {
                    Assertions.assertEquals(0.0d, ((FeatureImportance) it.next()).getScore());
                }
                Assertions.assertEquals(0.0d, ExplainabilityMetrics.impactScore(fixedOutputClassifier, prediction, topFeatures));
            }
        }
    }
}
