package org.kie.kogito.explainability.utils;

import java.time.Duration;
import java.time.LocalTime;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import java.util.stream.Collectors;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.kie.kogito.explainability.TestUtils;
import org.kie.kogito.explainability.model.DataDistribution;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.FeatureDistribution;
import org.kie.kogito.explainability.model.FeatureFactory;
import org.kie.kogito.explainability.model.PerturbationContext;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.Type;
import org.kie.kogito.explainability.model.Value;

/* loaded from: input_file:org/kie/kogito/explainability/utils/DataUtilsTest.class */
class DataUtilsTest {
    private static final Random random = new Random();

    DataUtilsTest() {
    }

    @BeforeAll
    static void setupBefore() {
        random.setSeed(4L);
    }

    @Test
    void testDataGeneration() {
        double[] generateData = DataUtils.generateData(0.5d, 0.1d, 100, random);
        Assertions.assertEquals(0.5d, DataUtils.getMean(generateData), 0.01d);
        Assertions.assertEquals(0.1d, DataUtils.getStdDev(generateData, 0.5d), 0.01d);
        double d = 0.0d;
        for (double d2 : generateData) {
            d += d2 - 0.5d;
        }
        Assertions.assertEquals(0.0d, d, 1.0E-4d);
    }

    @Test
    void testGetMean() {
        Assertions.assertEquals(3.0d, DataUtils.getMean(new double[]{2.0d, 4.0d, 3.0d, 5.0d, 1.0d}), 1.0E-6d);
    }

    @Test
    void testGetStdDev() {
        Assertions.assertEquals(1.41d, DataUtils.getStdDev(new double[]{2.0d, 4.0d, 3.0d, 5.0d, 1.0d}, 3.0d), 0.01d);
    }

    @Test
    void testGaussianKernel() {
        Assertions.assertEquals(0.398d, DataUtils.gaussianKernel(0.0d, 0.0d, 1.0d), 0.001d);
        Assertions.assertEquals(0.389d, DataUtils.gaussianKernel(0.218d, 0.0d, 1.0d), 0.001d);
    }

    @Test
    void testEuclideanDistance() {
        double[] dArr = {1.0d, 1.0d};
        Assertions.assertEquals(2.236d, DataUtils.euclideanDistance(dArr, new double[]{2.0d, 3.0d}), 0.001d);
        Assertions.assertTrue(Double.isNaN(DataUtils.euclideanDistance(dArr, new double[0])));
    }

    @Test
    void testHammingDistanceDouble() {
        double[] dArr = {2.0d, 1.0d};
        Assertions.assertEquals(1.0d, DataUtils.hammingDistance(dArr, new double[]{2.0d, 3.0d}), 0.1d);
        Assertions.assertTrue(Double.isNaN(DataUtils.hammingDistance(dArr, new double[0])));
    }

    @Test
    void testHammingDistanceString() {
        Assertions.assertEquals(1.0d, DataUtils.hammingDistance("test1", "test2"), 0.1d);
        Assertions.assertTrue(Double.isNaN(DataUtils.hammingDistance("test1", "testTooLong")));
    }

    @Test
    void testExponentialSmoothingKernel() {
        Assertions.assertEquals(0.994d, DataUtils.exponentialSmoothingKernel(0.218d, 2.0d), 0.001d);
    }

    @Test
    void testPerturbFeaturesEmpty() {
        LinkedList linkedList = new LinkedList();
        List perturbFeatures = DataUtils.perturbFeatures(linkedList, new PerturbationContext(random, 0));
        Assertions.assertNotNull(perturbFeatures);
        Assertions.assertEquals(linkedList.size(), perturbFeatures.size());
    }

    @Test
    void testPerturbDropNumericZero() {
        LinkedList linkedList = new LinkedList();
        linkedList.add(FeatureFactory.newNumericalFeature("f0", 1));
        linkedList.add(FeatureFactory.newNumericalFeature("f1", Double.valueOf(3.14d)));
        linkedList.add(FeatureFactory.newNumericalFeature("f2", 5));
        assertPerturbDropNumeric(new PredictionInput(linkedList), 0);
    }

    @Test
    void testPerturbDropNumericOne() {
        LinkedList linkedList = new LinkedList();
        linkedList.add(FeatureFactory.newNumericalFeature("f0", 1));
        linkedList.add(FeatureFactory.newNumericalFeature("f1", Double.valueOf(3.14d)));
        linkedList.add(FeatureFactory.newNumericalFeature("f2", Double.valueOf(0.55d)));
        assertPerturbDropNumeric(new PredictionInput(linkedList), 1);
    }

    @Test
    void testPerturbDropNumericTwo() {
        LinkedList linkedList = new LinkedList();
        linkedList.add(FeatureFactory.newNumericalFeature("f0", 1));
        linkedList.add(FeatureFactory.newNumericalFeature("f1", Double.valueOf(3.14d)));
        linkedList.add(FeatureFactory.newNumericalFeature("f2", Double.valueOf(0.55d)));
        assertPerturbDropNumeric(new PredictionInput(linkedList), 2);
    }

    @Test
    void testPerturbDropNumericThree() {
        LinkedList linkedList = new LinkedList();
        linkedList.add(FeatureFactory.newNumericalFeature("f0", 1));
        linkedList.add(FeatureFactory.newNumericalFeature("f1", Double.valueOf(3.14d)));
        linkedList.add(FeatureFactory.newNumericalFeature("f2", Double.valueOf(0.55d)));
        assertPerturbDropNumeric(new PredictionInput(linkedList), 3);
    }

    @Test
    void testPerturbDropStringZero() {
        LinkedList linkedList = new LinkedList();
        linkedList.add(FeatureFactory.newTextFeature("f0", "foo"));
        linkedList.add(FeatureFactory.newTextFeature("f1", "foo bar"));
        linkedList.add(FeatureFactory.newTextFeature("f2", " "));
        linkedList.add(FeatureFactory.newTextFeature("f3", "foo bar "));
        assertPerturbDropString(new PredictionInput(linkedList), 0);
    }

    @Test
    void testPerturbDropStringOne() {
        LinkedList linkedList = new LinkedList();
        linkedList.add(FeatureFactory.newTextFeature("f0", "foo"));
        linkedList.add(FeatureFactory.newTextFeature("f1", "foo bar"));
        linkedList.add(FeatureFactory.newTextFeature("f2", " "));
        linkedList.add(FeatureFactory.newTextFeature("f3", "foo bar "));
        assertPerturbDropString(new PredictionInput(linkedList), 1);
    }

    @Test
    void testPerturbDropStringTwo() {
        LinkedList linkedList = new LinkedList();
        linkedList.add(FeatureFactory.newTextFeature("f0", "foo"));
        linkedList.add(FeatureFactory.newTextFeature("f1", "foo bar"));
        linkedList.add(FeatureFactory.newTextFeature("f2", " "));
        linkedList.add(FeatureFactory.newTextFeature("f3", "foo bar "));
        assertPerturbDropString(new PredictionInput(linkedList), 2);
    }

    @Test
    void testPerturbDropStringThree() {
        LinkedList linkedList = new LinkedList();
        linkedList.add(FeatureFactory.newTextFeature("f0", "foo"));
        linkedList.add(FeatureFactory.newTextFeature("f1", "foo bar"));
        linkedList.add(FeatureFactory.newTextFeature("f2", " "));
        linkedList.add(FeatureFactory.newTextFeature("f3", "foo bar "));
        assertPerturbDropString(new PredictionInput(linkedList), 3);
    }

    @Test
    void testPerturbDropCompositeStringZero() {
        LinkedList linkedList = new LinkedList();
        linkedList.add(FeatureFactory.newTextFeature("f0", "foo"));
        linkedList.add(FeatureFactory.newTextFeature("f1", "foo bar"));
        linkedList.add(FeatureFactory.newTextFeature("f2", " "));
        linkedList.add(FeatureFactory.newTextFeature("f3", "foo bar "));
        assertPerturbDropString(new PredictionInput(List.of(FeatureFactory.newCompositeFeature("composite", linkedList))), 0);
    }

    @Test
    void testPerturbDropCompositeStringOne() {
        LinkedList linkedList = new LinkedList();
        linkedList.add(FeatureFactory.newTextFeature("f0", "foo"));
        linkedList.add(FeatureFactory.newTextFeature("f1", "foo bar"));
        linkedList.add(FeatureFactory.newTextFeature("f2", " "));
        linkedList.add(FeatureFactory.newTextFeature("f3", "foo bar "));
        assertPerturbDropString(new PredictionInput(List.of(FeatureFactory.newCompositeFeature("composite", linkedList))), 1);
    }

    @Test
    void testPerturbDropCompositeStringTwo() {
        LinkedList linkedList = new LinkedList();
        linkedList.add(FeatureFactory.newTextFeature("f0", "foo"));
        linkedList.add(FeatureFactory.newTextFeature("f1", "foo bar"));
        linkedList.add(FeatureFactory.newTextFeature("f2", " "));
        linkedList.add(FeatureFactory.newTextFeature("f3", "foo bar "));
        assertPerturbDropString(new PredictionInput(List.of(FeatureFactory.newCompositeFeature("composite", linkedList))), 2);
    }

    @Test
    void testPerturbDropCompositeStringThree() {
        LinkedList linkedList = new LinkedList();
        linkedList.add(FeatureFactory.newTextFeature("f0", "foo"));
        linkedList.add(FeatureFactory.newTextFeature("f1", "foo bar"));
        linkedList.add(FeatureFactory.newTextFeature("f2", " "));
        linkedList.add(FeatureFactory.newTextFeature("f3", "foo bar "));
        assertPerturbDropString(new PredictionInput(List.of(FeatureFactory.newCompositeFeature("composite", linkedList))), 3);
    }

    private void assertPerturbDropNumeric(PredictionInput predictionInput, int i) {
        List perturbFeatures = DataUtils.perturbFeatures(predictionInput.getFeatures(), new PerturbationContext(random, i));
        int i2 = 0;
        for (int i3 = 0; i3 < predictionInput.getFeatures().size(); i3++) {
            if (((Feature) predictionInput.getFeatures().get(i3)).getValue().asNumber() != ((Feature) perturbFeatures.get(i3)).getValue().asNumber()) {
                i2++;
            }
        }
        org.assertj.core.api.Assertions.assertThat(i2).isBetween(Integer.valueOf((int) Math.min(i, predictionInput.getFeatures().size() * 0.5d)), Integer.valueOf((int) Math.max(i, predictionInput.getFeatures().size() * 0.5d)));
    }

    private void assertPerturbDropString(PredictionInput predictionInput, int i) {
        List perturbFeatures = DataUtils.perturbFeatures(predictionInput.getFeatures(), new PerturbationContext(random, i));
        int i2 = 0;
        for (int i3 = 0; i3 < predictionInput.getFeatures().size(); i3++) {
            if (!((Feature) predictionInput.getFeatures().get(i3)).getValue().asString().equals(((Feature) perturbFeatures.get(i3)).getValue().asString())) {
                i2++;
            }
        }
        org.assertj.core.api.Assertions.assertThat(i2).isBetween(Integer.valueOf((int) Math.min(i, predictionInput.getFeatures().size() * 0.5d)), Integer.valueOf((int) Math.max(i, predictionInput.getFeatures().size() * 0.5d)));
    }

    @Test
    void testDoublesToFeatures() {
        double[] dArr = new double[10];
        for (int i = 0; i < 10; i++) {
            dArr[i] = i % 2 == 0 ? 1.0d : 0.0d;
        }
        List<Feature> doublesToFeatures = DataUtils.doublesToFeatures(dArr);
        Assertions.assertNotNull(doublesToFeatures);
        Assertions.assertEquals(10, doublesToFeatures.size());
        for (Feature feature : doublesToFeatures) {
            Assertions.assertNotNull(feature);
            Assertions.assertNotNull(feature.getName());
            Assertions.assertEquals(Type.NUMBER, feature.getType());
            Assertions.assertNotNull(feature.getValue());
        }
    }

    @Test
    void testDoubleToFeature() {
        Feature doubleToFeature = DataUtils.doubleToFeature(0.5d);
        Assertions.assertNotNull(doubleToFeature);
        Assertions.assertNotNull(doubleToFeature.getName());
        Assertions.assertEquals(Type.NUMBER, doubleToFeature.getType());
        Assertions.assertNotNull(doubleToFeature.getValue());
    }

    @Test
    void testRandomDistributionGeneration() {
        DataDistribution generateRandomDataDistribution = DataUtils.generateRandomDataDistribution(10, 10, random);
        Assertions.assertNotNull(generateRandomDataDistribution);
        Assertions.assertNotNull(generateRandomDataDistribution.asFeatureDistributions());
        Iterator it = generateRandomDataDistribution.asFeatureDistributions().iterator();
        while (it.hasNext()) {
            Assertions.assertNotNull((FeatureDistribution) it.next());
        }
    }

    @Test
    void testLinearizedNumericFeatures() {
        LinkedList linkedList = new LinkedList();
        linkedList.add(TestUtils.getMockedNumericFeature());
        Assertions.assertEquals(linkedList.size(), DataUtils.getLinearizedFeatures(linkedList).size());
    }

    @Test
    void testLinearizedTextFeatures() {
        LinkedList linkedList = new LinkedList();
        linkedList.add(TestUtils.getMockedTextFeature("foo bar "));
        Assertions.assertEquals(1, DataUtils.getLinearizedFeatures(linkedList).size());
    }

    @Test
    void testCompositeLinearizedFeatures() {
        LinkedList linkedList = new LinkedList();
        LinkedList linkedList2 = new LinkedList();
        linkedList2.add(FeatureFactory.newTextFeature("f0", "foo bar"));
        linkedList2.add(FeatureFactory.newFulltextFeature("f0", "foo bar", str -> {
            return Arrays.asList(str.split(" "));
        }));
        linkedList2.add(FeatureFactory.newCategoricalFeature("f0", "1"));
        linkedList2.add(FeatureFactory.newBooleanFeature("f1", true));
        linkedList2.add(FeatureFactory.newNumericalFeature("f2", 13));
        linkedList2.add(FeatureFactory.newDurationFeature("f3", Duration.ofDays(13L)));
        linkedList2.add(FeatureFactory.newTimeFeature("f4", LocalTime.now()));
        linkedList2.add(FeatureFactory.newObjectFeature("f5", new float[]{0.4f, 0.4f}));
        linkedList2.add(FeatureFactory.newObjectFeature("f6", FeatureFactory.newObjectFeature("nf-0", new Object())));
        linkedList.add(FeatureFactory.newCompositeFeature("name", linkedList2));
        Assertions.assertEquals(10, DataUtils.getLinearizedFeatures(linkedList).size());
    }

    @Test
    void testDropFeature() {
        for (Type type : Type.values()) {
            Feature mockedFeature = TestUtils.getMockedFeature(type, new Value(Double.valueOf(1.0d)));
            LinkedList linkedList = new LinkedList();
            linkedList.add(TestUtils.getMockedNumericFeature());
            linkedList.add(mockedFeature);
            linkedList.add(TestUtils.getMockedTextFeature("foo bar"));
            linkedList.add(TestUtils.getMockedNumericFeature());
            Assertions.assertNotEquals(linkedList, DataUtils.dropFeature(linkedList, mockedFeature));
        }
    }

    @Test
    void testDropLinearizedFeature() {
        for (Type type : Type.values()) {
            Feature mockedFeature = TestUtils.getMockedFeature(type, new Value(Double.valueOf(1.0d)));
            LinkedList linkedList = new LinkedList();
            linkedList.add(TestUtils.getMockedNumericFeature());
            linkedList.add(mockedFeature);
            linkedList.add(TestUtils.getMockedTextFeature("foo bar"));
            linkedList.add(TestUtils.getMockedNumericFeature());
            Feature newCompositeFeature = FeatureFactory.newCompositeFeature("composite", linkedList);
            Assertions.assertNotEquals(newCompositeFeature, DataUtils.dropOnLinearizedFeatures(mockedFeature, newCompositeFeature));
        }
    }

    @Test
    void testSampleWithReplacement() {
        List sampleWithReplacement = DataUtils.sampleWithReplacement(new ArrayList(), 1, random);
        Assertions.assertNotNull(sampleWithReplacement);
        Assertions.assertEquals(0, sampleWithReplacement.size());
        List list = (List) Arrays.stream(DataUtils.generateData(0.0d, 1.0d, 100, random)).boxed().collect(Collectors.toList());
        List sampleWithReplacement2 = DataUtils.sampleWithReplacement(list, 10, random);
        Assertions.assertNotNull(sampleWithReplacement2);
        Assertions.assertEquals(10, sampleWithReplacement2.size());
        org.assertj.core.api.Assertions.assertThat(list).contains(new Double[]{(Double) sampleWithReplacement2.get(random.nextInt(10 - 1))});
        List sampleWithReplacement3 = DataUtils.sampleWithReplacement(list, 300, random);
        org.assertj.core.api.Assertions.assertThat(sampleWithReplacement3).isNotNull();
        org.assertj.core.api.Assertions.assertThat(300).isEqualTo(sampleWithReplacement3.size());
        org.assertj.core.api.Assertions.assertThat(list).contains(new Double[]{(Double) sampleWithReplacement3.get(random.nextInt(300 - 1))});
    }
}
