package org.kie.pmml.pmml_4_2.predictive.models;

import org.drools.core.impl.InternalKnowledgeBase;
import org.junit.After;
import org.junit.Assert;
import org.junit.Test;
import org.kie.api.pmml.PMML4Result;
import org.kie.api.pmml.PMMLRequestData;
import org.kie.api.runtime.rule.RuleUnit;
import org.kie.api.runtime.rule.RuleUnitExecutor;
import org.kie.pmml.pmml_4_2.DroolsAbstractPMMLTest;

/* loaded from: input_file:org/kie/pmml/pmml_4_2/predictive/models/SimpleRegressionTest.class */
public class SimpleRegressionTest extends DroolsAbstractPMMLTest {
    private static final boolean VERBOSE = true;
    private static final String source1 = "org/kie/pmml/pmml_4_2/test_regression.pmml";
    private static final String source2 = "org/kie/pmml/pmml_4_2/test_regression_clax.pmml";
    private static final String packageName = "org.kie.pmml.pmml_4_2.test";

    @After
    public void tearDown() {
    }

    @Test
    public void testRegression() throws Exception {
        RuleUnitExecutor createExecutor = createExecutor(source1);
        PMMLRequestData pMMLRequestData = new PMMLRequestData("123", "LinReg");
        pMMLRequestData.addRequestParam("fld1", Double.valueOf(0.9d));
        pMMLRequestData.addRequestParam("fld2", Double.valueOf(0.3d));
        pMMLRequestData.addRequestParam("fld3", "x");
        PMML4Result pMML4Result = new PMML4Result();
        Class<? extends RuleUnit> startingRuleUnit = getStartingRuleUnit("RuleUnitIndicator", (InternalKnowledgeBase) this.kbase, calculatePossiblePackageNames("LinReg", new String[0]));
        Assert.assertNotNull(startingRuleUnit);
        createExecutor.run(startingRuleUnit);
        this.data.insert(pMMLRequestData);
        this.resultData.insert(pMML4Result);
        createExecutor.run(startingRuleUnit);
        Assert.assertEquals("OK", pMML4Result.getResultCode());
        Assert.assertNotNull(pMML4Result.getResultValue("Fld4", (String) null, new Object[0]));
        Double d = (Double) pMML4Result.getResultValue("Fld4", "value", Double.class, new Object[0]).orElse(null);
        Assert.assertNotNull(d);
        Assert.assertEquals(1.0d / (1.0d + Math.exp(-2.2579999999999996d)), d.doubleValue(), 1.0E-6d);
    }

    @Test
    public void testClassification() throws Exception {
        RuleUnitExecutor createExecutor = createExecutor(source2);
        PMMLRequestData pMMLRequestData = new PMMLRequestData("123", "LinReg");
        pMMLRequestData.addRequestParam("fld1", Double.valueOf(1.0d));
        pMMLRequestData.addRequestParam("fld2", Double.valueOf(1.0d));
        pMMLRequestData.addRequestParam("fld3", "x");
        PMML4Result pMML4Result = new PMML4Result();
        Class<? extends RuleUnit> startingRuleUnit = getStartingRuleUnit("RuleUnitIndicator", (InternalKnowledgeBase) this.kbase, calculatePossiblePackageNames("LinReg", new String[0]));
        Assert.assertNotNull(startingRuleUnit);
        createExecutor.run(startingRuleUnit);
        this.data.insert(pMMLRequestData);
        this.resultData.insert(pMML4Result);
        createExecutor.run(startingRuleUnit);
        Assert.assertNotNull(pMML4Result.getResultValue("RegOut", (String) null, new Object[0]));
        Assert.assertNotNull(pMML4Result.getResultValue("RegProb", (String) null, new Object[0]));
        Assert.assertNotNull(pMML4Result.getResultValue("RegProbA", (String) null, new Object[0]));
        String str = (String) pMML4Result.getResultValue("RegOut", "value", String.class, new Object[0]).orElse(null);
        Double d = (Double) pMML4Result.getResultValue("RegProb", "value", Double.class, new Object[0]).orElse(null);
        Double d2 = (Double) pMML4Result.getResultValue("RegProbA", "value", Double.class, new Object[0]).orElse(null);
        Assert.assertEquals("catC", str);
        Assert.assertEquals(0.709228d, d.doubleValue(), 1.0E-6d);
        Assert.assertEquals(0.010635d, d2.doubleValue(), 1.0E-6d);
    }
}
