package org.kie.kogito.explainability.local.lime;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
import org.assertj.core.api.AssertionsForClassTypes;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.kie.kogito.explainability.Config;
import org.kie.kogito.explainability.TestUtils;
import org.kie.kogito.explainability.local.LocalExplanationException;
import org.kie.kogito.explainability.model.FeatureImportance;
import org.kie.kogito.explainability.model.PerturbationContext;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.Saliency;

/* loaded from: input_file:org/kie/kogito/explainability/local/lime/LimeExplainerTest.class */
class LimeExplainerTest {
    private static final int DEFAULT_NO_OF_PERTURBATIONS = 1;

    LimeExplainerTest() {
    }

    @ValueSource(ints = {0, DEFAULT_NO_OF_PERTURBATIONS, 2, 3, 4})
    @ParameterizedTest
    void testEmptyPrediction(int i) throws ExecutionException, InterruptedException, TimeoutException {
        Random random = new Random();
        random.setSeed(i);
        LimeExplainer limeExplainer = new LimeExplainer(new LimeConfig().withPerturbationContext(new PerturbationContext(random, DEFAULT_NO_OF_PERTURBATIONS)).withSamples(10));
        PredictionInput predictionInput = new PredictionInput(Collections.emptyList());
        PredictionProvider sumSkipModel = TestUtils.getSumSkipModel(0);
        Prediction prediction = new Prediction(predictionInput, (PredictionOutput) ((List) sumSkipModel.predictAsync(List.of(predictionInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0));
        Assertions.assertThrows(LocalExplanationException.class, () -> {
            limeExplainer.explainAsync(prediction, sumSkipModel);
        });
    }

    @ValueSource(ints = {0, DEFAULT_NO_OF_PERTURBATIONS, 2, 3, 4})
    @ParameterizedTest
    void testNonEmptyInput(int i) throws ExecutionException, InterruptedException, TimeoutException {
        Random random = new Random();
        random.setSeed(i);
        LimeExplainer limeExplainer = new LimeExplainer(new LimeConfig().withPerturbationContext(new PerturbationContext(random, DEFAULT_NO_OF_PERTURBATIONS)).withSamples(10));
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < 4; i2 += DEFAULT_NO_OF_PERTURBATIONS) {
            arrayList.add(TestUtils.getMockedNumericFeature(i2));
        }
        PredictionInput predictionInput = new PredictionInput(arrayList);
        PredictionProvider sumSkipModel = TestUtils.getSumSkipModel(0);
        Assertions.assertNotNull((Map) limeExplainer.explainAsync(new Prediction(predictionInput, (PredictionOutput) ((List) sumSkipModel.predictAsync(List.of(predictionInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0)), sumSkipModel).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit()));
    }

    @ValueSource(ints = {0, DEFAULT_NO_OF_PERTURBATIONS, 2, 3, 4})
    @ParameterizedTest
    void testSparseBalance(int i) throws InterruptedException, ExecutionException, TimeoutException {
        Random random = new Random();
        random.setSeed(i);
        for (int i2 = DEFAULT_NO_OF_PERTURBATIONS; i2 < 4; i2 += DEFAULT_NO_OF_PERTURBATIONS) {
            LimeExplainer limeExplainer = new LimeExplainer(new LimeConfig().withPerturbationContext(new PerturbationContext(random, DEFAULT_NO_OF_PERTURBATIONS)).withSamples(100).withPenalizeBalanceSparse(false));
            ArrayList arrayList = new ArrayList();
            for (int i3 = 0; i3 < i2; i3 += DEFAULT_NO_OF_PERTURBATIONS) {
                arrayList.add(TestUtils.getMockedNumericFeature(i3));
            }
            PredictionInput predictionInput = new PredictionInput(arrayList);
            PredictionProvider sumSkipModel = TestUtils.getSumSkipModel(0);
            Prediction prediction = new Prediction(predictionInput, (PredictionOutput) ((List) sumSkipModel.predictAsync(List.of(predictionInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0));
            Map map = (Map) limeExplainer.explainAsync(prediction, sumSkipModel).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
            AssertionsForClassTypes.assertThat(map).isNotNull();
            Saliency saliency = (Saliency) map.get("sum-but0");
            Map map2 = (Map) new LimeExplainer(new LimeConfig().withSamples(100).withPenalizeBalanceSparse(true)).explainAsync(prediction, sumSkipModel).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
            AssertionsForClassTypes.assertThat(map2).isNotNull();
            Saliency saliency2 = (Saliency) map2.get("sum-but0");
            for (int i4 = 0; i4 < arrayList.size(); i4 += DEFAULT_NO_OF_PERTURBATIONS) {
                AssertionsForClassTypes.assertThat(Math.abs(((FeatureImportance) saliency2.getPerFeatureImportance().get(i4)).getScore())).isLessThanOrEqualTo(Math.abs(((FeatureImportance) saliency.getPerFeatureImportance().get(i4)).getScore()));
            }
        }
    }

    @Test
    void testNormalizedWeights() throws InterruptedException, ExecutionException, TimeoutException {
        Random random = new Random();
        random.setSeed(4L);
        LimeExplainer limeExplainer = new LimeExplainer(new LimeConfig().withNormalizeWeights(true).withPerturbationContext(new PerturbationContext(random, 2)).withSamples(10));
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < 4; i += DEFAULT_NO_OF_PERTURBATIONS) {
            arrayList.add(TestUtils.getMockedNumericFeature(i));
        }
        PredictionInput predictionInput = new PredictionInput(arrayList);
        PredictionProvider sumSkipModel = TestUtils.getSumSkipModel(0);
        Map map = (Map) limeExplainer.explainAsync(new Prediction(predictionInput, (PredictionOutput) ((List) sumSkipModel.predictAsync(List.of(predictionInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0)), sumSkipModel).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        AssertionsForClassTypes.assertThat(map).isNotNull();
        Iterator it = ((Saliency) map.get("sum-but0")).getPerFeatureImportance().iterator();
        while (it.hasNext()) {
            AssertionsForClassTypes.assertThat(((FeatureImportance) it.next()).getScore()).isBetween(Double.valueOf(-1.0d), Double.valueOf(1.0d));
        }
    }
}
