package org.kie.kogito.explainability.local.shap;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.kie.kogito.explainability.TestUtils;
import org.kie.kogito.explainability.local.shap.ShapConfig;
import org.kie.kogito.explainability.model.FeatureFactory;
import org.kie.kogito.explainability.model.FeatureImportance;
import org.kie.kogito.explainability.model.PerturbationContext;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.Saliency;
import org.kie.kogito.explainability.model.SimplePrediction;

/* loaded from: input_file:org/kie/kogito/explainability/local/shap/ShapKernelExplainerTest.class */
class ShapKernelExplainerTest {
    double[][] backgroundRaw = {new double[]{1.0d, 2.0d, 3.0d, -4.0d, 5.0d}, new double[]{10.0d, 11.0d, 12.0d, -4.0d, 13.0d}, new double[]{2.0d, 3.0d, 4.0d, -4.0d, 6.0d}};
    double[][] toExplainRaw = {new double[]{5.0d, 6.0d, 7.0d, -4.0d, 8.0d}, new double[]{11.0d, 12.0d, 13.0d, -5.0d, 14.0d}, new double[]{0.0d, 0.0d, 1.0d, 4.0d, 2.0d}};
    double[][] backgroundNoVariance = {new double[]{1.0d, 2.0d, 3.0d}, new double[]{1.0d, 2.0d, 3.0d}};
    double[][] toExplainZeroVariance = {new double[]{1.0d, 2.0d, 3.0d}, new double[]{1.0d, 2.0d, 3.0d}};
    double[][][] zeroVarianceOneOutputSHAP = {new double[]{new double[]{0.0d, 0.0d, 0.0d}}, new double[]{new double[]{0.0d, 0.0d, 0.0d}}};
    double[][][] zeroVarianceMultiOutputSHAP = {new double[]{new double[]{0.0d, 0.0d, 0.0d}, new double[]{0.0d, 0.0d, 0.0d}}, new double[]{new double[]{0.0d, 0.0d, 0.0d}, new double[]{0.0d, 0.0d, 0.0d}}};
    double[][] toExplainOneVariance = {new double[]{3.0d, 2.0d, 3.0d}, new double[]{1.0d, 2.0d, 2.0d}};
    double[][][] oneVarianceOneOutputSHAP = {new double[]{new double[]{2.0d, 0.0d, 0.0d}}, new double[]{new double[]{0.0d, 0.0d, -1.0d}}};
    double[][][] oneVarianceMultiOutputSHAP = {new double[]{new double[]{2.0d, 0.0d, 0.0d}, new double[]{4.0d, 0.0d, 0.0d}}, new double[]{new double[]{0.0d, 0.0d, -1.0d}, new double[]{0.0d, 0.0d, -2.0d}}};
    double[][] toExplainLogit = {new double[]{0.1d, 0.12d, 0.14d, -0.08d, 0.16d}, new double[]{0.22d, 0.24d, 0.26d, -0.1d, 0.38d}, new double[]{-0.1d, 0.0d, 0.02d, 0.1d, 0.04d}};
    double[][] backgroundLogit = {new double[]{0.02380952d, 0.04761905d, 0.07142857d, -0.0952381d, 0.11904762d}, new double[]{0.23809524d, 0.26190476d, 0.28571429d, -0.0952381d, 0.30952381d}, new double[]{0.04761905d, 0.07142857d, 0.11904762d, -0.0952381d, 0.14285714d}};
    double[][][] logitSHAP = {new double[]{new double[]{-0.01420862d, 0.0d, -0.08377778d, 0.06825253d, -0.13625127d}}, new double[]{new double[]{0.50970797d, 0.0d, 0.44412765d, -0.02169177d, 0.80832232d}}, new double[]{new double[]{Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN}}};
    double[][][] multiVarianceOneOutputSHAP = {new double[]{new double[]{0.66666667d, 0.0d, 0.66666667d, 0.0d, 0.0d}}, new double[]{new double[]{6.66666667d, 0.0d, 6.66666667d, -1.0d, 6.0d}}, new double[]{new double[]{-4.33333333d, 0.0d, -5.33333333d, 8.0d, -6.0d}}};
    double[][][] multiVarianceMultiOutputSHAP = {new double[]{new double[]{0.66666667d, 0.0d, 0.66666667d, 0.0d, 0.0d}, new double[]{1.333333333d, 0.0d, 1.33333333d, 0.0d, 0.0d}}, new double[]{new double[]{6.66666667d, 0.0d, 6.66666667d, -1.0d, 6.0d}, new double[]{13.333333333d, 0.0d, 13.333333333d, -2.0d, 12.0d}}, new double[]{new double[]{-4.33333333d, 0.0d, -5.33333333d, 8.0d, -6.0d}, new double[]{-8.6666666667d, 0.0d, -10.666666666d, 16.0d, -12.0d}}};
    PerturbationContext pc = new PerturbationContext(new Random(0), 0);
    ShapConfig.Builder testConfig = ShapConfig.builder().withLink(ShapConfig.LinkType.IDENTITY).withPC(this.pc);
    double[][] backgroundAllZeros = new double[100][6];
    double[][] toExplainAllOnes = {new double[]{1.0d, 1.0d, 1.0d, 1.0d, 1.0d, 1.0d}};

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r1v1, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v11, types: [double[][], double[][][]] */
    /* JADX WARN: Type inference failed for: r1v13, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v15, types: [double[][], double[][][]] */
    /* JADX WARN: Type inference failed for: r1v17, types: [double[][], double[][][]] */
    /* JADX WARN: Type inference failed for: r1v19, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v21, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v23, types: [double[][], double[][][]] */
    /* JADX WARN: Type inference failed for: r1v25, types: [double[][], double[][][]] */
    /* JADX WARN: Type inference failed for: r1v27, types: [double[][], double[][][]] */
    /* JADX WARN: Type inference failed for: r1v3, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v35, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v5, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v7, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v9, types: [double[][], double[][][]] */
    ShapKernelExplainerTest() {
    }

    private List<PredictionInput> createPIFromMatrix(double[][] dArr) {
        ArrayList arrayList = new ArrayList();
        int[] iArr = {dArr.length, dArr[0].length};
        for (int i = 0; i < iArr[0]; i++) {
            ArrayList arrayList2 = new ArrayList();
            for (int i2 = 0; i2 < iArr[1]; i2++) {
                arrayList2.add(FeatureFactory.newNumericalFeature("f", Double.valueOf(dArr[i][i2])));
            }
            arrayList.add(new PredictionInput(arrayList2));
        }
        return arrayList;
    }

    private RealMatrix[] saliencyToMatrix(Saliency[] saliencyArr) {
        RealMatrix createRealMatrix = MatrixUtils.createRealMatrix(new double[saliencyArr.length][saliencyArr[0].getPerFeatureImportance().size()]);
        RealMatrix[] realMatrixArr = {createRealMatrix.copy(), createRealMatrix.copy()};
        for (int i = 0; i < saliencyArr.length; i++) {
            List perFeatureImportance = saliencyArr[i].getPerFeatureImportance();
            for (int i2 = 0; i2 < perFeatureImportance.size(); i2++) {
                realMatrixArr[0].setEntry(i, i2, ((FeatureImportance) perFeatureImportance.get(i2)).getScore());
                realMatrixArr[1].setEntry(i, i2, ((FeatureImportance) perFeatureImportance.get(i2)).getConfidence());
            }
        }
        return realMatrixArr;
    }

    private void shapTestCase(PredictionProvider predictionProvider, ShapConfig shapConfig, double[][] dArr, double[][][] dArr2) throws InterruptedException, TimeoutException, ExecutionException {
        List<PredictionInput> createPIFromMatrix = createPIFromMatrix(dArr);
        List list = (List) predictionProvider.predictAsync(createPIFromMatrix).get(5L, TimeUnit.SECONDS);
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < list.size(); i++) {
            arrayList.add(new SimplePrediction(createPIFromMatrix.get(i), (PredictionOutput) list.get(i)));
        }
        ShapKernelExplainer shapKernelExplainer = new ShapKernelExplainer(shapConfig);
        for (int i2 = 0; i2 < createPIFromMatrix.size(); i2++) {
            RealMatrix realMatrix = saliencyToMatrix(((ShapResults) shapKernelExplainer.explainAsync((Prediction) arrayList.get(i2), predictionProvider).get(5L, TimeUnit.SECONDS)).getSaliencies())[0];
            for (int i3 = 0; i3 < realMatrix.getRowDimension(); i3++) {
                Assertions.assertArrayEquals(dArr2[i2][i3], realMatrix.getRow(i3), 1.0E-6d);
            }
        }
    }

    private void shapTestCase(PredictionProvider predictionProvider, ShapKernelExplainer shapKernelExplainer, double[][] dArr, double[][][] dArr2) throws InterruptedException, TimeoutException, ExecutionException {
        List<PredictionInput> createPIFromMatrix = createPIFromMatrix(dArr);
        List list = (List) predictionProvider.predictAsync(createPIFromMatrix).get(5L, TimeUnit.SECONDS);
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < list.size(); i++) {
            arrayList.add(new SimplePrediction(createPIFromMatrix.get(i), (PredictionOutput) list.get(i)));
        }
        for (int i2 = 0; i2 < createPIFromMatrix.size(); i2++) {
            RealMatrix realMatrix = saliencyToMatrix(((ShapResults) shapKernelExplainer.explainAsync((Prediction) arrayList.get(i2), predictionProvider).get(5L, TimeUnit.SECONDS)).getSaliencies())[0];
            for (int i3 = 0; i3 < realMatrix.getRowDimension(); i3++) {
                Assertions.assertArrayEquals(dArr2[i2][i3], realMatrix.getRow(i3), 1.0E-6d);
            }
        }
    }

    @Test
    void testNoVarianceOneOutput() throws InterruptedException, TimeoutException, ExecutionException {
        shapTestCase(TestUtils.getSumSkipModel(1), this.testConfig.withBackground(createPIFromMatrix(this.backgroundNoVariance)).withNSamples(100).build(), this.toExplainZeroVariance, this.zeroVarianceOneOutputSHAP);
    }

    @Test
    void testOneVarianceOneOutput() throws InterruptedException, TimeoutException, ExecutionException {
        shapTestCase(TestUtils.getSumSkipModel(1), this.testConfig.withBackground(createPIFromMatrix(this.backgroundNoVariance)).withNSamples(100).build(), this.toExplainOneVariance, this.oneVarianceOneOutputSHAP);
    }

    @Test
    void testMultiVarianceOneOutput() throws InterruptedException, TimeoutException, ExecutionException {
        shapTestCase(TestUtils.getSumSkipModel(1), this.testConfig.withBackground(createPIFromMatrix(this.backgroundRaw)).withNSamples(35).build(), this.toExplainRaw, this.multiVarianceOneOutputSHAP);
    }

    @Test
    void testMultiVarianceOneOutputLogit() throws InterruptedException, TimeoutException, ExecutionException {
        shapTestCase(TestUtils.getSumSkipModel(1), ShapConfig.builder().withBackground(createPIFromMatrix(this.backgroundLogit)).withLink(ShapConfig.LinkType.LOGIT).withNSamples(100).withPC(this.pc).build(), this.toExplainLogit, this.logitSHAP);
    }

    @Test
    void testNoVarianceMultiOutput() throws InterruptedException, TimeoutException, ExecutionException {
        shapTestCase(TestUtils.getSumSkipTwoOutputModel(1), this.testConfig.withBackground(createPIFromMatrix(this.backgroundNoVariance)).build(), this.toExplainZeroVariance, this.zeroVarianceMultiOutputSHAP);
    }

    @Test
    void testOneVarianceMultiOutput() throws InterruptedException, TimeoutException, ExecutionException {
        shapTestCase(TestUtils.getSumSkipTwoOutputModel(1), this.testConfig.withBackground(createPIFromMatrix(this.backgroundNoVariance)).build(), this.toExplainOneVariance, this.oneVarianceMultiOutputSHAP);
    }

    @Test
    void testMultiVarianceMultiOutput() throws InterruptedException, TimeoutException, ExecutionException {
        shapTestCase(TestUtils.getSumSkipTwoOutputModel(1), this.testConfig.withBackground(createPIFromMatrix(this.backgroundRaw)).build(), this.toExplainRaw, this.multiVarianceMultiOutputSHAP);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v5, types: [double[], double[][]] */
    @Test
    void testLargeBackground() throws InterruptedException, TimeoutException, ExecutionException {
        double[][] dArr = new double[100][10];
        for (int i = 0; i < 100; i++) {
            for (int i2 = 0; i2 < 10; i2++) {
                dArr[i][i2] = (i / 100.0d) + i2;
            }
        }
        double[][] dArr2 = {new double[]{new double[]{-0.495d, 0.0d, -4.495d, 0.005d, -8.595d, 0.005d, -18.495d, -6.695d, -8.385d, 5.505d}}};
        List<PredictionInput> createPIFromMatrix = createPIFromMatrix(dArr);
        List<PredictionInput> createPIFromMatrix2 = createPIFromMatrix(new double[]{new double[]{0.0d, 1.0d, -2.0d, 3.5d, -4.1d, 5.5d, -12.0d, 0.8d, 0.11d, 15.0d}});
        PredictionProvider sumSkipModel = TestUtils.getSumSkipModel(1);
        ShapConfig build = this.testConfig.withBackground(createPIFromMatrix).build();
        List list = (List) sumSkipModel.predictAsync(createPIFromMatrix2).get();
        ArrayList arrayList = new ArrayList();
        for (int i3 = 0; i3 < list.size(); i3++) {
            arrayList.add(new SimplePrediction(createPIFromMatrix2.get(i3), (PredictionOutput) list.get(i3)));
        }
        ShapKernelExplainer shapKernelExplainer = new ShapKernelExplainer(build);
        for (int i4 = 0; i4 < createPIFromMatrix2.size(); i4++) {
            RealMatrix realMatrix = saliencyToMatrix(((ShapResults) shapKernelExplainer.explainAsync((Prediction) arrayList.get(i4), sumSkipModel).get(5L, TimeUnit.SECONDS)).getSaliencies())[0];
            for (int i5 = 0; i5 < realMatrix.getRowDimension(); i5++) {
                Assertions.assertArrayEquals(dArr2[i4][i5], realMatrix.getRow(i5), 0.01d);
            }
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v5, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r0v7, types: [double[][], double[][][]] */
    @Test
    void testParallel() throws InterruptedException, ExecutionException {
        double[][] dArr = new double[100][10];
        for (int i = 0; i < 100; i++) {
            for (int i2 = 0; i2 < 10; i2++) {
                dArr[i][i2] = (i / 100.0d) + i2;
            }
        }
        ?? r0 = {new double[]{new double[]{-0.495d, 0.0d, -4.495d, 0.005d, -8.595d, 0.005d, -18.495d, -6.695d, -8.385d, 5.505d}}};
        List<PredictionInput> createPIFromMatrix = createPIFromMatrix(dArr);
        List<PredictionInput> createPIFromMatrix2 = createPIFromMatrix(new double[]{new double[]{0.0d, 1.0d, -2.0d, 3.5d, -4.1d, 5.5d, -12.0d, 0.8d, 0.11d, 15.0d}});
        PredictionProvider sumSkipModel = TestUtils.getSumSkipModel(1);
        ShapConfig build = this.testConfig.withBackground(createPIFromMatrix).build();
        List list = (List) sumSkipModel.predictAsync(createPIFromMatrix2).get();
        ArrayList arrayList = new ArrayList();
        for (int i3 = 0; i3 < list.size(); i3++) {
            arrayList.add(new SimplePrediction(createPIFromMatrix2.get(i3), (PredictionOutput) list.get(i3)));
        }
        CompletableFuture explainAsync = new ShapKernelExplainer(build).explainAsync((Prediction) arrayList.get(0), sumSkipModel);
        ForkJoinPool.commonPool().submit(() -> {
            Assertions.assertArrayEquals(r0[0][0], saliencyToMatrix(((ShapResults) explainAsync.join()).getSaliencies())[0].getRow(0), 0.01d);
        });
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v5, types: [double[], double[][]] */
    @Test
    void testTooLargeBackground() throws InterruptedException, TimeoutException, ExecutionException {
        double[][] dArr = new double[10][10];
        for (int i = 0; i < 10; i++) {
            for (int i2 = 0; i2 < 10; i2++) {
                dArr[i][i2] = (i / 10.0d) + i2;
            }
        }
        List<PredictionInput> createPIFromMatrix = createPIFromMatrix(dArr);
        List<PredictionInput> createPIFromMatrix2 = createPIFromMatrix(new double[]{new double[]{0.0d, 1.0d, 2.0d, 3.0d, 4.0d}});
        PredictionProvider sumSkipModel = TestUtils.getSumSkipModel(1);
        ShapConfig build = this.testConfig.withBackground(createPIFromMatrix).build();
        List list = (List) sumSkipModel.predictAsync(createPIFromMatrix2).get(5L, TimeUnit.SECONDS);
        ArrayList arrayList = new ArrayList();
        for (int i3 = 0; i3 < list.size(); i3++) {
            arrayList.add(new SimplePrediction(createPIFromMatrix2.get(i3), (PredictionOutput) list.get(i3)));
        }
        Prediction prediction = (Prediction) arrayList.get(0);
        ShapKernelExplainer shapKernelExplainer = new ShapKernelExplainer(build);
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            shapKernelExplainer.explainAsync(prediction, sumSkipModel);
        });
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v5, types: [double[], double[][]] */
    @Test
    void testPredictionWrongSize() throws InterruptedException, TimeoutException, ExecutionException {
        double[][] dArr = new double[5][5];
        for (int i = 0; i < 5; i++) {
            for (int i2 = 0; i2 < 5; i2++) {
                dArr[i][i2] = (i / 5.0d) + i2;
            }
        }
        List<PredictionInput> createPIFromMatrix = createPIFromMatrix(dArr);
        List<PredictionInput> createPIFromMatrix2 = createPIFromMatrix(new double[]{new double[]{0.0d, 1.0d, 2.0d, 3.0d, 4.0d}});
        PredictionProvider sumSkipTwoOutputModel = TestUtils.getSumSkipTwoOutputModel(1);
        PredictionProvider sumSkipModel = TestUtils.getSumSkipModel(1);
        ShapConfig build = this.testConfig.withBackground(createPIFromMatrix).build();
        List list = (List) sumSkipTwoOutputModel.predictAsync(createPIFromMatrix2).get(5L, TimeUnit.SECONDS);
        ArrayList arrayList = new ArrayList();
        for (int i3 = 0; i3 < list.size(); i3++) {
            arrayList.add(new SimplePrediction(createPIFromMatrix2.get(i3), (PredictionOutput) list.get(i3)));
        }
        Prediction prediction = (Prediction) arrayList.get(0);
        ShapKernelExplainer shapKernelExplainer = new ShapKernelExplainer(build);
        Assertions.assertThrows(ExecutionException.class, () -> {
            shapKernelExplainer.explainAsync(prediction, sumSkipModel).get();
        });
    }

    @Test
    void testStateless() throws InterruptedException, TimeoutException, ExecutionException {
        PredictionProvider sumSkipModel = TestUtils.getSumSkipModel(1);
        ShapConfig build = this.testConfig.withBackground(createPIFromMatrix(this.backgroundNoVariance)).withNSamples(100).build();
        ShapConfig build2 = this.testConfig.withBackground(createPIFromMatrix(this.backgroundRaw)).withNSamples(35).build();
        ShapConfig build3 = ShapConfig.builder().withBackground(createPIFromMatrix(this.backgroundLogit)).withLink(ShapConfig.LinkType.LOGIT).withNSamples(100).withPC(this.pc).build();
        ShapKernelExplainer shapKernelExplainer = new ShapKernelExplainer(build);
        for (int i = 0; i < 10; i++) {
            shapTestCase(sumSkipModel, shapKernelExplainer, this.toExplainOneVariance, this.oneVarianceOneOutputSHAP);
            shapKernelExplainer.setConfig(build2);
            shapTestCase(sumSkipModel, shapKernelExplainer, this.toExplainRaw, this.multiVarianceOneOutputSHAP);
            shapKernelExplainer.setConfig(build3);
            shapTestCase(sumSkipModel, shapKernelExplainer, this.toExplainLogit, this.logitSHAP);
            shapKernelExplainer.setConfig(build);
        }
    }

    @ValueSource(doubles = {0.001d, 0.1d, 0.25d, 0.5d})
    @ParameterizedTest
    void testErrorBounds(double d) throws InterruptedException, ExecutionException {
        for (double d2 : new double[]{0.95d, 0.975d, 0.99d}) {
            int[] iArr = new int[600];
            for (int i = 0; i < 100; i++) {
                PredictionProvider noisySumModel = TestUtils.getNoisySumModel(this.pc.getRandom(), d);
                ShapConfig build = this.testConfig.withBackground(createPIFromMatrix(this.backgroundAllZeros)).withConfidence(d2).build();
                List<PredictionInput> createPIFromMatrix = createPIFromMatrix(this.toExplainAllOnes);
                RealMatrix[] saliencyToMatrix = saliencyToMatrix(((ShapResults) new ShapKernelExplainer(build).explainAsync(new SimplePrediction(createPIFromMatrix.get(0), (PredictionOutput) ((List) noisySumModel.predictAsync(createPIFromMatrix).get()).get(0)), noisySumModel).get()).getSaliencies());
                RealMatrix realMatrix = saliencyToMatrix[0];
                RealMatrix realMatrix2 = saliencyToMatrix[1];
                for (int i2 = 0; i2 < realMatrix.getRowDimension(); i2++) {
                    for (int i3 = 0; i3 < realMatrix.getColumnDimension(); i3++) {
                        double entry = realMatrix2.getEntry(i2, i3);
                        double entry2 = realMatrix.getEntry(i2, i3);
                        iArr[(i * 6) + i3] = (((entry2 + entry) > 1.0d ? 1 : ((entry2 + entry) == 1.0d ? 0 : -1)) > 0) & ((1.0d > (entry2 - entry) ? 1 : (1.0d == (entry2 - entry) ? 0 : -1)) > 0) ? 1 : 0;
                    }
                }
            }
            Assertions.assertEquals(d2, Arrays.stream(iArr).sum() / 600.0d, 0.05d);
        }
    }
}
