package org.kie.pmml.models.regression.compiler.factories;

import com.github.javaparser.ast.CompilationUnit;
import com.github.javaparser.ast.NodeList;
import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration;
import com.github.javaparser.ast.body.ConstructorDeclaration;
import com.github.javaparser.ast.body.MethodDeclaration;
import com.github.javaparser.ast.expr.DoubleLiteralExpr;
import com.github.javaparser.ast.expr.Expression;
import com.github.javaparser.ast.expr.MethodCallExpr;
import com.github.javaparser.ast.expr.NullLiteralExpr;
import com.github.javaparser.ast.expr.SimpleName;
import com.github.javaparser.ast.expr.StringLiteralExpr;
import com.github.javaparser.ast.expr.VariableDeclarationExpr;
import com.github.javaparser.ast.stmt.BlockStmt;
import com.github.javaparser.ast.stmt.ExpressionStmt;
import com.github.javaparser.ast.stmt.ReturnStmt;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.dmg.pmml.DataDictionary;
import org.dmg.pmml.DataField;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PMML;
import org.dmg.pmml.regression.CategoricalPredictor;
import org.dmg.pmml.regression.PredictorTerm;
import org.dmg.pmml.regression.RegressionModel;
import org.dmg.pmml.regression.RegressionTable;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.kie.pmml.api.enums.RESULT_FEATURE;
import org.kie.pmml.commons.model.KiePMMLOutputField;
import org.kie.pmml.commons.utils.KiePMMLModelUtils;
import org.kie.pmml.compiler.api.dto.CommonCompilationDTO;
import org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils;
import org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock;
import org.kie.pmml.compiler.commons.testutils.CodegenTestUtils;
import org.kie.pmml.compiler.commons.utils.CommonCodegenUtils;
import org.kie.pmml.compiler.commons.utils.JavaParserUtils;
import org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO;
import org.kie.test.util.filesystem.FileUtils;

/* loaded from: input_file:org/kie/pmml/models/regression/compiler/factories/KiePMMLRegressionTableRegressionFactoryTest.class */
public class KiePMMLRegressionTableRegressionFactoryTest extends AbstractKiePMMLRegressionTableRegressionFactoryTest {
    private static final String PACKAGE_NAME = "packagename";
    private static final String TEST_01_SOURCE = "KiePMMLRegressionTableRegressionFactoryTest_01.txt";
    private static final String TEST_02_SOURCE = "KiePMMLRegressionTableRegressionFactoryTest_02.txt";
    private static final String TEST_03_SOURCE = "KiePMMLRegressionTableRegressionFactoryTest_03.txt";
    private static final String TEST_04_SOURCE = "KiePMMLRegressionTableRegressionFactoryTest_04.txt";
    private static final String TEST_05_SOURCE = "KiePMMLRegressionTableRegressionFactoryTest_05.txt";
    private static final String TEST_06_SOURCE = "KiePMMLRegressionTableRegressionFactoryTest_06.txt";
    private static CompilationUnit compilationUnit;
    private static ClassOrInterfaceDeclaration modelTemplate;

    @Before
    public void setup() {
        compilationUnit = JavaParserUtils.getFromFileName("KiePMMLRegressionTableRegressionTemplate.tmpl");
        modelTemplate = (ClassOrInterfaceDeclaration) compilationUnit.getClassByName("KiePMMLRegressionTableRegressionTemplate").get();
    }

    @Test
    public void getRegressionTableTest() {
        this.regressionTable = getRegressionTable(3.5d, "professional");
        Model regressionModel = new RegressionModel();
        regressionModel.setNormalizationMethod(RegressionModel.NormalizationMethod.CAUCHIT);
        regressionModel.addRegressionTables(new RegressionTable[]{this.regressionTable});
        regressionModel.setModelName(KiePMMLModelUtils.getGeneratedClassName("RegressionModel"));
        DataField dataField = new DataField();
        dataField.setName(FieldName.create("targetField"));
        dataField.setOpType(OpType.CATEGORICAL);
        DataDictionary dataDictionary = new DataDictionary();
        dataDictionary.addDataFields(new DataField[]{dataField});
        MiningField miningField = new MiningField();
        miningField.setUsageType(MiningField.UsageType.TARGET);
        miningField.setName(dataField.getName());
        MiningSchema miningSchema = new MiningSchema();
        miningSchema.addMiningFields(new MiningField[]{miningField});
        regressionModel.setMiningSchema(miningSchema);
        PMML pmml = new PMML();
        pmml.setDataDictionary(dataDictionary);
        pmml.addModels(new Model[]{regressionModel});
        LinkedHashMap regressionTables = KiePMMLRegressionTableRegressionFactory.getRegressionTables(RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod(CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, regressionModel, new HasClassLoaderMock()), new ArrayList(), regressionModel.getNormalizationMethod()));
        Assert.assertNotNull(regressionTables);
        regressionTables.values().forEach(kiePMMLTableSourceCategory -> {
            commonValidateKiePMMLRegressionTable(kiePMMLTableSourceCategory.getSource());
        });
    }

    @Test
    public void setConstructor() {
        this.regressionTable = getRegressionTable(3.5d, "professional");
        ConstructorDeclaration constructorDeclaration = (ConstructorDeclaration) modelTemplate.getDefaultConstructor().get();
        SimpleName simpleName = new SimpleName("TableName");
        KiePMMLRegressionTableRegressionFactory.setConstructor(this.regressionTable, constructorDeclaration, simpleName, "targetField", this.regressionTable.getTargetCategory(), RegressionModel.NormalizationMethod.CAUCHIT);
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        hashMap2.put("targetField", new StringLiteralExpr("targetField"));
        hashMap2.put("intercept", new DoubleLiteralExpr(String.valueOf(3.5d)));
        hashMap2.put("targetCategory", CommonCodegenUtils.getExpressionForObject(this.regressionTable.getTargetCategory()));
        hashMap2.put("resultUpdater", KiePMMLRegressionTableRegressionFactory.createResultUpdaterExpression(RegressionModel.NormalizationMethod.CAUCHIT));
        Assert.assertTrue(CodegenTestUtils.commonEvaluateConstructor(constructorDeclaration, simpleName.asString(), hashMap, hashMap2));
    }

    @Test
    public void createResultUpdaterExpressionWithSupportedMethods() {
        KiePMMLRegressionTableRegressionFactory.SUPPORTED_NORMALIZATION_METHODS.forEach(normalizationMethod -> {
            try {
                Assert.assertTrue(JavaParserUtils.equalsNode(JavaParserUtils.parseExpression(String.format(FileUtils.getFileContent(TEST_03_SOURCE), normalizationMethod.name())), KiePMMLRegressionTableRegressionFactory.createResultUpdaterExpression(normalizationMethod)));
            } catch (IOException e) {
                Assert.fail(e.getMessage());
            }
        });
    }

    @Test
    public void createResultUpdaterExpressionWithUnSupportedMethods() {
        KiePMMLRegressionTableRegressionFactory.UNSUPPORTED_NORMALIZATION_METHODS.forEach(normalizationMethod -> {
            Assert.assertTrue(KiePMMLRegressionTableRegressionFactory.createResultUpdaterExpression(normalizationMethod) instanceof NullLiteralExpr);
        });
    }

    @Test
    public void createResultUpdaterSupportedExpression() throws IOException {
        Assert.assertTrue(JavaParserUtils.equalsNode(JavaParserUtils.parseExpression(String.format(FileUtils.getFileContent(TEST_03_SOURCE), RegressionModel.NormalizationMethod.CAUCHIT.name())), KiePMMLRegressionTableRegressionFactory.createResultUpdaterSupportedExpression(RegressionModel.NormalizationMethod.CAUCHIT)));
    }

    @Test
    public void createNumericPredictorsExpressions() {
        List list = (List) IntStream.range(0, 3).mapToObj(i -> {
            return PMMLModelTestUtils.getNumericPredictor("predictorName-" + i, i, 1.23d * i);
        }).collect(Collectors.toList());
        Assert.assertEquals(list.size(), KiePMMLRegressionTableRegressionFactory.createNumericPredictorsExpressions(list).size());
    }

    @Test
    public void createNumericPredictorExpressionWithExponent() throws IOException {
        Assert.assertTrue(JavaParserUtils.equalsNode(JavaParserUtils.parseExpression(String.format(FileUtils.getFileContent(TEST_01_SOURCE), Double.valueOf(1.23d), 2)), KiePMMLRegressionTableRegressionFactory.createNumericPredictorExpression(PMMLModelTestUtils.getNumericPredictor("predictorName", 2, 1.23d))));
    }

    @Test
    public void createNumericPredictorExpressionWithoutExponent() throws IOException {
        Assert.assertTrue(JavaParserUtils.equalsNode(JavaParserUtils.parseExpression(String.format(FileUtils.getFileContent(TEST_02_SOURCE), Double.valueOf(1.23d))), KiePMMLRegressionTableRegressionFactory.createNumericPredictorExpression(PMMLModelTestUtils.getNumericPredictor("predictorName", 1, 1.23d))));
    }

    @Test
    public void createCategoricalPredictorsExpressions() {
        List list = (List) IntStream.range(0, 3).mapToObj(i -> {
            return (List) IntStream.range(0, 3).mapToObj(i -> {
                return PMMLModelTestUtils.getCategoricalPredictor("predictorName-" + i, i, 1.23d * i);
            }).collect(Collectors.toList());
        }).reduce((list2, list3) -> {
            ArrayList arrayList = new ArrayList();
            arrayList.addAll(list2);
            arrayList.addAll(list3);
            return arrayList;
        }).get();
        BlockStmt blockStmt = new BlockStmt();
        Assert.assertEquals(3L, KiePMMLRegressionTableRegressionFactory.createCategoricalPredictorsExpressions(list, blockStmt).size());
        ((Map) list.stream().collect(Collectors.groupingBy(categoricalPredictor -> {
            return categoricalPredictor.getField().getValue();
        }))).values().forEach(list4 -> {
            commonEvaluateCategoryPredictors(blockStmt, list4);
        });
    }

    @Test
    public void populateWithGroupedCategoricalPredictorMap() throws IOException {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < 3; i++) {
            arrayList.add(PMMLModelTestUtils.getCategoricalPredictor("predictorName-" + i, i, 1.23d * i));
        }
        BlockStmt blockStmt = new BlockStmt();
        KiePMMLRegressionTableRegressionFactory.populateWithGroupedCategoricalPredictorMap(arrayList, blockStmt, "categoricalPredictorMapName");
        Assert.assertTrue(JavaParserUtils.equalsNode(JavaParserUtils.parseBlock(String.format(FileUtils.getFileContent(TEST_04_SOURCE), "categoricalPredictorMapName", ((CategoricalPredictor) arrayList.get(0)).getValue(), ((CategoricalPredictor) arrayList.get(0)).getCoefficient(), ((CategoricalPredictor) arrayList.get(1)).getValue(), ((CategoricalPredictor) arrayList.get(1)).getCoefficient(), ((CategoricalPredictor) arrayList.get(2)).getValue(), ((CategoricalPredictor) arrayList.get(2)).getCoefficient())), blockStmt));
    }

    @Test
    public void createCategoricalPredictorExpression() throws IOException {
        Assert.assertTrue(JavaParserUtils.equalsNode(JavaParserUtils.parseExpression(String.format(FileUtils.getFileContent(TEST_05_SOURCE), "categoricalPredictorMapName")), KiePMMLRegressionTableRegressionFactory.createCategoricalPredictorExpression("categoricalPredictorMapName")));
    }

    @Test
    public void addPredictorTerms() {
        List list = (List) IntStream.range(0, 3).mapToObj(i -> {
            return PMMLModelTestUtils.getPredictorTerm("predictorName-" + i, 1.23d * i, Collections.singletonList("fieldRef-" + i));
        }).collect(Collectors.toList());
        Map addPredictorTerms = KiePMMLRegressionTableRegressionFactory.addPredictorTerms(list, new ClassOrInterfaceDeclaration());
        Assert.assertEquals(list.size(), addPredictorTerms.size());
        IntStream.range(0, list.size()).forEach(i2 -> {
            PredictorTerm predictorTerm = (PredictorTerm) list.get(i2);
            Assert.assertTrue(addPredictorTerms.containsKey(predictorTerm.getName().getValue()));
            Assert.assertEquals(String.format("evaluatePredictorTerm%d", Integer.valueOf(i2 + 1)), ((MethodDeclaration) addPredictorTerms.get(predictorTerm.getName().getValue())).getNameAsString());
        });
    }

    @Test
    public void addPredictorTerm() throws IOException {
        MethodDeclaration addPredictorTerm = KiePMMLRegressionTableRegressionFactory.addPredictorTerm(PMMLModelTestUtils.getPredictorTerm("predictorName", 23.12d, Collections.singletonList("fieldRef")), new ClassOrInterfaceDeclaration(), 3);
        Assert.assertEquals(String.format("evaluatePredictorTerm%d", 3), addPredictorTerm.getNameAsString());
        Assert.assertTrue(JavaParserUtils.equalsNode(JavaParserUtils.parseBlock(String.format(FileUtils.getFileContent(TEST_06_SOURCE), "fieldRef", Double.valueOf(23.12d))), (BlockStmt) addPredictorTerm.getBody().get()));
    }

    @Test
    public void populateOutputFieldsMap() {
        ArrayList arrayList = new ArrayList();
        arrayList.add(getOutputField("KOF-TARGET", RESULT_FEATURE.PREDICTED_VALUE, "TARGET"));
        arrayList.addAll((List) IntStream.range(0, 2).mapToObj(i -> {
            return getOutputField("KOF-PROB-" + i, RESULT_FEATURE.PROBABILITY, "PROB-" + i);
        }).collect(Collectors.toList()));
    }

    private void commonEvaluateCategoryPredictors(BlockStmt blockStmt, List<CategoricalPredictor> list) {
        Assert.assertTrue(blockStmt.getStatements().stream().anyMatch(statement -> {
            return (statement instanceof ExpressionStmt) && (((ExpressionStmt) statement).getExpression() instanceof VariableDeclarationExpr) && statement.toString().equals(String.format("java.util.Map<String, Double> %s = new java.util.HashMap<String, Double>();", KiePMMLModelUtils.getSanitizedVariableName(new StringBuilder().append(((CategoricalPredictor) list.get(0)).getField()).append("Map").toString())));
        }));
        list.forEach(categoricalPredictor -> {
            Assert.assertTrue(blockStmt.getStatements().stream().anyMatch(statement2 -> {
                return (statement2 instanceof ExpressionStmt) && (((ExpressionStmt) statement2).getExpression() instanceof MethodCallExpr) && statement2.toString().equals(String.format("%s.put(\"%s\", %s);", KiePMMLModelUtils.getSanitizedVariableName(new StringBuilder().append(categoricalPredictor.getField()).append("Map").toString()), categoricalPredictor.getValue(), categoricalPredictor.getCoefficient()));
            }));
        });
    }

    private void commonEvaluateGetTargetCategory(ClassOrInterfaceDeclaration classOrInterfaceDeclaration, Expression expression) {
        NodeList statements = ((BlockStmt) ((MethodDeclaration) classOrInterfaceDeclaration.getMethodsByName("getTargetCategory").get(0)).getBody().get()).getStatements();
        Assert.assertEquals(1L, statements.size());
        Assert.assertTrue(statements.get(0) instanceof ReturnStmt);
        Assert.assertEquals(expression, statements.get(0).getExpression().get());
    }

    private KiePMMLOutputField getOutputField(String str, RESULT_FEATURE result_feature, String str2) {
        return KiePMMLOutputField.builder(str, Collections.emptyList()).withResultFeature(result_feature).withTargetField(str2).build();
    }
}
