package org.kie.pmml.pmml_4_2.predictive.models;

import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.kie.api.KieBaseConfiguration;
import org.kie.api.pmml.PMML4Result;
import org.kie.api.pmml.PMMLRequestData;
import org.kie.internal.io.ResourceFactory;
import org.kie.pmml.pmml_4_2.DroolsAbstractPMMLTest;
import org.kie.pmml.pmml_4_2.PMML4ExecutionHelper;

@RunWith(Parameterized.class)
/* loaded from: input_file:org/kie/pmml/pmml_4_2/predictive/models/SimpleRegressionTest.class */
public class SimpleRegressionTest extends DroolsAbstractPMMLTest {
    private static final String source1 = "org/kie/pmml/pmml_4_2/test_regression.pmml";
    private static final String source2 = "org/kie/pmml/pmml_4_2/test_regression_clax.pmml";
    private static final double COMPARISON_DELTA = 1.0E-6d;
    private double fld1;
    private double fld2;
    private String fld3;

    /* JADX INFO: Access modifiers changed from: private */
    @FunctionalInterface
    /* loaded from: input_file:org/kie/pmml/pmml_4_2/predictive/models/SimpleRegressionTest$RegressionInterface.class */
    public interface RegressionInterface {
        Double apply(double d, double d2, String str);
    }

    @Parameterized.Parameters
    public static Collection<Object[]> data() {
        return Arrays.asList(new Object[]{Double.valueOf(1.0d), Double.valueOf(1.0d), "x"}, new Object[]{Double.valueOf(0.9d), Double.valueOf(0.3d), "x"}, new Object[]{Double.valueOf(12.0d), Double.valueOf(25.0d), "x"}, new Object[]{Double.valueOf(0.2d), Double.valueOf(0.1d), "x"}, new Object[]{5, 8, "y"});
    }

    public SimpleRegressionTest(double d, double d2, String str) {
        this.fld1 = d;
        this.fld2 = d2;
        this.fld3 = str;
    }

    @Test
    public void testRegression() throws Exception {
        PMML4ExecutionHelper executionHelper = PMML4ExecutionHelper.PMML4ExecutionHelperFactory.getExecutionHelper("LinReg", ResourceFactory.newClassPathResource(source1), (KieBaseConfiguration) null);
        PMMLRequestData pMMLRequestData = new PMMLRequestData("123", "LinReg");
        pMMLRequestData.addRequestParam("fld1", Double.valueOf(this.fld1));
        pMMLRequestData.addRequestParam("fld2", Double.valueOf(this.fld2));
        pMMLRequestData.addRequestParam("fld3", this.fld3);
        PMML4Result submitRequest = executionHelper.submitRequest(pMMLRequestData);
        Assert.assertEquals("OK", submitRequest.getResultCode());
        Assert.assertNotNull(submitRequest.getResultValue("Fld4", (String) null, new Object[0]));
        Double d = (Double) submitRequest.getResultValue("Fld4", "value", Double.class, new Object[0]).orElse(null);
        Assert.assertNotNull(d);
        Assert.assertEquals(simpleRegressionResult(this.fld1, this.fld2, this.fld3), d.doubleValue(), COMPARISON_DELTA);
    }

    private double simpleRegressionResult(double d, double d2, String str) {
        return 1.0d / (1.0d + Math.exp(-((((0.5d + ((5.0d * d) * d)) + (2.0d * d2)) + fld3Coefficient(str)) + ((0.4d * d) * d2))));
    }

    @Test
    public void testClassification() throws Exception {
        PMML4ExecutionHelper executionHelper = PMML4ExecutionHelper.PMML4ExecutionHelperFactory.getExecutionHelper("LinReg", ResourceFactory.newClassPathResource(source2), (KieBaseConfiguration) null);
        PMMLRequestData pMMLRequestData = new PMMLRequestData("123", "LinReg");
        pMMLRequestData.addRequestParam("fld1", Double.valueOf(this.fld1));
        pMMLRequestData.addRequestParam("fld2", Double.valueOf(this.fld2));
        pMMLRequestData.addRequestParam("fld3", this.fld3);
        PMML4Result submitRequest = executionHelper.submitRequest(pMMLRequestData);
        Map<String, Double> categoryProbabilities = categoryProbabilities(this.fld1, this.fld2, this.fld3);
        String str = null;
        double d = Double.MIN_VALUE;
        for (String str2 : categoryProbabilities.keySet()) {
            double doubleValue = categoryProbabilities.get(str2).doubleValue();
            if (doubleValue > d) {
                str = str2;
                d = doubleValue;
            }
        }
        Assert.assertNotNull(submitRequest.getResultValue("RegOut", (String) null, new Object[0]));
        Assert.assertNotNull(submitRequest.getResultValue("RegProb", (String) null, new Object[0]));
        Assert.assertNotNull(submitRequest.getResultValue("RegProbA", (String) null, new Object[0]));
        String str3 = (String) submitRequest.getResultValue("RegOut", "value", String.class, new Object[0]).orElse(null);
        Double d2 = (Double) submitRequest.getResultValue("RegProb", "value", Double.class, new Object[0]).orElse(null);
        Double d3 = (Double) submitRequest.getResultValue("RegProbA", "value", Double.class, new Object[0]).orElse(null);
        Assert.assertEquals("cat" + str, str3);
        Assert.assertEquals(d, d2.doubleValue(), COMPARISON_DELTA);
        Assert.assertEquals(categoryProbabilities.get("A").doubleValue(), d3.doubleValue(), COMPARISON_DELTA);
    }

    private Map<String, Double> categoryProbabilities(double d, double d2, String str) {
        HashMap hashMap = new HashMap();
        hashMap.put("A", (d3, d4, str2) -> {
            return Double.valueOf(0.1d + d + d2 + fld3Coefficient(str));
        });
        hashMap.put("B", (d5, d6, str3) -> {
            return Double.valueOf(0.2d + (2.0d * d) + (2.0d * d2) + fld3Coefficient(str));
        });
        hashMap.put("C", (d7, d8, str4) -> {
            return Double.valueOf(0.3d + (3.0d * d) + (3.0d * d2) + fld3Coefficient(str));
        });
        hashMap.put("D", (d9, d10, str5) -> {
            return Double.valueOf(5.0d);
        });
        HashMap hashMap2 = new HashMap();
        double d11 = 0.0d;
        for (String str6 : hashMap.keySet()) {
            double exp = Math.exp(((RegressionInterface) hashMap.get(str6)).apply(d, d2, str).doubleValue());
            hashMap2.put(str6, Double.valueOf(exp));
            d11 += exp;
        }
        for (String str7 : hashMap.keySet()) {
            hashMap2.put(str7, Double.valueOf(((Double) hashMap2.get(str7)).doubleValue() / d11));
        }
        return hashMap2;
    }

    private static double fld3Coefficient(String str) {
        HashMap hashMap = new HashMap();
        hashMap.put("x", Double.valueOf(-3.0d));
        hashMap.put("y", Double.valueOf(3.0d));
        if (hashMap.containsKey(str)) {
            return ((Double) hashMap.get(str)).doubleValue();
        }
        return 0.0d;
    }
}
