package org.kie.kogito.explainability.local.shap;

import java.util.ArrayList;
import java.util.List;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.kie.kogito.explainability.model.FeatureFactory;
import org.kie.kogito.explainability.model.PredictionInput;

/* loaded from: input_file:org/kie/kogito/explainability/local/shap/ShapSyntheticDataSampleTest.class */
class ShapSyntheticDataSampleTest {
    ShapSyntheticDataSampleTest() {
    }

    /* JADX WARN: Type inference failed for: r0v15, types: [double[], double[][]] */
    private ShapSyntheticDataSample generateShapSample() {
        ArrayList arrayList = new ArrayList();
        arrayList.add(FeatureFactory.newNumericalFeature("f1", -1));
        arrayList.add(FeatureFactory.newNumericalFeature("f2", -1));
        arrayList.add(FeatureFactory.newNumericalFeature("f3", -1));
        arrayList.add(FeatureFactory.newNumericalFeature("f4", -1));
        arrayList.add(FeatureFactory.newNumericalFeature("f5", -1));
        return new ShapSyntheticDataSample(new PredictionInput(arrayList), new boolean[]{true, true, false, false, true}, (double[][]) new double[]{new double[]{0.0d, 1.0d, 2.0d, 3.0d, 4.0d}, new double[]{5.0d, 6.0d, 7.0d, 8.0d, 9.0d}}, 0.5d, true);
    }

    private List<PredictionInput> generateExpectedSynthData() {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        arrayList.add(FeatureFactory.newNumericalFeature("f1", -1));
        arrayList.add(FeatureFactory.newNumericalFeature("f2", -1));
        arrayList.add(FeatureFactory.newNumericalFeature("f3", Double.valueOf(2.0d)));
        arrayList.add(FeatureFactory.newNumericalFeature("f4", Double.valueOf(3.0d)));
        arrayList.add(FeatureFactory.newNumericalFeature("f5", -1));
        arrayList3.add(new PredictionInput(arrayList));
        arrayList2.add(FeatureFactory.newNumericalFeature("f1", -1));
        arrayList2.add(FeatureFactory.newNumericalFeature("f2", -1));
        arrayList2.add(FeatureFactory.newNumericalFeature("f3", Double.valueOf(7.0d)));
        arrayList2.add(FeatureFactory.newNumericalFeature("f4", Double.valueOf(8.0d)));
        arrayList2.add(FeatureFactory.newNumericalFeature("f5", -1));
        arrayList3.add(new PredictionInput(arrayList2));
        return arrayList3;
    }

    @Test
    void testSyntheticCreation() {
        ShapSyntheticDataSample generateShapSample = generateShapSample();
        List<PredictionInput> generateExpectedSynthData = generateExpectedSynthData();
        List syntheticData = generateShapSample.getSyntheticData();
        for (int i = 0; i < syntheticData.size(); i++) {
            List features = generateExpectedSynthData.get(i).getFeatures();
            List features2 = ((PredictionInput) syntheticData.get(i)).getFeatures();
            for (int i2 = 0; i2 < features2.size(); i2++) {
                Assertions.assertEquals(features2.get(i2), features.get(i2));
            }
        }
    }

    @Test
    void testIsFixed() {
        Assertions.assertTrue(generateShapSample().isFixed());
    }

    @Test
    void testGetMask() {
        Assertions.assertArrayEquals(new boolean[]{true, true, false, false, true}, generateShapSample().getMask());
    }

    @Test
    void testWeight() {
        ShapSyntheticDataSample generateShapSample = generateShapSample();
        Assertions.assertEquals(0.5d, generateShapSample.getWeight());
        generateShapSample.incrementWeight();
        Assertions.assertEquals(1.5d, generateShapSample.getWeight());
        generateShapSample.setWeight(2.5d);
        Assertions.assertEquals(2.5d, generateShapSample.getWeight());
    }
}
