package org.drools.pmml.pmml_4_2.predictive.models;

import org.drools.pmml.pmml_4_2.DroolsAbstractPMMLTest;
import org.drools.pmml.pmml_4_2.PMML4Helper;
import org.junit.After;
import org.junit.Assert;
import org.junit.Test;
import org.kie.api.definition.type.FactType;
import org.kie.api.runtime.ClassObjectFilter;
import org.kie.api.runtime.rule.QueryResultsRow;
import org.kie.api.runtime.rule.Variable;

/* loaded from: input_file:org/drools/pmml/pmml_4_2/predictive/models/NeuralNetworkTest.class */
public class NeuralNetworkTest extends DroolsAbstractPMMLTest {
    private static final boolean VERBOSE = true;
    private static final String source1 = "org/drools/pmml/pmml_4_2/test_ann_regression.xml";
    private static final String source3 = "org/drools/pmml/pmml_4_2/test_miningSchema.xml";
    private static final String source2 = "org/drools/pmml/pmml_4_2/test_ann_iris.xml";
    private static final String source22 = "org/drools/pmml/pmml_4_2/test_ann_iris_v2.xml";
    private static final String source23 = "org/drools/pmml/pmml_4_2/test_ann_iris_prediction.xml";
    private static final String source4 = "org/drools/pmml/pmml_4_2/test_ann_mixed_inputs2.xml";
    private static final String source6 = "org/drools/pmml/pmml_4_2/mock_ptsd.xml";
    private static final String source7 = "org/drools/pmml/pmml_4_2/mock_cold.xml";
    private static final String source8 = "org/drools/pmml/pmml_4_2/mock_breastcancer.xml";
    private static final String source9 = "org/drools/pmml/pmml_4_2/test_nn_clax_output.xml";
    private static final String packageName = "org.drools.pmml.pmml_4_2.test";
    private static final String smartVent = "org/drools/pmml/pmml_4_2/smartvent.xml";

    @After
    public void tearDown() {
        getKSession().dispose();
    }

    @Test
    public void testANN() throws Exception {
        setKSession(getModelSession(source1, true));
        setKbase(getKSession().getKieBase());
        getKSession().fireAllRules();
        Assert.assertEquals(33L, getNumAssertedSynapses());
        getKSession().getEntryPoint("in_Gender").insert("male");
        getKSession().getEntryPoint("in_NoOfClaims").insert("3");
        getKSession().getEntryPoint("in_Scrambled").insert(7);
        getKSession().getEntryPoint("in_Domicile").insert("urban");
        getKSession().getEntryPoint("in_AgeOfCar").insert(Double.valueOf(8.0d));
        getKSession().fireAllRules();
        Thread.sleep(200L);
        Assert.assertEquals(828.0d, Math.floor(queryDoubleField("OutAmOfClaims", "NeuralInsurance")), 0.0d);
    }

    @Test
    public void testANNCompilation() throws Exception {
        setKSession(getModelSession(source3, true));
        setKbase(getKSession().getKieBase());
    }

    @Test
    public void testCold() throws Exception {
        setKSession(getModelSession(source7, true));
        setKbase(getKSession().getKieBase());
        getKSession().fireAllRules();
        getKSession().getEntryPoint("in_Temp").insert(Double.valueOf(28.0d));
        getKSession().fireAllRules();
        Assert.assertEquals(0.44d, queryDoubleField("Cold", "MockCold"), 1.0E-6d);
    }

    private String getQId(String str, String str2) {
        return (String) ((QueryResultsRow) getKSession().getQueryResults("getItemId", new Object[]{str + "_" + str2, (String) ((QueryResultsRow) getKSession().getQueryResults("getItemId", new Object[]{str, Variable.v, Variable.v}).iterator().next()).get("$id"), Variable.v}).iterator().next()).get("$id");
    }

    @Test
    public void testIris() throws Exception {
        setKSession(getModelSession(source2, true));
        setKbase(getKSession().getKieBase());
        getKSession().fireAllRules();
        Assert.assertEquals(21L, getNumAssertedSynapses());
        getKSession().getEntryPoint("in_PetalLen").insert(Double.valueOf(2.2d));
        getKSession().getEntryPoint("in_PetalWid").insert(Double.valueOf(4.1d));
        getKSession().getEntryPoint("in_SepalLen").insert(Double.valueOf(2.3d));
        getKSession().getEntryPoint("in_SepalWid").insert(Double.valueOf(1.8d));
        getKSession().fireAllRules();
        FactType factType = getKbase().getFactType(packageName, "Test_MLP_7");
        FactType factType2 = getKbase().getFactType(packageName, "Test_MLP_8");
        FactType factType3 = getKbase().getFactType(packageName, "Test_MLP_9");
        getKbase().getFactType(packageName, "Cspecies_virginica");
        Assert.assertEquals(0.001d, truncN(getDoubleFieldValue(factType).doubleValue(), 3), 1.0E-4d);
        Assert.assertEquals(0.282d, truncN(getDoubleFieldValue(factType2).doubleValue(), 3), 1.0E-4d);
        Assert.assertEquals(0.716d, truncN(getDoubleFieldValue(factType3).doubleValue(), 3), 1.0E-4d);
        checkFirstDataFieldOfTypeStatus(getKbase().getFactType(packageName, "SpecSetosa"), true, false, "Test_MLP", Double.valueOf(0.001111d));
        checkFirstDataFieldOfTypeStatus(getKbase().getFactType(packageName, "SpecVirgin"), true, false, "Test_MLP", Double.valueOf(0.716639d));
        checkFirstDataFieldOfTypeStatus(getKbase().getFactType(packageName, "SpecVersic"), true, false, "Test_MLP", Double.valueOf(0.282249d));
        checkFirstDataFieldOfTypeStatus(getKbase().getFactType(packageName, "SpecOut"), true, false, "Test_MLP", "virginica");
    }

    @Test
    public void testIris2() throws Exception {
        setKSession(getModelSession(source22, true));
        setKbase(getKSession().getKieBase());
        getKSession().fireAllRules();
        Assert.assertEquals(12L, getNumAssertedSynapses());
        getKSession().getEntryPoint("in_PetalLen").insert(101);
        getKSession().getEntryPoint("in_PetalWid").insert(Integer.valueOf(VERBOSE));
        getKSession().getEntryPoint("in_SepalLen").insert(151);
        getKSession().getEntryPoint("in_SepalWid").insert(30);
        getKSession().fireAllRules();
        FactType factType = getKbase().getFactType(packageName, "Test_MLP_0");
        FactType factType2 = getKbase().getFactType(packageName, "Test_MLP_1");
        FactType factType3 = getKbase().getFactType(packageName, "Test_MLP_2");
        Assert.assertEquals(1.542d, truncN(getDoubleFieldValue(factType).doubleValue(), 3), 0.0d);
        Assert.assertEquals(0.0d, truncN(getDoubleFieldValue(factType2).doubleValue(), 3), 0.0d);
        Assert.assertEquals(3.0d, truncN(getDoubleFieldValue(factType3).doubleValue(), 3), 0.0d);
        checkFirstDataFieldOfTypeStatus(getKbase().getFactType(packageName, "OutSpecies"), true, false, "Test_MLP", "versicolor");
        checkFirstDataFieldOfTypeStatus(getKbase().getFactType(packageName, "OutProb"), true, false, "Test_MLP", Double.valueOf(0.999999d));
    }

    @Test
    public void testIris3() throws Exception {
        setKSession(getModelSession(source23, true));
        setKbase(getKSession().getKieBase());
        getKSession().fireAllRules();
        Assert.assertEquals(6L, getNumAssertedSynapses());
        getKSession().getEntryPoint("in_PetalNum").insert(101);
        getKSession().getEntryPoint("in_PetalWid").insert(2);
        getKSession().getEntryPoint("in_Species").insert("virginica");
        getKSession().getEntryPoint("in_SepalWid").insert(30);
        getKSession().fireAllRules();
        Assert.assertEquals(24.0d, queryIntegerField("OutSepLen", "Neuiris"), 0.0d);
    }

    @Test
    public void testSimpleANN() throws Exception {
        setKSession(getModelSession(source3, true));
        setKbase(getKSession().getKieBase());
        getKSession().getEntryPoint("in_Feat2").insert(4);
        getKSession().getEntryPoint("in_Feat1").insert(Double.valueOf(3.5d));
        getKSession().fireAllRules();
        checkFirstDataFieldOfTypeStatus(getKbase().getFactType(packageName, "MockOutput2"), true, false, "Test_MLP", Double.valueOf(1.0d));
        checkFirstDataFieldOfTypeStatus(getKbase().getFactType(packageName, "MockOutput1"), true, false, "Test_MLP", Double.valueOf(0.0d));
    }

    @Test
    public void testHeart() throws Exception {
        setKSession(getModelSession(source4, true));
        setKbase(getKSession().getKieBase());
        getKSession().fireAllRules();
        Assert.assertEquals(81L, getNumAssertedSynapses());
        getKSession().getEntryPoint("in_Feat1").insert(Double.valueOf(83.0d));
        getKSession().getEntryPoint("in_Feat2").insert(Double.valueOf(1.0d));
        getKSession().getEntryPoint("in_Feat3").insert(Double.valueOf(5.0d));
        getKSession().getEntryPoint("in_Feat4").insert("asympt");
        getKSession().getEntryPoint("in_Feat5").insert("yes");
        getKSession().getEntryPoint("in_Feat6").insert("t");
        getKSession().getEntryPoint("in_Feat7").insert(Double.valueOf(1.0d));
        getKSession().getEntryPoint("in_Feat8").insert("normal");
        getKSession().getEntryPoint("in_Feat9").insert("male");
        getKSession().getEntryPoint("in_Feat10").insert("flat");
        getKSession().getEntryPoint("in_Feat11").insert("normal");
        getKSession().getEntryPoint("in_Feat12").insert(Double.valueOf(3.3d));
        getKSession().getEntryPoint("in_Feat13").insert(Double.valueOf(2.5d));
        getKSession().fireAllRules();
        checkFirstDataFieldOfTypeStatus(getKbase().getFactType(packageName, "OutN"), true, false, "HEART_MLP", ">50_1");
        checkFirstDataFieldOfTypeStatus(getKbase().getFactType(packageName, "OutP"), true, false, "HEART_MLP", Double.valueOf(0.943336d));
    }

    @Test
    public void testOverride() throws Exception {
        setKSession(getModelSession(source3, true));
        setKbase(getKSession().getKieBase());
        getKSession().fireAllRules();
        getKSession().getEntryPoint("in_Feat1").insert(Double.valueOf(2.2d));
        getKSession().fireAllRules();
        getKSession().getEntryPoint("in_Feat2").insert(5);
        getKSession().fireAllRules();
        FactType factType = getKbase().getFactType(packageName, "Out1");
        FactType factType2 = getKbase().getFactType(packageName, "Out2");
        FactType factType3 = getKbase().getFactType(packageName, "Feat2");
        Assert.assertEquals(1L, getKSession().getObjects(new ClassObjectFilter(factType.getFactClass())).size());
        Assert.assertEquals(1L, getKSession().getObjects(new ClassObjectFilter(factType2.getFactClass())).size());
        Assert.assertEquals(2L, getKSession().getObjects(new ClassObjectFilter(factType3.getFactClass())).size());
        getKSession().getEntryPoint("in_Feat1").insert(Double.valueOf(2.5d));
        getKSession().getEntryPoint("in_Feat2").insert(6);
        getKSession().fireAllRules();
        Assert.assertEquals(1L, getKSession().getObjects(new ClassObjectFilter(factType.getFactClass())).size());
        Assert.assertEquals(1L, getKSession().getObjects(new ClassObjectFilter(factType2.getFactClass())).size());
        Assert.assertEquals(2L, getKSession().getObjects(new ClassObjectFilter(factType3.getFactClass())).size());
    }

    @Test
    public void testSmartVent() throws Exception {
        setKSession(getModelSession(smartVent, true));
        setKbase(getKSession().getKieBase());
        getKSession().fireAllRules();
        getKSession().getEntryPoint("in_PIP").insert(Double.valueOf(28.0d));
        getKSession().getEntryPoint("in_PEEP").insert(Double.valueOf(5.0d));
        getKSession().getEntryPoint("in_RATE").insert(Double.valueOf(30.0d));
        getKSession().getEntryPoint("in_IT").insert(Double.valueOf(0.4d));
        getKSession().getEntryPoint("in_Ph").insert(Double.valueOf(7.281d));
        getKSession().getEntryPoint("in_CO2").insert(Double.valueOf(39.3d));
        getKSession().getEntryPoint("in_PaO2").insert(Double.valueOf(126.0d));
        getKSession().getEntryPoint("in_FIO2").insert(Double.valueOf(100.0d));
        getKSession().fireAllRules();
        Assert.assertEquals(24.0d, queryDoubleField("Out_sPIP", "SmartVent"), 0.5d);
        Assert.assertEquals(5.0d, queryDoubleField("Out_sPEEP", "SmartVent"), 0.1d);
        Assert.assertEquals(30.0d, queryDoubleField("Out_sRATE", "SmartVent"), 0.5d);
        Assert.assertEquals(0.4d, queryDoubleField("Out_sIT", "SmartVent"), 0.05d);
        Assert.assertEquals(-1.0d, queryDoubleField("Out_sFIO2", "SmartVent"), 0.05d);
        getKSession().getEntryPoint("in_RATE").insert(Double.valueOf(20.0d));
        getKSession().getEntryPoint("in_PaO2").insert(Double.valueOf(75.0d));
        getKSession().getEntryPoint("in_Ph").insert(Double.valueOf(7.31d));
        getKSession().getEntryPoint("in_CO2").insert(Double.valueOf(37.0d));
        getKSession().getEntryPoint("in_IT").insert(Double.valueOf(0.4d));
        getKSession().getEntryPoint("in_PIP").insert(Double.valueOf(20.0d));
        getKSession().getEntryPoint("in_PEEP").insert(Double.valueOf(4.0d));
        getKSession().getEntryPoint("in_FIO2").insert(Double.valueOf(38.0d));
        getKSession().fireAllRules();
        Assert.assertEquals(18.0d, queryDoubleField("Out_sPIP", "SmartVent"), 0.5d);
        Assert.assertEquals(4.12d, queryDoubleField("Out_sPEEP", "SmartVent"), 0.1d);
        Assert.assertEquals(19.0d, queryDoubleField("Out_sRATE", "SmartVent"), 0.5d);
        Assert.assertEquals(0.4d, queryDoubleField("Out_sIT", "SmartVent"), 0.05d);
        Assert.assertEquals(-1.0d, queryDoubleField("Out_sFIO2", "SmartVent"), 0.05d);
    }

    @Test
    public void testClaxOutput() throws Exception {
        setKSession(getModelSession(source9, true));
        setKbase(getKSession().getKieBase());
        getKSession().fireAllRules();
        getKSession().getEntryPoint("in_Temp").insert(Double.valueOf(28.0d));
        getKSession().fireAllRules();
        Thread.sleep(200L);
        checkFirstDataFieldOfTypeStatus(getKbase().getFactType(packageName, "ColdCat"), true, false, "MockCold", "SURE");
        checkFirstDataFieldOfTypeStatus(getKbase().getFactType(packageName, "ColdYES"), true, false, "MockCold", Double.valueOf(0.6475435612444598d));
        checkFirstDataFieldOfTypeStatus(getKbase().getFactType(packageName, "ColdNO"), true, false, "MockCold", Double.valueOf(0.0036540476859388943d));
    }

    private int getNumAssertedSynapses() {
        return getKSession().getObjects(new ClassObjectFilter(getKSession().getKieBase().getFactType(PMML4Helper.pmmlDefaultPackageName(), "Synapse").getFactClass())).size();
    }

    private double truncN(double d, int i) {
        return Math.floor(d * Math.pow(10.0d, i)) * Math.pow(10.0d, -i);
    }
}
