package org.kie.kogito.explainability.explainability.integrationtests.dmn;

import java.io.InputStreamReader;
import java.io.Reader;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.kie.dmn.api.core.DMNRuntime;
import org.kie.kogito.dmn.DMNKogito;
import org.kie.kogito.dmn.DmnDecisionModel;
import org.kie.kogito.explainability.Config;
import org.kie.kogito.explainability.local.counterfactual.CounterfactualConfig;
import org.kie.kogito.explainability.local.counterfactual.CounterfactualExplainer;
import org.kie.kogito.explainability.local.counterfactual.CounterfactualResult;
import org.kie.kogito.explainability.local.counterfactual.SolverConfigBuilder;
import org.kie.kogito.explainability.model.CounterfactualPrediction;
import org.kie.kogito.explainability.model.DataDistribution;
import org.kie.kogito.explainability.model.FeatureFactory;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.Type;
import org.kie.kogito.explainability.model.Value;
import org.kie.kogito.explainability.model.domain.NumericalFeatureDomain;
import org.kie.kogito.explainability.utils.CompositeFeatureUtils;
import org.optaplanner.core.config.solver.EnvironmentMode;
import org.optaplanner.core.config.solver.SolverConfig;
import org.optaplanner.core.config.solver.termination.TerminationConfig;

/* loaded from: input_file:org/kie/kogito/explainability/explainability/integrationtests/dmn/PrequalificationDmnCounterfactualExplainerTest.class */
class PrequalificationDmnCounterfactualExplainerTest {
    private static final long steps = 100000;
    private static final long randomSeed = 23;

    PrequalificationDmnCounterfactualExplainerTest() {
    }

    @Test
    void testValidCounterfactual() throws ExecutionException, InterruptedException, TimeoutException {
        PredictionProvider model = getModel();
        List of = List.of(new Output("Qualified?", Type.BOOLEAN, new Value(true), 0.0d));
        SolverConfig build = SolverConfigBuilder.builder().withTerminationConfig(new TerminationConfig().withScoreCalculationCountLimit(Long.valueOf(steps))).build();
        build.setRandomSeed(Long.valueOf(randomSeed));
        build.setEnvironmentMode(EnvironmentMode.REPRODUCIBLE);
        CounterfactualConfig withGoalThreshold = new CounterfactualConfig().withGoalThreshold(0.1d);
        withGoalThreshold.withSolverConfig(build);
        CounterfactualExplainer counterfactualExplainer = new CounterfactualExplainer(withGoalThreshold);
        PredictionInput testInputVariable = getTestInputVariable();
        PredictionOutput predictionOutput = new PredictionOutput(of);
        Output output = (Output) ((PredictionOutput) ((List) model.predictAsync(List.of(getTestInputFixed())).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0)).getOutputs().get(0);
        Assertions.assertEquals("Qualified?", output.getName());
        Assertions.assertFalse(((Boolean) output.getValue().getUnderlyingObject()).booleanValue());
        CounterfactualResult counterfactualResult = (CounterfactualResult) counterfactualExplainer.explainAsync(new CounterfactualPrediction(testInputVariable, predictionOutput, (DataDistribution) null, UUID.randomUUID(), (Long) null), model).get();
        List list = (List) model.predictAsync(List.of(new PredictionInput(CompositeFeatureUtils.unflattenFeatures((List) counterfactualResult.getEntities().stream().map((v0) -> {
            return v0.asFeature();
        }).collect(Collectors.toList()), testInputVariable.getFeatures())))).get();
        Assertions.assertTrue(counterfactualResult.isValid());
        Output output2 = (Output) ((PredictionOutput) list.get(0)).getOutputs().get(0);
        Assertions.assertEquals("Qualified?", output2.getName());
        Assertions.assertTrue(((Boolean) output2.getValue().getUnderlyingObject()).booleanValue());
    }

    private PredictionInput getTestInputFixed() {
        HashMap hashMap = new HashMap();
        hashMap.put("Monthly Other Debt", 5000);
        hashMap.put("Monthly Income", 10000);
        HashMap hashMap2 = new HashMap();
        hashMap2.put("Appraised Value", 500000);
        hashMap2.put("Loan Amount", 500000);
        hashMap2.put("Credit Score", 700);
        hashMap2.put("Best Rate", 1);
        hashMap2.put("Borrower", hashMap);
        LinkedList linkedList = new LinkedList();
        linkedList.add(FeatureFactory.newCompositeFeature("context", hashMap2));
        return new PredictionInput(linkedList);
    }

    private PredictionInput getTestInputVariable() {
        HashMap hashMap = new HashMap();
        hashMap.put("Monthly Other Debt", FeatureFactory.newNumericalFeature("Monthly Other Debt", 10000, NumericalFeatureDomain.create(0.0d, 10000.0d)));
        hashMap.put("Monthly Income", FeatureFactory.newNumericalFeature("Monthly Income", 10000, NumericalFeatureDomain.create(1000.0d, 500000.0d)));
        HashMap hashMap2 = new HashMap();
        hashMap2.put("Appraised Value", 500000);
        hashMap2.put("Loan Amount", FeatureFactory.newNumericalFeature("Loan Amount", 500000, NumericalFeatureDomain.create(10.0d, 500000.0d)));
        hashMap2.put("Credit Score", 700);
        hashMap2.put("Best Rate", 1);
        hashMap2.put("Borrower", hashMap);
        LinkedList linkedList = new LinkedList();
        linkedList.add(FeatureFactory.newCompositeFeature("context", hashMap2));
        return new PredictionInput(linkedList);
    }

    private PredictionProvider getModel() {
        DMNRuntime createGenericDMNRuntime = DMNKogito.createGenericDMNRuntime(new Reader[]{new InputStreamReader(getClass().getResourceAsStream("/dmn/Prequalification-1.dmn"))});
        Assertions.assertEquals(1, createGenericDMNRuntime.getModels().size());
        return new DecisionModelWrapper(new DmnDecisionModel(createGenericDMNRuntime, "http://www.trisotech.com/definitions/_f31e1f8e-d4ce-4a3a-ac3b-747efa6b3401", "Prequalification"), List.of("LTV", "LLPA", "DTI", "Loan Payment"));
    }
}
