package org.kie.pmml.models.drools.tree.evaluator;

import java.util.HashMap;
import java.util.Map;
import org.dmg.pmml.PMML;
import org.dmg.pmml.tree.TreeModel;
import org.drools.compiler.builder.impl.KnowledgeBuilderImpl;
import org.drools.compiler.kproject.ReleaseIdImpl;
import org.drools.modelcompiler.ExecutableModelProject;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import org.kie.api.KieBase;
import org.kie.api.builder.ReleaseId;
import org.kie.api.conf.KieBaseOption;
import org.kie.api.definition.KieDescr;
import org.kie.api.pmml.PMML4Result;
import org.kie.api.pmml.PMMLRequestData;
import org.kie.internal.utils.KieHelper;
import org.kie.pmml.api.enums.PMML_MODEL;
import org.kie.pmml.api.enums.ResultCode;
import org.kie.pmml.api.runtime.PMMLContext;
import org.kie.pmml.compiler.commons.CommonTestingUtils;
import org.kie.pmml.compiler.testutils.TestUtils;
import org.kie.pmml.evaluator.core.PMMLContextImpl;
import org.kie.pmml.evaluator.core.utils.PMMLRequestDataBuilder;
import org.kie.pmml.models.drools.tree.compiler.executor.TreeModelImplementationProvider;
import org.kie.pmml.models.drools.tree.evaluator.implementations.HasKnowledgeBuilderMock;
import org.kie.pmml.models.drools.tree.model.KiePMMLTreeModel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/kie/pmml/models/drools/tree/evaluator/PMMLTreeModelEvaluatorTest.class */
public class PMMLTreeModelEvaluatorTest {
    private static final String SOURCE_1 = "TreeSample.pmml";
    private static final String PACKAGE_NAME = "PACKAGE_NAME";
    private static final String modelName = "golfing";
    private static KiePMMLTreeModel kiePMMLModel;
    private static PMMLTreeModelEvaluator evaluator;
    private static KieBase kieBase;
    private final String SCORE = "SCORE";
    private final String WILL_PLAY = "will play";
    private final String NO_PLAY = "no play";
    private final String MAY_PLAY = "may play";
    private final String WHO_PLAY = "who play";
    private final String HUMIDITY = "humidity";
    private final String TEMPERATURE = "temperature";
    private final String OUTLOOK = "outlook";
    private final String SUNNY = "sunny";
    private final String WINDY = "windy";
    private final String OVERCAST = "overcast";
    private final String RAIN = "rain";
    private final String TARGET_FIELD = "whatIdo";
    private static final Logger logger = LoggerFactory.getLogger(PMMLTreeModelEvaluatorTest.class);
    private static final ReleaseId RELEASE_ID = new ReleaseIdImpl("org", "test", "1.0.0");
    private static final TreeModelImplementationProvider provider = new TreeModelImplementationProvider();

    @BeforeClass
    public static void setUp() throws Exception {
        evaluator = new PMMLTreeModelEvaluator();
        PMML loadFromFile = TestUtils.loadFromFile(SOURCE_1);
        Assert.assertNotNull(loadFromFile);
        Assert.assertEquals(1L, loadFromFile.getModels().size());
        Assert.assertTrue(loadFromFile.getModels().get(0) instanceof TreeModel);
        KnowledgeBuilderImpl knowledgeBuilderImpl = new KnowledgeBuilderImpl();
        kiePMMLModel = provider.getKiePMMLModel(PACKAGE_NAME, CommonTestingUtils.getFieldsFromDataDictionary(loadFromFile.getDataDictionary()), loadFromFile.getTransformationDictionary(), (TreeModel) loadFromFile.getModels().get(0), new HasKnowledgeBuilderMock(knowledgeBuilderImpl));
        kieBase = new KieHelper().addContent((KieDescr) knowledgeBuilderImpl.getPackageDescrs(kiePMMLModel.getKModulePackageName()).get(0)).setReleaseId(RELEASE_ID).build(ExecutableModelProject.class, new KieBaseOption[0]);
        Assert.assertNotNull(kieBase);
    }

    @Test
    public void getPMMLModelType() {
        Assert.assertEquals(PMML_MODEL.TREE_MODEL, evaluator.getPMMLModelType());
    }

    @Test
    public void evaluateNull() throws Exception {
        HashMap hashMap = new HashMap();
        hashMap.put("outlook", "sunny");
        commonEvaluate(modelName, hashMap, null);
        hashMap.clear();
        hashMap.put("outlook", "sunny");
        hashMap.put("temperature", Double.valueOf(65.0d));
        commonEvaluate(modelName, hashMap, null);
        hashMap.clear();
        hashMap.put("outlook", "overcast");
        commonEvaluate(modelName, hashMap, null);
        hashMap.clear();
        hashMap.put("outlook", "rain");
        commonEvaluate(modelName, hashMap, null);
        hashMap.clear();
        hashMap.put("outlook", "overcast");
        hashMap.put("temperature", Double.valueOf(80.0d));
        commonEvaluate(modelName, hashMap, null);
    }

    @Test
    public void evaluateWillPlay() throws Exception {
        HashMap hashMap = new HashMap();
        hashMap.put("outlook", "sunny");
        hashMap.put("temperature", Double.valueOf(65.0d));
        hashMap.put("humidity", Double.valueOf(65.0d));
        commonEvaluate(modelName, hashMap, "will play");
    }

    @Test
    public void evaluateNoPlay() throws Exception {
        HashMap hashMap = new HashMap();
        hashMap.put("outlook", "sunny");
        hashMap.put("temperature", Double.valueOf(65.0d));
        hashMap.put("humidity", Double.valueOf(95.0d));
        commonEvaluate(modelName, hashMap, "no play");
        hashMap.clear();
        hashMap.put("outlook", "sunny");
        hashMap.put("humidity", Double.valueOf(95.0d));
        hashMap.put("temperature", Double.valueOf(95.0d));
        commonEvaluate(modelName, hashMap, "no play");
        hashMap.clear();
        hashMap.put("outlook", "sunny");
        hashMap.put("temperature", Double.valueOf(95.0d));
        commonEvaluate(modelName, hashMap, "no play");
        hashMap.clear();
        hashMap.put("outlook", "sunny");
        hashMap.put("temperature", Double.valueOf(45.0d));
        commonEvaluate(modelName, hashMap, "no play");
        hashMap.clear();
        hashMap.put("outlook", "rain");
        hashMap.put("humidity", Double.valueOf(45.0d));
        commonEvaluate(modelName, hashMap, "no play");
    }

    @Test
    public void evaluateMayPlay() throws Exception {
        HashMap hashMap = new HashMap();
        hashMap.put("outlook", "overcast");
        hashMap.put("temperature", Double.valueOf(70.0d));
        hashMap.put("humidity", Double.valueOf(60.0d));
        hashMap.put("windy", "false");
        commonEvaluate(modelName, hashMap, "may play");
    }

    @Test
    public void evaluateWhoPlay() throws Exception {
        HashMap hashMap = new HashMap();
        hashMap.put("temperature", Double.valueOf(75.0d));
        hashMap.put("windy", "true");
        hashMap.put("humidity", Double.valueOf(75.0d));
        commonEvaluate(modelName, hashMap, "who play");
        hashMap.clear();
        hashMap.put("windy", "false");
        hashMap.put("temperature", Double.valueOf(65.0d));
        hashMap.put("humidity", Double.valueOf(75.0d));
        commonEvaluate(modelName, hashMap, "who play");
    }

    private void commonEvaluate(String str, Map<String, Object> map, String str2) {
        commonEvaluate(new PMMLContextImpl(getPMMLRequestData(str, map)), str2);
    }

    private void commonEvaluate(PMMLContext pMMLContext, String str) {
        PMML4Result evaluate = evaluator.evaluate(kieBase, kiePMMLModel, pMMLContext);
        Assert.assertNotNull(evaluate);
        logger.trace(evaluate.toString());
        Assert.assertEquals("whatIdo", evaluate.getResultObjectName());
        Map resultVariables = evaluate.getResultVariables();
        Assert.assertNotNull(resultVariables);
        if (str == null) {
            Assert.assertEquals(ResultCode.FAIL.getName(), evaluate.getResultCode());
            Assert.assertFalse(resultVariables.containsKey("whatIdo"));
        } else {
            Assert.assertEquals(ResultCode.OK.getName(), evaluate.getResultCode());
            Assert.assertFalse(resultVariables.isEmpty());
            Assert.assertTrue(resultVariables.containsKey("whatIdo"));
            Assert.assertEquals(str, resultVariables.get("whatIdo"));
        }
    }

    private PMMLRequestData getPMMLRequestData(String str, Map<String, Object> map) {
        PMMLRequestDataBuilder pMMLRequestDataBuilder = new PMMLRequestDataBuilder("CORRELATION_ID", str);
        for (Map.Entry<String, Object> entry : map.entrySet()) {
            Object value = entry.getValue();
            pMMLRequestDataBuilder.addParameter(entry.getKey(), value, value.getClass());
        }
        return pMMLRequestDataBuilder.build();
    }
}
