package org.kie.kogito.explainability.utils;

import java.util.List;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;

/* loaded from: input_file:org/kie/kogito/explainability/utils/LarsPathTest.class */
class LarsPathTest {
    RealMatrix X = MatrixUtils.createRealMatrix((double[][]) new double[]{new double[]{0.92966881d, 0.17435502d, 0.86274567d, 0.02096693d, 0.61729408d, 0.27663037d, 0.07324771d, 0.86299396d, 0.20387837d, 0.2678897d}, new double[]{0.46124402d, 0.21212798d, 0.54547663d, 0.85310364d, 0.23584478d, 0.89939373d, 0.90052444d, 0.48947526d, 0.97695481d, 0.31682039d}, new double[]{0.66084177d, 0.54153099d, 0.76965712d, 0.08213559d, 0.9262654d, 0.68282777d, 0.500637d, 0.76781516d, 0.14606141d, 0.53844816d}, new double[]{0.44602165d, 0.72739983d, 0.66221962d, 0.20234917d, 0.80836334d, 0.37038587d, 0.67539221d, 0.77099063d, 0.92992129d, 0.56789747d}, new double[]{0.67568569d, 0.37884472d, 0.18745406d, 0.04757457d, 0.09661771d, 0.50471931d, 0.35367252d, 0.75794935d, 0.6424804d, 0.55250168d}, new double[]{0.19722479d, 0.32117211d, 0.70339706d, 0.53906674d, 0.76903061d, 0.32923893d, 0.50025901d, 0.20776133d, 0.1088789d, 0.79303772d}, new double[]{0.31128645d, 0.05883037d, 0.64210569d, 0.88726458d, 0.19756748d, 0.02448866d, 0.2172705d, 0.27894779d, 0.55028519d, 0.70483099d}, new double[]{0.47339132d, 0.14034869d, 0.0816702d, 0.06699631d, 0.06823621d, 0.03639515d, 0.07545303d, 0.1208853d, 0.72845905d, 0.74802801d}, new double[]{0.99628077d, 0.83760513d, 0.63542635d, 0.07380346d, 0.79007766d, 0.55288944d, 0.44548098d, 0.4055312d, 0.70605767d, 0.83153303d}, new double[]{0.47161946d, 0.97424448d, 0.91217761d, 0.6264732d, 0.43486423d, 0.39281956d, 0.66218207d, 0.01484187d, 0.75595905d, 0.04462323d}});
    RealVector y = MatrixUtils.createRealVector(new double[]{6.38923853d, -2.16396995d, 7.37162403d, 1.79236199d, 4.21888433d, 0.41875855d, -3.69136276d, -0.50760573d, 4.89875242d, -4.03316984d});
    List<Integer> correctActives = List.of(7, 3, 0, 4, 8, 5, 9, 2, 1, 6);
    RealVector correctAlphas = MatrixUtils.createRealVector(new double[]{1.56169836d, 0.83963397d, 0.68991047d, 0.6664122d, 0.29931992d, 0.14315316d, 0.1200302d, 0.00776273d, 0.00389666d, 0.00187156d, 0.0d});
    RealMatrix X2 = MatrixUtils.createRealMatrix((double[][]) new double[]{new double[]{0.0d, 1.0d, 0.0d, 1.0d, 0.0d}, new double[]{0.0d, 1.0d, 0.0d, 1.0d, 0.0d}, new double[]{0.0d, 1.0d, 1.0d, 0.0d, 1.0d}, new double[]{1.0d, 1.0d, 0.0d, 0.0d, 1.0d}, new double[]{0.0d, 1.0d, 1.0d, 0.0d, 1.0d}, new double[]{0.0d, 0.0d, 1.0d, 0.0d, 1.0d}, new double[]{0.0d, 0.0d, 1.0d, 1.0d, 1.0d}, new double[]{0.0d, 0.0d, 0.0d, 0.0d, 0.0d}, new double[]{0.0d, 0.0d, 0.0d, 0.0d, 0.0d}, new double[]{1.0d, 1.0d, 0.0d, 1.0d, 0.0d}, new double[]{0.0d, 0.0d, 1.0d, 0.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 0.0d, 0.0d}, new double[]{0.0d, 1.0d, 1.0d, 0.0d, 0.0d}, new double[]{0.0d, 0.0d, 0.0d, 0.0d, 0.0d}, new double[]{0.0d, 1.0d, 1.0d, 1.0d, 1.0d}});
    RealVector y2 = MatrixUtils.createRealVector(new double[]{1.09622926d, 1.09622926d, 1.07290478d, 0.5044599d, 1.07290478d, 0.88775703d, 1.79883855d, 0.0d, 0.0d, 1.39511415d, 0.88775703d, 0.29888489d, 1.0524775d, 0.0d, 1.98398629d});
    List<Integer> correctActives2 = List.of(1, 2, 3, 4, 0);
    RealVector correctAlphas2 = MatrixUtils.createRealVector(new double[]{0.61828706d, 0.54926307d, 0.36443261d, 0.183918d, 0.04809642d, 0.0d});
    RealMatrix XMinAlpha = MatrixUtils.createRealMatrix((double[][]) new double[]{new double[]{1.0d, 2.0d, 3.0d, 4.0d, 5.0d}, new double[]{2.0d, 4.0d, 6.0d, 8.0d, 10.0d}, new double[]{3.0d, 6.0d, 8.0d, 12.0d, 15.0d}, new double[]{-1.0d, -2.0d, -3.0d, -4.0d, -5.0d}});
    RealVector yMinAlpha = MatrixUtils.createRealVector(new double[]{55.0d, 110.0d, 162.0d, -55.0d});
    List<Integer> correctActivesMinAlpha = List.of(4, 2);
    RealVector correctAlphasMinAlpha = MatrixUtils.createRealVector(new double[]{1020.0d, 0.681818182d, 4.16888746E-13d});
    RealMatrix XDGR = MatrixUtils.createRealMatrix((double[][]) new double[]{new double[]{0.0d, 1.0d, 2.0d, 3.0d, 4.0d}, new double[]{5.0d, 6.0d, 7.0d, 8.0d, 9.0d}, new double[]{10.0d, 11.0d, 12.0d, 13.0d, 14.0d}, new double[]{15.0d, 16.0d, 17.0d, 18.0d, 19.0d}, new double[]{20.0d, 21.0d, 22.0d, 23.0d, 24.0d}, new double[]{25.0d, 26.0d, 27.0d, 28.0d, 29.0d}, new double[]{30.0d, 31.0d, 32.0d, 33.0d, 34.0d}, new double[]{35.0d, 36.0d, 37.0d, 38.0d, 39.0d}, new double[]{40.0d, 41.0d, 42.0d, 43.0d, 44.0d}, new double[]{45.0d, 46.0d, 47.0d, 48.0d, 49.0d}});
    RealVector yDGR = MatrixUtils.createRealVector(new double[]{0.0d, 50.0d, 100.0d, 150.0d, 200.0d, 250.0d, 300.0d, 350.0d, 400.0d, 450.0d});
    RealVector dummyWeights = this.yDGR.map(d -> {
        return 1.0d;
    });
    RealMatrix XVarDrop = MatrixUtils.createRealMatrix((double[][]) new double[]{new double[]{0.18321534d, -0.35029812d, 0.07666221d, -0.0143539d, 0.07493564d, 0.0264753d}, new double[]{-0.02852663d, 0.35584493d, 0.49064148d, -0.25236788d, 0.61892759d, -0.31065164d}, new double[]{0.21982559d, 0.49110281d, -0.5332698d, -0.68922198d, -0.52132598d, 0.56451554d}, new double[]{0.61981706d, 0.26882512d, 0.50779668d, 0.36052515d, 0.13083871d, 0.41441906d}, new double[]{-0.63640374d, -0.65128988d, -0.45031639d, 0.57505155d, 0.22006143d, -0.63980788d}, new double[]{-0.35792762d, -0.11418486d, -0.09151419d, 0.02036707d, -0.52343739d, -0.05495039d}});
    RealVector yVarDrop = MatrixUtils.createRealVector(new double[]{0.00357273d, 0.008411d, 0.33522509d, 0.07329731d, -0.24901509d, -0.17149104d});
    List<Integer> correctActivesVarDrop = List.of(1, 3, 0, 2, 4);
    RealVector correctAlphasVarDrop = MatrixUtils.createRealVector(new double[]{0.0643070982530952d, 0.0545709061429076d, 0.051526183371599d, 0.0273483558266873d, 0.0086875666842567d, 0.0056182106014212d, 0.0037483228269213d, 0.0d});

    /* JADX WARN: Type inference failed for: r1v1, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v13, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v25, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v37, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v46, types: [double[], double[][]] */
    LarsPathTest() {
    }

    @ValueSource(ints = {2, 3, 4, 5, 6, 7, 8, 9, 10})
    @ParameterizedTest
    void testLars10(int i) {
        LarsPathResults fit = LarsPath.fit(this.X, this.y, i, false);
        Assertions.assertEquals(this.correctActives.subList(0, i), fit.getActive().subList(0, i));
        Assertions.assertArrayEquals(this.correctAlphas.getSubVector(0, i).toArray(), fit.getAlphas().getSubVector(0, i).toArray(), 1.0E-6d);
    }

    @ValueSource(ints = {5})
    @ParameterizedTest
    void testLars5(int i) {
        LarsPathResults fit = LarsPath.fit(this.X2, this.y2, i, false);
        Assertions.assertEquals(this.correctActives2.subList(0, i), fit.getActive().subList(0, i));
        Assertions.assertArrayEquals(this.correctAlphas2.getSubVector(0, i).toArray(), fit.getAlphas().getSubVector(0, i).toArray(), 1.0E-6d);
    }

    @Test
    void testLarsMinAlpha() {
        LarsPathResults fit = LarsPath.fit(this.XMinAlpha, this.yMinAlpha, 500, false);
        Assertions.assertEquals(this.correctActivesMinAlpha, fit.getActive());
        Assertions.assertArrayEquals(this.correctAlphasMinAlpha.toArray(), fit.getAlphas().toArray(), 1.0E-6d);
    }

    @Test
    void testLarsDGR() {
        RealMatrix coefs = LarsPath.fit(this.XDGR, this.yDGR, 500, false).getCoefs();
        Assertions.assertTrue(WeightedLinearRegression.getMSE(this.XDGR, this.yDGR, this.dummyWeights, coefs.getColumnVector(coefs.getColumnDimension() - 1)) < 1.0E-16d);
    }

    @Test
    void testLarsVarDrop() {
        LarsPathResults fit = LarsPath.fit(this.XVarDrop, this.yVarDrop, 500, true);
        Assertions.assertEquals(this.correctActivesVarDrop, fit.getActive());
        Assertions.assertArrayEquals(this.correctAlphasVarDrop.toArray(), fit.getAlphas().toArray(), 1.0E-6d);
    }

    @Test
    void testLarsMismatchedInputs() {
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            LarsPath.fit(this.XVarDrop, this.yDGR, 500, true);
        });
    }
}
