/*
 * Decompiled with CFR 0.152.
 */
package org.kie.pmml.regression.tests;

import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import org.assertj.core.api.Assertions;
import org.assertj.core.data.Percentage;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.kie.api.pmml.PMML4Result;
import org.kie.pmml.api.runtime.PMMLRuntime;
import org.kie.pmml.models.tests.AbstractPMMLTest;

public class LogisticRegressionIrisDataTest
extends AbstractPMMLTest {
    private static final String FILE_NAME_NO_SUFFIX = "LogisticRegressionIrisData";
    private static final String MODEL_NAME = "LogisticRegressionIrisData";
    private static final String TARGET_FIELD = "Species";
    private static final String PROBABILITY_SETOSA_FIELD = "Probability_setosa";
    private static final String PROBABILITY_VERSICOLOR_FIELD = "Probability_versicolor";
    private static final String PROBABILITY_VIRGINICA_FIELD = "Probability_virginica";
    private static final Percentage TOLERANCE_PERCENTAGE = Percentage.withPercentage((double)0.001);
    private static PMMLRuntime pmmlRuntime;
    private double sepalLength;
    private double sepalWidth;
    private double petalLength;
    private double petalWidth;
    private String expectedResult;

    public void initLogisticRegressionIrisDataTest(double sepalLength, double sepalWidth, double petalLength, double petalWidth, String expectedResult) {
        this.sepalLength = sepalLength;
        this.sepalWidth = sepalWidth;
        this.petalLength = petalLength;
        this.petalWidth = petalWidth;
        this.expectedResult = expectedResult;
    }

    @BeforeAll
    public static void setupClass() {
        pmmlRuntime = LogisticRegressionIrisDataTest.getPMMLRuntime((String)"LogisticRegressionIrisData");
    }

    public static Collection<Object[]> data() {
        return Arrays.asList({6.9, 3.1, 5.1, 2.3, "virginica"}, {5.8, 2.6, 4.0, 1.2, "versicolor"}, {5.7, 3.0, 4.2, 1.2, "versicolor"}, {5.0, 3.3, 1.4, 0.2, "setosa"}, {5.4, 3.9, 1.3, 0.4, "setosa"});
    }

    @MethodSource(value={"data"})
    @ParameterizedTest
    void testLogisticRegressionIrisData(double sepalLength, double sepalWidth, double petalLength, double petalWidth, String expectedResult) {
        this.initLogisticRegressionIrisDataTest(sepalLength, sepalWidth, petalLength, petalWidth, expectedResult);
        HashMap<String, Double> inputData = new HashMap<String, Double>();
        inputData.put("Sepal.Length", sepalLength);
        inputData.put("Sepal.Width", sepalWidth);
        inputData.put("Petal.Length", petalLength);
        inputData.put("Petal.Width", petalWidth);
        PMML4Result pmml4Result = this.evaluate(pmmlRuntime, inputData, "LogisticRegressionIrisData", "LogisticRegressionIrisData");
        Assertions.assertThat(pmml4Result.getResultVariables().get(TARGET_FIELD)).isNotNull();
        Assertions.assertThat(pmml4Result.getResultVariables().get(TARGET_FIELD)).isEqualTo((Object)expectedResult);
        Assertions.assertThat((double)((Double)pmml4Result.getResultVariables().get(PROBABILITY_SETOSA_FIELD))).isCloseTo(this.setosaProbability(), TOLERANCE_PERCENTAGE);
        Assertions.assertThat((double)((Double)pmml4Result.getResultVariables().get(PROBABILITY_VERSICOLOR_FIELD))).isCloseTo(this.versicolorProbability(), TOLERANCE_PERCENTAGE);
        Assertions.assertThat((double)((Double)pmml4Result.getResultVariables().get(PROBABILITY_VIRGINICA_FIELD))).isCloseTo(this.virginicaProbability(), TOLERANCE_PERCENTAGE);
    }

    private double setosaProbability() {
        return 0.0660297693761902 * this.sepalLength + 0.242847872054487 * this.sepalWidth + -0.224657116235727 * this.petalLength + -0.0574727291860025 * this.petalWidth + 0.11822288946815;
    }

    private double versicolorProbability() {
        return -0.0201536848255179 * this.sepalLength + -0.44561625761404 * this.sepalWidth + 0.22066920522933 * this.petalLength + -0.494306595747785 * this.petalWidth + 1.57705897385745;
    }

    private double virginicaProbability() {
        return -0.0458760845506725 * this.sepalLength + 0.202768385559553 * this.sepalWidth + 0.00398791100639665 * this.petalLength + 0.551779324933787 * this.petalWidth - 0.695281863325603;
    }
}

