package org.kie.kogito.explainability.utils;

import java.util.Arrays;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.kie.kogito.explainability.utils.MatrixUtilsExtensions;

/* loaded from: input_file:org/kie/kogito/explainability/utils/LassoLarsIC.class */
public class LassoLarsIC {

    /* loaded from: input_file:org/kie/kogito/explainability/utils/LassoLarsIC$Criterion.class */
    public enum Criterion {
        AIC,
        BIC
    }

    private LassoLarsIC() {
        throw new IllegalStateException("Utility class");
    }

    public static LassoLarsICResults fit(RealMatrix realMatrix, RealVector realVector, Criterion criterion) {
        return fit(realMatrix, realVector, criterion, realMatrix.getColumnDimension() * 200);
    }

    public static LassoLarsICResults fit(RealMatrix realMatrix, RealVector realVector, Criterion criterion, int i) {
        int rowDimension = realMatrix.getRowDimension();
        double ulp = Math.ulp(1.0f);
        double ulp2 = Math.ulp(1.0d);
        RealVector mapDivide = MatrixUtilsExtensions.rowSum(realMatrix).mapDivide(rowDimension);
        double sum = Arrays.stream(realVector.toArray()).sum() / rowDimension;
        RealMatrix vectorDifference = MatrixUtilsExtensions.vectorDifference(realMatrix, mapDivide, MatrixUtilsExtensions.Axis.ROW);
        RealVector mapSubtract = realVector.mapSubtract(sum);
        LarsPathResults fit = LarsPath.fit(vectorDifference, mapSubtract, i, true);
        double log = criterion == Criterion.AIC ? 2.0d : Math.log(rowDimension);
        RealVector mapDivide2 = MatrixUtilsExtensions.rowSquareSum(MatrixUtilsExtensions.vectorDifference(MatrixUtilsExtensions.matrixDot(vectorDifference, fit.getCoefs()), mapSubtract, MatrixUtilsExtensions.Axis.COLUMN)).mapDivide(r0.getRowDimension());
        double variance = MatrixUtilsExtensions.variance(mapSubtract);
        RealVector createRealVector = MatrixUtils.createRealVector(new double[fit.getCoefs().getColumnDimension()]);
        RealMatrix transpose = fit.getCoefs().transpose();
        for (int i2 = 0; i2 < transpose.getRowDimension(); i2++) {
            double sum2 = Arrays.stream(transpose.getRowVector(i2).map(d -> {
                return Math.abs(d) > ulp ? 1.0d : 0.0d;
            }).toArray()).sum();
            if (sum2 != 0.0d) {
                createRealVector.setEntry(i2, sum2);
            }
        }
        int minIndex = mapDivide2.mapMultiply(rowDimension).mapDivide(variance + ulp2).add(createRealVector.mapMultiply(log)).getMinIndex();
        RealVector columnVector = fit.getCoefs().getColumnVector(minIndex);
        return new LassoLarsICResults(columnVector, fit.getAlphas().getEntry(minIndex), sum - mapDivide.dotProduct(columnVector));
    }
}
