package org.kie.kogito.explainability.local.lime;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;
import org.assertj.core.api.AssertionsForClassTypes;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.kie.kogito.explainability.Config;
import org.kie.kogito.explainability.TestUtils;
import org.kie.kogito.explainability.local.LocalExplanationException;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.FeatureFactory;
import org.kie.kogito.explainability.model.FeatureImportance;
import org.kie.kogito.explainability.model.GenericFeatureDistribution;
import org.kie.kogito.explainability.model.IndependentFeaturesDataDistribution;
import org.kie.kogito.explainability.model.PerturbationContext;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.Saliency;
import org.kie.kogito.explainability.model.SimplePrediction;
import org.kie.kogito.explainability.model.Type;
import org.mockito.Mockito;

/* loaded from: input_file:org/kie/kogito/explainability/local/lime/LimeExplainerTest.class */
class LimeExplainerTest {
    private static final int DEFAULT_NO_OF_PERTURBATIONS = 1;

    LimeExplainerTest() {
    }

    @ValueSource(longs = {0, 1, 2, 3, 4})
    @ParameterizedTest
    void testEmptyPrediction(long j) throws ExecutionException, InterruptedException, TimeoutException {
        LimeExplainer limeExplainer = new LimeExplainer(new LimeConfig().withPerturbationContext(new PerturbationContext(Long.valueOf(j), new Random(), DEFAULT_NO_OF_PERTURBATIONS)).withSamples(10));
        PredictionInput predictionInput = new PredictionInput(Collections.emptyList());
        PredictionProvider sumSkipModel = TestUtils.getSumSkipModel(0);
        SimplePrediction simplePrediction = new SimplePrediction(predictionInput, (PredictionOutput) ((List) sumSkipModel.predictAsync(List.of(predictionInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0));
        Assertions.assertThrows(LocalExplanationException.class, () -> {
            limeExplainer.explainAsync(simplePrediction, sumSkipModel);
        });
    }

    @ValueSource(longs = {0, 1, 2, 3, 4})
    @ParameterizedTest
    void testNonEmptyInput(long j) throws ExecutionException, InterruptedException, TimeoutException {
        LimeExplainer limeExplainer = new LimeExplainer(new LimeConfig().withPerturbationContext(new PerturbationContext(Long.valueOf(j), new Random(), DEFAULT_NO_OF_PERTURBATIONS)).withSamples(10));
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < 4; i += DEFAULT_NO_OF_PERTURBATIONS) {
            arrayList.add(TestUtils.getMockedNumericFeature(i));
        }
        PredictionInput predictionInput = new PredictionInput(arrayList);
        PredictionProvider sumSkipModel = TestUtils.getSumSkipModel(0);
        Assertions.assertNotNull((Map) limeExplainer.explainAsync(new SimplePrediction(predictionInput, (PredictionOutput) ((List) sumSkipModel.predictAsync(List.of(predictionInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0)), sumSkipModel).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit()));
    }

    @ValueSource(longs = {0, 1, 2, 3, 4})
    @ParameterizedTest
    void testSparseBalance(long j) throws InterruptedException, ExecutionException, TimeoutException {
        for (int i = DEFAULT_NO_OF_PERTURBATIONS; i < 4; i += DEFAULT_NO_OF_PERTURBATIONS) {
            Random random = new Random();
            LimeExplainer limeExplainer = new LimeExplainer(new LimeConfig().withPerturbationContext(new PerturbationContext(Long.valueOf(j), random, DEFAULT_NO_OF_PERTURBATIONS)).withSamples(100).withPenalizeBalanceSparse(false));
            ArrayList arrayList = new ArrayList();
            for (int i2 = 0; i2 < i; i2 += DEFAULT_NO_OF_PERTURBATIONS) {
                arrayList.add(TestUtils.getMockedNumericFeature(i2));
            }
            PredictionInput predictionInput = new PredictionInput(arrayList);
            PredictionProvider sumSkipModel = TestUtils.getSumSkipModel(0);
            SimplePrediction simplePrediction = new SimplePrediction(predictionInput, (PredictionOutput) ((List) sumSkipModel.predictAsync(List.of(predictionInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0));
            Map map = (Map) limeExplainer.explainAsync(simplePrediction, sumSkipModel).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
            AssertionsForClassTypes.assertThat(map).isNotNull();
            Saliency saliency = (Saliency) map.get("sum-but0");
            Map map2 = (Map) new LimeExplainer(new LimeConfig().withPerturbationContext(new PerturbationContext(Long.valueOf(j), random, DEFAULT_NO_OF_PERTURBATIONS)).withSamples(100).withPenalizeBalanceSparse(true)).explainAsync(simplePrediction, sumSkipModel).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
            AssertionsForClassTypes.assertThat(map2).isNotNull();
            Saliency saliency2 = (Saliency) map2.get("sum-but0");
            for (int i3 = 0; i3 < arrayList.size(); i3 += DEFAULT_NO_OF_PERTURBATIONS) {
                AssertionsForClassTypes.assertThat(Math.abs(((FeatureImportance) saliency2.getPerFeatureImportance().get(i3)).getScore())).isLessThanOrEqualTo(Math.abs(((FeatureImportance) saliency.getPerFeatureImportance().get(i3)).getScore()));
            }
        }
    }

    @Test
    void testNormalizedWeights() throws InterruptedException, ExecutionException, TimeoutException {
        LimeExplainer limeExplainer = new LimeExplainer(new LimeConfig().withNormalizeWeights(true).withPerturbationContext(new PerturbationContext(4L, new Random(), 2)).withSamples(10));
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < 4; i += DEFAULT_NO_OF_PERTURBATIONS) {
            arrayList.add(TestUtils.getMockedNumericFeature(i));
        }
        PredictionInput predictionInput = new PredictionInput(arrayList);
        PredictionProvider sumSkipModel = TestUtils.getSumSkipModel(0);
        Map map = (Map) limeExplainer.explainAsync(new SimplePrediction(predictionInput, (PredictionOutput) ((List) sumSkipModel.predictAsync(List.of(predictionInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0)), sumSkipModel).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        AssertionsForClassTypes.assertThat(map).isNotNull();
        Iterator it = ((Saliency) map.get("sum-but0")).getPerFeatureImportance().iterator();
        while (it.hasNext()) {
            AssertionsForClassTypes.assertThat(((FeatureImportance) it.next()).getScore()).isBetween(Double.valueOf(0.0d), Double.valueOf(1.0d));
        }
    }

    @Test
    void testWithDataDistribution() throws InterruptedException, ExecutionException, TimeoutException {
        Random random = new Random();
        PerturbationContext perturbationContext = new PerturbationContext(4L, random, DEFAULT_NO_OF_PERTURBATIONS);
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i = 0; i < 4; i += DEFAULT_NO_OF_PERTURBATIONS) {
            Feature newNumericalFeature = FeatureFactory.newNumericalFeature("f-" + i, Double.valueOf(Double.NaN));
            arrayList2.add(newNumericalFeature);
            ArrayList arrayList3 = new ArrayList();
            for (int i2 = 0; i2 < 4; i2 += DEFAULT_NO_OF_PERTURBATIONS) {
                arrayList3.add(Type.NUMBER.randomValue(perturbationContext));
            }
            arrayList.add(new GenericFeatureDistribution(newNumericalFeature, arrayList3));
        }
        LimeExplainer limeExplainer = new LimeExplainer(new LimeConfig().withDataDistribution(new IndependentFeaturesDataDistribution(arrayList)).withPerturbationContext(perturbationContext).withSamples(10));
        PredictionInput predictionInput = new PredictionInput(arrayList2);
        PredictionProvider sumThresholdModel = TestUtils.getSumThresholdModel(random.nextDouble(), random.nextDouble());
        Map map = (Map) limeExplainer.explainAsync(new SimplePrediction(predictionInput, (PredictionOutput) ((List) sumThresholdModel.predictAsync(List.of(predictionInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0)), sumThresholdModel).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        AssertionsForClassTypes.assertThat(map).isNotNull();
        AssertionsForClassTypes.assertThat((Saliency) map.get("inside")).isNotNull();
    }

    @Test
    void testZeroSampleSize() throws ExecutionException, InterruptedException, TimeoutException {
        LimeExplainer limeExplainer = new LimeExplainer(new LimeConfig().withSamples(0));
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < 4; i += DEFAULT_NO_OF_PERTURBATIONS) {
            arrayList.add(TestUtils.getMockedNumericFeature(i));
        }
        PredictionInput predictionInput = new PredictionInput(arrayList);
        PredictionProvider sumSkipModel = TestUtils.getSumSkipModel(0);
        Assertions.assertNotNull((Map) limeExplainer.explainAsync(new SimplePrediction(predictionInput, (PredictionOutput) ((List) sumSkipModel.predictAsync(List.of(predictionInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0)), sumSkipModel).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit()));
    }

    @ValueSource(longs = {0, 1, 2, 3, 4})
    @ParameterizedTest
    void testDeterministic(long j) throws ExecutionException, InterruptedException, TimeoutException {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < 2; i += DEFAULT_NO_OF_PERTURBATIONS) {
            LimeExplainer limeExplainer = new LimeExplainer(new LimeConfig().withPerturbationContext(new PerturbationContext(Long.valueOf(j), new Random(), DEFAULT_NO_OF_PERTURBATIONS)).withSamples(10));
            ArrayList arrayList2 = new ArrayList();
            for (int i2 = 0; i2 < 4; i2 += DEFAULT_NO_OF_PERTURBATIONS) {
                arrayList2.add(TestUtils.getMockedNumericFeature(i2));
            }
            PredictionInput predictionInput = new PredictionInput(arrayList2);
            PredictionProvider sumSkipModel = TestUtils.getSumSkipModel(0);
            arrayList.add((Saliency) ((Map) limeExplainer.explainAsync(new SimplePrediction(predictionInput, (PredictionOutput) ((List) sumSkipModel.predictAsync(List.of(predictionInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0)), sumSkipModel).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get("sum-but0"));
        }
        AssertionsForClassTypes.assertThat((List) ((Saliency) arrayList.get(0)).getPerFeatureImportance().stream().map((v0) -> {
            return v0.getScore();
        }).collect(Collectors.toList())).isEqualTo(((Saliency) arrayList.get(DEFAULT_NO_OF_PERTURBATIONS)).getPerFeatureImportance().stream().map((v0) -> {
            return v0.getScore();
        }).collect(Collectors.toList()));
    }

    @Test
    void testEmptyInput() {
        LimeExplainer limeExplainer = new LimeExplainer();
        PredictionProvider predictionProvider = (PredictionProvider) Mockito.mock(PredictionProvider.class);
        Prediction prediction = (Prediction) Mockito.mock(Prediction.class);
        AssertionsForClassTypes.assertThatCode(() -> {
            limeExplainer.explainAsync(prediction, predictionProvider);
        }).hasMessage("cannot explain a prediction whose input is empty");
    }
}
