package org.kie.kogito.explainability.global.pdp;

import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.kie.kogito.explainability.Config;
import org.kie.kogito.explainability.TestUtils;
import org.kie.kogito.explainability.model.DataDistribution;
import org.kie.kogito.explainability.model.FeatureFactory;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.PartialDependenceGraph;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.PredictionProviderMetadata;
import org.kie.kogito.explainability.model.Type;
import org.kie.kogito.explainability.model.Value;
import org.kie.kogito.explainability.utils.DataUtils;

/* loaded from: input_file:org/kie/kogito/explainability/global/pdp/PartialDependencePlotExplainerTest.class */
class PartialDependencePlotExplainerTest {
    PartialDependencePlotExplainerTest() {
    }

    private PredictionProviderMetadata getMetadata(final Random random) {
        return new PredictionProviderMetadata() { // from class: org.kie.kogito.explainability.global.pdp.PartialDependencePlotExplainerTest.1
            public DataDistribution getDataDistribution() {
                return DataUtils.generateRandomDataDistribution(3, 100, random);
            }

            public PredictionInput getInputShape() {
                LinkedList linkedList = new LinkedList();
                linkedList.add(FeatureFactory.newNumericalFeature("f0", 0));
                linkedList.add(FeatureFactory.newNumericalFeature("f1", 0));
                linkedList.add(FeatureFactory.newNumericalFeature("f2", 0));
                return new PredictionInput(linkedList);
            }

            public PredictionOutput getOutputShape() {
                LinkedList linkedList = new LinkedList();
                linkedList.add(new Output("sum-but0", Type.BOOLEAN, new Value(false), 0.0d));
                return new PredictionOutput(linkedList);
            }
        };
    }

    @Test
    void testPdpNumericClassifier() throws Exception {
        Random random = new Random();
        for (int i = 0; i < 5; i++) {
            random.setSeed(i);
            List<PartialDependenceGraph> explainFromMetadata = new PartialDependencePlotExplainer().explainFromMetadata(TestUtils.getSumSkipModel(0), getMetadata(random));
            Assertions.assertNotNull(explainFromMetadata);
            for (PartialDependenceGraph partialDependenceGraph : explainFromMetadata) {
                Assertions.assertNotNull(partialDependenceGraph.getFeature());
                Assertions.assertNotNull(partialDependenceGraph.getX());
                Assertions.assertNotNull(partialDependenceGraph.getY());
                Assertions.assertEquals(partialDependenceGraph.getX().size(), partialDependenceGraph.getY().size());
                assertGraph(partialDependenceGraph);
            }
            Assertions.assertEquals(1L, ((PartialDependenceGraph) explainFromMetadata.get(0)).getY().stream().distinct().count());
            org.assertj.core.api.Assertions.assertThat(((PartialDependenceGraph) explainFromMetadata.get(1)).getY().stream().distinct().count()).isGreaterThan(1L);
            org.assertj.core.api.Assertions.assertThat(((PartialDependenceGraph) explainFromMetadata.get(2)).getY().stream().distinct().count()).isGreaterThan(1L);
        }
    }

    private void assertGraph(PartialDependenceGraph partialDependenceGraph) {
        for (int i = 0; i < partialDependenceGraph.getX().size(); i++) {
            Assertions.assertNotEquals(Double.NaN, ((Value) partialDependenceGraph.getY().get(i)).asNumber());
            if (i > 0) {
                Assertions.assertTrue(((Value) partialDependenceGraph.getX().get(i)).asNumber() >= ((Value) partialDependenceGraph.getX().get(i - 1)).asNumber());
            }
        }
    }

    @Test
    void testBrokenPredict() {
        Config.INSTANCE.setAsyncTimeout(1L);
        Config.INSTANCE.setAsyncTimeUnit(TimeUnit.MILLISECONDS);
        Random random = new Random();
        for (int i = 0; i < 5; i++) {
            random.setSeed(i);
            PartialDependencePlotExplainer partialDependencePlotExplainer = new PartialDependencePlotExplainer();
            PredictionProvider predictionProvider = list -> {
                return CompletableFuture.supplyAsync(() -> {
                    try {
                        Thread.sleep(1000L);
                        return Collections.emptyList();
                    } catch (InterruptedException e) {
                        throw new RuntimeException("this is a test");
                    }
                });
            };
            Assertions.assertThrows(TimeoutException.class, () -> {
                partialDependencePlotExplainer.explainFromMetadata(predictionProvider, getMetadata(random));
            });
        }
        Config.INSTANCE.setAsyncTimeout(5L);
        Config.INSTANCE.setAsyncTimeUnit(Config.DEFAULT_ASYNC_TIMEUNIT);
    }

    @Test
    void testTextClassifier() throws Exception {
        Random random = new Random();
        for (int i = 0; i < 5; i++) {
            random.setSeed(i);
            PartialDependencePlotExplainer partialDependencePlotExplainer = new PartialDependencePlotExplainer();
            PredictionProvider dummyTextClassifier = TestUtils.getDummyTextClassifier();
            ArrayList arrayList = new ArrayList(3);
            for (String str : List.of("we want your money", "please reply quickly", "you are the lucky winner", "huge donation for you!", "bitcoin for you")) {
                ArrayList arrayList2 = new ArrayList();
                arrayList2.add(FeatureFactory.newFulltextFeature("text", str));
                PredictionInput predictionInput = new PredictionInput(arrayList2);
                arrayList.add(new Prediction(predictionInput, (PredictionOutput) ((List) dummyTextClassifier.predictAsync(List.of(predictionInput)).get()).get(0)));
            }
            org.assertj.core.api.Assertions.assertThat(partialDependencePlotExplainer.explainFromPredictions(dummyTextClassifier, arrayList)).isNotEmpty();
        }
    }
}
