/*
 * Decompiled with CFR 0.152.
 */
package org.kie.pmml.models.regression.compiler.factories;

import com.github.javaparser.ast.CompilationUnit;
import com.github.javaparser.ast.Node;
import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration;
import com.github.javaparser.ast.body.MethodDeclaration;
import com.github.javaparser.ast.expr.Expression;
import com.github.javaparser.ast.expr.MethodReferenceExpr;
import java.io.IOException;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.assertj.core.api.Assertions;
import org.dmg.pmml.DataDictionary;
import org.dmg.pmml.DataField;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.PMML;
import org.dmg.pmml.ResultFeature;
import org.dmg.pmml.regression.RegressionModel;
import org.dmg.pmml.regression.RegressionTable;
import org.drools.util.FileUtils;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.kie.pmml.api.compilation.PMMLCompilationContext;
import org.kie.pmml.api.enums.OP_TYPE;
import org.kie.pmml.api.exceptions.KiePMMLInternalException;
import org.kie.pmml.commons.utils.KiePMMLModelUtils;
import org.kie.pmml.compiler.api.dto.CommonCompilationDTO;
import org.kie.pmml.compiler.api.dto.CompilationDTO;
import org.kie.pmml.compiler.commons.mocks.PMMLCompilationContextMock;
import org.kie.pmml.compiler.commons.testutils.CodegenTestUtils;
import org.kie.pmml.compiler.commons.utils.JavaParserUtils;
import org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO;
import org.kie.pmml.models.regression.compiler.factories.AbstractKiePMMLRegressionTableRegressionFactoryTest;
import org.kie.pmml.models.regression.compiler.factories.KiePMMLClassificationTableFactory;
import org.kie.pmml.models.regression.model.KiePMMLClassificationTable;
import org.kie.pmml.models.regression.model.tuples.KiePMMLTableSourceCategory;

public class KiePMMLClassificationTableFactoryTest
extends AbstractKiePMMLRegressionTableRegressionFactoryTest {
    private static final String TEST_01_SOURCE = "KiePMMLClassificationTableFactoryTest_01.txt";
    private static final String TEST_02_SOURCE = "KiePMMLClassificationTableFactoryTest_02.txt";
    private static CompilationUnit COMPILATION_UNIT;
    private static ClassOrInterfaceDeclaration MODEL_TEMPLATE;
    private static MethodDeclaration STATIC_GETTER_METHOD;

    @BeforeAll
    public static void setup() {
        COMPILATION_UNIT = JavaParserUtils.getFromFileName((String)"KiePMMLClassificationTableTemplate.tmpl");
        MODEL_TEMPLATE = (ClassOrInterfaceDeclaration)COMPILATION_UNIT.getClassByName("KiePMMLClassificationTableTemplate").get();
        STATIC_GETTER_METHOD = (MethodDeclaration)MODEL_TEMPLATE.getMethodsByName("getKiePMMLTable").get(0);
    }

    @Test
    void getClassificationTable() {
        RegressionTable regressionTableProf = this.getRegressionTable(3.5, "professional");
        RegressionTable regressionTableCler = this.getRegressionTable(27.4, "clerical");
        OutputField outputFieldCat = this.getOutputField("CAT-1", ResultFeature.PROBABILITY, "CatPred-1");
        OutputField outputFieldNum = this.getOutputField("NUM-1", ResultFeature.PROBABILITY, "NumPred-0");
        OutputField outputFieldPrev = this.getOutputField("PREV", ResultFeature.PREDICTED_VALUE, null);
        String targetField = "targetField";
        DataField dataField = new DataField();
        dataField.setName(FieldName.create((String)targetField));
        dataField.setOpType(OpType.CATEGORICAL);
        DataDictionary dataDictionary = new DataDictionary();
        dataDictionary.addDataFields(new DataField[]{dataField});
        RegressionModel regressionModel = new RegressionModel();
        regressionModel.setNormalizationMethod(RegressionModel.NormalizationMethod.CAUCHIT);
        regressionModel.addRegressionTables(new RegressionTable[]{regressionTableProf, regressionTableCler});
        regressionModel.setModelName(KiePMMLModelUtils.getGeneratedClassName((String)"RegressionModel"));
        Output output = new Output();
        output.addOutputFields(new OutputField[]{outputFieldCat, outputFieldNum, outputFieldPrev});
        regressionModel.setOutput(output);
        MiningField targetMiningField = new MiningField();
        targetMiningField.setUsageType(MiningField.UsageType.TARGET);
        targetMiningField.setName(dataField.getName());
        MiningSchema miningSchema = new MiningSchema();
        miningSchema.addMiningFields(new MiningField[]{targetMiningField});
        regressionModel.setMiningSchema(miningSchema);
        PMML pmml = new PMML();
        pmml.setDataDictionary(dataDictionary);
        pmml.addModels(new Model[]{regressionModel});
        CommonCompilationDTO source = CommonCompilationDTO.fromGeneratedPackageNameAndFields((String)"PACKAGE_NAME", (PMML)pmml, (Model)regressionModel, (PMMLCompilationContext)new PMMLCompilationContextMock(), (String)"FILENAME");
        RegressionCompilationDTO compilationDTO = RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod((CompilationDTO)source, (List)regressionModel.getRegressionTables(), (RegressionModel.NormalizationMethod)regressionModel.getNormalizationMethod());
        KiePMMLClassificationTable retrieved = KiePMMLClassificationTableFactory.getClassificationTable((RegressionCompilationDTO)compilationDTO);
        Assertions.assertThat((Object)retrieved).isNotNull();
        Assertions.assertThat((Map)retrieved.getCategoryTableMap()).hasSameSizeAs((Iterable)regressionModel.getRegressionTables());
        regressionModel.getRegressionTables().forEach(regressionTable -> Assertions.assertThat((Map)retrieved.getCategoryTableMap()).containsKey((Object)regressionTable.getTargetCategory().toString()));
        Assertions.assertThat((String)retrieved.getRegressionNormalizationMethod().getName()).isEqualTo(regressionModel.getNormalizationMethod().value());
        Assertions.assertThat((Comparable)retrieved.getOpType()).isEqualTo((Object)OP_TYPE.CATEGORICAL);
        boolean isBinary = regressionModel.getRegressionTables().size() == 2;
        Assertions.assertThat((boolean)retrieved.isBinary()).isEqualTo(isBinary);
        Assertions.assertThat((boolean)retrieved.isBinary()).isEqualTo(isBinary);
        Assertions.assertThat((String)retrieved.getTargetField()).isEqualTo(targetMiningField.getName().getValue());
    }

    @Test
    void getClassificationTableBuilders() {
        RegressionTable regressionTableProf = this.getRegressionTable(3.5, "professional");
        RegressionTable regressionTableCler = this.getRegressionTable(27.4, "clerical");
        OutputField outputFieldCat = this.getOutputField("CAT-1", ResultFeature.PROBABILITY, "CatPred-1");
        OutputField outputFieldNum = this.getOutputField("NUM-1", ResultFeature.PROBABILITY, "NumPred-0");
        OutputField outputFieldPrev = this.getOutputField("PREV", ResultFeature.PREDICTED_VALUE, null);
        String targetField = "targetField";
        DataField dataField = new DataField();
        dataField.setName(FieldName.create((String)targetField));
        dataField.setOpType(OpType.CATEGORICAL);
        DataDictionary dataDictionary = new DataDictionary();
        dataDictionary.addDataFields(new DataField[]{dataField});
        RegressionModel regressionModel = new RegressionModel();
        regressionModel.setNormalizationMethod(RegressionModel.NormalizationMethod.CAUCHIT);
        regressionModel.addRegressionTables(new RegressionTable[]{regressionTableProf, regressionTableCler});
        regressionModel.setModelName(KiePMMLModelUtils.getGeneratedClassName((String)"RegressionModel"));
        Output output = new Output();
        output.addOutputFields(new OutputField[]{outputFieldCat, outputFieldNum, outputFieldPrev});
        regressionModel.setOutput(output);
        MiningField miningField = new MiningField();
        miningField.setUsageType(MiningField.UsageType.TARGET);
        miningField.setName(dataField.getName());
        MiningSchema miningSchema = new MiningSchema();
        miningSchema.addMiningFields(new MiningField[]{miningField});
        regressionModel.setMiningSchema(miningSchema);
        PMML pmml = new PMML();
        pmml.setDataDictionary(dataDictionary);
        pmml.addModels(new Model[]{regressionModel});
        CommonCompilationDTO source = CommonCompilationDTO.fromGeneratedPackageNameAndFields((String)"PACKAGE_NAME", (PMML)pmml, (Model)regressionModel, (PMMLCompilationContext)new PMMLCompilationContextMock(), (String)"FILENAME");
        RegressionCompilationDTO compilationDTO = RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod((CompilationDTO)source, (List)regressionModel.getRegressionTables(), (RegressionModel.NormalizationMethod)regressionModel.getNormalizationMethod());
        Map retrieved = KiePMMLClassificationTableFactory.getClassificationTableBuilders((RegressionCompilationDTO)compilationDTO);
        Assertions.assertThat((Map)retrieved).isNotNull();
        Assertions.assertThat((Map)retrieved).hasSize(3);
        retrieved.values().forEach(kiePMMLTableSourceCategory -> this.commonValidateKiePMMLRegressionTable(kiePMMLTableSourceCategory.getSource()));
        Map<String, String> sources = retrieved.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, stringKiePMMLTableSourceCategoryEntry -> ((KiePMMLTableSourceCategory)stringKiePMMLTableSourceCategoryEntry.getValue()).getSource()));
        CodegenTestUtils.commonValidateCompilation(sources);
    }

    @Test
    void getClassificationTableBuilder() {
        RegressionTable regressionTableProf = this.getRegressionTable(3.5, "professional");
        RegressionTable regressionTableCler = this.getRegressionTable(27.4, "clerical");
        OutputField outputFieldCat = this.getOutputField("CAT-1", ResultFeature.PROBABILITY, "CatPred-1");
        OutputField outputFieldNum = this.getOutputField("NUM-1", ResultFeature.PROBABILITY, "NumPred-0");
        OutputField outputFieldPrev = this.getOutputField("PREV", ResultFeature.PREDICTED_VALUE, null);
        String targetField = "targetField";
        DataField dataField = new DataField();
        dataField.setName(FieldName.create((String)targetField));
        dataField.setOpType(OpType.CATEGORICAL);
        DataDictionary dataDictionary = new DataDictionary();
        dataDictionary.addDataFields(new DataField[]{dataField});
        RegressionModel regressionModel = new RegressionModel();
        regressionModel.setNormalizationMethod(RegressionModel.NormalizationMethod.CAUCHIT);
        regressionModel.addRegressionTables(new RegressionTable[]{regressionTableProf, regressionTableCler});
        regressionModel.setModelName(KiePMMLModelUtils.getGeneratedClassName((String)"RegressionModel"));
        Output output = new Output();
        output.addOutputFields(new OutputField[]{outputFieldCat, outputFieldNum, outputFieldPrev});
        regressionModel.setOutput(output);
        MiningField miningField = new MiningField();
        miningField.setUsageType(MiningField.UsageType.TARGET);
        miningField.setName(dataField.getName());
        MiningSchema miningSchema = new MiningSchema();
        miningSchema.addMiningFields(new MiningField[]{miningField});
        regressionModel.setMiningSchema(miningSchema);
        PMML pmml = new PMML();
        pmml.setDataDictionary(dataDictionary);
        pmml.addModels(new Model[]{regressionModel});
        CommonCompilationDTO source = CommonCompilationDTO.fromGeneratedPackageNameAndFields((String)"PACKAGE_NAME", (PMML)pmml, (Model)regressionModel, (PMMLCompilationContext)new PMMLCompilationContextMock(), (String)"FILENAME");
        RegressionCompilationDTO compilationDTO = RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod((CompilationDTO)source, (List)regressionModel.getRegressionTables(), (RegressionModel.NormalizationMethod)regressionModel.getNormalizationMethod());
        LinkedHashMap regressionTablesMap = new LinkedHashMap();
        regressionModel.getRegressionTables().forEach(regressionTable -> {
            String key = compilationDTO.getPackageName() + "." + regressionTable.getTargetCategory().toString().toUpperCase();
            KiePMMLTableSourceCategory value = new KiePMMLTableSourceCategory("", regressionTable.getTargetCategory().toString());
            regressionTablesMap.put(key, value);
        });
        Map.Entry retrieved = KiePMMLClassificationTableFactory.getClassificationTableBuilder((RegressionCompilationDTO)compilationDTO, regressionTablesMap);
        Assertions.assertThat((Object)retrieved).isNotNull();
    }

    @Test
    void getProbabilityMapUnsupportedFunction() {
        KiePMMLClassificationTableFactory.UNSUPPORTED_NORMALIZATION_METHODS.forEach(normalizationMethod -> {
            String expected;
            try {
                KiePMMLClassificationTableFactory.getProbabilityMapFunction((RegressionModel.NormalizationMethod)normalizationMethod, (boolean)false);
            }
            catch (Throwable t) {
                Assertions.assertThat((Throwable)t).isInstanceOf(KiePMMLInternalException.class);
                expected = String.format("Unsupported NormalizationMethod %s", normalizationMethod);
                Assertions.assertThat((String)t.getMessage()).isEqualTo(expected);
            }
            try {
                KiePMMLClassificationTableFactory.getProbabilityMapFunction((RegressionModel.NormalizationMethod)normalizationMethod, (boolean)true);
            }
            catch (Throwable t) {
                Assertions.assertThat((Throwable)t).isInstanceOf(KiePMMLInternalException.class);
                expected = String.format("Unsupported NormalizationMethod %s", normalizationMethod);
                Assertions.assertThat((String)t.getMessage()).isEqualTo(expected);
            }
        });
    }

    @Test
    void getProbabilityMapSupportedFunction() {
        KiePMMLClassificationTableFactory.SUPPORTED_NORMALIZATION_METHODS.forEach(normalizationMethod -> Assertions.assertThat((Object)KiePMMLClassificationTableFactory.getProbabilityMapFunction((RegressionModel.NormalizationMethod)normalizationMethod, (boolean)false)).isNotNull());
        KiePMMLClassificationTableFactory.SUPPORTED_NORMALIZATION_METHODS.forEach(normalizationMethod -> Assertions.assertThat((Object)KiePMMLClassificationTableFactory.getProbabilityMapFunction((RegressionModel.NormalizationMethod)normalizationMethod, (boolean)true)).isNotNull());
    }

    @Test
    void setStaticGetter() throws IOException {
        String variableName = "variableName";
        RegressionTable regressionTableProf = this.getRegressionTable(3.5, "professional");
        RegressionTable regressionTableCler = this.getRegressionTable(27.4, "clerical");
        OutputField outputFieldCat = this.getOutputField("CAT-1", ResultFeature.PROBABILITY, "CatPred-1");
        OutputField outputFieldNum = this.getOutputField("NUM-1", ResultFeature.PROBABILITY, "NumPred-0");
        OutputField outputFieldPrev = this.getOutputField("PREV", ResultFeature.PREDICTED_VALUE, null);
        String targetField = "targetField";
        DataField dataField = new DataField();
        dataField.setName(FieldName.create((String)targetField));
        dataField.setOpType(OpType.CATEGORICAL);
        DataDictionary dataDictionary = new DataDictionary();
        dataDictionary.addDataFields(new DataField[]{dataField});
        RegressionModel regressionModel = new RegressionModel();
        regressionModel.setNormalizationMethod(RegressionModel.NormalizationMethod.CAUCHIT);
        regressionModel.addRegressionTables(new RegressionTable[]{regressionTableProf, regressionTableCler});
        regressionModel.setModelName(KiePMMLModelUtils.getGeneratedClassName((String)"RegressionModel"));
        Output output = new Output();
        output.addOutputFields(new OutputField[]{outputFieldCat, outputFieldNum, outputFieldPrev});
        regressionModel.setOutput(output);
        MiningField miningField = new MiningField();
        miningField.setUsageType(MiningField.UsageType.TARGET);
        miningField.setName(dataField.getName());
        MiningSchema miningSchema = new MiningSchema();
        miningSchema.addMiningFields(new MiningField[]{miningField});
        regressionModel.setMiningSchema(miningSchema);
        PMML pmml = new PMML();
        pmml.setDataDictionary(dataDictionary);
        pmml.addModels(new Model[]{regressionModel});
        CommonCompilationDTO source = CommonCompilationDTO.fromGeneratedPackageNameAndFields((String)"PACKAGE_NAME", (PMML)pmml, (Model)regressionModel, (PMMLCompilationContext)new PMMLCompilationContextMock(), (String)"FILENAME");
        RegressionCompilationDTO compilationDTO = RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod((CompilationDTO)source, (List)regressionModel.getRegressionTables(), (RegressionModel.NormalizationMethod)regressionModel.getNormalizationMethod());
        LinkedHashMap regressionTablesMap = new LinkedHashMap();
        regressionModel.getRegressionTables().forEach(regressionTable -> {
            String key = "defpack." + regressionTable.getTargetCategory().toString().toUpperCase();
            KiePMMLTableSourceCategory value = new KiePMMLTableSourceCategory("", regressionTable.getTargetCategory().toString());
            regressionTablesMap.put(key, value);
        });
        MethodDeclaration staticGetterMethod = STATIC_GETTER_METHOD.clone();
        KiePMMLClassificationTableFactory.setStaticGetter((RegressionCompilationDTO)compilationDTO, regressionTablesMap, (MethodDeclaration)staticGetterMethod, (String)variableName);
        String text = FileUtils.getFileContent((String)TEST_02_SOURCE);
        MethodDeclaration expected = JavaParserUtils.parseMethod((String)text);
        Assertions.assertThat((boolean)JavaParserUtils.equalsNode((Node)expected, (Node)staticGetterMethod)).isTrue();
    }

    @Test
    void getProbabilityMapFunctionExpressionWithSupportedMethods() {
        KiePMMLClassificationTableFactory.SUPPORTED_NORMALIZATION_METHODS.forEach(normalizationMethod -> {
            Expression retrieved = KiePMMLClassificationTableFactory.getProbabilityMapFunctionExpression((RegressionModel.NormalizationMethod)normalizationMethod, (boolean)false);
            try {
                String text = FileUtils.getFileContent((String)TEST_01_SOURCE);
                Expression expected = JavaParserUtils.parseExpression((String)String.format(text, normalizationMethod.name()));
                Assertions.assertThat((boolean)JavaParserUtils.equalsNode((Node)expected, (Node)retrieved)).isTrue();
            }
            catch (IOException e) {
                Assertions.fail((String)e.getMessage());
            }
        });
    }

    @Test
    void getProbabilityMapFunctionExpressionWithUnSupportedMethods() {
        KiePMMLClassificationTableFactory.UNSUPPORTED_NORMALIZATION_METHODS.forEach(normalizationMethod -> {
            try {
                KiePMMLClassificationTableFactory.getProbabilityMapFunctionExpression((RegressionModel.NormalizationMethod)normalizationMethod, (boolean)false);
                Assertions.fail((String)("Expecting KiePMMLInternalException with normalizationMethod " + normalizationMethod));
            }
            catch (Exception e) {
                Assertions.assertThat((Throwable)e).isInstanceOf(KiePMMLInternalException.class);
            }
        });
    }

    @Test
    void getProbabilityMapFunctionSupportedExpression() throws IOException {
        MethodReferenceExpr retrieved = KiePMMLClassificationTableFactory.getProbabilityMapFunctionSupportedExpression((RegressionModel.NormalizationMethod)RegressionModel.NormalizationMethod.CAUCHIT, (boolean)true);
        String text = FileUtils.getFileContent((String)TEST_01_SOURCE);
        Expression expected = JavaParserUtils.parseExpression((String)String.format(text, RegressionModel.NormalizationMethod.CAUCHIT.name()));
        Assertions.assertThat((boolean)JavaParserUtils.equalsNode((Node)expected, (Node)retrieved)).isTrue();
    }

    private OutputField getOutputField(String name, ResultFeature resultFeature, String targetField) {
        OutputField toReturn = new OutputField();
        toReturn.setName(FieldName.create((String)name));
        toReturn.setResultFeature(resultFeature);
        if (targetField != null) {
            toReturn.setTargetField(FieldName.create((String)targetField));
        }
        return toReturn;
    }
}

