package org.kie.kogito.explainability.local.counterfactual;

import java.util.List;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.ForkJoinPool;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.kie.kogito.explainability.local.LocalExplainer;
import org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity;
import org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntityFactory;
import org.kie.kogito.explainability.model.CounterfactualPrediction;
import org.kie.kogito.explainability.model.DataDistribution;
import org.kie.kogito.explainability.model.FeatureDistribution;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionFeatureDomain;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.domain.FeatureDomain;
import org.optaplanner.core.api.solver.SolverManager;
import org.optaplanner.core.config.solver.SolverConfig;
import org.optaplanner.core.config.solver.SolverManagerConfig;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/kie/kogito/explainability/local/counterfactual/CounterfactualExplainer.class */
public class CounterfactualExplainer implements LocalExplainer<CounterfactualResult> {
    private final SolverConfig solverConfig;
    private final Executor executor;
    private static final Logger logger = LoggerFactory.getLogger((Class<?>) CounterfactualExplainer.class);
    public static final Consumer<CounterfactualResult> defaultIntermediateConsumer = counterfactualResult -> {
        logger.debug("Intermediate counterfactual: {}", counterfactualResult.getEntities());
    };
    public static final Consumer<CounterfactualSolution> assignSolutionId = counterfactualSolution -> {
        counterfactualSolution.setSolutionId(UUID.randomUUID());
    };

    /* loaded from: input_file:org/kie/kogito/explainability/local/counterfactual/CounterfactualExplainer$Builder.class */
    public static class Builder {
        private Executor executor = ForkJoinPool.commonPool();
        private SolverConfig solverConfig = null;

        private Builder() {
        }

        public Builder withExecutor(Executor executor) {
            this.executor = executor;
            return this;
        }

        public Builder withSolverConfig(SolverConfig solverConfig) {
            this.solverConfig = solverConfig;
            return this;
        }

        public CounterfactualExplainer build() {
            if (this.solverConfig == null) {
                this.solverConfig = CounterfactualConfigurationFactory.builder().build();
            }
            return new CounterfactualExplainer(this.solverConfig, this.executor);
        }
    }

    public CounterfactualExplainer() {
        this.solverConfig = CounterfactualConfigurationFactory.builder().build();
        this.executor = ForkJoinPool.commonPool();
    }

    protected CounterfactualExplainer(SolverConfig solverConfig, Executor executor) {
        this.solverConfig = solverConfig;
        this.executor = executor;
    }

    public static Builder builder() {
        return new Builder();
    }

    private static List<CounterfactualEntity> createEntities(PredictionInput predictionInput, PredictionFeatureDomain predictionFeatureDomain, List<Boolean> list, DataDistribution dataDistribution) {
        List<FeatureDomain> featureDomains = predictionFeatureDomain.getFeatureDomains();
        return (List) IntStream.range(0, predictionInput.getFeatures().size()).mapToObj(i -> {
            return CounterfactualEntityFactory.from(predictionInput.getFeatures().get(i), (Boolean) list.get(i), (FeatureDomain) featureDomains.get(i), (FeatureDistribution) Optional.ofNullable(dataDistribution).map(dataDistribution2 -> {
                return dataDistribution2.asFeatureDistributions().get(i);
            }).orElse(null));
        }).collect(Collectors.toList());
    }

    private Consumer<CounterfactualSolution> createSolutionConsumer(Consumer<CounterfactualResult> consumer) {
        return counterfactualSolution -> {
            consumer.accept(new CounterfactualResult(counterfactualSolution.getEntities(), counterfactualSolution.getPredictionOutputs(), counterfactualSolution.getScore().isFeasible(), counterfactualSolution.getSolutionId(), counterfactualSolution.getExecutionId()));
        };
    }

    @Override // org.kie.kogito.explainability.local.LocalExplainer
    public CompletableFuture<CounterfactualResult> explainAsync(Prediction prediction, PredictionProvider predictionProvider) {
        CounterfactualPrediction counterfactualPrediction = (CounterfactualPrediction) prediction;
        PredictionFeatureDomain domain = counterfactualPrediction.getDomain();
        List<Boolean> constraints = counterfactualPrediction.getConstraints();
        UUID executionId = counterfactualPrediction.getExecutionId();
        List<CounterfactualEntity> createEntities = createEntities(prediction.getInput(), domain, constraints, counterfactualPrediction.getDataDistribution());
        List<Output> outputs = prediction.getOutput().getOutputs();
        Function function = uuid -> {
            return new CounterfactualSolution(createEntities, predictionProvider, outputs, UUID.randomUUID(), executionId);
        };
        CompletableFuture supplyAsync = CompletableFuture.supplyAsync(() -> {
            SolverManager create = SolverManager.create(this.solverConfig, new SolverManagerConfig());
            try {
                try {
                    CounterfactualSolution counterfactualSolution = (CounterfactualSolution) create.solveAndListen(executionId, function, assignSolutionId.andThen(createSolutionConsumer(counterfactualPrediction.getIntermediateConsumer() == null ? defaultIntermediateConsumer : counterfactualPrediction.getIntermediateConsumer())), (BiConsumer) null).getFinalBestSolution();
                    if (create != null) {
                        create.close();
                    }
                    return counterfactualSolution;
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                    throw new IllegalStateException("Solving failed (Thread interrupted)", e);
                } catch (ExecutionException e2) {
                    logger.error("Solving failed: {}", e2.getMessage());
                    throw new IllegalStateException("Prediction returned an error", e2);
                }
            } catch (Throwable th) {
                if (create != null) {
                    try {
                        create.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
                throw th;
            }
        }, this.executor);
        CompletableFuture thenCompose = supplyAsync.thenCompose(counterfactualSolution -> {
            return predictionProvider.predictAsync(List.of(new PredictionInput((List) counterfactualSolution.getEntities().stream().map((v0) -> {
                return v0.asFeature();
            }).collect(Collectors.toList()))));
        });
        return CompletableFuture.allOf(thenCompose, supplyAsync).thenApply(r10 -> {
            CounterfactualSolution counterfactualSolution2 = (CounterfactualSolution) supplyAsync.join();
            return new CounterfactualResult(counterfactualSolution2.getEntities(), (List) thenCompose.join(), counterfactualSolution2.getScore().isFeasible(), counterfactualSolution2.getSolutionId(), counterfactualSolution2.getExecutionId());
        });
    }
}
