package org.kie.kogito.explainability.local.counterfactual;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.kie.kogito.explainability.Config;
import org.kie.kogito.explainability.TestUtils;
import org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity;
import org.kie.kogito.explainability.model.CounterfactualPrediction;
import org.kie.kogito.explainability.model.DataDistribution;
import org.kie.kogito.explainability.model.DataDomain;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.FeatureFactory;
import org.kie.kogito.explainability.model.NumericFeatureDistribution;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.PerturbationContext;
import org.kie.kogito.explainability.model.PredictionFeatureDomain;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.Type;
import org.kie.kogito.explainability.model.Value;
import org.kie.kogito.explainability.model.domain.EmptyFeatureDomain;
import org.kie.kogito.explainability.model.domain.NumericalFeatureDomain;
import org.kie.kogito.explainability.utils.DataUtils;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatchers;
import org.mockito.Mockito;
import org.optaplanner.core.api.score.buildin.bendablebigdecimal.BendableBigDecimalScore;
import org.optaplanner.core.api.solver.SolverJob;
import org.optaplanner.core.api.solver.SolverManager;
import org.optaplanner.core.config.solver.EnvironmentMode;
import org.optaplanner.core.config.solver.SolverConfig;
import org.optaplanner.core.config.solver.termination.TerminationConfig;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/kie/kogito/explainability/local/counterfactual/CounterfactualExplainerTest.class */
class CounterfactualExplainerTest {
    final long predictionTimeOut = 10;
    final TimeUnit predictionTimeUnit = TimeUnit.MINUTES;
    final Long steps = 30000L;
    final double DEFAULT_GOAL_THRESHOLD = 0.01d;
    private Function<SolverConfig, SolverManager<CounterfactualSolution, UUID>> solverManagerFactory;
    private SolverManager<CounterfactualSolution, UUID> solverManager;
    private static final Logger logger = LoggerFactory.getLogger(CounterfactualExplainerTest.class);
    private static final Long MAX_RUNNING_TIME_SECONDS = 60L;

    CounterfactualExplainerTest() {
    }

    @BeforeEach
    private void setup() {
        this.solverManagerFactory = (Function) Mockito.mock(Function.class);
        this.solverManager = (SolverManager) Mockito.mock(SolverManager.class);
    }

    private CounterfactualResult runCounterfactualSearch(Long l, List<Output> list, List<Boolean> list2, DataDomain dataDomain, List<Feature> list3, PredictionProvider predictionProvider, double d) throws InterruptedException, ExecutionException, TimeoutException {
        SolverConfig build = SolverConfigBuilder.builder().withTerminationConfig(new TerminationConfig().withScoreCalculationCountLimit(this.steps)).build();
        build.setRandomSeed(l);
        build.setEnvironmentMode(EnvironmentMode.REPRODUCIBLE);
        CounterfactualConfig counterfactualConfig = new CounterfactualConfig();
        counterfactualConfig.withSolverConfig(build).withGoalThreshold(d);
        return (CounterfactualResult) new CounterfactualExplainer(counterfactualConfig).explainAsync(new CounterfactualPrediction(new PredictionInput(list3), new PredictionOutput(list), new PredictionFeatureDomain(dataDomain.getFeatureDomains()), list2, (DataDistribution) null, UUID.randomUUID(), (Long) null), predictionProvider).get(10L, this.predictionTimeUnit);
    }

    @ValueSource(ints = {0, 1, 2})
    @ParameterizedTest
    void testNonEmptyInput(int i) throws ExecutionException, InterruptedException, TimeoutException {
        new Random().setSeed(i);
        List of = List.of(new Output("class", Type.NUMBER, new Value(Double.valueOf(10.0d)), 0.0d));
        LinkedList linkedList = new LinkedList();
        LinkedList linkedList2 = new LinkedList();
        LinkedList linkedList3 = new LinkedList();
        for (int i2 = 0; i2 < 4; i2++) {
            linkedList.add(TestUtils.getMockedNumericFeature(i2));
            linkedList2.add(NumericalFeatureDomain.create(0.0d, 1000.0d));
            linkedList3.add(false);
        }
        SolverConfig build = SolverConfigBuilder.builder().withTerminationConfig(new TerminationConfig().withScoreCalculationCountLimit(10L)).build();
        build.setRandomSeed(Long.valueOf(i));
        build.setEnvironmentMode(EnvironmentMode.REPRODUCIBLE);
        CounterfactualResult counterfactualResult = (CounterfactualResult) new CounterfactualExplainer(new CounterfactualConfig().withSolverConfig(build)).explainAsync(new CounterfactualPrediction(new PredictionInput(linkedList), new PredictionOutput(of), new PredictionFeatureDomain(linkedList2), linkedList3, (DataDistribution) null, UUID.randomUUID(), (Long) null), TestUtils.getSumSkipModel(0)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        Iterator it = counterfactualResult.getEntities().iterator();
        while (it.hasNext()) {
            logger.debug("Entity: {}", (CounterfactualEntity) it.next());
        }
        logger.debug("Outputs: {}", ((PredictionOutput) counterfactualResult.getOutput().get(0)).getOutputs());
        Assertions.assertNotNull(counterfactualResult);
        Assertions.assertNotNull(counterfactualResult.getEntities());
    }

    @ValueSource(ints = {0, 1, 2})
    @ParameterizedTest
    void testCounterfactualMatch(int i) throws ExecutionException, InterruptedException, TimeoutException {
        new Random().setSeed(i);
        List<Output> of = List.of(new Output("inside", Type.BOOLEAN, new Value(true), 0.0d));
        LinkedList linkedList = new LinkedList();
        LinkedList linkedList2 = new LinkedList();
        LinkedList linkedList3 = new LinkedList();
        linkedList.add(FeatureFactory.newNumericalFeature("f-num1", Double.valueOf(100.0d)));
        linkedList3.add(false);
        linkedList2.add(NumericalFeatureDomain.create(0.0d, 1000.0d));
        linkedList.add(FeatureFactory.newNumericalFeature("f-num2", Double.valueOf(150.0d)));
        linkedList3.add(false);
        linkedList2.add(NumericalFeatureDomain.create(0.0d, 1000.0d));
        linkedList.add(FeatureFactory.newNumericalFeature("f-num3", Double.valueOf(1.0d)));
        linkedList3.add(false);
        linkedList2.add(NumericalFeatureDomain.create(0.0d, 1000.0d));
        linkedList.add(FeatureFactory.newNumericalFeature("f-num4", Double.valueOf(2.0d)));
        linkedList3.add(false);
        linkedList2.add(NumericalFeatureDomain.create(0.0d, 1000.0d));
        CounterfactualResult runCounterfactualSearch = runCounterfactualSearch(Long.valueOf(i), of, linkedList3, new DataDomain(linkedList2), linkedList, TestUtils.getSumThresholdModel(500.0d, 10.0d), 0.01d);
        double d = 0.0d;
        for (CounterfactualEntity counterfactualEntity : runCounterfactualSearch.getEntities()) {
            d += counterfactualEntity.asFeature().getValue().asNumber();
            logger.debug("Entity: {}", counterfactualEntity);
        }
        logger.debug("Outputs: {}", ((PredictionOutput) runCounterfactualSearch.getOutput().get(0)).getOutputs());
        Assertions.assertTrue(d <= 510.0d);
        Assertions.assertTrue(d >= 490.0d);
        Assertions.assertTrue(runCounterfactualSearch.isValid());
    }

    @ValueSource(ints = {0, 1, 2})
    @ParameterizedTest
    void testCounterfactualConstrainedMatchUnscaled(int i) throws ExecutionException, InterruptedException, TimeoutException {
        new Random().setSeed(i);
        List<Output> of = List.of(new Output("inside", Type.BOOLEAN, new Value(true), 0.0d));
        LinkedList linkedList = new LinkedList();
        LinkedList linkedList2 = new LinkedList();
        LinkedList linkedList3 = new LinkedList();
        linkedList.add(FeatureFactory.newNumericalFeature("f-num1", Double.valueOf(100.0d)));
        linkedList3.add(true);
        linkedList2.add(EmptyFeatureDomain.create());
        linkedList.add(FeatureFactory.newNumericalFeature("f-num2", Double.valueOf(100.0d)));
        linkedList3.add(false);
        linkedList2.add(NumericalFeatureDomain.create(0.0d, 1000.0d));
        linkedList.add(FeatureFactory.newNumericalFeature("f-num3", Double.valueOf(100.0d)));
        linkedList3.add(false);
        linkedList2.add(NumericalFeatureDomain.create(0.0d, 1000.0d));
        linkedList.add(FeatureFactory.newNumericalFeature("f-num4", Double.valueOf(100.0d)));
        linkedList3.add(true);
        linkedList2.add(EmptyFeatureDomain.create());
        CounterfactualResult runCounterfactualSearch = runCounterfactualSearch(Long.valueOf(i), of, linkedList3, new DataDomain(linkedList2), linkedList, TestUtils.getSumThresholdModel(500.0d, 10.0d), 0.01d);
        List<CounterfactualEntity> entities = runCounterfactualSearch.getEntities();
        double d = 0.0d;
        for (CounterfactualEntity counterfactualEntity : entities) {
            d += counterfactualEntity.asFeature().getValue().asNumber();
            logger.debug("Entity: {}", counterfactualEntity);
        }
        Assertions.assertFalse(((CounterfactualEntity) entities.get(0)).isChanged());
        Assertions.assertFalse(((CounterfactualEntity) entities.get(3)).isChanged());
        Assertions.assertTrue(d <= 510.0d);
        Assertions.assertTrue(d >= 490.0d);
        Assertions.assertTrue(runCounterfactualSearch.isValid());
    }

    @ValueSource(ints = {0, 1, 2})
    @ParameterizedTest
    void testCounterfactualConstrainedMatchScaled(int i) throws ExecutionException, InterruptedException, TimeoutException {
        new Random().setSeed(i);
        List<Output> of = List.of(new Output("inside", Type.BOOLEAN, new Value(true), 0.0d));
        LinkedList linkedList = new LinkedList();
        LinkedList linkedList2 = new LinkedList();
        LinkedList linkedList3 = new LinkedList();
        LinkedList linkedList4 = new LinkedList();
        Feature newNumericalFeature = FeatureFactory.newNumericalFeature("f-num1", Double.valueOf(100.0d));
        linkedList.add(newNumericalFeature);
        linkedList3.add(true);
        linkedList2.add(EmptyFeatureDomain.create());
        linkedList4.add(new NumericFeatureDistribution(newNumericalFeature, new NormalDistribution(500.0d, 1.1d).sample(1000)));
        Feature newNumericalFeature2 = FeatureFactory.newNumericalFeature("f-num2", Double.valueOf(100.0d));
        linkedList.add(newNumericalFeature2);
        linkedList3.add(false);
        linkedList2.add(NumericalFeatureDomain.create(0.0d, 1000.0d));
        linkedList4.add(new NumericFeatureDistribution(newNumericalFeature2, new NormalDistribution(430.0d, 1.7d).sample(1000)));
        Feature newNumericalFeature3 = FeatureFactory.newNumericalFeature("f-num3", Double.valueOf(100.0d));
        linkedList.add(newNumericalFeature3);
        linkedList3.add(false);
        linkedList2.add(NumericalFeatureDomain.create(0.0d, 1000.0d));
        linkedList4.add(new NumericFeatureDistribution(newNumericalFeature3, new NormalDistribution(470.0d, 2.9d).sample(1000)));
        Feature newNumericalFeature4 = FeatureFactory.newNumericalFeature("f-num4", Double.valueOf(100.0d));
        linkedList.add(newNumericalFeature4);
        linkedList3.add(true);
        linkedList2.add(EmptyFeatureDomain.create());
        linkedList4.add(new NumericFeatureDistribution(newNumericalFeature4, new NormalDistribution(2390.0d, 0.3d).sample(1000)));
        CounterfactualResult runCounterfactualSearch = runCounterfactualSearch(Long.valueOf(i), of, linkedList3, new DataDomain(linkedList2), linkedList, TestUtils.getSumThresholdModel(500.0d, 10.0d), 0.01d);
        List<CounterfactualEntity> entities = runCounterfactualSearch.getEntities();
        double d = 0.0d;
        for (CounterfactualEntity counterfactualEntity : entities) {
            d += counterfactualEntity.asFeature().getValue().asNumber();
            logger.debug("Entity: {}", counterfactualEntity);
        }
        Assertions.assertFalse(((CounterfactualEntity) entities.get(0)).isChanged());
        Assertions.assertFalse(((CounterfactualEntity) entities.get(3)).isChanged());
        Assertions.assertTrue(d <= 510.0d);
        Assertions.assertTrue(d >= 490.0d);
        Assertions.assertTrue(runCounterfactualSearch.isValid());
    }

    @ValueSource(ints = {0, 1, 2})
    @ParameterizedTest
    void testCounterfactualBoolean(int i) throws ExecutionException, InterruptedException, TimeoutException {
        new Random().setSeed(i);
        List<Output> of = List.of(new Output("inside", Type.BOOLEAN, new Value(true), 0.0d));
        LinkedList linkedList = new LinkedList();
        LinkedList linkedList2 = new LinkedList();
        LinkedList linkedList3 = new LinkedList();
        for (int i2 = 0; i2 < 4; i2++) {
            linkedList.add(TestUtils.getMockedNumericFeature(i2));
            linkedList2.add(NumericalFeatureDomain.create(0.0d, 1000.0d));
            linkedList3.add(false);
        }
        linkedList.add(FeatureFactory.newBooleanFeature("f-bool", true));
        linkedList2.add(EmptyFeatureDomain.create());
        linkedList3.add(false);
        linkedList3.set(2, true);
        CounterfactualResult runCounterfactualSearch = runCounterfactualSearch(Long.valueOf(i), of, linkedList3, new DataDomain(linkedList2), linkedList, TestUtils.getSumThresholdModel(500.0d, 10.0d), 0.01d);
        List<CounterfactualEntity> entities = runCounterfactualSearch.getEntities();
        double d = 0.0d;
        for (CounterfactualEntity counterfactualEntity : entities) {
            d += counterfactualEntity.asFeature().getValue().asNumber();
            logger.debug("Entity: {}", counterfactualEntity);
        }
        Assertions.assertFalse(((CounterfactualEntity) entities.get(2)).isChanged());
        Assertions.assertTrue(d <= 510.0d);
        Assertions.assertTrue(d >= 490.0d);
        Assertions.assertTrue(runCounterfactualSearch.isValid());
    }

    /* JADX WARN: Removed duplicated region for block: B:20:0x0214 A[SYNTHETIC] */
    /* JADX WARN: Removed duplicated region for block: B:24:0x0224 A[SYNTHETIC] */
    /* JADX WARN: Removed duplicated region for block: B:27:0x0234 A[SYNTHETIC] */
    /* JADX WARN: Removed duplicated region for block: B:30:0x0244 A[SYNTHETIC] */
    /* JADX WARN: Removed duplicated region for block: B:33:0x016f A[SYNTHETIC] */
    @org.junit.jupiter.params.provider.ValueSource(ints = {0, 1, 2})
    @org.junit.jupiter.params.ParameterizedTest
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    void testCounterfactualCategoricalStrictFail(int r11) throws java.util.concurrent.ExecutionException, java.lang.InterruptedException, java.util.concurrent.TimeoutException {
        /*
            Method dump skipped, instructions count: 644
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: org.kie.kogito.explainability.local.counterfactual.CounterfactualExplainerTest.testCounterfactualCategoricalStrictFail(int):void");
    }

    /* JADX WARN: Removed duplicated region for block: B:20:0x0214 A[SYNTHETIC] */
    /* JADX WARN: Removed duplicated region for block: B:24:0x0224 A[SYNTHETIC] */
    /* JADX WARN: Removed duplicated region for block: B:27:0x0234 A[SYNTHETIC] */
    /* JADX WARN: Removed duplicated region for block: B:30:0x0244 A[SYNTHETIC] */
    /* JADX WARN: Removed duplicated region for block: B:33:0x0171 A[SYNTHETIC] */
    @org.junit.jupiter.params.provider.ValueSource(ints = {0, 1, 2})
    @org.junit.jupiter.params.ParameterizedTest
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    void testCounterfactualCategoricalNotStrict(int r11) throws java.util.concurrent.ExecutionException, java.lang.InterruptedException, java.util.concurrent.TimeoutException {
        /*
            Method dump skipped, instructions count: 644
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: org.kie.kogito.explainability.local.counterfactual.CounterfactualExplainerTest.testCounterfactualCategoricalNotStrict(int):void");
    }

    @ValueSource(ints = {0, 1, 2})
    @ParameterizedTest
    void testCounterfactualMatchThreshold(int i) throws ExecutionException, InterruptedException, TimeoutException {
        new Random().setSeed(i);
        List<Output> of = List.of(new Output("inside", Type.BOOLEAN, new Value(true), 0.9d));
        LinkedList linkedList = new LinkedList();
        LinkedList linkedList2 = new LinkedList();
        LinkedList linkedList3 = new LinkedList();
        linkedList.add(FeatureFactory.newNumericalFeature("f-num1", Double.valueOf(100.0d)));
        linkedList3.add(false);
        linkedList2.add(NumericalFeatureDomain.create(0.0d, 1000.0d));
        linkedList.add(FeatureFactory.newNumericalFeature("f-num2", Double.valueOf(100.0d)));
        linkedList3.add(false);
        linkedList2.add(NumericalFeatureDomain.create(0.0d, 1000.0d));
        linkedList.add(FeatureFactory.newNumericalFeature("f-num3", Double.valueOf(100.0d)));
        linkedList3.add(false);
        linkedList2.add(NumericalFeatureDomain.create(0.0d, 1000.0d));
        linkedList.add(FeatureFactory.newNumericalFeature("f-num4", Double.valueOf(100.0d)));
        linkedList3.add(false);
        linkedList2.add(NumericalFeatureDomain.create(0.0d, 1000.0d));
        DataDomain dataDomain = new DataDomain(linkedList2);
        PredictionProvider sumThresholdModel = TestUtils.getSumThresholdModel(500.0d, 10.0d);
        CounterfactualResult runCounterfactualSearch = runCounterfactualSearch(Long.valueOf(i), of, linkedList3, dataDomain, linkedList, sumThresholdModel, 0.01d);
        List<CounterfactualEntity> entities = runCounterfactualSearch.getEntities();
        double d = 0.0d;
        for (CounterfactualEntity counterfactualEntity : entities) {
            d += counterfactualEntity.asFeature().getValue().asNumber();
            logger.debug("Entity: {}", counterfactualEntity);
        }
        Assertions.assertTrue(d <= 510.0d);
        Assertions.assertTrue(d >= 490.0d);
        double score = ((Output) ((PredictionOutput) ((List) sumThresholdModel.predictAsync(List.of(new PredictionInput((List) entities.stream().map((v0) -> {
            return v0.asFeature();
        }).collect(Collectors.toList())))).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0)).getOutputs().get(0)).getScore();
        logger.debug("Prediction score: {}", Double.valueOf(score));
        Assertions.assertTrue(score >= 0.9d);
        Assertions.assertTrue(runCounterfactualSearch.isValid());
    }

    @ValueSource(ints = {0, 1, 2})
    @ParameterizedTest
    void testCounterfactualMatchNoThreshold(int i) throws ExecutionException, InterruptedException, TimeoutException {
        new Random().setSeed(i);
        List<Output> of = List.of(new Output("inside", Type.BOOLEAN, new Value(true), 0.0d));
        LinkedList linkedList = new LinkedList();
        LinkedList linkedList2 = new LinkedList();
        LinkedList linkedList3 = new LinkedList();
        linkedList.add(FeatureFactory.newNumericalFeature("f-num1", Double.valueOf(100.0d)));
        linkedList3.add(false);
        linkedList2.add(NumericalFeatureDomain.create(0.0d, 1000.0d));
        linkedList.add(FeatureFactory.newNumericalFeature("f-num2", Double.valueOf(100.0d)));
        linkedList3.add(false);
        linkedList2.add(NumericalFeatureDomain.create(0.0d, 1000.0d));
        linkedList.add(FeatureFactory.newNumericalFeature("f-num3", Double.valueOf(100.0d)));
        linkedList3.add(false);
        linkedList2.add(NumericalFeatureDomain.create(0.0d, 1000.0d));
        linkedList.add(FeatureFactory.newNumericalFeature("f-num4", Double.valueOf(100.0d)));
        linkedList3.add(false);
        linkedList2.add(NumericalFeatureDomain.create(0.0d, 1000.0d));
        DataDomain dataDomain = new DataDomain(linkedList2);
        PredictionProvider sumThresholdModel = TestUtils.getSumThresholdModel(500.0d, 10.0d);
        CounterfactualResult runCounterfactualSearch = runCounterfactualSearch(Long.valueOf(i), of, linkedList3, dataDomain, linkedList, sumThresholdModel, 0.01d);
        List<CounterfactualEntity> entities = runCounterfactualSearch.getEntities();
        double d = 0.0d;
        for (CounterfactualEntity counterfactualEntity : entities) {
            d += counterfactualEntity.asFeature().getValue().asNumber();
            logger.debug("Entity: {}", counterfactualEntity);
        }
        Assertions.assertTrue(d <= 510.0d);
        Assertions.assertTrue(d >= 490.0d);
        double score = ((Output) ((PredictionOutput) ((List) sumThresholdModel.predictAsync(List.of(new PredictionInput((List) entities.stream().map((v0) -> {
            return v0.asFeature();
        }).collect(Collectors.toList())))).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0)).getOutputs().get(0)).getScore();
        logger.debug("Prediction score: {}", Double.valueOf(score));
        Assertions.assertTrue(score < 0.5d);
        Assertions.assertTrue(runCounterfactualSearch.isValid());
    }

    @ValueSource(ints = {0, 1, 2})
    @ParameterizedTest
    void testNoCounterfactualPossible(long j) throws ExecutionException, InterruptedException, TimeoutException {
        PerturbationContext perturbationContext = new PerturbationContext(Long.valueOf(j), new Random(), 4);
        List<Output> of = List.of(new Output("inside", Type.BOOLEAN, new Value(true), 0.0d));
        LinkedList linkedList = new LinkedList();
        LinkedList linkedList2 = new LinkedList();
        LinkedList linkedList3 = new LinkedList();
        linkedList.add(FeatureFactory.newNumericalFeature("f-num1", Double.valueOf(1.0d)));
        linkedList3.add(true);
        linkedList2.add(EmptyFeatureDomain.create());
        linkedList.add(FeatureFactory.newNumericalFeature("f-num2", Double.valueOf(1.0d)));
        linkedList3.add(false);
        linkedList2.add(NumericalFeatureDomain.create(0.0d, 2.0d));
        linkedList.add(FeatureFactory.newNumericalFeature("f-num3", Double.valueOf(1.0d)));
        linkedList3.add(false);
        linkedList2.add(NumericalFeatureDomain.create(0.0d, 2.0d));
        linkedList.add(FeatureFactory.newNumericalFeature("f-num4", Double.valueOf(1.0d)));
        linkedList3.add(true);
        linkedList2.add(EmptyFeatureDomain.create());
        Assertions.assertFalse(runCounterfactualSearch(Long.valueOf(j), of, linkedList3, new DataDomain(linkedList2), DataUtils.perturbFeatures(linkedList, perturbationContext), TestUtils.getSumThresholdModel(500.0d, 1.0d), 0.01d).isValid());
    }

    @ValueSource(ints = {0, 1, 2})
    @ParameterizedTest
    void testConsumers(int i) throws ExecutionException, InterruptedException, TimeoutException {
        new Random().setSeed(i);
        List of = List.of(new Output("inside", Type.BOOLEAN, new Value(true), 0.0d));
        LinkedList linkedList = new LinkedList();
        LinkedList linkedList2 = new LinkedList();
        LinkedList linkedList3 = new LinkedList();
        linkedList.add(FeatureFactory.newNumericalFeature("f-num1", Double.valueOf(100.0d)));
        linkedList3.add(false);
        linkedList2.add(NumericalFeatureDomain.create(0.0d, 1000.0d));
        linkedList.add(FeatureFactory.newNumericalFeature("f-num2", Double.valueOf(100.0d)));
        linkedList3.add(false);
        linkedList2.add(NumericalFeatureDomain.create(0.0d, 1000.0d));
        linkedList.add(FeatureFactory.newNumericalFeature("f-num3", Double.valueOf(100.0d)));
        linkedList3.add(false);
        linkedList2.add(NumericalFeatureDomain.create(0.0d, 1000.0d));
        linkedList.add(FeatureFactory.newNumericalFeature("f-num4", Double.valueOf(100.0d)));
        linkedList3.add(false);
        linkedList2.add(NumericalFeatureDomain.create(0.0d, 1000.0d));
        SolverConfig build = SolverConfigBuilder.builder().withTerminationConfig(new TerminationConfig().withScoreCalculationCountLimit(10000L)).build();
        build.setRandomSeed(Long.valueOf(i));
        build.setEnvironmentMode(EnvironmentMode.REPRODUCIBLE);
        Consumer consumer = (Consumer) Mockito.mock(Consumer.class);
        CounterfactualResult counterfactualResult = (CounterfactualResult) new CounterfactualExplainer(new CounterfactualConfig().withSolverConfig(build).withGoalThreshold(0.01d)).explainAsync(new CounterfactualPrediction(new PredictionInput(linkedList), new PredictionOutput(of), new PredictionFeatureDomain(linkedList2), linkedList3, (DataDistribution) null, UUID.randomUUID(), (Long) null), TestUtils.getSumThresholdModel(500.0d, 10.0d), consumer).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        Iterator it = counterfactualResult.getEntities().iterator();
        while (it.hasNext()) {
            logger.debug("Entity: {}", (CounterfactualEntity) it.next());
        }
        logger.debug("Outputs: {}", ((PredictionOutput) counterfactualResult.getOutput().get(0)).getOutputs());
        ((Consumer) Mockito.verify(consumer, Mockito.atLeast(1))).accept((CounterfactualResult) ArgumentMatchers.any());
    }

    @ValueSource(ints = {1, 2, 3, 5, 8})
    @ParameterizedTest
    void testSequenceIds(int i) throws ExecutionException, InterruptedException, TimeoutException {
        ArrayList arrayList = new ArrayList();
        Consumer<CounterfactualResult> consumer = counterfactualResult -> {
            arrayList.add(Long.valueOf(counterfactualResult.getSequenceId()));
        };
        ArgumentCaptor forClass = ArgumentCaptor.forClass(Consumer.class);
        CounterfactualResult mockExplainerInvocation = mockExplainerInvocation(consumer, null);
        ((SolverManager) Mockito.verify(this.solverManager)).solveAndListen((UUID) ArgumentMatchers.any(), (Function) ArgumentMatchers.any(), (Consumer) forClass.capture(), (BiConsumer) ArgumentMatchers.any());
        Consumer consumer2 = (Consumer) forClass.getValue();
        IntStream.range(0, i).forEach(i2 -> {
            CounterfactualSolution counterfactualSolution = (CounterfactualSolution) Mockito.mock(CounterfactualSolution.class);
            Mockito.when(counterfactualSolution.getScore()).thenReturn(BendableBigDecimalScore.zero(0, 0));
            consumer2.accept(counterfactualSolution);
        });
        arrayList.add(Long.valueOf(mockExplainerInvocation.getSequenceId()));
        Assertions.assertEquals(i + 1, arrayList.size());
        Assertions.assertEquals(i + 1, (int) arrayList.stream().distinct().count());
    }

    @Disabled("FAI-713")
    @ValueSource(ints = {0, 1, 2})
    @ParameterizedTest
    void testIntermediateUniqueIds(int i) throws ExecutionException, InterruptedException, TimeoutException {
        new Random().setSeed(i);
        List of = List.of(new Output("inside", Type.BOOLEAN, new Value(true), 0.0d));
        LinkedList linkedList = new LinkedList();
        LinkedList linkedList2 = new LinkedList();
        LinkedList linkedList3 = new LinkedList();
        linkedList.add(FeatureFactory.newNumericalFeature("f-num1", Double.valueOf(10.0d)));
        linkedList3.add(false);
        linkedList2.add(NumericalFeatureDomain.create(0.0d, 10000.0d));
        linkedList.add(FeatureFactory.newNumericalFeature("f-num2", Double.valueOf(10.0d)));
        linkedList3.add(false);
        linkedList2.add(NumericalFeatureDomain.create(0.0d, 10000.0d));
        linkedList.add(FeatureFactory.newNumericalFeature("f-num3", Double.valueOf(10.0d)));
        linkedList3.add(false);
        linkedList2.add(NumericalFeatureDomain.create(0.0d, 10000.0d));
        linkedList.add(FeatureFactory.newNumericalFeature("f-num4", Double.valueOf(10.0d)));
        linkedList3.add(false);
        linkedList2.add(NumericalFeatureDomain.create(0.0d, 10000.0d));
        PredictionProvider sumThresholdModel = TestUtils.getSumThresholdModel(400.0d, 10.0d);
        SolverConfig build = SolverConfigBuilder.builder().withTerminationConfig(new TerminationConfig().withBestScoreFeasible(true).withScoreCalculationCountLimit(10000L)).build();
        build.setRandomSeed(Long.valueOf(i));
        build.setEnvironmentMode(EnvironmentMode.REPRODUCIBLE);
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        Consumer consumer = counterfactualResult -> {
            arrayList.add(counterfactualResult.getSolutionId());
        };
        Consumer consumer2 = counterfactualResult2 -> {
            arrayList2.add(counterfactualResult2.getExecutionId());
        };
        CounterfactualExplainer counterfactualExplainer = new CounterfactualExplainer(new CounterfactualConfig().withSolverConfig(build));
        PredictionInput predictionInput = new PredictionInput(linkedList);
        PredictionOutput predictionOutput = new PredictionOutput(of);
        UUID randomUUID = UUID.randomUUID();
        Iterator it = ((CounterfactualResult) counterfactualExplainer.explainAsync(new CounterfactualPrediction(predictionInput, predictionOutput, new PredictionFeatureDomain(linkedList2), linkedList3, (DataDistribution) null, randomUUID, (Long) null), sumThresholdModel, consumer.andThen(consumer2)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).getEntities().iterator();
        while (it.hasNext()) {
            logger.debug("Entity: {}", (CounterfactualEntity) it.next());
        }
        Assertions.assertEquals((int) arrayList.stream().distinct().count(), arrayList.size());
        Assertions.assertEquals(1, (int) arrayList2.stream().distinct().count());
        Assertions.assertEquals(arrayList2.get(0), randomUUID);
    }

    @Disabled("FAI-713")
    @ValueSource(ints = {0, 1, 2})
    @ParameterizedTest
    void testFinalUniqueIds(int i) throws ExecutionException, InterruptedException, TimeoutException {
        new Random().setSeed(i);
        List of = List.of(new Output("inside", Type.BOOLEAN, new Value(true), 0.5d));
        LinkedList linkedList = new LinkedList();
        LinkedList linkedList2 = new LinkedList();
        LinkedList linkedList3 = new LinkedList();
        linkedList.add(FeatureFactory.newNumericalFeature("f-num1", Double.valueOf(10.0d)));
        linkedList3.add(false);
        linkedList2.add(NumericalFeatureDomain.create(0.0d, 10000.0d));
        linkedList.add(FeatureFactory.newNumericalFeature("f-num2", Double.valueOf(10.0d)));
        linkedList3.add(false);
        linkedList2.add(NumericalFeatureDomain.create(0.0d, 10000.0d));
        linkedList.add(FeatureFactory.newNumericalFeature("f-num3", Double.valueOf(10.0d)));
        linkedList3.add(false);
        linkedList2.add(NumericalFeatureDomain.create(0.0d, 10000.0d));
        linkedList.add(FeatureFactory.newNumericalFeature("f-num4", Double.valueOf(10.0d)));
        linkedList3.add(false);
        linkedList2.add(NumericalFeatureDomain.create(0.0d, 10000.0d));
        PredictionProvider sumThresholdModel = TestUtils.getSumThresholdModel(400.0d, 10.0d);
        SolverConfig build = SolverConfigBuilder.builder().withTerminationConfig(new TerminationConfig().withBestScoreFeasible(true).withScoreCalculationCountLimit(10000L)).build();
        build.setRandomSeed(Long.valueOf(i));
        build.setEnvironmentMode(EnvironmentMode.REPRODUCIBLE);
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        Consumer consumer = counterfactualResult -> {
            arrayList.add(counterfactualResult.getSolutionId());
        };
        Consumer consumer2 = counterfactualResult2 -> {
            arrayList2.add(counterfactualResult2.getExecutionId());
        };
        CounterfactualExplainer counterfactualExplainer = new CounterfactualExplainer(new CounterfactualConfig().withSolverConfig(build));
        PredictionInput predictionInput = new PredictionInput(linkedList);
        PredictionOutput predictionOutput = new PredictionOutput(of);
        UUID randomUUID = UUID.randomUUID();
        CounterfactualResult counterfactualResult3 = (CounterfactualResult) counterfactualExplainer.explainAsync(new CounterfactualPrediction(predictionInput, predictionOutput, new PredictionFeatureDomain(linkedList2), linkedList3, (DataDistribution) null, randomUUID, (Long) null), sumThresholdModel, consumer.andThen(consumer2)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        Iterator it = counterfactualResult3.getEntities().iterator();
        while (it.hasNext()) {
            logger.debug("Entity: {}", (CounterfactualEntity) it.next());
        }
        Assertions.assertEquals((int) arrayList.stream().distinct().count(), arrayList.size());
        Assertions.assertTrue(arrayList.size() > 0);
        Assertions.assertTrue(arrayList2.size() > 0);
        Assertions.assertEquals(arrayList2.size(), arrayList.size());
        Assertions.assertEquals(1, (int) arrayList2.stream().distinct().count());
        Assertions.assertNotEquals(arrayList.get(arrayList.size() - 1), counterfactualResult3.getSolutionId());
        Assertions.assertEquals(arrayList2.get(0), randomUUID);
    }

    @ValueSource(ints = {0, 1, 2, 3, 4})
    @ParameterizedTest
    void testSparsity(int i) throws ExecutionException, InterruptedException, TimeoutException {
        new Random().setSeed(i);
        List<Output> of = List.of(new Output("inside", Type.BOOLEAN, new Value(true), 0.0d));
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        arrayList.add(FeatureFactory.newNumericalFeature("f-num1", 0));
        arrayList2.add(NumericalFeatureDomain.create(0.0d, 10.0d));
        arrayList3.add(false);
        arrayList2.add(NumericalFeatureDomain.create(0.0d, 10.0d));
        arrayList.add(FeatureFactory.newNumericalFeature("f-num2", 5));
        arrayList3.add(false);
        CounterfactualResult runCounterfactualSearch = runCounterfactualSearch(Long.valueOf(i), of, arrayList3, new DataDomain(arrayList2), arrayList, TestUtils.getSumThresholdModel(10.0d, 0.1d), 0.01d);
        Assertions.assertTrue((((CounterfactualEntity) runCounterfactualSearch.getEntities().get(0)).isChanged() && ((CounterfactualEntity) runCounterfactualSearch.getEntities().get(1)).isChanged()) ? false : true);
        Assertions.assertTrue(runCounterfactualSearch.isValid());
    }

    @Test
    void testTerminationSpentLimitWhenDefined() throws ExecutionException, InterruptedException, TimeoutException {
        ArgumentCaptor forClass = ArgumentCaptor.forClass(SolverConfig.class);
        mockExplainerInvocation((Consumer) Mockito.mock(Consumer.class), MAX_RUNNING_TIME_SECONDS);
        ((Function) Mockito.verify(this.solverManagerFactory)).apply((SolverConfig) forClass.capture());
        Assertions.assertEquals(MAX_RUNNING_TIME_SECONDS, ((SolverConfig) forClass.getValue()).getTerminationConfig().getSpentLimit().getSeconds());
    }

    @Test
    void testTerminationSpentLimitWhenUndefined() throws ExecutionException, InterruptedException, TimeoutException {
        ArgumentCaptor forClass = ArgumentCaptor.forClass(SolverConfig.class);
        mockExplainerInvocation((Consumer) Mockito.mock(Consumer.class), null);
        ((Function) Mockito.verify(this.solverManagerFactory)).apply((SolverConfig) forClass.capture());
        Assertions.assertNull(((SolverConfig) forClass.getValue()).getTerminationConfig().getSecondsSpentLimit());
    }

    CounterfactualResult mockExplainerInvocation(Consumer<CounterfactualResult> consumer, Long l) throws ExecutionException, InterruptedException, TimeoutException {
        SolverJob solverJob = (SolverJob) Mockito.mock(SolverJob.class);
        CounterfactualSolution counterfactualSolution = (CounterfactualSolution) Mockito.mock(CounterfactualSolution.class);
        BendableBigDecimalScore zero = BendableBigDecimalScore.zero(0, 0);
        Mockito.when(this.solverManager.solveAndListen((UUID) ArgumentMatchers.any(), (Function) ArgumentMatchers.any(), (Consumer) ArgumentMatchers.any(), (BiConsumer) ArgumentMatchers.any())).thenReturn(solverJob);
        Mockito.when((CounterfactualSolution) solverJob.getFinalBestSolution()).thenReturn(counterfactualSolution);
        Mockito.when(counterfactualSolution.getScore()).thenReturn(zero);
        Mockito.when(this.solverManagerFactory.apply((SolverConfig) ArgumentMatchers.any())).thenReturn(this.solverManager);
        return (CounterfactualResult) new CounterfactualExplainer(new CounterfactualConfig().withSolverManagerFactory(this.solverManagerFactory)).explainAsync(new CounterfactualPrediction(new PredictionInput(Collections.emptyList()), new PredictionOutput(Collections.emptyList()), new PredictionFeatureDomain(Collections.emptyList()), Collections.emptyList(), (DataDistribution) null, UUID.randomUUID(), l), list -> {
            return CompletableFuture.completedFuture(Collections.emptyList());
        }, consumer).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    }
}
