package org.kie.kogito.explainability.utils;

import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import org.apache.commons.lang3.tuple.Pair;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.kie.kogito.explainability.Config;
import org.kie.kogito.explainability.TestUtils;
import org.kie.kogito.explainability.local.lime.LimeConfig;
import org.kie.kogito.explainability.local.lime.LimeExplainer;
import org.kie.kogito.explainability.model.FeatureFactory;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.Saliency;

/* loaded from: input_file:org/kie/kogito/explainability/utils/ExplainabilityMetricsTest.class */
class ExplainabilityMetricsTest {
    ExplainabilityMetricsTest() {
    }

    @Test
    void testExplainabilityNoExplanation() {
        double quantifyExplainability = ExplainabilityMetrics.quantifyExplainability(0, 0, 0.0d);
        Assertions.assertFalse(Double.isNaN(quantifyExplainability));
        Assertions.assertFalse(Double.isInfinite(quantifyExplainability));
        Assertions.assertEquals(0.0d, quantifyExplainability);
    }

    @Test
    void testExplainabilityNoExplanationWithInteraction() {
        double quantifyExplainability = ExplainabilityMetrics.quantifyExplainability(0, 0, 1.0d);
        Assertions.assertFalse(Double.isNaN(quantifyExplainability));
        Assertions.assertFalse(Double.isInfinite(quantifyExplainability));
        Assertions.assertEquals(0.0d, quantifyExplainability);
    }

    @Test
    void testExplainabilitySameIOChunksNoInteraction() {
        double quantifyExplainability = ExplainabilityMetrics.quantifyExplainability(10, 10, 0.0d);
        Assertions.assertFalse(Double.isNaN(quantifyExplainability));
        Assertions.assertFalse(Double.isInfinite(quantifyExplainability));
        org.assertj.core.api.Assertions.assertThat(quantifyExplainability).isBetween(Double.valueOf(0.0d), Double.valueOf(1.0d));
    }

    @Test
    void testExplainabilitySameIOChunksWithInteraction() {
        Assertions.assertEquals(0.2331d, ExplainabilityMetrics.quantifyExplainability(10, 10, 0.5d), 1.0E-5d);
    }

    @Test
    void testExplainabilityDifferentIOChunksNoInteraction() {
        Assertions.assertEquals(0.481d, ExplainabilityMetrics.quantifyExplainability(3, 9, 0.0d), 1.0E-5d);
    }

    @Test
    void testExplainabilityDifferentIOChunksInteraction() {
        Assertions.assertEquals(0.3145d, ExplainabilityMetrics.quantifyExplainability(3, 9, 0.5d), 1.0E-5d);
    }

    @Test
    void testFidelityWithTextClassifier() throws ExecutionException, InterruptedException, TimeoutException {
        LinkedList linkedList = new LinkedList();
        LimeExplainer limeExplainer = new LimeExplainer(new LimeConfig().withSamples(10));
        PredictionProvider dummyTextClassifier = TestUtils.getDummyTextClassifier();
        LinkedList linkedList2 = new LinkedList();
        linkedList2.add(FeatureFactory.newFulltextFeature("f-0", "brown fox", str -> {
            return Arrays.asList(str.split(" "));
        }));
        linkedList2.add(FeatureFactory.newTextFeature("f-1", "money"));
        PredictionInput predictionInput = new PredictionInput(linkedList2);
        Prediction prediction = new Prediction(predictionInput, (PredictionOutput) ((List) dummyTextClassifier.predictAsync(List.of(predictionInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0));
        Iterator it = ((Map) limeExplainer.explainAsync(prediction, dummyTextClassifier).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).values().iterator();
        while (it.hasNext()) {
            linkedList.add(Pair.of((Saliency) it.next(), prediction));
        }
        Assertions.assertDoesNotThrow(() -> {
            ExplainabilityMetrics.classificationFidelity(linkedList);
        });
    }

    @Test
    void testFidelityWithEvenSumModel() throws ExecutionException, InterruptedException, TimeoutException {
        LinkedList linkedList = new LinkedList();
        LimeExplainer limeExplainer = new LimeExplainer(new LimeConfig().withSamples(10));
        PredictionProvider evenSumModel = TestUtils.getEvenSumModel(1);
        LinkedList linkedList2 = new LinkedList();
        linkedList2.add(FeatureFactory.newNumericalFeature("f-1", 1));
        linkedList2.add(FeatureFactory.newNumericalFeature("f-2", 2));
        linkedList2.add(FeatureFactory.newNumericalFeature("f-3", 3));
        PredictionInput predictionInput = new PredictionInput(linkedList2);
        Prediction prediction = new Prediction(predictionInput, (PredictionOutput) ((List) evenSumModel.predictAsync(List.of(predictionInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0));
        Iterator it = ((Map) limeExplainer.explainAsync(prediction, evenSumModel).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).values().iterator();
        while (it.hasNext()) {
            linkedList.add(Pair.of((Saliency) it.next(), prediction));
        }
        Assertions.assertDoesNotThrow(() -> {
            ExplainabilityMetrics.classificationFidelity(linkedList);
        });
    }

    @Test
    void testBrokenPredict() {
        Config.INSTANCE.setAsyncTimeout(1L);
        Config.INSTANCE.setAsyncTimeUnit(TimeUnit.MILLISECONDS);
        Prediction prediction = new Prediction(new PredictionInput(Collections.emptyList()), new PredictionOutput(Collections.emptyList()));
        PredictionProvider predictionProvider = list -> {
            return CompletableFuture.supplyAsync(() -> {
                try {
                    Thread.sleep(1000L);
                    return Collections.emptyList();
                } catch (InterruptedException e) {
                    throw new RuntimeException("this is a test");
                }
            });
        };
        Assertions.assertThrows(TimeoutException.class, () -> {
            ExplainabilityMetrics.impactScore(predictionProvider, prediction, Collections.emptyList());
        });
        Config.INSTANCE.setAsyncTimeout(5L);
        Config.INSTANCE.setAsyncTimeUnit(Config.DEFAULT_ASYNC_TIMEUNIT);
    }
}
