package org.kie.kogito.explainability.local.shap;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Random;
import java.util.concurrent.ForkJoinPool;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.kie.kogito.explainability.local.shap.ShapConfig;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.FeatureFactory;
import org.kie.kogito.explainability.model.PerturbationContext;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.utils.MatrixUtils;

/* loaded from: input_file:org/kie/kogito/explainability/local/shap/ShapConfigTest.class */
class ShapConfigTest {
    PerturbationContext pc = new PerturbationContext(new Random(), 0);
    List<Feature> fs = Arrays.asList(FeatureFactory.newNumericalFeature("f", Double.valueOf(1.0d)), FeatureFactory.newNumericalFeature("f", Double.valueOf(2.0d)));
    PredictionInput pi = new PredictionInput(this.fs);
    List<PredictionInput> pis = Arrays.asList(this.pi, this.pi);
    List<PredictionInput> piEmpty = new ArrayList();
    double[][] piMatrix = MatrixUtils.matrixFromPredictionInput(this.pis);

    ShapConfigTest() {
    }

    @Test
    void testRecovery() {
        ForkJoinPool commonPool = ForkJoinPool.commonPool();
        ShapConfig build = ShapConfig.builder().withLink(ShapConfig.LinkType.IDENTITY).withBackground(this.pis).withPC(this.pc).withExecutor(commonPool).withNSamples(100).withConfidence(0.99d).build();
        Assertions.assertEquals(ShapConfig.LinkType.IDENTITY, build.getLink());
        Assertions.assertTrue(build.getNSamples().isPresent());
        Assertions.assertEquals(100, (Integer) build.getNSamples().get());
        Assertions.assertEquals(0.99d, build.getConfidence());
        Assertions.assertSame(this.pc, build.getPC());
        Assertions.assertSame(commonPool, build.getExecutor());
        Assertions.assertSame(this.pis, build.getBackground());
        Assertions.assertTrue(Arrays.deepEquals(this.piMatrix, build.getBackgroundMatrix()));
    }

    @Test
    void testNullRecovery() {
        ShapConfig build = ShapConfig.builder().withLink(ShapConfig.LinkType.LOGIT).withBackground(this.pis).build();
        Assertions.assertEquals(ShapConfig.LinkType.LOGIT, build.getLink());
        Assertions.assertFalse(build.getNSamples().isPresent());
        Assertions.assertSame(this.pis, build.getBackground());
        Assertions.assertTrue(Arrays.deepEquals(this.piMatrix, build.getBackgroundMatrix()));
        Assertions.assertSame(ForkJoinPool.commonPool(), build.getExecutor());
        Assertions.assertFalse(build.getNSamples().isPresent());
    }

    @Test
    void testMandatoryErrors() {
        ShapConfig.Builder withLink = ShapConfig.builder().withLink(ShapConfig.LinkType.IDENTITY);
        ShapConfig.Builder withBackground = ShapConfig.builder().withBackground(this.pis);
        Objects.requireNonNull(withLink);
        Assertions.assertThrows(IllegalArgumentException.class, withLink::build);
        Objects.requireNonNull(withBackground);
        Assertions.assertThrows(IllegalArgumentException.class, withBackground::build);
    }

    @Test
    void testEmptyBackgroundMandatoryErrors() {
        ShapConfig.Builder withBackground = ShapConfig.builder().withLink(ShapConfig.LinkType.IDENTITY).withBackground(this.piEmpty);
        Objects.requireNonNull(withBackground);
        Assertions.assertThrows(IllegalArgumentException.class, withBackground::build);
    }
}
