package org.kie.kogito.explainability.local.lime.optim;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ExecutionException;
import org.assertj.core.api.AssertionsForClassTypes;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.kie.kogito.explainability.TestUtils;
import org.kie.kogito.explainability.local.lime.LimeConfig;
import org.kie.kogito.explainability.model.PerturbationContext;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.utils.DataUtils;

/* loaded from: input_file:org/kie/kogito/explainability/local/lime/optim/LimeConfigOptimizerTest.class */
class LimeConfigOptimizerTest {
    LimeConfigOptimizerTest() {
    }

    @Test
    void testImpactOptimization() throws Exception {
        assertConfigOptimized(new LimeConfigOptimizer().withTimeLimit(10L).forImpactScore());
    }

    @Test
    void testImpactOptimizationNoSampling() throws Exception {
        assertConfigOptimized(new LimeConfigOptimizer().withTimeLimit(10L).forImpactScore().withSampling(false));
    }

    @Test
    void testImpactOptimizationNoWeighting() throws Exception {
        assertConfigOptimized(new LimeConfigOptimizer().withTimeLimit(10L).forImpactScore().withWeighting(false));
    }

    @Test
    void testImpactOptimizationNoEncoding() throws Exception {
        assertConfigOptimized(new LimeConfigOptimizer().withTimeLimit(10L).forImpactScore().withEncoding(false));
    }

    @Test
    void testImpactOptimizationNoProximity() throws Exception {
        assertConfigOptimized(new LimeConfigOptimizer().withTimeLimit(10L).forImpactScore().withProximity(false));
    }

    @Test
    void testImpactOptimizationNoEntity() {
        LimeConfigOptimizer withProximity = new LimeConfigOptimizer().forImpactScore().withSampling(false).withEncoding(false).withWeighting(false).withProximity(false);
        Assertions.assertThrows(AssertionError.class, () -> {
            assertConfigOptimized(withProximity);
        });
    }

    @Test
    void testStabilityOptimization() throws Exception {
        assertConfigOptimized(new LimeConfigOptimizer().withTimeLimit(10L).forStabilityScore());
    }

    @Test
    void testStabilityOptimizationNoSampling() throws Exception {
        assertConfigOptimized(new LimeConfigOptimizer().withTimeLimit(10L).forStabilityScore().withSampling(false));
    }

    @Test
    void testStabilityOptimizationNoWeighting() throws Exception {
        assertConfigOptimized(new LimeConfigOptimizer().withTimeLimit(10L).forStabilityScore().withWeighting(false));
    }

    @Test
    void testStabilityOptimizationNoEncoding() throws Exception {
        assertConfigOptimized(new LimeConfigOptimizer().withTimeLimit(10L).forStabilityScore().withEncoding(false));
    }

    @Test
    void testStabilityOptimizationNoProximity() throws Exception {
        assertConfigOptimized(new LimeConfigOptimizer().withTimeLimit(10L).forStabilityScore().withProximity(false));
    }

    @Test
    void testStabilityOptimizationNoEntity() {
        LimeConfigOptimizer withProximity = new LimeConfigOptimizer().forStabilityScore().withSampling(false).withEncoding(false).withWeighting(false).withProximity(false);
        Assertions.assertThrows(AssertionError.class, () -> {
            assertConfigOptimized(withProximity);
        });
    }

    @Test
    void testWeightedStabilityOptimization() throws Exception {
        assertConfigOptimized(new LimeConfigOptimizer().withTimeLimit(10L).withWeightedStability(0.5d, 0.5d));
        assertConfigOptimized(new LimeConfigOptimizer().withTimeLimit(10L).withWeightedStability(0.3d, 0.7d));
        assertConfigOptimized(new LimeConfigOptimizer().withTimeLimit(10L).withWeightedStability(0.7d, 0.3d));
        assertConfigOptimized(new LimeConfigOptimizer().withTimeLimit(10L).withWeightedStability(1.0d, 0.0d));
        assertConfigOptimized(new LimeConfigOptimizer().withTimeLimit(10L).withWeightedStability(0.0d, 1.0d));
    }

    @Test
    void testWeightedStabilityWrongParamsOptimization() {
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            new LimeConfigOptimizer().withWeightedStability(0.8d, 0.7d);
        });
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            new LimeConfigOptimizer().withWeightedStability(0.1d, 0.7d);
        });
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            new LimeConfigOptimizer().withWeightedStability(0.1d, 1.1d);
        });
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            new LimeConfigOptimizer().withWeightedStability(2.1d, 0.1d);
        });
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            new LimeConfigOptimizer().withWeightedStability(-0.1d, 0.9d);
        });
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            new LimeConfigOptimizer().withWeightedStability(0.1d, -0.9d);
        });
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            new LimeConfigOptimizer().withWeightedStability(0.1d, 0.99d);
        });
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            new LimeConfigOptimizer().withWeightedStability(0.009d, 0.99d);
        });
    }

    @Test
    void testSameConfig() throws ExecutionException, InterruptedException {
        ArrayList arrayList = new ArrayList();
        PredictionProvider sumSkipModel = TestUtils.getSumSkipModel(1);
        List sample = DataUtils.generateRandomDataDistribution(5, 100, new Random()).sample(3);
        List predictions = DataUtils.getPredictions(sample, (List) sumSkipModel.predictAsync(sample).get());
        for (int i = 0; i < 2; i++) {
            arrayList.add(new LimeConfigOptimizer().withDeterministicExecution(true).withStepCountLimit(10).withTimeLimit(10L).optimize(new LimeConfig().withSamples(10).withPerturbationContext(new PerturbationContext(0L, new Random(), 1)), predictions, sumSkipModel));
        }
        LimeConfig limeConfig = (LimeConfig) arrayList.get(0);
        LimeConfig limeConfig2 = (LimeConfig) arrayList.get(1);
        AssertionsForClassTypes.assertThat(limeConfig.getNoOfRetries()).isEqualTo(limeConfig2.getNoOfRetries());
        AssertionsForClassTypes.assertThat(limeConfig.getNoOfSamples()).isEqualTo(limeConfig2.getNoOfSamples());
        AssertionsForClassTypes.assertThat(limeConfig.getProximityFilteredDatasetMinimum()).isEqualTo(limeConfig2.getProximityFilteredDatasetMinimum());
        AssertionsForClassTypes.assertThat(limeConfig.getProximityKernelWidth()).isEqualTo(limeConfig2.getProximityKernelWidth());
        AssertionsForClassTypes.assertThat(limeConfig.getProximityThreshold()).isEqualTo(limeConfig2.getProximityThreshold());
        AssertionsForClassTypes.assertThat(limeConfig.isProximityFilter()).isEqualTo(limeConfig2.isProximityFilter());
        AssertionsForClassTypes.assertThat(limeConfig.isAdaptDatasetVariance()).isEqualTo(limeConfig2.isAdaptDatasetVariance());
        AssertionsForClassTypes.assertThat(limeConfig.isPenalizeBalanceSparse()).isEqualTo(limeConfig2.isPenalizeBalanceSparse());
        AssertionsForClassTypes.assertThat(limeConfig.getEncodingParams().getNumericTypeClusterGaussianFilterWidth()).isEqualTo(limeConfig2.getEncodingParams().getNumericTypeClusterGaussianFilterWidth());
        AssertionsForClassTypes.assertThat(limeConfig.getEncodingParams().getNumericTypeClusterThreshold()).isEqualTo(limeConfig2.getEncodingParams().getNumericTypeClusterThreshold());
        AssertionsForClassTypes.assertThat(limeConfig.getSeparableDatasetRatio()).isEqualTo(limeConfig2.getSeparableDatasetRatio());
        AssertionsForClassTypes.assertThat(limeConfig.getPerturbationContext().getNoOfPerturbations()).isEqualTo(limeConfig2.getPerturbationContext().getNoOfPerturbations());
    }

    private void assertConfigOptimized(LimeConfigOptimizer limeConfigOptimizer) throws InterruptedException, ExecutionException {
        LimeConfig withSamples = new LimeConfig().withSamples(10);
        PredictionProvider sumSkipModel = TestUtils.getSumSkipModel(1);
        Random random = new Random();
        random.setSeed(4L);
        List sample = DataUtils.generateRandomDataDistribution(5, 100, random).sample(10);
        LimeConfig optimize = limeConfigOptimizer.optimize(withSamples, DataUtils.getPredictions(sample, (List) sumSkipModel.predictAsync(sample).get()), sumSkipModel);
        AssertionsForClassTypes.assertThat(optimize).isNotNull();
        org.assertj.core.api.Assertions.assertThat(optimize).isNotSameAs(withSamples);
    }
}
