package org.kie.kogito.explainability.utils;

import java.util.HashMap;
import java.util.List;
import java.util.Random;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;

/* loaded from: input_file:org/kie/kogito/explainability/utils/RandomChoiceTest.class */
class RandomChoiceTest {
    List<String> obj = List.of("a", "b", "c", "d", "e");

    RandomChoiceTest() {
    }

    @ValueSource(ints = {0, 1, 2})
    @ParameterizedTest
    void testOnlyOneWeight(int i) {
        Random random = new Random();
        random.setSeed(i);
        Assertions.assertEquals(List.of("c", "c", "c"), new RandomChoice(this.obj, List.of(Double.valueOf(0.0d), Double.valueOf(0.0d), Double.valueOf(1.0d), Double.valueOf(0.0d), Double.valueOf(0.0d))).sample(3, random));
    }

    @ValueSource(ints = {0, 1, 2})
    @ParameterizedTest
    void testTwoWeight(int i) {
        Random random = new Random();
        random.setSeed(i);
        List sample = new RandomChoice(this.obj, List.of(Double.valueOf(1.0d), Double.valueOf(0.0d), Double.valueOf(1.0d), Double.valueOf(0.0d), Double.valueOf(0.0d))).sample(5, random);
        for (int i2 = 0; i2 < sample.size(); i2++) {
            Assertions.assertTrue(((String) sample.get(i2)).equals("a") || ((String) sample.get(i2)).equals("c"));
        }
    }

    @Test
    void weightMismatch() {
        List of = List.of(Double.valueOf(1.0d), Double.valueOf(1.0d), Double.valueOf(0.0d));
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            new RandomChoice(this.obj, of);
        });
    }

    @ValueSource(ints = {0, 1, 2})
    @ParameterizedTest
    void testUniform(int i) {
        RandomChoice randomChoice = new RandomChoice(this.obj);
        Random random = new Random();
        random.setSeed(i);
        for (int i2 = 0; i2 < 100; i2++) {
            List<String> sample = randomChoice.sample(1000, random);
            HashMap hashMap = new HashMap();
            for (String str : sample) {
                hashMap.putIfAbsent(str, 0);
                hashMap.put(str, Integer.valueOf(((Integer) hashMap.get(str)).intValue() + 1));
            }
            for (String str2 : this.obj) {
                Assertions.assertTrue(((Integer) hashMap.get(str2)).intValue() > 70);
                Assertions.assertTrue(((Integer) hashMap.get(str2)).intValue() < 324);
            }
        }
    }

    @ValueSource(ints = {0, 1, 2})
    @ParameterizedTest
    void testMultiWeight(int i) {
        RandomChoice randomChoice = new RandomChoice(this.obj, List.of(Double.valueOf(5.0d), Double.valueOf(4.0d), Double.valueOf(3.0d), Double.valueOf(2.0d), Double.valueOf(1.0d)));
        Random random = new Random();
        random.setSeed(i);
        for (int i2 = 0; i2 < 100; i2++) {
            List<String> sample = randomChoice.sample(1000, random);
            HashMap hashMap = new HashMap();
            for (String str : sample) {
                hashMap.putIfAbsent(str, 0);
                hashMap.put(str, Integer.valueOf(((Integer) hashMap.get(str)).intValue() + 1));
            }
            Assertions.assertTrue(((Integer) hashMap.get("a")).intValue() > 171);
            Assertions.assertTrue(((Integer) hashMap.get("a")).intValue() < 475);
            Assertions.assertTrue(((Integer) hashMap.get("b")).intValue() > 118);
            Assertions.assertTrue(((Integer) hashMap.get("b")).intValue() < 401);
            Assertions.assertTrue(((Integer) hashMap.get("c")).intValue() > 70);
            Assertions.assertTrue(((Integer) hashMap.get("c")).intValue() < 324);
            Assertions.assertTrue(((Integer) hashMap.get("d")).intValue() > 28);
            Assertions.assertTrue(((Integer) hashMap.get("d")).intValue() < 242);
            Assertions.assertTrue(((Integer) hashMap.get("e")).intValue() > 28);
            Assertions.assertTrue(((Integer) hashMap.get("e")).intValue() < 151);
        }
    }
}
