/*
 * Decompiled with CFR 0.152.
 */
package org.kie.pmml.models.regression.compiler.factories;

import com.github.javaparser.ast.CompilationUnit;
import com.github.javaparser.ast.Node;
import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration;
import com.github.javaparser.ast.body.MethodDeclaration;
import com.github.javaparser.ast.expr.Expression;
import com.github.javaparser.ast.expr.MethodCallExpr;
import com.github.javaparser.ast.expr.NameExpr;
import com.github.javaparser.ast.expr.StringLiteralExpr;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.dmg.pmml.DataDictionary;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PMML;
import org.dmg.pmml.TransformationDictionary;
import org.dmg.pmml.regression.CategoricalPredictor;
import org.dmg.pmml.regression.PredictorTerm;
import org.dmg.pmml.regression.RegressionModel;
import org.dmg.pmml.regression.RegressionTable;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import org.kie.pmml.api.enums.MINING_FUNCTION;
import org.kie.pmml.api.enums.OP_TYPE;
import org.kie.pmml.api.enums.PMML_MODEL;
import org.kie.pmml.commons.model.HasClassLoader;
import org.kie.pmml.compiler.api.dto.CommonCompilationDTO;
import org.kie.pmml.compiler.api.dto.CompilationDTO;
import org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils;
import org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock;
import org.kie.pmml.compiler.commons.utils.JavaParserUtils;
import org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO;
import org.kie.pmml.models.regression.compiler.factories.KiePMMLRegressionModelFactory;
import org.kie.pmml.models.regression.model.AbstractKiePMMLTable;
import org.kie.pmml.models.regression.model.KiePMMLClassificationTable;
import org.kie.pmml.models.regression.model.KiePMMLRegressionModel;
import org.kie.pmml.models.regression.model.KiePMMLRegressionTable;
import org.kie.pmml.models.regression.model.enums.REGRESSION_NORMALIZATION_METHOD;
import org.kie.test.util.filesystem.FileUtils;

public class KiePMMLRegressionModelFactoryTest {
    private static CompilationUnit COMPILATION_UNIT;
    private static ClassOrInterfaceDeclaration MODEL_TEMPLATE;
    private static final String TEST_01_SOURCE = "KiePMMLRegressionModelFactoryTest_01.txt";
    private static final String modelName = "firstModel";
    private static final double tableIntercept = 3.5;
    private static final Object tableTargetCategory;
    private static List<RegressionTable> regressionTables;
    private static List<DataField> dataFields;
    private static List<MiningField> miningFields;
    private static MiningField targetMiningField;
    private static DataDictionary dataDictionary;
    private static TransformationDictionary transformationDictionary;
    private static MiningSchema miningSchema;
    private static RegressionModel regressionModel;
    private static PMML pmml;

    @BeforeClass
    public static void setup() {
        Random random = new Random();
        HashSet fieldNames = new HashSet();
        regressionTables = IntStream.range(0, 3).mapToObj(i -> {
            ArrayList categoricalPredictors = new ArrayList();
            ArrayList numericPredictors = new ArrayList();
            ArrayList predictorTerms = new ArrayList();
            IntStream.range(0, 3).forEach(j -> {
                String catFieldName = "CatPred-" + j;
                String numFieldName = "NumPred-" + j;
                categoricalPredictors.add(PMMLModelTestUtils.getCategoricalPredictor((String)catFieldName, (double)random.nextDouble(), (double)random.nextDouble()));
                numericPredictors.add(PMMLModelTestUtils.getNumericPredictor((String)numFieldName, (int)random.nextInt(), (double)random.nextDouble()));
                predictorTerms.add(PMMLModelTestUtils.getPredictorTerm((String)("PredTerm-" + j), (double)random.nextDouble(), Arrays.asList(catFieldName, numFieldName)));
                fieldNames.add(catFieldName);
                fieldNames.add(numFieldName);
            });
            return PMMLModelTestUtils.getRegressionTable(categoricalPredictors, numericPredictors, predictorTerms, (double)(3.5 + random.nextDouble()), (Object)(tableTargetCategory + "-" + i));
        }).collect(Collectors.toList());
        dataFields = new ArrayList<DataField>();
        miningFields = new ArrayList<MiningField>();
        fieldNames.forEach(fieldName -> {
            dataFields.add(PMMLModelTestUtils.getDataField((String)fieldName, (OpType)OpType.CATEGORICAL, (DataType)DataType.STRING));
            miningFields.add(PMMLModelTestUtils.getMiningField((String)fieldName, (MiningField.UsageType)MiningField.UsageType.ACTIVE));
        });
        targetMiningField = miningFields.get(0);
        targetMiningField.setUsageType(MiningField.UsageType.TARGET);
        dataDictionary = PMMLModelTestUtils.getDataDictionary(dataFields);
        transformationDictionary = new TransformationDictionary();
        miningSchema = PMMLModelTestUtils.getMiningSchema(miningFields);
        regressionModel = PMMLModelTestUtils.getRegressionModel((String)modelName, (MiningFunction)MiningFunction.REGRESSION, (MiningSchema)miningSchema, regressionTables);
        COMPILATION_UNIT = JavaParserUtils.getFromFileName((String)"KiePMMLRegressionModelTemplate.tmpl");
        MODEL_TEMPLATE = (ClassOrInterfaceDeclaration)COMPILATION_UNIT.getClassByName("KiePMMLRegressionModelTemplate").get();
        pmml = new PMML();
        pmml.setDataDictionary(dataDictionary);
        pmml.setTransformationDictionary(transformationDictionary);
        pmml.addModels(new Model[]{regressionModel});
    }

    @Test
    public void getKiePMMLRegressionModelClasses() throws IOException, IllegalAccessException, InstantiationException {
        CommonCompilationDTO compilationDTO = CommonCompilationDTO.fromGeneratedPackageNameAndFields((String)"PACKAGE_NAME", (PMML)pmml, (Model)regressionModel, (HasClassLoader)new HasClassLoaderMock());
        KiePMMLRegressionModel retrieved = KiePMMLRegressionModelFactory.getKiePMMLRegressionModelClasses((RegressionCompilationDTO)RegressionCompilationDTO.fromCompilationDTO((CompilationDTO)compilationDTO));
        Assert.assertNotNull((Object)retrieved);
        Assert.assertEquals((Object)regressionModel.getModelName(), (Object)retrieved.getName());
        Assert.assertEquals((Object)MINING_FUNCTION.byName((String)regressionModel.getMiningFunction().value()), (Object)retrieved.getMiningFunction());
        Assert.assertEquals((Object)miningFields.get(0).getName().getValue(), (Object)retrieved.getTargetField());
        AbstractKiePMMLTable regressionTable = retrieved.getRegressionTable();
        Assert.assertNotNull((Object)regressionTable);
        Assert.assertTrue((boolean)(regressionTable instanceof KiePMMLClassificationTable));
        this.evaluateCategoricalRegressionTable((KiePMMLClassificationTable)regressionTable);
    }

    @Test
    public void getKiePMMLRegressionModelSourcesMap() throws IOException {
        CommonCompilationDTO compilationDTO = CommonCompilationDTO.fromGeneratedPackageNameAndFields((String)"PACKAGE_NAME", (PMML)pmml, (Model)regressionModel, (HasClassLoader)new HasClassLoaderMock());
        Map retrieved = KiePMMLRegressionModelFactory.getKiePMMLRegressionModelSourcesMap((RegressionCompilationDTO)RegressionCompilationDTO.fromCompilationDTO((CompilationDTO)compilationDTO));
        Assert.assertNotNull((Object)retrieved);
        int expectedSize = regressionTables.size() + 2;
        Assert.assertEquals((long)expectedSize, (long)retrieved.size());
    }

    @Test
    public void getRegressionTablesMap() {
        CommonCompilationDTO compilationDTO = CommonCompilationDTO.fromGeneratedPackageNameAndFields((String)"PACKAGE_NAME", (PMML)pmml, (Model)regressionModel, (HasClassLoader)new HasClassLoaderMock());
        Map retrieved = KiePMMLRegressionModelFactory.getRegressionTablesMap((RegressionCompilationDTO)RegressionCompilationDTO.fromCompilationDTO((CompilationDTO)compilationDTO));
        int expectedSize = regressionTables.size() + 1;
        Assert.assertEquals((long)expectedSize, (long)retrieved.size());
        Collection values = retrieved.values();
        regressionTables.forEach(regressionTable -> Assert.assertTrue((boolean)values.stream().anyMatch(kiePMMLTableSourceCategory -> kiePMMLTableSourceCategory.getCategory().equals(regressionTable.getTargetCategory()))));
    }

    @Test
    public void setStaticGetter() throws IOException {
        String nestedTable = "NestedTable";
        MINING_FUNCTION miningFunction = MINING_FUNCTION.byName((String)regressionModel.getMiningFunction().value());
        ClassOrInterfaceDeclaration modelTemplate = MODEL_TEMPLATE.clone();
        CommonCompilationDTO source = CommonCompilationDTO.fromGeneratedPackageNameAndFields((String)"PACKAGE_NAME", (PMML)pmml, (Model)regressionModel, (HasClassLoader)new HasClassLoaderMock());
        RegressionCompilationDTO compilationDTO = RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod((CompilationDTO)source, new ArrayList(), (RegressionModel.NormalizationMethod)regressionModel.getNormalizationMethod());
        KiePMMLRegressionModelFactory.setStaticGetter((CompilationDTO)compilationDTO, (ClassOrInterfaceDeclaration)modelTemplate, (String)nestedTable);
        HashMap<Integer, NameExpr> superInvocationExpressionsMap = new HashMap<Integer, NameExpr>();
        superInvocationExpressionsMap.put(0, new NameExpr(String.format("\"%s\"", regressionModel.getModelName())));
        HashMap<String, Object> assignExpressionMap = new HashMap<String, Object>();
        assignExpressionMap.put("targetField", new StringLiteralExpr(targetMiningField.getName().getValue()));
        assignExpressionMap.put("miningFunction", new NameExpr(miningFunction.getClass().getName() + "." + miningFunction.name()));
        assignExpressionMap.put("pmmlMODEL", new NameExpr(PMML_MODEL.class.getName() + "." + PMML_MODEL.REGRESSION_MODEL.name()));
        MethodCallExpr methodCallExpr = new MethodCallExpr();
        methodCallExpr.setScope((Expression)new NameExpr(nestedTable));
        methodCallExpr.setName("getKiePMMLTable");
        assignExpressionMap.put("regressionTable", methodCallExpr);
        MethodDeclaration retrieved = (MethodDeclaration)modelTemplate.getMethodsByName("getModel").get(0);
        String text = FileUtils.getFileContent((String)TEST_01_SOURCE);
        MethodDeclaration expected = JavaParserUtils.parseMethod((String)text);
        Assert.assertTrue((boolean)JavaParserUtils.equalsNode((Node)expected, (Node)retrieved));
    }

    private void evaluateCategoricalRegressionTable(KiePMMLClassificationTable regressionTable) {
        Assert.assertEquals((Object)REGRESSION_NORMALIZATION_METHOD.byName((String)regressionModel.getNormalizationMethod().value()), (Object)regressionTable.getRegressionNormalizationMethod());
        Assert.assertEquals((Object)OP_TYPE.CATEGORICAL, (Object)regressionTable.getOpType());
        Map categoryTableMap = regressionTable.getCategoryTableMap();
        for (RegressionTable originalRegressionTable : regressionTables) {
            Assert.assertTrue((boolean)categoryTableMap.containsKey(originalRegressionTable.getTargetCategory().toString()));
            this.evaluateRegressionTable((KiePMMLRegressionTable)categoryTableMap.get(originalRegressionTable.getTargetCategory().toString()), originalRegressionTable);
        }
    }

    private void evaluateRegressionTable(KiePMMLRegressionTable regressionTable, RegressionTable originalRegressionTable) {
        Assert.assertEquals((Object)originalRegressionTable.getIntercept(), (Object)regressionTable.getIntercept());
        Map numericFunctionMap = regressionTable.getNumericFunctionMap();
        for (Object numericPredictor : originalRegressionTable.getNumericPredictors()) {
            Assert.assertTrue((boolean)numericFunctionMap.containsKey(numericPredictor.getName().getValue()));
        }
        Map categoricalFunctionMap = regressionTable.getCategoricalFunctionMap();
        for (CategoricalPredictor categoricalPredictor : originalRegressionTable.getCategoricalPredictors()) {
            Assert.assertTrue((boolean)categoricalFunctionMap.containsKey(categoricalPredictor.getName().getValue()));
        }
        Map predictorTermsFunctionMap = regressionTable.getPredictorTermsFunctionMap();
        for (PredictorTerm predictorTerm : originalRegressionTable.getPredictorTerms()) {
            Assert.assertTrue((boolean)predictorTermsFunctionMap.containsKey(predictorTerm.getName().getValue()));
        }
    }

    static {
        tableTargetCategory = "professional";
    }
}

