/*
 * Decompiled with CFR 0.152.
 */
package org.kie.pmml.models.regression.compiler.factories;

import com.github.javaparser.ast.CompilationUnit;
import com.github.javaparser.ast.Node;
import com.github.javaparser.ast.NodeList;
import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration;
import com.github.javaparser.ast.body.ConstructorDeclaration;
import com.github.javaparser.ast.expr.CastExpr;
import com.github.javaparser.ast.expr.Expression;
import com.github.javaparser.ast.expr.MethodCallExpr;
import com.github.javaparser.ast.expr.MethodReferenceExpr;
import com.github.javaparser.ast.expr.NameExpr;
import com.github.javaparser.ast.expr.ObjectCreationExpr;
import com.github.javaparser.ast.expr.SimpleName;
import com.github.javaparser.ast.expr.StringLiteralExpr;
import com.github.javaparser.ast.expr.ThisExpr;
import com.github.javaparser.ast.stmt.BlockStmt;
import com.github.javaparser.ast.type.ClassOrInterfaceType;
import com.github.javaparser.ast.type.Type;
import java.util.AbstractMap;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import org.dmg.pmml.regression.RegressionModel;
import org.kie.pmml.api.enums.OP_TYPE;
import org.kie.pmml.api.exceptions.KiePMMLException;
import org.kie.pmml.api.exceptions.KiePMMLInternalException;
import org.kie.pmml.api.iinterfaces.SerializableFunction;
import org.kie.pmml.compiler.commons.utils.CommonCodegenUtils;
import org.kie.pmml.compiler.commons.utils.JavaParserUtils;
import org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO;
import org.kie.pmml.models.regression.compiler.factories.KiePMMLRegressionTableRegressionFactory;
import org.kie.pmml.models.regression.model.enums.REGRESSION_NORMALIZATION_METHOD;
import org.kie.pmml.models.regression.model.tuples.KiePMMLTableSourceCategory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KiePMMLRegressionTableClassificationFactory {
    public static final String KIE_PMML_REGRESSION_TABLE_CLASSIFICATION_TEMPLATE_JAVA = "KiePMMLRegressionTableClassificationTemplate.tmpl";
    public static final String KIE_PMML_REGRESSION_TABLE_CLASSIFICATION_TEMPLATE = "KiePMMLRegressionTableClassificationTemplate";
    public static final List<RegressionModel.NormalizationMethod> SUPPORTED_NORMALIZATION_METHODS = Arrays.asList(RegressionModel.NormalizationMethod.SOFTMAX, RegressionModel.NormalizationMethod.SIMPLEMAX, RegressionModel.NormalizationMethod.NONE, RegressionModel.NormalizationMethod.LOGIT, RegressionModel.NormalizationMethod.PROBIT, RegressionModel.NormalizationMethod.CLOGLOG, RegressionModel.NormalizationMethod.CAUCHIT);
    public static final List<RegressionModel.NormalizationMethod> UNSUPPORTED_NORMALIZATION_METHODS = Arrays.asList(RegressionModel.NormalizationMethod.EXP, RegressionModel.NormalizationMethod.LOGLOG);
    private static final Logger logger = LoggerFactory.getLogger((String)KiePMMLRegressionTableClassificationFactory.class.getName());
    private static final String MAIN_CLASS_NOT_FOUND = "Main class not found";
    private static AtomicInteger classArity = new AtomicInteger(0);

    private KiePMMLRegressionTableClassificationFactory() {
    }

    public static Map<String, KiePMMLTableSourceCategory> getRegressionTables(RegressionCompilationDTO compilationDTO) {
        logger.trace("getRegressionTables {}", compilationDTO.getRegressionTables());
        LinkedHashMap<String, KiePMMLTableSourceCategory> toReturn = KiePMMLRegressionTableRegressionFactory.getRegressionTables(compilationDTO);
        Map.Entry<String, String> regressionTableEntry = KiePMMLRegressionTableClassificationFactory.getRegressionTable(compilationDTO, toReturn);
        toReturn.put(regressionTableEntry.getKey(), new KiePMMLTableSourceCategory(regressionTableEntry.getValue(), ""));
        return toReturn;
    }

    public static Map.Entry<String, String> getRegressionTable(RegressionCompilationDTO compilationDTO, LinkedHashMap<String, KiePMMLTableSourceCategory> regressionTablesMap) {
        logger.trace("getRegressionTable {}", regressionTablesMap);
        String className = "KiePMMLRegressionTableClassification" + classArity.addAndGet(1);
        CompilationUnit cloneCU = JavaParserUtils.getKiePMMLModelCompilationUnit((String)className, (String)compilationDTO.getPackageName(), (String)KIE_PMML_REGRESSION_TABLE_CLASSIFICATION_TEMPLATE_JAVA, (String)KIE_PMML_REGRESSION_TABLE_CLASSIFICATION_TEMPLATE);
        ClassOrInterfaceDeclaration tableTemplate = (ClassOrInterfaceDeclaration)cloneCU.getClassByName(className).orElseThrow(() -> new KiePMMLException("Main class not found: " + className));
        boolean isBinary = compilationDTO.isBinary(regressionTablesMap.size());
        ConstructorDeclaration constructorDeclaration = (ConstructorDeclaration)tableTemplate.getDefaultConstructor().orElseThrow(() -> new KiePMMLInternalException(String.format("Missing default constructor in ClassOrInterfaceDeclaration %s ", tableTemplate.getName())));
        KiePMMLRegressionTableClassificationFactory.setConstructor(compilationDTO, constructorDeclaration, tableTemplate.getName(), null, compilationDTO.getModelNormalizationMethod(), isBinary);
        KiePMMLRegressionTableClassificationFactory.addMapPopulation(constructorDeclaration.getBody(), regressionTablesMap);
        return new AbstractMap.SimpleEntry<String, String>(JavaParserUtils.getFullClassName((CompilationUnit)cloneCU), cloneCU.toString());
    }

    static void setConstructor(RegressionCompilationDTO compilationDTO, ConstructorDeclaration constructorDeclaration, SimpleName generatedClassName, Object targetCategory, RegressionModel.NormalizationMethod normalizationMethod, boolean isBinary) {
        constructorDeclaration.setName(generatedClassName);
        BlockStmt body = constructorDeclaration.getBody();
        CommonCodegenUtils.setAssignExpressionValue((BlockStmt)body, (String)"targetField", (Expression)new StringLiteralExpr(compilationDTO.getTargetFieldName()));
        REGRESSION_NORMALIZATION_METHOD regressionNormalizationMethod = compilationDTO.getDefaultREGRESSION_NORMALIZATION_METHOD();
        CommonCodegenUtils.setAssignExpressionValue((BlockStmt)body, (String)"regressionNormalizationMethod", (Expression)new NameExpr(regressionNormalizationMethod.getClass().getSimpleName() + "." + regressionNormalizationMethod.name()));
        OP_TYPE opType = compilationDTO.getOP_TYPE();
        if (opType != null) {
            CommonCodegenUtils.setAssignExpressionValue((BlockStmt)body, (String)"opType", (Expression)new NameExpr(opType.getClass().getSimpleName() + "." + opType.name()));
        }
        Expression targetCategoryExpression = CommonCodegenUtils.getExpressionForObject((Object)targetCategory);
        CommonCodegenUtils.setAssignExpressionValue((BlockStmt)body, (String)"targetCategory", (Expression)targetCategoryExpression);
        CommonCodegenUtils.setAssignExpressionValue((BlockStmt)body, (String)"isBinary", (Expression)CommonCodegenUtils.getExpressionForObject((Object)isBinary));
        Expression probabilityMapFunctionExpression = KiePMMLRegressionTableClassificationFactory.createProbabilityMapFunctionExpression(normalizationMethod, isBinary);
        CommonCodegenUtils.setAssignExpressionValue((BlockStmt)body, (String)"probabilityMapFunction", (Expression)probabilityMapFunctionExpression);
    }

    static Expression createProbabilityMapFunctionExpression(RegressionModel.NormalizationMethod normalizationMethod, boolean isBinary) {
        if (UNSUPPORTED_NORMALIZATION_METHODS.contains(normalizationMethod)) {
            throw new KiePMMLInternalException(String.format("Unsupported NormalizationMethod %s", normalizationMethod));
        }
        return KiePMMLRegressionTableClassificationFactory.createProbabilityMapFunctionSupportedExpression(normalizationMethod, isBinary);
    }

    static MethodReferenceExpr createProbabilityMapFunctionSupportedExpression(RegressionModel.NormalizationMethod normalizationMethod, boolean isBinary) {
        String normalizationName = normalizationMethod.name();
        if (RegressionModel.NormalizationMethod.NONE.equals((Object)normalizationMethod) && isBinary) {
            normalizationName = normalizationName + "Binary";
        }
        String thisExpressionMethodName = String.format("get%sProbabilityMap", normalizationName);
        CastExpr castExpr = new CastExpr();
        String stringClassName = String.class.getSimpleName();
        String doubleClassName = Double.class.getSimpleName();
        ClassOrInterfaceType linkedHashMapReferenceType = CommonCodegenUtils.getTypedClassOrInterfaceTypeByTypeNames((String)LinkedHashMap.class.getCanonicalName(), Arrays.asList(stringClassName, doubleClassName));
        ClassOrInterfaceType consumerType = CommonCodegenUtils.getTypedClassOrInterfaceTypeByTypes((String)SerializableFunction.class.getCanonicalName(), Arrays.asList(linkedHashMapReferenceType, linkedHashMapReferenceType));
        castExpr.setType((Type)consumerType);
        castExpr.setExpression((Expression)new ThisExpr());
        MethodReferenceExpr toReturn = new MethodReferenceExpr();
        toReturn.setScope((Expression)castExpr);
        toReturn.setIdentifier(thisExpressionMethodName);
        return toReturn;
    }

    static void addMapPopulation(BlockStmt body, LinkedHashMap<String, KiePMMLTableSourceCategory> regressionTablesMap) {
        regressionTablesMap.forEach((className, tableSourceCategory) -> {
            ObjectCreationExpr objectCreationExpr = new ObjectCreationExpr();
            objectCreationExpr.setType(className);
            NodeList expressions = NodeList.nodeList((Node[])new Expression[]{new StringLiteralExpr(tableSourceCategory.getCategory()), objectCreationExpr});
            body.addStatement((Expression)new MethodCallExpr((Expression)new NameExpr("categoryTableMap"), "put", expressions));
        });
    }
}

