package org.kie.kogito.explainability.utils;

import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.math3.linear.CholeskyDecomposition;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.util.Precision;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/kie/kogito/explainability/utils/LarsPath.class */
public class LarsPath {
    private static final Logger LOGGER = LoggerFactory.getLogger(LarsPath.class);

    private LarsPath() {
        throw new IllegalStateException("Utility class");
    }

    private static void updateCovarianceTrackers(LarsPathDataCarrier larsPathDataCarrier) {
        if (larsPathDataCarrier.getCov().getDimension() > 0) {
            larsPathDataCarrier.setcIdx(larsPathDataCarrier.getCov().map(Math::abs).map(d -> {
                return Precision.round(d, 16);
            }).getMaxIndex());
            larsPathDataCarrier.setC_(larsPathDataCarrier.getCov().getEntry(larsPathDataCarrier.getcIdx()));
        } else {
            larsPathDataCarrier.setC_(0.0d);
            larsPathDataCarrier.setcIdx(0);
        }
        larsPathDataCarrier.setC(Math.abs(larsPathDataCarrier.getC_()));
        larsPathDataCarrier.getAlphas().setEntry(larsPathDataCarrier.getnIter(), larsPathDataCarrier.getC() / larsPathDataCarrier.getnSamples());
    }

    private static void checkRegressorDegeneracy(double d, LarsPathDataCarrier larsPathDataCarrier) {
        if (d >= 1.0E-7d) {
            larsPathDataCarrier.setDegenerateRegressor(false);
            return;
        }
        LOGGER.warn(String.format("Regressors in active set degenerate.Dropping a regressor, after %d iterations, Reduce max_iter or increase eps parameters.", Integer.valueOf(larsPathDataCarrier.getnIter())));
        larsPathDataCarrier.setCov(larsPathDataCarrier.getCovNotShortened());
        larsPathDataCarrier.getCov().setEntry(0, 0.0d);
        MatrixUtilsExtensions.swap(larsPathDataCarrier.getCov(), larsPathDataCarrier.getcIdx(), 0);
        larsPathDataCarrier.setDegenerateRegressor(true);
    }

    private static RealMatrix computeGram(LarsPathDataCarrier larsPathDataCarrier, boolean z) {
        RealMatrix subMatrix = larsPathDataCarrier.getXT().getSubMatrix(0, larsPathDataCarrier.getnActive() - (z ? 0 : 1), 0, larsPathDataCarrier.getnSamples() - 1);
        return MatrixUtilsExtensions.matrixDot(subMatrix, subMatrix.transpose());
    }

    private static void getCholeskyDecomposition(LarsPathDataCarrier larsPathDataCarrier) {
        int i = larsPathDataCarrier.getnActive();
        if (larsPathDataCarrier.isDrop()) {
            larsPathDataCarrier.setDecomp(new CholeskyDecomposition(computeGram(larsPathDataCarrier, false), 1.0E-16d, -1.0E-12d));
            return;
        }
        int i2 = larsPathDataCarrier.getcIdx();
        larsPathDataCarrier.getSignActive().setEntry(i, Math.signum(larsPathDataCarrier.getC_()));
        int i3 = i2 + i;
        MatrixUtilsExtensions.swap(larsPathDataCarrier.getCov(), i2, 0);
        MatrixUtilsExtensions.swap(larsPathDataCarrier.getIndices(), i3, i);
        MatrixUtilsExtensions.swap(larsPathDataCarrier.getXT(), i3, i);
        larsPathDataCarrier.setX(larsPathDataCarrier.getXT().transpose());
        larsPathDataCarrier.setCovNotShortened(larsPathDataCarrier.getCov().copy());
        larsPathDataCarrier.setCov(larsPathDataCarrier.getCov().getSubVector(1, larsPathDataCarrier.getCov().getDimension() - 1));
        CholeskyDecomposition choleskyDecomposition = new CholeskyDecomposition(computeGram(larsPathDataCarrier, true));
        larsPathDataCarrier.setDecomp(choleskyDecomposition);
        RealMatrix l = choleskyDecomposition.getL();
        checkRegressorDegeneracy(l.getEntry(l.getRowDimension() - 1, l.getColumnDimension() - 1), larsPathDataCarrier);
        if (larsPathDataCarrier.isDegenerateRegressor()) {
            return;
        }
        larsPathDataCarrier.getActive().add(Integer.valueOf(larsPathDataCarrier.getIndices()[i]));
        larsPathDataCarrier.setnActive(i + 1);
    }

    private static boolean minimumAlphaBreakCondition(LarsPathDataCarrier larsPathDataCarrier) {
        int i = larsPathDataCarrier.getnIter();
        int i2 = i > 0 ? 1 : 0;
        RealVector alphas = larsPathDataCarrier.getAlphas();
        RealMatrix coefs = larsPathDataCarrier.getCoefs();
        RealVector rowVector = coefs.getRowVector(i);
        RealVector rowVector2 = coefs.getRowVector(i - i2);
        if (alphas.getEntry(i) > larsPathDataCarrier.getEqualityTolerance()) {
            return false;
        }
        if (Math.abs(alphas.getEntry(i)) > larsPathDataCarrier.getEqualityTolerance()) {
            if (i > 0) {
                rowVector = rowVector2.mapAdd(alphas.getEntry(i - 1) / (alphas.getEntry(i - 1) - alphas.getEntry(i))).ebeMultiply(rowVector.subtract(rowVector2));
            }
            larsPathDataCarrier.getAlphas().setEntry(i, 0.0d);
        }
        larsPathDataCarrier.getCoefs().setRowVector(i, rowVector);
        return true;
    }

    private static boolean maximumIterationBreakCondition(LarsPathDataCarrier larsPathDataCarrier) {
        return larsPathDataCarrier.getnIter() >= larsPathDataCarrier.getMaxIterations() || larsPathDataCarrier.getnActive() >= larsPathDataCarrier.getnFeatures();
    }

    private static boolean earlyStoppingBreakCondition(LarsPathDataCarrier larsPathDataCarrier) {
        RealVector alphas = larsPathDataCarrier.getAlphas();
        int i = larsPathDataCarrier.getnIter();
        int i2 = larsPathDataCarrier.getnActive();
        if (i <= 0 || alphas.getEntry(i - 1) >= alphas.getEntry(i)) {
            return false;
        }
        LOGGER.warn(String.format("Early stopping the lars path, as the residues are small and the current value of alpha is no longer well controlled. %d iterations, alpha=%.3f, previous alpha=%.3f, with an active set of %d regressors.", Integer.valueOf(i), Double.valueOf(alphas.getEntry(i)), Double.valueOf(alphas.getEntry(i - 1)), Integer.valueOf(i2)));
        return true;
    }

    private static void getNormalizedLeastSquares(LarsPathDataCarrier larsPathDataCarrier) {
        double sqrt;
        RealVector subVector = larsPathDataCarrier.getSignActive().getSubVector(0, larsPathDataCarrier.getnActive());
        RealVector solve = larsPathDataCarrier.getDecomp().getSolver().solve(subVector);
        if (solve.getDimension() == 1 && solve.getEntry(0) == 0.0d) {
            solve.setEntry(0, 1.0d);
            sqrt = 1.0d;
        } else {
            sqrt = 1.0d / Math.sqrt(Arrays.stream(solve.ebeMultiply(subVector).toArray()).sum());
            solve.mapMultiplyToSelf(sqrt);
        }
        larsPathDataCarrier.setLeastSquares(solve);
        larsPathDataCarrier.setNormalizationFactor(sqrt);
    }

    private static void getCorrelationDirection(LarsPathDataCarrier larsPathDataCarrier) {
        RealVector createRealVector;
        double d;
        RealVector operate = larsPathDataCarrier.getXT().getSubMatrix(0, larsPathDataCarrier.getnActive() - 1, 0, larsPathDataCarrier.getnSamples() - 1).transpose().operate(larsPathDataCarrier.getLeastSquares());
        int i = larsPathDataCarrier.getnActive();
        int i2 = larsPathDataCarrier.getnFeatures();
        int i3 = larsPathDataCarrier.getnSamples();
        double c = larsPathDataCarrier.getC();
        double normalizationFactor = larsPathDataCarrier.getNormalizationFactor();
        if (i < i2) {
            createRealVector = larsPathDataCarrier.getXT().getSubMatrix(i, i2 - 1, 0, i3 - 1).operate(operate);
            createRealVector.mapToSelf(d2 -> {
                return Precision.round(d2, 16);
            });
            RealVector cov = larsPathDataCarrier.getCov();
            double tiny = larsPathDataCarrier.getTiny();
            d = Math.min(Math.min(MatrixUtilsExtensions.minPos(cov.map(d3 -> {
                return c - d3;
            }).ebeDivide(createRealVector.map(d4 -> {
                return (normalizationFactor - d4) + tiny;
            }))), MatrixUtilsExtensions.minPos(cov.mapAdd(c).ebeDivide(createRealVector.mapAdd(normalizationFactor + tiny)))), c / normalizationFactor);
        } else {
            createRealVector = MatrixUtils.createRealVector(new double[0]);
            d = c / normalizationFactor;
        }
        larsPathDataCarrier.setCorrEqDir(createRealVector);
        larsPathDataCarrier.setGamma(d);
    }

    private static void setZInformation(LarsPathDataCarrier larsPathDataCarrier) {
        RealVector rowVector = larsPathDataCarrier.getCoefs().getRowVector(larsPathDataCarrier.getnIter());
        larsPathDataCarrier.setZ(MatrixUtils.createRealVector(larsPathDataCarrier.getActive().stream().mapToDouble(num -> {
            return -rowVector.getEntry(num.intValue());
        }).toArray()).ebeDivide(larsPathDataCarrier.getLeastSquares().mapAdd(larsPathDataCarrier.getEqualityTolerance())));
        larsPathDataCarrier.setzPos(MatrixUtilsExtensions.minPos(larsPathDataCarrier.getZ()));
    }

    private static void getActiveIndices(LarsPathDataCarrier larsPathDataCarrier) {
        larsPathDataCarrier.setDrop(false);
        double d = larsPathDataCarrier.getzPos();
        RealVector z = larsPathDataCarrier.getZ();
        if (d < larsPathDataCarrier.getGamma()) {
            larsPathDataCarrier.setIdx((Set) IntStream.range(0, z.getDimension()).filter(i -> {
                return Math.abs(z.getEntry(i) - d) < larsPathDataCarrier.getEqualityTolerance();
            }).boxed().collect(Collectors.toSet()));
            Iterator<Integer> it = larsPathDataCarrier.getIdx().iterator();
            while (it.hasNext()) {
                int intValue = it.next().intValue();
                larsPathDataCarrier.getSignActive().setEntry(intValue, -larsPathDataCarrier.getSignActive().getEntry(intValue));
            }
            if (larsPathDataCarrier.isLasso()) {
                larsPathDataCarrier.setGamma(d);
            }
            larsPathDataCarrier.setDrop(true);
        }
    }

    private static void trackCoefficientsAndAlphas(LarsPathDataCarrier larsPathDataCarrier) {
        int i = larsPathDataCarrier.getnIter();
        int i2 = larsPathDataCarrier.getnActive();
        int i3 = larsPathDataCarrier.getnFeatures();
        int maxFeatures = larsPathDataCarrier.getMaxFeatures();
        List<Integer> active = larsPathDataCarrier.getActive();
        if (i >= larsPathDataCarrier.getCoefs().getRowDimension()) {
            int max = 2 * Math.max(1, maxFeatures - i2);
            RealMatrix createRealMatrix = MatrixUtils.createRealMatrix(i + max, i3);
            createRealMatrix.setSubMatrix(larsPathDataCarrier.getCoefs().getData(), 0, 0);
            larsPathDataCarrier.setCoefs(createRealMatrix);
            larsPathDataCarrier.setAlphas(larsPathDataCarrier.getAlphas().append(MatrixUtils.createRealVector(new double[max])));
        }
        for (int i4 = 0; i4 < active.size(); i4++) {
            larsPathDataCarrier.getCoefs().setEntry(i, active.get(i4).intValue(), larsPathDataCarrier.getCoefs().getEntry(i - 1, active.get(i4).intValue()) + (larsPathDataCarrier.getGamma() * larsPathDataCarrier.getLeastSquares().getEntry(i4)));
        }
    }

    private static void adjustCovarianceByCorrelationDirection(LarsPathDataCarrier larsPathDataCarrier) {
        if (larsPathDataCarrier.getCorrEqDir().getDimension() > 0) {
            larsPathDataCarrier.setCov(larsPathDataCarrier.getCov().subtract(larsPathDataCarrier.getCorrEqDir().mapMultiply(larsPathDataCarrier.getGamma())));
        }
    }

    private static void dropFeature(LarsPathDataCarrier larsPathDataCarrier) {
        larsPathDataCarrier.setnActive(larsPathDataCarrier.getnActive() - 1);
        int i = larsPathDataCarrier.getnActive();
        larsPathDataCarrier.setActive((List) IntStream.range(0, larsPathDataCarrier.getActive().size()).filter(i2 -> {
            return !larsPathDataCarrier.getIdx().contains(Integer.valueOf(i2));
        }).map(i3 -> {
            return larsPathDataCarrier.getActive().get(i3).intValue();
        }).boxed().collect(Collectors.toList()));
        RealMatrix coefs = larsPathDataCarrier.getCoefs();
        RealVector y = larsPathDataCarrier.getY();
        Iterator<Integer> it = larsPathDataCarrier.getIdx().iterator();
        while (it.hasNext()) {
            for (int intValue = it.next().intValue(); intValue < i; intValue++) {
                MatrixUtilsExtensions.swap(larsPathDataCarrier.getXT(), intValue, intValue + 1);
                larsPathDataCarrier.setX(larsPathDataCarrier.getXT().transpose());
                MatrixUtilsExtensions.swap(larsPathDataCarrier.getIndices(), intValue, intValue + 1);
            }
            larsPathDataCarrier.setCov(MatrixUtils.createRealVector(new double[]{larsPathDataCarrier.getXT().getRowVector(i).dotProduct(y.subtract(larsPathDataCarrier.getX().getSubMatrix(0, larsPathDataCarrier.getX().getRowDimension() - 1, 0, i - 1).operate(MatrixUtils.createRealVector(larsPathDataCarrier.getActive().stream().mapToDouble(num -> {
                return coefs.getEntry(larsPathDataCarrier.getnIter(), num.intValue());
            }).toArray()))))}).append(larsPathDataCarrier.getCov()));
        }
        larsPathDataCarrier.setSignActive(MatrixUtils.createRealVector(IntStream.range(0, larsPathDataCarrier.getSignActive().getDimension()).filter(i4 -> {
            return !larsPathDataCarrier.getIdx().contains(Integer.valueOf(i4));
        }).mapToDouble(i5 -> {
            return larsPathDataCarrier.getSignActive().getEntry(i5);
        }).toArray()));
        larsPathDataCarrier.setSignActive(larsPathDataCarrier.getSignActive().append(0.0d));
    }

    private static LarsPathResults truncatedAndFormattedResults(LarsPathDataCarrier larsPathDataCarrier) {
        int i = larsPathDataCarrier.getnIter();
        RealVector alphas = larsPathDataCarrier.getAlphas();
        RealMatrix coefs = larsPathDataCarrier.getCoefs();
        if (i + 1 < alphas.getDimension()) {
            larsPathDataCarrier.setAlphas(alphas.getSubVector(0, Math.min(i + 1, alphas.getDimension())));
            larsPathDataCarrier.setCoefs(coefs.getSubMatrix(0, Math.min(i, coefs.getRowDimension() - 1), 0, larsPathDataCarrier.getnFeatures() - 1));
        }
        return new LarsPathResults(larsPathDataCarrier.getCoefs().transpose(), larsPathDataCarrier.getAlphas(), larsPathDataCarrier.getActive(), i);
    }

    public static LarsPathResults fit(RealMatrix realMatrix, RealVector realVector, int i, boolean z) {
        if (realMatrix.getRowDimension() != realVector.getDimension()) {
            throw new IllegalArgumentException(String.format("Number of rows of X (%d) must match number of entries in y (%d)!", Integer.valueOf(realMatrix.getRowDimension()), Integer.valueOf(realVector.getDimension())));
        }
        LarsPathDataCarrier larsPathDataCarrier = new LarsPathDataCarrier(realMatrix, realVector, i, z);
        while (true) {
            updateCovarianceTrackers(larsPathDataCarrier);
            if (minimumAlphaBreakCondition(larsPathDataCarrier) || maximumIterationBreakCondition(larsPathDataCarrier)) {
                break;
            }
            getCholeskyDecomposition(larsPathDataCarrier);
            if (!larsPathDataCarrier.isDegenerateRegressor()) {
                if (larsPathDataCarrier.isLasso() && earlyStoppingBreakCondition(larsPathDataCarrier)) {
                    break;
                }
                getNormalizedLeastSquares(larsPathDataCarrier);
                getCorrelationDirection(larsPathDataCarrier);
                setZInformation(larsPathDataCarrier);
                getActiveIndices(larsPathDataCarrier);
                larsPathDataCarrier.incrementnIter();
                trackCoefficientsAndAlphas(larsPathDataCarrier);
                adjustCovarianceByCorrelationDirection(larsPathDataCarrier);
                if (larsPathDataCarrier.isLasso() && larsPathDataCarrier.isDrop()) {
                    dropFeature(larsPathDataCarrier);
                }
            }
        }
        return truncatedAndFormattedResults(larsPathDataCarrier);
    }
}
