package org.kie.kogito.explainability.explainability.integrationtests.dmn;

import java.io.InputStreamReader;
import java.io.Reader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
import org.assertj.core.api.AssertionsForClassTypes;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.kie.dmn.api.core.DMNRuntime;
import org.kie.kogito.dmn.DMNKogito;
import org.kie.kogito.dmn.DmnDecisionModel;
import org.kie.kogito.explainability.Config;
import org.kie.kogito.explainability.global.pdp.PartialDependencePlotExplainer;
import org.kie.kogito.explainability.model.FeatureFactory;
import org.kie.kogito.explainability.model.PerturbationContext;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.SimplePrediction;
import org.kie.kogito.explainability.utils.DataUtils;

/* loaded from: input_file:org/kie/kogito/explainability/explainability/integrationtests/dmn/FraudScoringDmnPDPExplainerTest.class */
class FraudScoringDmnPDPExplainerTest {
    FraudScoringDmnPDPExplainerTest() {
    }

    @Test
    void testFraudScoringDMNExplanation() throws ExecutionException, InterruptedException, TimeoutException {
        DMNRuntime createGenericDMNRuntime = DMNKogito.createGenericDMNRuntime(new Reader[]{new InputStreamReader(getClass().getResourceAsStream("/dmn/fraud.dmn"))});
        Assertions.assertEquals(1, createGenericDMNRuntime.getModels().size());
        DecisionModelWrapper decisionModelWrapper = new DecisionModelWrapper(new DmnDecisionModel(createGenericDMNRuntime, "http://www.redhat.com/dmn/definitions/_81556584-7d78-4f8c-9d5f-b3cddb9b5c73", "fraud-scoring"));
        List<PredictionInput> inputs = getInputs();
        List list = (List) decisionModelWrapper.predictAsync(inputs).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < list.size(); i++) {
            arrayList.add(new SimplePrediction(inputs.get(i), (PredictionOutput) list.get(i)));
        }
        List explainFromPredictions = new PartialDependencePlotExplainer().explainFromPredictions(decisionModelWrapper, arrayList);
        AssertionsForClassTypes.assertThat(explainFromPredictions).isNotNull();
        org.assertj.core.api.Assertions.assertThat(explainFromPredictions).hasSize(32);
    }

    private List<PredictionInput> getInputs() {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        HashMap hashMap = new HashMap();
        hashMap.put("Card Type", "Debit");
        hashMap.put("Location", "Local");
        hashMap.put("Amount", 1000);
        hashMap.put("Auth Code", "Authorized");
        arrayList2.add(hashMap);
        HashMap hashMap2 = new HashMap();
        hashMap2.put("Card Type", "Credit");
        hashMap2.put("Location", "Local");
        hashMap2.put("Amount", 100000);
        hashMap2.put("Auth Code", "Denied");
        arrayList2.add(hashMap2);
        HashMap hashMap3 = new HashMap();
        hashMap3.put("Transactions", arrayList2);
        ArrayList arrayList3 = new ArrayList();
        arrayList3.add(FeatureFactory.newCompositeFeature("context", hashMap3));
        PredictionInput predictionInput = new PredictionInput(arrayList3);
        arrayList.add(predictionInput);
        Random random = new Random();
        for (int i = 0; i < 100; i++) {
            arrayList.add(new PredictionInput(DataUtils.perturbFeatures(predictionInput.getFeatures(), new PerturbationContext(random, predictionInput.getFeatures().size()))));
        }
        return arrayList;
    }
}
