package org.kie.dmn.pmml;

import java.io.InputStreamReader;
import java.math.BigDecimal;
import java.util.Arrays;
import java.util.Map;
import org.drools.compiler.kie.builder.impl.DrlProject;
import org.hamcrest.CoreMatchers;
import org.hamcrest.MatcherAssert;
import org.hamcrest.Matchers;
import org.junit.Assert;
import org.junit.Test;
import org.kie.api.KieServices;
import org.kie.api.builder.KieFileSystem;
import org.kie.api.builder.Message;
import org.kie.api.internal.assembler.KieAssemblers;
import org.kie.api.internal.utils.ServiceRegistry;
import org.kie.api.runtime.KieRuntimeFactory;
import org.kie.dmn.api.core.DMNContext;
import org.kie.dmn.api.core.DMNResult;
import org.kie.dmn.api.core.DMNRuntime;
import org.kie.dmn.api.core.DMNType;
import org.kie.dmn.core.api.DMNFactory;
import org.kie.dmn.core.assembler.DMNAssemblerService;
import org.kie.dmn.core.impl.CompositeTypeImpl;
import org.kie.dmn.core.impl.DMNModelImpl;
import org.kie.dmn.core.impl.SimpleTypeImpl;
import org.kie.dmn.core.internal.utils.DMNRuntimeBuilder;
import org.kie.dmn.core.pmml.DMNImportPMMLInfo;
import org.kie.dmn.core.pmml.DMNPMMLModelInfo;
import org.kie.dmn.core.util.DMNRuntimeUtil;
import org.kie.dmn.feel.lang.types.BuiltInType;
import org.kie.internal.builder.IncrementalResults;
import org.kie.internal.builder.InternalKieBuilder;
import org.kie.internal.io.ResourceFactory;
import org.kie.internal.services.KieAssemblersImpl;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/kie/dmn/pmml/DMNRuntimePMMLTest.class */
public abstract class DMNRuntimePMMLTest {
    public static final Logger LOG = LoggerFactory.getLogger(DMNRuntimePMMLTest.class);
    private static final double COMPARISON_DELTA = 1.0E-6d;

    @Test
    public void testBasic() {
        runDMNModelInvokingPMML(DMNRuntimeUtil.createRuntimeWithAdditionalResources("KiePMMLScoreCard.dmn", DMNRuntimePMMLTest.class, new String[]{"test_scorecard.pmml"}));
    }

    @Test
    public void testWithInputTypes() {
        runDMNModelInvokingPMML(DMNRuntimeUtil.createRuntimeWithAdditionalResources("KiePMMLScoreCard_wInputType.dmn", DMNRuntimePMMLTest.class, new String[]{"test_scorecard.pmml"}));
    }

    @Test
    public void testBasicNoKieAssembler() {
        runDMNModelInvokingPMML((DMNRuntime) DMNRuntimeBuilder.fromDefaults().setRelativeImportResolver((str, str2, str3) -> {
            return new InputStreamReader(DMNRuntimePMMLTest.class.getResourceAsStream(str3));
        }).buildConfiguration().fromResources(Arrays.asList(ResourceFactory.newClassPathResource("KiePMMLScoreCard.dmn", DMNRuntimePMMLTest.class))).getOrElseThrow(exc -> {
            return new RuntimeException("Error compiling DMN model(s)", exc);
        }));
    }

    static void runDMNModelInvokingPMML(DMNRuntime dMNRuntime) {
        DMNModelImpl model = dMNRuntime.getModel("http://www.trisotech.com/definitions/_ca466dbe-20b4-4e88-a43f-4ce3aff26e4f", "KiePMMLScoreCard");
        MatcherAssert.assertThat(model, CoreMatchers.notNullValue());
        MatcherAssert.assertThat(DMNRuntimeUtil.formatMessages(model.getMessages()), Boolean.valueOf(model.hasErrors()), CoreMatchers.is(false));
        DMNResult evaluateAll = dMNRuntime.evaluateAll(model, DMNFactory.newContext());
        LOG.debug("{}", evaluateAll);
        MatcherAssert.assertThat(DMNRuntimeUtil.formatMessages(evaluateAll.getMessages()), Boolean.valueOf(evaluateAll.hasErrors()), CoreMatchers.is(false));
        MatcherAssert.assertThat(evaluateAll.getContext().get("my decision"), CoreMatchers.is(new BigDecimal("41.345")));
        Map pmmlImportInfo = model.getPmmlImportInfo();
        MatcherAssert.assertThat(pmmlImportInfo.keySet(), Matchers.hasSize(1));
        DMNImportPMMLInfo dMNImportPMMLInfo = (DMNImportPMMLInfo) pmmlImportInfo.values().iterator().next();
        MatcherAssert.assertThat(dMNImportPMMLInfo.getImportName(), CoreMatchers.is("iris"));
        MatcherAssert.assertThat(dMNImportPMMLInfo.getModels(), Matchers.hasSize(1));
        DMNPMMLModelInfo dMNPMMLModelInfo = (DMNPMMLModelInfo) dMNImportPMMLInfo.getModels().iterator().next();
        MatcherAssert.assertThat(dMNPMMLModelInfo.getName(), CoreMatchers.is("Sample Score"));
        MatcherAssert.assertThat(dMNPMMLModelInfo.getInputFields(), Matchers.hasEntry(CoreMatchers.is("age"), Matchers.anything()));
        MatcherAssert.assertThat(dMNPMMLModelInfo.getInputFields(), Matchers.hasEntry(CoreMatchers.is("occupation"), Matchers.anything()));
        MatcherAssert.assertThat(dMNPMMLModelInfo.getInputFields(), Matchers.hasEntry(CoreMatchers.is("residenceState"), Matchers.anything()));
        MatcherAssert.assertThat(dMNPMMLModelInfo.getInputFields(), Matchers.hasEntry(CoreMatchers.is("validLicense"), Matchers.anything()));
        MatcherAssert.assertThat(dMNPMMLModelInfo.getInputFields(), Matchers.not(Matchers.hasEntry(CoreMatchers.is("overallScore"), Matchers.anything())));
        MatcherAssert.assertThat(dMNPMMLModelInfo.getInputFields(), Matchers.not(Matchers.hasEntry(CoreMatchers.is("calculatedScore"), Matchers.anything())));
        MatcherAssert.assertThat(dMNPMMLModelInfo.getOutputFields(), Matchers.hasEntry(CoreMatchers.is("calculatedScore"), Matchers.anything()));
    }

    @Test
    public void testSteppedCompilation() {
        ((KieAssemblersImpl) ServiceRegistry.getService(KieAssemblers.class)).accept(new DMNAssemblerService());
        KieServices kieServices = KieServices.Factory.get();
        KieFileSystem newKieFileSystem = kieServices.newKieFileSystem();
        newKieFileSystem.write("src/main/resources/org/acme/test_scorecard.pmml", kieServices.getResources().newClassPathResource("test_scorecard.pmml", DMNRuntimePMMLTest.class));
        InternalKieBuilder buildAll = kieServices.newKieBuilder(newKieFileSystem).buildAll(DrlProject.class);
        Assert.assertEquals(0L, buildAll.getResults().getMessages(new Message.Level[]{Message.Level.ERROR}).size());
        newKieFileSystem.write("src/main/resources/org/acme/KiePMMLScoreCard.dmn", kieServices.getResources().newClassPathResource("KiePMMLScoreCard.dmn", DMNRuntimePMMLTest.class));
        IncrementalResults build = buildAll.createFileSet(new String[]{"src/main/resources/org/acme/KiePMMLScoreCard.dmn"}).build();
        Assert.assertEquals(0L, build.getAddedMessages().size());
        Assert.assertEquals(0L, build.getRemovedMessages().size());
        runDMNModelInvokingPMML((DMNRuntime) KieRuntimeFactory.of(kieServices.newKieContainer(kieServices.getRepository().getDefaultReleaseId()).getKieBase()).get(DMNRuntime.class));
    }

    @Test
    public void testMultiOutputs() {
        DMNRuntime createRuntimeWithAdditionalResources = DMNRuntimeUtil.createRuntimeWithAdditionalResources("KiePMMLRegressionClax.dmn", DMNRuntimePMMLTest.class, new String[]{"test_regression_clax.pmml"});
        DMNModelImpl model = createRuntimeWithAdditionalResources.getModel("http://www.trisotech.com/definitions/_ca466dbe-20b4-4e88-a43f-4ce3aff26e4f", "KiePMMLRegressionClax");
        MatcherAssert.assertThat(model, CoreMatchers.notNullValue());
        MatcherAssert.assertThat(DMNRuntimeUtil.formatMessages(model.getMessages()), Boolean.valueOf(model.hasErrors()), CoreMatchers.is(false));
        DMNContext newContext = DMNFactory.newContext();
        newContext.set("fld1", Double.valueOf(1.0d));
        newContext.set("fld2", Double.valueOf(1.0d));
        newContext.set("fld3", "x");
        DMNResult evaluateAll = createRuntimeWithAdditionalResources.evaluateAll(model, newContext);
        LOG.debug("{}", evaluateAll);
        MatcherAssert.assertThat(DMNRuntimeUtil.formatMessages(evaluateAll.getMessages()), Boolean.valueOf(evaluateAll.hasErrors()), CoreMatchers.is(false));
        Map map = (Map) evaluateAll.getContext().get("my decision");
        Assert.assertEquals("catD", (String) map.get("RegOut"));
        Assert.assertEquals(0.8279559384018024d, ((BigDecimal) map.get("RegProb")).doubleValue(), COMPARISON_DELTA);
        Assert.assertEquals(0.0022681396056233208d, ((BigDecimal) map.get("RegProbA")).doubleValue(), COMPARISON_DELTA);
        DMNType resolveType = model.getTypeRegistry().resolveType(model.getDefinitions().getURIFEEL(), BuiltInType.NUMBER.getName());
        DMNType resolveType2 = model.getTypeRegistry().resolveType(model.getDefinitions().getURIFEEL(), BuiltInType.STRING.getName());
        Map pmmlImportInfo = model.getPmmlImportInfo();
        MatcherAssert.assertThat(pmmlImportInfo.keySet(), Matchers.hasSize(1));
        DMNImportPMMLInfo dMNImportPMMLInfo = (DMNImportPMMLInfo) pmmlImportInfo.values().iterator().next();
        MatcherAssert.assertThat(dMNImportPMMLInfo.getImportName(), CoreMatchers.is("test_regression_clax"));
        MatcherAssert.assertThat(dMNImportPMMLInfo.getModels(), Matchers.hasSize(1));
        DMNPMMLModelInfo dMNPMMLModelInfo = (DMNPMMLModelInfo) dMNImportPMMLInfo.getModels().iterator().next();
        MatcherAssert.assertThat(dMNPMMLModelInfo.getName(), CoreMatchers.is("LinReg"));
        Map inputFields = dMNPMMLModelInfo.getInputFields();
        SimpleTypeImpl simpleTypeImpl = (SimpleTypeImpl) inputFields.get("fld1");
        Assert.assertEquals("test_regression_clax", simpleTypeImpl.getNamespace());
        Assert.assertEquals(BuiltInType.NUMBER, simpleTypeImpl.getFeelType());
        Assert.assertEquals(resolveType, simpleTypeImpl.getBaseType());
        SimpleTypeImpl simpleTypeImpl2 = (SimpleTypeImpl) inputFields.get("fld2");
        Assert.assertEquals("test_regression_clax", simpleTypeImpl2.getNamespace());
        Assert.assertEquals(BuiltInType.NUMBER, simpleTypeImpl2.getFeelType());
        Assert.assertEquals(resolveType, simpleTypeImpl2.getBaseType());
        SimpleTypeImpl simpleTypeImpl3 = (SimpleTypeImpl) inputFields.get("fld3");
        Assert.assertEquals("test_regression_clax", simpleTypeImpl3.getNamespace());
        Assert.assertEquals(BuiltInType.STRING, simpleTypeImpl3.getFeelType());
        Assert.assertEquals(resolveType2, simpleTypeImpl3.getBaseType());
        CompositeTypeImpl compositeTypeImpl = (CompositeTypeImpl) dMNPMMLModelInfo.getOutputFields().get("LinReg");
        Assert.assertEquals("test_regression_clax", compositeTypeImpl.getNamespace());
        Map fields = compositeTypeImpl.getFields();
        SimpleTypeImpl simpleTypeImpl4 = (SimpleTypeImpl) fields.get("RegOut");
        Assert.assertEquals("test_regression_clax", simpleTypeImpl4.getNamespace());
        Assert.assertEquals(BuiltInType.STRING, simpleTypeImpl4.getFeelType());
        Assert.assertEquals(resolveType2, simpleTypeImpl4.getBaseType());
        SimpleTypeImpl simpleTypeImpl5 = (SimpleTypeImpl) fields.get("RegProb");
        Assert.assertEquals("test_regression_clax", simpleTypeImpl5.getNamespace());
        Assert.assertEquals(BuiltInType.NUMBER, simpleTypeImpl5.getFeelType());
        Assert.assertEquals(resolveType, simpleTypeImpl5.getBaseType());
        SimpleTypeImpl simpleTypeImpl6 = (SimpleTypeImpl) fields.get("RegProbA");
        Assert.assertEquals("test_regression_clax", simpleTypeImpl6.getNamespace());
        Assert.assertEquals(BuiltInType.NUMBER, simpleTypeImpl6.getFeelType());
        Assert.assertEquals(resolveType, simpleTypeImpl6.getBaseType());
    }
}
