package org.kie.kogito.explainability.utils;

import java.util.LinkedList;
import java.util.stream.DoubleStream;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:org/kie/kogito/explainability/utils/LinearModelTest.class */
class LinearModelTest {
    LinearModelTest() {
    }

    @Test
    void testEmptyFitClassificationDoesNothing() {
        LinearModel linearModel = new LinearModel(10, true);
        linearModel.fit(new LinkedList());
        Assertions.assertArrayEquals(new double[10], linearModel.getWeights());
    }

    @Test
    void testEmptyFitRegressionDoesNothing() {
        LinearModel linearModel = new LinearModel(10, false);
        linearModel.fit(new LinkedList());
        Assertions.assertArrayEquals(new double[10], linearModel.getWeights());
    }

    @Test
    void testRegressionFit() {
        LinearModel linearModel = new LinearModel(10, false);
        LinkedList linkedList = new LinkedList();
        for (int i = 0; i < 100; i++) {
            double[] dArr = new double[10];
            for (int i2 = 0; i2 < 10; i2++) {
                dArr[i2] = i / ((1.0d * i2) + i);
            }
            linkedList.add(new ImmutablePair(dArr, Double.valueOf(DoubleStream.of(dArr).sum())));
        }
        org.assertj.core.api.Assertions.assertThat(linearModel.fit(linkedList)).isLessThan(1.0d);
    }

    @Test
    void testClassificationFit() {
        LinearModel linearModel = new LinearModel(10, true);
        LinkedList linkedList = new LinkedList();
        for (int i = 0; i < 100; i++) {
            double[] dArr = new double[10];
            for (int i2 = 0; i2 < 10; i2++) {
                dArr[i2] = i / ((1.0d * i2) + i);
            }
            linkedList.add(new ImmutablePair(dArr, Double.valueOf(i % 2 == 0 ? 1.0d : 0.0d)));
        }
        org.assertj.core.api.Assertions.assertThat(linearModel.fit(linkedList)).isLessThan(1.0d);
    }
}
