package org.kie.kogito.explainability.model;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.kie.kogito.explainability.TestUtils;

/* loaded from: input_file:org/kie/kogito/explainability/model/SaliencyTest.class */
class SaliencyTest {
    SaliencyTest() {
    }

    @Test
    void testGetTopFeatures() {
        ArrayList arrayList = new ArrayList();
        arrayList.add(new FeatureImportance(TestUtils.getMockedNumericFeature(), 0.19d));
        arrayList.add(new FeatureImportance(TestUtils.getMockedNumericFeature(), -0.44d));
        arrayList.add(new FeatureImportance(TestUtils.getMockedNumericFeature(), 0.04d));
        List topFeatures = new Saliency(new Output("name", Type.NUMBER), arrayList).getTopFeatures(2);
        Assertions.assertNotNull(topFeatures);
        Assertions.assertEquals(2, topFeatures.size());
        List list = (List) topFeatures.stream().map((v0) -> {
            return v0.getScore();
        }).collect(Collectors.toList());
        Assertions.assertTrue(list.contains(Double.valueOf(-0.44d)));
        Assertions.assertTrue(list.contains(Double.valueOf(0.19d)));
    }

    @Test
    void testGetPositiveFeatures() {
        ArrayList arrayList = new ArrayList();
        arrayList.add(new FeatureImportance(TestUtils.getMockedNumericFeature(), 0.19d));
        arrayList.add(new FeatureImportance(TestUtils.getMockedNumericFeature(), -0.44d));
        arrayList.add(new FeatureImportance(TestUtils.getMockedNumericFeature(), 0.04d));
        List positiveFeatures = new Saliency(new Output("name", Type.NUMBER), arrayList).getPositiveFeatures(2);
        Assertions.assertNotNull(positiveFeatures);
        Assertions.assertEquals(2, positiveFeatures.size());
        List list = (List) positiveFeatures.stream().map((v0) -> {
            return v0.getScore();
        }).collect(Collectors.toList());
        Assertions.assertTrue(list.contains(Double.valueOf(0.04d)));
        Assertions.assertTrue(list.contains(Double.valueOf(0.19d)));
    }

    @Test
    void testGetNegativeFeatures() {
        ArrayList arrayList = new ArrayList();
        arrayList.add(new FeatureImportance(TestUtils.getMockedNumericFeature(), 0.19d));
        arrayList.add(new FeatureImportance(TestUtils.getMockedNumericFeature(), -0.44d));
        arrayList.add(new FeatureImportance(TestUtils.getMockedNumericFeature(), 0.04d));
        List negativeFeatures = new Saliency(new Output("name", Type.NUMBER), arrayList).getNegativeFeatures(2);
        Assertions.assertNotNull(negativeFeatures);
        Assertions.assertEquals(1, negativeFeatures.size());
        Assertions.assertTrue(((List) negativeFeatures.stream().map((v0) -> {
            return v0.getScore();
        }).collect(Collectors.toList())).contains(Double.valueOf(-0.44d)));
    }

    @Test
    void testSameImportantFeatures() {
        ArrayList arrayList = new ArrayList();
        arrayList.add(new FeatureImportance(TestUtils.getMockedNumericFeature(), 0.1d));
        arrayList.add(new FeatureImportance(TestUtils.getMockedNumericFeature(), 0.1d));
        arrayList.add(new FeatureImportance(TestUtils.getMockedNumericFeature(), 0.1d));
        Saliency saliency = new Saliency(new Output("name", Type.NUMBER), arrayList);
        List topFeatures = saliency.getTopFeatures(2);
        Assertions.assertNotNull(topFeatures);
        Assertions.assertEquals(2, topFeatures.size());
        List list = (List) topFeatures.stream().map((v0) -> {
            return v0.getScore();
        }).collect(Collectors.toList());
        Assertions.assertTrue(list.contains(Double.valueOf(0.1d)));
        Assertions.assertTrue(list.contains(Double.valueOf(0.1d)));
        List negativeFeatures = saliency.getNegativeFeatures(2);
        Assertions.assertNotNull(negativeFeatures);
        Assertions.assertTrue(negativeFeatures.isEmpty());
        List positiveFeatures = saliency.getPositiveFeatures(2);
        Assertions.assertNotNull(positiveFeatures);
        Assertions.assertEquals(2, positiveFeatures.size());
    }

    @Test
    void testMergeSaliencyMaps() {
        ArrayList arrayList = new ArrayList();
        arrayList.add(new FeatureImportance(FeatureFactory.newTextFeature("f1", "foo"), 0.1d));
        arrayList.add(new FeatureImportance(FeatureFactory.newTextFeature("f2", "bar"), -0.4d));
        arrayList.add(new FeatureImportance(FeatureFactory.newNumericalFeature("f3", 10), 0.01d));
        Saliency saliency = new Saliency(new Output("out", Type.NUMBER), arrayList);
        ArrayList arrayList2 = new ArrayList();
        arrayList2.add(new FeatureImportance(FeatureFactory.newTextFeature("f1", "foo"), 0.2d));
        arrayList2.add(new FeatureImportance(FeatureFactory.newTextFeature("f2", "bar"), -0.2d));
        arrayList2.add(new FeatureImportance(FeatureFactory.newNumericalFeature("f3", 10), 0.03d));
        Saliency saliency2 = new Saliency(new Output("out", Type.NUMBER), arrayList2);
        HashMap hashMap = new HashMap();
        hashMap.put("out", saliency);
        HashMap hashMap2 = new HashMap();
        hashMap2.put("out", saliency2);
        Map merge = Saliency.merge(List.of(hashMap, hashMap2));
        Assertions.assertNotNull(merge);
        Assertions.assertEquals(1, merge.size());
        List perFeatureImportance = ((Saliency) merge.get("out")).getPerFeatureImportance();
        Assertions.assertNotNull(perFeatureImportance);
        Assertions.assertEquals(3, perFeatureImportance.size());
        Assertions.assertEquals(0.15d, ((FeatureImportance) perFeatureImportance.get(0)).getScore(), 0.001d);
        Assertions.assertEquals(-0.3d, ((FeatureImportance) perFeatureImportance.get(1)).getScore(), 0.001d);
        Assertions.assertEquals(0.02d, ((FeatureImportance) perFeatureImportance.get(2)).getScore(), 0.001d);
    }
}
