package org.kie.kogito.explainability.explainability.integrationtests.dmn;

import java.io.InputStreamReader;
import java.io.Reader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
import org.assertj.core.api.AssertionsForClassTypes;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.kie.dmn.api.core.DMNRuntime;
import org.kie.kogito.dmn.DMNKogito;
import org.kie.kogito.dmn.DmnDecisionModel;
import org.kie.kogito.explainability.Config;
import org.kie.kogito.explainability.global.pdp.PartialDependencePlotExplainer;
import org.kie.kogito.explainability.model.FeatureFactory;
import org.kie.kogito.explainability.model.PerturbationContext;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.utils.DataUtils;

/* loaded from: input_file:org/kie/kogito/explainability/explainability/integrationtests/dmn/PrequalificationDmnPDPExplainerTest.class */
class PrequalificationDmnPDPExplainerTest {
    PrequalificationDmnPDPExplainerTest() {
    }

    @Test
    void testPrequalificationDMNExplanation() throws ExecutionException, InterruptedException, TimeoutException {
        DMNRuntime createGenericDMNRuntime = DMNKogito.createGenericDMNRuntime(new Reader[]{new InputStreamReader(getClass().getResourceAsStream("/dmn/Prequalification-1.dmn"))});
        Assertions.assertEquals(1, createGenericDMNRuntime.getModels().size());
        DecisionModelWrapper decisionModelWrapper = new DecisionModelWrapper(new DmnDecisionModel(createGenericDMNRuntime, "http://www.trisotech.com/definitions/_f31e1f8e-d4ce-4a3a-ac3b-747efa6b3401", "Prequalification"));
        List<PredictionInput> inputs = getInputs();
        List list = (List) decisionModelWrapper.predictAsync(inputs).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < list.size(); i++) {
            arrayList.add(new Prediction(inputs.get(i), (PredictionOutput) list.get(i)));
        }
        List explainFromPredictions = new PartialDependencePlotExplainer().explainFromPredictions(decisionModelWrapper, arrayList);
        AssertionsForClassTypes.assertThat(explainFromPredictions).isNotNull();
        org.assertj.core.api.Assertions.assertThat(explainFromPredictions).hasSize(25);
    }

    private List<PredictionInput> getInputs() {
        ArrayList arrayList = new ArrayList();
        HashMap hashMap = new HashMap();
        hashMap.put("Monthly Other Debt", 1000);
        hashMap.put("Monthly Income", 10000);
        HashMap hashMap2 = new HashMap();
        hashMap2.put("Appraised Value", 500000);
        hashMap2.put("Loan Amount", 300000);
        hashMap2.put("Credit Score", 600);
        hashMap2.put("Borrower", hashMap);
        LinkedList linkedList = new LinkedList();
        linkedList.add(FeatureFactory.newCompositeFeature("context", hashMap2));
        PredictionInput predictionInput = new PredictionInput(linkedList);
        Random random = new Random();
        for (int i = 0; i < 100; i++) {
            arrayList.add(new PredictionInput(DataUtils.perturbFeatures(predictionInput.getFeatures(), new PerturbationContext(random, predictionInput.getFeatures().size()))));
        }
        return arrayList;
    }
}
