package org.kie.pmml.models.tree.compiler.factories;

import com.github.javaparser.StaticJavaParser;
import com.github.javaparser.ast.NodeList;
import com.github.javaparser.ast.PackageDeclaration;
import com.github.javaparser.ast.body.MethodDeclaration;
import com.github.javaparser.ast.body.VariableDeclarator;
import com.github.javaparser.ast.expr.DoubleLiteralExpr;
import com.github.javaparser.ast.expr.Expression;
import com.github.javaparser.ast.expr.MethodCallExpr;
import com.github.javaparser.ast.expr.MethodReferenceExpr;
import com.github.javaparser.ast.expr.NameExpr;
import com.github.javaparser.ast.expr.NullLiteralExpr;
import com.github.javaparser.ast.expr.ObjectCreationExpr;
import com.github.javaparser.ast.expr.TypeExpr;
import com.github.javaparser.ast.expr.VariableDeclarationExpr;
import com.github.javaparser.ast.stmt.BlockStmt;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Random;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.dmg.pmml.DataDictionary;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.PMML;
import org.dmg.pmml.ScoreDistribution;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.TreeModel;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import org.kie.pmml.compiler.commons.CommonTestingUtils;
import org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock;
import org.kie.pmml.compiler.commons.testutils.PMMLModelTestUtils;
import org.kie.pmml.compiler.commons.utils.CommonCodegenUtils;
import org.kie.pmml.compiler.commons.utils.JavaParserUtils;
import org.kie.pmml.compiler.commons.utils.ModelUtils;
import org.kie.pmml.compiler.testutils.TestUtils;
import org.kie.pmml.models.tree.compiler.factories.KiePMMLNodeFactory;
import org.kie.pmml.models.tree.compiler.utils.KiePMMLTreeModelUtils;
import org.kie.pmml.models.tree.model.KiePMMLNode;

/* loaded from: input_file:org/kie/pmml/models/tree/compiler/factories/KiePMMLNodeFactoryTest.class */
public class KiePMMLNodeFactoryTest {
    private static final String SOURCE_1 = "TreeSample.pmml";
    private static final String SOURCE_2 = "TreeSimplified.pmml";
    private static final String PACKAGE_NAME = "packagename";
    private static PMML pmml1;
    private static Node node1;
    private static DataDictionary dataDictionary1;
    private static List<DerivedField> derivedFields1;
    private static PMML pmml2;
    private static Node nodeRoot;
    private static Node compoundPredicateNode;
    private static Node nodeLeaf;
    private static DataDictionary dataDictionary2;
    private static List<DerivedField> derivedFields2;

    @BeforeClass
    public static void setupClass() throws Exception {
        pmml1 = TestUtils.loadFromFile(SOURCE_1);
        TreeModel treeModel = (TreeModel) pmml1.getModels().get(0);
        dataDictionary1 = pmml1.getDataDictionary();
        derivedFields1 = ModelUtils.getDerivedFields(pmml1.getTransformationDictionary(), treeModel.getLocalTransformations());
        node1 = treeModel.getNode();
        pmml2 = TestUtils.loadFromFile(SOURCE_2);
        TreeModel treeModel2 = (TreeModel) pmml2.getModels().get(0);
        dataDictionary2 = pmml2.getDataDictionary();
        derivedFields2 = ModelUtils.getDerivedFields(pmml2.getTransformationDictionary(), treeModel2.getLocalTransformations());
        nodeRoot = treeModel2.getNode();
        compoundPredicateNode = (Node) nodeRoot.getNodes().get(0);
        nodeLeaf = (Node) ((Node) ((Node) nodeRoot.getNodes().get(0)).getNodes().get(0)).getNodes().get(0);
    }

    @Test
    public void getKiePMMLNode() {
        KiePMMLNode kiePMMLNode = KiePMMLNodeFactory.getKiePMMLNode(node1, CommonTestingUtils.getFieldsFromDataDictionaryAndDerivedFields(dataDictionary1, derivedFields1), PACKAGE_NAME, Double.valueOf(1.0d), new HasClassLoaderMock());
        Assert.assertNotNull(kiePMMLNode);
        commonVerifyNode(kiePMMLNode, node1);
    }

    @Test
    public void getKiePMMLNodeSourcesMap() {
        Map<String, String> kiePMMLNodeSourcesMap = KiePMMLNodeFactory.getKiePMMLNodeSourcesMap(new KiePMMLNodeFactory.NodeNamesDTO(node1, KiePMMLTreeModelUtils.createNodeClassName(), (String) null, Double.valueOf(1.0d)), CommonTestingUtils.getFieldsFromDataDictionaryAndDerivedFields(dataDictionary1, derivedFields1), PACKAGE_NAME);
        Assert.assertNotNull(kiePMMLNodeSourcesMap);
        commonVerifyNodeSource(kiePMMLNodeSourcesMap, PACKAGE_NAME);
    }

    @Test
    public void populateJavaParserDTOAndSourcesMap() {
        HashMap hashMap = new HashMap();
        KiePMMLNodeFactory.NodeNamesDTO nodeNamesDTO = new KiePMMLNodeFactory.NodeNamesDTO(nodeRoot, KiePMMLTreeModelUtils.createNodeClassName(), (String) null, Double.valueOf(1.0d));
        KiePMMLNodeFactory.JavaParserDTO javaParserDTO = new KiePMMLNodeFactory.JavaParserDTO(nodeNamesDTO, PACKAGE_NAME);
        KiePMMLNodeFactory.populateJavaParserDTOAndSourcesMap(javaParserDTO, hashMap, nodeNamesDTO, CommonTestingUtils.getFieldsFromDataDictionaryAndDerivedFields(dataDictionary2, derivedFields2), true);
        commonVerifyEvaluateNode(javaParserDTO, nodeNamesDTO, true);
    }

    @Test
    public void mergeNodeReferences() {
        KiePMMLNodeFactory.NodeNamesDTO nodeNamesDTO = new KiePMMLNodeFactory.NodeNamesDTO(nodeRoot, KiePMMLTreeModelUtils.createNodeClassName(), (String) null, Double.valueOf(1.0d));
        KiePMMLNodeFactory.JavaParserDTO javaParserDTO = new KiePMMLNodeFactory.JavaParserDTO(nodeNamesDTO, PACKAGE_NAME);
        Node node = (Node) nodeRoot.getNodes().get(0);
        String str = (String) nodeNamesDTO.childrenNodes.get(node);
        String format = String.format("%s.%s", PACKAGE_NAME, str);
        NodeList nodeList = NodeList.nodeList(new Expression[]{KiePMMLNodeFactory.getEvaluateNodeMethodReference(format)});
        MethodReferenceExpr methodReferenceExpr = new MethodReferenceExpr();
        methodReferenceExpr.setScope(new NameExpr(format));
        methodReferenceExpr.setIdentifier("evaluateNode");
        MethodCallExpr methodCallExpr = new MethodCallExpr();
        methodCallExpr.setScope(new TypeExpr(StaticJavaParser.parseClassOrInterfaceType(Arrays.class.getName())));
        methodCallExpr.setName("asList");
        methodCallExpr.setArguments(nodeList);
        KiePMMLNodeFactory.mergeNodeReferences(javaParserDTO, new KiePMMLNodeFactory.NodeNamesDTO(node, nodeNamesDTO.getNestedNodeClassName(node), nodeNamesDTO.nodeClassName, nodeNamesDTO.missingValuePenalty), methodCallExpr);
        MethodReferenceExpr asMethodReferenceExpr = methodCallExpr.getArguments().get(0).asMethodReferenceExpr();
        Assert.assertEquals(javaParserDTO.nodeClassName, asMethodReferenceExpr.getScope().asNameExpr().toString());
        Assert.assertEquals("evaluateNode" + str, asMethodReferenceExpr.getIdentifier());
    }

    @Test
    public void populateEvaluateNode() {
        KiePMMLNodeFactory.NodeNamesDTO nodeNamesDTO = new KiePMMLNodeFactory.NodeNamesDTO(nodeLeaf, KiePMMLTreeModelUtils.createNodeClassName(), "PARENTNODECLASS", Double.valueOf(1.0d));
        KiePMMLNodeFactory.JavaParserDTO javaParserDTO = new KiePMMLNodeFactory.JavaParserDTO(nodeNamesDTO, "packageName");
        KiePMMLNodeFactory.populateEvaluateNode(javaParserDTO, nodeNamesDTO, CommonTestingUtils.getFieldsFromDataDictionaryAndDerivedFields(dataDictionary2, derivedFields2), false);
        commonVerifyEvaluateNode(javaParserDTO, nodeNamesDTO, false);
        KiePMMLNodeFactory.NodeNamesDTO nodeNamesDTO2 = new KiePMMLNodeFactory.NodeNamesDTO(nodeRoot, KiePMMLTreeModelUtils.createNodeClassName(), (String) null, Double.valueOf(1.0d));
        KiePMMLNodeFactory.JavaParserDTO javaParserDTO2 = new KiePMMLNodeFactory.JavaParserDTO(nodeNamesDTO2, "packageName");
        KiePMMLNodeFactory.populateEvaluateNode(javaParserDTO2, nodeNamesDTO2, CommonTestingUtils.getFieldsFromDataDictionaryAndDerivedFields(dataDictionary2, derivedFields2), true);
        commonVerifyEvaluateNode(javaParserDTO2, nodeNamesDTO2, true);
    }

    @Test
    public void populateEvaluateNodeWithNodeFunctions() {
        BlockStmt blockStmt = new BlockStmt();
        VariableDeclarator variableDeclarator = new VariableDeclarator();
        variableDeclarator.setType("Object");
        variableDeclarator.setName("nodeFunctions");
        blockStmt.addStatement(new VariableDeclarationExpr(variableDeclarator));
        Assert.assertFalse(variableDeclarator.getInitializer().isPresent());
        List<String> emptyList = Collections.emptyList();
        KiePMMLNodeFactory.populateEvaluateNodeWithNodeFunctions(blockStmt, emptyList);
        commonVerifyEvaluateNodeWithNodeFunctions(variableDeclarator, emptyList);
        List<String> list = (List) IntStream.range(0, 2).mapToObj(i -> {
            return "full.node.NodeClassName" + i;
        }).collect(Collectors.toList());
        KiePMMLNodeFactory.populateEvaluateNodeWithNodeFunctions(blockStmt, list);
        commonVerifyEvaluateNodeWithNodeFunctions(variableDeclarator, list);
    }

    @Test
    public void getEvaluateNodeMethodReference() {
        MethodReferenceExpr evaluateNodeMethodReference = KiePMMLNodeFactory.getEvaluateNodeMethodReference("full.node.NodeClassName");
        Assert.assertEquals("full.node.NodeClassName", evaluateNodeMethodReference.getScope().toString());
        Assert.assertEquals("evaluateNode", evaluateNodeMethodReference.getIdentifier());
    }

    @Test
    public void populateEvaluateNodeWithScore() {
        BlockStmt blockStmt = new BlockStmt();
        VariableDeclarator variableDeclarator = new VariableDeclarator();
        variableDeclarator.setType("Object");
        variableDeclarator.setName("score");
        blockStmt.addStatement(new VariableDeclarationExpr(variableDeclarator));
        Assert.assertFalse(variableDeclarator.getInitializer().isPresent());
        KiePMMLNodeFactory.populateEvaluateNodeWithScore(blockStmt, (Object) null);
        commonVerifyEvaluateNodeWithScore(variableDeclarator, null);
        KiePMMLNodeFactory.populateEvaluateNodeWithScore(blockStmt, "scoreValue");
        commonVerifyEvaluateNodeWithScore(variableDeclarator, "scoreValue");
        Double valueOf = Double.valueOf(54345.34d);
        KiePMMLNodeFactory.populateEvaluateNodeWithScore(blockStmt, valueOf);
        commonVerifyEvaluateNodeWithScore(variableDeclarator, valueOf);
    }

    @Test
    public void populateEvaluateNodeWithScoreDistributions() {
        BlockStmt blockStmt = new BlockStmt();
        VariableDeclarator variableDeclarator = new VariableDeclarator();
        variableDeclarator.setType("List");
        variableDeclarator.setName("scoreDistributions");
        blockStmt.addStatement(new VariableDeclarationExpr(variableDeclarator));
        Assert.assertFalse(variableDeclarator.getInitializer().isPresent());
        List<ScoreDistribution> randomPMMLScoreDistributions = PMMLModelTestUtils.getRandomPMMLScoreDistributions(false);
        KiePMMLNodeFactory.populateEvaluateNodeWithScoreDistributions(blockStmt, randomPMMLScoreDistributions);
        commonVerifyEvaluateNodeWithScoreDistributions(variableDeclarator, randomPMMLScoreDistributions);
        List<ScoreDistribution> randomPMMLScoreDistributions2 = PMMLModelTestUtils.getRandomPMMLScoreDistributions(true);
        KiePMMLNodeFactory.populateEvaluateNodeWithScoreDistributions(blockStmt, randomPMMLScoreDistributions2);
        commonVerifyEvaluateNodeWithScoreDistributions(variableDeclarator, randomPMMLScoreDistributions2);
    }

    @Test
    public void populateEvaluateNodeWithMissingValuePenalty() {
        BlockStmt blockStmt = new BlockStmt();
        VariableDeclarator variableDeclarator = new VariableDeclarator();
        variableDeclarator.setType("double");
        variableDeclarator.setName("missingValuePenalty");
        blockStmt.addStatement(new VariableDeclarationExpr(variableDeclarator));
        Assert.assertFalse(variableDeclarator.getInitializer().isPresent());
        double nextDouble = new Random().nextDouble();
        KiePMMLNodeFactory.populateEvaluateNodeWithMissingValuePenalty(blockStmt, Double.valueOf(nextDouble));
        Assert.assertTrue(variableDeclarator.getInitializer().isPresent());
        DoubleLiteralExpr doubleLiteralExpr = (Expression) variableDeclarator.getInitializer().get();
        Assert.assertTrue(doubleLiteralExpr instanceof DoubleLiteralExpr);
        Assert.assertEquals(nextDouble, doubleLiteralExpr.asDouble(), 0.0d);
    }

    @Test
    public void populateEvaluateNodeWithPredicateFunction() {
        BlockStmt blockStmt = new BlockStmt();
        KiePMMLNodeFactory.populateEvaluateNodeWithPredicate(blockStmt, compoundPredicateNode.getPredicate(), CommonTestingUtils.getFieldsFromDataDictionaryAndDerivedFields(dataDictionary2, derivedFields2));
        Assert.assertTrue(JavaParserUtils.equalsNode(JavaParserUtils.parseBlock("{\n    KiePMMLSimplePredicate predicate_0 = KiePMMLSimplePredicate.builder(\"temperature\", Collections.emptyList(), org.kie.pmml.api.enums.OPERATOR.GREATER_THAN).withValue(60.0).build();\n    KiePMMLSimplePredicate predicate_1 = KiePMMLSimplePredicate.builder(\"temperature\", Collections.emptyList(), org.kie.pmml.api.enums.OPERATOR.LESS_THAN).withValue(100.0).build();\n    KiePMMLSimplePredicate predicate_2 = KiePMMLSimplePredicate.builder(\"outlook\", Collections.emptyList(), org.kie.pmml.api.enums.OPERATOR.EQUAL).withValue(\"overcast\").build();\n    KiePMMLSimplePredicate predicate_3 = KiePMMLSimplePredicate.builder(\"humidity\", Collections.emptyList(), org.kie.pmml.api.enums.OPERATOR.LESS_THAN).withValue(70.0).build();\n    KiePMMLSimplePredicate predicate_4 = KiePMMLSimplePredicate.builder(\"windy\", Collections.emptyList(), org.kie.pmml.api.enums.OPERATOR.EQUAL).withValue(\"false\").build();\n    KiePMMLCompoundPredicate predicate = KiePMMLCompoundPredicate.builder(Collections.emptyList(), org.kie.pmml.api.enums.BOOLEAN_OPERATOR.AND).withKiePMMLPredicates(Arrays.asList(predicate_0, predicate_1, predicate_2, predicate_3, predicate_4)).build();\n}"), blockStmt));
    }

    @Test
    public void nodeNamesDTO() {
        Assert.assertEquals(nodeRoot.getNodes().size(), new KiePMMLNodeFactory.NodeNamesDTO(nodeRoot, KiePMMLTreeModelUtils.createNodeClassName(), PACKAGE_NAME, Double.valueOf(1.0d)).childrenNodes.size());
    }

    private void commonVerifyEvaluateNode(KiePMMLNodeFactory.JavaParserDTO javaParserDTO, KiePMMLNodeFactory.NodeNamesDTO nodeNamesDTO, boolean z) {
        BlockStmt blockStmt = z ? javaParserDTO.evaluateRootNodeBody : (BlockStmt) ((MethodDeclaration) javaParserDTO.nodeTemplate.getMethodsByName("evaluateNode" + nodeNamesDTO.nodeClassName).get(0)).getBody().orElseThrow(() -> {
            return new RuntimeException("No body in nested node evaluate node");
        });
        commonVerifyEvaluateNodeWithScore((VariableDeclarator) CommonCodegenUtils.getVariableDeclarator(blockStmt, "score").orElseThrow(() -> {
            return new RuntimeException("No SCORE variable declarator in generated methodCallExpr");
        }), nodeNamesDTO.node.getScore());
        VariableDeclarator variableDeclarator = (VariableDeclarator) CommonCodegenUtils.getVariableDeclarator(blockStmt, "nodeFunctions").orElseThrow(() -> {
            return new RuntimeException("No NODE_FUNCTIONS variable declarator in generated methodCallExpr");
        });
        if (z) {
            commonVerifyEvaluateNodeWithNodeFunctions(variableDeclarator, nodeNamesDTO.getNestedNodesFullClassNames(javaParserDTO.packageName));
        } else {
            commonVerifyEvaluateNodeWithNodeFunctions(variableDeclarator, new ArrayList(nodeNamesDTO.childrenNodes.values()));
        }
    }

    private void commonVerifyEvaluateNodeWithNodeFunctions(VariableDeclarator variableDeclarator, List<String> list) {
        Assert.assertTrue(variableDeclarator.getInitializer().isPresent());
        MethodCallExpr methodCallExpr = (Expression) variableDeclarator.getInitializer().get();
        Assert.assertTrue(methodCallExpr instanceof MethodCallExpr);
        MethodCallExpr methodCallExpr2 = methodCallExpr;
        Expression expression = (Expression) methodCallExpr2.getScope().orElseThrow(() -> {
            return new RuntimeException("No scope in generated methodCallExpr");
        });
        if (list.isEmpty()) {
            Assert.assertEquals(Collections.class.getName(), expression.toString());
            Assert.assertEquals("emptyList", methodCallExpr2.getName().asString());
            Assert.assertTrue(methodCallExpr2.getArguments().isEmpty());
        } else {
            Assert.assertEquals(Arrays.class.getName(), expression.toString());
            Assert.assertEquals("asList", methodCallExpr2.getName().asString());
            Assert.assertEquals(list.size(), methodCallExpr2.getArguments().size());
        }
    }

    private void commonVerifyEvaluateNodeWithScore(VariableDeclarator variableDeclarator, Object obj) {
        Assert.assertTrue(variableDeclarator.getInitializer().isPresent());
        Expression expression = (Expression) variableDeclarator.getInitializer().get();
        if (obj == null) {
            Assert.assertTrue(expression instanceof NullLiteralExpr);
        } else {
            Assert.assertTrue(expression instanceof NameExpr);
            Assert.assertEquals(String.format(obj instanceof String ? "\"%s\"" : "%s", obj), expression.toString());
        }
    }

    private void commonVerifyEvaluateNodeWithScoreDistributions(VariableDeclarator variableDeclarator, List<ScoreDistribution> list) {
        Assert.assertTrue(variableDeclarator.getInitializer().isPresent());
        MethodCallExpr methodCallExpr = (Expression) variableDeclarator.getInitializer().get();
        Assert.assertTrue(methodCallExpr instanceof MethodCallExpr);
        MethodCallExpr methodCallExpr2 = methodCallExpr;
        Assert.assertEquals("Arrays", ((Expression) methodCallExpr2.getScope().get()).toString());
        Assert.assertEquals("asList", methodCallExpr2.getName().toString());
        NodeList arguments = methodCallExpr2.getArguments();
        Assert.assertEquals(list.size(), arguments.size());
        arguments.forEach(expression -> {
            Assert.assertTrue(expression instanceof ObjectCreationExpr);
        });
        Stream stream = arguments.stream();
        Class<ObjectCreationExpr> cls = ObjectCreationExpr.class;
        ObjectCreationExpr.class.getClass();
        List list2 = (List) stream.map((v1) -> {
            return r1.cast(v1);
        }).collect(Collectors.toList());
        list.forEach(scoreDistribution -> {
            Optional findFirst = list2.stream().filter(objectCreationExpr -> {
                return scoreDistribution.getValue().equals(objectCreationExpr.getArgument(2).asStringLiteralExpr().asString());
            }).findFirst();
            Assert.assertTrue(findFirst.isPresent());
            Expression expressionForObject = CommonCodegenUtils.getExpressionForObject(Integer.valueOf(scoreDistribution.getRecordCount().intValue()));
            Expression expressionForObject2 = CommonCodegenUtils.getExpressionForObject(Double.valueOf(scoreDistribution.getConfidence().doubleValue()));
            Expression expressionForObject3 = scoreDistribution.getProbability() != null ? CommonCodegenUtils.getExpressionForObject(Double.valueOf(scoreDistribution.getProbability().doubleValue())) : new NullLiteralExpr();
            findFirst.ifPresent(objectCreationExpr2 -> {
                Assert.assertEquals(expressionForObject, objectCreationExpr2.getArgument(3));
                Assert.assertEquals(expressionForObject2, objectCreationExpr2.getArgument(4));
                Assert.assertEquals(expressionForObject3, objectCreationExpr2.getArgument(5));
            });
        });
    }

    private void commonVerifyNode(KiePMMLNode kiePMMLNode, Node node) {
        Assert.assertEquals(node.getId(), kiePMMLNode.getName());
    }

    private void commonVerifyNodeSource(Map<String, String> map, String str) {
        Assert.assertEquals(1L, map.size());
        Assert.assertEquals(str, ((PackageDeclaration) JavaParserUtils.getFromSource(map.values().iterator().next()).getPackageDeclaration().get()).getName().asString());
    }
}
