package org.kie.kogito.explainability.global.lime;

import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.ExecutionException;
import java.util.stream.Collectors;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.kie.kogito.explainability.TestUtils;
import org.kie.kogito.explainability.local.lime.LimeExplainer;
import org.kie.kogito.explainability.model.DataDistribution;
import org.kie.kogito.explainability.model.FeatureFactory;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.PredictionProviderMetadata;
import org.kie.kogito.explainability.model.Saliency;
import org.kie.kogito.explainability.model.Type;
import org.kie.kogito.explainability.model.Value;
import org.kie.kogito.explainability.utils.DataUtils;

/* loaded from: input_file:org/kie/kogito/explainability/global/lime/AggregatedLimeExplainerTest.class */
class AggregatedLimeExplainerTest {
    AggregatedLimeExplainerTest() {
    }

    @Test
    void testExplainWithMetadata() throws ExecutionException, InterruptedException {
        for (int i = 0; i < 5; i++) {
            final Random random = new Random();
            random.setSeed(i);
            Map map = (Map) new AggregatedLimeExplainer(new LimeExplainer()).explainFromMetadata(TestUtils.getSumSkipModel(1), new PredictionProviderMetadata() { // from class: org.kie.kogito.explainability.global.lime.AggregatedLimeExplainerTest.1
                public DataDistribution getDataDistribution() {
                    return DataUtils.generateRandomDataDistribution(3, 100, random);
                }

                public PredictionInput getInputShape() {
                    LinkedList linkedList = new LinkedList();
                    linkedList.add(FeatureFactory.newNumericalFeature("f0", 0));
                    linkedList.add(FeatureFactory.newNumericalFeature("f1", 0));
                    linkedList.add(FeatureFactory.newNumericalFeature("f2", 0));
                    return new PredictionInput(linkedList);
                }

                public PredictionOutput getOutputShape() {
                    LinkedList linkedList = new LinkedList();
                    linkedList.add(new Output("sum-but1", Type.BOOLEAN, new Value(false), 0.0d));
                    return new PredictionOutput(linkedList);
                }
            }).get();
            Assertions.assertNotNull(map);
            Assertions.assertEquals(1, map.size());
            Assertions.assertTrue(map.containsKey("sum-but1"));
            Saliency saliency = (Saliency) map.get("sum-but1");
            Assertions.assertNotNull(saliency);
            Assertions.assertFalse(((List) saliency.getPositiveFeatures(2).stream().map((v0) -> {
                return v0.getFeature();
            }).map((v0) -> {
                return v0.getName();
            }).collect(Collectors.toList())).contains("f1"));
        }
    }

    @Test
    void testExplainWithPredictions() throws ExecutionException, InterruptedException {
        for (int i = 0; i < 5; i++) {
            Random random = new Random();
            random.setSeed(i);
            PredictionProvider sumSkipModel = TestUtils.getSumSkipModel(1);
            List sample = DataUtils.generateRandomDataDistribution(3, 100, random).sample(10);
            Map map = (Map) new AggregatedLimeExplainer(new LimeExplainer()).explainFromPredictions(sumSkipModel, DataUtils.getPredictions(sample, (List) sumSkipModel.predictAsync(sample).get())).get();
            Assertions.assertNotNull(map);
            Assertions.assertEquals(1, map.size());
            Assertions.assertTrue(map.containsKey("sum-but1"));
            Saliency saliency = (Saliency) map.get("sum-but1");
            Assertions.assertNotNull(saliency);
            Assertions.assertFalse(((List) saliency.getPositiveFeatures(2).stream().map((v0) -> {
                return v0.getFeature();
            }).map((v0) -> {
                return v0.getName();
            }).collect(Collectors.toList())).contains("f1"));
        }
    }
}
