package org.kie.kogito.explainability.local.shap;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.math3.exception.MathArithmeticException;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.apache.commons.math3.util.CombinatoricsUtils;
import org.kie.kogito.explainability.local.LocalExplainer;
import org.kie.kogito.explainability.local.shap.ShapConfig;
import org.kie.kogito.explainability.model.FeatureImportance;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.Saliency;
import org.kie.kogito.explainability.utils.MatrixUtils;
import org.kie.kogito.explainability.utils.RandomChoice;
import org.kie.kogito.explainability.utils.WeightedLinearRegression;
import org.kie.kogito.explainability.utils.WeightedLinearRegressionResults;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/kie/kogito/explainability/local/shap/ShapKernelExplainer.class */
public class ShapKernelExplainer implements LocalExplainer<Saliency[]> {
    private static final Logger LOGGER = LoggerFactory.getLogger((Class<?>) ShapKernelExplainer.class);
    private ShapConfig config;

    public ShapKernelExplainer(ShapConfig shapConfig) {
        this.config = shapConfig;
    }

    public void setConfig(ShapConfig shapConfig) {
        this.config = shapConfig;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private ShapDataCarrier initialize(PredictionProvider predictionProvider) {
        int pow;
        int[] shape = MatrixUtils.getShape(this.config.getBackgroundMatrix());
        int i = shape[0];
        int i2 = shape[1];
        if (i > 100) {
            LOGGER.debug("Warning: Background data sets larger than 100 samples might be slow!");
        }
        CompletableFuture<U> thenApply = predictionProvider.predictAsync(this.config.getBackground()).thenApply(MatrixUtils::matrixFromPredictionOutput);
        CompletableFuture thenApply2 = thenApply.thenApply((Function<? super U, ? extends U>) dArr -> {
            return Integer.valueOf(MatrixUtils.getShape(dArr)[1]);
        });
        CompletableFuture thenApply3 = thenApply.thenApply((Function<? super U, ? extends U>) dArr2 -> {
            return MatrixUtils.sum(MatrixUtils.matrixMultiply(dArr2, 1.0d / i), MatrixUtils.Axis.ROW);
        });
        CompletableFuture thenApply4 = thenApply3.thenApply(dArr3 -> {
            return MatrixUtils.rowVector(link(dArr3));
        });
        int intValue = this.config.getNSamples().orElseGet(() -> {
            return Integer.valueOf(2048 + (2 * i2));
        }).intValue();
        if (i2 <= 30 && (pow = ((int) Math.pow(2.0d, i2)) - 2) < intValue) {
            intValue = pow;
        }
        ShapDataCarrier shapDataCarrier = new ShapDataCarrier();
        shapDataCarrier.setRows(i);
        shapDataCarrier.setCols(i2);
        shapDataCarrier.setOutputSize(thenApply2);
        shapDataCarrier.setModel(predictionProvider);
        shapDataCarrier.setFnull(thenApply3);
        shapDataCarrier.setLinkNull(thenApply4);
        shapDataCarrier.setNumSamples(Integer.valueOf(intValue));
        return shapDataCarrier;
    }

    private double link(double d) {
        return this.config.getLink().equals(ShapConfig.LinkType.IDENTITY) ? d : Math.log(d / (1.0d - d));
    }

    private double[] link(double[] dArr) {
        return Arrays.stream(dArr).map(this::link).toArray();
    }

    private void setVaryingFeatureGroups(PredictionInput predictionInput, ShapDataCarrier shapDataCarrier) {
        ArrayList arrayList = new ArrayList();
        double[] dArr = MatrixUtils.matrixFromPredictionInput(predictionInput)[0];
        double[] dArr2 = new double[shapDataCarrier.getRows() + 1];
        for (int i = 0; i < shapDataCarrier.getCols(); i++) {
            System.arraycopy(MatrixUtils.getCol(this.config.getBackgroundMatrix(), i), 0, dArr2, 0, shapDataCarrier.getRows());
            dArr2[shapDataCarrier.getRows()] = dArr[i];
            if (Arrays.stream(dArr2).distinct().count() > 1) {
                arrayList.add(Integer.valueOf(i));
            }
        }
        shapDataCarrier.setVaryingFeatureGroups(arrayList);
        shapDataCarrier.setNumVarying(arrayList.size());
    }

    private double[] normalizeWeightVector(double[] dArr) {
        double[][] rowVector = MatrixUtils.rowVector(dArr);
        double d = MatrixUtils.sum(rowVector, MatrixUtils.Axis.COLUMN)[0];
        return d == CMAESOptimizer.DEFAULT_STOPFITNESS ? dArr : MatrixUtils.matrixMultiply(rowVector, 1.0d / d)[0];
    }

    private void addSample(PredictionInput predictionInput, List<Integer> list, double d, boolean z, boolean z2, ShapDataCarrier shapDataCarrier) {
        boolean[] zArr = new boolean[shapDataCarrier.getCols()];
        if (z) {
            for (int i = 0; i < shapDataCarrier.getNumVarying(); i++) {
                zArr[shapDataCarrier.getVaryingFeatureGroups(i).intValue()] = true;
            }
        }
        for (int i2 = 0; i2 < list.size(); i2++) {
            zArr[shapDataCarrier.getVaryingFeatureGroups(list.get(i2).intValue()).intValue()] = !z;
        }
        int hashMask = hashMask(zArr);
        if (shapDataCarrier.getMasksUsed().containsKey(Integer.valueOf(hashMask))) {
            shapDataCarrier.getSamplesAdded(shapDataCarrier.getMasksUsed(Integer.valueOf(hashMask)).intValue()).incrementWeight();
            return;
        }
        ShapSyntheticDataSample shapSyntheticDataSample = new ShapSyntheticDataSample(predictionInput, zArr, this.config.getBackgroundMatrix(), d, z2);
        shapDataCarrier.addMask(Integer.valueOf(hashMask), shapDataCarrier.getSamplesAddedSize());
        shapDataCarrier.addSample(shapSyntheticDataSample);
    }

    private int hashMask(boolean[] zArr) {
        int length = zArr.length;
        int i = 0;
        for (int i2 = 0; i2 < length; i2++) {
            i = (int) (i + (zArr[i2] ? Math.pow(2.0d, (length - i2) - 1) : CMAESOptimizer.DEFAULT_STOPFITNESS));
        }
        return i;
    }

    public static Saliency[] saliencyFromMatrix(double[][] dArr, PredictionInput predictionInput, PredictionOutput predictionOutput) {
        int[] shape = MatrixUtils.getShape(dArr);
        Saliency[] saliencyArr = new Saliency[shape[0]];
        for (int i = 0; i < shape[0]; i++) {
            ArrayList arrayList = new ArrayList();
            for (int i2 = 0; i2 < shape[1]; i2++) {
                arrayList.add(new FeatureImportance(predictionInput.getFeatures().get(i2), dArr[i][i2]));
            }
            saliencyArr[i] = new Saliency(predictionOutput.getOutputs().get(i), arrayList);
        }
        return saliencyArr;
    }

    public static Saliency[] saliencyFromMatrix(double[][] dArr, double[][] dArr2, PredictionInput predictionInput, PredictionOutput predictionOutput) {
        int[] shape = MatrixUtils.getShape(dArr);
        Saliency[] saliencyArr = new Saliency[shape[0]];
        for (int i = 0; i < shape[0]; i++) {
            ArrayList arrayList = new ArrayList();
            for (int i2 = 0; i2 < shape[1]; i2++) {
                arrayList.add(new FeatureImportance(predictionInput.getFeatures().get(i2), dArr[i][i2], dArr2[i][i2]));
            }
            saliencyArr[i] = new Saliency(predictionOutput.getOutputs().get(i), arrayList);
        }
        return saliencyArr;
    }

    private CompletableFuture<Saliency[]> explain(Prediction prediction, PredictionProvider predictionProvider) {
        ShapDataCarrier initialize = initialize(predictionProvider);
        initialize.setSamplesAdded(new ArrayList());
        PredictionInput input = prediction.getInput();
        PredictionOutput output = prediction.getOutput();
        if (input.getFeatures().size() != initialize.getCols()) {
            throw new IllegalArgumentException(String.format("Prediction input feature count (%d) does not match background data feature count (%d)", Integer.valueOf(input.getFeatures().size()), Integer.valueOf(initialize.getCols())));
        }
        int cols = initialize.getCols();
        CompletableFuture<U> thenApply = initialize.getOutputSize().thenApply(num -> {
            if (output.getOutputs().size() != num.intValue()) {
                throw new IllegalArgumentException(String.format("Prediction output size (%d) does not match background data output size (%d)", Integer.valueOf(output.getOutputs().size()), num));
            }
            return new double[num.intValue()][cols];
        });
        double[][] matrixFromPredictionOutput = MatrixUtils.matrixFromPredictionOutput(output);
        setVaryingFeatureGroups(input, initialize);
        if (initialize.getNumVarying() == 0) {
            return thenApply.thenApply((Function<? super U, ? extends U>) dArr -> {
                return saliencyFromMatrix(dArr, input, output);
            });
        }
        if (initialize.getNumVarying() == 1) {
            CompletableFuture<U> thenApply2 = initialize.getLinkNull().thenApply(dArr2 -> {
                return MatrixUtils.matrixDifference(matrixFromPredictionOutput, dArr2)[0];
            });
            return thenApply.thenCompose((Function<? super U, ? extends CompletionStage<U>>) dArr3 -> {
                return thenApply2.thenCombine((CompletionStage) initialize.getOutputSize(), (dArr3, num2) -> {
                    double[][] dArr3 = new double[num2.intValue()][cols];
                    for (int i = 0; i < num2.intValue(); i++) {
                        dArr3[i][initialize.getVaryingFeatureGroups(0).intValue()] = dArr3[i];
                    }
                    return saliencyFromMatrix(dArr3, input, output);
                });
            });
        }
        ShapStatistics computeSubsetStatistics = computeSubsetStatistics(initialize);
        initializeWeights(computeSubsetStatistics, initialize);
        addCompleteSubsets(computeSubsetStatistics, input, initialize);
        renormalizeWeights(computeSubsetStatistics);
        addNonCompleteSubsets(computeSubsetStatistics, input, initialize);
        CompletableFuture<double[][]> runSyntheticData = runSyntheticData(initialize);
        return thenApply.thenCompose((Function<? super U, ? extends CompletionStage<U>>) dArr4 -> {
            return solveSystem(runSyntheticData, matrixFromPredictionOutput[0], initialize).thenApply(dArr4 -> {
                return saliencyFromMatrix(dArr4[0], dArr4[1], input, output);
            });
        });
    }

    private ShapStatistics computeSubsetStatistics(ShapDataCarrier shapDataCarrier) {
        int ceil = (int) Math.ceil((shapDataCarrier.getNumVarying() - 1) / 2.0d);
        int i = shapDataCarrier.getNumVarying() % 2 == 1 ? ceil : ceil - 1;
        int[] iArr = new int[ceil + 1];
        for (int i2 = 1; i2 < ceil + 1; i2++) {
            try {
                iArr[i2] = (int) CombinatoricsUtils.binomialCoefficient(shapDataCarrier.getNumVarying(), i2);
            } catch (MathArithmeticException e) {
                iArr[i2] = shapDataCarrier.getNumSamples().intValue() * shapDataCarrier.getNumSamples().intValue();
            }
        }
        return new ShapStatistics(ceil, i, iArr, shapDataCarrier.getNumSamples().intValue());
    }

    private void initializeWeights(ShapStatistics shapStatistics, ShapDataCarrier shapDataCarrier) {
        double[] dArr = new double[shapStatistics.getNumSubsetSizes() + 1];
        for (int i = 1; i <= shapStatistics.getNumSubsetSizes(); i++) {
            double numVarying = (shapDataCarrier.getNumVarying() - 1.0d) / (i * (shapDataCarrier.getNumVarying() - i));
            if (i <= shapStatistics.getLargestPairedSubsetSize()) {
                numVarying *= 2.0d;
            }
            dArr[i] = numVarying;
        }
        double[] normalizeWeightVector = normalizeWeightVector(dArr);
        shapStatistics.setWeightOfSubsetSize(normalizeWeightVector);
        shapStatistics.setRemainingWeights(Arrays.copyOf(normalizeWeightVector, normalizeWeightVector.length));
    }

    private void addCompleteSubsets(ShapStatistics shapStatistics, PredictionInput predictionInput, ShapDataCarrier shapDataCarrier) {
        shapDataCarrier.setMasksUsed(new HashMap<>());
        int i = 1;
        while (i < shapStatistics.getNumSubsetSizes() + 1) {
            int i2 = shapStatistics.getNumSubsetsAtSize()[i] * (i <= shapStatistics.getLargestPairedSubsetSize() ? 2 : 1);
            if (shapStatistics.getNumSamplesRemaining() * shapStatistics.getRemainingWeights()[i] < i2) {
                return;
            }
            shapStatistics.incrementNumFullSubsets();
            shapStatistics.decreaseNumSamplesRemainingBy(i2);
            double[] remainingWeights = shapStatistics.getRemainingWeights();
            remainingWeights[i] = 0.0d;
            shapStatistics.setRemainingWeights(normalizeWeightVector(remainingWeights));
            Iterator<int[]> combinationsIterator = CombinatoricsUtils.combinationsIterator(shapDataCarrier.getNumVarying(), i);
            double d = shapStatistics.getWeightOfSubsetSize()[i] / i2;
            while (combinationsIterator.hasNext()) {
                List<Integer> list = (List) Arrays.stream(combinationsIterator.next()).boxed().collect(Collectors.toList());
                addSample(predictionInput, list, d, false, true, shapDataCarrier);
                if (i <= shapStatistics.getLargestPairedSubsetSize()) {
                    addSample(predictionInput, list, d, true, true, shapDataCarrier);
                }
            }
            i++;
        }
    }

    private void renormalizeWeights(ShapStatistics shapStatistics) {
        double[] weightOfSubsetSize = shapStatistics.getWeightOfSubsetSize();
        double[] copyOf = Arrays.copyOf(weightOfSubsetSize, weightOfSubsetSize.length);
        for (int i = 0; i < copyOf.length; i++) {
            if (i < shapStatistics.getLargestPairedSubsetSize()) {
                int i2 = i;
                copyOf[i2] = copyOf[i2] / 2.0d;
            }
        }
        shapStatistics.setFinalRemainingWeights(normalizeWeightVector(Arrays.copyOfRange(copyOf, shapStatistics.getNumFullSubsets() + 1, shapStatistics.getNumSubsetSizes() + 1)));
    }

    private void addNonCompleteSubsets(ShapStatistics shapStatistics, PredictionInput predictionInput, ShapDataCarrier shapDataCarrier) {
        if (shapStatistics.getNumFullSubsets() < shapStatistics.getNumSubsetSizes()) {
            List sample = new RandomChoice((List) IntStream.range(shapStatistics.getNumFullSubsets() + 1, shapStatistics.getNumSubsetSizes() + 1).boxed().collect(Collectors.toList()), (List) Arrays.stream(shapStatistics.getFinalRemainingWeights()).boxed().collect(Collectors.toList())).sample(shapStatistics.getNumSamplesRemaining() * 4, this.config.getPC().getRandom());
            List list = (List) IntStream.range(0, shapDataCarrier.getNumVarying()).boxed().collect(Collectors.toList());
            int i = 0;
            while (shapStatistics.getNumSamplesRemaining() > 0) {
                int intValue = ((Integer) sample.get(i)).intValue();
                i++;
                Collections.shuffle(list);
                List<Integer> subList = list.subList(0, intValue);
                addSample(predictionInput, subList, 1.0d, false, false, shapDataCarrier);
                shapStatistics.decreaseNumSamplesRemainingBy(1);
                if (shapStatistics.getNumSamplesRemaining() > 0 && intValue <= shapStatistics.getLargestPairedSubsetSize()) {
                    addSample(predictionInput, subList, 1.0d, true, false, shapDataCarrier);
                    shapStatistics.decreaseNumSamplesRemainingBy(1);
                }
            }
            normalizeSampleWeights(shapStatistics, shapDataCarrier);
        }
    }

    private void normalizeSampleWeights(ShapStatistics shapStatistics, ShapDataCarrier shapDataCarrier) {
        double d = 0.0d;
        for (int numFullSubsets = shapStatistics.getNumFullSubsets() + 1; numFullSubsets < shapStatistics.getNumSubsetSizes() + 1; numFullSubsets++) {
            d += shapStatistics.getWeightOfSubsetSize()[numFullSubsets];
        }
        double d2 = 0.0d;
        for (int i = 0; i < shapDataCarrier.getSamplesAddedSize().intValue(); i++) {
            if (!shapDataCarrier.getSamplesAdded(i).isFixed()) {
                d2 += shapDataCarrier.getSamplesAdded(i).getWeight();
            }
        }
        for (int i2 = 0; i2 < shapDataCarrier.getSamplesAddedSize().intValue(); i2++) {
            ShapSyntheticDataSample samplesAdded = shapDataCarrier.getSamplesAdded(i2);
            if (!samplesAdded.isFixed() && d2 != CMAESOptimizer.DEFAULT_STOPFITNESS) {
                samplesAdded.setWeight((samplesAdded.getWeight() * d) / d2);
            }
        }
    }

    private CompletableFuture<double[][]> runSyntheticData(ShapDataCarrier shapDataCarrier) {
        return shapDataCarrier.getLinkNull().thenCompose(dArr -> {
            return shapDataCarrier.getOutputSize().thenCompose(num -> {
                HashMap hashMap = new HashMap();
                for (int i = 0; i < shapDataCarrier.getSamplesAddedSize().intValue(); i++) {
                    hashMap.put(Integer.valueOf(i), shapDataCarrier.getModel().predictAsync(shapDataCarrier.getSamplesAdded(i).getSyntheticData()).thenApply(MatrixUtils::matrixFromPredictionOutput).thenApply((Function<? super U, ? extends U>) dArr -> {
                        return MatrixUtils.sum(MatrixUtils.matrixMultiply(dArr, 1.0d / shapDataCarrier.getRows()), MatrixUtils.Axis.ROW);
                    }).thenApply(this::link).thenApply(dArr2 -> {
                        return MatrixUtils.matrixDifference(MatrixUtils.rowVector(dArr2), dArr)[0];
                    }));
                }
                CompletableFuture[] completableFutureArr = {CompletableFuture.supplyAsync(() -> {
                    return new double[shapDataCarrier.getSamplesAddedSize().intValue()][num.intValue()];
                }, this.config.getExecutor())};
                hashMap.forEach((num, completableFuture) -> {
                    completableFutureArr[0] = completableFutureArr[0].thenCompose(dArr3 -> {
                        return completableFuture.thenApply(dArr3 -> {
                            dArr3[num.intValue()] = dArr3;
                            return dArr3;
                        });
                    });
                });
                return completableFutureArr[0];
            });
        });
    }

    private double[][] solve(double[][] dArr, int i, double[] dArr2, double[] dArr3, int i2, ShapDataCarrier shapDataCarrier) {
        double[][] dArr4 = new double[shapDataCarrier.getSamplesAddedSize().intValue()][shapDataCarrier.getCols()];
        double[] dArr5 = new double[shapDataCarrier.getSamplesAddedSize().intValue()];
        double[] dArr6 = new double[shapDataCarrier.getSamplesAddedSize().intValue()];
        for (int i3 = 0; i3 < shapDataCarrier.getSamplesAddedSize().intValue(); i3++) {
            for (int i4 = 0; i4 < shapDataCarrier.getCols(); i4++) {
                dArr4[i3][i4] = shapDataCarrier.getSamplesAdded(i3).getMask()[i4] ? 1.0d : CMAESOptimizer.DEFAULT_STOPFITNESS;
            }
            dArr6[i3] = dArr[i3][i];
            dArr5[i3] = shapDataCarrier.getSamplesAdded(i3).getWeight();
        }
        double link = link(dArr2[i]) - link(dArr3[i]);
        double[][] rowVector = MatrixUtils.rowVector(MatrixUtils.getCol(dArr4, i2));
        double[] dArr7 = MatrixUtils.matrixDifference(MatrixUtils.rowVector(dArr6), MatrixUtils.matrixMultiply(rowVector, link))[0];
        ArrayList arrayList = new ArrayList();
        shapDataCarrier.getVaryingFeatureGroups().forEach(num -> {
            if (num.intValue() != i2) {
                arrayList.add(num);
            }
        });
        return runWLRR(MatrixUtils.transpose(MatrixUtils.matrixRowDifference(MatrixUtils.transpose(MatrixUtils.getCols(dArr4, arrayList)), rowVector[0])), dArr7, dArr5, link, i2, shapDataCarrier);
    }

    private CompletableFuture<double[][][]> solveSystem(CompletableFuture<double[][]> completableFuture, double[] dArr, ShapDataCarrier shapDataCarrier) {
        int intValue = shapDataCarrier.getVaryingFeatureGroups(shapDataCarrier.getVaryingFeatureGroups().size() - 1).intValue();
        return completableFuture.thenCompose(dArr2 -> {
            return shapDataCarrier.getFnull().thenCompose(dArr2 -> {
                return shapDataCarrier.getOutputSize().thenCompose(num -> {
                    HashMap hashMap = new HashMap();
                    for (int i = 0; i < num.intValue(); i++) {
                        int i2 = i;
                        hashMap.put(Integer.valueOf(i), CompletableFuture.supplyAsync(() -> {
                            return solve(dArr2, i2, dArr, dArr2, intValue, shapDataCarrier);
                        }, this.config.getExecutor()));
                    }
                    CompletableFuture[] completableFutureArr = {CompletableFuture.supplyAsync(() -> {
                        return new double[2][num.intValue()][shapDataCarrier.getCols()];
                    }, this.config.getExecutor())};
                    hashMap.forEach((num, completableFuture2) -> {
                        completableFutureArr[0] = completableFutureArr[0].thenCompose(dArr2 -> {
                            return completableFuture2.thenApply(dArr2 -> {
                                dArr2[0][num.intValue()] = dArr2[0];
                                dArr2[1][num.intValue()] = dArr2[1];
                                return dArr2;
                            });
                        });
                    });
                    return completableFutureArr[0];
                });
            });
        });
    }

    private double[][] runWLRR(double[][] dArr, double[] dArr2, double[] dArr3, double d, int i, ShapDataCarrier shapDataCarrier) {
        WeightedLinearRegressionResults fit = WeightedLinearRegression.fit(dArr, dArr2, dArr3, false, this.config.getPC().getRandom());
        double[] coefficients = fit.getCoefficients();
        double[] conf = fit.getConf(1.0d - this.config.getConfidence());
        int i2 = 0;
        double[] dArr4 = new double[shapDataCarrier.getCols()];
        double[] dArr5 = new double[shapDataCarrier.getCols()];
        for (int i3 = 0; i3 < shapDataCarrier.getVaryingFeatureGroups().size(); i3++) {
            int intValue = shapDataCarrier.getVaryingFeatureGroups(i3).intValue();
            if (intValue != i) {
                dArr4[intValue] = coefficients[i2];
                dArr5[intValue] = conf[i2];
                i2++;
            }
        }
        dArr4[i] = d - Arrays.stream(coefficients).sum();
        dArr5[i] = Math.sqrt(Arrays.stream(conf).map(d2 -> {
            return d2 * d2;
        }).sum());
        return new double[][]{dArr4, dArr5};
    }

    @Override // org.kie.kogito.explainability.local.LocalExplainer
    public CompletableFuture<Saliency[]> explainAsync(Prediction prediction, PredictionProvider predictionProvider) {
        return explainAsync(prediction, predictionProvider, null);
    }

    @Override // org.kie.kogito.explainability.local.LocalExplainer
    public CompletableFuture<Saliency[]> explainAsync(Prediction prediction, PredictionProvider predictionProvider, Consumer<Saliency[]> consumer) {
        return explain(prediction, predictionProvider);
    }
}
