package org.kie.pmml.regression.tests;

import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import org.assertj.core.api.Assertions;
import org.assertj.core.data.Percentage;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.kie.api.pmml.PMML4Result;
import org.kie.pmml.api.runtime.PMMLRuntime;
import org.kie.pmml.models.tests.AbstractPMMLTest;

/* loaded from: input_file:org/kie/pmml/regression/tests/LogisticRegressionSimplemaxNormalizationTest.class */
public class LogisticRegressionSimplemaxNormalizationTest extends AbstractPMMLTest {
    private static final String FILE_NAME_NO_SUFFIX = "LogisticRegressionSimplemaxNormalization";
    private static final String MODEL_NAME = "LogisticRegressionSimplemaxNormalization";
    private static final String TARGET_FIELD = "Species";
    private static final String PROBABILITY_SETOSA_FIELD = "Probability_setosa";
    private static final String PROBABILITY_VERSICOLOR_FIELD = "Probability_versicolor";
    private static final String PROBABILITY_VIRGINICA_FIELD = "Probability_virginica";
    private static final Percentage TOLERANCE_PERCENTAGE = Percentage.withPercentage(0.001d);
    private static PMMLRuntime pmmlRuntime;
    private double sepalLength;
    private double sepalWidth;
    private double petalLength;
    private double petalWidth;
    private String expectedResult;
    private double expectedSetosaProbability;
    private double expectedVersicolorProbability;
    private double expectedVirginicaProbability;

    public void initLogisticRegressionSimplemaxNormalizationTest(double d, double d2, double d3, double d4, String str, double d5, double d6, double d7) {
        this.sepalLength = d;
        this.sepalWidth = d2;
        this.petalLength = d3;
        this.petalWidth = d4;
        this.expectedResult = str;
        this.expectedSetosaProbability = d5;
        this.expectedVersicolorProbability = d6;
        this.expectedVirginicaProbability = d7;
    }

    @BeforeAll
    public static void setupClass() {
        pmmlRuntime = getPMMLRuntime("LogisticRegressionSimplemaxNormalization");
    }

    public static Collection<Object[]> data() {
        return Arrays.asList(new Object[]{Double.valueOf(6.9d), Double.valueOf(3.1d), Double.valueOf(5.1d), Double.valueOf(2.3d), "virginica", Double.valueOf(0.0487181316027585d), Double.valueOf(0.0450959264075301d), Double.valueOf(0.906185941989711d)}, new Object[]{Double.valueOf(5.8d), Double.valueOf(2.6d), Double.valueOf(4.0d), Double.valueOf(1.2d), "versicolor", Double.valueOf(0.165004279225608d), Double.valueOf(0.59107423809292d), Double.valueOf(0.243921482681471d)}, new Object[]{Double.valueOf(5.4d), Double.valueOf(3.9d), Double.valueOf(1.3d), Double.valueOf(0.4d), "setosa", Double.valueOf(1.10684700233123d), Double.valueOf(-0.180527000396087d), Double.valueOf(0.0736799980648569d)});
    }

    @MethodSource({"data"})
    @ParameterizedTest
    void testLogisticRegressionWithNormalization(double d, double d2, double d3, double d4, String str, double d5, double d6, double d7) throws Exception {
        initLogisticRegressionSimplemaxNormalizationTest(d, d2, d3, d4, str, d5, d6, d7);
        HashMap hashMap = new HashMap();
        hashMap.put("Sepal.Length", Double.valueOf(d));
        hashMap.put("Sepal.Width", Double.valueOf(d2));
        hashMap.put("Petal.Length", Double.valueOf(d3));
        hashMap.put("Petal.Width", Double.valueOf(d4));
        PMML4Result evaluate = evaluate(pmmlRuntime, hashMap, "LogisticRegressionSimplemaxNormalization", "LogisticRegressionSimplemaxNormalization");
        Assertions.assertThat(evaluate.getResultVariables().get(TARGET_FIELD)).isNotNull();
        Assertions.assertThat(evaluate.getResultVariables().get(TARGET_FIELD)).isEqualTo(str);
        Assertions.assertThat(((Double) evaluate.getResultVariables().get(PROBABILITY_SETOSA_FIELD)).doubleValue()).isCloseTo(d5, TOLERANCE_PERCENTAGE);
        Assertions.assertThat(((Double) evaluate.getResultVariables().get(PROBABILITY_VERSICOLOR_FIELD)).doubleValue()).isCloseTo(d6, TOLERANCE_PERCENTAGE);
        Assertions.assertThat(((Double) evaluate.getResultVariables().get(PROBABILITY_VIRGINICA_FIELD)).doubleValue()).isCloseTo(d7, TOLERANCE_PERCENTAGE);
    }
}
