package org.kie.pmml.models.regression.compiler.factories;

import com.github.javaparser.ast.CompilationUnit;
import com.github.javaparser.ast.NodeList;
import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration;
import com.github.javaparser.ast.body.ConstructorDeclaration;
import com.github.javaparser.ast.expr.CastExpr;
import com.github.javaparser.ast.expr.Expression;
import com.github.javaparser.ast.expr.MethodCallExpr;
import com.github.javaparser.ast.expr.MethodReferenceExpr;
import com.github.javaparser.ast.expr.NameExpr;
import com.github.javaparser.ast.expr.ObjectCreationExpr;
import com.github.javaparser.ast.expr.SimpleName;
import com.github.javaparser.ast.expr.StringLiteralExpr;
import com.github.javaparser.ast.expr.ThisExpr;
import com.github.javaparser.ast.stmt.BlockStmt;
import com.github.javaparser.ast.type.ClassOrInterfaceType;
import com.github.javaparser.ast.type.Type;
import java.util.AbstractMap;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import org.dmg.pmml.regression.RegressionModel;
import org.kie.pmml.api.enums.OP_TYPE;
import org.kie.pmml.api.exceptions.KiePMMLException;
import org.kie.pmml.api.exceptions.KiePMMLInternalException;
import org.kie.pmml.api.iinterfaces.SerializableFunction;
import org.kie.pmml.commons.Constants;
import org.kie.pmml.compiler.commons.utils.CommonCodegenUtils;
import org.kie.pmml.compiler.commons.utils.JavaParserUtils;
import org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO;
import org.kie.pmml.models.regression.model.enums.REGRESSION_NORMALIZATION_METHOD;
import org.kie.pmml.models.regression.model.tuples.KiePMMLTableSourceCategory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:BOOT-INF/lib/kie-pmml-models-regression-compiler-8.16.2-SNAPSHOT.jar:org/kie/pmml/models/regression/compiler/factories/KiePMMLRegressionTableClassificationFactory.class */
public class KiePMMLRegressionTableClassificationFactory {
    public static final String KIE_PMML_REGRESSION_TABLE_CLASSIFICATION_TEMPLATE_JAVA = "KiePMMLRegressionTableClassificationTemplate.tmpl";
    public static final String KIE_PMML_REGRESSION_TABLE_CLASSIFICATION_TEMPLATE = "KiePMMLRegressionTableClassificationTemplate";
    private static final String MAIN_CLASS_NOT_FOUND = "Main class not found";
    public static final List<RegressionModel.NormalizationMethod> SUPPORTED_NORMALIZATION_METHODS = Arrays.asList(RegressionModel.NormalizationMethod.SOFTMAX, RegressionModel.NormalizationMethod.SIMPLEMAX, RegressionModel.NormalizationMethod.NONE, RegressionModel.NormalizationMethod.LOGIT, RegressionModel.NormalizationMethod.PROBIT, RegressionModel.NormalizationMethod.CLOGLOG, RegressionModel.NormalizationMethod.CAUCHIT);
    public static final List<RegressionModel.NormalizationMethod> UNSUPPORTED_NORMALIZATION_METHODS = Arrays.asList(RegressionModel.NormalizationMethod.EXP, RegressionModel.NormalizationMethod.LOGLOG);
    private static final Logger logger = LoggerFactory.getLogger(KiePMMLRegressionTableClassificationFactory.class.getName());
    private static AtomicInteger classArity = new AtomicInteger(0);

    private KiePMMLRegressionTableClassificationFactory() {
    }

    public static Map<String, KiePMMLTableSourceCategory> getRegressionTables(RegressionCompilationDTO regressionCompilationDTO) {
        logger.trace("getRegressionTables {}", regressionCompilationDTO.getRegressionTables());
        LinkedHashMap<String, KiePMMLTableSourceCategory> regressionTables = KiePMMLRegressionTableRegressionFactory.getRegressionTables(regressionCompilationDTO);
        Map.Entry<String, String> regressionTable = getRegressionTable(regressionCompilationDTO, regressionTables);
        regressionTables.put(regressionTable.getKey(), new KiePMMLTableSourceCategory(regressionTable.getValue(), ""));
        return regressionTables;
    }

    public static Map.Entry<String, String> getRegressionTable(RegressionCompilationDTO regressionCompilationDTO, LinkedHashMap<String, KiePMMLTableSourceCategory> linkedHashMap) {
        logger.trace("getRegressionTable {}", linkedHashMap);
        String str = "KiePMMLRegressionTableClassification" + classArity.addAndGet(1);
        CompilationUnit kiePMMLModelCompilationUnit = JavaParserUtils.getKiePMMLModelCompilationUnit(str, regressionCompilationDTO.getPackageName(), KIE_PMML_REGRESSION_TABLE_CLASSIFICATION_TEMPLATE_JAVA, KIE_PMML_REGRESSION_TABLE_CLASSIFICATION_TEMPLATE);
        ClassOrInterfaceDeclaration orElseThrow = kiePMMLModelCompilationUnit.getClassByName(str).orElseThrow(() -> {
            return new KiePMMLException("Main class not found: " + str);
        });
        boolean isBinary = regressionCompilationDTO.isBinary(linkedHashMap.size());
        ConstructorDeclaration orElseThrow2 = orElseThrow.getDefaultConstructor().orElseThrow(() -> {
            return new KiePMMLInternalException(String.format(Constants.MISSING_DEFAULT_CONSTRUCTOR, orElseThrow.getName()));
        });
        setConstructor(regressionCompilationDTO, orElseThrow2, orElseThrow.getName(), null, regressionCompilationDTO.getModelNormalizationMethod(), isBinary);
        addMapPopulation(orElseThrow2.getBody(), linkedHashMap);
        return new AbstractMap.SimpleEntry(JavaParserUtils.getFullClassName(kiePMMLModelCompilationUnit), kiePMMLModelCompilationUnit.toString());
    }

    static void setConstructor(RegressionCompilationDTO regressionCompilationDTO, ConstructorDeclaration constructorDeclaration, SimpleName simpleName, Object obj, RegressionModel.NormalizationMethod normalizationMethod, boolean z) {
        constructorDeclaration.setName(simpleName);
        BlockStmt body = constructorDeclaration.getBody();
        CommonCodegenUtils.setAssignExpressionValue(body, "targetField", new StringLiteralExpr(regressionCompilationDTO.getTargetFieldName()));
        REGRESSION_NORMALIZATION_METHOD defaultREGRESSION_NORMALIZATION_METHOD = regressionCompilationDTO.getDefaultREGRESSION_NORMALIZATION_METHOD();
        CommonCodegenUtils.setAssignExpressionValue(body, "regressionNormalizationMethod", new NameExpr(defaultREGRESSION_NORMALIZATION_METHOD.getClass().getSimpleName() + "." + defaultREGRESSION_NORMALIZATION_METHOD.name()));
        OP_TYPE op_type = regressionCompilationDTO.getOP_TYPE();
        if (op_type != null) {
            CommonCodegenUtils.setAssignExpressionValue(body, "opType", new NameExpr(op_type.getClass().getSimpleName() + "." + op_type.name()));
        }
        CommonCodegenUtils.setAssignExpressionValue(body, "targetCategory", CommonCodegenUtils.getExpressionForObject(obj));
        CommonCodegenUtils.setAssignExpressionValue(body, "isBinary", CommonCodegenUtils.getExpressionForObject(Boolean.valueOf(z)));
        CommonCodegenUtils.setAssignExpressionValue(body, "probabilityMapFunction", createProbabilityMapFunctionExpression(normalizationMethod, z));
    }

    static Expression createProbabilityMapFunctionExpression(RegressionModel.NormalizationMethod normalizationMethod, boolean z) {
        if (UNSUPPORTED_NORMALIZATION_METHODS.contains(normalizationMethod)) {
            throw new KiePMMLInternalException(String.format("Unsupported NormalizationMethod %s", normalizationMethod));
        }
        return createProbabilityMapFunctionSupportedExpression(normalizationMethod, z);
    }

    static MethodReferenceExpr createProbabilityMapFunctionSupportedExpression(RegressionModel.NormalizationMethod normalizationMethod, boolean z) {
        String name = normalizationMethod.name();
        if (RegressionModel.NormalizationMethod.NONE.equals(normalizationMethod) && z) {
            name = name + "Binary";
        }
        String format = String.format("get%sProbabilityMap", name);
        CastExpr castExpr = new CastExpr();
        ClassOrInterfaceType typedClassOrInterfaceTypeByTypeNames = CommonCodegenUtils.getTypedClassOrInterfaceTypeByTypeNames(LinkedHashMap.class.getCanonicalName(), Arrays.asList(String.class.getSimpleName(), Double.class.getSimpleName()));
        castExpr.setType((Type) CommonCodegenUtils.getTypedClassOrInterfaceTypeByTypes(SerializableFunction.class.getCanonicalName(), Arrays.asList(typedClassOrInterfaceTypeByTypeNames, typedClassOrInterfaceTypeByTypeNames)));
        castExpr.setExpression((Expression) new ThisExpr());
        MethodReferenceExpr methodReferenceExpr = new MethodReferenceExpr();
        methodReferenceExpr.setScope(castExpr);
        methodReferenceExpr.setIdentifier(format);
        return methodReferenceExpr;
    }

    static void addMapPopulation(BlockStmt blockStmt, LinkedHashMap<String, KiePMMLTableSourceCategory> linkedHashMap) {
        linkedHashMap.forEach((str, kiePMMLTableSourceCategory) -> {
            ObjectCreationExpr objectCreationExpr = new ObjectCreationExpr();
            objectCreationExpr.setType(str);
            blockStmt.addStatement(new MethodCallExpr(new NameExpr("categoryTableMap"), "put", (NodeList<Expression>) NodeList.nodeList(new StringLiteralExpr(kiePMMLTableSourceCategory.getCategory()), objectCreationExpr)));
        });
    }
}
