package org.kie.pmml.models.regression.compiler.factories;

import com.github.javaparser.ast.CompilationUnit;
import com.github.javaparser.ast.NodeList;
import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration;
import com.github.javaparser.ast.body.ConstructorDeclaration;
import com.github.javaparser.ast.body.MethodDeclaration;
import com.github.javaparser.ast.expr.BooleanLiteralExpr;
import com.github.javaparser.ast.expr.Expression;
import com.github.javaparser.ast.expr.MethodCallExpr;
import com.github.javaparser.ast.expr.NameExpr;
import com.github.javaparser.ast.expr.SimpleName;
import com.github.javaparser.ast.expr.StringLiteralExpr;
import com.github.javaparser.ast.stmt.BlockStmt;
import com.github.javaparser.ast.stmt.ExpressionStmt;
import com.github.javaparser.ast.stmt.ReturnStmt;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.OpType;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.ResultFeature;
import org.dmg.pmml.regression.RegressionModel;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.kie.pmml.api.enums.OP_TYPE;
import org.kie.pmml.api.exceptions.KiePMMLInternalException;
import org.kie.pmml.compiler.commons.testutils.CodegenTestUtils;
import org.kie.pmml.compiler.commons.utils.JavaParserUtils;
import org.kie.pmml.models.regression.model.enums.REGRESSION_NORMALIZATION_METHOD;
import org.kie.pmml.models.regression.model.tuples.KiePMMLTableSourceCategory;

/* loaded from: input_file:org/kie/pmml/models/regression/compiler/factories/KiePMMLRegressionTableClassificationFactoryTest.class */
public class KiePMMLRegressionTableClassificationFactoryTest extends AbstractKiePMMLRegressionTableRegressionFactoryTest {
    private static final List<RegressionModel.NormalizationMethod> SUPPORTED_NORMALIZATION_METHODS = Arrays.asList(RegressionModel.NormalizationMethod.SOFTMAX, RegressionModel.NormalizationMethod.SIMPLEMAX, RegressionModel.NormalizationMethod.NONE, RegressionModel.NormalizationMethod.LOGIT, RegressionModel.NormalizationMethod.PROBIT, RegressionModel.NormalizationMethod.CLOGLOG, RegressionModel.NormalizationMethod.CAUCHIT);
    private static final List<RegressionModel.NormalizationMethod> UNSUPPORTED_NORMALIZATION_METHODS = Arrays.asList(RegressionModel.NormalizationMethod.EXP, RegressionModel.NormalizationMethod.LOGLOG);
    private CompilationUnit compilationUnit;
    private ClassOrInterfaceDeclaration modelTemplate;

    @Before
    public void setup() {
        this.compilationUnit = JavaParserUtils.getFromFileName("KiePMMLRegressionTableClassificationTemplate.tmpl");
        this.modelTemplate = (ClassOrInterfaceDeclaration) this.compilationUnit.getClassByName("KiePMMLRegressionTableClassificationTemplate").get();
    }

    @Test
    public void getRegressionTables() {
        Map regressionTables = KiePMMLRegressionTableClassificationFactory.getRegressionTables(Arrays.asList(getRegressionTable(3.5d, "professional"), getRegressionTable(27.4d, "clerical")), RegressionModel.NormalizationMethod.SOFTMAX, OpType.CATEGORICAL, Arrays.asList(getOutputField("CAT-1", ResultFeature.PROBABILITY, "CatPred-1"), getOutputField("NUM-1", ResultFeature.PROBABILITY, "NumPred-0"), getOutputField("PREV", ResultFeature.PREDICTED_VALUE, null)), "targetField", "packageName");
        Assert.assertNotNull(regressionTables);
        Assert.assertEquals(3L, regressionTables.size());
        regressionTables.values().forEach(kiePMMLTableSourceCategory -> {
            commonValidateKiePMMLRegressionTable(kiePMMLTableSourceCategory.getSource());
        });
    }

    @Test
    public void getRegressionTable() {
        Assert.assertNotNull(KiePMMLRegressionTableClassificationFactory.getRegressionTable(new LinkedHashMap(), RegressionModel.NormalizationMethod.SOFTMAX, OpType.CATEGORICAL, Arrays.asList(getOutputField("CAT-1", ResultFeature.PROBABILITY, "CatPred-1"), getOutputField("NUM-1", ResultFeature.PROBABILITY, "NumPred-0"), getOutputField("PREV", ResultFeature.PREDICTED_VALUE, null)), "targetField", "packageName"));
    }

    @Test
    public void setConstructor() {
        ConstructorDeclaration constructorDeclaration = (ConstructorDeclaration) this.modelTemplate.getDefaultConstructor().get();
        SimpleName simpleName = new SimpleName("GeneratedClassName");
        REGRESSION_NORMALIZATION_METHOD regression_normalization_method = REGRESSION_NORMALIZATION_METHOD.CAUCHIT;
        OP_TYPE op_type = OP_TYPE.CATEGORICAL;
        KiePMMLRegressionTableClassificationFactory.setConstructor(constructorDeclaration, simpleName, "targetField", regression_normalization_method, op_type);
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        hashMap2.put("targetField", new StringLiteralExpr("targetField"));
        hashMap2.put("regressionNormalizationMethod", new NameExpr(regression_normalization_method.getClass().getSimpleName() + "." + regression_normalization_method.name()));
        hashMap2.put("opType", new NameExpr(op_type.getClass().getSimpleName() + "." + op_type.name()));
        Assert.assertTrue(CodegenTestUtils.commonEvaluateConstructor(constructorDeclaration, simpleName.asString(), hashMap, hashMap2));
    }

    @Test
    public void addMapPopulation() {
        BlockStmt blockStmt = new BlockStmt();
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        IntStream.range(0, 3).forEach(i -> {
        });
        KiePMMLRegressionTableClassificationFactory.addMapPopulation(blockStmt, linkedHashMap);
        NodeList statements = blockStmt.getStatements();
        Assert.assertEquals(linkedHashMap.size(), statements.size());
        statements.forEach(statement -> {
            Assert.assertTrue(statement instanceof ExpressionStmt);
            Assert.assertTrue(((ExpressionStmt) statement).getExpression() instanceof MethodCallExpr);
            MethodCallExpr expression = ((ExpressionStmt) statement).getExpression();
            Assert.assertEquals("categoryTableMap", ((Expression) expression.getScope().get()).asNameExpr().toString());
            Assert.assertEquals("put", expression.getName().asString());
        });
        List list = (List) statements.stream().map(statement2 -> {
            return ((ExpressionStmt) statement2).getExpression();
        }).collect(Collectors.toList());
        IntStream.range(0, 3).forEach(i2 -> {
            String str = "KEY" + i2;
            KiePMMLTableSourceCategory kiePMMLTableSourceCategory = (KiePMMLTableSourceCategory) linkedHashMap.get(str);
            MethodCallExpr methodCallExpr = (MethodCallExpr) list.get(i2);
            Assert.assertEquals(kiePMMLTableSourceCategory.getCategory(), methodCallExpr.getArguments().get(0).getValue());
            Assert.assertEquals(str, methodCallExpr.getArguments().get(1).getTypeAsString());
        });
    }

    @Test
    public void populateGetProbabilityMapMethodSupported() {
        SUPPORTED_NORMALIZATION_METHODS.forEach(normalizationMethod -> {
            KiePMMLRegressionTableClassificationFactory.populateGetProbabilityMapMethod(normalizationMethod, this.modelTemplate);
            BlockStmt blockStmt = (BlockStmt) ((MethodDeclaration) this.modelTemplate.getMethodsByName("getProbabilityMap").get(0)).getBody().get();
            Assert.assertNotNull(blockStmt.getStatements());
            Assert.assertFalse(blockStmt.getStatements().isEmpty());
        });
    }

    @Test
    public void populateGetProbabilityMapMethodUnsupported() {
        UNSUPPORTED_NORMALIZATION_METHODS.forEach(normalizationMethod -> {
            try {
                KiePMMLRegressionTableClassificationFactory.populateGetProbabilityMapMethod(normalizationMethod, this.modelTemplate);
                Assert.fail("Expecting KiePMMLInternalException with normalizationMethod " + normalizationMethod);
            } catch (Exception e) {
                Assert.assertTrue(e instanceof KiePMMLInternalException);
            }
        });
    }

    @Test
    public void populateIsBinaryMethod() {
        KiePMMLRegressionTableClassificationFactory.populateIsBinaryMethod(OpType.CATEGORICAL, 1, this.modelTemplate);
        commonEvaluateIsBinaryMethod(this.modelTemplate, false);
        KiePMMLRegressionTableClassificationFactory.populateIsBinaryMethod(OpType.CATEGORICAL, 2, this.modelTemplate);
        commonEvaluateIsBinaryMethod(this.modelTemplate, true);
        KiePMMLRegressionTableClassificationFactory.populateIsBinaryMethod(OpType.CATEGORICAL, 3, this.modelTemplate);
        commonEvaluateIsBinaryMethod(this.modelTemplate, false);
        KiePMMLRegressionTableClassificationFactory.populateIsBinaryMethod(OpType.CONTINUOUS, 1, this.modelTemplate);
        commonEvaluateIsBinaryMethod(this.modelTemplate, false);
        KiePMMLRegressionTableClassificationFactory.populateIsBinaryMethod(OpType.CONTINUOUS, 2, this.modelTemplate);
        commonEvaluateIsBinaryMethod(this.modelTemplate, false);
        KiePMMLRegressionTableClassificationFactory.populateIsBinaryMethod(OpType.CONTINUOUS, 3, this.modelTemplate);
        commonEvaluateIsBinaryMethod(this.modelTemplate, false);
        KiePMMLRegressionTableClassificationFactory.populateIsBinaryMethod(OpType.ORDINAL, 1, this.modelTemplate);
        commonEvaluateIsBinaryMethod(this.modelTemplate, false);
        KiePMMLRegressionTableClassificationFactory.populateIsBinaryMethod(OpType.ORDINAL, 2, this.modelTemplate);
        commonEvaluateIsBinaryMethod(this.modelTemplate, false);
        KiePMMLRegressionTableClassificationFactory.populateIsBinaryMethod(OpType.ORDINAL, 3, this.modelTemplate);
        commonEvaluateIsBinaryMethod(this.modelTemplate, false);
    }

    private void commonEvaluateIsBinaryMethod(ClassOrInterfaceDeclaration classOrInterfaceDeclaration, boolean z) {
        NodeList statements = ((BlockStmt) ((MethodDeclaration) classOrInterfaceDeclaration.getMethodsByName("isBinary").get(0)).getBody().get()).getStatements();
        Assert.assertEquals(1L, statements.size());
        Assert.assertTrue(statements.get(0) instanceof ReturnStmt);
        ReturnStmt returnStmt = statements.get(0);
        Assert.assertTrue(returnStmt.getExpression().isPresent() && (returnStmt.getExpression().get() instanceof BooleanLiteralExpr));
        Assert.assertEquals(Boolean.valueOf(z), Boolean.valueOf(((BooleanLiteralExpr) returnStmt.getExpression().get()).getValue()));
    }

    private OutputField getOutputField(String str, ResultFeature resultFeature, String str2) {
        OutputField outputField = new OutputField();
        outputField.setName(FieldName.create(str));
        outputField.setResultFeature(resultFeature);
        if (str2 != null) {
            outputField.setTargetField(FieldName.create(str2));
        }
        return outputField;
    }
}
