package org.kie.kogito.explainability.local.shap;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.math3.exception.MathArithmeticException;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.apache.commons.math3.util.CombinatoricsUtils;
import org.kie.kogito.explainability.local.LocalExplainer;
import org.kie.kogito.explainability.local.shap.ShapConfig;
import org.kie.kogito.explainability.model.FeatureImportance;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.Saliency;
import org.kie.kogito.explainability.utils.LarsPath;
import org.kie.kogito.explainability.utils.LassoLarsIC;
import org.kie.kogito.explainability.utils.MatrixUtilsExtensions;
import org.kie.kogito.explainability.utils.RandomChoice;
import org.kie.kogito.explainability.utils.WeightedLinearRegression;
import org.kie.kogito.explainability.utils.WeightedLinearRegressionResults;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/kie/kogito/explainability/local/shap/ShapKernelExplainer.class */
public class ShapKernelExplainer implements LocalExplainer<ShapResults> {
    private static final Logger LOGGER = LoggerFactory.getLogger((Class<?>) ShapKernelExplainer.class);
    private ShapConfig config;

    public ShapKernelExplainer(ShapConfig shapConfig) {
        this.config = shapConfig;
    }

    public void setConfig(ShapConfig shapConfig) {
        this.config = shapConfig;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private ShapDataCarrier initialize(PredictionProvider predictionProvider) {
        int pow;
        int rowDimension = this.config.getBackgroundMatrix().getRowDimension();
        int columnDimension = this.config.getBackgroundMatrix().getColumnDimension();
        if (rowDimension > 100) {
            LOGGER.debug("Warning: Background data sets larger than 100 samples might be slow!");
        }
        CompletableFuture<U> thenApply = predictionProvider.predictAsync(this.config.getBackground()).thenApply(MatrixUtilsExtensions::matrixFromPredictionOutput);
        CompletableFuture thenApply2 = thenApply.thenApply((Function<? super U, ? extends U>) (v0) -> {
            return v0.getColumnDimension();
        });
        CompletableFuture thenApply3 = thenApply.thenApply((Function<? super U, ? extends U>) realMatrix -> {
            return MatrixUtilsExtensions.rowSum(realMatrix).mapDivide(rowDimension);
        });
        CompletableFuture thenApply4 = thenApply3.thenApply(this::link);
        int intValue = this.config.getNSamples().orElseGet(() -> {
            return Integer.valueOf(2048 + (2 * columnDimension));
        }).intValue();
        if (columnDimension <= 30 && (pow = ((int) Math.pow(2.0d, columnDimension)) - 2) < intValue) {
            intValue = pow;
        }
        ShapDataCarrier shapDataCarrier = new ShapDataCarrier();
        shapDataCarrier.setRows(rowDimension);
        shapDataCarrier.setCols(columnDimension);
        shapDataCarrier.setOutputSize(thenApply2);
        shapDataCarrier.setModel(predictionProvider);
        shapDataCarrier.setFnull(thenApply3);
        shapDataCarrier.setLinkNull(thenApply4);
        shapDataCarrier.setNumSamples(Integer.valueOf(intValue));
        return shapDataCarrier;
    }

    private double link(double d) {
        return this.config.getLink().equals(ShapConfig.LinkType.IDENTITY) ? d : Math.log(d / (1.0d - d));
    }

    private RealVector link(RealVector realVector) {
        return realVector.map(this::link);
    }

    private void setVaryingFeatureGroups(PredictionInput predictionInput, ShapDataCarrier shapDataCarrier) {
        ArrayList arrayList = new ArrayList();
        RealVector vectorFromPredictionInput = MatrixUtilsExtensions.vectorFromPredictionInput(predictionInput);
        RealVector createRealVector = MatrixUtils.createRealVector(new double[shapDataCarrier.getRows() + 1]);
        for (int i = 0; i < shapDataCarrier.getCols(); i++) {
            createRealVector.setSubVector(0, this.config.getBackgroundMatrix().getColumnVector(i));
            createRealVector.setEntry(shapDataCarrier.getRows(), vectorFromPredictionInput.getEntry(i));
            if (Arrays.stream(createRealVector.toArray()).distinct().count() > 1) {
                arrayList.add(Integer.valueOf(i));
            }
        }
        shapDataCarrier.setVaryingFeatureGroups(arrayList);
        shapDataCarrier.setNumVarying(arrayList.size());
    }

    private RealVector normalizeWeightVector(RealVector realVector) {
        try {
            return realVector.mapDivide(MatrixUtilsExtensions.sum(realVector));
        } catch (MathArithmeticException e) {
            return realVector;
        }
    }

    private boolean addSample(PredictionInput predictionInput, List<Integer> list, double d, boolean z, boolean z2, ShapDataCarrier shapDataCarrier) {
        boolean[] zArr = new boolean[shapDataCarrier.getCols()];
        if (z) {
            for (int i = 0; i < shapDataCarrier.getNumVarying(); i++) {
                zArr[shapDataCarrier.getVaryingFeatureGroups(i).intValue()] = true;
            }
        }
        Iterator<Integer> it = list.iterator();
        while (it.hasNext()) {
            zArr[shapDataCarrier.getVaryingFeatureGroups(it.next().intValue()).intValue()] = !z;
        }
        int hashMask = hashMask(zArr);
        if (shapDataCarrier.getMasksUsed().containsKey(Integer.valueOf(hashMask))) {
            shapDataCarrier.getSamplesAdded(shapDataCarrier.getMasksUsed(Integer.valueOf(hashMask)).intValue()).incrementWeight();
            return false;
        }
        ShapSyntheticDataSample shapSyntheticDataSample = new ShapSyntheticDataSample(predictionInput, zArr, this.config.getBackgroundMatrix(), d, z2);
        shapDataCarrier.addMask(Integer.valueOf(hashMask), shapDataCarrier.getSamplesAddedSize());
        shapDataCarrier.addSample(shapSyntheticDataSample);
        return true;
    }

    private int hashMask(boolean[] zArr) {
        int length = zArr.length;
        int i = 0;
        for (int i2 = 0; i2 < length; i2++) {
            i = (int) (i + (zArr[i2] ? Math.pow(2.0d, (length - i2) - 1) : CMAESOptimizer.DEFAULT_STOPFITNESS));
        }
        return i;
    }

    public static Saliency[] saliencyFromMatrix(RealMatrix realMatrix, PredictionInput predictionInput, PredictionOutput predictionOutput) {
        Saliency[] saliencyArr = new Saliency[realMatrix.getRowDimension()];
        for (int i = 0; i < realMatrix.getRowDimension(); i++) {
            ArrayList arrayList = new ArrayList();
            for (int i2 = 0; i2 < realMatrix.getColumnDimension(); i2++) {
                arrayList.add(new FeatureImportance(predictionInput.getFeatures().get(i2), realMatrix.getEntry(i, i2)));
            }
            saliencyArr[i] = new Saliency(predictionOutput.getOutputs().get(i), arrayList);
        }
        return saliencyArr;
    }

    public static Saliency[] saliencyFromMatrix(RealMatrix realMatrix, RealMatrix realMatrix2, PredictionInput predictionInput, PredictionOutput predictionOutput) {
        Saliency[] saliencyArr = new Saliency[realMatrix.getRowDimension()];
        for (int i = 0; i < realMatrix.getRowDimension(); i++) {
            ArrayList arrayList = new ArrayList();
            for (int i2 = 0; i2 < realMatrix.getColumnDimension(); i2++) {
                arrayList.add(new FeatureImportance(predictionInput.getFeatures().get(i2), realMatrix.getEntry(i, i2), realMatrix2.getEntry(i, i2)));
            }
            saliencyArr[i] = new Saliency(predictionOutput.getOutputs().get(i), arrayList);
        }
        return saliencyArr;
    }

    private CompletableFuture<ShapResults> explain(Prediction prediction, PredictionProvider predictionProvider) {
        ShapDataCarrier initialize = initialize(predictionProvider);
        initialize.setSamplesAdded(new ArrayList());
        PredictionInput input = prediction.getInput();
        PredictionOutput output = prediction.getOutput();
        if (input.getFeatures().size() != initialize.getCols()) {
            throw new IllegalArgumentException(String.format("Prediction input feature count (%d) does not match background data feature count (%d)", Integer.valueOf(input.getFeatures().size()), Integer.valueOf(initialize.getCols())));
        }
        int cols = initialize.getCols();
        CompletableFuture<U> thenApply = initialize.getOutputSize().thenApply(num -> {
            if (output.getOutputs().size() != num.intValue()) {
                throw new IllegalArgumentException(String.format("Prediction output size (%d) does not match background data output size (%d)", Integer.valueOf(output.getOutputs().size()), num));
            }
            return MatrixUtils.createRealMatrix(new double[num.intValue()][cols]);
        });
        RealVector vectorFromPredictionOutput = MatrixUtilsExtensions.vectorFromPredictionOutput(output);
        setVaryingFeatureGroups(input, initialize);
        if (initialize.getNumVarying() == 0) {
            return thenApply.thenApply((Function<? super U, ? extends U>) realMatrix -> {
                return saliencyFromMatrix(realMatrix, input, output);
            }).thenCombine((CompletionStage) initialize.getFnull(), ShapResults::new);
        }
        if (initialize.getNumVarying() == 1) {
            CompletableFuture<RealVector> linkNull = initialize.getLinkNull();
            Objects.requireNonNull(vectorFromPredictionOutput);
            CompletableFuture<U> thenApply2 = linkNull.thenApply(vectorFromPredictionOutput::subtract);
            return thenApply.thenCompose((Function<? super U, ? extends CompletionStage<U>>) realMatrix2 -> {
                return thenApply2.thenCombine((CompletionStage) initialize.getOutputSize(), (realVector, num2) -> {
                    RealMatrix createRealMatrix = MatrixUtils.createRealMatrix(new double[num2.intValue()][cols]);
                    for (int i = 0; i < num2.intValue(); i++) {
                        createRealMatrix.setEntry(i, initialize.getVaryingFeatureGroups(0).intValue(), realVector.getEntry(i));
                    }
                    return saliencyFromMatrix(createRealMatrix, input, output);
                });
            }).thenCombine((CompletionStage) initialize.getFnull(), ShapResults::new);
        }
        ShapStatistics computeSubsetStatistics = computeSubsetStatistics(initialize);
        initializeWeights(computeSubsetStatistics, initialize);
        addCompleteSubsets(computeSubsetStatistics, input, initialize);
        renormalizeWeights(computeSubsetStatistics);
        addNonCompleteSubsets(computeSubsetStatistics, input, initialize);
        CompletableFuture<RealMatrix> runSyntheticData = runSyntheticData(initialize);
        return thenApply.thenCompose((Function<? super U, ? extends CompletionStage<U>>) realMatrix3 -> {
            return solveSystem(runSyntheticData, vectorFromPredictionOutput, initialize).thenApply(realMatrixArr -> {
                return saliencyFromMatrix(realMatrixArr[0], realMatrixArr[1], input, output);
            });
        }).thenCombine((CompletionStage) initialize.getFnull(), ShapResults::new);
    }

    private ShapStatistics computeSubsetStatistics(ShapDataCarrier shapDataCarrier) {
        int ceil = (int) Math.ceil((shapDataCarrier.getNumVarying() - 1) / 2.0d);
        int i = shapDataCarrier.getNumVarying() % 2 == 1 ? ceil : ceil - 1;
        int[] iArr = new int[ceil + 1];
        for (int i2 = 1; i2 < ceil + 1; i2++) {
            try {
                iArr[i2] = (int) CombinatoricsUtils.binomialCoefficient(shapDataCarrier.getNumVarying(), i2);
            } catch (MathArithmeticException e) {
                iArr[i2] = shapDataCarrier.getNumSamples().intValue() * shapDataCarrier.getNumSamples().intValue();
            }
        }
        return new ShapStatistics(ceil, i, iArr, shapDataCarrier.getNumSamples().intValue());
    }

    private void initializeWeights(ShapStatistics shapStatistics, ShapDataCarrier shapDataCarrier) {
        double[] dArr = new double[shapStatistics.getNumSubsetSizes() + 1];
        for (int i = 1; i <= shapStatistics.getNumSubsetSizes(); i++) {
            double numVarying = (shapDataCarrier.getNumVarying() - 1.0d) / (i * (shapDataCarrier.getNumVarying() - i));
            if (i <= shapStatistics.getLargestPairedSubsetSize()) {
                numVarying *= 2.0d;
            }
            dArr[i] = numVarying;
        }
        RealVector normalizeWeightVector = normalizeWeightVector(MatrixUtils.createRealVector(dArr));
        shapStatistics.setWeightOfSubsetSize(normalizeWeightVector);
        shapStatistics.setRemainingWeights(normalizeWeightVector.copy());
    }

    private void addCompleteSubsets(ShapStatistics shapStatistics, PredictionInput predictionInput, ShapDataCarrier shapDataCarrier) {
        shapDataCarrier.setMasksUsed(new HashMap<>());
        int i = 1;
        while (i < shapStatistics.getNumSubsetSizes() + 1) {
            int i2 = shapStatistics.getNumSubsetsAtSize()[i] * (i <= shapStatistics.getLargestPairedSubsetSize() ? 2 : 1);
            if (shapStatistics.getNumSamplesRemaining() * shapStatistics.getRemainingWeights().getEntry(i) < i2) {
                return;
            }
            shapStatistics.incrementNumFullSubsets();
            shapStatistics.decreaseNumSamplesRemainingBy(i2);
            RealVector remainingWeights = shapStatistics.getRemainingWeights();
            remainingWeights.setEntry(i, CMAESOptimizer.DEFAULT_STOPFITNESS);
            shapStatistics.setRemainingWeights(normalizeWeightVector(remainingWeights));
            Iterator<int[]> combinationsIterator = CombinatoricsUtils.combinationsIterator(shapDataCarrier.getNumVarying(), i);
            double entry = shapStatistics.getWeightOfSubsetSize().getEntry(i) / i2;
            while (combinationsIterator.hasNext()) {
                List<Integer> list = (List) Arrays.stream(combinationsIterator.next()).boxed().collect(Collectors.toList());
                addSample(predictionInput, list, entry, false, true, shapDataCarrier);
                if (i <= shapStatistics.getLargestPairedSubsetSize()) {
                    addSample(predictionInput, list, entry, true, true, shapDataCarrier);
                }
            }
            i++;
        }
    }

    private void renormalizeWeights(ShapStatistics shapStatistics) {
        RealVector copy = shapStatistics.getWeightOfSubsetSize().copy();
        copy.ebeDivide(MatrixUtils.createRealVector(IntStream.range(0, copy.getDimension()).mapToDouble(i -> {
            return i < shapStatistics.getLargestPairedSubsetSize() ? 2.0d : 1.0d;
        }).toArray()));
        shapStatistics.setFinalRemainingWeights(normalizeWeightVector(copy.getSubVector(shapStatistics.getNumFullSubsets() + 1, shapStatistics.getNumSubsetSizes() - shapStatistics.getNumFullSubsets())));
    }

    private void addNonCompleteSubsets(ShapStatistics shapStatistics, PredictionInput predictionInput, ShapDataCarrier shapDataCarrier) {
        if (shapStatistics.getNumFullSubsets() < shapStatistics.getNumSubsetSizes()) {
            RandomChoice randomChoice = new RandomChoice((List) IntStream.range(shapStatistics.getNumFullSubsets() + 1, shapStatistics.getNumSubsetSizes() + 1).boxed().collect(Collectors.toList()), (List) Arrays.stream(shapStatistics.getFinalRemainingWeights().toArray()).boxed().collect(Collectors.toList()));
            List sample = randomChoice.sample(shapStatistics.getNumSamplesRemaining() * 4, this.config.getPC().getRandom());
            List list = (List) IntStream.range(0, shapDataCarrier.getNumVarying()).boxed().collect(Collectors.toList());
            int i = 0;
            while (shapStatistics.getNumSamplesRemaining() > 0) {
                if (i >= sample.size()) {
                    sample = randomChoice.sample(shapStatistics.getNumSamplesRemaining() * 4, this.config.getPC().getRandom());
                    i = 0;
                }
                int intValue = ((Integer) sample.get(i)).intValue();
                i++;
                Collections.shuffle(list);
                List<Integer> subList = list.subList(0, intValue);
                if (addSample(predictionInput, subList, 1.0d, false, false, shapDataCarrier)) {
                    shapStatistics.decreaseNumSamplesRemainingBy(1);
                }
                if (shapStatistics.getNumSamplesRemaining() > 0 && intValue <= shapStatistics.getLargestPairedSubsetSize() && addSample(predictionInput, subList, 1.0d, true, false, shapDataCarrier)) {
                    shapStatistics.decreaseNumSamplesRemainingBy(1);
                }
            }
            normalizeSampleWeights(shapStatistics, shapDataCarrier);
        }
    }

    private void normalizeSampleWeights(ShapStatistics shapStatistics, ShapDataCarrier shapDataCarrier) {
        double sum = MatrixUtilsExtensions.sum(shapStatistics.getWeightOfSubsetSize().getSubVector(shapStatistics.getNumFullSubsets(), shapStatistics.getNumSubsetSizes() - shapStatistics.getNumFullSubsets()));
        double d = 0.0d;
        for (int i = 0; i < shapDataCarrier.getSamplesAddedSize().intValue(); i++) {
            if (!shapDataCarrier.getSamplesAdded(i).isFixed()) {
                d += shapDataCarrier.getSamplesAdded(i).getWeight();
            }
        }
        for (int i2 = 0; i2 < shapDataCarrier.getSamplesAddedSize().intValue(); i2++) {
            ShapSyntheticDataSample samplesAdded = shapDataCarrier.getSamplesAdded(i2);
            if (!samplesAdded.isFixed() && d != CMAESOptimizer.DEFAULT_STOPFITNESS) {
                samplesAdded.setWeight((samplesAdded.getWeight() * sum) / d);
            }
        }
    }

    private CompletableFuture<RealMatrix> runSyntheticData(ShapDataCarrier shapDataCarrier) {
        if (this.config.getBatchSize() <= 1) {
            return shapDataCarrier.getLinkNull().thenCompose(realVector -> {
                return shapDataCarrier.getOutputSize().thenCompose(num -> {
                    HashMap hashMap = new HashMap();
                    for (int i = 0; i < shapDataCarrier.getSamplesAddedSize().intValue(); i++) {
                        hashMap.put(Integer.valueOf(i), shapDataCarrier.getModel().predictAsync(shapDataCarrier.getSamplesAdded(i).getSyntheticData()).thenApply(MatrixUtilsExtensions::matrixFromPredictionOutput).thenApply((Function<? super U, ? extends U>) realMatrix -> {
                            return MatrixUtilsExtensions.rowSum(realMatrix).mapDivide(realMatrix.getRowDimension());
                        }).thenApply(this::link).thenApply(realVector -> {
                            return realVector.subtract(realVector);
                        }));
                    }
                    CompletableFuture[] completableFutureArr = {CompletableFuture.supplyAsync(() -> {
                        return MatrixUtils.createRealMatrix(new double[shapDataCarrier.getSamplesAddedSize().intValue()][num.intValue()]);
                    }, this.config.getExecutor())};
                    hashMap.forEach((num, completableFuture) -> {
                        completableFutureArr[0] = completableFutureArr[0].thenCompose(realMatrix2 -> {
                            return completableFuture.thenApply(realVector2 -> {
                                realMatrix2.setRowVector(num.intValue(), realVector2);
                                return realMatrix2;
                            });
                        });
                    });
                    return completableFutureArr[0];
                });
            });
        }
        int batchSize = this.config.getBatchSize();
        return shapDataCarrier.getLinkNull().thenCompose(realVector2 -> {
            return shapDataCarrier.getOutputSize().thenCompose(num -> {
                CompletableFuture supplyAsync = CompletableFuture.supplyAsync(() -> {
                    return MatrixUtils.createRealMatrix(new double[shapDataCarrier.getSamplesAddedSize().intValue()][num.intValue()]);
                }, this.config.getExecutor());
                int i = 0;
                while (true) {
                    int i2 = i;
                    if (i2 >= shapDataCarrier.getSamplesAddedSize().intValue()) {
                        return supplyAsync;
                    }
                    supplyAsync = shapDataCarrier.getModel().predictAsync((List) IntStream.range(i2, Math.min(shapDataCarrier.getSamplesAddedSize().intValue(), i2 + batchSize)).mapToObj(i3 -> {
                        return shapDataCarrier.getSamplesAdded(i3).getSyntheticData();
                    }).collect(ArrayList::new, (v0, v1) -> {
                        v0.addAll(v1);
                    }, (v0, v1) -> {
                        v0.addAll(v1);
                    })).thenApply(MatrixUtilsExtensions::matrixFromPredictionOutput).thenApply((Function<? super U, ? extends U>) realMatrix -> {
                        return MatrixUtilsExtensions.batchRowMean(realMatrix, shapDataCarrier.getRows());
                    }).thenCombine((CompletionStage) supplyAsync, (realMatrix2, realMatrix3) -> {
                        IntStream.range(0, realMatrix2.getRowDimension()).forEach(i4 -> {
                            realMatrix3.setRowVector(i2 + i4, realMatrix2.getRowVector(i4).map(this::link).subtract(realVector2));
                        });
                        return realMatrix3;
                    });
                    i = i2 + batchSize;
                }
            });
        });
    }

    private List<Integer> getRegularizationIndexes(RealMatrix realMatrix, RealVector realVector) {
        List<Integer> of = List.of();
        switch (this.config.getRegularizerType()) {
            case AUTO:
            case AIC:
                of = MatrixUtilsExtensions.nonzero(LassoLarsIC.fit(realMatrix, realVector, LassoLarsIC.Criterion.AIC).getCoefs());
                break;
            case BIC:
                of = MatrixUtilsExtensions.nonzero(LassoLarsIC.fit(realMatrix, realVector, LassoLarsIC.Criterion.BIC).getCoefs());
                break;
            case TOP_N_FEATURES:
                of = LarsPath.fit(realMatrix, realVector, this.config.getNRegularizationFeatures().intValue(), false).getActive();
                break;
            case NONE:
                throw new IllegalArgumentException("RegularizerType=NONE will never be able enter the switch statement");
        }
        return of;
    }

    private RealVector[] solve(RealMatrix realMatrix, int i, RealVector realVector, RealVector realVector2, ShapDataCarrier shapDataCarrier) {
        List<Integer> regularizationIndexes;
        RealMatrix createRealMatrix = MatrixUtils.createRealMatrix(new double[shapDataCarrier.getSamplesAddedSize().intValue()][shapDataCarrier.getCols()]);
        RealVector createRealVector = MatrixUtils.createRealVector(new double[shapDataCarrier.getSamplesAddedSize().intValue()]);
        RealVector createRealVector2 = MatrixUtils.createRealVector(new double[shapDataCarrier.getSamplesAddedSize().intValue()]);
        for (int i2 = 0; i2 < shapDataCarrier.getSamplesAddedSize().intValue(); i2++) {
            for (int i3 = 0; i3 < shapDataCarrier.getCols(); i3++) {
                createRealMatrix.setEntry(i2, i3, shapDataCarrier.getSamplesAdded(i2).getMask()[i3] ? 1.0d : CMAESOptimizer.DEFAULT_STOPFITNESS);
            }
            createRealVector2.setEntry(i2, realMatrix.getEntry(i2, i));
            createRealVector.setEntry(i2, shapDataCarrier.getSamplesAdded(i2).getWeight());
        }
        double intValue = shapDataCarrier.getSamplesAddedSize().intValue() / Math.pow(2.0d, shapDataCarrier.getCols());
        double link = link(realVector.getEntry(i)) - link(realVector2.getEntry(i));
        boolean z = intValue < 0.2d && this.config.getRegularizerType() == ShapConfig.RegularizerType.AUTO;
        boolean z2 = (this.config.getRegularizerType() == ShapConfig.RegularizerType.NONE || this.config.getRegularizerType() == ShapConfig.RegularizerType.AUTO) ? false : true;
        if (z || z2) {
            RealVector colSum = MatrixUtilsExtensions.colSum(createRealMatrix);
            RealVector createRealVector3 = MatrixUtils.createRealVector(new double[createRealVector.getDimension() * 2]);
            createRealVector3.setSubVector(0, createRealVector.ebeMultiply(colSum.map(d -> {
                return shapDataCarrier.getNumVarying() - d;
            })));
            createRealVector3.setSubVector(createRealVector.getDimension(), createRealVector.ebeMultiply(colSum));
            RealVector map = createRealVector3.map(Math::sqrt);
            RealVector createRealVector4 = MatrixUtils.createRealVector(new double[createRealVector2.getDimension() * 2]);
            createRealVector4.setSubVector(0, createRealVector2);
            createRealVector4.setSubVector(createRealVector2.getDimension(), createRealVector2.mapSubtract(link));
            RealVector ebeMultiply = createRealVector4.ebeMultiply(map);
            RealMatrix createRealMatrix2 = MatrixUtils.createRealMatrix(createRealMatrix.getRowDimension() * 2, createRealMatrix.getColumnDimension());
            createRealMatrix2.setSubMatrix(createRealMatrix.getData(), 0, 0);
            createRealMatrix2.setSubMatrix(MatrixUtilsExtensions.map(createRealMatrix, d2 -> {
                return d2 - 1.0d;
            }).getData(), createRealMatrix.getRowDimension(), 0);
            regularizationIndexes = getRegularizationIndexes(MatrixUtilsExtensions.vectorRowProduct(createRealMatrix2.transpose(), map).transpose(), ebeMultiply);
        } else {
            regularizationIndexes = shapDataCarrier.getVaryingFeatureGroups();
        }
        int intValue2 = regularizationIndexes.get(regularizationIndexes.size() - 1).intValue();
        RealVector columnVector = createRealMatrix.getColumnVector(intValue2);
        return runWLRR(MatrixUtilsExtensions.vectorDifference(MatrixUtilsExtensions.getCols(createRealMatrix, regularizationIndexes.subList(0, regularizationIndexes.size() - 1)), columnVector, MatrixUtilsExtensions.Axis.COLUMN), createRealVector2.subtract(columnVector.mapMultiply(link)), createRealVector, link, intValue2, regularizationIndexes, shapDataCarrier);
    }

    private CompletableFuture<RealMatrix[]> solveSystem(CompletableFuture<RealMatrix> completableFuture, RealVector realVector, ShapDataCarrier shapDataCarrier) {
        return completableFuture.thenCompose(realMatrix -> {
            return shapDataCarrier.getFnull().thenCompose(realVector2 -> {
                return shapDataCarrier.getOutputSize().thenCompose(num -> {
                    HashMap hashMap = new HashMap();
                    for (int i = 0; i < num.intValue(); i++) {
                        int i2 = i;
                        hashMap.put(Integer.valueOf(i), CompletableFuture.supplyAsync(() -> {
                            return solve(realMatrix, i2, realVector, realVector2, shapDataCarrier);
                        }, this.config.getExecutor()));
                    }
                    RealMatrix createRealMatrix = MatrixUtils.createRealMatrix(new double[num.intValue()][shapDataCarrier.getCols()]);
                    CompletableFuture[] completableFutureArr = {CompletableFuture.supplyAsync(() -> {
                        return new RealMatrix[]{createRealMatrix.copy(), createRealMatrix.copy()};
                    }, this.config.getExecutor())};
                    hashMap.forEach((num, completableFuture2) -> {
                        completableFutureArr[0] = completableFutureArr[0].thenCompose(realMatrixArr -> {
                            return completableFuture2.thenApply(realVectorArr -> {
                                realMatrixArr[0].setRowVector(num.intValue(), realVectorArr[0]);
                                realMatrixArr[1].setRowVector(num.intValue(), realVectorArr[1]);
                                return realMatrixArr;
                            });
                        });
                    });
                    return completableFutureArr[0];
                });
            });
        });
    }

    private RealVector[] runWLRR(RealMatrix realMatrix, RealVector realVector, RealVector realVector2, double d, int i, List<Integer> list, ShapDataCarrier shapDataCarrier) {
        WeightedLinearRegressionResults fit = WeightedLinearRegression.fit(realMatrix, realVector, realVector2, false);
        RealVector coefficients = fit.getCoefficients();
        RealVector conf = fit.getConf(1.0d - this.config.getConfidence());
        int i2 = 0;
        RealVector createRealVector = MatrixUtils.createRealVector(new double[shapDataCarrier.getCols()]);
        RealVector copy = createRealVector.copy();
        Iterator<Integer> it = list.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            if (intValue != i) {
                createRealVector.setEntry(intValue, coefficients.getEntry(i2));
                copy.setEntry(intValue, conf.getEntry(i2));
                i2++;
            }
        }
        createRealVector.setEntry(i, d - MatrixUtilsExtensions.sum(coefficients));
        copy.setEntry(i, Math.sqrt(MatrixUtilsExtensions.sum(conf.map(d2 -> {
            return d2 * d2;
        }))));
        return new RealVector[]{createRealVector, copy};
    }

    @Override // org.kie.kogito.explainability.local.LocalExplainer
    public CompletableFuture<ShapResults> explainAsync(Prediction prediction, PredictionProvider predictionProvider) {
        return explainAsync(prediction, predictionProvider, null);
    }

    @Override // org.kie.kogito.explainability.local.LocalExplainer
    public CompletableFuture<ShapResults> explainAsync(Prediction prediction, PredictionProvider predictionProvider, Consumer<ShapResults> consumer) {
        return explain(prediction, predictionProvider);
    }
}
