package org.kie.kogito.explainability.utils;

import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:org/kie/kogito/explainability/utils/WeightedLinearRegressionResultsTest.class */
class WeightedLinearRegressionResultsTest {
    RealVector coefficients = MatrixUtils.createRealVector(new double[]{5.0d, 1.0d, -1.0d, 3.0d});
    RealVector flatCoef = MatrixUtils.createRealVector(new double[]{5.0d, 1.0d, -1.0d, 3.0d});
    RealVector stdErrs = MatrixUtils.createRealVector(new double[]{5.0d, 1.0d, -1.0d, 3.0d});

    WeightedLinearRegressionResultsTest() {
    }

    @Test
    void testWLRResultsNoIntercept() {
        RealVector createRealVector = MatrixUtils.createRealVector(new double[]{5.0d, 1.0d, -1.0d, 3.0d});
        RealVector createRealVector2 = MatrixUtils.createRealVector(new double[]{5.0d, 1.0d, -1.0d, 3.0d});
        RealVector createRealVector3 = MatrixUtils.createRealVector(new double[]{5.0d, 1.0d, -1.0d, 3.0d});
        RealVector createRealVector4 = MatrixUtils.createRealVector(new double[4]);
        WeightedLinearRegressionResults weightedLinearRegressionResults = new WeightedLinearRegressionResults(createRealVector, false, 1, 0.01d, createRealVector3, createRealVector4);
        Assertions.assertArrayEquals(createRealVector2.toArray(), weightedLinearRegressionResults.getCoefficients().toArray());
        Assertions.assertArrayEquals(createRealVector3.toArray(), weightedLinearRegressionResults.getStdErrors().toArray());
        Assertions.assertArrayEquals(createRealVector4.toArray(), weightedLinearRegressionResults.getPValues().toArray());
        Assertions.assertEquals(0.0d, weightedLinearRegressionResults.getIntercept());
        Assertions.assertEquals(0.01d, weightedLinearRegressionResults.getMSE());
    }

    @Test
    void testWLRResultWithIntercept() {
        RealVector createRealVector = MatrixUtils.createRealVector(new double[]{5.0d, 1.0d, -1.0d, 3.0d});
        RealVector createRealVector2 = MatrixUtils.createRealVector(new double[]{5.0d, 1.0d, -1.0d});
        RealVector createRealVector3 = MatrixUtils.createRealVector(new double[]{5.0d, 1.0d, -1.0d, 3.0d});
        WeightedLinearRegressionResults weightedLinearRegressionResults = new WeightedLinearRegressionResults(createRealVector, true, 1, 0.01d, createRealVector3, MatrixUtils.createRealVector(new double[4]));
        Assertions.assertArrayEquals(createRealVector2.toArray(), weightedLinearRegressionResults.getCoefficients().toArray());
        Assertions.assertArrayEquals(createRealVector3.toArray(), weightedLinearRegressionResults.getStdErrors().toArray());
        Assertions.assertEquals(3.0d, weightedLinearRegressionResults.getIntercept());
        Assertions.assertEquals(0.01d, weightedLinearRegressionResults.getMSE());
    }

    /* JADX WARN: Type inference failed for: r0v7, types: [double[], double[][]] */
    @Test
    void testPredictions() {
        RealVector createRealVector = MatrixUtils.createRealVector(new double[]{5.0d, 1.0d, -1.0d, 3.0d, 5.0d});
        RealVector createRealVector2 = MatrixUtils.createRealVector(new double[]{5.0d, 1.0d, -1.0d, 3.0d, 5.0d});
        Assertions.assertArrayEquals(MatrixUtils.createRealVector(new double[]{6.0d, 66.0d, -13.4d}).toArray(), new WeightedLinearRegressionResults(createRealVector, true, 1, 0.01d, createRealVector2, MatrixUtils.createRealVector(new double[5])).predict(MatrixUtils.createRealMatrix((double[][]) new double[]{new double[]{1.0d, 5.0d, 3.0d, -2.0d}, new double[]{10.0d, -1.0d, 0.0d, 4.0d}, new double[]{-2.0d, 7.5d, 6.0d, -3.3d}})).toArray(), 1.0E-6d);
    }

    /* JADX WARN: Type inference failed for: r0v10, types: [double[], double[][]] */
    @Test
    void testPredictionsWrongNumFeatures() {
        RealVector createRealVector = MatrixUtils.createRealVector(new double[]{5.0d, 1.0d, -1.0d, 3.0d, 5.0d});
        RealVector createRealVector2 = MatrixUtils.createRealVector(new double[]{5.0d, 1.0d, -1.0d, 3.0d, 5.0d});
        RealVector createRealVector3 = MatrixUtils.createRealVector(new double[5]);
        RealMatrix createRealMatrix = MatrixUtils.createRealMatrix((double[][]) new double[]{new double[]{1.0d, 5.0d}, new double[]{10.0d, -1.0d}, new double[]{-2.0d, 7.5d}});
        WeightedLinearRegressionResults weightedLinearRegressionResults = new WeightedLinearRegressionResults(createRealVector, true, 1, 0.01d, createRealVector2, createRealVector3);
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            weightedLinearRegressionResults.predict(createRealMatrix);
        });
    }
}
