package org.kie.kogito.explainability.explainability.integrationtests.dmn;

import java.io.InputStreamReader;
import java.io.Reader;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
import org.assertj.core.api.AssertionsForClassTypes;
import org.junit.jupiter.api.Test;
import org.kie.dmn.api.core.DMNRuntime;
import org.kie.kogito.dmn.DMNKogito;
import org.kie.kogito.dmn.DmnDecisionModel;
import org.kie.kogito.explainability.Config;
import org.kie.kogito.explainability.local.lime.LimeConfig;
import org.kie.kogito.explainability.local.lime.LimeExplainer;
import org.kie.kogito.explainability.model.FeatureFactory;
import org.kie.kogito.explainability.model.FeatureImportance;
import org.kie.kogito.explainability.model.PerturbationContext;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionInputsDataDistribution;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.Saliency;
import org.kie.kogito.explainability.utils.DataUtils;
import org.kie.kogito.explainability.utils.ExplainabilityMetrics;
import org.kie.kogito.explainability.utils.ValidationUtils;

/* loaded from: input_file:org/kie/kogito/explainability/explainability/integrationtests/dmn/DummyDmnModelsLimeExplainerTest.class */
class DummyDmnModelsLimeExplainerTest {
    DummyDmnModelsLimeExplainerTest() {
    }

    @Test
    void testFunctional1DMNExplanation() throws ExecutionException, InterruptedException, TimeoutException {
        DMNRuntime createGenericDMNRuntime = DMNKogito.createGenericDMNRuntime(new Reader[]{new InputStreamReader(getClass().getResourceAsStream("/dmn/functionalTest1.dmn"))});
        AssertionsForClassTypes.assertThat(createGenericDMNRuntime.getModels().size()).isEqualTo(1);
        DecisionModelWrapper decisionModelWrapper = new DecisionModelWrapper(new DmnDecisionModel(createGenericDMNRuntime, "https://kiegroup.org/dmn/_049CD980-1310-4B02-9E90-EFC57059F44A", "functionalTest1"));
        HashMap hashMap = new HashMap();
        hashMap.put("booleanInput", true);
        hashMap.put("notUsedInput", 1);
        ArrayList arrayList = new ArrayList();
        arrayList.add(FeatureFactory.newCompositeFeature("context", hashMap));
        PredictionInput predictionInput = new PredictionInput(arrayList);
        Prediction prediction = new Prediction(predictionInput, (PredictionOutput) ((List) decisionModelWrapper.predictAsync(List.of(predictionInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0));
        Random random = new Random();
        random.setSeed(0L);
        PerturbationContext perturbationContext = new PerturbationContext(random, 1);
        LimeExplainer limeExplainer = new LimeExplainer(new LimeConfig().withSamples(10).withPerturbationContext(perturbationContext));
        for (Saliency saliency : ((Map) limeExplainer.explainAsync(prediction, decisionModelWrapper).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).values()) {
            AssertionsForClassTypes.assertThat(saliency).isNotNull();
            List positiveFeatures = saliency.getPositiveFeatures(2);
            AssertionsForClassTypes.assertThat(positiveFeatures.isEmpty()).isFalse();
            AssertionsForClassTypes.assertThat(((FeatureImportance) positiveFeatures.get(0)).getFeature().getName()).isEqualTo("booleanInput");
        }
        AssertionsForClassTypes.assertThatCode(() -> {
            ValidationUtils.validateLocalSaliencyStability(decisionModelWrapper, prediction, limeExplainer, 1, 0.5d, 0.5d);
        }).doesNotThrowAnyException();
        ArrayList arrayList2 = new ArrayList();
        for (int i = 0; i < 10; i++) {
            arrayList2.add(new PredictionInput(DataUtils.perturbFeatures(arrayList, perturbationContext)));
        }
        PredictionInputsDataDistribution predictionInputsDataDistribution = new PredictionInputsDataDistribution(arrayList2);
        AssertionsForClassTypes.assertThat(ExplainabilityMetrics.getLocalSaliencyPrecision("decision", decisionModelWrapper, limeExplainer, predictionInputsDataDistribution, 2, 5)).isBetween(Double.valueOf(0.0d), Double.valueOf(1.0d));
        AssertionsForClassTypes.assertThat(ExplainabilityMetrics.getLocalSaliencyRecall("decision", decisionModelWrapper, limeExplainer, predictionInputsDataDistribution, 2, 5)).isBetween(Double.valueOf(0.0d), Double.valueOf(1.0d));
        AssertionsForClassTypes.assertThat(ExplainabilityMetrics.getLocalSaliencyF1("decision", decisionModelWrapper, limeExplainer, predictionInputsDataDistribution, 2, 5)).isBetween(Double.valueOf(0.0d), Double.valueOf(1.0d));
    }

    @Test
    void testFunctional2DMNExplanation() throws ExecutionException, InterruptedException, TimeoutException {
        DMNRuntime createGenericDMNRuntime = DMNKogito.createGenericDMNRuntime(new Reader[]{new InputStreamReader(getClass().getResourceAsStream("/dmn/functionalTest2.dmn"))});
        AssertionsForClassTypes.assertThat(createGenericDMNRuntime.getModels().size()).isEqualTo(1);
        DecisionModelWrapper decisionModelWrapper = new DecisionModelWrapper(new DmnDecisionModel(createGenericDMNRuntime, "https://kiegroup.org/dmn/_049CD980-1310-4B02-9E90-EFC57059F44A", "new-file"));
        HashMap hashMap = new HashMap();
        hashMap.put("numberInput", 1);
        hashMap.put("notUsedInput", 1);
        ArrayList arrayList = new ArrayList();
        arrayList.add(FeatureFactory.newCompositeFeature("context", hashMap));
        PredictionInput predictionInput = new PredictionInput(arrayList);
        Prediction prediction = new Prediction(predictionInput, (PredictionOutput) ((List) decisionModelWrapper.predictAsync(List.of(predictionInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0));
        Random random = new Random();
        random.setSeed(0L);
        PerturbationContext perturbationContext = new PerturbationContext(random, 1);
        LimeExplainer limeExplainer = new LimeExplainer(new LimeConfig().withSamples(10).withPerturbationContext(perturbationContext));
        for (Saliency saliency : ((Map) limeExplainer.explainAsync(prediction, decisionModelWrapper).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).values()) {
            AssertionsForClassTypes.assertThat(saliency).isNotNull();
            List positiveFeatures = saliency.getPositiveFeatures(2);
            AssertionsForClassTypes.assertThat(positiveFeatures.isEmpty()).isFalse();
            AssertionsForClassTypes.assertThat(((FeatureImportance) positiveFeatures.get(0)).getFeature().getName()).isEqualTo("numberInput");
        }
        AssertionsForClassTypes.assertThatCode(() -> {
            ValidationUtils.validateLocalSaliencyStability(decisionModelWrapper, prediction, limeExplainer, 1, 0.5d, 0.5d);
        }).doesNotThrowAnyException();
        ArrayList arrayList2 = new ArrayList();
        for (int i = 0; i < 10; i++) {
            arrayList2.add(new PredictionInput(DataUtils.perturbFeatures(arrayList, perturbationContext)));
        }
        PredictionInputsDataDistribution predictionInputsDataDistribution = new PredictionInputsDataDistribution(arrayList2);
        AssertionsForClassTypes.assertThat(ExplainabilityMetrics.getLocalSaliencyPrecision("decision", decisionModelWrapper, limeExplainer, predictionInputsDataDistribution, 2, 5)).isBetween(Double.valueOf(0.0d), Double.valueOf(1.0d));
        AssertionsForClassTypes.assertThat(ExplainabilityMetrics.getLocalSaliencyRecall("decision", decisionModelWrapper, limeExplainer, predictionInputsDataDistribution, 2, 5)).isBetween(Double.valueOf(0.0d), Double.valueOf(1.0d));
        AssertionsForClassTypes.assertThat(ExplainabilityMetrics.getLocalSaliencyF1("decision", decisionModelWrapper, limeExplainer, predictionInputsDataDistribution, 2, 5)).isBetween(Double.valueOf(0.0d), Double.valueOf(1.0d));
    }

    @Test
    void testAllTypesDMNExplanation() throws ExecutionException, InterruptedException, TimeoutException {
        DMNRuntime createGenericDMNRuntime = DMNKogito.createGenericDMNRuntime(new Reader[]{new InputStreamReader(getClass().getResourceAsStream("/dmn/allTypes.dmn"))});
        AssertionsForClassTypes.assertThat(createGenericDMNRuntime.getModels().size()).isEqualTo(1);
        DecisionModelWrapper decisionModelWrapper = new DecisionModelWrapper(new DmnDecisionModel(createGenericDMNRuntime, "https://kiegroup.org/dmn/_24B9EC8C-2F02-40EB-B6BB-E8CDE82FBF08", "new-file"));
        HashMap hashMap = new HashMap();
        hashMap.put("stringInput", "test");
        hashMap.put("listOfStringInput", Collections.singletonList("test"));
        hashMap.put("numberInput", 1);
        hashMap.put("listOfNumbersInput", Collections.singletonList(1));
        hashMap.put("booleanInput", true);
        hashMap.put("listOfBooleansInput", Collections.singletonList(true));
        hashMap.put("timeInput", "h09:00");
        hashMap.put("dateInput", "2020-04-02");
        hashMap.put("dateAndTimeInput", "2020-04-02T09:00:00");
        hashMap.put("daysAndTimeDurationInput", "P1DT1H");
        hashMap.put("yearsAndMonthDurationInput", "P1Y1M");
        HashMap hashMap2 = new HashMap();
        hashMap2.put("aNestedListOfNumbers", Collections.singletonList(1));
        hashMap2.put("aNestedString", "test");
        hashMap2.put("aNestedComplexInput", Collections.singletonMap("doubleNestedNumber", 1));
        hashMap.put("complexInput", hashMap2);
        hashMap.put("listOfComplexInput", Collections.singletonList(hashMap2));
        ArrayList arrayList = new ArrayList();
        arrayList.add(FeatureFactory.newCompositeFeature("context", hashMap));
        PredictionInput predictionInput = new PredictionInput(arrayList);
        Prediction prediction = new Prediction(predictionInput, (PredictionOutput) ((List) decisionModelWrapper.predictAsync(List.of(predictionInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0));
        Random random = new Random();
        random.setSeed(0L);
        PerturbationContext perturbationContext = new PerturbationContext(random, 3);
        LimeExplainer limeExplainer = new LimeExplainer(new LimeConfig().withSamples(10).withPerturbationContext(perturbationContext));
        Iterator it = ((Map) limeExplainer.explainAsync(prediction, decisionModelWrapper).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).values().iterator();
        while (it.hasNext()) {
            AssertionsForClassTypes.assertThat((Saliency) it.next()).isNotNull();
        }
        AssertionsForClassTypes.assertThatCode(() -> {
            ValidationUtils.validateLocalSaliencyStability(decisionModelWrapper, prediction, limeExplainer, 1, 0.5d, 0.2d);
        }).doesNotThrowAnyException();
        ArrayList arrayList2 = new ArrayList();
        for (int i = 0; i < 10; i++) {
            arrayList2.add(new PredictionInput(DataUtils.perturbFeatures(arrayList, perturbationContext)));
        }
        PredictionInputsDataDistribution predictionInputsDataDistribution = new PredictionInputsDataDistribution(arrayList2);
        AssertionsForClassTypes.assertThat(ExplainabilityMetrics.getLocalSaliencyPrecision("myDecision", decisionModelWrapper, limeExplainer, predictionInputsDataDistribution, 2, 5)).isBetween(Double.valueOf(0.0d), Double.valueOf(1.0d));
        AssertionsForClassTypes.assertThat(ExplainabilityMetrics.getLocalSaliencyRecall("myDecision", decisionModelWrapper, limeExplainer, predictionInputsDataDistribution, 2, 5)).isBetween(Double.valueOf(0.0d), Double.valueOf(1.0d));
        AssertionsForClassTypes.assertThat(ExplainabilityMetrics.getLocalSaliencyF1("myDecision", decisionModelWrapper, limeExplainer, predictionInputsDataDistribution, 2, 5)).isBetween(Double.valueOf(0.0d), Double.valueOf(1.0d));
    }
}
