package org.kie.kogito.explainability.explainability.integrationtests.dmn;

import java.io.InputStreamReader;
import java.io.Reader;
import java.math.BigDecimal;
import java.util.LinkedList;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.kie.dmn.api.core.DMNRuntime;
import org.kie.kogito.dmn.DMNKogito;
import org.kie.kogito.dmn.DmnDecisionModel;
import org.kie.kogito.explainability.Config;
import org.kie.kogito.explainability.local.counterfactual.CounterfactualConfig;
import org.kie.kogito.explainability.local.counterfactual.CounterfactualExplainer;
import org.kie.kogito.explainability.local.counterfactual.CounterfactualResult;
import org.kie.kogito.explainability.local.counterfactual.SolverConfigBuilder;
import org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity;
import org.kie.kogito.explainability.model.CounterfactualPrediction;
import org.kie.kogito.explainability.model.DataDistribution;
import org.kie.kogito.explainability.model.FeatureFactory;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.PredictionFeatureDomain;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.Type;
import org.kie.kogito.explainability.model.Value;
import org.kie.kogito.explainability.model.domain.EmptyFeatureDomain;
import org.kie.kogito.explainability.model.domain.NumericalFeatureDomain;
import org.optaplanner.core.config.solver.EnvironmentMode;
import org.optaplanner.core.config.solver.SolverConfig;
import org.optaplanner.core.config.solver.termination.TerminationConfig;

/* loaded from: input_file:org/kie/kogito/explainability/explainability/integrationtests/dmn/ComplexEligibilityDmnCounterfactualExplainerTest.class */
class ComplexEligibilityDmnCounterfactualExplainerTest {
    ComplexEligibilityDmnCounterfactualExplainerTest() {
    }

    @Test
    void testDMNValidCounterfactualExplanation() throws ExecutionException, InterruptedException, TimeoutException {
        PredictionProvider model = getModel();
        List<Output> generateGoal = generateGoal(true, true, 0.6d);
        LinkedList linkedList = new LinkedList();
        LinkedList linkedList2 = new LinkedList();
        LinkedList linkedList3 = new LinkedList();
        linkedList.add(FeatureFactory.newNumericalFeature("age", 40));
        linkedList3.add(true);
        linkedList2.add(EmptyFeatureDomain.create());
        linkedList.add(FeatureFactory.newBooleanFeature("hasReferral", true));
        linkedList3.add(true);
        linkedList2.add(EmptyFeatureDomain.create());
        linkedList.add(FeatureFactory.newNumericalFeature("monthlySalary", 500));
        linkedList2.add(NumericalFeatureDomain.create(10.0d, 10000.0d));
        linkedList3.add(false);
        SolverConfig build = SolverConfigBuilder.builder().withTerminationConfig(new TerminationConfig().withScoreCalculationCountLimit(10000L)).build();
        build.setRandomSeed(23L);
        build.setEnvironmentMode(EnvironmentMode.REPRODUCIBLE);
        CounterfactualResult counterfactualResult = (CounterfactualResult) new CounterfactualExplainer(new CounterfactualConfig().withSolverConfig(build).withGoalThreshold(0.01d)).explainAsync(new CounterfactualPrediction(new PredictionInput(linkedList), new PredictionOutput(generateGoal), new PredictionFeatureDomain(linkedList2), linkedList3, (DataDistribution) null, UUID.randomUUID(), 60L), model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        List outputs = ((PredictionOutput) counterfactualResult.getOutput().get(0)).getOutputs();
        Assertions.assertTrue(counterfactualResult.isValid());
        Assertions.assertEquals("inputsAreValid", ((Output) outputs.get(0)).getName());
        Assertions.assertTrue(((Boolean) ((Output) outputs.get(0)).getValue().getUnderlyingObject()).booleanValue());
        Assertions.assertEquals("canRequestLoan", ((Output) outputs.get(1)).getName());
        Assertions.assertTrue(((Boolean) ((Output) outputs.get(1)).getValue().getUnderlyingObject()).booleanValue());
        Assertions.assertEquals("my-scoring-function", ((Output) outputs.get(2)).getName());
        Assertions.assertEquals(0.6d, ((BigDecimal) ((Output) outputs.get(2)).getValue().getUnderlyingObject()).doubleValue(), 0.05d);
        List entities = counterfactualResult.getEntities();
        Assertions.assertEquals("age", ((CounterfactualEntity) entities.get(0)).asFeature().getName());
        Assertions.assertEquals(((CounterfactualEntity) entities.get(0)).asFeature().getValue().asNumber(), 40.0d);
        Assertions.assertEquals("hasReferral", ((CounterfactualEntity) entities.get(1)).asFeature().getName());
        Assertions.assertTrue(((Boolean) ((CounterfactualEntity) entities.get(1)).asFeature().getValue().getUnderlyingObject()).booleanValue());
        Assertions.assertEquals("monthlySalary", ((CounterfactualEntity) entities.get(2)).asFeature().getName());
        Assertions.assertTrue(((CounterfactualEntity) entities.get(2)).asFeature().getValue().asNumber() > 6000.0d);
    }

    @Test
    void testDMNScoringFunction() throws ExecutionException, InterruptedException, TimeoutException {
        PredictionProvider model = getModel();
        List<Output> generateGoal = generateGoal(true, true, 1.0d);
        LinkedList linkedList = new LinkedList();
        LinkedList linkedList2 = new LinkedList();
        LinkedList linkedList3 = new LinkedList();
        linkedList.add(FeatureFactory.newNumericalFeature("age", 40));
        linkedList3.add(false);
        linkedList2.add(NumericalFeatureDomain.create(18.0d, 60.0d));
        linkedList.add(FeatureFactory.newBooleanFeature("hasReferral", true));
        linkedList3.add(true);
        linkedList2.add(EmptyFeatureDomain.create());
        linkedList.add(FeatureFactory.newNumericalFeature("monthlySalary", 500));
        linkedList2.add(NumericalFeatureDomain.create(10.0d, 100000.0d));
        linkedList3.add(false);
        SolverConfig build = SolverConfigBuilder.builder().withTerminationConfig(new TerminationConfig().withScoreCalculationCountLimit(10000L)).build();
        build.setRandomSeed(23L);
        build.setEnvironmentMode(EnvironmentMode.REPRODUCIBLE);
        CounterfactualResult counterfactualResult = (CounterfactualResult) new CounterfactualExplainer(new CounterfactualConfig().withSolverConfig(build).withGoalThreshold(0.01d)).explainAsync(new CounterfactualPrediction(new PredictionInput(linkedList), new PredictionOutput(generateGoal), new PredictionFeatureDomain(linkedList2), linkedList3, (DataDistribution) null, UUID.randomUUID(), 60L), model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        List outputs = ((PredictionOutput) counterfactualResult.getOutput().get(0)).getOutputs();
        Assertions.assertTrue(counterfactualResult.isValid());
        Assertions.assertEquals("inputsAreValid", ((Output) outputs.get(0)).getName());
        Assertions.assertTrue(((Boolean) ((Output) outputs.get(0)).getValue().getUnderlyingObject()).booleanValue());
        Assertions.assertEquals("canRequestLoan", ((Output) outputs.get(1)).getName());
        Assertions.assertTrue(((Boolean) ((Output) outputs.get(1)).getValue().getUnderlyingObject()).booleanValue());
        Assertions.assertEquals("my-scoring-function", ((Output) outputs.get(2)).getName());
        Assertions.assertEquals(1.0d, ((BigDecimal) ((Output) outputs.get(2)).getValue().getUnderlyingObject()).doubleValue(), 0.01d);
        List entities = counterfactualResult.getEntities();
        Assertions.assertEquals("age", ((CounterfactualEntity) entities.get(0)).asFeature().getName());
        Assertions.assertEquals(18.0d, ((CounterfactualEntity) entities.get(0)).asFeature().getValue().asNumber());
        Assertions.assertEquals("hasReferral", ((CounterfactualEntity) entities.get(1)).asFeature().getName());
        Assertions.assertTrue(((Boolean) ((CounterfactualEntity) entities.get(1)).asFeature().getValue().getUnderlyingObject()).booleanValue());
        Assertions.assertEquals("monthlySalary", ((CounterfactualEntity) entities.get(2)).asFeature().getName());
        double asNumber = ((CounterfactualEntity) entities.get(2)).asFeature().getValue().asNumber();
        Assertions.assertEquals(7900.0d, asNumber, 10.0d);
        Assertions.assertEquals(18.0d, ((7.0d * asNumber) / 2000.0d) - 10.0d, 0.5d);
    }

    @Test
    void testDMNInvalidCounterfactualExplanation() throws ExecutionException, InterruptedException, TimeoutException {
        PredictionProvider model = getModel();
        List<Output> generateGoal = generateGoal(true, true, 0.6d);
        LinkedList linkedList = new LinkedList();
        LinkedList linkedList2 = new LinkedList();
        LinkedList linkedList3 = new LinkedList();
        linkedList.add(FeatureFactory.newNumericalFeature("age", 61));
        linkedList3.add(true);
        linkedList2.add(EmptyFeatureDomain.create());
        linkedList.add(FeatureFactory.newBooleanFeature("hasReferral", true));
        linkedList3.add(true);
        linkedList2.add(EmptyFeatureDomain.create());
        linkedList.add(FeatureFactory.newNumericalFeature("monthlySalary", 500));
        linkedList2.add(NumericalFeatureDomain.create(10.0d, 10000.0d));
        linkedList3.add(false);
        SolverConfig build = SolverConfigBuilder.builder().withTerminationConfig(new TerminationConfig().withScoreCalculationCountLimit(10000L)).build();
        build.setRandomSeed(23L);
        build.setEnvironmentMode(EnvironmentMode.REPRODUCIBLE);
        Assertions.assertFalse(((CounterfactualResult) new CounterfactualExplainer(new CounterfactualConfig().withSolverConfig(build).withGoalThreshold(0.01d)).explainAsync(new CounterfactualPrediction(new PredictionInput(linkedList), new PredictionOutput(generateGoal), new PredictionFeatureDomain(linkedList2), linkedList3, (DataDistribution) null, UUID.randomUUID(), 60L), model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).isValid());
    }

    private PredictionProvider getModel() {
        DMNRuntime createGenericDMNRuntime = DMNKogito.createGenericDMNRuntime(new Reader[]{new InputStreamReader(getClass().getResourceAsStream("/dmn/ComplexEligibility.dmn"))});
        Assertions.assertEquals(1, createGenericDMNRuntime.getModels().size());
        return new DecisionModelWrapper(new DmnDecisionModel(createGenericDMNRuntime, "https://kiegroup.org/dmn/_B305FE71-3B8C-48C5-B5B1-D9CC04825B16", "myComplexEligibility"), List.of());
    }

    private List<Output> generateGoal(boolean z, boolean z2, double d) {
        return List.of(new Output("inputsAreValid", Type.BOOLEAN, new Value(Boolean.valueOf(z)), 0.0d), new Output("canRequestLoan", Type.BOOLEAN, new Value(Boolean.valueOf(z2)), 0.0d), new Output("my-scoring-function", Type.NUMBER, new Value(Double.valueOf(d)), 0.0d));
    }
}
