package org.kie.kogito.explainability.utils;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.kie.kogito.explainability.model.FeatureFactory;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.Type;
import org.kie.kogito.explainability.model.Value;
import org.kie.kogito.explainability.utils.MatrixUtils;

/* loaded from: input_file:org/kie/kogito/explainability/utils/MatrixUtilsTest.class */
class MatrixUtilsTest {
    double[][] matOneElem = {new double[]{5.0d}};
    double[] vector = {5.0d, 6.0d, 7.0d};
    double[][] matRowVector = {new double[]{5.0d, 6.0d, 7.0d}};
    double[][] matColVector = {new double[]{5.0d}, new double[]{6.0d}, new double[]{7.0d}};
    double[][] vectorProdRowCol = {new double[]{110.0d}};
    double[][] vectorProdColRow = {new double[]{25.0d, 30.0d, 35.0d}, new double[]{30.0d, 36.0d, 42.0d}, new double[]{35.0d, 42.0d, 49.0d}};
    double[][] mat4X3 = {new double[]{1.0d, 2.0d, 3.0d}, new double[]{10.0d, 5.0d, -3.0d}, new double[]{14.0d, -6.6d, 7.0d}, new double[]{0.0d, 5.0d, -3.0d}};
    double[][] mat3X4 = {new double[]{1.0d, 10.0d, 14.0d, 0.0d}, new double[]{2.0d, 5.0d, -6.6d, 5.0d}, new double[]{3.0d, -3.0d, 7.0d, -3.0d}};
    double[][] mat3X5 = {new double[]{1.0d, 10.0d, 3.0d, -4.0d, 0.0d}, new double[]{10.0d, 5.0d, -3.0d, 3.7d, 1.0d}, new double[]{14.0d, -6.6d, 7.0d, 14.0d, 3.0d}};
    double[][] mat3X5get013 = {new double[]{1.0d, 10.0d, -4.0d}, new double[]{10.0d, 5.0d, 3.7d}, new double[]{14.0d, -6.6d, 14.0d}};
    double[][] mat3X5get03130 = {new double[]{1.0d, -4.0d, 10.0d, -4.0d, 1.0d}, new double[]{10.0d, 3.7d, 5.0d, 3.7d, 10.0d}, new double[]{14.0d, 14.0d, -6.6d, 14.0d, 14.0d}};
    double[][] mat43X35Product = {new double[]{63.0d, 0.2d, 18.0d, 45.4d, 11.0d}, new double[]{18.0d, 144.8d, -6.0d, -63.5d, -4.0d}, new double[]{46.0d, 60.8d, 110.8d, 17.58d, 14.4d}, new double[]{8.0d, 44.8d, -36.0d, -23.5d, -4.0d}};
    double[][] matSquareNonSingular = {new double[]{1.0d, 2.0d, 3.0d}, new double[]{10.0d, 5.0d, -3.0d}, new double[]{14.0d, -6.6d, 7.0d}};
    double[][] matSNSInv = {new double[]{-0.02464332d, 0.05479896d, 0.03404669d}, new double[]{0.18158236d, 0.05674449d, -0.05350195d}, new double[]{0.22049287d, -0.05609598d, 0.02431907d}};
    double[][] matSquareSingular = {new double[]{1.0d, 2.0d, 3.0d}, new double[]{4.0d, 5.0d, 6.0d}, new double[]{7.0d, 8.0d, 9.0d}};
    double[][] identity = {new double[]{1.0d, 0.0d, 0.0d}, new double[]{0.0d, 1.0d, 0.0d}, new double[]{0.0d, 0.0d, 1.0d}};
    double[][] matIdentityPlusVector = {new double[]{6.0d, 6.0d, 7.0d}, new double[]{5.0d, 7.0d, 7.0d}, new double[]{5.0d, 6.0d, 8.0d}};
    double[][] mssPlusIdentity = {new double[]{2.0d, 2.0d, 3.0d}, new double[]{4.0d, 6.0d, 6.0d}, new double[]{7.0d, 8.0d, 10.0d}};
    double[][] mssMinusIdentity = {new double[]{0.0d, 2.0d, 3.0d}, new double[]{4.0d, 4.0d, 6.0d}, new double[]{7.0d, 8.0d, 8.0d}};
    double[] mssSumRow = {12.0d, 15.0d, 18.0d};
    double[] mssSumCol = {6.0d, 15.0d, 24.0d};
    Random rn = new Random();

    /* JADX WARN: Type inference failed for: r1v1, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v11, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v13, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v15, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v17, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v19, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v21, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v23, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v25, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v27, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v29, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v31, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v33, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v35, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v37, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v5, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v7, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v9, types: [double[], double[][]] */
    MatrixUtilsTest() {
    }

    @Test
    void testPICreation() {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < 5; i++) {
            arrayList.add(FeatureFactory.newNumericalFeature("f", Double.valueOf(this.mat3X5[0][i])));
        }
        Assertions.assertArrayEquals(this.mat3X5[0], MatrixUtils.matrixFromPredictionInput(new PredictionInput(arrayList))[0]);
    }

    @Test
    void testPIListCreation() {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < 3; i++) {
            ArrayList arrayList2 = new ArrayList();
            for (int i2 = 0; i2 < 5; i2++) {
                arrayList2.add(FeatureFactory.newNumericalFeature("f", Double.valueOf(this.mat3X5[i][i2])));
            }
            arrayList.add(new PredictionInput(arrayList2));
        }
        Assertions.assertArrayEquals(this.mat3X5, MatrixUtils.matrixFromPredictionInput(arrayList));
    }

    @Test
    void testPOCreation() {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < 5; i++) {
            arrayList.add(new Output("o", Type.NUMBER, new Value(Double.valueOf(this.mat3X5[0][i])), 0.0d));
        }
        Assertions.assertArrayEquals(this.mat3X5[0], MatrixUtils.matrixFromPredictionOutput(new PredictionOutput(arrayList))[0]);
    }

    @Test
    void testPOListCreation() {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < 3; i++) {
            ArrayList arrayList2 = new ArrayList();
            for (int i2 = 0; i2 < 5; i2++) {
                arrayList2.add(new Output("o", Type.NUMBER, new Value(Double.valueOf(this.mat3X5[i][i2])), 0.0d));
            }
            arrayList.add(new PredictionOutput(arrayList2));
        }
        Assertions.assertArrayEquals(this.mat3X5, MatrixUtils.matrixFromPredictionOutput(arrayList));
    }

    @Test
    void testRowVectorCreation() {
        double[][] rowVector = MatrixUtils.rowVector(this.vector);
        for (int i = 0; i < rowVector.length; i++) {
            Assertions.assertEquals(rowVector[0][i], this.vector[i]);
        }
    }

    @Test
    void testColVectorCreation() {
        double[][] columnVector = MatrixUtils.columnVector(this.vector);
        for (int i = 0; i < columnVector.length; i++) {
            Assertions.assertEquals(columnVector[i][0], this.vector[i]);
        }
    }

    @Test
    void testShape() {
        Assertions.assertArrayEquals(new int[]{3, 5}, MatrixUtils.getShape(this.mat3X5));
    }

    @Test
    void testGetColTooBig() {
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            MatrixUtils.getCol(this.mat3X4, 10);
        });
    }

    @Test
    void testGetNegCol() {
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            MatrixUtils.getCol(this.mat3X4, -10);
        });
    }

    @Test
    void testGetCol() {
        Assertions.assertArrayEquals(MatrixUtils.getCol(this.mat3X4, 1), new double[]{10.0d, 5.0d, -3.0d});
    }

    @Test
    void testGetCols() {
        double[][] cols = MatrixUtils.getCols(this.mat3X5, List.of(0, 1, 3));
        for (int i = 0; i < this.mat3X5get013.length; i++) {
            Assertions.assertArrayEquals(this.mat3X5get013[i], cols[i]);
        }
    }

    @Test
    void testGetDupCols() {
        double[][] cols = MatrixUtils.getCols(this.mat3X5, List.of(0, 3, 1, 3, 0));
        for (int i = 0; i < this.mat3X5get03130.length; i++) {
            Assertions.assertArrayEquals(this.mat3X5get03130[i], cols[i]);
        }
    }

    @Test
    void testGetColsTooBig() {
        List of = List.of(0, 6);
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            MatrixUtils.getCols(this.mat3X5, of);
        });
    }

    @Test
    void testGetNegCols() {
        List of = List.of(0, -6);
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            MatrixUtils.getCols(this.mat3X5, of);
        });
    }

    @Test
    void testGetNoCols() {
        List of = List.of();
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            MatrixUtils.getCols(this.mat3X5, of);
        });
    }

    @Test
    void testOneElemTranspose() {
        double[][] transpose = MatrixUtils.transpose(this.matOneElem);
        for (int i = 0; i < transpose.length; i++) {
            Assertions.assertArrayEquals(transpose[i], this.matOneElem[i]);
        }
    }

    @Test
    void testVectorTranspose() {
        double[][] transpose = MatrixUtils.transpose(this.matRowVector);
        for (int i = 0; i < transpose.length; i++) {
            Assertions.assertArrayEquals(transpose[i], this.matColVector[i]);
        }
    }

    @Test
    void testMatrixTranspose() {
        double[][] transpose = MatrixUtils.transpose(this.mat3X4);
        for (int i = 0; i < transpose.length; i++) {
            Assertions.assertArrayEquals(transpose[i], this.mat4X3[i]);
        }
    }

    @Test
    void testIntraSumRow() {
        Assertions.assertArrayEquals(this.mssSumRow, MatrixUtils.sum(this.matSquareSingular, MatrixUtils.Axis.ROW), 1.0E-6d);
    }

    @Test
    void testIntraSumCol() {
        Assertions.assertArrayEquals(this.mssSumCol, MatrixUtils.sum(this.matSquareSingular, MatrixUtils.Axis.COLUMN), 1.0E-6d);
    }

    @Test
    void testMatSum() {
        double[][] matrixSum = MatrixUtils.matrixSum(this.matSquareSingular, this.identity);
        for (int i = 0; i < matrixSum.length; i++) {
            Assertions.assertArrayEquals(this.mssPlusIdentity[i], matrixSum[i], 1.0E-6d);
        }
    }

    @Test
    void testMatSumWrongSizes() {
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            MatrixUtils.matrixSum(this.matSquareSingular, this.mat4X3);
        });
    }

    @Test
    void testMatDiff() {
        double[][] matrixDifference = MatrixUtils.matrixDifference(this.matSquareSingular, this.identity);
        for (int i = 0; i < matrixDifference.length; i++) {
            Assertions.assertArrayEquals(this.mssMinusIdentity[i], matrixDifference[i], 1.0E-6d);
        }
    }

    @Test
    void testMatRowSum() {
        double[][] matrixRowSum = MatrixUtils.matrixRowSum(this.identity, this.vector);
        for (int i = 0; i < matrixRowSum.length; i++) {
            Assertions.assertArrayEquals(this.matIdentityPlusVector[i], matrixRowSum[i], 1.0E-6d);
        }
    }

    @Test
    void testMatRowDiff() {
        double[][] matrixRowDifference = MatrixUtils.matrixRowDifference(this.matIdentityPlusVector, this.vector);
        for (int i = 0; i < matrixRowDifference.length; i++) {
            Assertions.assertArrayEquals(this.identity[i], matrixRowDifference[i], 1.0E-6d);
        }
    }

    @Test
    void testMatMulScalar() {
        double[][] matrixMultiply = MatrixUtils.matrixMultiply(this.mat4X3, 3.0d);
        for (int i = 0; i < matrixMultiply.length; i++) {
            for (int i2 = 0; i2 < matrixMultiply[0].length; i2++) {
                Assertions.assertEquals(this.mat4X3[i][i2] * 3.0d, matrixMultiply[i][i2], 1.0E-6d);
            }
        }
    }

    @Test
    void testMatMulByZero() {
        double[][] matrixMultiply = MatrixUtils.matrixMultiply(this.mat4X3, 0.0d);
        for (int i = 0; i < matrixMultiply.length; i++) {
            for (int i2 = 0; i2 < matrixMultiply[0].length; i2++) {
                Assertions.assertEquals(this.mat4X3[i][i2] * 0.0d, matrixMultiply[i][i2], 1.0E-6d);
            }
        }
    }

    @Test
    void testMatMulNormal() {
        double[][] matrixMultiply = MatrixUtils.matrixMultiply(this.mat4X3, this.mat3X5);
        for (int i = 0; i < matrixMultiply.length; i++) {
            Assertions.assertArrayEquals(this.mat43X35Product[i], matrixMultiply[i], 1.0E-6d);
        }
    }

    @Test
    void testMatMulWrongShape() {
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            MatrixUtils.matrixMultiply(this.mat3X4, this.mat3X5);
        });
    }

    @Test
    void testVectorRowColMultiply() {
        double[][] matrixMultiply = MatrixUtils.matrixMultiply(this.matRowVector, this.matColVector);
        for (int i = 0; i < matrixMultiply.length; i++) {
            Assertions.assertArrayEquals(this.vectorProdRowCol[i], matrixMultiply[i], 1.0E-6d);
        }
    }

    @Test
    void testVectorColRowMultiply() {
        double[][] matrixMultiply = MatrixUtils.matrixMultiply(this.matColVector, this.matRowVector);
        for (int i = 0; i < matrixMultiply.length; i++) {
            Assertions.assertArrayEquals(this.vectorProdColRow[i], matrixMultiply[i], 1.0E-6d);
        }
    }

    @Test
    void testInvertNormal() {
        double[][] jitterInvert = MatrixUtils.jitterInvert(this.matSquareNonSingular, 1, 1.0E-9d, this.rn);
        for (int i = 0; i < jitterInvert.length; i++) {
            Assertions.assertArrayEquals(this.matSNSInv[i], jitterInvert[i], 1.0E-4d);
        }
    }

    @Test
    void testInvertSingular() {
        Assertions.assertThrows(ArithmeticException.class, () -> {
            MatrixUtils.jitterInvert(this.matSquareSingular, 1, 1.0E-9d, this.rn);
        });
    }

    @Test
    void testJitterInvert() {
        for (int i = 0; i < 100; i++) {
            double[][] matrixMultiply = MatrixUtils.matrixMultiply(this.matSquareSingular, MatrixUtils.jitterInvert(this.matSquareSingular, 10, 1.0E-9d, this.rn));
            for (int i2 = 0; i2 < matrixMultiply.length; i2++) {
                Assertions.assertArrayEquals(matrixMultiply[i2], this.identity[i2], 1.0E-4d);
            }
        }
    }

    @Test
    void testSecureJitterInvert() {
        for (int i = 0; i < 100; i++) {
            double[][] matrixMultiply = MatrixUtils.matrixMultiply(this.matSquareSingular, MatrixUtils.jitterInvert(this.matSquareSingular, 10, 1.0E-9d));
            for (int i2 = 0; i2 < matrixMultiply.length; i2++) {
                Assertions.assertArrayEquals(matrixMultiply[i2], this.identity[i2], 1.0E-4d);
            }
        }
    }
}
