package org.kie.kogito.explainability.explainability.integrationtests.dmn;

import java.io.InputStreamReader;
import java.io.Reader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
import org.assertj.core.api.AssertionsForClassTypes;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.kie.dmn.api.core.DMNRuntime;
import org.kie.kogito.dmn.DMNKogito;
import org.kie.kogito.dmn.DmnDecisionModel;
import org.kie.kogito.explainability.Config;
import org.kie.kogito.explainability.local.lime.LimeConfig;
import org.kie.kogito.explainability.local.lime.LimeExplainer;
import org.kie.kogito.explainability.local.lime.optim.LimeConfigOptimizer;
import org.kie.kogito.explainability.model.FeatureFactory;
import org.kie.kogito.explainability.model.PerturbationContext;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionInputsDataDistribution;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.Saliency;
import org.kie.kogito.explainability.model.SimplePrediction;
import org.kie.kogito.explainability.utils.DataUtils;
import org.kie.kogito.explainability.utils.ExplainabilityMetrics;
import org.kie.kogito.explainability.utils.ValidationUtils;

/* loaded from: input_file:org/kie/kogito/explainability/explainability/integrationtests/dmn/LoanEligibilityDmnLimeExplainerTest.class */
class LoanEligibilityDmnLimeExplainerTest {
    LoanEligibilityDmnLimeExplainerTest() {
    }

    @Test
    void testLoanEligibilityDMNExplanation() throws ExecutionException, InterruptedException, TimeoutException {
        PredictionProvider model = getModel();
        PredictionInput testInput = getTestInput();
        SimplePrediction simplePrediction = new SimplePrediction(testInput, (PredictionOutput) ((List) model.predictAsync(List.of(testInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0));
        Random random = new Random();
        random.setSeed(0L);
        PerturbationContext perturbationContext = new PerturbationContext(random, 1);
        LimeExplainer limeExplainer = new LimeExplainer(new LimeConfig().withPerturbationContext(perturbationContext));
        Iterator it = ((Map) limeExplainer.explainAsync(simplePrediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).values().iterator();
        while (it.hasNext()) {
            Assertions.assertNotNull((Saliency) it.next());
        }
        Assertions.assertDoesNotThrow(() -> {
            ValidationUtils.validateLocalSaliencyStability(model, simplePrediction, limeExplainer, 1, 0.4d, 0.4d);
        });
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < 10; i++) {
            arrayList.add(new PredictionInput(DataUtils.perturbFeatures(testInput.getFeatures(), perturbationContext)));
        }
        AssertionsForClassTypes.assertThat(ExplainabilityMetrics.getLocalSaliencyF1("Eligibility", model, limeExplainer, new PredictionInputsDataDistribution(arrayList), 2, 2)).isBetween(Double.valueOf(0.5d), Double.valueOf(1.0d));
    }

    @Test
    void testExplanationStabilityWithOptimization() throws ExecutionException, InterruptedException, TimeoutException {
        PredictionProvider model = getModel();
        List<PredictionInput> randomFraudScoringInputs = DmnTestUtils.randomFraudScoringInputs();
        List predictions = DataUtils.getPredictions(randomFraudScoringInputs, (List) model.predictAsync(randomFraudScoringInputs.subList(0, 5)).get());
        LimeConfigOptimizer limeConfigOptimizer = new LimeConfigOptimizer();
        Random random = new Random();
        random.setSeed(0L);
        LimeConfig withPerturbationContext = new LimeConfig().withPerturbationContext(new PerturbationContext(random, 1));
        LimeConfig optimize = limeConfigOptimizer.optimize(withPerturbationContext, predictions, model);
        org.assertj.core.api.Assertions.assertThat(optimize).isNotSameAs(withPerturbationContext);
        LimeExplainer limeExplainer = new LimeExplainer(optimize);
        PredictionInput testInput = getTestInput();
        SimplePrediction simplePrediction = new SimplePrediction(testInput, (PredictionOutput) ((List) model.predictAsync(List.of(testInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0));
        Assertions.assertDoesNotThrow(() -> {
            ValidationUtils.validateLocalSaliencyStability(model, simplePrediction, limeExplainer, 1, 0.5d, 0.5d);
        });
    }

    @Test
    void testExplanationImpactScoreWithOptimization() throws ExecutionException, InterruptedException, TimeoutException {
        PredictionProvider model = getModel();
        List<PredictionInput> randomLoanEligibilityInputs = DmnTestUtils.randomLoanEligibilityInputs();
        List predictions = DataUtils.getPredictions(randomLoanEligibilityInputs, (List) model.predictAsync(randomLoanEligibilityInputs.subList(0, 10)).get());
        LimeConfigOptimizer withSampling = new LimeConfigOptimizer().forImpactScore().withSampling(false);
        Random random = new Random();
        random.setSeed(0L);
        LimeConfig withPerturbationContext = new LimeConfig().withSamples(10).withPerturbationContext(new PerturbationContext(random, 1));
        org.assertj.core.api.Assertions.assertThat(withSampling.optimize(withPerturbationContext, predictions, model)).isNotSameAs(withPerturbationContext);
    }

    @Test
    void testExplanationWeightedStabilityWithOptimization() throws ExecutionException, InterruptedException, TimeoutException {
        PredictionProvider model = getModel();
        List<PredictionInput> randomFraudScoringInputs = DmnTestUtils.randomFraudScoringInputs();
        List predictions = DataUtils.getPredictions(randomFraudScoringInputs, (List) model.predictAsync(randomFraudScoringInputs.subList(0, 5)).get());
        LimeConfigOptimizer withWeightedStability = new LimeConfigOptimizer().withWeightedStability(0.4d, 0.6d);
        Random random = new Random();
        random.setSeed(0L);
        LimeConfig withPerturbationContext = new LimeConfig().withPerturbationContext(new PerturbationContext(random, 1));
        LimeConfig optimize = withWeightedStability.optimize(withPerturbationContext, predictions, model);
        org.assertj.core.api.Assertions.assertThat(optimize).isNotSameAs(withPerturbationContext);
        LimeExplainer limeExplainer = new LimeExplainer(optimize);
        PredictionInput testInput = getTestInput();
        SimplePrediction simplePrediction = new SimplePrediction(testInput, (PredictionOutput) ((List) model.predictAsync(List.of(testInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0));
        Assertions.assertDoesNotThrow(() -> {
            ValidationUtils.validateLocalSaliencyStability(model, simplePrediction, limeExplainer, 1, 0.4d, 0.6d);
        });
    }

    private PredictionProvider getModel() {
        DMNRuntime createGenericDMNRuntime = DMNKogito.createGenericDMNRuntime(new Reader[]{new InputStreamReader(getClass().getResourceAsStream("/dmn/LoanEligibility.dmn"))});
        Assertions.assertEquals(1, createGenericDMNRuntime.getModels().size());
        return new DecisionModelWrapper(new DmnDecisionModel(createGenericDMNRuntime, "https://github.com/kiegroup/kogito-examples/dmn-quarkus-listener-example", "LoanEligibility"));
    }

    private PredictionInput getTestInput() {
        HashMap hashMap = new HashMap();
        hashMap.put("Age", 43);
        hashMap.put("Salary", 1950);
        hashMap.put("Existing payments", 100);
        HashMap hashMap2 = new HashMap();
        hashMap2.put("Duration", 15);
        hashMap2.put("Installment", 100);
        HashMap hashMap3 = new HashMap();
        hashMap3.put("Client", hashMap);
        hashMap3.put("Loan", hashMap2);
        ArrayList arrayList = new ArrayList();
        arrayList.add(FeatureFactory.newCompositeFeature("context", hashMap3));
        return new PredictionInput(arrayList);
    }
}
