package org.kie.kogito.explainability.utils;

import java.util.ArrayList;
import java.util.List;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.kie.kogito.explainability.model.FeatureFactory;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.Type;
import org.kie.kogito.explainability.model.Value;
import org.kie.kogito.explainability.utils.MatrixUtilsExtensions;

/* loaded from: input_file:org/kie/kogito/explainability/utils/MatrixUtilsExtensionsTest.class */
class MatrixUtilsExtensionsTest {
    double[][] mat3X5 = {new double[]{1.0d, 10.0d, 3.0d, -4.0d, 0.0d}, new double[]{10.0d, 5.0d, -3.0d, 3.7d, 1.0d}, new double[]{14.0d, -6.6d, 7.0d, 14.0d, 3.0d}};
    double[][] mat3X5get013 = {new double[]{1.0d, 10.0d, -4.0d}, new double[]{10.0d, 5.0d, 3.7d}, new double[]{14.0d, -6.6d, 14.0d}};
    double[][] mat3X5get03130 = {new double[]{1.0d, -4.0d, 10.0d, -4.0d, 1.0d}, new double[]{10.0d, 3.7d, 5.0d, 3.7d, 10.0d}, new double[]{14.0d, 14.0d, -6.6d, 14.0d, 14.0d}};
    double[][] matSquareNonSingular = {new double[]{1.0d, 2.0d, 3.0d}, new double[]{10.0d, 5.0d, -3.0d}, new double[]{14.0d, -6.6d, 7.0d}};
    double[][] matSNSInv = {new double[]{-0.02464332d, 0.05479896d, 0.03404669d}, new double[]{0.18158236d, 0.05674449d, -0.05350195d}, new double[]{0.22049287d, -0.05609598d, 0.02431907d}};
    double[][] matSquareSingular = {new double[]{1.0d, 2.0d, 3.0d}, new double[]{4.0d, 5.0d, 6.0d}, new double[]{7.0d, 8.0d, 9.0d}};
    double[] mssSumRow = {12.0d, 15.0d, 18.0d};
    RealVector v = MatrixUtils.createRealVector(new double[]{1.0d, 2.0d, 3.0d});
    RealMatrix mssMatrix = MatrixUtils.createRealMatrix(this.matSquareSingular);
    RealMatrix rowDiffResult = MatrixUtils.createRealMatrix((double[][]) new double[]{new double[]{0.0d, 0.0d, 0.0d}, new double[]{3.0d, 3.0d, 3.0d}, new double[]{6.0d, 6.0d, 6.0d}});
    RealMatrix colDiffResult = MatrixUtils.createRealMatrix((double[][]) new double[]{new double[]{0.0d, 1.0d, 2.0d}, new double[]{2.0d, 3.0d, 4.0d}, new double[]{4.0d, 5.0d, 6.0d}});
    RealMatrix swapResult = MatrixUtils.createRealMatrix((double[][]) new double[]{new double[]{7.0d, 8.0d, 9.0d}, new double[]{4.0d, 5.0d, 6.0d}, new double[]{1.0d, 2.0d, 3.0d}});
    RealVector swapResultV = MatrixUtils.createRealVector(new double[]{3.0d, 2.0d, 1.0d});
    RealMatrix dotInput1 = MatrixUtils.createRealMatrix((double[][]) new double[]{new double[]{0.0d, 1.0d, 2.0d}, new double[]{3.0d, 4.0d, 5.0d}});
    RealMatrix dotInput2 = MatrixUtils.createRealMatrix((double[][]) new double[]{new double[]{0.0d, 1.0d, 2.0d}, new double[]{3.0d, 4.0d, 5.0d}, new double[]{6.0d, 7.0d, 8.0d}});
    RealMatrix dotResult = MatrixUtils.createRealMatrix((double[][]) new double[]{new double[]{15.0d, 18.0d, 21.0d}, new double[]{42.0d, 54.0d, 66.0d}});
    RealVector vMix = MatrixUtils.createRealVector(new double[]{-3.0d, -2.0d, -1.0d, 1.0d, 2.0d, 3.0d});
    RealVector allNeg = MatrixUtils.createRealVector(new double[]{-3.0d, -2.0d, -1.0d});
    RealVector varInput = MatrixUtils.createRealVector(new double[]{0.0d, 4.0d, 16.0d, 2.0d, -128.0d, -4.0d});

    /* JADX WARN: Type inference failed for: r1v1, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v11, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v21, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v24, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v27, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v3, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v33, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v36, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v39, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v5, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v7, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v9, types: [double[], double[][]] */
    MatrixUtilsExtensionsTest() {
    }

    @Test
    void testPICreation() {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < 5; i++) {
            arrayList.add(FeatureFactory.newNumericalFeature("f", Double.valueOf(this.mat3X5[0][i])));
        }
        Assertions.assertArrayEquals(this.mat3X5[0], MatrixUtilsExtensions.vectorFromPredictionInput(new PredictionInput(arrayList)).toArray());
    }

    @Test
    void testPIListCreation() {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < 3; i++) {
            ArrayList arrayList2 = new ArrayList();
            for (int i2 = 0; i2 < 5; i2++) {
                arrayList2.add(FeatureFactory.newNumericalFeature("f", Double.valueOf(this.mat3X5[i][i2])));
            }
            arrayList.add(new PredictionInput(arrayList2));
        }
        Assertions.assertArrayEquals(this.mat3X5, MatrixUtilsExtensions.matrixFromPredictionInput(arrayList).getData());
    }

    @Test
    void testPOCreation() {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < 5; i++) {
            arrayList.add(new Output("o", Type.NUMBER, new Value(Double.valueOf(this.mat3X5[0][i])), 0.0d));
        }
        Assertions.assertArrayEquals(this.mat3X5[0], MatrixUtilsExtensions.vectorFromPredictionOutput(new PredictionOutput(arrayList)).toArray());
    }

    @Test
    void testPOListCreation() {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < 3; i++) {
            ArrayList arrayList2 = new ArrayList();
            for (int i2 = 0; i2 < 5; i2++) {
                arrayList2.add(new Output("o", Type.NUMBER, new Value(Double.valueOf(this.mat3X5[i][i2])), 0.0d));
            }
            arrayList.add(new PredictionOutput(arrayList2));
        }
        Assertions.assertArrayEquals(this.mat3X5, MatrixUtilsExtensions.matrixFromPredictionOutput(arrayList).getData());
    }

    @Test
    void testGetCols() {
        RealMatrix cols = MatrixUtilsExtensions.getCols(MatrixUtils.createRealMatrix(this.mat3X5), List.of(0, 1, 3));
        for (int i = 0; i < this.mat3X5get013.length; i++) {
            Assertions.assertArrayEquals(this.mat3X5get013[i], cols.getRow(i));
        }
    }

    @Test
    void testGetDupCols() {
        RealMatrix cols = MatrixUtilsExtensions.getCols(MatrixUtils.createRealMatrix(this.mat3X5), List.of(0, 3, 1, 3, 0));
        for (int i = 0; i < this.mat3X5get03130.length; i++) {
            Assertions.assertArrayEquals(this.mat3X5get03130[i], cols.getRow(i));
        }
    }

    @Test
    void testGetColsTooBig() {
        List of = List.of(0, 6);
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            MatrixUtilsExtensions.getCols(MatrixUtils.createRealMatrix(this.mat3X5), of);
        });
    }

    @Test
    void testGetNegCols() {
        List of = List.of(0, -6);
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            MatrixUtilsExtensions.getCols(MatrixUtils.createRealMatrix(this.mat3X5), of);
        });
    }

    @Test
    void testGetNoCols() {
        List of = List.of();
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            MatrixUtilsExtensions.getCols(MatrixUtils.createRealMatrix(this.mat3X5), of);
        });
    }

    @Test
    void testIntraSumRow() {
        Assertions.assertArrayEquals(this.mssSumRow, MatrixUtilsExtensions.rowSum(this.mssMatrix).toArray(), 1.0E-6d);
    }

    @Test
    void rowSum() {
        Assertions.assertArrayEquals(new double[]{12.0d, 15.0d, 18.0d}, MatrixUtilsExtensions.rowSum(this.mssMatrix).toArray());
    }

    @Test
    void rowSquareSum() {
        Assertions.assertArrayEquals(new double[]{66.0d, 93.0d, 126.0d}, MatrixUtilsExtensions.rowSquareSum(this.mssMatrix).toArray());
    }

    @Test
    void rowDifference() {
        Assertions.assertEquals(this.rowDiffResult, MatrixUtilsExtensions.vectorDifference(this.mssMatrix, this.v, MatrixUtilsExtensions.Axis.ROW));
    }

    @Test
    void colDifference() {
        Assertions.assertEquals(this.colDiffResult, MatrixUtilsExtensions.vectorDifference(this.mssMatrix, this.v, MatrixUtilsExtensions.Axis.COLUMN));
    }

    @Test
    void matrixDot() {
        Assertions.assertEquals(this.dotResult, MatrixUtilsExtensions.matrixDot(this.dotInput1, this.dotInput2));
    }

    @Test
    void testInvertNormal() {
        RealMatrix safeInvert = MatrixUtilsExtensions.safeInvert(MatrixUtils.createRealMatrix(this.matSquareNonSingular));
        for (int i = 0; i < safeInvert.getRowDimension(); i++) {
            Assertions.assertArrayEquals(this.matSNSInv[i], safeInvert.getRow(i), 1.0E-4d);
        }
    }

    @Test
    void testInvertSingular() {
        RealMatrix realMatrix = this.mssMatrix;
        RealMatrix safeInvert = MatrixUtilsExtensions.safeInvert(realMatrix);
        RealMatrix multiply = realMatrix.multiply(safeInvert).multiply(realMatrix);
        for (int i = 0; i < safeInvert.getRowDimension(); i++) {
            Assertions.assertArrayEquals(realMatrix.getRow(i), multiply.getRow(i), 1.0E-4d);
        }
    }

    @Test
    void testMinPos() {
        Assertions.assertEquals(1.0d, MatrixUtilsExtensions.minPos(this.vMix), 1.0E-4d);
    }

    @Test
    void testMinPosNoNeg() {
        Assertions.assertEquals(Double.MAX_VALUE, MatrixUtilsExtensions.minPos(this.allNeg), 1.0E-4d);
    }

    @Test
    void testVar() {
        Assertions.assertEquals(2443.22222222d, MatrixUtilsExtensions.variance(this.varInput), 1.0E-4d);
    }

    @Test
    void testSwapRealMatrix() {
        RealMatrix copy = this.mssMatrix.copy();
        MatrixUtilsExtensions.swap(copy, 0, 2);
        Assertions.assertEquals(this.swapResult, copy);
    }

    @Test
    void testSwapRealVector() {
        RealVector copy = this.v.copy();
        MatrixUtilsExtensions.swap(copy, 0, 2);
        Assertions.assertEquals(this.swapResultV, copy);
    }

    @Test
    void testSwapIntArr() {
        int[] iArr = {1, 2, 3};
        MatrixUtilsExtensions.swap(iArr, 0, 2);
        Assertions.assertArrayEquals(new int[]{3, 2, 1}, iArr);
    }
}
