package org.drools.pmml.pmml_4_2.predictive.models;

import java.util.Collection;
import org.dmg.pmml.pmml_4_2.descr.MISSINGVALUESTRATEGY;
import org.dmg.pmml.pmml_4_2.descr.PMML;
import org.dmg.pmml.pmml_4_2.descr.TreeModel;
import org.drools.pmml.pmml_4_2.DroolsAbstractPMMLTest;
import org.drools.pmml.pmml_4_2.PMML4Compiler;
import org.drools.pmml.pmml_4_2.PMML4Helper;
import org.junit.After;
import org.junit.Assert;
import org.junit.Test;
import org.kie.api.definition.type.FactType;
import org.kie.api.runtime.ClassObjectFilter;
import org.kie.api.runtime.KieSession;
import org.kie.internal.io.ResourceFactory;

/* loaded from: input_file:org/drools/pmml/pmml_4_2/predictive/models/DecisionTreeTest.class */
public class DecisionTreeTest extends DroolsAbstractPMMLTest {
    private static final boolean VERBOSE = false;
    private static final String source1 = "org/drools/pmml/pmml_4_2/test_tree_simple.xml";
    private static final String source2 = "org/drools/pmml/pmml_4_2/test_tree_missing.xml";
    private static final String packageName = "org.drools.pmml.pmml_4_2.test";

    @After
    public void tearDown() {
        getKSession().dispose();
    }

    @Test
    public void testSimpleTree() throws Exception {
        setKSession(getModelSession(source1, false));
        setKbase(getKSession().getKieBase());
        KieSession kSession = getKSession();
        kSession.fireAllRules();
        FactType factType = kSession.getKieBase().getFactType(packageName, "Fld5");
        kSession.getEntryPoint("in_Fld1").insert(Double.valueOf(30.0d));
        kSession.getEntryPoint("in_Fld2").insert(Double.valueOf(60.0d));
        kSession.getEntryPoint("in_Fld3").insert("false");
        kSession.getEntryPoint("in_Fld4").insert("optA");
        kSession.fireAllRules();
        checkFirstDataFieldOfTypeStatus(factType, true, false, "Missing", "tgtY");
        checkGeneratedRules();
    }

    protected Object getToken(KieSession kieSession) {
        FactType factType = kieSession.getKieBase().getFactType(PMML4Helper.pmmlDefaultPackageName(), "TreeToken");
        Assert.assertNotNull(factType);
        Collection objects = kieSession.getObjects(new ClassObjectFilter(factType.getFactClass()));
        Assert.assertEquals(1L, objects.size());
        return objects.iterator().next();
    }

    @Test
    public void testMissingTree() throws Exception {
        setKSession(getModelSession(source2, false));
        setKbase(getKSession().getKieBase());
        KieSession kSession = getKSession();
        kSession.fireAllRules();
        FactType factType = kSession.getKieBase().getFactType(packageName, "Fld9");
        FactType factType2 = kSession.getKieBase().getFactType(PMML4Helper.pmmlDefaultPackageName(), "TreeToken");
        kSession.getEntryPoint("in_Fld1").insert(Double.valueOf(45.0d));
        kSession.getEntryPoint("in_Fld2").insert(Double.valueOf(60.0d));
        kSession.getEntryPoint("in_Fld3").insert("optA");
        kSession.fireAllRules();
        Object token = getToken(kSession);
        Assert.assertEquals(Double.valueOf(0.6d), factType2.get(token, "confidence"));
        Assert.assertEquals("null", factType2.get(token, "current"));
        checkFirstDataFieldOfTypeStatus(factType, true, false, "Missing", "tgtZ");
        checkGeneratedRules();
    }

    @Test
    public void testMissingTreeWeighted1() throws Exception {
        setKSession(getModelSession(source2, false));
        setKbase(getKSession().getKieBase());
        KieSession kSession = getKSession();
        kSession.fireAllRules();
        FactType factType = kSession.getKieBase().getFactType(packageName, "Fld9");
        FactType factType2 = kSession.getKieBase().getFactType(PMML4Helper.pmmlDefaultPackageName(), "TreeToken");
        kSession.getEntryPoint("in_Fld1").insert(Double.valueOf(-1.0d));
        kSession.getEntryPoint("in_Fld2").insert(Double.valueOf(-1.0d));
        kSession.getEntryPoint("in_Fld3").insert("optA");
        kSession.fireAllRules();
        Object token = getToken(kSession);
        Assert.assertEquals(Double.valueOf(0.8d), factType2.get(token, "confidence"));
        Assert.assertEquals("null", factType2.get(token, "current"));
        Assert.assertEquals(Double.valueOf(50.0d), factType2.get(token, "totalCount"));
        checkFirstDataFieldOfTypeStatus(factType, true, false, "Missing", "tgtX");
        checkGeneratedRules();
    }

    @Test
    public void testMissingTreeWeighted2() throws Exception {
        setKSession(getModelSession(source2, false));
        setKbase(getKSession().getKieBase());
        KieSession kSession = getKSession();
        kSession.fireAllRules();
        FactType factType = kSession.getKieBase().getFactType(packageName, "Fld9");
        FactType factType2 = kSession.getKieBase().getFactType(PMML4Helper.pmmlDefaultPackageName(), "TreeToken");
        kSession.getEntryPoint("in_Fld1").insert(Double.valueOf(-1.0d));
        kSession.getEntryPoint("in_Fld2").insert(Double.valueOf(-1.0d));
        kSession.getEntryPoint("in_Fld3").insert("miss");
        kSession.fireAllRules();
        Object token = getToken(kSession);
        Assert.assertEquals(Double.valueOf(0.6d), factType2.get(token, "confidence"));
        Assert.assertEquals("null", factType2.get(token, "current"));
        Assert.assertEquals(Double.valueOf(100.0d), factType2.get(token, "totalCount"));
        checkFirstDataFieldOfTypeStatus(factType, true, false, "Missing", "tgtX");
        checkGeneratedRules();
    }

    @Test
    public void testMissingTreeDefault() throws Exception {
        PMML4Compiler pMML4Compiler = new PMML4Compiler();
        PMML loadModel = pMML4Compiler.loadModel(DroolsAbstractPMMLTest.PMML, ResourceFactory.newClassPathResource(source2).getInputStream());
        for (Object obj : loadModel.getAssociationModelsAndBaselineModelsAndClusteringModels()) {
            if (obj instanceof TreeModel) {
                ((TreeModel) obj).setMissingValueStrategy(MISSINGVALUESTRATEGY.DEFAULT_CHILD);
            }
        }
        KieSession session = getSession(pMML4Compiler.generateTheory(loadModel));
        setKSession(session);
        setKbase(getKSession().getKieBase());
        session.fireAllRules();
        FactType factType = session.getKieBase().getFactType(packageName, "Fld9");
        FactType factType2 = session.getKieBase().getFactType(PMML4Helper.pmmlDefaultPackageName(), "TreeToken");
        session.getEntryPoint("in_Fld1").insert(Double.valueOf(70.0d));
        session.getEntryPoint("in_Fld2").insert(Double.valueOf(40.0d));
        session.getEntryPoint("in_Fld3").insert("miss");
        session.fireAllRules();
        Object token = getToken(session);
        Assert.assertEquals(0.72d, ((Double) factType2.get(token, "confidence")).doubleValue(), 1.0E-6d);
        Assert.assertEquals("null", factType2.get(token, "current"));
        Assert.assertEquals(Double.valueOf(40.0d), factType2.get(token, "totalCount"));
        checkFirstDataFieldOfTypeStatus(factType, true, false, "Missing", "tgtX");
        checkGeneratedRules();
    }

    @Test
    public void testMissingTreeAllMissingDefault() throws Exception {
        PMML4Compiler pMML4Compiler = new PMML4Compiler();
        PMML loadModel = pMML4Compiler.loadModel(DroolsAbstractPMMLTest.PMML, ResourceFactory.newClassPathResource(source2).getInputStream());
        for (Object obj : loadModel.getAssociationModelsAndBaselineModelsAndClusteringModels()) {
            if (obj instanceof TreeModel) {
                ((TreeModel) obj).setMissingValueStrategy(MISSINGVALUESTRATEGY.DEFAULT_CHILD);
            }
        }
        KieSession session = getSession(pMML4Compiler.generateTheory(loadModel));
        setKSession(session);
        setKbase(getKSession().getKieBase());
        session.fireAllRules();
        session.getKieBase().getFactType(packageName, "Fld9");
        FactType factType = session.getKieBase().getFactType(PMML4Helper.pmmlDefaultPackageName(), "TreeToken");
        session.getEntryPoint("in_Fld1").insert(Double.valueOf(-1.0d));
        session.getEntryPoint("in_Fld2").insert(Double.valueOf(-1.0d));
        session.getEntryPoint("in_Fld3").insert("miss");
        session.fireAllRules();
        Object token = getToken(session);
        Assert.assertEquals(1.0d, ((Double) factType.get(token, "confidence")).doubleValue(), 1.0E-6d);
        Assert.assertEquals("null", factType.get(token, "current"));
        Assert.assertEquals(Double.valueOf(0.0d), factType.get(token, "totalCount"));
        checkGeneratedRules();
    }

    @Test
    public void testMissingTreeLastChoice() throws Exception {
        PMML4Compiler pMML4Compiler = new PMML4Compiler();
        PMML loadModel = pMML4Compiler.loadModel(DroolsAbstractPMMLTest.PMML, ResourceFactory.newClassPathResource(source2).getInputStream());
        for (Object obj : loadModel.getAssociationModelsAndBaselineModelsAndClusteringModels()) {
            if (obj instanceof TreeModel) {
                ((TreeModel) obj).setMissingValueStrategy(MISSINGVALUESTRATEGY.LAST_PREDICTION);
            }
        }
        KieSession session = getSession(pMML4Compiler.generateTheory(loadModel));
        setKSession(session);
        setKbase(getKSession().getKieBase());
        session.fireAllRules();
        FactType factType = session.getKieBase().getFactType(packageName, "Fld9");
        FactType factType2 = session.getKieBase().getFactType(PMML4Helper.pmmlDefaultPackageName(), "TreeToken");
        session.getEntryPoint("in_Fld1").insert(Double.valueOf(-1.0d));
        session.getEntryPoint("in_Fld2").insert(Double.valueOf(-1.0d));
        session.getEntryPoint("in_Fld3").insert("optA");
        session.fireAllRules();
        Object token = getToken(session);
        Assert.assertEquals(0.8d, ((Double) factType2.get(token, "confidence")).doubleValue(), 1.0E-6d);
        Assert.assertEquals("null", factType2.get(token, "current"));
        Assert.assertEquals(Double.valueOf(50.0d), factType2.get(token, "totalCount"));
        checkFirstDataFieldOfTypeStatus(factType, true, false, "Missing", "tgtX");
        checkGeneratedRules();
    }

    @Test
    public void testMissingTreeNull() throws Exception {
        PMML4Compiler pMML4Compiler = new PMML4Compiler();
        PMML loadModel = pMML4Compiler.loadModel(DroolsAbstractPMMLTest.PMML, ResourceFactory.newClassPathResource(source2).getInputStream());
        for (Object obj : loadModel.getAssociationModelsAndBaselineModelsAndClusteringModels()) {
            if (obj instanceof TreeModel) {
                ((TreeModel) obj).setMissingValueStrategy(MISSINGVALUESTRATEGY.NULL_PREDICTION);
            }
        }
        KieSession session = getSession(pMML4Compiler.generateTheory(loadModel));
        setKSession(session);
        setKbase(getKSession().getKieBase());
        session.fireAllRules();
        FactType factType = session.getKieBase().getFactType(packageName, "Fld9");
        FactType factType2 = session.getKieBase().getFactType(PMML4Helper.pmmlDefaultPackageName(), "TreeToken");
        session.getEntryPoint("in_Fld1").insert(Double.valueOf(-1.0d));
        session.getEntryPoint("in_Fld2").insert(Double.valueOf(-1.0d));
        session.getEntryPoint("in_Fld3").insert("optA");
        session.fireAllRules();
        Object token = getToken(session);
        Assert.assertEquals(0.0d, ((Double) factType2.get(token, "confidence")).doubleValue(), 1.0E-6d);
        Assert.assertEquals("null", factType2.get(token, "current"));
        Assert.assertEquals(Double.valueOf(0.0d), factType2.get(token, "totalCount"));
        Assert.assertEquals(0L, getKSession().getObjects(new ClassObjectFilter(factType.getFactClass())).size());
        checkGeneratedRules();
    }

    @Test
    public void testMissingAggregate() throws Exception {
        PMML4Compiler pMML4Compiler = new PMML4Compiler();
        PMML loadModel = pMML4Compiler.loadModel(DroolsAbstractPMMLTest.PMML, ResourceFactory.newClassPathResource(source2).getInputStream());
        for (Object obj : loadModel.getAssociationModelsAndBaselineModelsAndClusteringModels()) {
            if (obj instanceof TreeModel) {
                ((TreeModel) obj).setMissingValueStrategy(MISSINGVALUESTRATEGY.AGGREGATE_NODES);
            }
        }
        KieSession session = getSession(pMML4Compiler.generateTheory(loadModel));
        setKSession(session);
        setKbase(getKSession().getKieBase());
        session.fireAllRules();
        FactType factType = session.getKieBase().getFactType(packageName, "Fld9");
        FactType factType2 = session.getKieBase().getFactType(PMML4Helper.pmmlDefaultPackageName(), "TreeToken");
        session.getEntryPoint("in_Fld1").insert(Double.valueOf(45.0d));
        session.getEntryPoint("in_Fld2").insert(Double.valueOf(90.0d));
        session.getEntryPoint("in_Fld3").insert("miss");
        session.fireAllRules();
        Object token = getToken(session);
        Assert.assertEquals(0.47d, ((Double) factType2.get(token, "confidence")).doubleValue(), 0.01d);
        Assert.assertEquals("null", factType2.get(token, "current"));
        Assert.assertEquals(Double.valueOf(60.0d), factType2.get(token, "totalCount"));
        checkFirstDataFieldOfTypeStatus(factType, true, false, "Missing", "tgtY");
        checkGeneratedRules();
    }

    @Test
    public void testMissingTreeNone() throws Exception {
        PMML4Compiler pMML4Compiler = new PMML4Compiler();
        PMML loadModel = pMML4Compiler.loadModel(DroolsAbstractPMMLTest.PMML, ResourceFactory.newClassPathResource(source2).getInputStream());
        for (Object obj : loadModel.getAssociationModelsAndBaselineModelsAndClusteringModels()) {
            if (obj instanceof TreeModel) {
                ((TreeModel) obj).setMissingValueStrategy(MISSINGVALUESTRATEGY.NONE);
            }
        }
        KieSession session = getSession(pMML4Compiler.generateTheory(loadModel));
        setKSession(session);
        setKbase(getKSession().getKieBase());
        session.fireAllRules();
        FactType factType = session.getKieBase().getFactType(packageName, "Fld9");
        FactType factType2 = session.getKieBase().getFactType(PMML4Helper.pmmlDefaultPackageName(), "TreeToken");
        session.getEntryPoint("in_Fld1").insert(Double.valueOf(-1.0d));
        session.getEntryPoint("in_Fld2").insert(Double.valueOf(-1.0d));
        session.getEntryPoint("in_Fld3").insert("miss");
        session.fireAllRules();
        Object token = getToken(session);
        Assert.assertEquals(0.6d, ((Double) factType2.get(token, "confidence")).doubleValue(), 1.0E-6d);
        Assert.assertEquals("null", factType2.get(token, "current"));
        Assert.assertEquals(Double.valueOf(100.0d), factType2.get(token, "totalCount"));
        checkFirstDataFieldOfTypeStatus(factType, true, false, "Missing", "tgtX");
        checkGeneratedRules();
    }

    @Test
    public void testSimpleTreeOutput() throws Exception {
        setKSession(getModelSession(source2, false));
        setKbase(getKSession().getKieBase());
        KieSession kSession = getKSession();
        kSession.fireAllRules();
        FactType factType = kSession.getKieBase().getFactType(packageName, "Fld9");
        FactType factType2 = kSession.getKieBase().getFactType(PMML4Helper.pmmlDefaultPackageName(), "TreeToken");
        kSession.getEntryPoint("in_Fld1").insert(Double.valueOf(-1.0d));
        kSession.getEntryPoint("in_Fld2").insert(Double.valueOf(-1.0d));
        kSession.getEntryPoint("in_Fld3").insert("optA");
        kSession.fireAllRules();
        Object token = getToken(kSession);
        Assert.assertEquals(Double.valueOf(0.8d), factType2.get(token, "confidence"));
        Assert.assertEquals("null", factType2.get(token, "current"));
        Assert.assertEquals(Double.valueOf(50.0d), factType2.get(token, "totalCount"));
        checkFirstDataFieldOfTypeStatus(factType, true, false, "Missing", "tgtX");
        checkFirstDataFieldOfTypeStatus(kSession.getKieBase().getFactType(packageName, "OutClass"), true, false, "Missing", "tgtX");
        checkFirstDataFieldOfTypeStatus(kSession.getKieBase().getFactType(packageName, "OutProb"), true, false, "Missing", Double.valueOf(0.8d));
        checkGeneratedRules();
    }
}
