package org.kie.kogito.explainability.local.lime;

import com.sun.xml.bind.v2.runtime.reflect.opt.Const;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
import org.apache.commons.lang3.tuple.Pair;
import org.kie.kogito.explainability.local.LocalExplainer;
import org.kie.kogito.explainability.local.LocalExplanationException;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.FeatureImportance;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.PerturbationContext;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.Saliency;
import org.kie.kogito.explainability.model.Type;
import org.kie.kogito.explainability.model.Value;
import org.kie.kogito.explainability.utils.DataUtils;
import org.kie.kogito.explainability.utils.LinearModel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/kie/kogito/explainability/local/lime/LimeExplainer.class */
public class LimeExplainer implements LocalExplainer<Map<String, Saliency>> {
    private static final Logger LOGGER = LoggerFactory.getLogger((Class<?>) LimeExplainer.class);
    private final LimeConfig limeConfig;

    public LimeExplainer() {
        this(new LimeConfig());
    }

    public LimeExplainer(LimeConfig limeConfig) {
        this.limeConfig = limeConfig;
    }

    public LimeConfig getLimeConfig() {
        return this.limeConfig;
    }

    @Override // org.kie.kogito.explainability.local.LocalExplainer
    public CompletableFuture<Map<String, Saliency>> explainAsync(Prediction prediction, PredictionProvider predictionProvider) {
        PredictionInput input = prediction.getInput();
        if (input.getFeatures().isEmpty()) {
            throw new LocalExplanationException("cannot explain a prediction whose input is empty");
        }
        List<Feature> features = DataUtils.linearizeInputs(List.of(input)).get(0).getFeatures();
        if (features.isEmpty()) {
            throw new LocalExplanationException("input features linearization failed");
        }
        return explainRetryCycle(predictionProvider, input, features, prediction.getOutput().getOutputs(), this.limeConfig.getNoOfRetries(), this.limeConfig.getNoOfSamples(), this.limeConfig.getPerturbationContext());
    }

    protected CompletableFuture<Map<String, Saliency>> explainRetryCycle(PredictionProvider predictionProvider, PredictionInput predictionInput, List<Feature> list, List<Output> list2, int i, int i2, PerturbationContext perturbationContext) {
        List<PredictionInput> perturbedInputs = getPerturbedInputs(predictionInput.getFeatures(), perturbationContext);
        return predictionProvider.predictAsync(perturbedInputs).thenCompose(list3 -> {
            PerturbationContext perturbationContext2;
            int i3;
            try {
                return CompletableFuture.completedFuture(getSaliencies(list, list2, getLimeInputs(list, list2, perturbedInputs, list3, i > 0)));
            } catch (DatasetNotSeparableException e) {
                if (i <= 0) {
                    throw e;
                }
                if (this.limeConfig.adaptDatasetVariance()) {
                    perturbationContext2 = new PerturbationContext(perturbationContext.getRandom(), Math.min(list.size() - 1, Math.max(perturbationContext.getNoOfPerturbations() + 1, list.size() / i)));
                    i3 = i2 + (this.limeConfig.getNoOfSamples() / this.limeConfig.getNoOfRetries());
                } else {
                    perturbationContext2 = perturbationContext;
                    i3 = i2;
                }
                return explainRetryCycle(predictionProvider, predictionInput, list, list2, i - 1, i3, perturbationContext2);
            }
        });
    }

    private List<LimeInputs> getLimeInputs(List<Feature> list, List<Output> list2, List<PredictionInput> list3, List<PredictionOutput> list4, boolean z) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < list2.size(); i++) {
            arrayList.add(prepareInputs(list3, list4, list, i, list2.get(i), z));
        }
        return arrayList;
    }

    private Map<String, Saliency> getSaliencies(List<Feature> list, List<Output> list2, List<LimeInputs> list3) {
        HashMap hashMap = new HashMap();
        for (int i = 0; i < list2.size(); i++) {
            LimeInputs limeInputs = list3.get(i);
            Output output = list2.get(i);
            getSaliency(list, hashMap, limeInputs, output);
            LOGGER.debug("weights set for output {}", output);
        }
        return hashMap;
    }

    private void getSaliency(List<Feature> list, Map<String, Saliency> map, LimeInputs limeInputs, Output output) {
        ArrayList arrayList = new ArrayList();
        List<Pair<double[], Double>> encodedTrainingSet = new DatasetEncoder(limeInputs.getPerturbedInputs(), limeInputs.getPerturbedOutputs(), list, output, this.limeConfig.getEncodingParams()).getEncodedTrainingSet();
        double[] sampleWeights = SampleWeighter.getSampleWeights(list, encodedTrainingSet, this.limeConfig.getProximityKernelWidth() * Math.sqrt(list.size()));
        double[] dArr = new double[list.size()];
        Arrays.fill(dArr, 1.0d);
        if (this.limeConfig.isPenalizeBalanceSparse()) {
            new IndependentSparseFeatureBalanceFilter().apply(dArr, list, encodedTrainingSet);
        }
        if (this.limeConfig.isProximityFilter()) {
            new ProximityFilter(this.limeConfig.getProximityThreshold(), this.limeConfig.getProximityFilteredDatasetMinimum().doubleValue()).apply(encodedTrainingSet, sampleWeights);
        }
        LinearModel linearModel = new LinearModel(list.size(), limeInputs.isClassification());
        if (!Double.isNaN(linearModel.fit(encodedTrainingSet, sampleWeights))) {
            int i = 0;
            Iterator<Feature> it = list.iterator();
            while (it.hasNext()) {
                arrayList.add(new FeatureImportance(it.next(), linearModel.getWeights()[i] * dArr[i]));
                i++;
            }
        }
        map.put(output.getName(), new Saliency(output, arrayList));
    }

    private LimeInputs prepareInputs(List<PredictionInput> list, List<PredictionOutput> list2, List<Feature> list3, int i, Output output, boolean z) {
        if (output.getValue() == null || output.getValue().getUnderlyingObject() == null) {
            return new LimeInputs(false, list3, output, Collections.emptyList(), Collections.emptyList());
        }
        Map<Double, Long> classBalance = getClassBalance(list2, output.getValue(), i);
        double longValue = classBalance.values().stream().max((v0, v1) -> {
            return v0.compareTo(v1);
        }).orElse(1L).longValue() / list.size();
        List list4 = (List) list2.stream().map(predictionOutput -> {
            return predictionOutput.getOutputs().get(i);
        }).collect(Collectors.toList());
        boolean z2 = classBalance.size() == 2;
        if (!z) {
            LOGGER.warn("Using an hardly separable dataset for output '{}' of type '{}' with value '{}' ({})", output.getName(), output.getType(), output.getValue(), classBalance);
            return new LimeInputs(z2, list3, output, list, list4);
        }
        if (classBalance.size() <= 1 || longValue >= this.limeConfig.getSeparableDatasetRatio()) {
            throw new DatasetNotSeparableException(output, classBalance);
        }
        return new LimeInputs(z2, list3, output, list, list4);
    }

    private Map<Double, Long> getClassBalance(List<PredictionOutput> list, Value value, int i) {
        Map<Double, Long> map = (Map) list.stream().map(predictionOutput -> {
            return predictionOutput.getOutputs().get(i);
        }).map(output -> {
            return Double.valueOf(toDouble(output, value));
        }).collect(Collectors.groupingBy((v0) -> {
            return v0.doubleValue();
        }, Collectors.counting()));
        LOGGER.debug("raw samples per class: {}", map);
        return map;
    }

    private double toDouble(Output output, Value value) {
        if (Type.NUMBER.equals(output.getType())) {
            return output.getValue().asNumber();
        }
        boolean z = output.getValue().getUnderlyingObject() == null && value.getUnderlyingObject() == null;
        boolean z2 = output.getValue().getUnderlyingObject() != null && output.getValue().asString().equals(value.asString());
        if (z || z2) {
            return 1.0d;
        }
        return Const.default_value_double;
    }

    private List<PredictionInput> getPerturbedInputs(List<Feature> list, PerturbationContext perturbationContext) {
        ArrayList arrayList = new ArrayList();
        double max = Math.max(this.limeConfig.getNoOfSamples(), Math.pow(2.0d, list.size()));
        for (int i = 0; i < max; i++) {
            arrayList.add(new PredictionInput(DataUtils.perturbFeatures(list, perturbationContext)));
        }
        return arrayList;
    }
}
