package org.kie.pmml.pmml_4_2.predictive.models;

import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;
import org.junit.Assert;
import org.junit.Test;
import org.kie.api.KieBaseConfiguration;
import org.kie.api.conf.KieBaseOption;
import org.kie.api.io.ResourceType;
import org.kie.api.pmml.PMML4Result;
import org.kie.api.pmml.PMMLRequestData;
import org.kie.internal.io.ResourceFactory;
import org.kie.internal.utils.KieHelper;
import org.kie.pmml.pmml_4_2.DroolsAbstractPMMLTest;
import org.kie.pmml.pmml_4_2.PMML4ExecutionHelper;
import org.kie.pmml.pmml_4_2.PMMLRequestDataBuilder;
import org.kie.pmml.pmml_4_2.model.ScoreCard;
import org.kie.pmml.pmml_4_2.model.mining.SegmentExecution;
import org.kie.pmml.pmml_4_2.model.mining.SegmentExecutionState;
import org.kie.pmml.pmml_4_2.model.tree.AbstractTreeToken;

/* loaded from: input_file:org/kie/pmml/pmml_4_2/predictive/models/MiningmodelTest.class */
public class MiningmodelTest extends DroolsAbstractPMMLTest {
    private static final boolean VERBOSE = true;
    private static final String FILE_BASE = "org/kie/pmml/pmml_4_2/";
    private static final String source1 = "org/kie/pmml/pmml_4_2/test_mining_model_simple.pmml";
    private static final String source2 = "org/kie/pmml/pmml_4_2/test_mining_model_simple2.pmml";
    private static final String source3 = "org/kie/pmml/pmml_4_2/filebased";
    private static final String source4 = "org/kie/pmml/pmml_4_2/test_mining_model_selectall.pmml";
    private static final String source5 = "org/kie/pmml/pmml_4_2/test_mining_model_modelchain.pmml";
    private static final String WEIGHTED_AVG = "org/kie/pmml/pmml_4_2/test_mining_model_weighted_avg.pmml";
    private static final String SUMMED = "org/kie/pmml/pmml_4_2/test_mining_model_summed.pmml";
    private static final String RESOURCES_TEST_ROOT = "src/test/resources/";

    @Test
    public void testSelectFirstSegmentFirst() {
        PMML4ExecutionHelper executionHelper = PMML4ExecutionHelper.PMML4ExecutionHelperFactory.getExecutionHelper("SampleMine", ResourceFactory.newClassPathResource(source1), (KieBaseConfiguration) null, true);
        PMMLRequestData build = new PMMLRequestDataBuilder("1234", "SampleMine").addParameter("fld1", Double.valueOf(30.0d), Double.class).addParameter("fld2", Double.valueOf(60.0d), Double.class).addParameter("fld3", "false", String.class).addParameter("fld4", "optA", String.class).build();
        executionHelper.submitRequest(build);
        executionHelper.getExecutor().getSessionObjects().forEach(obj -> {
            System.out.println(obj);
        });
        executionHelper.getMiningModelPojo().forEach(abstractPMMLData -> {
            System.out.println(abstractPMMLData);
        });
        executionHelper.getResultData().iterator().forEachRemaining(pMML4Result -> {
            Assert.assertEquals(build.getCorrelationId(), pMML4Result.getCorrelationId());
            if (pMML4Result.getSegmentationId() == null) {
                Assert.assertEquals("OK", pMML4Result.getResultCode());
                Assert.assertNotNull(pMML4Result.getResultValue("Fld5", (String) null, new Object[0]));
                Assert.assertEquals("tgtY", (String) pMML4Result.getResultValue("Fld5", "value", String.class, new Object[0]).orElse(null));
            }
        });
    }

    @Test
    public void testSelectSecondSegmentFirst() {
        PMML4ExecutionHelper executionHelper = PMML4ExecutionHelper.PMML4ExecutionHelperFactory.getExecutionHelper("SampleMine", new KieHelper().addResource(ResourceFactory.newClassPathResource(source1), ResourceType.PMML).build(new KieBaseOption[0]), true);
        PMMLRequestData build = new PMMLRequestDataBuilder("1234", "SampleMine").addParameter("fld1", Double.valueOf(45.0d), Double.class).addParameter("fld2", Double.valueOf(60.0d), Double.class).addParameter("fld6", "optA", String.class).build();
        executionHelper.submitRequest(build);
        executionHelper.getResultData().forEach(pMML4Result -> {
            Assert.assertEquals(build.getCorrelationId(), pMML4Result.getCorrelationId());
            Assert.assertEquals("OK", pMML4Result.getResultCode());
            if (pMML4Result.getSegmentationId() == null) {
                Assert.assertNotNull(pMML4Result.getResultValue("Fld5", (String) null, new Object[0]));
                Assert.assertEquals("tgtZ", (String) pMML4Result.getResultValue("Fld5", "value", String.class, new Object[0]).orElse(null));
                AbstractTreeToken abstractTreeToken = (AbstractTreeToken) pMML4Result.getResultValue("MissingTreeToken", (String) null, AbstractTreeToken.class, new Object[0]).orElse(null);
                Assert.assertNotNull(abstractTreeToken);
                Assert.assertEquals(0.6d, abstractTreeToken.getConfidence().doubleValue(), 0.0d);
                Assert.assertEquals("null", abstractTreeToken.getCurrent());
            }
        });
        int i = 0;
        for (SegmentExecution segmentExecution : executionHelper.getChildModelSegments()) {
            Assert.assertEquals(build.getCorrelationId(), segmentExecution.getCorrelationId());
            if (segmentExecution.getState() == SegmentExecutionState.COMPLETE) {
                i += VERBOSE;
            }
        }
        Assert.assertEquals(1L, i);
    }

    @Test
    public void testWithScorecard() {
        PMML4ExecutionHelper executionHelper = PMML4ExecutionHelper.PMML4ExecutionHelperFactory.getExecutionHelper("SampleScorecardMine", ResourceFactory.newClassPathResource(source2), (KieBaseConfiguration) null, true);
        PMMLRequestData build = new PMMLRequestDataBuilder("1234", "SampleScorecardMine").addParameter("age", Double.valueOf(33.0d), Double.class).addParameter("occupation", "SKYDIVER", String.class).addParameter("residenceState", "KN", String.class).addParameter("validLicense", true, Boolean.class).build();
        executionHelper.submitRequest(build);
        executionHelper.getResultData().forEach(pMML4Result -> {
            Assert.assertEquals(build.getCorrelationId(), pMML4Result.getCorrelationId());
            Assert.assertEquals("OK", pMML4Result.getResultCode());
            if (pMML4Result.getSegmentationId() == null) {
                ScoreCard scoreCard = (ScoreCard) pMML4Result.getResultValue("ScoreCard", (String) null, ScoreCard.class, new Object[0]).orElse(null);
                Assert.assertNotNull(scoreCard);
                Map ranking = scoreCard.getRanking();
                Assert.assertNotNull(ranking);
                Assert.assertTrue(ranking instanceof LinkedHashMap);
                LinkedHashMap linkedHashMap = (LinkedHashMap) ranking;
                Assert.assertTrue(linkedHashMap.containsKey("LX00"));
                Assert.assertTrue(linkedHashMap.containsKey("RES"));
                Assert.assertTrue(linkedHashMap.containsKey("CX2"));
                Assert.assertEquals(Double.valueOf(-1.0d), linkedHashMap.get("LX00"));
                Assert.assertEquals(Double.valueOf(-10.0d), linkedHashMap.get("RES"));
                Assert.assertEquals(Double.valueOf(-30.0d), linkedHashMap.get("CX2"));
                Iterator it = linkedHashMap.keySet().iterator();
                Assert.assertEquals("LX00", it.next());
                Assert.assertEquals("RES", it.next());
                Assert.assertEquals("CX2", it.next());
            }
        });
        int i = 0;
        for (SegmentExecution segmentExecution : executionHelper.getChildModelSegments()) {
            Assert.assertEquals(build.getCorrelationId(), segmentExecution.getCorrelationId());
            if (segmentExecution.getState() == SegmentExecutionState.COMPLETE) {
                i += VERBOSE;
            }
        }
        Assert.assertEquals(1L, i);
    }

    @Test
    public void testWithRegression() {
        PMML4ExecutionHelper executionHelper = PMML4ExecutionHelper.PMML4ExecutionHelperFactory.getExecutionHelper("SampleScorecardMine", ResourceFactory.newClassPathResource(source2), (KieBaseConfiguration) null, true);
        PMMLRequestData build = new PMMLRequestDataBuilder("123", "SampleScorecardMine").addParameter("fld1r", Double.valueOf(1.0d), Double.class).addParameter("fld2r", Double.valueOf(1.0d), Double.class).addParameter("fld3r", "x", String.class).build();
        executionHelper.submitRequest(build);
        executionHelper.getResultData().forEach(pMML4Result -> {
            Assert.assertEquals(build.getCorrelationId(), pMML4Result.getCorrelationId());
            Assert.assertEquals("OK", pMML4Result.getResultCode());
            if (pMML4Result.getSegmentationId() == null) {
                System.out.println(pMML4Result);
                Assert.assertNotNull(pMML4Result.getResultValue("RegOut", (String) null, new Object[0]));
                Assert.assertEquals("catC", (String) pMML4Result.getResultValue("RegOut", "value", String.class, new Object[0]).orElse(null));
                Assert.assertNotNull(pMML4Result.getResultValue("RegProb", (String) null, new Object[0]));
                Assert.assertEquals(0.709228d, ((Double) pMML4Result.getResultValue("RegProb", "value", Double.class, new Object[0]).orElse(null)).doubleValue(), 1.0E-6d);
                Assert.assertNotNull(pMML4Result.getResultValue("RegProbA", (String) null, new Object[0]));
                Assert.assertEquals(0.010635d, ((Double) pMML4Result.getResultValue("RegProbA", "value", Double.class, new Object[0]).orElse(null)).doubleValue(), 1.0E-6d);
            }
        });
        int i = 0;
        for (SegmentExecution segmentExecution : executionHelper.getChildModelSegments()) {
            Assert.assertEquals(build.getCorrelationId(), segmentExecution.getCorrelationId());
            if (segmentExecution.getState() == SegmentExecutionState.COMPLETE) {
                i += VERBOSE;
            }
        }
        Assert.assertEquals(1L, i);
    }

    @Test
    public void testSelectAll() {
        PMML4ExecutionHelper executionHelper = PMML4ExecutionHelper.PMML4ExecutionHelperFactory.getExecutionHelper("SampleSelectAllMine", ResourceFactory.newClassPathResource(source4), (KieBaseConfiguration) null, true);
        PMMLRequestData build = new PMMLRequestDataBuilder("1234", "SampleSelectAllMine").addParameter("age", Double.valueOf(33.0d), Double.class).addParameter("occupation", "SKYDIVER", String.class).addParameter("residenceState", "KN", String.class).addParameter("validLicense", true, Boolean.class).build();
        executionHelper.submitRequest(build);
        executionHelper.getResultData().forEach(pMML4Result -> {
            Assert.assertEquals("OK", pMML4Result.getResultCode());
            Assert.assertEquals(build.getCorrelationId(), pMML4Result.getCorrelationId());
            ScoreCard scoreCard = (ScoreCard) pMML4Result.getResultValue("ScoreCard", (String) null, ScoreCard.class, new Object[0]).orElse(null);
            Assert.assertNotNull(scoreCard);
            Map ranking = scoreCard.getRanking();
            Assert.assertNotNull(ranking);
            Assert.assertTrue(ranking instanceof LinkedHashMap);
            LinkedHashMap linkedHashMap = (LinkedHashMap) ranking;
            Assert.assertTrue(linkedHashMap.containsKey("LX00") || linkedHashMap.containsKey("LC00"));
            if (linkedHashMap.containsKey("LX00")) {
                Assert.assertTrue(linkedHashMap.containsKey("RES"));
                Assert.assertTrue(linkedHashMap.containsKey("CX2"));
                Assert.assertEquals(Double.valueOf(-1.0d), linkedHashMap.get("LX00"));
                Assert.assertEquals(Double.valueOf(-10.0d), linkedHashMap.get("RES"));
                Assert.assertEquals(Double.valueOf(-30.0d), linkedHashMap.get("CX2"));
                Iterator it = linkedHashMap.keySet().iterator();
                Assert.assertEquals("LX00", it.next());
                Assert.assertEquals("RES", it.next());
                Assert.assertEquals("CX2", it.next());
                Assert.assertEquals(41.345d, scoreCard.getScore(), 1.0E-6d);
                return;
            }
            Assert.assertTrue(linkedHashMap.containsKey("RST"));
            Assert.assertTrue(linkedHashMap.containsKey("DX2"));
            Assert.assertEquals(Double.valueOf(-1.0d), linkedHashMap.get("LC00"));
            Assert.assertEquals(Double.valueOf(10.0d), linkedHashMap.get("RST"));
            Assert.assertEquals(Double.valueOf(-30.0d), linkedHashMap.get("DX2"));
            Iterator it2 = linkedHashMap.keySet().iterator();
            Assert.assertEquals("RST", it2.next());
            Assert.assertEquals("LC00", it2.next());
            Assert.assertEquals("DX2", it2.next());
            Assert.assertEquals(21.345d, scoreCard.getScore(), 1.0E-6d);
        });
        int i = 0;
        for (SegmentExecution segmentExecution : executionHelper.getChildModelSegments()) {
            Assert.assertEquals(build.getCorrelationId(), segmentExecution.getCorrelationId());
            if (segmentExecution.getState() == SegmentExecutionState.COMPLETE) {
                i += VERBOSE;
            }
        }
        Assert.assertEquals(2L, i);
    }

    @Test
    public void testSimpleModelChain() {
        PMML4Result submitRequest = PMML4ExecutionHelper.PMML4ExecutionHelperFactory.getExecutionHelper("SampleModelChainMine", ResourceFactory.newClassPathResource(source5), (KieBaseConfiguration) null, true).submitRequest(new PMMLRequestDataBuilder("1234", "SampleModelChainMine").addParameter("age", Double.valueOf(33.0d), Double.class).addParameter("occupation", "TEACHER", String.class).addParameter("residenceState", "TN", String.class).addParameter("validLicense", true, Boolean.class).build());
        Assert.assertEquals("OK", submitRequest.getResultCode());
        Map resultVariables = submitRequest.getResultVariables();
        Assert.assertNotNull(resultVariables);
        Assert.assertTrue(resultVariables.containsKey("QualificationLevel"));
        Assert.assertTrue(resultVariables.containsKey("OverallScore"));
        String str = (String) submitRequest.getResultValue("QualificationLevel", "value", String.class, new Object[0]).orElse(null);
        Double d = (Double) submitRequest.getResultValue("OverallScore", "value", Double.class, new Object[0]).orElse(null);
        Assert.assertNotNull(str);
        Assert.assertNotNull(d);
        Assert.assertEquals("Well", str);
        Assert.assertEquals(56.345d, d.doubleValue(), 1.0E-6d);
    }

    @Test
    public void testWeightedAverage() {
        PMML4ExecutionHelper executionHelper = PMML4ExecutionHelper.PMML4ExecutionHelperFactory.getExecutionHelper("SampleMiningModelAvg", ResourceFactory.newClassPathResource(WEIGHTED_AVG), (KieBaseConfiguration) null, true);
        PMML4Result submitRequest = executionHelper.submitRequest(new PMMLRequestDataBuilder("1234", executionHelper.getModelName()).addParameter("petal_length", Double.valueOf(6.45d), Double.class).addParameter("petal_width", Double.valueOf(1.75d), Double.class).addParameter("sepal_width", Double.valueOf(1.23d), Double.class).build());
        Assert.assertEquals(7.1833385d, ((Double) submitRequest.getResultValue("WeightedAvg_Sepal_length", "value", Double.class, new Object[0]).orElse(null)).doubleValue(), 1.0E-6d);
        Assert.assertEquals(1.0d, ((Double) submitRequest.getResultValue("WeightedAvg_Sepal_length", "weight", Double.class, new Object[0]).orElse(null)).doubleValue(), 0.01d);
    }

    @Test
    public void testSum() {
        Double d;
        PMML4ExecutionHelper executionHelper = PMML4ExecutionHelper.PMML4ExecutionHelperFactory.getExecutionHelper("SampleMiningModelSum", ResourceFactory.newClassPathResource(SUMMED), (KieBaseConfiguration) null, true);
        Double d2 = (Double) executionHelper.submitRequest(new PMMLRequestDataBuilder("1234", executionHelper.getModelName()).addParameter("petal_length", Double.valueOf(6.45d), Double.class).addParameter("petal_width", Double.valueOf(1.75d), Double.class).addParameter("sepal_width", Double.valueOf(1.23d), Double.class).build()).getResultValue("Sum_Sepal_length", "value", Double.class, new Object[0]).orElse(null);
        Double valueOf = Double.valueOf(0.0d);
        for (PMML4Result pMML4Result : executionHelper.getResultData()) {
            if (pMML4Result.getSegmentationId() != null && (d = (Double) pMML4Result.getResultValue("Sepal_length", "value", Double.class, new Object[0]).orElse(null)) != null) {
                valueOf = Double.valueOf(valueOf.doubleValue() + d.doubleValue());
            }
        }
        Assert.assertEquals(valueOf.doubleValue(), d2.doubleValue(), 1.0E-6d);
    }
}
