/*
 * Decompiled with CFR 0.152.
 */
package org.kie.pmml.models.regression.compiler.executor;

import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.dmg.pmml.DataDictionary;
import org.dmg.pmml.Model;
import org.dmg.pmml.PMML;
import org.dmg.pmml.regression.RegressionModel;
import org.junit.Assert;
import org.junit.Test;
import org.kie.memorycompiler.KieMemoryCompiler;
import org.kie.pmml.api.enums.PMML_MODEL;
import org.kie.pmml.api.exceptions.KiePMMLException;
import org.kie.pmml.commons.model.HasClassLoader;
import org.kie.pmml.commons.model.KiePMMLModelWithSources;
import org.kie.pmml.compiler.api.CommonTestingUtils;
import org.kie.pmml.compiler.api.dto.CommonCompilationDTO;
import org.kie.pmml.compiler.api.dto.CompilationDTO;
import org.kie.pmml.compiler.api.testutils.TestUtils;
import org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock;
import org.kie.pmml.models.regression.compiler.executor.RegressionModelImplementationProvider;
import org.kie.pmml.models.regression.model.KiePMMLRegressionModel;

public class RegressionModelImplementationProviderTest {
    private static final RegressionModelImplementationProvider PROVIDER = new RegressionModelImplementationProvider();
    private static final String RELEASE_ID = "org.drools:kie-pmml-models-testing:1.0";
    private static final String SOURCE_1 = "LinearRegressionSample.pmml";
    private static final String SOURCE_2 = "test_regression.pmml";
    private static final String SOURCE_3 = "test_regression_clax.pmml";
    private static final String PACKAGE_NAME = "packagename";
    private static final List<RegressionModel.NormalizationMethod> VALID_NORMALIZATION_METHODS = Arrays.asList(RegressionModel.NormalizationMethod.NONE, RegressionModel.NormalizationMethod.SOFTMAX, RegressionModel.NormalizationMethod.LOGIT, RegressionModel.NormalizationMethod.EXP, RegressionModel.NormalizationMethod.PROBIT, RegressionModel.NormalizationMethod.CLOGLOG, RegressionModel.NormalizationMethod.LOGLOG, RegressionModel.NormalizationMethod.CAUCHIT);

    @Test
    public void getPMMLModelType() {
        Assert.assertEquals((Object)PMML_MODEL.REGRESSION_MODEL, (Object)PROVIDER.getPMMLModelType());
    }

    @Test
    public void getKiePMMLModel() throws Exception {
        PMML pmml = TestUtils.loadFromFile((String)SOURCE_1);
        Assert.assertNotNull((Object)pmml);
        Assert.assertEquals((long)1L, (long)pmml.getModels().size());
        Assert.assertTrue((boolean)(pmml.getModels().get(0) instanceof RegressionModel));
        RegressionModel regressionModel = (RegressionModel)pmml.getModels().get(0);
        CommonCompilationDTO compilationDTO = CommonCompilationDTO.fromGeneratedPackageNameAndFields((String)PACKAGE_NAME, (PMML)pmml, (Model)regressionModel, (HasClassLoader)new HasClassLoaderMock());
        KiePMMLRegressionModel retrieved = PROVIDER.getKiePMMLModel((CompilationDTO)compilationDTO);
        Assert.assertNotNull((Object)retrieved);
        Assert.assertTrue((boolean)(retrieved instanceof Serializable));
    }

    @Test
    public void getKiePMMLModelWithSources() throws Exception {
        PMML pmml = TestUtils.loadFromFile((String)SOURCE_1);
        Assert.assertNotNull((Object)pmml);
        Assert.assertEquals((long)1L, (long)pmml.getModels().size());
        Assert.assertTrue((boolean)(pmml.getModels().get(0) instanceof RegressionModel));
        RegressionModel regressionModel = (RegressionModel)pmml.getModels().get(0);
        CommonCompilationDTO compilationDTO = CommonCompilationDTO.fromGeneratedPackageNameAndFields((String)PACKAGE_NAME, (PMML)pmml, (Model)regressionModel, (HasClassLoader)new HasClassLoaderMock());
        KiePMMLModelWithSources retrieved = PROVIDER.getKiePMMLModelWithSources((CompilationDTO)compilationDTO);
        Assert.assertNotNull((Object)retrieved);
        Map sourcesMap = retrieved.getSourcesMap();
        Assert.assertNotNull((Object)sourcesMap);
        Assert.assertFalse((boolean)sourcesMap.isEmpty());
        ClassLoader classLoader = Thread.currentThread().getContextClassLoader();
        Map compiled = KieMemoryCompiler.compile((Map)sourcesMap, (ClassLoader)classLoader);
        for (Class clazz : compiled.values()) {
            Assert.assertTrue((boolean)(clazz instanceof Serializable));
        }
    }

    @Test
    public void validateNormalizationMethodValid() {
        VALID_NORMALIZATION_METHODS.forEach(arg_0 -> ((RegressionModelImplementationProvider)PROVIDER).validateNormalizationMethod(arg_0));
    }

    @Test
    public void validateNormalizationMethodInvalid() {
        for (RegressionModel.NormalizationMethod normalizationMethod : RegressionModel.NormalizationMethod.values()) {
            if (VALID_NORMALIZATION_METHODS.contains(normalizationMethod)) continue;
            try {
                PROVIDER.validateNormalizationMethod(normalizationMethod);
                Assert.fail((String)("Expecting failure due to invalid normalization method " + normalizationMethod));
            }
            catch (KiePMMLException kiePMMLException) {
                // empty catch block
            }
        }
    }

    @Test
    public void validateSource2() throws Exception {
        this.commonValidateSource(SOURCE_2);
    }

    @Test
    public void validateSource3() throws Exception {
        this.commonValidateSource(SOURCE_3);
    }

    @Test
    public void validateNoRegressionTables() throws Exception {
        PMML pmml = TestUtils.loadFromFile((String)SOURCE_1);
        Assert.assertNotNull((Object)pmml);
        Assert.assertEquals((long)1L, (long)pmml.getModels().size());
        Assert.assertTrue((boolean)(pmml.getModels().get(0) instanceof RegressionModel));
        RegressionModel regressionModel = (RegressionModel)pmml.getModels().get(0);
        regressionModel.getRegressionTables().clear();
        List fields = CommonTestingUtils.getFieldsFromDataDictionary((DataDictionary)pmml.getDataDictionary());
        try {
            PROVIDER.validate(fields, regressionModel);
            Assert.fail((String)"Expecting validation failure due to missing RegressionTables");
        }
        catch (KiePMMLException kiePMMLException) {
            // empty catch block
        }
        regressionModel = new RegressionModel(regressionModel.getMiningFunction(), regressionModel.getMiningSchema(), null);
        try {
            PROVIDER.validate(fields, regressionModel);
            Assert.fail((String)"Expecting validation failure due to missing RegressionTables");
        }
        catch (KiePMMLException kiePMMLException) {
            // empty catch block
        }
    }

    private void commonValidateSource(String sourceFile) throws Exception {
        PMML pmml = TestUtils.loadFromFile((String)sourceFile);
        Assert.assertNotNull((Object)pmml);
        Assert.assertEquals((long)1L, (long)pmml.getModels().size());
        Assert.assertTrue((boolean)(pmml.getModels().get(0) instanceof RegressionModel));
        PROVIDER.validate(CommonTestingUtils.getFieldsFromDataDictionary((DataDictionary)pmml.getDataDictionary()), (RegressionModel)pmml.getModels().get(0));
    }
}

