package org.kie.pmml.models.drools.tree.compiler.factories;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import org.dmg.pmml.PMML;
import org.dmg.pmml.tree.ClassifierNode;
import org.dmg.pmml.tree.LeafNode;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.TreeModel;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.kie.pmml.api.enums.DATA_TYPE;
import org.kie.pmml.compiler.commons.utils.ModelUtils;
import org.kie.pmml.compiler.testutils.TestUtils;
import org.kie.pmml.models.drools.ast.factories.KiePMMLDataDictionaryASTFactory;

/* loaded from: input_file:org/kie/pmml/models/drools/tree/compiler/factories/KiePMMLTreeModelNodeASTFactoryTest.class */
public class KiePMMLTreeModelNodeASTFactoryTest {
    private static final String SOURCE_GOLFING = "TreeSample.pmml";
    private static final String SOURCE_IRIS = "irisTree.pmml";
    private PMML golfingPmml;
    private TreeModel golfingModel;
    private PMML irisPmml;
    private TreeModel irisModel;

    @Before
    public void setUp() throws Exception {
        this.golfingPmml = TestUtils.loadFromFile(SOURCE_GOLFING);
        Assert.assertNotNull(this.golfingPmml);
        Assert.assertEquals(1L, this.golfingPmml.getModels().size());
        Assert.assertTrue(this.golfingPmml.getModels().get(0) instanceof TreeModel);
        this.golfingModel = (TreeModel) this.golfingPmml.getModels().get(0);
        this.irisPmml = TestUtils.loadFromFile(SOURCE_IRIS);
        Assert.assertNotNull(this.irisPmml);
        Assert.assertEquals(1L, this.irisPmml.getModels().size());
        Assert.assertTrue(this.irisPmml.getModels().get(0) instanceof TreeModel);
        this.irisModel = (TreeModel) this.irisPmml.getModels().get(0);
    }

    @Test
    public void declareRulesFromRootGolfingNode() {
        Node node = this.golfingModel.getNode();
        Assert.assertEquals("will play", node.getScore());
        HashMap hashMap = new HashMap();
        DATA_TYPE targetFieldType = ModelUtils.getTargetFieldType(this.golfingPmml.getDataDictionary(), this.golfingModel);
        KiePMMLDataDictionaryASTFactory.factory(hashMap).declareTypes(this.golfingPmml.getDataDictionary());
        KiePMMLTreeModelNodeASTFactory.factory(hashMap, Collections.emptyList(), TreeModel.NoTrueChildStrategy.RETURN_NULL_PREDICTION, targetFieldType).declareRulesFromRootNode(node, "_will");
        Assert.assertFalse(hashMap.isEmpty());
    }

    @Test
    public void declareRulesFromRootIrisNode() {
        Node node = this.irisModel.getNode();
        Assert.assertEquals("setosa", node.getScore());
        HashMap hashMap = new HashMap();
        DATA_TYPE targetFieldType = ModelUtils.getTargetFieldType(this.irisPmml.getDataDictionary(), this.irisModel);
        KiePMMLDataDictionaryASTFactory.factory(hashMap).declareTypes(this.irisPmml.getDataDictionary());
        KiePMMLTreeModelNodeASTFactory.factory(hashMap, Collections.emptyList(), TreeModel.NoTrueChildStrategy.RETURN_NULL_PREDICTION, targetFieldType).declareRulesFromRootNode(node, "_setosa");
        Assert.assertFalse(hashMap.isEmpty());
    }

    @Test
    public void declareIntermediateRuleFromGolfingNode() {
        Node node = (Node) this.golfingModel.getNode().getNodes().get(0);
        Assert.assertEquals("will play", node.getScore());
        HashMap hashMap = new HashMap();
        DATA_TYPE targetFieldType = ModelUtils.getTargetFieldType(this.golfingPmml.getDataDictionary(), this.golfingModel);
        KiePMMLDataDictionaryASTFactory.factory(hashMap).declareTypes(this.golfingPmml.getDataDictionary());
        ArrayList arrayList = new ArrayList();
        KiePMMLTreeModelNodeASTFactory.factory(hashMap, Collections.emptyList(), TreeModel.NoTrueChildStrategy.RETURN_NULL_PREDICTION, targetFieldType).declareIntermediateRuleFromNode(node, "_will play", arrayList);
        Assert.assertFalse(arrayList.isEmpty());
    }

    @Test
    public void declareIntermediateRuleFromIrisNode() {
        Node node = (Node) this.irisModel.getNode().getNodes().get(1);
        Assert.assertEquals("versicolor", node.getScore());
        HashMap hashMap = new HashMap();
        DATA_TYPE targetFieldType = ModelUtils.getTargetFieldType(this.irisPmml.getDataDictionary(), this.irisModel);
        KiePMMLDataDictionaryASTFactory.factory(hashMap).declareTypes(this.irisPmml.getDataDictionary());
        ArrayList arrayList = new ArrayList();
        KiePMMLTreeModelNodeASTFactory.factory(hashMap, Collections.emptyList(), TreeModel.NoTrueChildStrategy.RETURN_NULL_PREDICTION, targetFieldType).declareIntermediateRuleFromNode(node, "_setosa", arrayList);
        Assert.assertFalse(arrayList.isEmpty());
    }

    @Test
    public void isFinalLeaf() {
        LeafNode leafNode = new LeafNode();
        DATA_TYPE data_type = DATA_TYPE.STRING;
        KiePMMLTreeModelNodeASTFactory.factory(new HashMap(), Collections.emptyList(), TreeModel.NoTrueChildStrategy.RETURN_NULL_PREDICTION, data_type).isFinalLeaf(leafNode);
        Assert.assertTrue(KiePMMLTreeModelNodeASTFactory.factory(new HashMap(), Collections.emptyList(), TreeModel.NoTrueChildStrategy.RETURN_NULL_PREDICTION, data_type).isFinalLeaf(leafNode));
        ClassifierNode classifierNode = new ClassifierNode();
        Assert.assertTrue(KiePMMLTreeModelNodeASTFactory.factory(new HashMap(), Collections.emptyList(), TreeModel.NoTrueChildStrategy.RETURN_NULL_PREDICTION, data_type).isFinalLeaf(classifierNode));
        classifierNode.addNodes(new LeafNode());
        Assert.assertFalse(KiePMMLTreeModelNodeASTFactory.factory(new HashMap(), Collections.emptyList(), TreeModel.NoTrueChildStrategy.RETURN_NULL_PREDICTION, data_type).isFinalLeaf(classifierNode));
    }
}
