package org.kie.kogito.explainability.explainability.integrationtests.opennlp;

import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.CompletableFuture;
import opennlp.tools.langdetect.Language;
import opennlp.tools.langdetect.LanguageDetectorME;
import opennlp.tools.langdetect.LanguageDetectorModel;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.kie.kogito.explainability.Config;
import org.kie.kogito.explainability.local.lime.LimeConfig;
import org.kie.kogito.explainability.local.lime.LimeExplainer;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.FeatureFactory;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.PerturbationContext;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.Saliency;
import org.kie.kogito.explainability.model.Type;
import org.kie.kogito.explainability.model.Value;
import org.kie.kogito.explainability.utils.ExplainabilityMetrics;
import org.kie.kogito.explainability.utils.ValidationUtils;

/* loaded from: input_file:org/kie/kogito/explainability/explainability/integrationtests/opennlp/OpenNLPLimeExplainerTest.class */
class OpenNLPLimeExplainerTest {
    OpenNLPLimeExplainerTest() {
    }

    @ValueSource(ints = {0, 1, 2, 3, 4})
    @ParameterizedTest
    void testOpenNLPLangDetect(int i) throws Exception {
        Random random = new Random();
        random.setSeed(i);
        LimeExplainer limeExplainer = new LimeExplainer(new LimeConfig().withSamples(100).withPerturbationContext(new PerturbationContext(random, 2)));
        LanguageDetectorME languageDetectorME = new LanguageDetectorME(new LanguageDetectorModel(getClass().getResourceAsStream("/opennlp/langdetect-183.bin")));
        PredictionProvider predictionProvider = list -> {
            return CompletableFuture.supplyAsync(() -> {
                LinkedList linkedList = new LinkedList();
                Iterator it = list.iterator();
                while (it.hasNext()) {
                    PredictionInput predictionInput = (PredictionInput) it.next();
                    StringBuilder sb = new StringBuilder();
                    for (Feature feature : predictionInput.getFeatures()) {
                        if (sb.length() > 0) {
                            sb.append(' ');
                        }
                        sb.append(feature.getValue().asString());
                    }
                    Language predictLanguage = languageDetectorME.predictLanguage(sb.toString());
                    linkedList.add(new PredictionOutput(List.of(new Output("lang", Type.TEXT, new Value(predictLanguage.getLang()), predictLanguage.getConfidence()))));
                }
                return linkedList;
            });
        };
        LinkedList linkedList = new LinkedList();
        linkedList.add(FeatureFactory.newFulltextFeature("text", "italiani,spaghetti pizza mandolino", str -> {
            return Arrays.asList(str.split("\\W"));
        }));
        PredictionInput predictionInput = new PredictionInput(linkedList);
        List list2 = (List) predictionProvider.predictAsync(List.of(predictionInput)).get();
        Assertions.assertNotNull(list2);
        Assertions.assertFalse(list2.isEmpty());
        PredictionOutput predictionOutput = (PredictionOutput) list2.get(0);
        Assertions.assertNotNull(predictionOutput);
        Assertions.assertNotNull(predictionOutput.getOutputs());
        Assertions.assertEquals(1, predictionOutput.getOutputs().size());
        Assertions.assertEquals("ita", ((Output) predictionOutput.getOutputs().get(0)).getValue().asString());
        Assertions.assertEquals(0.03d, ((Output) predictionOutput.getOutputs().get(0)).getScore(), 0.01d);
        Prediction prediction = new Prediction(predictionInput, predictionOutput);
        for (Saliency saliency : ((Map) limeExplainer.explainAsync(prediction, predictionProvider).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).values()) {
            Assertions.assertNotNull(saliency);
            Assertions.assertEquals(1.0d, ExplainabilityMetrics.impactScore(predictionProvider, prediction, saliency.getPositiveFeatures(3)));
        }
        Assertions.assertDoesNotThrow(() -> {
            ValidationUtils.validateLocalSaliencyStability(predictionProvider, prediction, limeExplainer, 2, 0.8d, 0.8d);
        });
    }
}
