package org.kie.kogito.explainability.local.lime;

import java.nio.ByteBuffer;
import java.nio.charset.Charset;
import java.util.LinkedList;
import java.util.List;
import org.apache.commons.lang3.tuple.Pair;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.kie.kogito.explainability.TestUtils;
import org.kie.kogito.explainability.model.EncodingParams;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.Type;
import org.kie.kogito.explainability.model.Value;

/* loaded from: input_file:org/kie/kogito/explainability/local/lime/DatasetEncoderTest.class */
class DatasetEncoderTest {
    DatasetEncoderTest() {
    }

    @Test
    void testEmptyDatasetEncoding() {
        List encodedTrainingSet = new DatasetEncoder(new LinkedList(), new LinkedList(), new LinkedList(), new Output("foo", Type.NUMBER, new Value(1), 1.0d), new EncodingParams(1.0d, 0.1d)).getEncodedTrainingSet();
        Assertions.assertNotNull(encodedTrainingSet);
        Assertions.assertTrue(encodedTrainingSet.isEmpty());
    }

    @Test
    void testDatasetEncodingWithBinaryData() {
        LinkedList linkedList = new LinkedList();
        for (int i = 0; i < 10; i++) {
            LinkedList linkedList2 = new LinkedList();
            for (int i2 = 0; i2 < 3; i2++) {
                linkedList2.add(TestUtils.getMockedFeature(Type.BINARY, new Value(ByteBuffer.wrap((i + i2).getBytes(Charset.defaultCharset())))));
            }
            linkedList.add(new PredictionInput(linkedList2));
        }
        LinkedList linkedList3 = new LinkedList();
        for (int i3 = 0; i3 < 3; i3++) {
            linkedList3.add(TestUtils.getMockedFeature(Type.BINARY, new Value(ByteBuffer.wrap((i3 + i3).getBytes(Charset.defaultCharset())))));
        }
        assertEncode(linkedList, new PredictionInput(linkedList3));
    }

    @Test
    void testDatasetEncodingWithVectorData() {
        LinkedList linkedList = new LinkedList();
        for (int i = 0; i < 10; i++) {
            LinkedList linkedList2 = new LinkedList();
            for (int i2 = 0; i2 < 3; i2++) {
                linkedList2.add(TestUtils.getMockedFeature(Type.VECTOR, new Value(new double[]{i, i2})));
            }
            linkedList.add(new PredictionInput(linkedList2));
        }
        LinkedList linkedList3 = new LinkedList();
        for (int i3 = 0; i3 < 3; i3++) {
            linkedList3.add(TestUtils.getMockedFeature(Type.BINARY, new Value(new double[]{i3, i3})));
        }
        assertEncode(linkedList, new PredictionInput(linkedList3));
    }

    @Test
    void testDatasetEncodingWithCategoricalData() {
        LinkedList linkedList = new LinkedList();
        for (int i = 0; i < 10; i++) {
            LinkedList linkedList2 = new LinkedList();
            for (int i2 = 0; i2 < 3; i2++) {
                double[] dArr = new double[2];
                linkedList2.add(TestUtils.getMockedFeature(Type.CATEGORICAL, new Value(i + i2)));
            }
            linkedList.add(new PredictionInput(linkedList2));
        }
        LinkedList linkedList3 = new LinkedList();
        for (int i3 = 0; i3 < 3; i3++) {
            linkedList3.add(TestUtils.getMockedFeature(Type.CATEGORICAL, new Value(i3 + i3)));
        }
        assertEncode(linkedList, new PredictionInput(linkedList3));
    }

    @Test
    void testDatasetEncodingWithBooleanData() {
        LinkedList linkedList = new LinkedList();
        for (int i = 0; i < 10; i++) {
            LinkedList linkedList2 = new LinkedList();
            for (int i2 = 0; i2 < 3; i2++) {
                linkedList2.add(TestUtils.getMockedFeature(Type.BOOLEAN, new Value(Boolean.valueOf(i2 % 2 == 0))));
            }
            linkedList.add(new PredictionInput(linkedList2));
        }
        LinkedList linkedList3 = new LinkedList();
        for (int i3 = 0; i3 < 3; i3++) {
            linkedList3.add(TestUtils.getMockedFeature(Type.BOOLEAN, new Value(Boolean.valueOf(i3 % 2 == 0))));
        }
        assertEncode(linkedList, new PredictionInput(linkedList3));
    }

    @Test
    void testDatasetEncodingWithNumericData() {
        LinkedList linkedList = new LinkedList();
        for (int i = 0; i < 10; i++) {
            LinkedList linkedList2 = new LinkedList();
            for (int i2 = 0; i2 < 3; i2++) {
                linkedList2.add(TestUtils.getMockedNumericFeature());
            }
            linkedList.add(new PredictionInput(linkedList2));
        }
        LinkedList linkedList3 = new LinkedList();
        for (int i3 = 0; i3 < 3; i3++) {
            linkedList3.add(TestUtils.getMockedNumericFeature());
        }
        assertEncode(linkedList, new PredictionInput(linkedList3));
    }

    @Test
    void testDatasetEncodingWithTextData() {
        LinkedList linkedList = new LinkedList();
        for (int i = 0; i < 10; i++) {
            LinkedList linkedList2 = new LinkedList();
            for (int i2 = 0; i2 < 3; i2++) {
                linkedList2.add(TestUtils.getMockedTextFeature(i + " " + i2));
            }
            linkedList.add(new PredictionInput(linkedList2));
        }
        LinkedList linkedList3 = new LinkedList();
        for (int i3 = 0; i3 < 3; i3++) {
            linkedList3.add(TestUtils.getMockedTextFeature(i3 + " " + i3));
        }
        assertEncode(linkedList, new PredictionInput(linkedList3));
    }

    private void assertEncode(List<PredictionInput> list, PredictionInput predictionInput) {
        LinkedList linkedList = new LinkedList();
        for (int i = 0; i < 10; i++) {
            linkedList.add(new Output("o", Type.NUMBER, new Value(Double.valueOf(i % 2 == 0 ? 1.0d : 0.0d)), 1.0d));
        }
        List<Pair> encodedTrainingSet = new DatasetEncoder(list, linkedList, predictionInput.getFeatures(), new Output("o", Type.BOOLEAN, new Value(Double.valueOf(1.0d)), 1.0d), new EncodingParams(1.0d, 0.1d)).getEncodedTrainingSet();
        Assertions.assertNotNull(encodedTrainingSet);
        Assertions.assertEquals(10, encodedTrainingSet.size());
        for (Pair pair : encodedTrainingSet) {
            Assertions.assertNotNull(pair.getKey());
            Assertions.assertNotNull(pair.getValue());
            org.assertj.core.api.Assertions.assertThat((Double) pair.getValue()).isBetween(Double.valueOf(0.0d), Double.valueOf(1.0d));
        }
    }
}
