package org.kie.pmml.models.regression.compiler.factories;

import com.github.javaparser.ast.CompilationUnit;
import com.github.javaparser.ast.NodeList;
import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration;
import com.github.javaparser.ast.body.ConstructorDeclaration;
import com.github.javaparser.ast.body.MethodDeclaration;
import com.github.javaparser.ast.expr.DoubleLiteralExpr;
import com.github.javaparser.ast.expr.Expression;
import com.github.javaparser.ast.expr.MethodCallExpr;
import com.github.javaparser.ast.expr.NameExpr;
import com.github.javaparser.ast.expr.SimpleName;
import com.github.javaparser.ast.expr.StringLiteralExpr;
import com.github.javaparser.ast.stmt.BlockStmt;
import com.github.javaparser.ast.stmt.ExpressionStmt;
import com.github.javaparser.ast.stmt.ReturnStmt;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.dmg.pmml.regression.CategoricalPredictor;
import org.dmg.pmml.regression.PredictorTerm;
import org.dmg.pmml.regression.RegressionModel;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.kie.pmml.api.enums.RESULT_FEATURE;
import org.kie.pmml.api.exceptions.KiePMMLException;
import org.kie.pmml.api.exceptions.KiePMMLInternalException;
import org.kie.pmml.commons.model.KiePMMLOutputField;
import org.kie.pmml.compiler.commons.testutils.CodegenTestUtils;
import org.kie.pmml.compiler.commons.testutils.PMMLModelTestUtils;
import org.kie.pmml.compiler.commons.utils.JavaParserUtils;

/* loaded from: input_file:org/kie/pmml/models/regression/compiler/factories/KiePMMLRegressionTableRegressionFactoryTest.class */
public class KiePMMLRegressionTableRegressionFactoryTest extends AbstractKiePMMLRegressionTableRegressionFactoryTest {
    private static final List<RegressionModel.NormalizationMethod> SUPPORTED_NORMALIZATION_METHODS = Arrays.asList(RegressionModel.NormalizationMethod.SOFTMAX, RegressionModel.NormalizationMethod.LOGIT, RegressionModel.NormalizationMethod.EXP, RegressionModel.NormalizationMethod.PROBIT, RegressionModel.NormalizationMethod.CLOGLOG, RegressionModel.NormalizationMethod.CAUCHIT, RegressionModel.NormalizationMethod.NONE);
    private static final List<RegressionModel.NormalizationMethod> UNSUPPORTED_NORMALIZATION_METHODS = Arrays.asList(RegressionModel.NormalizationMethod.SIMPLEMAX, RegressionModel.NormalizationMethod.LOGLOG);
    private static CompilationUnit compilationUnit;
    private static ClassOrInterfaceDeclaration modelTemplate;

    @Before
    public void setup() {
        compilationUnit = JavaParserUtils.getFromFileName("KiePMMLRegressionTableRegressionTemplate.tmpl");
        modelTemplate = (ClassOrInterfaceDeclaration) compilationUnit.getClassByName("KiePMMLRegressionTableRegressionTemplate").get();
    }

    @Test
    public void getRegressionTableTest() {
        this.regressionTable = getRegressionTable(3.5d, "professional");
        LinkedHashMap regressionTables = KiePMMLRegressionTableRegressionFactory.getRegressionTables(Collections.singletonList(this.regressionTable), RegressionModel.NormalizationMethod.CAUCHIT, Collections.emptyList(), "targetField", "packageName");
        Assert.assertNotNull(regressionTables);
        regressionTables.values().forEach(kiePMMLTableSourceCategory -> {
            commonValidateKiePMMLRegressionTable(kiePMMLTableSourceCategory.getSource());
        });
    }

    @Test
    public void setConstructor() {
        this.regressionTable = getRegressionTable(3.5d, "professional");
        ConstructorDeclaration constructorDeclaration = (ConstructorDeclaration) modelTemplate.getDefaultConstructor().get();
        SimpleName simpleName = new SimpleName("TableName");
        KiePMMLRegressionTableRegressionFactory.setConstructor(this.regressionTable, constructorDeclaration, simpleName, "targetField");
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        hashMap2.put("targetField", new StringLiteralExpr("targetField"));
        hashMap2.put("intercept", new DoubleLiteralExpr(String.valueOf(3.5d)));
        Assert.assertTrue(CodegenTestUtils.commonEvaluateConstructor(constructorDeclaration, simpleName.asString(), hashMap, hashMap2));
    }

    @Test
    public void populateOutputFieldsMapWithResult() {
        ArrayList arrayList = new ArrayList();
        KiePMMLOutputField outputField = getOutputField("KOF-TARGET", RESULT_FEATURE.PREDICTED_VALUE, "TARGET");
        arrayList.add(outputField);
        arrayList.addAll((List) IntStream.range(0, 2).mapToObj(i -> {
            return getOutputField("KOF-PROB-" + i, RESULT_FEATURE.PROBABILITY, "PROB-" + i);
        }).collect(Collectors.toList()));
        BlockStmt blockStmt = new BlockStmt();
        KiePMMLRegressionTableRegressionFactory.populateOutputFieldsMapWithResult(blockStmt, arrayList);
        NodeList statements = blockStmt.getStatements();
        Assert.assertEquals(1L, statements.size());
        Assert.assertTrue(statements.get(0) instanceof ExpressionStmt);
        ExpressionStmt expressionStmt = statements.get(0);
        Assert.assertTrue(expressionStmt.getExpression() instanceof MethodCallExpr);
        MethodCallExpr expression = expressionStmt.getExpression();
        Assert.assertEquals("outputFieldsMap", ((Expression) expression.getScope().get()).asNameExpr().toString());
        Assert.assertEquals("put", expression.getName().asString());
        NodeList arguments = expression.getArguments();
        Assert.assertEquals(2L, arguments.size());
        Assert.assertTrue(arguments.get(0) instanceof StringLiteralExpr);
        Assert.assertEquals(outputField.getName(), arguments.get(0).asString());
        Assert.assertTrue(arguments.get(1) instanceof NameExpr);
        Assert.assertEquals("result", arguments.get(1).getNameAsString());
    }

    @Test
    public void addNumericPredictorWithExponent() {
        Assert.assertEquals(String.format("{\n    double coefficient = %s;\n    double exponent = %s.0;\n    // Considering exponent because it is != 1\n    return Math.pow(input, exponent) * coefficient;\n}", Double.valueOf(1.23d), 2).replace("\n", System.lineSeparator()), ((BlockStmt) KiePMMLRegressionTableRegressionFactory.addNumericPredictor(PMMLModelTestUtils.getNumericPredictor("predictorName", 2, 1.23d), new ClassOrInterfaceDeclaration(), 3).getBody().get()).toString());
    }

    @Test
    public void addNumericPredictors() {
        List list = (List) IntStream.range(0, 3).mapToObj(i -> {
            return PMMLModelTestUtils.getNumericPredictor("predictorName-" + i, i, 1.23d * i);
        }).collect(Collectors.toList());
        Assert.assertEquals(list.size(), KiePMMLRegressionTableRegressionFactory.addNumericPredictors(list, new ClassOrInterfaceDeclaration()).size());
    }

    @Test
    public void addNumericPredictorWithoutExponent() {
        Assert.assertEquals(String.format("{\n    double coefficient = %s;\n    // Ignoring exponent because it is 1\n    return input * coefficient;\n}", Double.valueOf(1.23d)).replace("\n", System.lineSeparator()), ((BlockStmt) KiePMMLRegressionTableRegressionFactory.addNumericPredictor(PMMLModelTestUtils.getNumericPredictor("predictorName", 1, 1.23d), new ClassOrInterfaceDeclaration(), 3).getBody().get()).toString());
    }

    @Test
    public void getNumericPredictorWithExponentTemplate() {
        Assert.assertEquals(String.format("{\n    double coefficient = %s;\n    double exponent = %s.0;\n    // Considering exponent because it is != 1\n    return Math.pow(input, exponent) * coefficient;\n}", Double.valueOf(1.23d), 2).replace("\n", System.lineSeparator()), ((BlockStmt) KiePMMLRegressionTableRegressionFactory.getNumericPredictorWithExponentTemplate(PMMLModelTestUtils.getNumericPredictor("predictorName", 2, 1.23d), (ClassOrInterfaceDeclaration) JavaParserUtils.getFromFileName("KiePMMLEvaluateMethodTemplate.tmpl").clone().getClassByName("KiePMMLEvaluateMethodTemplate").orElseThrow(() -> {
            return new KiePMMLException("Main class not found");
        })).getBody().get()).toString());
    }

    @Test
    public void getNumericPredictorWithoutExponentTemplate() {
        Assert.assertEquals(String.format("{\n    double coefficient = %s;\n    // Ignoring exponent because it is 1\n    return input * coefficient;\n}", Double.valueOf(1.23d)).replace("\n", System.lineSeparator()), ((BlockStmt) KiePMMLRegressionTableRegressionFactory.getNumericPredictorWithoutExponentTemplate(PMMLModelTestUtils.getNumericPredictor("predictorName", 2, 1.23d), (ClassOrInterfaceDeclaration) JavaParserUtils.getFromFileName("KiePMMLEvaluateMethodTemplate.tmpl").clone().getClassByName("KiePMMLEvaluateMethodTemplate").orElseThrow(() -> {
            return new KiePMMLException("Main class not found");
        })).getBody().get()).toString());
    }

    @Test
    public void addCategoricalPredictors() {
        Assert.assertEquals(3L, KiePMMLRegressionTableRegressionFactory.addCategoricalPredictors((List) IntStream.range(0, 3).mapToObj(i -> {
            return (List) IntStream.range(0, 3).mapToObj(i -> {
                return PMMLModelTestUtils.getCategoricalPredictor("predictorName-" + i, i, 1.23d * i);
            }).collect(Collectors.toList());
        }).reduce((list, list2) -> {
            ArrayList arrayList = new ArrayList();
            arrayList.addAll(list);
            arrayList.addAll(list2);
            return arrayList;
        }).get(), new ClassOrInterfaceDeclaration()).size());
    }

    @Test
    public void addGroupedCategoricalPredictor() {
        List list = (List) IntStream.range(0, 3).mapToObj(i -> {
            return PMMLModelTestUtils.getCategoricalPredictor("predictorName-" + i, i, 1.23d * i);
        }).collect(Collectors.toList());
        MethodDeclaration addGroupedCategoricalPredictor = KiePMMLRegressionTableRegressionFactory.addGroupedCategoricalPredictor(list, new ClassOrInterfaceDeclaration(), 3);
        Assert.assertEquals(String.format("evaluateCategoricalPredictor%d", 3), addGroupedCategoricalPredictor.getNameAsString());
        Assert.assertEquals(String.format("{\n    if (Objects.equals(%s, input))\n        return %s;\n    else if (Objects.equals(%s, input))\n        return %s;\n    else if (Objects.equals(%s, input))\n        return %s;\n    else\n        return 0.0;\n}", ((CategoricalPredictor) list.get(0)).getValue(), ((CategoricalPredictor) list.get(0)).getCoefficient(), ((CategoricalPredictor) list.get(1)).getValue(), ((CategoricalPredictor) list.get(1)).getCoefficient(), ((CategoricalPredictor) list.get(2)).getValue(), ((CategoricalPredictor) list.get(2)).getCoefficient()).replace("\n", System.lineSeparator()), ((BlockStmt) addGroupedCategoricalPredictor.getBody().get()).toString());
    }

    @Test
    public void addPredictorTerms() {
        List list = (List) IntStream.range(0, 3).mapToObj(i -> {
            return PMMLModelTestUtils.getPredictorTerm("predictorName-" + i, 1.23d * i, Collections.singletonList("fieldRef-" + i));
        }).collect(Collectors.toList());
        Map addPredictorTerms = KiePMMLRegressionTableRegressionFactory.addPredictorTerms(list, new ClassOrInterfaceDeclaration());
        Assert.assertEquals(list.size(), addPredictorTerms.size());
        IntStream.range(0, list.size()).forEach(i2 -> {
            PredictorTerm predictorTerm = (PredictorTerm) list.get(i2);
            Assert.assertTrue(addPredictorTerms.containsKey(predictorTerm.getName().getValue()));
            Assert.assertEquals(String.format("evaluatePredictorTerm%d", Integer.valueOf(i2 + 1)), ((MethodDeclaration) addPredictorTerms.get(predictorTerm.getName().getValue())).getNameAsString());
        });
    }

    @Test
    public void addPredictorTerm() {
        MethodDeclaration addPredictorTerm = KiePMMLRegressionTableRegressionFactory.addPredictorTerm(PMMLModelTestUtils.getPredictorTerm("predictorName", 23.12d, Collections.singletonList("fieldRef")), new ClassOrInterfaceDeclaration(), 3);
        Assert.assertEquals(String.format("evaluatePredictorTerm%d", 3), addPredictorTerm.getNameAsString());
        Assert.assertEquals(String.format("{\n    final AtomicReference<Double> result = new AtomicReference<>(1.0);\n    List<String> fieldRefs = Arrays.asList(\"%s\");\n    for (String key : resultMap.keySet()) {\n        if (fieldRefs.contains(key)) {\n            result.set(result.get() * (Double) resultMap.get(key));\n        }\n    }\n    double coefficient = %s;\n    return result.get() * coefficient;\n}", "fieldRef", Double.valueOf(23.12d)).replace("\n", System.lineSeparator()), ((BlockStmt) addPredictorTerm.getBody().get()).toString());
    }

    @Test
    public void populateGetTargetCategoryTargetCategoryNull() {
        KiePMMLRegressionTableRegressionFactory.populateGetTargetCategory(modelTemplate, (Object) null);
        commonEvaluateGetTargetCategory(modelTemplate, new NameExpr("null"));
    }

    @Test
    public void populateGetTargetCategoryTargetCategoryString() {
        KiePMMLRegressionTableRegressionFactory.populateGetTargetCategory(modelTemplate, "CATEGORY");
        commonEvaluateGetTargetCategory(modelTemplate, new StringLiteralExpr("CATEGORY"));
    }

    @Test
    public void populateGetTargetCategoryTargetCategoryNoString() {
        KiePMMLRegressionTableRegressionFactory.populateGetTargetCategory(modelTemplate, 435);
        commonEvaluateGetTargetCategory(modelTemplate, new NameExpr("435"));
    }

    @Test
    public void populateUpdateResultSupported() {
        SUPPORTED_NORMALIZATION_METHODS.forEach(normalizationMethod -> {
            KiePMMLRegressionTableRegressionFactory.populateUpdateResult(modelTemplate, normalizationMethod);
            BlockStmt blockStmt = (BlockStmt) ((MethodDeclaration) modelTemplate.getMethodsByName("updateResult").get(0)).getBody().get();
            Assert.assertNotNull(blockStmt.getStatements());
            if (normalizationMethod.equals(RegressionModel.NormalizationMethod.NONE)) {
                Assert.assertTrue(blockStmt.getStatements().isEmpty());
            } else {
                Assert.assertFalse(blockStmt.getStatements().isEmpty());
            }
        });
    }

    @Test
    public void populateUpdateResultUnsupported() {
        UNSUPPORTED_NORMALIZATION_METHODS.forEach(normalizationMethod -> {
            try {
                KiePMMLRegressionTableRegressionFactory.populateUpdateResult(modelTemplate, normalizationMethod);
                Assert.fail("Expecting KiePMMLInternalException with normalizationMethod " + normalizationMethod);
            } catch (Exception e) {
                Assert.assertTrue(e instanceof KiePMMLInternalException);
            }
        });
    }

    @Test
    public void populateOutputFieldsMap() {
        ArrayList arrayList = new ArrayList();
        arrayList.add(getOutputField("KOF-TARGET", RESULT_FEATURE.PREDICTED_VALUE, "TARGET"));
        arrayList.addAll((List) IntStream.range(0, 2).mapToObj(i -> {
            return getOutputField("KOF-PROB-" + i, RESULT_FEATURE.PROBABILITY, "PROB-" + i);
        }).collect(Collectors.toList()));
        KiePMMLRegressionTableRegressionFactory.populateOutputFieldsMap(modelTemplate, arrayList);
        Assert.assertEquals(1L, ((BlockStmt) ((MethodDeclaration) modelTemplate.getMethodsByName("populateOutputFieldsMapWithResult").get(0)).getBody().get()).getStatements().size());
    }

    private void commonEvaluateGetTargetCategory(ClassOrInterfaceDeclaration classOrInterfaceDeclaration, Expression expression) {
        NodeList statements = ((BlockStmt) ((MethodDeclaration) classOrInterfaceDeclaration.getMethodsByName("getTargetCategory").get(0)).getBody().get()).getStatements();
        Assert.assertEquals(1L, statements.size());
        Assert.assertTrue(statements.get(0) instanceof ReturnStmt);
        Assert.assertEquals(expression, statements.get(0).getExpression().get());
    }

    private KiePMMLOutputField getOutputField(String str, RESULT_FEATURE result_feature, String str2) {
        return KiePMMLOutputField.builder(str, Collections.emptyList()).withResultFeature(result_feature).withTargetField(str2).build();
    }
}
