package org.kie.pmml.models.regression.compiler.factories;

import com.github.javaparser.ast.CompilationUnit;
import com.github.javaparser.ast.NodeList;
import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration;
import com.github.javaparser.ast.body.ConstructorDeclaration;
import com.github.javaparser.ast.body.MethodDeclaration;
import com.github.javaparser.ast.body.Parameter;
import com.github.javaparser.ast.body.VariableDeclarator;
import com.github.javaparser.ast.expr.CastExpr;
import com.github.javaparser.ast.expr.DoubleLiteralExpr;
import com.github.javaparser.ast.expr.Expression;
import com.github.javaparser.ast.expr.LambdaExpr;
import com.github.javaparser.ast.expr.MethodCallExpr;
import com.github.javaparser.ast.expr.MethodReferenceExpr;
import com.github.javaparser.ast.expr.NameExpr;
import com.github.javaparser.ast.expr.NullLiteralExpr;
import com.github.javaparser.ast.expr.ObjectCreationExpr;
import com.github.javaparser.ast.expr.SimpleName;
import com.github.javaparser.ast.expr.StringLiteralExpr;
import com.github.javaparser.ast.expr.ThisExpr;
import com.github.javaparser.ast.expr.VariableDeclarationExpr;
import com.github.javaparser.ast.stmt.BlockStmt;
import com.github.javaparser.ast.stmt.ExpressionStmt;
import com.github.javaparser.ast.type.ClassOrInterfaceType;
import com.github.javaparser.ast.type.Type;
import com.github.javaparser.ast.type.UnknownType;
import java.util.AbstractMap;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import org.dmg.pmml.regression.CategoricalPredictor;
import org.dmg.pmml.regression.NumericPredictor;
import org.dmg.pmml.regression.PredictorTerm;
import org.dmg.pmml.regression.RegressionModel;
import org.dmg.pmml.regression.RegressionTable;
import org.kie.pmml.api.exceptions.KiePMMLException;
import org.kie.pmml.api.exceptions.KiePMMLInternalException;
import org.kie.pmml.api.iinterfaces.SerializableFunction;
import org.kie.pmml.commons.Constants;
import org.kie.pmml.commons.utils.KiePMMLModelUtils;
import org.kie.pmml.compiler.commons.utils.CommonCodegenUtils;
import org.kie.pmml.compiler.commons.utils.JavaParserUtils;
import org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO;
import org.kie.pmml.models.regression.model.tuples.KiePMMLTableSourceCategory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:BOOT-INF/lib/kie-pmml-models-regression-compiler-8.17.1-SNAPSHOT.jar:org/kie/pmml/models/regression/compiler/factories/KiePMMLRegressionTableRegressionFactory.class */
public class KiePMMLRegressionTableRegressionFactory {
    public static final String KIE_PMML_REGRESSION_TABLE_REGRESSION_TEMPLATE_JAVA = "KiePMMLRegressionTableRegressionTemplate.tmpl";
    public static final String KIE_PMML_REGRESSION_TABLE_REGRESSION_TEMPLATE = "KiePMMLRegressionTableRegressionTemplate";
    static final String MAIN_CLASS_NOT_FOUND = "Main class not found";
    static final String KIE_PMML_EVALUATE_METHOD_TEMPLATE_JAVA = "KiePMMLEvaluateMethodTemplate.tmpl";
    static final String KIE_PMML_EVALUATE_METHOD_TEMPLATE = "KiePMMLEvaluateMethodTemplate";
    private static final String COEFFICIENT = "coefficient";
    private static CompilationUnit templateEvaluate;
    private static CompilationUnit cloneEvaluate;
    private static final Logger logger = LoggerFactory.getLogger(KiePMMLRegressionTableRegressionFactory.class.getName());
    static final List<RegressionModel.NormalizationMethod> SUPPORTED_NORMALIZATION_METHODS = Arrays.asList(RegressionModel.NormalizationMethod.SOFTMAX, RegressionModel.NormalizationMethod.LOGIT, RegressionModel.NormalizationMethod.EXP, RegressionModel.NormalizationMethod.PROBIT, RegressionModel.NormalizationMethod.CLOGLOG, RegressionModel.NormalizationMethod.CAUCHIT, RegressionModel.NormalizationMethod.NONE);
    static final List<RegressionModel.NormalizationMethod> UNSUPPORTED_NORMALIZATION_METHODS = Arrays.asList(RegressionModel.NormalizationMethod.SIMPLEMAX, RegressionModel.NormalizationMethod.LOGLOG);
    private static AtomicInteger classArity = new AtomicInteger(0);
    private static AtomicInteger predictorsArity = new AtomicInteger(0);

    private KiePMMLRegressionTableRegressionFactory() {
    }

    public static LinkedHashMap<String, KiePMMLTableSourceCategory> getRegressionTables(RegressionCompilationDTO regressionCompilationDTO) {
        logger.trace("getRegressionTables {}", regressionCompilationDTO.getRegressionTables());
        LinkedHashMap<String, KiePMMLTableSourceCategory> linkedHashMap = new LinkedHashMap<>();
        for (RegressionTable regressionTable : regressionCompilationDTO.getRegressionTables()) {
            Map.Entry<String, String> regressionTable2 = getRegressionTable(regressionTable, regressionCompilationDTO);
            linkedHashMap.put(regressionTable2.getKey(), new KiePMMLTableSourceCategory(regressionTable2.getValue(), regressionTable.getTargetCategory() != null ? regressionTable.getTargetCategory().toString() : ""));
        }
        return linkedHashMap;
    }

    public static Map.Entry<String, String> getRegressionTable(RegressionTable regressionTable, RegressionCompilationDTO regressionCompilationDTO) {
        logger.trace("getRegressionTable {}", regressionTable);
        String str = "KiePMMLRegressionTableRegression" + classArity.addAndGet(1);
        CompilationUnit kiePMMLModelCompilationUnit = JavaParserUtils.getKiePMMLModelCompilationUnit(str, regressionCompilationDTO.getPackageName(), KIE_PMML_REGRESSION_TABLE_REGRESSION_TEMPLATE_JAVA, KIE_PMML_REGRESSION_TABLE_REGRESSION_TEMPLATE);
        ClassOrInterfaceDeclaration orElseThrow = kiePMMLModelCompilationUnit.getClassByName(str).orElseThrow(() -> {
            return new KiePMMLException("Main class not found: " + str);
        });
        ConstructorDeclaration orElseThrow2 = orElseThrow.getDefaultConstructor().orElseThrow(() -> {
            return new KiePMMLInternalException(String.format(Constants.MISSING_DEFAULT_CONSTRUCTOR, orElseThrow.getName()));
        });
        setConstructor(regressionTable, orElseThrow2, orElseThrow.getName(), regressionCompilationDTO.getTargetFieldName(), regressionTable.getTargetCategory(), regressionCompilationDTO.getDefaultNormalizationMethod());
        Map<String, Expression> createNumericPredictorsExpressions = createNumericPredictorsExpressions(regressionTable.getNumericPredictors());
        Map<String, MethodDeclaration> addPredictorTerms = addPredictorTerms(regressionTable.getPredictorTerms(), orElseThrow);
        BlockStmt body = orElseThrow2.getBody();
        Map<String, Expression> createCategoricalPredictorsExpressions = createCategoricalPredictorsExpressions(regressionTable.getCategoricalPredictors(), body);
        CommonCodegenUtils.addMapPopulationExpressions(createNumericPredictorsExpressions, body, "numericFunctionMap");
        CommonCodegenUtils.addMapPopulationExpressions(createCategoricalPredictorsExpressions, body, "categoricalFunctionMap");
        CommonCodegenUtils.addMapPopulation(addPredictorTerms, body, "predictorTermsFunctionMap");
        return new AbstractMap.SimpleEntry(JavaParserUtils.getFullClassName(kiePMMLModelCompilationUnit), kiePMMLModelCompilationUnit.toString());
    }

    static void setConstructor(RegressionTable regressionTable, ConstructorDeclaration constructorDeclaration, SimpleName simpleName, String str, Object obj, RegressionModel.NormalizationMethod normalizationMethod) {
        constructorDeclaration.setName(simpleName);
        BlockStmt body = constructorDeclaration.getBody();
        CommonCodegenUtils.setAssignExpressionValue(body, "intercept", new DoubleLiteralExpr(String.valueOf(regressionTable.getIntercept().doubleValue())));
        CommonCodegenUtils.setAssignExpressionValue(body, "targetField", new StringLiteralExpr(str));
        CommonCodegenUtils.setAssignExpressionValue(body, "targetCategory", CommonCodegenUtils.getExpressionForObject(obj));
        CommonCodegenUtils.setAssignExpressionValue(body, "resultUpdater", createResultUpdaterExpression(normalizationMethod));
    }

    static Expression createResultUpdaterExpression(RegressionModel.NormalizationMethod normalizationMethod) {
        return UNSUPPORTED_NORMALIZATION_METHODS.contains(normalizationMethod) ? new NullLiteralExpr() : createResultUpdaterSupportedExpression(normalizationMethod);
    }

    static MethodReferenceExpr createResultUpdaterSupportedExpression(RegressionModel.NormalizationMethod normalizationMethod) {
        String format = String.format("update%sResult", normalizationMethod.name());
        CastExpr castExpr = new CastExpr();
        String simpleName = Double.class.getSimpleName();
        castExpr.setType((Type) CommonCodegenUtils.getTypedClassOrInterfaceTypeByTypeNames(SerializableFunction.class.getCanonicalName(), Arrays.asList(simpleName, simpleName)));
        castExpr.setExpression((Expression) new ThisExpr());
        MethodReferenceExpr methodReferenceExpr = new MethodReferenceExpr();
        methodReferenceExpr.setScope(castExpr);
        methodReferenceExpr.setIdentifier(format);
        return methodReferenceExpr;
    }

    static Map<String, Expression> createNumericPredictorsExpressions(List<NumericPredictor> list) {
        return (Map) list.stream().collect(Collectors.toMap(numericPredictor -> {
            return numericPredictor.getName().getValue();
        }, KiePMMLRegressionTableRegressionFactory::createNumericPredictorExpression));
    }

    static CastExpr createNumericPredictorExpression(NumericPredictor numericPredictor) {
        boolean z = !Objects.equals(1, numericPredictor.getExponent());
        String str = z ? "evaluateNumericWithExponent" : "evaluateNumericWithoutExponent";
        MethodCallExpr methodCallExpr = new MethodCallExpr();
        methodCallExpr.setName(str);
        NodeList<Expression> nodeList = new NodeList<>();
        nodeList.add(0, (int) new NameExpr("input"));
        nodeList.add(1, (int) CommonCodegenUtils.getExpressionForObject(Double.valueOf(numericPredictor.getCoefficient().doubleValue())));
        if (z) {
            nodeList.add(2, (int) CommonCodegenUtils.getExpressionForObject(Double.valueOf(numericPredictor.getExponent().doubleValue())));
        }
        methodCallExpr.setArguments(nodeList);
        ExpressionStmt expressionStmt = new ExpressionStmt(methodCallExpr);
        LambdaExpr lambdaExpr = new LambdaExpr();
        lambdaExpr.setParameters(NodeList.nodeList(new Parameter(new UnknownType(), "input")));
        lambdaExpr.setBody(expressionStmt);
        String simpleName = Double.class.getSimpleName();
        ClassOrInterfaceType typedClassOrInterfaceTypeByTypeNames = CommonCodegenUtils.getTypedClassOrInterfaceTypeByTypeNames(SerializableFunction.class.getCanonicalName(), Arrays.asList(simpleName, simpleName));
        CastExpr castExpr = new CastExpr();
        castExpr.setType((Type) typedClassOrInterfaceTypeByTypeNames);
        castExpr.setExpression((Expression) lambdaExpr);
        return castExpr;
    }

    static Map<String, Expression> createCategoricalPredictorsExpressions(List<CategoricalPredictor> list, BlockStmt blockStmt) {
        return (Map) ((Map) list.stream().collect(Collectors.groupingBy(categoricalPredictor -> {
            return categoricalPredictor.getField().getValue();
        }))).entrySet().stream().map(entry -> {
            String sanitizedVariableName = KiePMMLModelUtils.getSanitizedVariableName(String.format("%sMap", entry.getKey()));
            populateWithGroupedCategoricalPredictorMap((List) entry.getValue(), blockStmt, sanitizedVariableName);
            return new AbstractMap.SimpleEntry((String) entry.getKey(), createCategoricalPredictorExpression(sanitizedVariableName));
        }).collect(Collectors.toMap((v0) -> {
            return v0.getKey();
        }, (v0) -> {
            return v0.getValue();
        }));
    }

    static void populateWithGroupedCategoricalPredictorMap(List<CategoricalPredictor> list, BlockStmt blockStmt, String str) {
        VariableDeclarator variableDeclarator = new VariableDeclarator(CommonCodegenUtils.getTypedClassOrInterfaceTypeByTypeNames(Map.class.getName(), Arrays.asList(String.class.getSimpleName(), Double.class.getSimpleName())), str);
        ObjectCreationExpr objectCreationExpr = new ObjectCreationExpr();
        objectCreationExpr.setType(CommonCodegenUtils.getTypedClassOrInterfaceTypeByTypeNames(HashMap.class.getName(), Arrays.asList(String.class.getSimpleName(), Double.class.getSimpleName())));
        variableDeclarator.setInitializer(objectCreationExpr);
        blockStmt.addStatement(new VariableDeclarationExpr(variableDeclarator));
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        list.forEach(categoricalPredictor -> {
            linkedHashMap.put(categoricalPredictor.getValue().toString(), CommonCodegenUtils.getExpressionForObject(Double.valueOf(categoricalPredictor.getCoefficient().doubleValue())));
        });
        CommonCodegenUtils.addMapPopulationExpressions(linkedHashMap, blockStmt, str);
    }

    static CastExpr createCategoricalPredictorExpression(String str) {
        MethodCallExpr methodCallExpr = new MethodCallExpr();
        methodCallExpr.setName("evaluateCategoricalPredictor");
        NodeList<Expression> nodeList = new NodeList<>();
        nodeList.add(0, (int) new NameExpr("input"));
        nodeList.add(1, (int) new NameExpr(str));
        methodCallExpr.setArguments(nodeList);
        ExpressionStmt expressionStmt = new ExpressionStmt(methodCallExpr);
        LambdaExpr lambdaExpr = new LambdaExpr();
        lambdaExpr.setParameters(NodeList.nodeList(new Parameter(new UnknownType(), "input")));
        lambdaExpr.setBody(expressionStmt);
        ClassOrInterfaceType typedClassOrInterfaceTypeByTypeNames = CommonCodegenUtils.getTypedClassOrInterfaceTypeByTypeNames(SerializableFunction.class.getCanonicalName(), Arrays.asList(String.class.getSimpleName(), Double.class.getSimpleName()));
        CastExpr castExpr = new CastExpr();
        castExpr.setType((Type) typedClassOrInterfaceTypeByTypeNames);
        castExpr.setExpression((Expression) lambdaExpr);
        return castExpr;
    }

    static Map<String, MethodDeclaration> addPredictorTerms(List<PredictorTerm> list, ClassOrInterfaceDeclaration classOrInterfaceDeclaration) {
        predictorsArity.set(0);
        return (Map) list.stream().map(predictorTerm -> {
            int addAndGet = predictorsArity.addAndGet(1);
            return new AbstractMap.SimpleEntry(predictorTerm.getName() != null ? predictorTerm.getName().getValue() : "predictorTerm" + addAndGet, addPredictorTerm(predictorTerm, classOrInterfaceDeclaration, addAndGet));
        }).collect(Collectors.toMap((v0) -> {
            return v0.getKey();
        }, (v0) -> {
            return v0.getValue();
        }));
    }

    static MethodDeclaration addPredictorTerm(PredictorTerm predictorTerm, ClassOrInterfaceDeclaration classOrInterfaceDeclaration, int i) {
        try {
            templateEvaluate = JavaParserUtils.getFromFileName(KIE_PMML_EVALUATE_METHOD_TEMPLATE_JAVA);
            cloneEvaluate = templateEvaluate.mo435clone();
            MethodDeclaration methodDeclaration = cloneEvaluate.getClassByName(KIE_PMML_EVALUATE_METHOD_TEMPLATE).orElseThrow(() -> {
                return new RuntimeException("Main class not found");
            }).getMethodsByName("evaluatePredictor").get(0);
            BlockStmt orElseThrow = methodDeclaration.getBody().orElseThrow(() -> {
                return new KiePMMLInternalException(String.format(Constants.MISSING_BODY_TEMPLATE, methodDeclaration.getName()));
            });
            CommonCodegenUtils.getVariableDeclarator(orElseThrow, "fieldRefs").orElseThrow(() -> {
                return new KiePMMLInternalException(String.format(Constants.MISSING_VARIABLE_IN_BODY, "fieldRefs", orElseThrow));
            }).setInitializer(new MethodCallExpr(new NameExpr("Arrays"), Constants.AS_LIST, (NodeList<Expression>) NodeList.nodeList((List) predictorTerm.getFieldRefs().stream().map(fieldRef -> {
                return new StringLiteralExpr(fieldRef.getField().getValue());
            }).collect(Collectors.toList()))));
            CommonCodegenUtils.getVariableDeclarator(orElseThrow, COEFFICIENT).orElseThrow(() -> {
                return new KiePMMLInternalException(String.format(Constants.MISSING_VARIABLE_IN_BODY, COEFFICIENT, orElseThrow));
            }).setInitializer(String.valueOf(predictorTerm.getCoefficient().doubleValue()));
            return CommonCodegenUtils.addMethod(methodDeclaration, classOrInterfaceDeclaration, "evaluatePredictorTerm" + i);
        } catch (Exception e) {
            throw new KiePMMLInternalException(String.format("Failed to add PredictorTerm %s", predictorTerm), e);
        }
    }
}
