package org.kie.kogito.explainability.model;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.UUID;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.kie.kogito.explainability.TestUtils;

/* loaded from: input_file:org/kie/kogito/explainability/model/DatasetTest.class */
class DatasetTest {
    DatasetTest() {
    }

    @Test
    void testEmpty() {
        Dataset dataset = new Dataset(new ArrayList());
        Assertions.assertThat(dataset.getData()).isEmpty();
        Assertions.assertThat(dataset.getInputs()).isEmpty();
        Assertions.assertThat(dataset.getOutputs()).isEmpty();
    }

    @Test
    void testNotEmpty() {
        ArrayList arrayList = new ArrayList();
        arrayList.add(new SimplePrediction(new PredictionInput(List.of(TestUtils.getMockedNumericFeature())), new PredictionOutput(List.of(new Output("name", Type.UNDEFINED)))));
        Dataset dataset = new Dataset(arrayList);
        Assertions.assertThat(dataset.getData()).isNotEmpty();
        Assertions.assertThat(dataset.getInputs()).isNotEmpty();
        Assertions.assertThat(dataset.getOutputs()).isNotEmpty();
    }

    @Test
    void testInputFilter() {
        ArrayList arrayList = new ArrayList();
        arrayList.add(new SimplePrediction(new PredictionInput(List.of(TestUtils.getMockedNumericFeature())), new PredictionOutput(List.of(new Output("name", Type.UNDEFINED)))));
        Dataset filterByInput = new Dataset(arrayList).filterByInput(predictionInput -> {
            return predictionInput.getFeatures().size() == 1;
        });
        Assertions.assertThat(filterByInput.getData()).isNotEmpty();
        Assertions.assertThat(filterByInput.getInputs()).isNotEmpty();
        Assertions.assertThat(filterByInput.getOutputs()).isNotEmpty();
        Dataset filterByInput2 = new Dataset(arrayList).filterByInput(predictionInput2 -> {
            return predictionInput2.getFeatures().size() == 2;
        });
        Assertions.assertThat(filterByInput2.getData()).isEmpty();
        Assertions.assertThat(filterByInput2.getInputs()).isEmpty();
        Assertions.assertThat(filterByInput2.getOutputs()).isEmpty();
    }

    @Test
    void testOutFilter() {
        ArrayList arrayList = new ArrayList();
        arrayList.add(new SimplePrediction(new PredictionInput(List.of(TestUtils.getMockedNumericFeature())), new PredictionOutput(List.of(new Output("name", Type.UNDEFINED)))));
        Dataset filterByOutput = new Dataset(arrayList).filterByOutput(predictionOutput -> {
            return predictionOutput.getOutputs().size() == 1;
        });
        Assertions.assertThat(filterByOutput.getData()).isNotEmpty();
        Assertions.assertThat(filterByOutput.getInputs()).isNotEmpty();
        Assertions.assertThat(filterByOutput.getOutputs()).isNotEmpty();
        Dataset filterByOutput2 = new Dataset(arrayList).filterByOutput(predictionOutput2 -> {
            return predictionOutput2.getOutputs().size() == 2;
        });
        Assertions.assertThat(filterByOutput2.getData()).isEmpty();
        Assertions.assertThat(filterByOutput2.getInputs()).isEmpty();
        Assertions.assertThat(filterByOutput2.getOutputs()).isEmpty();
    }

    private Dataset createDatasetFeatureFiltering(Random random) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < 1000; i++) {
            ArrayList arrayList2 = new ArrayList();
            for (int i2 = 0; i2 < 4; i2++) {
                arrayList2.add(FeatureFactory.newNumericalFeature("f-" + i2, Double.valueOf(random.nextDouble() * 100.0d)));
            }
            for (int i3 = 4; i3 < 8; i3++) {
                arrayList2.add(FeatureFactory.newNumericalFeature("f-" + i3, Double.valueOf(100.0d + (random.nextDouble() * 100.0d))));
            }
            arrayList2.add(FeatureFactory.newBooleanFeature("f-8", true));
            arrayList2.add(FeatureFactory.newTextFeature("f-9", UUID.randomUUID().toString()));
            arrayList.add(new SimplePrediction(new PredictionInput(arrayList2), new PredictionOutput(List.of(new Output("output", Type.BOOLEAN, new Value(false), 1.0d)))));
        }
        return new Dataset(arrayList);
    }

    @ValueSource(ints = {0, 1, 2})
    @ParameterizedTest
    void testFilterByFeatureName(int i) {
        Random random = new Random(i);
        Dataset createDatasetFeatureFiltering = createDatasetFeatureFiltering(random);
        org.junit.jupiter.api.Assertions.assertEquals(1000, createDatasetFeatureFiltering.getData().size());
        int nextInt = random.nextInt(createDatasetFeatureFiltering.getData().size());
        org.junit.jupiter.api.Assertions.assertEquals(10, ((Prediction) createDatasetFeatureFiltering.getData().get(nextInt)).getInput().getFeatures().size());
        Predicate predicate = feature -> {
            return feature.getName().equals("f-3");
        };
        Dataset filterByFeature = createDatasetFeatureFiltering.filterByFeature(predicate.negate());
        org.junit.jupiter.api.Assertions.assertEquals(1000, filterByFeature.getData().size());
        org.junit.jupiter.api.Assertions.assertEquals(9, ((Prediction) filterByFeature.getData().get(nextInt)).getInput().getFeatures().size());
        org.junit.jupiter.api.Assertions.assertFalse(((List) ((Prediction) filterByFeature.getData().get(nextInt)).getInput().getFeatures().stream().map((v0) -> {
            return v0.getName();
        }).collect(Collectors.toList())).contains("f-3"));
    }

    @ValueSource(ints = {0, 1, 2})
    @ParameterizedTest
    void testFilterByFeatureType(int i) {
        Random random = new Random(i);
        Dataset createDatasetFeatureFiltering = createDatasetFeatureFiltering(random);
        org.junit.jupiter.api.Assertions.assertEquals(1000, createDatasetFeatureFiltering.getData().size());
        int nextInt = random.nextInt(createDatasetFeatureFiltering.getData().size());
        org.junit.jupiter.api.Assertions.assertEquals(10, ((Prediction) createDatasetFeatureFiltering.getData().get(nextInt)).getInput().getFeatures().size());
        Predicate predicate = feature -> {
            return feature.getType().equals(Type.NUMBER);
        };
        Dataset filterByFeature = createDatasetFeatureFiltering.filterByFeature(predicate.negate());
        org.junit.jupiter.api.Assertions.assertEquals(1000, filterByFeature.getData().size());
        org.junit.jupiter.api.Assertions.assertEquals(2, ((Prediction) filterByFeature.getData().get(nextInt)).getInput().getFeatures().size());
        org.junit.jupiter.api.Assertions.assertFalse(((List) ((Prediction) filterByFeature.getData().get(nextInt)).getInput().getFeatures().stream().map((v0) -> {
            return v0.getType();
        }).collect(Collectors.toList())).contains(Type.NUMBER));
    }

    @ValueSource(ints = {0, 1, 2})
    @ParameterizedTest
    void testFilterByFeatureValue(int i) {
        Random random = new Random(i);
        Dataset createDatasetFeatureFiltering = createDatasetFeatureFiltering(random);
        org.junit.jupiter.api.Assertions.assertEquals(1000, createDatasetFeatureFiltering.getData().size());
        int nextInt = random.nextInt(createDatasetFeatureFiltering.getData().size());
        org.junit.jupiter.api.Assertions.assertEquals(10, ((Prediction) createDatasetFeatureFiltering.getData().get(nextInt)).getInput().getFeatures().size());
        Dataset filterByFeature = createDatasetFeatureFiltering.filterByFeature(feature -> {
            return feature.getValue().asNumber() > 100.0d;
        });
        org.junit.jupiter.api.Assertions.assertEquals(1000, filterByFeature.getData().size());
        org.junit.jupiter.api.Assertions.assertEquals(4, ((Prediction) filterByFeature.getData().get(nextInt)).getInput().getFeatures().size());
        org.junit.jupiter.api.Assertions.assertTrue(((List) ((Prediction) filterByFeature.getData().get(nextInt)).getInput().getFeatures().stream().map((v0) -> {
            return v0.getValue();
        }).map((v0) -> {
            return v0.asNumber();
        }).collect(Collectors.toList())).stream().allMatch(d -> {
            return d.doubleValue() > 100.0d;
        }));
        org.junit.jupiter.api.Assertions.assertTrue(((List) ((Prediction) filterByFeature.getData().get(nextInt)).getInput().getFeatures().stream().map((v0) -> {
            return v0.getType();
        }).collect(Collectors.toList())).stream().allMatch(type -> {
            return type.equals(Type.NUMBER);
        }));
    }
}
