package org.kie.pmml.models.regression.compiler.factories;

import com.github.javaparser.ast.CompilationUnit;
import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration;
import com.github.javaparser.ast.body.MethodDeclaration;
import java.io.IOException;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.stream.Collectors;
import org.assertj.core.api.Assertions;
import org.dmg.pmml.DataDictionary;
import org.dmg.pmml.DataField;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.PMML;
import org.dmg.pmml.ResultFeature;
import org.dmg.pmml.regression.RegressionModel;
import org.dmg.pmml.regression.RegressionTable;
import org.junit.BeforeClass;
import org.junit.Test;
import org.kie.pmml.api.enums.OP_TYPE;
import org.kie.pmml.api.exceptions.KiePMMLInternalException;
import org.kie.pmml.commons.utils.KiePMMLModelUtils;
import org.kie.pmml.compiler.api.dto.CommonCompilationDTO;
import org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock;
import org.kie.pmml.compiler.commons.testutils.CodegenTestUtils;
import org.kie.pmml.compiler.commons.utils.JavaParserUtils;
import org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO;
import org.kie.pmml.models.regression.model.KiePMMLClassificationTable;
import org.kie.pmml.models.regression.model.tuples.KiePMMLTableSourceCategory;
import org.kie.test.util.filesystem.FileUtils;

/* loaded from: input_file:org/kie/pmml/models/regression/compiler/factories/KiePMMLClassificationTableFactoryTest.class */
public class KiePMMLClassificationTableFactoryTest extends AbstractKiePMMLRegressionTableRegressionFactoryTest {
    private static final String TEST_01_SOURCE = "KiePMMLClassificationTableFactoryTest_01.txt";
    private static final String TEST_02_SOURCE = "KiePMMLClassificationTableFactoryTest_02.txt";
    private static CompilationUnit COMPILATION_UNIT;
    private static ClassOrInterfaceDeclaration MODEL_TEMPLATE;
    private static MethodDeclaration STATIC_GETTER_METHOD;

    @BeforeClass
    public static void setup() {
        COMPILATION_UNIT = JavaParserUtils.getFromFileName("KiePMMLClassificationTableTemplate.tmpl");
        MODEL_TEMPLATE = (ClassOrInterfaceDeclaration) COMPILATION_UNIT.getClassByName("KiePMMLClassificationTableTemplate").get();
        STATIC_GETTER_METHOD = (MethodDeclaration) MODEL_TEMPLATE.getMethodsByName("getKiePMMLTable").get(0);
    }

    @Test
    public void getClassificationTable() {
        RegressionTable regressionTable = getRegressionTable(3.5d, "professional");
        RegressionTable regressionTable2 = getRegressionTable(27.4d, "clerical");
        OutputField outputField = getOutputField("CAT-1", ResultFeature.PROBABILITY, "CatPred-1");
        OutputField outputField2 = getOutputField("NUM-1", ResultFeature.PROBABILITY, "NumPred-0");
        OutputField outputField3 = getOutputField("PREV", ResultFeature.PREDICTED_VALUE, null);
        DataField dataField = new DataField();
        dataField.setName(FieldName.create("targetField"));
        dataField.setOpType(OpType.CATEGORICAL);
        DataDictionary dataDictionary = new DataDictionary();
        dataDictionary.addDataFields(new DataField[]{dataField});
        Model regressionModel = new RegressionModel();
        regressionModel.setNormalizationMethod(RegressionModel.NormalizationMethod.CAUCHIT);
        regressionModel.addRegressionTables(new RegressionTable[]{regressionTable, regressionTable2});
        regressionModel.setModelName(KiePMMLModelUtils.getGeneratedClassName("RegressionModel"));
        Output output = new Output();
        output.addOutputFields(new OutputField[]{outputField, outputField2, outputField3});
        regressionModel.setOutput(output);
        MiningField miningField = new MiningField();
        miningField.setUsageType(MiningField.UsageType.TARGET);
        miningField.setName(dataField.getName());
        MiningSchema miningSchema = new MiningSchema();
        miningSchema.addMiningFields(new MiningField[]{miningField});
        regressionModel.setMiningSchema(miningSchema);
        PMML pmml = new PMML();
        pmml.setDataDictionary(dataDictionary);
        pmml.addModels(new Model[]{regressionModel});
        KiePMMLClassificationTable classificationTable = KiePMMLClassificationTableFactory.getClassificationTable(RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod(CommonCompilationDTO.fromGeneratedPackageNameAndFields("PACKAGE_NAME", pmml, regressionModel, new HasClassLoaderMock()), regressionModel.getRegressionTables(), regressionModel.getNormalizationMethod()));
        Assertions.assertThat(classificationTable).isNotNull();
        Assertions.assertThat(classificationTable.getCategoryTableMap()).hasSameSizeAs(regressionModel.getRegressionTables());
        regressionModel.getRegressionTables().forEach(regressionTable3 -> {
            Assertions.assertThat(classificationTable.getCategoryTableMap()).containsKey(regressionTable3.getTargetCategory().toString());
        });
        Assertions.assertThat(classificationTable.getRegressionNormalizationMethod().getName()).isEqualTo(regressionModel.getNormalizationMethod().value());
        Assertions.assertThat(classificationTable.getOpType()).isEqualTo(OP_TYPE.CATEGORICAL);
        boolean z = regressionModel.getRegressionTables().size() == 2;
        Assertions.assertThat(classificationTable.isBinary()).isEqualTo(z);
        Assertions.assertThat(classificationTable.isBinary()).isEqualTo(z);
        Assertions.assertThat(classificationTable.getTargetField()).isEqualTo(miningField.getName().getValue());
    }

    @Test
    public void getClassificationTableBuilders() {
        RegressionTable regressionTable = getRegressionTable(3.5d, "professional");
        RegressionTable regressionTable2 = getRegressionTable(27.4d, "clerical");
        OutputField outputField = getOutputField("CAT-1", ResultFeature.PROBABILITY, "CatPred-1");
        OutputField outputField2 = getOutputField("NUM-1", ResultFeature.PROBABILITY, "NumPred-0");
        OutputField outputField3 = getOutputField("PREV", ResultFeature.PREDICTED_VALUE, null);
        DataField dataField = new DataField();
        dataField.setName(FieldName.create("targetField"));
        dataField.setOpType(OpType.CATEGORICAL);
        DataDictionary dataDictionary = new DataDictionary();
        dataDictionary.addDataFields(new DataField[]{dataField});
        Model regressionModel = new RegressionModel();
        regressionModel.setNormalizationMethod(RegressionModel.NormalizationMethod.CAUCHIT);
        regressionModel.addRegressionTables(new RegressionTable[]{regressionTable, regressionTable2});
        regressionModel.setModelName(KiePMMLModelUtils.getGeneratedClassName("RegressionModel"));
        Output output = new Output();
        output.addOutputFields(new OutputField[]{outputField, outputField2, outputField3});
        regressionModel.setOutput(output);
        MiningField miningField = new MiningField();
        miningField.setUsageType(MiningField.UsageType.TARGET);
        miningField.setName(dataField.getName());
        MiningSchema miningSchema = new MiningSchema();
        miningSchema.addMiningFields(new MiningField[]{miningField});
        regressionModel.setMiningSchema(miningSchema);
        PMML pmml = new PMML();
        pmml.setDataDictionary(dataDictionary);
        pmml.addModels(new Model[]{regressionModel});
        Map classificationTableBuilders = KiePMMLClassificationTableFactory.getClassificationTableBuilders(RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod(CommonCompilationDTO.fromGeneratedPackageNameAndFields("PACKAGE_NAME", pmml, regressionModel, new HasClassLoaderMock()), regressionModel.getRegressionTables(), regressionModel.getNormalizationMethod()));
        Assertions.assertThat(classificationTableBuilders).isNotNull();
        Assertions.assertThat(classificationTableBuilders).hasSize(3);
        classificationTableBuilders.values().forEach(kiePMMLTableSourceCategory -> {
            commonValidateKiePMMLRegressionTable(kiePMMLTableSourceCategory.getSource());
        });
        CodegenTestUtils.commonValidateCompilation((Map) classificationTableBuilders.entrySet().stream().collect(Collectors.toMap((v0) -> {
            return v0.getKey();
        }, entry -> {
            return ((KiePMMLTableSourceCategory) entry.getValue()).getSource();
        })));
    }

    @Test
    public void getClassificationTableBuilder() {
        RegressionTable regressionTable = getRegressionTable(3.5d, "professional");
        RegressionTable regressionTable2 = getRegressionTable(27.4d, "clerical");
        OutputField outputField = getOutputField("CAT-1", ResultFeature.PROBABILITY, "CatPred-1");
        OutputField outputField2 = getOutputField("NUM-1", ResultFeature.PROBABILITY, "NumPred-0");
        OutputField outputField3 = getOutputField("PREV", ResultFeature.PREDICTED_VALUE, null);
        DataField dataField = new DataField();
        dataField.setName(FieldName.create("targetField"));
        dataField.setOpType(OpType.CATEGORICAL);
        DataDictionary dataDictionary = new DataDictionary();
        dataDictionary.addDataFields(new DataField[]{dataField});
        Model regressionModel = new RegressionModel();
        regressionModel.setNormalizationMethod(RegressionModel.NormalizationMethod.CAUCHIT);
        regressionModel.addRegressionTables(new RegressionTable[]{regressionTable, regressionTable2});
        regressionModel.setModelName(KiePMMLModelUtils.getGeneratedClassName("RegressionModel"));
        Output output = new Output();
        output.addOutputFields(new OutputField[]{outputField, outputField2, outputField3});
        regressionModel.setOutput(output);
        MiningField miningField = new MiningField();
        miningField.setUsageType(MiningField.UsageType.TARGET);
        miningField.setName(dataField.getName());
        MiningSchema miningSchema = new MiningSchema();
        miningSchema.addMiningFields(new MiningField[]{miningField});
        regressionModel.setMiningSchema(miningSchema);
        PMML pmml = new PMML();
        pmml.setDataDictionary(dataDictionary);
        pmml.addModels(new Model[]{regressionModel});
        RegressionCompilationDTO fromCompilationDTORegressionTablesAndNormalizationMethod = RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod(CommonCompilationDTO.fromGeneratedPackageNameAndFields("PACKAGE_NAME", pmml, regressionModel, new HasClassLoaderMock()), regressionModel.getRegressionTables(), regressionModel.getNormalizationMethod());
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        regressionModel.getRegressionTables().forEach(regressionTable3 -> {
            linkedHashMap.put(fromCompilationDTORegressionTablesAndNormalizationMethod.getPackageName() + "." + regressionTable3.getTargetCategory().toString().toUpperCase(), new KiePMMLTableSourceCategory("", regressionTable3.getTargetCategory().toString()));
        });
        Assertions.assertThat(KiePMMLClassificationTableFactory.getClassificationTableBuilder(fromCompilationDTORegressionTablesAndNormalizationMethod, linkedHashMap)).isNotNull();
    }

    @Test
    public void getProbabilityMapUnsupportedFunction() {
        KiePMMLClassificationTableFactory.UNSUPPORTED_NORMALIZATION_METHODS.forEach(normalizationMethod -> {
            try {
                KiePMMLClassificationTableFactory.getProbabilityMapFunction(normalizationMethod, false);
            } catch (Throwable th) {
                Assertions.assertThat(th).isInstanceOf(KiePMMLInternalException.class);
                Assertions.assertThat(th.getMessage()).isEqualTo(String.format("Unsupported NormalizationMethod %s", normalizationMethod));
            }
            try {
                KiePMMLClassificationTableFactory.getProbabilityMapFunction(normalizationMethod, true);
            } catch (Throwable th2) {
                Assertions.assertThat(th2).isInstanceOf(KiePMMLInternalException.class);
                Assertions.assertThat(th2.getMessage()).isEqualTo(String.format("Unsupported NormalizationMethod %s", normalizationMethod));
            }
        });
    }

    @Test
    public void getProbabilityMapSupportedFunction() {
        KiePMMLClassificationTableFactory.SUPPORTED_NORMALIZATION_METHODS.forEach(normalizationMethod -> {
            Assertions.assertThat(KiePMMLClassificationTableFactory.getProbabilityMapFunction(normalizationMethod, false)).isNotNull();
        });
        KiePMMLClassificationTableFactory.SUPPORTED_NORMALIZATION_METHODS.forEach(normalizationMethod2 -> {
            Assertions.assertThat(KiePMMLClassificationTableFactory.getProbabilityMapFunction(normalizationMethod2, true)).isNotNull();
        });
    }

    @Test
    public void setStaticGetter() throws IOException {
        RegressionTable regressionTable = getRegressionTable(3.5d, "professional");
        RegressionTable regressionTable2 = getRegressionTable(27.4d, "clerical");
        OutputField outputField = getOutputField("CAT-1", ResultFeature.PROBABILITY, "CatPred-1");
        OutputField outputField2 = getOutputField("NUM-1", ResultFeature.PROBABILITY, "NumPred-0");
        OutputField outputField3 = getOutputField("PREV", ResultFeature.PREDICTED_VALUE, null);
        DataField dataField = new DataField();
        dataField.setName(FieldName.create("targetField"));
        dataField.setOpType(OpType.CATEGORICAL);
        DataDictionary dataDictionary = new DataDictionary();
        dataDictionary.addDataFields(new DataField[]{dataField});
        Model regressionModel = new RegressionModel();
        regressionModel.setNormalizationMethod(RegressionModel.NormalizationMethod.CAUCHIT);
        regressionModel.addRegressionTables(new RegressionTable[]{regressionTable, regressionTable2});
        regressionModel.setModelName(KiePMMLModelUtils.getGeneratedClassName("RegressionModel"));
        Output output = new Output();
        output.addOutputFields(new OutputField[]{outputField, outputField2, outputField3});
        regressionModel.setOutput(output);
        MiningField miningField = new MiningField();
        miningField.setUsageType(MiningField.UsageType.TARGET);
        miningField.setName(dataField.getName());
        MiningSchema miningSchema = new MiningSchema();
        miningSchema.addMiningFields(new MiningField[]{miningField});
        regressionModel.setMiningSchema(miningSchema);
        PMML pmml = new PMML();
        pmml.setDataDictionary(dataDictionary);
        pmml.addModels(new Model[]{regressionModel});
        RegressionCompilationDTO fromCompilationDTORegressionTablesAndNormalizationMethod = RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod(CommonCompilationDTO.fromGeneratedPackageNameAndFields("PACKAGE_NAME", pmml, regressionModel, new HasClassLoaderMock()), regressionModel.getRegressionTables(), regressionModel.getNormalizationMethod());
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        regressionModel.getRegressionTables().forEach(regressionTable3 -> {
            linkedHashMap.put("defpack." + regressionTable3.getTargetCategory().toString().toUpperCase(), new KiePMMLTableSourceCategory("", regressionTable3.getTargetCategory().toString()));
        });
        MethodDeclaration clone = STATIC_GETTER_METHOD.clone();
        KiePMMLClassificationTableFactory.setStaticGetter(fromCompilationDTORegressionTablesAndNormalizationMethod, linkedHashMap, clone, "variableName");
        Assertions.assertThat(JavaParserUtils.equalsNode(JavaParserUtils.parseMethod(FileUtils.getFileContent(TEST_02_SOURCE)), clone)).isTrue();
    }

    @Test
    public void getProbabilityMapFunctionExpressionWithSupportedMethods() {
        KiePMMLClassificationTableFactory.SUPPORTED_NORMALIZATION_METHODS.forEach(normalizationMethod -> {
            try {
                Assertions.assertThat(JavaParserUtils.equalsNode(JavaParserUtils.parseExpression(String.format(FileUtils.getFileContent(TEST_01_SOURCE), normalizationMethod.name())), KiePMMLClassificationTableFactory.getProbabilityMapFunctionExpression(normalizationMethod, false))).isTrue();
            } catch (IOException e) {
                Assertions.fail(e.getMessage());
            }
        });
    }

    @Test
    public void getProbabilityMapFunctionExpressionWithUnSupportedMethods() {
        KiePMMLClassificationTableFactory.UNSUPPORTED_NORMALIZATION_METHODS.forEach(normalizationMethod -> {
            try {
                KiePMMLClassificationTableFactory.getProbabilityMapFunctionExpression(normalizationMethod, false);
                Assertions.fail("Expecting KiePMMLInternalException with normalizationMethod " + normalizationMethod);
            } catch (Exception e) {
                Assertions.assertThat(e).isInstanceOf(KiePMMLInternalException.class);
            }
        });
    }

    @Test
    public void getProbabilityMapFunctionSupportedExpression() throws IOException {
        Assertions.assertThat(JavaParserUtils.equalsNode(JavaParserUtils.parseExpression(String.format(FileUtils.getFileContent(TEST_01_SOURCE), RegressionModel.NormalizationMethod.CAUCHIT.name())), KiePMMLClassificationTableFactory.getProbabilityMapFunctionSupportedExpression(RegressionModel.NormalizationMethod.CAUCHIT, true))).isTrue();
    }

    private OutputField getOutputField(String str, ResultFeature resultFeature, String str2) {
        OutputField outputField = new OutputField();
        outputField.setName(FieldName.create(str));
        outputField.setResultFeature(resultFeature);
        if (str2 != null) {
            outputField.setTargetField(FieldName.create(str2));
        }
        return outputField;
    }
}
