package org.kie.kogito.explainability.explainability.integrationtests.dmn;

import java.io.InputStreamReader;
import java.io.Reader;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.kie.dmn.api.core.DMNRuntime;
import org.kie.kogito.dmn.DMNKogito;
import org.kie.kogito.dmn.DmnDecisionModel;
import org.kie.kogito.explainability.Config;
import org.kie.kogito.explainability.local.lime.LimeExplainer;
import org.kie.kogito.explainability.model.FeatureFactory;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.Saliency;
import org.kie.kogito.explainability.utils.ExplainabilityMetrics;

/* loaded from: input_file:org/kie/kogito/explainability/explainability/integrationtests/dmn/FraudScoringDmnLimeExplainerTest.class */
class FraudScoringDmnLimeExplainerTest {
    FraudScoringDmnLimeExplainerTest() {
    }

    @Test
    void testFraudScoringDMNExplanation() throws ExecutionException, InterruptedException, TimeoutException {
        DMNRuntime createGenericDMNRuntime = DMNKogito.createGenericDMNRuntime(new Reader[]{new InputStreamReader(getClass().getResourceAsStream("/dmn/fraud.dmn"))});
        Assertions.assertEquals(1, createGenericDMNRuntime.getModels().size());
        DmnDecisionModel dmnDecisionModel = new DmnDecisionModel(createGenericDMNRuntime, "http://www.redhat.com/dmn/definitions/_81556584-7d78-4f8c-9d5f-b3cddb9b5c73", "fraud-scoring");
        LinkedList linkedList = new LinkedList();
        HashMap hashMap = new HashMap();
        hashMap.put("Card Type", "Debit");
        hashMap.put("Location", "Local");
        hashMap.put("Amount", 1000);
        hashMap.put("Auth Code", "Authorized");
        linkedList.add(hashMap);
        HashMap hashMap2 = new HashMap();
        hashMap2.put("Card Type", "Credit");
        hashMap2.put("Location", "Local");
        hashMap2.put("Amount", 100000);
        hashMap2.put("Auth Code", "Denied");
        linkedList.add(hashMap2);
        HashMap hashMap3 = new HashMap();
        hashMap3.put("Transactions", linkedList);
        DecisionModelWrapper decisionModelWrapper = new DecisionModelWrapper(dmnDecisionModel);
        LinkedList linkedList2 = new LinkedList();
        linkedList2.add(FeatureFactory.newCompositeFeature("context", hashMap3));
        PredictionInput predictionInput = new PredictionInput(linkedList2);
        Prediction prediction = new Prediction(predictionInput, (PredictionOutput) ((List) decisionModelWrapper.predictAsync(List.of(predictionInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0));
        for (Saliency saliency : ((Map) new LimeExplainer(100, 5).explainAsync(prediction, decisionModelWrapper).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).values()) {
            Assertions.assertNotNull(saliency);
            List topFeatures = saliency.getTopFeatures(4);
            double abs = Math.abs(((Double) topFeatures.stream().map((v0) -> {
                return v0.getScore();
            }).findFirst().orElse(Double.valueOf(0.0d))).doubleValue());
            if (!topFeatures.isEmpty() && abs > 0.0d) {
                org.assertj.core.api.Assertions.assertThat(ExplainabilityMetrics.impactScore(decisionModelWrapper, prediction, topFeatures)).isPositive();
            }
        }
    }
}
