package org.kie.kogito.explainability.local.counterfactual;

import java.time.Duration;
import java.util.List;
import java.util.Objects;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.kie.kogito.explainability.local.LocalExplainer;
import org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity;
import org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntityFactory;
import org.kie.kogito.explainability.model.CounterfactualPrediction;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.utils.CompositeFeatureUtils;
import org.optaplanner.core.api.solver.SolverManager;
import org.optaplanner.core.config.solver.SolverConfig;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/kie/kogito/explainability/local/counterfactual/CounterfactualExplainer.class */
public class CounterfactualExplainer implements LocalExplainer<CounterfactualResult> {
    public static final Consumer<CounterfactualSolution> assignSolutionId = counterfactualSolution -> {
        counterfactualSolution.setSolutionId(UUID.randomUUID());
    };
    private static final Logger logger = LoggerFactory.getLogger(CounterfactualExplainer.class);
    private final CounterfactualConfig counterfactualConfig;

    public CounterfactualExplainer() {
        this.counterfactualConfig = new CounterfactualConfig();
    }

    public CounterfactualExplainer(CounterfactualConfig counterfactualConfig) {
        this.counterfactualConfig = counterfactualConfig;
    }

    public CounterfactualConfig getCounterfactualConfig() {
        return this.counterfactualConfig;
    }

    private Consumer<CounterfactualSolution> createSolutionConsumer(Consumer<CounterfactualResult> consumer, AtomicLong atomicLong) {
        return counterfactualSolution -> {
            if (counterfactualSolution.getScore().isFeasible()) {
                List<CounterfactualEntity> entities = counterfactualSolution.getEntities();
                consumer.accept(new CounterfactualResult(entities, CompositeFeatureUtils.unflattenFeatures((List) entities.stream().map((v0) -> {
                    return v0.asFeature();
                }).collect(Collectors.toList()), counterfactualSolution.getOriginalFeatures()), counterfactualSolution.getPredictionOutputs(), counterfactualSolution.getScore().isFeasible(), counterfactualSolution.getSolutionId(), counterfactualSolution.getExecutionId(), atomicLong.incrementAndGet()));
            }
        };
    }

    private static List<PredictionInput> buildInput(List<CounterfactualEntity> list) {
        return List.of(new PredictionInput((List) list.stream().map((v0) -> {
            return v0.asFeature();
        }).collect(Collectors.toList())));
    }

    @Override // org.kie.kogito.explainability.local.LocalExplainer
    public CompletableFuture<CounterfactualResult> explainAsync(Prediction prediction, PredictionProvider predictionProvider, Consumer<CounterfactualResult> consumer) {
        AtomicLong atomicLong = new AtomicLong(0L);
        CounterfactualPrediction counterfactualPrediction = (CounterfactualPrediction) prediction;
        UUID executionId = counterfactualPrediction.getExecutionId();
        Long maxRunningTimeSeconds = counterfactualPrediction.getMaxRunningTimeSeconds();
        List<CounterfactualEntity> createEntities = CounterfactualEntityFactory.createEntities(prediction.getInput());
        List<Output> outputs = prediction.getOutput().getOutputs();
        List<Feature> features = prediction.getInput().getFeatures();
        Function function = uuid -> {
            return new CounterfactualSolution(createEntities, features, predictionProvider, outputs, UUID.randomUUID(), executionId, this.counterfactualConfig.getGoalThreshold());
        };
        CompletableFuture supplyAsync = CompletableFuture.supplyAsync(() -> {
            SolverConfig solverConfig = this.counterfactualConfig.getSolverConfig();
            if (Objects.nonNull(maxRunningTimeSeconds)) {
                solverConfig.withTerminationSpentLimit(Duration.ofSeconds(maxRunningTimeSeconds.longValue()));
            }
            SolverManager<CounterfactualSolution, UUID> apply = this.counterfactualConfig.getSolverManagerFactory().apply(solverConfig);
            try {
                try {
                    CounterfactualSolution counterfactualSolution = (CounterfactualSolution) apply.solveAndListen(executionId, function, assignSolutionId.andThen(createSolutionConsumer(consumer, atomicLong)), (BiConsumer) null).getFinalBestSolution();
                    if (apply != null) {
                        apply.close();
                    }
                    return counterfactualSolution;
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                    throw new IllegalStateException("Solving failed (Thread interrupted)", e);
                } catch (ExecutionException e2) {
                    logger.error("Solving failed: {}", e2.getMessage());
                    throw new IllegalStateException("Prediction returned an error", e2);
                }
            } catch (Throwable th) {
                if (apply != null) {
                    try {
                        apply.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
                throw th;
            }
        }, this.counterfactualConfig.getExecutor());
        CompletableFuture thenCompose = supplyAsync.thenCompose(counterfactualSolution -> {
            return predictionProvider.predictAsync(buildInput(counterfactualSolution.getEntities()));
        });
        return CompletableFuture.allOf(thenCompose, supplyAsync).thenApply(r14 -> {
            CounterfactualSolution counterfactualSolution2 = (CounterfactualSolution) supplyAsync.join();
            return new CounterfactualResult(counterfactualSolution2.getEntities(), counterfactualSolution2.getOriginalFeatures(), (List) thenCompose.join(), counterfactualSolution2.getScore().isFeasible(), UUID.randomUUID(), counterfactualSolution2.getExecutionId(), atomicLong.incrementAndGet());
        });
    }
}
