package org.kie.kogito.explainability.local.counterfactual;

import java.math.BigDecimal;
import java.time.Duration;
import java.time.LocalTime;
import java.time.temporal.ChronoUnit;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;
import org.kie.kogito.explainability.Config;
import org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.Type;
import org.kie.kogito.explainability.utils.CompositeFeatureUtils;
import org.optaplanner.core.api.score.buildin.bendablebigdecimal.BendableBigDecimalScore;
import org.optaplanner.core.api.score.calculator.EasyScoreCalculator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/kie/kogito/explainability/local/counterfactual/CounterFactualScoreCalculator.class */
public class CounterFactualScoreCalculator implements EasyScoreCalculator<CounterfactualSolution, BendableBigDecimalScore> {
    private static final double DEFAULT_DISTANCE = 1.0d;
    private static final Logger logger = LoggerFactory.getLogger((Class<?>) CounterFactualScoreCalculator.class);
    private static final Set<Type> SUPPORTED_CATEGORICAL_TYPES = Set.of(Type.CATEGORICAL, Type.BOOLEAN, Type.TEXT, Type.CURRENCY, Type.BINARY, Type.UNDEFINED);

    public static Double outputDistance(Output output, Output output2) throws IllegalArgumentException {
        return outputDistance(output, output2, 0.0d);
    }

    public static Double outputDistance(Output output, Output output2, double d) throws IllegalArgumentException {
        Type type = output.getType();
        Type type2 = output2.getType();
        if (type != type2) {
            if (!Objects.nonNull(output.getValue().getUnderlyingObject())) {
                return Double.valueOf(DEFAULT_DISTANCE);
            }
            String format = String.format("Features must have the same type. Feature '%s', has type '%s' and '%s'", output.getName(), type.toString(), type2.toString());
            logger.error(format);
            throw new IllegalArgumentException(format);
        }
        if (type == Type.NUMBER) {
            double asNumber = output.getValue().asNumber();
            double asNumber2 = output2.getValue().asNumber();
            double abs = Math.abs(asNumber - asNumber2);
            if (!Double.isNaN(asNumber) && !Double.isNaN(asNumber2)) {
                double max = (asNumber == 0.0d || asNumber2 == 0.0d) ? abs : abs / Math.max(asNumber, asNumber2);
                return max < d ? Double.valueOf(0.0d) : Double.valueOf(max);
            }
            String format2 = String.format("Unsupported NaN or NULL for numeric feature '%s'", output.getName());
            logger.error(format2);
            throw new IllegalArgumentException(format2);
        }
        if (type == Type.DURATION) {
            Duration duration = (Duration) output.getValue().getUnderlyingObject();
            Duration duration2 = (Duration) output2.getValue().getUnderlyingObject();
            if (Objects.isNull(duration) || Objects.isNull(duration2)) {
                return Double.valueOf(DEFAULT_DISTANCE);
            }
            double seconds = duration.minus(duration2).abs().getSeconds();
            double max2 = (duration.isZero() || duration2.isZero()) ? seconds : seconds / Math.max(duration.getSeconds(), duration2.getSeconds());
            return max2 < d ? Double.valueOf(0.0d) : Double.valueOf(max2);
        }
        if (type != Type.TIME) {
            if (SUPPORTED_CATEGORICAL_TYPES.contains(type)) {
                return Double.valueOf(Objects.equals(output2.getValue().getUnderlyingObject(), output.getValue().getUnderlyingObject()) ? 0.0d : DEFAULT_DISTANCE);
            }
            String format3 = String.format("Feature '%s' has unsupported type '%s'", output.getName(), type.toString());
            logger.error(format3);
            throw new IllegalArgumentException(format3);
        }
        LocalTime localTime = (LocalTime) output.getValue().getUnderlyingObject();
        LocalTime localTime2 = (LocalTime) output2.getValue().getUnderlyingObject();
        if (Objects.isNull(localTime) || Objects.isNull(localTime2)) {
            return Double.valueOf(DEFAULT_DISTANCE);
        }
        double abs2 = Math.abs(localTime.until(localTime2, ChronoUnit.SECONDS)) / LocalTime.MIN.until(LocalTime.MAX, ChronoUnit.SECONDS);
        return abs2 < d ? Double.valueOf(0.0d) : Double.valueOf(abs2);
    }

    private BendableBigDecimalScore calculateInputScore(CounterfactualSolution counterfactualSolution) {
        StringBuilder sb = new StringBuilder();
        int i = 0;
        int i2 = 0;
        double d = 0.0d;
        int size = counterfactualSolution.getEntities().size();
        for (CounterfactualEntity counterfactualEntity : counterfactualSolution.getEntities()) {
            double similarity = counterfactualEntity.similarity();
            d += similarity / size;
            Feature asFeature = counterfactualEntity.asFeature();
            sb.append(String.format("%s=%s (d:%f)", asFeature.getName(), asFeature.getValue().getUnderlyingObject(), Double.valueOf(similarity)));
            if (counterfactualEntity.isChanged()) {
                i--;
                if (counterfactualEntity.isConstrained()) {
                    i2--;
                }
            }
        }
        logger.debug("Current solution: {}", sb);
        double d2 = -Math.sqrt(Math.abs(DEFAULT_DISTANCE - d));
        logger.debug("Changed constraints penalty: {}", Integer.valueOf(i2));
        logger.debug("Feature distance: {}", Double.valueOf(-Math.abs(d2)));
        return BendableBigDecimalScore.of(new BigDecimal[]{BigDecimal.ZERO, BigDecimal.valueOf(i2), BigDecimal.ZERO}, new BigDecimal[]{BigDecimal.valueOf(-Math.abs(d2)), BigDecimal.valueOf(i)});
    }

    private BendableBigDecimalScore calculateOutputScore(CounterfactualSolution counterfactualSolution) {
        List<PredictionOutput> predictionOutputs = counterfactualSolution.getPredictionOutputs();
        List<Output> goal = counterfactualSolution.getGoal();
        double d = 0.0d;
        int i = 0;
        double d2 = 0.0d;
        Iterator<PredictionOutput> it = predictionOutputs.iterator();
        while (it.hasNext()) {
            List<Output> outputs = it.next().getOutputs();
            if (goal.size() != outputs.size()) {
                throw new IllegalArgumentException("Prediction size must be equal to goal size");
            }
            int size = outputs.size();
            for (int i2 = 0; i2 < size; i2++) {
                Output output = outputs.get(i2);
                Output output2 = goal.get(i2);
                double doubleValue = outputDistance(output, output2, counterfactualSolution.getGoalThreshold()).doubleValue();
                d += doubleValue * doubleValue;
                if (output.getScore() < output2.getScore()) {
                    i--;
                }
            }
            d2 -= Math.sqrt(d);
            logger.debug("Distance penalty: {}", Double.valueOf(d2));
            logger.debug("Confidence threshold penalty: {}", Integer.valueOf(i));
        }
        return BendableBigDecimalScore.of(new BigDecimal[]{BigDecimal.valueOf(d2), BigDecimal.ZERO, BigDecimal.valueOf(i)}, new BigDecimal[]{BigDecimal.ZERO, BigDecimal.ZERO});
    }

    @Override // org.optaplanner.core.api.score.calculator.EasyScoreCalculator
    public BendableBigDecimalScore calculateScore(CounterfactualSolution counterfactualSolution) {
        BendableBigDecimalScore calculateInputScore = calculateInputScore(counterfactualSolution);
        try {
            counterfactualSolution.setPredictionOutputs(counterfactualSolution.getModel().predictAsync(List.of(new PredictionInput(CompositeFeatureUtils.unflattenFeatures((List) counterfactualSolution.getEntities().stream().map((v0) -> {
                return v0.asFeature();
            }).collect(Collectors.toList()), counterfactualSolution.getOriginalFeatures())))).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit()));
            calculateInputScore = calculateInputScore.add(calculateOutputScore(counterfactualSolution));
        } catch (InterruptedException e) {
            logger.error("Interrupted while waiting for prediction {}", e.getMessage());
            Thread.currentThread().interrupt();
        } catch (ExecutionException e2) {
            logger.error("Prediction returned an error {}", e2.getMessage());
        } catch (TimeoutException e3) {
            logger.error("Timed out while waiting for prediction");
        }
        return calculateInputScore;
    }
}
