package org.kie.pmml.models.regression.compiler.executor;

import java.io.Serializable;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.assertj.core.api.Assertions;
import org.dmg.pmml.Model;
import org.dmg.pmml.PMML;
import org.dmg.pmml.regression.RegressionModel;
import org.junit.Test;
import org.kie.memorycompiler.KieMemoryCompiler;
import org.kie.pmml.api.enums.PMML_MODEL;
import org.kie.pmml.api.exceptions.KiePMMLException;
import org.kie.pmml.commons.model.KiePMMLModelWithSources;
import org.kie.pmml.compiler.api.CommonTestingUtils;
import org.kie.pmml.compiler.api.dto.CommonCompilationDTO;
import org.kie.pmml.compiler.api.testutils.TestUtils;
import org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock;
import org.kie.pmml.models.regression.model.KiePMMLRegressionModel;

/* loaded from: input_file:org/kie/pmml/models/regression/compiler/executor/RegressionModelImplementationProviderTest.class */
public class RegressionModelImplementationProviderTest {
    private static final String RELEASE_ID = "org.drools:kie-pmml-models-testing:1.0";
    private static final String SOURCE_1 = "LinearRegressionSample.pmml";
    private static final String SOURCE_2 = "test_regression.pmml";
    private static final String SOURCE_3 = "test_regression_clax.pmml";
    private static final String PACKAGE_NAME = "packagename";
    private static final RegressionModelImplementationProvider PROVIDER = new RegressionModelImplementationProvider();
    private static final List<RegressionModel.NormalizationMethod> VALID_NORMALIZATION_METHODS = Arrays.asList(RegressionModel.NormalizationMethod.NONE, RegressionModel.NormalizationMethod.SOFTMAX, RegressionModel.NormalizationMethod.LOGIT, RegressionModel.NormalizationMethod.EXP, RegressionModel.NormalizationMethod.PROBIT, RegressionModel.NormalizationMethod.CLOGLOG, RegressionModel.NormalizationMethod.LOGLOG, RegressionModel.NormalizationMethod.CAUCHIT);

    @Test
    public void getPMMLModelType() {
        Assertions.assertThat(PROVIDER.getPMMLModelType()).isEqualTo(PMML_MODEL.REGRESSION_MODEL);
    }

    @Test
    public void getKiePMMLModel() throws Exception {
        PMML loadFromFile = TestUtils.loadFromFile(SOURCE_1);
        Assertions.assertThat(loadFromFile).isNotNull();
        Assertions.assertThat(loadFromFile.getModels()).hasSize(1);
        Assertions.assertThat((Model) loadFromFile.getModels().get(0)).isInstanceOf(RegressionModel.class);
        KiePMMLRegressionModel kiePMMLModel = PROVIDER.getKiePMMLModel(CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, loadFromFile, (RegressionModel) loadFromFile.getModels().get(0), new HasClassLoaderMock()));
        Assertions.assertThat(kiePMMLModel).isNotNull();
        Assertions.assertThat(kiePMMLModel).isInstanceOf(Serializable.class);
    }

    @Test
    public void getKiePMMLModelWithSources() throws Exception {
        PMML loadFromFile = TestUtils.loadFromFile(SOURCE_1);
        Assertions.assertThat(loadFromFile).isNotNull();
        Assertions.assertThat(loadFromFile.getModels()).hasSize(1);
        Assertions.assertThat((Model) loadFromFile.getModels().get(0)).isInstanceOf(RegressionModel.class);
        KiePMMLModelWithSources kiePMMLModelWithSources = PROVIDER.getKiePMMLModelWithSources(CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, loadFromFile, (RegressionModel) loadFromFile.getModels().get(0), new HasClassLoaderMock()));
        Assertions.assertThat(kiePMMLModelWithSources).isNotNull();
        Map sourcesMap = kiePMMLModelWithSources.getSourcesMap();
        Assertions.assertThat(sourcesMap).isNotNull();
        Assertions.assertThat(sourcesMap).isNotEmpty();
        Iterator it = KieMemoryCompiler.compile(sourcesMap, Thread.currentThread().getContextClassLoader()).values().iterator();
        while (it.hasNext()) {
            Assertions.assertThat((Class) it.next()).isInstanceOf(Serializable.class);
        }
    }

    @Test
    public void validateNormalizationMethodValid() {
        List<RegressionModel.NormalizationMethod> list = VALID_NORMALIZATION_METHODS;
        RegressionModelImplementationProvider regressionModelImplementationProvider = PROVIDER;
        Objects.requireNonNull(regressionModelImplementationProvider);
        list.forEach(regressionModelImplementationProvider::validateNormalizationMethod);
    }

    @Test
    public void validateNormalizationMethodInvalid() {
        for (RegressionModel.NormalizationMethod normalizationMethod : RegressionModel.NormalizationMethod.values()) {
            if (!VALID_NORMALIZATION_METHODS.contains(normalizationMethod)) {
                try {
                    PROVIDER.validateNormalizationMethod(normalizationMethod);
                    Assertions.fail("Expecting failure due to invalid normalization method " + normalizationMethod);
                } catch (KiePMMLException e) {
                }
            }
        }
    }

    @Test
    public void validateSource2() throws Exception {
        commonValidateSource(SOURCE_2);
    }

    @Test
    public void validateSource3() throws Exception {
        commonValidateSource(SOURCE_3);
    }

    @Test
    public void validateNoRegressionTables() throws Exception {
        PMML loadFromFile = TestUtils.loadFromFile(SOURCE_1);
        Assertions.assertThat(loadFromFile).isNotNull();
        Assertions.assertThat(loadFromFile.getModels()).hasSize(1);
        Assertions.assertThat((Model) loadFromFile.getModels().get(0)).isInstanceOf(RegressionModel.class);
        RegressionModel regressionModel = (RegressionModel) loadFromFile.getModels().get(0);
        regressionModel.getRegressionTables().clear();
        List fieldsFromDataDictionary = CommonTestingUtils.getFieldsFromDataDictionary(loadFromFile.getDataDictionary());
        try {
            PROVIDER.validate(fieldsFromDataDictionary, regressionModel);
            Assertions.fail("Expecting validation failure due to missing RegressionTables");
        } catch (KiePMMLException e) {
        }
        try {
            PROVIDER.validate(fieldsFromDataDictionary, new RegressionModel(regressionModel.getMiningFunction(), regressionModel.getMiningSchema(), (List) null));
            Assertions.fail("Expecting validation failure due to missing RegressionTables");
        } catch (KiePMMLException e2) {
        }
    }

    private void commonValidateSource(String str) throws Exception {
        PMML loadFromFile = TestUtils.loadFromFile(str);
        Assertions.assertThat(loadFromFile).isNotNull();
        Assertions.assertThat(loadFromFile.getModels()).hasSize(1);
        Assertions.assertThat((Model) loadFromFile.getModels().get(0)).isInstanceOf(RegressionModel.class);
        PROVIDER.validate(CommonTestingUtils.getFieldsFromDataDictionary(loadFromFile.getDataDictionary()), (RegressionModel) loadFromFile.getModels().get(0));
    }
}
