/*
 * Decompiled with CFR 0.152.
 */
package org.kie.pmml.models.drools.tree.evaluator;

import java.io.File;
import java.nio.file.Files;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import org.drools.core.impl.InternalKnowledgeBase;
import org.drools.core.impl.KnowledgeBaseFactory;
import org.drools.core.reteoo.builder.NodeFactory;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.kie.api.KieBase;
import org.kie.api.KieBaseConfiguration;
import org.kie.api.definition.type.FactType;
import org.kie.api.io.ResourceType;
import org.kie.api.pmml.PMML4Result;
import org.kie.api.runtime.KieSession;
import org.kie.internal.builder.KnowledgeBuilder;
import org.kie.internal.builder.KnowledgeBuilderConfiguration;
import org.kie.internal.builder.KnowledgeBuilderFactory;
import org.kie.internal.io.ResourceFactory;
import org.kie.pmml.api.enums.ResultCode;
import org.kie.pmml.evaluator.api.exceptions.KiePMMLModelException;
import org.kie.pmml.models.drools.executor.KiePMMLStatusHolder;
import org.kie.test.util.filesystem.FileUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@RunWith(value=Parameterized.class)
public class DrlIrisTreeTest {
    private static final String SOURCE_1 = "IrisTreeGen.drl";
    private static final Logger logger = LoggerFactory.getLogger(DrlIrisTreeTest.class);
    private static final String PACKAGE = "iristreemodel";
    private static final String TARGET_FIELD = "Species";
    private static KieBase kbase;
    private double sepalLength;
    private double sepalWidth;
    private double petalLength;
    private double petalWidth;
    private String expectedResult;

    public DrlIrisTreeTest(double sepalLength, double sepalWidth, double petalLength, double petalWidth, String expectedResult) {
        this.sepalLength = sepalLength;
        this.sepalWidth = sepalWidth;
        this.petalLength = petalLength;
        this.petalWidth = petalWidth;
        this.expectedResult = expectedResult;
    }

    @Parameterized.Parameters
    public static Collection<Object[]> data() {
        return Arrays.asList({6.9, 3.1, 5.1, 2.3, "virginica"}, {5.8, 2.6, 4.0, 1.2, "versicolor"}, {5.7, 3.0, 4.2, 1.2, "versicolor"}, {5.0, 3.3, 1.4, 0.2, "setosa"}, {5.4, 3.9, 1.3, 0.4, "setosa"});
    }

    @BeforeClass
    public static void setUp() throws Exception {
        File drlFile = FileUtils.getFile((String)SOURCE_1);
        String content = new String(Files.readAllBytes(drlFile.toPath()));
        kbase = DrlIrisTreeTest.loadKnowledgeBaseFromString(null, null, null, content);
    }

    private static KieBase loadKnowledgeBaseFromString(KnowledgeBuilderConfiguration config, KieBaseConfiguration kBaseConfig, NodeFactory nodeFactory, String ... drlContentStrings) {
        InternalKnowledgeBase kbase;
        KnowledgeBuilder kbuilder = config == null ? KnowledgeBuilderFactory.newKnowledgeBuilder() : KnowledgeBuilderFactory.newKnowledgeBuilder((KnowledgeBuilderConfiguration)config);
        for (String drlContentString : drlContentStrings) {
            kbuilder.add(ResourceFactory.newByteArrayResource((byte[])drlContentString.getBytes()), ResourceType.DRL);
        }
        if (kbuilder.hasErrors()) {
            Assert.fail((String)kbuilder.getErrors().toString());
        }
        if (kBaseConfig == null) {
            kBaseConfig = KnowledgeBaseFactory.newKnowledgeBaseConfiguration();
        }
        InternalKnowledgeBase internalKnowledgeBase = kbase = kBaseConfig == null ? KnowledgeBaseFactory.newKnowledgeBase() : KnowledgeBaseFactory.newKnowledgeBase((KieBaseConfiguration)kBaseConfig);
        if (nodeFactory != null) {
            kbase.getConfiguration().getComponentFactory().setNodeFactoryProvider(nodeFactory);
        }
        kbase.addPackages(kbuilder.getKnowledgePackages());
        return kbase;
    }

    @Test
    public void testIrisTree() {
        HashMap<String, Object> inputData = new HashMap<String, Object>();
        inputData.put("SEPAL_LENGTH", this.sepalLength);
        inputData.put("SEPAL_WIDTH", this.sepalWidth);
        inputData.put("PETAL_LENGTH", this.petalLength);
        inputData.put("PETAL_WIDTH", this.petalWidth);
        this.commonExecute(inputData);
    }

    private void commonExecute(Map<String, Object> inputData) {
        KieSession kSession = kbase.newKieSession();
        ArrayList<Object> executionParams = new ArrayList<Object>();
        KiePMMLStatusHolder statusHolder = new KiePMMLStatusHolder();
        executionParams.add(statusHolder);
        PMML4Result pmml4Result = new PMML4Result();
        pmml4Result.setResultCode(ResultCode.FAIL.getName());
        pmml4Result.setResultObjectName(TARGET_FIELD);
        executionParams.add(pmml4Result);
        for (Map.Entry<String, Object> entry : inputData.entrySet()) {
            try {
                FactType factType = kSession.getKieBase().getFactType(PACKAGE, entry.getKey());
                Object toAdd = factType.newInstance();
                factType.set(toAdd, "value", entry.getValue());
                executionParams.add(toAdd);
            }
            catch (Exception e) {
                throw new KiePMMLModelException(e.getMessage(), (Throwable)e);
            }
        }
        executionParams.forEach(arg_0 -> ((KieSession)kSession).insert(arg_0));
        kSession.setGlobal("$pmml4Result", (Object)pmml4Result);
        kSession.fireAllRules();
        Assert.assertEquals((Object)ResultCode.OK.getName(), (Object)pmml4Result.getResultCode());
        Assert.assertNotNull(pmml4Result.getResultVariables().get(TARGET_FIELD));
        Assert.assertEquals((Object)this.expectedResult, pmml4Result.getResultVariables().get(TARGET_FIELD));
    }
}

