package org.kie.kogito.explainability.explainability.integrationtests.opennlp;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
import java.util.function.Function;
import opennlp.tools.langdetect.Language;
import opennlp.tools.langdetect.LanguageDetectorME;
import opennlp.tools.langdetect.LanguageDetectorModel;
import org.assertj.core.api.AssertionsForClassTypes;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.kie.kogito.explainability.Config;
import org.kie.kogito.explainability.local.lime.LimeConfig;
import org.kie.kogito.explainability.local.lime.LimeExplainer;
import org.kie.kogito.explainability.local.lime.optim.LimeConfigOptimizer;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.FeatureFactory;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.PerturbationContext;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionInputsDataDistribution;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.Saliency;
import org.kie.kogito.explainability.model.SimplePrediction;
import org.kie.kogito.explainability.model.Type;
import org.kie.kogito.explainability.model.Value;
import org.kie.kogito.explainability.utils.DataUtils;
import org.kie.kogito.explainability.utils.ExplainabilityMetrics;
import org.kie.kogito.explainability.utils.ValidationUtils;

/* loaded from: input_file:org/kie/kogito/explainability/explainability/integrationtests/opennlp/OpenNLPLimeExplainerTest.class */
class OpenNLPLimeExplainerTest {
    OpenNLPLimeExplainerTest() {
    }

    @ValueSource(ints = {0})
    @ParameterizedTest
    void testOpenNLPLangDetect(int i) throws Exception {
        Random random = new Random();
        random.setSeed(i);
        LimeExplainer limeExplainer = new LimeExplainer(new LimeConfig().withSamples(10).withPerturbationContext(new PerturbationContext(random, 1)));
        PredictionProvider model = getModel();
        Function<String, List<String>> tokenizer = getTokenizer();
        PredictionInput testInput = getTestInput(tokenizer);
        List list = (List) model.predictAsync(List.of(testInput)).get();
        Assertions.assertNotNull(list);
        Assertions.assertFalse(list.isEmpty());
        PredictionOutput predictionOutput = (PredictionOutput) list.get(0);
        Assertions.assertNotNull(predictionOutput);
        Assertions.assertNotNull(predictionOutput.getOutputs());
        Assertions.assertEquals(1, predictionOutput.getOutputs().size());
        Assertions.assertEquals("ita", ((Output) predictionOutput.getOutputs().get(0)).getValue().asString());
        Assertions.assertEquals(0.03d, ((Output) predictionOutput.getOutputs().get(0)).getScore(), 0.01d);
        SimplePrediction simplePrediction = new SimplePrediction(testInput, predictionOutput);
        for (Saliency saliency : ((Map) limeExplainer.explainAsync(simplePrediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).values()) {
            Assertions.assertNotNull(saliency);
            Assertions.assertEquals(1.0d, ExplainabilityMetrics.impactScore(model, simplePrediction, saliency.getPositiveFeatures(3)));
        }
        Assertions.assertDoesNotThrow(() -> {
            ValidationUtils.validateLocalSaliencyStability(model, simplePrediction, limeExplainer, 2, 0.6d, 0.6d);
        });
        AssertionsForClassTypes.assertThat(ExplainabilityMetrics.getLocalSaliencyF1("lang", model, limeExplainer, new PredictionInputsDataDistribution(getSamples(tokenizer)), 2, 2)).isBetween(Double.valueOf(0.5d), Double.valueOf(1.0d));
    }

    private Function<String, List<String>> getTokenizer() {
        return str -> {
            return Arrays.asList(str.split("\\W"));
        };
    }

    private PredictionProvider getModel() throws IOException {
        LanguageDetectorME languageDetectorME = new LanguageDetectorME(new LanguageDetectorModel(getClass().getResourceAsStream("/opennlp/langdetect-183.bin")));
        return list -> {
            return CompletableFuture.supplyAsync(() -> {
                LinkedList linkedList = new LinkedList();
                Iterator it = list.iterator();
                while (it.hasNext()) {
                    PredictionInput predictionInput = (PredictionInput) it.next();
                    StringBuilder sb = new StringBuilder();
                    for (Feature feature : predictionInput.getFeatures()) {
                        if (sb.length() > 0) {
                            sb.append(' ');
                        }
                        sb.append(feature.getValue().asString());
                    }
                    Language predictLanguage = languageDetectorME.predictLanguage(sb.toString());
                    linkedList.add(new PredictionOutput(List.of(new Output("lang", Type.TEXT, new Value(predictLanguage.getLang()), predictLanguage.getConfidence()))));
                }
                return linkedList;
            });
        };
    }

    private List<PredictionInput> getSamples(Function<String, List<String>> function) {
        List of = List.of("we want your money", "please reply quickly", "you are the lucky winner", "italiani, spaghetti pizza mandolino", "guten tag", "allez les bleus", "daje roma");
        ArrayList arrayList = new ArrayList();
        Iterator it = of.iterator();
        while (it.hasNext()) {
            arrayList.add(new PredictionInput(List.of(FeatureFactory.newFulltextFeature("text", (String) it.next(), function))));
        }
        return arrayList;
    }

    private PredictionInput getTestInput(Function<String, List<String>> function) {
        ArrayList arrayList = new ArrayList();
        arrayList.add(FeatureFactory.newFulltextFeature("text", "italiani,spaghetti pizza mandolino", function));
        return new PredictionInput(arrayList);
    }

    @Test
    void testExplanationStabilityWithOptimization() throws ExecutionException, InterruptedException, TimeoutException, IOException {
        PredictionProvider model = getModel();
        List<PredictionInput> samples = getSamples(getTokenizer());
        List predictions = DataUtils.getPredictions(samples, (List) model.predictAsync(samples.subList(0, 5)).get());
        LimeConfigOptimizer withSampling = new LimeConfigOptimizer().withSampling(false);
        Random random = new Random();
        random.setSeed(0L);
        LimeConfig withPerturbationContext = new LimeConfig().withSamples(10).withPerturbationContext(new PerturbationContext(random, 1));
        LimeConfig optimize = withSampling.optimize(withPerturbationContext, predictions, model);
        org.assertj.core.api.Assertions.assertThat(optimize).isNotSameAs(withPerturbationContext);
        LimeExplainer limeExplainer = new LimeExplainer(optimize);
        PredictionInput testInput = getTestInput(getTokenizer());
        SimplePrediction simplePrediction = new SimplePrediction(testInput, (PredictionOutput) ((List) model.predictAsync(List.of(testInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0));
        Assertions.assertDoesNotThrow(() -> {
            ValidationUtils.validateLocalSaliencyStability(model, simplePrediction, limeExplainer, 1, 0.9d, 0.8d);
        });
    }

    @Test
    void testExplanationImpactScoreWithOptimization() throws ExecutionException, InterruptedException, TimeoutException, IOException {
        PredictionProvider model = getModel();
        List<PredictionInput> samples = getSamples(getTokenizer());
        List predictions = DataUtils.getPredictions(samples, (List) model.predictAsync(samples.subList(0, 5)).get());
        LimeConfigOptimizer withSampling = new LimeConfigOptimizer().forImpactScore().withSampling(false);
        Random random = new Random();
        random.setSeed(0L);
        LimeConfig withPerturbationContext = new LimeConfig().withSamples(10).withPerturbationContext(new PerturbationContext(random, 1));
        org.assertj.core.api.Assertions.assertThat(withSampling.optimize(withPerturbationContext, predictions, model)).isNotSameAs(withPerturbationContext);
    }

    @Test
    void testExplanationWeightedStabilityWithOptimization() throws ExecutionException, InterruptedException, TimeoutException, IOException {
        PredictionProvider model = getModel();
        List<PredictionInput> samples = getSamples(getTokenizer());
        List predictions = DataUtils.getPredictions(samples, (List) model.predictAsync(samples.subList(0, 5)).get());
        LimeConfigOptimizer withSampling = new LimeConfigOptimizer().withSampling(false);
        Random random = new Random();
        random.setSeed(0L);
        LimeConfig withPerturbationContext = new LimeConfig().withSamples(10).withPerturbationContext(new PerturbationContext(random, 1));
        LimeConfig optimize = withSampling.optimize(withPerturbationContext, predictions, model);
        org.assertj.core.api.Assertions.assertThat(optimize).isNotSameAs(withPerturbationContext);
        LimeExplainer limeExplainer = new LimeExplainer(optimize);
        PredictionInput testInput = getTestInput(getTokenizer());
        SimplePrediction simplePrediction = new SimplePrediction(testInput, (PredictionOutput) ((List) model.predictAsync(List.of(testInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0));
        Assertions.assertDoesNotThrow(() -> {
            ValidationUtils.validateLocalSaliencyStability(model, simplePrediction, limeExplainer, 1, 0.8d, 0.9d);
        });
    }
}
