package org.kie.kogito.explainability.utils;

import java.security.SecureRandom;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;

/* loaded from: input_file:org/kie/kogito/explainability/utils/MatrixUtilsExtensions.class */
public class MatrixUtilsExtensions {

    /* loaded from: input_file:org/kie/kogito/explainability/utils/MatrixUtilsExtensions$Axis.class */
    public enum Axis {
        ROW,
        COLUMN
    }

    private MatrixUtilsExtensions() {
        throw new IllegalStateException("Utility class");
    }

    public static double[][] matrixFromPredictionInput(PredictionInput predictionInput) {
        return rowVector(predictionInput.getFeatures().stream().mapToDouble(feature -> {
            return feature.getValue().asNumber();
        }).toArray());
    }

    public static double[][] matrixFromPredictionInput(List<PredictionInput> list) {
        return (double[][]) list.stream().map(predictionInput -> {
            return predictionInput.getFeatures().stream().mapToDouble(feature -> {
                return feature.getValue().asNumber();
            }).toArray();
        }).toArray(i -> {
            return new double[i];
        });
    }

    public static double[][] matrixFromPredictionOutput(PredictionOutput predictionOutput) {
        return rowVector(predictionOutput.getOutputs().stream().mapToDouble(output -> {
            return output.getValue().asNumber();
        }).toArray());
    }

    public static double[][] matrixFromPredictionOutput(List<PredictionOutput> list) {
        return (double[][]) list.stream().map(predictionOutput -> {
            return predictionOutput.getOutputs().stream().mapToDouble(output -> {
                return output.getValue().asNumber();
            }).toArray();
        }).toArray(i -> {
            return new double[i];
        });
    }

    public static double[][] rowVector(double[] dArr) {
        return new double[][]{dArr};
    }

    public static double[][] columnVector(double[] dArr) {
        double[][] dArr2 = new double[dArr.length][1];
        for (int i = 0; i < dArr.length; i++) {
            dArr2[i][0] = dArr[i];
        }
        return dArr2;
    }

    public static int[] getShape(double[][] dArr) {
        return new int[]{dArr.length, dArr[0].length};
    }

    public static double[] getCol(double[][] dArr, int i) {
        int i2 = getShape(dArr)[1];
        if (i2 <= i || i < 0) {
            throw new IllegalArgumentException(String.format("Column index %d too large, matrix only has %d column(s)", Integer.valueOf(i), Integer.valueOf(i2)));
        }
        return Arrays.stream(dArr).mapToDouble(dArr2 -> {
            return dArr2[i];
        }).toArray();
    }

    public static double[][] getCols(double[][] dArr, List<Integer> list) {
        if (list.isEmpty()) {
            throw new IllegalArgumentException("Empty column idxs passed to getCols");
        }
        int[] shape = getShape(dArr);
        double[][] dArr2 = new double[shape[0]][list.size()];
        for (int i = 0; i < shape[0]; i++) {
            for (int i2 = 0; i2 < list.size(); i2++) {
                if (list.get(i2).intValue() >= shape[1] || list.get(i2).intValue() < 0) {
                    throw new IllegalArgumentException(String.format("Column index %d output bounds, matrix only has %d column(s)", Integer.valueOf(i2), Integer.valueOf(shape[1])));
                }
                dArr2[i][i2] = dArr[i][list.get(i2).intValue()];
            }
        }
        return dArr2;
    }

    public static double[][] matrixSum(double[][] dArr, double[][] dArr2) {
        int[] shape = getShape(dArr);
        int[] shape2 = getShape(dArr2);
        if (!Arrays.equals(shape, shape2)) {
            throw new IllegalArgumentException("Shape of matrix A must shape of matrix B" + String.format("Matrix A shape:  %d x %d, ", Integer.valueOf(shape[0]), Integer.valueOf(shape[1])) + String.format("Matrix B shape:  %d x %d,", Integer.valueOf(shape2[0]), Integer.valueOf(shape2[1])));
        }
        double[][] dArr3 = new double[shape[0]][shape[1]];
        for (int i = 0; i < shape[0]; i++) {
            for (int i2 = 0; i2 < shape[1]; i2++) {
                dArr3[i][i2] = dArr[i][i2] + dArr2[i][i2];
            }
        }
        return dArr3;
    }

    public static double[][] matrixRowSum(double[][] dArr, double[] dArr2) {
        int[] shape = getShape(dArr);
        double[][] rowVector = rowVector(dArr2);
        double[][] dArr3 = new double[shape[0]][shape[1]];
        for (int i = 0; i < shape[0]; i++) {
            dArr3[i] = matrixSum(rowVector(dArr[i]), rowVector)[0];
        }
        return dArr3;
    }

    public static double[][] matrixDifference(double[][] dArr, double[][] dArr2) {
        return matrixSum(dArr, matrixMultiply(dArr2, -1.0d));
    }

    public static double[][] matrixRowDifference(double[][] dArr, double[] dArr2) {
        return matrixRowSum(dArr, Arrays.stream(dArr2).map(d -> {
            return -d;
        }).toArray());
    }

    public static double[][] matrixMultiply(double[][] dArr, double[][] dArr2) {
        int[] shape = getShape(dArr);
        int[] shape2 = getShape(dArr2);
        if (shape[1] != shape2[0]) {
            throw new IllegalArgumentException("# columns of matrix A must match # rows of matrix B" + String.format("Matrix A shape:  %d x %d, ", Integer.valueOf(shape[0]), Integer.valueOf(shape[1])) + String.format("Matrix B shape:  %d x %d,", Integer.valueOf(shape2[0]), Integer.valueOf(shape2[1])));
        }
        double[][] dArr3 = new double[shape[0]][shape2[1]];
        for (int i = 0; i < shape[0]; i++) {
            for (int i2 = 0; i2 < shape2[1]; i2++) {
                for (int i3 = 0; i3 < shape[1]; i3++) {
                    double[] dArr4 = dArr3[i];
                    int i4 = i2;
                    dArr4[i4] = dArr4[i4] + (dArr[i][i3] * dArr2[i3][i2]);
                }
            }
        }
        return dArr3;
    }

    public static double[][] matrixMultiply(double[][] dArr, double d) {
        int[] shape = getShape(dArr);
        double[][] dArr2 = new double[shape[0]][shape[1]];
        for (int i = 0; i < shape[0]; i++) {
            for (int i2 = 0; i2 < shape[1]; i2++) {
                dArr2[i][i2] = dArr[i][i2] * d;
            }
        }
        return dArr2;
    }

    public static double[] sum(double[][] dArr, Axis axis) {
        int[] shape = getShape(dArr);
        if (axis == Axis.ROW) {
            double[][] dArr2 = new double[1][shape[1]];
            for (int i = 0; i < shape[0]; i++) {
                dArr2 = matrixSum(dArr2, rowVector(dArr[i]));
            }
            return dArr2[0];
        }
        double[][] dArr3 = new double[1][shape[0]];
        for (int i2 = 0; i2 < shape[1]; i2++) {
            dArr3 = matrixSum(dArr3, rowVector(getCol(dArr, i2)));
        }
        return dArr3[0];
    }

    public static double[][] transpose(double[][] dArr) {
        int[] shape = getShape(dArr);
        double[][] dArr2 = new double[shape[1]][shape[0]];
        for (int i = 0; i < shape[0]; i++) {
            for (int i2 = 0; i2 < shape[1]; i2++) {
                dArr2[i2][i] = dArr[i][i2];
            }
        }
        return dArr2;
    }

    private static int findPivot(double[][] dArr, boolean[] zArr) {
        double d = 0.0d;
        int i = 0;
        int i2 = getShape(dArr)[0];
        for (int i3 = 0; i3 < i2; i3++) {
            double abs = Math.abs(dArr[i3][i3]);
            if (abs > d && !zArr[i3]) {
                i = i3;
                d = abs;
            }
        }
        return i;
    }

    private static double[][] invertSquareMatrix(double[][] dArr, double d) {
        int i = getShape(dArr)[0];
        double[][] dArr2 = new double[i][i];
        for (int i2 = 0; i2 < i; i2++) {
            dArr2[i2] = Arrays.copyOf(dArr[i2], i);
        }
        boolean[] zArr = new boolean[i];
        Arrays.fill(zArr, false);
        for (int i3 = 0; i3 < i; i3++) {
            int findPivot = findPivot(dArr2, zArr);
            double d2 = dArr2[findPivot][findPivot];
            if (Math.abs(d2) < d) {
                throw new ArithmeticException("Matrix is singular and cannot be inverted");
            }
            dArr2[findPivot][findPivot] = 1.0d;
            zArr[findPivot] = true;
            for (int i4 = 0; i4 < i; i4++) {
                double[] dArr3 = dArr2[findPivot];
                int i5 = i4;
                dArr3[i5] = dArr3[i5] / d2;
            }
            for (int i6 = 0; i6 < i; i6++) {
                if (i6 != findPivot) {
                    double d3 = dArr2[i6][findPivot];
                    dArr2[i6][findPivot] = 0.0d;
                    for (int i7 = 0; i7 < i; i7++) {
                        double[] dArr4 = dArr2[i6];
                        int i8 = i7;
                        dArr4[i8] = dArr4[i8] - (dArr2[findPivot][i7] * d3);
                    }
                }
            }
        }
        return dArr2;
    }

    public static double[][] jitterInvert(double[][] dArr, int i, double d, Random random) {
        for (int i2 = 0; i2 < i; i2++) {
            try {
                return invertSquareMatrix(dArr, d);
            } catch (ArithmeticException e) {
                jitterMatrix(dArr, 1.0E-8d, random);
            }
        }
        throw new ArithmeticException("Matrix is singular and could not be inverted via jittering");
    }

    public static double[][] jitterInvert(double[][] dArr, int i, double d) {
        return jitterInvert(dArr, i, d, new SecureRandom());
    }

    private static void jitterMatrix(double[][] dArr, double d, Random random) {
        for (double[] dArr2 : dArr) {
            for (int i = 0; i < dArr[0].length; i++) {
                int i2 = i;
                dArr2[i2] = dArr2[i2] + (d * random.nextDouble());
            }
        }
    }
}
