package org.kie.kogito.explainability.explainability.integrationtests.pmml;

import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.CompletableFuture;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.kie.api.pmml.PMML4Result;
import org.kie.kogito.explainability.Config;
import org.kie.kogito.explainability.local.lime.LimeExplainer;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.FeatureFactory;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.Saliency;
import org.kie.kogito.explainability.model.Type;
import org.kie.kogito.explainability.model.Value;
import org.kie.kogito.explainability.utils.ExplainabilityMetrics;
import org.kie.pmml.evaluator.api.executor.PMMLRuntime;
import org.kie.test.util.filesystem.FileUtils;

/* loaded from: input_file:org/kie/kogito/explainability/explainability/integrationtests/pmml/PmmlLimeExplainerTest.class */
class PmmlLimeExplainerTest {
    private static PMMLRuntime logisticRegressionIris;
    private static PMMLRuntime categoricalVariableRegression;
    private static PMMLRuntime scorecardCategorical;
    private static PMMLRuntime compoundScoreCard;

    PmmlLimeExplainerTest() {
    }

    @BeforeAll
    static void setUpBefore() {
        logisticRegressionIris = AbstractPMMLTest.getPMMLRuntime(LogisticRegressionIrisDataExecutor.MODEL_NAME, FileUtils.getFile("logisticRegressionIrisData.pmml"));
        categoricalVariableRegression = AbstractPMMLTest.getPMMLRuntime("categoricalVariables_Model", FileUtils.getFile("categoricalVariablesRegression.pmml"));
        scorecardCategorical = AbstractPMMLTest.getPMMLRuntime("SimpleScorecardCategorical", FileUtils.getFile("SimpleScorecardCategorical.pmml"));
        compoundScoreCard = AbstractPMMLTest.getPMMLRuntime("CompoundNestedPredicateScorecard", FileUtils.getFile("CompoundNestedPredicateScorecard.pmml"));
    }

    @Test
    void testPMMLRegression() throws Exception {
        Random random = new Random();
        for (int i = 0; i < 5; i++) {
            random.setSeed(i);
            LimeExplainer limeExplainer = new LimeExplainer(100, 1, random);
            LinkedList linkedList = new LinkedList();
            linkedList.add(FeatureFactory.newNumericalFeature("sepalLength", Double.valueOf(6.9d)));
            linkedList.add(FeatureFactory.newNumericalFeature("sepalWidth", Double.valueOf(3.1d)));
            linkedList.add(FeatureFactory.newNumericalFeature("petalLength", Double.valueOf(5.1d)));
            linkedList.add(FeatureFactory.newNumericalFeature("petalWidth", Double.valueOf(2.3d)));
            PredictionInput predictionInput = new PredictionInput(linkedList);
            PredictionProvider predictionProvider = list -> {
                return CompletableFuture.supplyAsync(() -> {
                    LinkedList linkedList2 = new LinkedList();
                    Iterator it = list.iterator();
                    while (it.hasNext()) {
                        List features = ((PredictionInput) it.next()).getFeatures();
                        linkedList2.add(new PredictionOutput(List.of(new Output("species", Type.TEXT, new Value(new LogisticRegressionIrisDataExecutor(((Feature) features.get(0)).getValue().asNumber(), ((Feature) features.get(1)).getValue().asNumber(), ((Feature) features.get(2)).getValue().asNumber(), ((Feature) features.get(3)).getValue().asNumber()).execute(logisticRegressionIris).getResultVariables().get("Species").toString()), 1.0d))));
                    }
                    return linkedList2;
                });
            };
            Prediction prediction = new Prediction(predictionInput, (PredictionOutput) ((List) predictionProvider.predictAsync(List.of(predictionInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0));
            for (Saliency saliency : ((Map) limeExplainer.explainAsync(prediction, predictionProvider).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).values()) {
                Assertions.assertNotNull(saliency);
                Assertions.assertEquals(1.0d, ExplainabilityMetrics.impactScore(predictionProvider, prediction, saliency.getPositiveFeatures(2)));
            }
        }
    }

    @Test
    void testPMMLRegressionCategorical() throws Exception {
        LinkedList linkedList = new LinkedList();
        linkedList.add(FeatureFactory.newCategoricalFeature("mapX", "red"));
        linkedList.add(FeatureFactory.newCategoricalFeature("mapY", "classB"));
        PredictionInput predictionInput = new PredictionInput(linkedList);
        LimeExplainer limeExplainer = new LimeExplainer(10, 1);
        PredictionProvider predictionProvider = list -> {
            return CompletableFuture.supplyAsync(() -> {
                LinkedList linkedList2 = new LinkedList();
                Iterator it = list.iterator();
                while (it.hasNext()) {
                    List features = ((PredictionInput) it.next()).getFeatures();
                    linkedList2.add(new PredictionOutput(List.of(new Output("result", Type.NUMBER, new Value(new CategoricalVariablesRegressionExecutor(((Feature) features.get(0)).getValue().asString(), ((Feature) features.get(1)).getValue().asString()).execute(categoricalVariableRegression).getResultVariables().get("result").toString()), 1.0d))));
                }
                return linkedList2;
            });
        };
        Prediction prediction = new Prediction(predictionInput, (PredictionOutput) ((List) predictionProvider.predictAsync(List.of(predictionInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0));
        for (Saliency saliency : ((Map) limeExplainer.explainAsync(prediction, predictionProvider).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).values()) {
            Assertions.assertNotNull(saliency);
            Assertions.assertEquals(1.0d, ExplainabilityMetrics.impactScore(predictionProvider, prediction, saliency.getTopFeatures(1)));
        }
    }

    @Test
    void testPMMLScorecardCategorical() throws Exception {
        LinkedList linkedList = new LinkedList();
        linkedList.add(FeatureFactory.newCategoricalFeature("input1", "classA"));
        linkedList.add(FeatureFactory.newCategoricalFeature("input2", "classB"));
        PredictionInput predictionInput = new PredictionInput(linkedList);
        LimeExplainer limeExplainer = new LimeExplainer(10, 1);
        PredictionProvider predictionProvider = list -> {
            return CompletableFuture.supplyAsync(() -> {
                LinkedList linkedList2 = new LinkedList();
                Iterator it = list.iterator();
                while (it.hasNext()) {
                    List features = ((PredictionInput) it.next()).getFeatures();
                    PMML4Result execute = new SimpleScorecardCategoricalExecutor(((Feature) features.get(0)).getValue().asString(), ((Feature) features.get(1)).getValue().asString()).execute(scorecardCategorical);
                    linkedList2.add(new PredictionOutput(List.of(new Output("score", Type.TEXT, new Value(execute.getResultVariables().get("Score")), 1.0d), new Output("reason1", Type.TEXT, new Value(execute.getResultVariables().get("Reason Code 1")), 1.0d), new Output("reason2", Type.TEXT, new Value(execute.getResultVariables().get("Reason Code 2")), 1.0d))));
                }
                return linkedList2;
            });
        };
        Prediction prediction = new Prediction(predictionInput, (PredictionOutput) ((List) predictionProvider.predictAsync(List.of(predictionInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0));
        for (Saliency saliency : ((Map) limeExplainer.explainAsync(prediction, predictionProvider).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).values()) {
            Assertions.assertNotNull(saliency);
            Assertions.assertEquals(0.33d, ExplainabilityMetrics.impactScore(predictionProvider, prediction, saliency.getTopFeatures(1)), 0.01d);
        }
    }

    @Test
    void testPMMLCompoundScorecard() throws Exception {
        Random random = new Random();
        for (int i = 0; i < 5; i++) {
            random.setSeed(i);
            LimeExplainer limeExplainer = new LimeExplainer(100, 2, random);
            LinkedList linkedList = new LinkedList();
            linkedList.add(FeatureFactory.newNumericalFeature("input1", -50));
            linkedList.add(FeatureFactory.newTextFeature("input2", "classB"));
            PredictionInput predictionInput = new PredictionInput(linkedList);
            PredictionProvider predictionProvider = list -> {
                return CompletableFuture.supplyAsync(() -> {
                    LinkedList linkedList2 = new LinkedList();
                    Iterator it = list.iterator();
                    while (it.hasNext()) {
                        List features = ((PredictionInput) it.next()).getFeatures();
                        PMML4Result execute = new CompoundNestedPredicateScorecardExecutor(((Feature) features.get(0)).getValue().asNumber(), ((Feature) features.get(1)).getValue().asString()).execute(compoundScoreCard);
                        linkedList2.add(new PredictionOutput(List.of(new Output("score", Type.TEXT, new Value(execute.getResultVariables().get("Score")), 1.0d), new Output("reason1", Type.TEXT, new Value(execute.getResultVariables().get("Reason Code 1")), 1.0d))));
                    }
                    return linkedList2;
                });
            };
            Prediction prediction = new Prediction(predictionInput, (PredictionOutput) ((List) predictionProvider.predictAsync(List.of(predictionInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0));
            for (Saliency saliency : ((Map) limeExplainer.explainAsync(prediction, predictionProvider).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).values()) {
                Assertions.assertNotNull(saliency);
                Assertions.assertEquals(1.0d, ExplainabilityMetrics.impactScore(predictionProvider, prediction, saliency.getTopFeatures(2)));
            }
        }
    }
}
