package org.kie.kogito.explainability.utils;

import java.util.Random;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:org/kie/kogito/explainability/utils/MatrixUtilsTest.class */
class MatrixUtilsTest {
    private static final double[][] matOneElem = {new double[]{5.0d}};
    private static final double[][] matRowVector = {new double[]{5.0d, 6.0d, 7.0d}};
    private static final double[][] matColVector = {new double[]{5.0d}, new double[]{6.0d}, new double[]{7.0d}};
    private static final double[][] vectorProdRowCol = {new double[]{110.0d}};
    private static final double[][] vectorProdColRow = {new double[]{25.0d, 30.0d, 35.0d}, new double[]{30.0d, 36.0d, 42.0d}, new double[]{35.0d, 42.0d, 49.0d}};
    private static final double[][] mat4X3 = {new double[]{1.0d, 2.0d, 3.0d}, new double[]{10.0d, 5.0d, -3.0d}, new double[]{14.0d, -6.6d, 7.0d}, new double[]{0.0d, 5.0d, -3.0d}};
    private static final double[][] mat3X4 = {new double[]{1.0d, 10.0d, 14.0d, 0.0d}, new double[]{2.0d, 5.0d, -6.6d, 5.0d}, new double[]{3.0d, -3.0d, 7.0d, -3.0d}};
    private static final double[][] mat3X5 = {new double[]{1.0d, 10.0d, 3.0d, -4.0d, 0.0d}, new double[]{10.0d, 5.0d, -3.0d, 3.7d, 1.0d}, new double[]{14.0d, -6.6d, 7.0d, 14.0d, 3.0d}};
    private static final double[][] mat43X35Product = {new double[]{63.0d, 0.2d, 18.0d, 45.4d, 11.0d}, new double[]{18.0d, 144.8d, -6.0d, -63.5d, -4.0d}, new double[]{46.0d, 60.8d, 110.8d, 17.58d, 14.4d}, new double[]{8.0d, 44.8d, -36.0d, -23.5d, -4.0d}};
    private static final double[][] matSquareNonSingular = {new double[]{1.0d, 2.0d, 3.0d}, new double[]{10.0d, 5.0d, -3.0d}, new double[]{14.0d, -6.6d, 7.0d}};
    private static final double[][] matSNSInv = {new double[]{-0.02464332d, 0.05479896d, 0.03404669d}, new double[]{0.18158236d, 0.05674449d, -0.05350195d}, new double[]{0.22049287d, -0.05609598d, 0.02431907d}};
    private static final double[][] matSquareSingular = {new double[]{1.0d, 2.0d, 3.0d}, new double[]{4.0d, 5.0d, 6.0d}, new double[]{7.0d, 8.0d, 9.0d}};
    private static final double[][] identity = {new double[]{1.0d, 0.0d, 0.0d}, new double[]{0.0d, 1.0d, 0.0d}, new double[]{0.0d, 0.0d, 1.0d}};

    MatrixUtilsTest() {
    }

    @Test
    void testShape() {
        Assertions.assertArrayEquals(new int[]{3, 5}, MatrixUtils.getShape(mat3X5));
    }

    @Test
    void testGetCol() {
        Assertions.assertArrayEquals(MatrixUtils.getCol(mat3X4, 1), new double[]{10.0d, 5.0d, -3.0d});
    }

    @Test
    void testGetColTooBig() {
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            MatrixUtils.getCol(mat3X4, 10);
        });
    }

    @Test
    void testOneElemTranspose() {
        double[][] transpose = MatrixUtils.transpose(matOneElem);
        for (int i = 0; i < transpose.length; i++) {
            Assertions.assertArrayEquals(transpose[i], matOneElem[i]);
        }
    }

    @Test
    void testVectorTranspose() {
        double[][] transpose = MatrixUtils.transpose(matRowVector);
        for (int i = 0; i < transpose.length; i++) {
            Assertions.assertArrayEquals(transpose[i], matColVector[i]);
        }
    }

    @Test
    void testMatrixTranspose() {
        double[][] transpose = MatrixUtils.transpose(mat3X4);
        for (int i = 0; i < transpose.length; i++) {
            Assertions.assertArrayEquals(transpose[i], mat4X3[i]);
        }
    }

    @Test
    void testMatMulNormal() {
        double[][] matrixMultiply = MatrixUtils.matrixMultiply(mat4X3, mat3X5);
        for (int i = 0; i < matrixMultiply.length; i++) {
            Assertions.assertArrayEquals(mat43X35Product[i], matrixMultiply[i], 1.0E-6d);
        }
    }

    @Test
    void testMatMulWrongShape() {
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            MatrixUtils.matrixMultiply(mat3X4, mat3X5);
        });
    }

    @Test
    void testVectorRowColMultiply() {
        double[][] matrixMultiply = MatrixUtils.matrixMultiply(matRowVector, matColVector);
        for (int i = 0; i < matrixMultiply.length; i++) {
            Assertions.assertArrayEquals(vectorProdRowCol[i], matrixMultiply[i], 1.0E-6d);
        }
    }

    @Test
    void testVectorColRowMultiply() {
        double[][] matrixMultiply = MatrixUtils.matrixMultiply(matColVector, matRowVector);
        for (int i = 0; i < matrixMultiply.length; i++) {
            Assertions.assertArrayEquals(vectorProdColRow[i], matrixMultiply[i], 1.0E-6d);
        }
    }

    @Test
    void testInvertNormal() {
        double[][] invertSquareMatrix = MatrixUtils.invertSquareMatrix(matSquareNonSingular, 1.0E-9d);
        for (int i = 0; i < invertSquareMatrix.length; i++) {
            Assertions.assertArrayEquals(matSNSInv[i], invertSquareMatrix[i], 1.0E-6d);
        }
    }

    @Test
    void testInvertSingular() {
        Assertions.assertThrows(ArithmeticException.class, () -> {
            MatrixUtils.invertSquareMatrix(matSquareSingular, 1.0E-9d);
        });
    }

    @Test
    void testJitterInvert() {
        for (int i = 0; i < 100; i++) {
            Random random = new Random();
            random.setSeed(i);
            double[][] matrixMultiply = MatrixUtils.matrixMultiply(matSquareSingular, MatrixUtils.jitterInvert(matSquareSingular, 10, 1.0E-9d, random));
            for (int i2 = 0; i2 < matrixMultiply.length; i2++) {
                Assertions.assertArrayEquals(matrixMultiply[i2], identity[i2], 1.0E-6d);
            }
        }
    }
}
