package org.kie.kogito.explainability.utils;

import java.util.Arrays;
import java.util.Random;
import java.util.stream.IntStream;
import org.apache.commons.math3.distribution.TDistribution;

/* loaded from: input_file:org/kie/kogito/explainability/utils/WeightedLinearRegression.class */
public class WeightedLinearRegression {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/kie/kogito/explainability/utils/WeightedLinearRegression$ModelSquareSums.class */
    public static class ModelSquareSums {
        public final double residualSquareSum;
        public final double totalSquareSum;

        ModelSquareSums(double d, double d2) {
            this.residualSquareSum = d;
            this.totalSquareSum = d2;
        }
    }

    private WeightedLinearRegression() {
        throw new IllegalStateException("Utility class");
    }

    public static WeightedLinearRegressionResults fit(double[][] dArr, double[] dArr2, double[] dArr3, boolean z, Random random) throws IllegalArgumentException, ArithmeticException {
        int length = z ? dArr[0].length + 1 : dArr[0].length;
        int length2 = dArr2.length;
        if (dArr.length != length2) {
            throw new IllegalArgumentException(String.format("Num sample mismatch: Number of rows in the features (%d)", Integer.valueOf(dArr.length)) + String.format(" must match number of observations (%d)", Integer.valueOf(length2)));
        }
        double[][] adjustFeatureMatrix = adjustFeatureMatrix(dArr, z);
        double[][] dArr4 = new double[length][length];
        double[][] dArr5 = new double[length][1];
        for (int i = 0; i < length; i++) {
            dArr5[i][0] = 0.0d;
            for (int i2 = 0; i2 < length; i2++) {
                dArr4[i][i2] = 0.0d;
                for (int i3 = 0; i3 < length2; i3++) {
                    double[] dArr6 = dArr4[i];
                    int i4 = i2;
                    dArr6[i4] = dArr6[i4] + (dArr3[i3] * adjustFeatureMatrix[i3][i] * adjustFeatureMatrix[i3][i2]);
                    if (i2 == 0) {
                        double[] dArr7 = dArr5[i];
                        dArr7[0] = dArr7[0] + (dArr3[i3] * adjustFeatureMatrix[i3][i] * dArr2[i3]);
                    }
                }
            }
        }
        try {
            double[][] jitterInvert = MatrixUtilsExtensions.jitterInvert(dArr4, 10, 1.0E-9d, random);
            double[][] matrixMultiply = MatrixUtilsExtensions.matrixMultiply(jitterInvert, dArr5);
            double mse = getMSE(adjustFeatureMatrix, dArr2, dArr3, matrixMultiply);
            double[] varianceMatrix = getVarianceMatrix(adjustFeatureMatrix, dArr2, dArr3, matrixMultiply, jitterInvert);
            return new WeightedLinearRegressionResults(matrixMultiply, z, length2 - length, mse, varianceMatrix, getPValues(length, length2, varianceMatrix, matrixMultiply));
        } catch (ArithmeticException e) {
            throw new ArithmeticException("Weighted Linear Regression: Matrix cannot be inverted! This can be caused by a very under-specified model, where the ratio of samples to features is roughly less than 0.10. This model has a ratio of " + (length2 / length) + ".");
        }
    }

    private static double[][] adjustFeatureMatrix(double[][] dArr, boolean z) {
        int length = dArr.length;
        int length2 = z ? dArr[0].length + 1 : dArr[0].length;
        double[][] dArr2 = new double[length][length2];
        for (int i = 0; i < length; i++) {
            if (z) {
                System.arraycopy(dArr[i], 0, dArr2[i], 0, length2 - 1);
                dArr2[i][length2 - 1] = 1.0d;
            } else {
                System.arraycopy(dArr[i], 0, dArr2[i], 0, length2);
            }
        }
        return dArr2;
    }

    private static ModelSquareSums getRSSandTSS(double[][] dArr, double[] dArr2, double[] dArr3, double[][] dArr4) {
        int length = dArr[0].length;
        int length2 = dArr2.length;
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i = 0; i < length2; i++) {
            d += dArr3[i] * dArr2[i];
            d2 += dArr3[i];
        }
        if (d2 == 0.0d) {
            throw new ArithmeticException("Weights cannot sum to zero!");
        }
        double d3 = d / d2;
        double d4 = 0.0d;
        double d5 = 0.0d;
        for (int i2 = 0; i2 < length2; i2++) {
            double d6 = 0.0d;
            for (int i3 = 0; i3 < length; i3++) {
                d6 += dArr[i2][i3] * dArr4[i3][0];
            }
            double d7 = dArr2[i2] - d6;
            double d8 = dArr2[i2] - d3;
            d4 += dArr3[i2] * d8 * d8;
            d5 += dArr3[i2] * d7 * d7;
        }
        if (d4 == 0.0d) {
            throw new ArithmeticException("Total variance of observations is zero. Use more samples to correct this error");
        }
        return new ModelSquareSums(d5, d4);
    }

    private static double[] getVarianceMatrix(double[][] dArr, double[] dArr2, double[] dArr3, double[][] dArr4, double[][] dArr5) {
        int length = dArr[0].length;
        double length2 = getRSSandTSS(dArr, dArr2, dArr3, dArr4).residualSquareSum / (dArr2.length - length);
        double[] dArr6 = new double[length];
        for (int i = 0; i < length; i++) {
            for (int i2 = 0; i2 < length; i2++) {
                double[] dArr7 = dArr5[i];
                int i3 = i2;
                dArr7[i3] = dArr7[i3] * length2;
            }
            dArr6[i] = Math.sqrt(dArr5[i][i]);
        }
        return dArr6;
    }

    private static double[] getPValues(int i, int i2, double[] dArr, double[][] dArr2) {
        int i3 = i2 - i;
        if (i3 <= 0) {
            return IntStream.range(0, i).mapToDouble(i4 -> {
                return Double.POSITIVE_INFINITY;
            }).toArray();
        }
        double[] col = MatrixUtilsExtensions.getCol(dArr2, 0);
        double[] array = IntStream.range(0, dArr.length).mapToDouble(i5 -> {
            return col[i5] / dArr[i5];
        }).toArray();
        TDistribution tDistribution = new TDistribution(i3);
        return Arrays.stream(array).map(d -> {
            return 2.0d * (1.0d - tDistribution.cumulativeProbability(d));
        }).toArray();
    }

    private static double getMSE(double[][] dArr, double[] dArr2, double[] dArr3, double[][] dArr4) {
        int length = dArr[0].length;
        int length2 = dArr2.length;
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i = 0; i < length2; i++) {
            double d3 = 0.0d;
            for (int i2 = 0; i2 < length; i2++) {
                d3 += dArr[i][i2] * dArr4[i2][0];
            }
            double d4 = dArr2[i] - d3;
            d += dArr3[i] * d4 * d4;
            d2 += dArr3[i];
        }
        if (d2 == 0.0d) {
            throw new ArithmeticException("Weights cannot sum to zero!");
        }
        return d / d2;
    }
}
