package org.kie.kogito.explainability.utils;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.commons.math3.analysis.UnivariateFunction;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.linear.SingularMatrixException;
import org.apache.commons.math3.linear.SingularValueDecomposition;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;

/* loaded from: input_file:org/kie/kogito/explainability/utils/MatrixUtilsExtensions.class */
public class MatrixUtilsExtensions {
    private static final String SHAPE_STRING = "Matrix %s shape: %d x %d";

    /* loaded from: input_file:org/kie/kogito/explainability/utils/MatrixUtilsExtensions$Axis.class */
    public enum Axis {
        ROW,
        COLUMN
    }

    private MatrixUtilsExtensions() {
        throw new IllegalStateException("Utility class");
    }

    public static RealVector vectorFromPredictionInput(PredictionInput predictionInput) {
        return MatrixUtils.createRealVector(predictionInput.getFeatures().stream().mapToDouble(feature -> {
            return feature.getValue().asNumber();
        }).toArray());
    }

    public static RealMatrix matrixFromPredictionInput(List<PredictionInput> list) {
        return MatrixUtils.createRealMatrix((double[][]) list.stream().map(predictionInput -> {
            return predictionInput.getFeatures().stream().mapToDouble(feature -> {
                return feature.getValue().asNumber();
            }).toArray();
        }).toArray(i -> {
            return new double[i];
        }));
    }

    public static RealVector vectorFromPredictionOutput(PredictionOutput predictionOutput) {
        return MatrixUtils.createRealVector(predictionOutput.getOutputs().stream().mapToDouble(output -> {
            return output.getValue().asNumber();
        }).toArray());
    }

    public static RealMatrix matrixFromPredictionOutput(List<PredictionOutput> list) {
        return MatrixUtils.createRealMatrix((double[][]) list.stream().map(predictionOutput -> {
            return predictionOutput.getOutputs().stream().mapToDouble(output -> {
                return output.getValue().asNumber();
            }).toArray();
        }).toArray(i -> {
            return new double[i];
        }));
    }

    public static RealMatrix getPsuedoInverse(RealMatrix realMatrix) {
        SingularValueDecomposition singularValueDecomposition = new SingularValueDecomposition(realMatrix);
        RealMatrix u = singularValueDecomposition.getU();
        RealMatrix v = singularValueDecomposition.getV();
        RealMatrix s = singularValueDecomposition.getS();
        for (int i = 0; i < s.getRowDimension(); i++) {
            double entry = s.getEntry(i, i);
            if (entry > 1.0E-6d) {
                s.setEntry(i, i, 1.0d / entry);
            } else {
                s.setEntry(i, i, 0.0d);
            }
        }
        return v.multiply(s.transpose().multiply(u.transpose()));
    }

    public static RealMatrix safeInvert(RealMatrix realMatrix) {
        try {
            return MatrixUtils.inverse(realMatrix, 1.0E-6d);
        } catch (SingularMatrixException e) {
            return getPsuedoInverse(realMatrix);
        }
    }

    public static RealVector rowSum(RealMatrix realMatrix) {
        RealVector createRealVector = MatrixUtils.createRealVector(new double[realMatrix.getColumnDimension()]);
        for (int i = 0; i < realMatrix.getRowDimension(); i++) {
            createRealVector = createRealVector.add(realMatrix.getRowVector(i));
        }
        return createRealVector;
    }

    public static RealVector colSum(RealMatrix realMatrix) {
        RealVector createRealVector = MatrixUtils.createRealVector(new double[realMatrix.getRowDimension()]);
        for (int i = 0; i < realMatrix.getColumnDimension(); i++) {
            createRealVector = createRealVector.add(realMatrix.getColumnVector(i));
        }
        return createRealVector;
    }

    public static RealVector rowSquareSum(RealMatrix realMatrix) {
        RealVector createRealVector = MatrixUtils.createRealVector(new double[realMatrix.getColumnDimension()]);
        for (int i = 0; i < realMatrix.getRowDimension(); i++) {
            RealVector rowVector = realMatrix.getRowVector(i);
            createRealVector = createRealVector.add(rowVector.ebeMultiply(rowVector));
        }
        return createRealVector;
    }

    public static RealMatrix vectorDifference(RealMatrix realMatrix, RealVector realVector, Axis axis) {
        if (axis == Axis.ROW) {
            RealMatrix createMatrix = realMatrix.createMatrix(realMatrix.getRowDimension(), realMatrix.getColumnDimension());
            for (int i = 0; i < realMatrix.getRowDimension(); i++) {
                createMatrix.setRowVector(i, realMatrix.getRowVector(i).subtract(realVector));
            }
            return createMatrix;
        }
        RealMatrix createMatrix2 = realMatrix.createMatrix(realMatrix.getRowDimension(), realMatrix.getColumnDimension());
        for (int i2 = 0; i2 < realMatrix.getColumnDimension(); i2++) {
            createMatrix2.setColumnVector(i2, realMatrix.getColumnVector(i2).subtract(realVector));
        }
        return createMatrix2;
    }

    public static RealMatrix vectorRowProduct(RealMatrix realMatrix, RealVector realVector) {
        int rowDimension = realMatrix.getRowDimension();
        int columnDimension = realMatrix.getColumnDimension();
        int dimension = realVector.getDimension();
        if (columnDimension != dimension) {
            throw new IllegalArgumentException("Columns of matrix A must match size of vector b" + String.format(SHAPE_STRING, "A", Integer.valueOf(rowDimension), Integer.valueOf(columnDimension)) + String.format("Size of vector b: %d", Integer.valueOf(dimension)));
        }
        RealMatrix createRealMatrix = MatrixUtils.createRealMatrix(rowDimension, columnDimension);
        for (int i = 0; i < rowDimension; i++) {
            createRealMatrix.setRowVector(i, realMatrix.getRowVector(i).ebeMultiply(realVector));
        }
        return createRealMatrix;
    }

    public static RealMatrix map(RealMatrix realMatrix, UnivariateFunction univariateFunction) {
        RealMatrix copy = realMatrix.copy();
        for (int i = 0; i < realMatrix.getRowDimension(); i++) {
            copy.setRowVector(i, realMatrix.getRowVector(i).map(univariateFunction));
        }
        return copy;
    }

    public static RealMatrix getCols(RealMatrix realMatrix, List<Integer> list) {
        if (list.isEmpty()) {
            throw new IllegalArgumentException("Empty column idxs passed to getCols");
        }
        RealMatrix createRealMatrix = MatrixUtils.createRealMatrix(new double[realMatrix.getRowDimension()][list.size()]);
        for (int i = 0; i < list.size(); i++) {
            if (list.get(i).intValue() >= realMatrix.getColumnDimension() || list.get(i).intValue() < 0) {
                throw new IllegalArgumentException(String.format("Column index %d output bounds, matrix only has %d column(s)", list.get(i), Integer.valueOf(realMatrix.getColumnDimension())));
            }
            createRealMatrix.setColumnVector(i, realMatrix.getColumnVector(list.get(i).intValue()));
        }
        return createRealMatrix;
    }

    public static RealMatrix matrixDot(RealMatrix realMatrix, RealMatrix realMatrix2) {
        int rowDimension = realMatrix.getRowDimension();
        int columnDimension = realMatrix.getColumnDimension();
        int rowDimension2 = realMatrix2.getRowDimension();
        int columnDimension2 = realMatrix2.getColumnDimension();
        if (columnDimension != rowDimension2) {
            throw new IllegalArgumentException("Columns of matrix A must match rows of matrix B" + String.format(SHAPE_STRING, "A", Integer.valueOf(rowDimension), Integer.valueOf(columnDimension)) + String.format(SHAPE_STRING, "B", Integer.valueOf(rowDimension2), Integer.valueOf(columnDimension2)));
        }
        RealMatrix createRealMatrix = MatrixUtils.createRealMatrix(rowDimension, columnDimension2);
        for (int i = 0; i < rowDimension; i++) {
            for (int i2 = 0; i2 < columnDimension2; i2++) {
                createRealMatrix.setEntry(i, i2, realMatrix.getRowVector(i).dotProduct(realMatrix2.getColumnVector(i2)));
            }
        }
        return createRealMatrix;
    }

    public static double minPos(RealVector realVector) {
        double d = Double.MAX_VALUE;
        for (int i = 0; i < realVector.getDimension(); i++) {
            double entry = realVector.getEntry(i);
            if (entry > 0.0d && entry < d) {
                d = entry;
            }
        }
        return d;
    }

    public static List<Integer> nonzero(RealVector realVector) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < realVector.getDimension(); i++) {
            if (realVector.getEntry(i) != 0.0d) {
                arrayList.add(Integer.valueOf(i));
            }
        }
        return arrayList;
    }

    public static double variance(RealVector realVector) {
        double sum = Arrays.stream(realVector.toArray()).sum() / realVector.getDimension();
        return Arrays.stream(realVector.map(d -> {
            return Math.pow(d - sum, 2.0d);
        }).toArray()).sum() / realVector.getDimension();
    }

    public static double sum(RealVector realVector) {
        return Arrays.stream(realVector.toArray()).sum();
    }

    public static void swap(RealMatrix realMatrix, int i, int i2) {
        double[] row = realMatrix.getRow(i);
        realMatrix.setRow(i, realMatrix.getRow(i2));
        realMatrix.setRow(i2, row);
    }

    public static void swap(RealVector realVector, int i, int i2) {
        double entry = realVector.getEntry(i);
        realVector.setEntry(i, realVector.getEntry(i2));
        realVector.setEntry(i2, entry);
    }

    public static void swap(int[] iArr, int i, int i2) {
        int i3 = iArr[i];
        iArr[i] = iArr[i2];
        iArr[i2] = i3;
    }
}
