/*
 * Decompiled with CFR 0.152.
 */
package org.kie.pmml.models.regression.compiler.factories;

import com.github.javaparser.ast.CompilationUnit;
import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration;
import com.github.javaparser.ast.body.ConstructorDeclaration;
import com.github.javaparser.ast.expr.AssignExpr;
import com.github.javaparser.ast.expr.Expression;
import com.github.javaparser.ast.expr.NameExpr;
import com.github.javaparser.ast.expr.ObjectCreationExpr;
import com.github.javaparser.ast.expr.StringLiteralExpr;
import com.github.javaparser.ast.stmt.BlockStmt;
import com.github.javaparser.ast.stmt.ExplicitConstructorInvocationStmt;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import org.dmg.pmml.DataDictionary;
import org.dmg.pmml.DataField;
import org.dmg.pmml.LocalTransformations;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.TransformationDictionary;
import org.dmg.pmml.regression.RegressionModel;
import org.kie.memorycompiler.KieMemoryCompiler;
import org.kie.pmml.commons.exceptions.KiePMMLInternalException;
import org.kie.pmml.commons.model.KiePMMLOutputField;
import org.kie.pmml.commons.model.enums.MINING_FUNCTION;
import org.kie.pmml.commons.model.enums.PMML_MODEL;
import org.kie.pmml.commons.utils.KiePMMLModelUtils;
import org.kie.pmml.compiler.commons.factories.KiePMMLOutputFieldFactory;
import org.kie.pmml.compiler.commons.utils.JavaParserUtils;
import org.kie.pmml.compiler.commons.utils.KiePMMLModelFactoryUtils;
import org.kie.pmml.compiler.commons.utils.ModelUtils;
import org.kie.pmml.models.regression.compiler.factories.KiePMMLRegressionTableClassificationFactory;
import org.kie.pmml.models.regression.compiler.factories.KiePMMLRegressionTableRegressionFactory;
import org.kie.pmml.models.regression.model.KiePMMLRegressionModel;
import org.kie.pmml.models.regression.model.tuples.KiePMMLTableSourceCategory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KiePMMLRegressionModelFactory {
    private static final Logger logger = LoggerFactory.getLogger((String)KiePMMLRegressionModelFactory.class.getName());
    static final String KIE_PMML_REGRESSION_MODEL_TEMPLATE_JAVA = "KiePMMLRegressionModelTemplate.tmpl";
    static final String KIE_PMML_REGRESSION_MODEL_TEMPLATE = "KiePMMLRegressionModelTemplate";

    private KiePMMLRegressionModelFactory() {
    }

    public static KiePMMLRegressionModel getKiePMMLRegressionModelClasses(DataDictionary dataDictionary, TransformationDictionary transformationDictionary, RegressionModel model) throws IOException, IllegalAccessException, InstantiationException {
        logger.trace("getKiePMMLRegressionModelClasses {} {}", (Object)dataDictionary, (Object)model);
        String className = KiePMMLModelUtils.getSanitizedClassName((String)model.getModelName());
        String packageName = KiePMMLModelUtils.getSanitizedPackageName((String)model.getModelName());
        Map<String, String> sourcesMap = KiePMMLRegressionModelFactory.getKiePMMLRegressionModelSourcesMap(dataDictionary, transformationDictionary, model, packageName);
        String fullClassName = packageName + "." + className;
        Map compiledClasses = KieMemoryCompiler.compile(sourcesMap, (ClassLoader)Thread.currentThread().getContextClassLoader());
        return (KiePMMLRegressionModel)((Class)compiledClasses.get(fullClassName)).newInstance();
    }

    public static Map<String, String> getKiePMMLRegressionModelSourcesMap(DataDictionary dataDictionary, TransformationDictionary transformationDictionary, RegressionModel model, String packageName) throws IOException {
        logger.trace("getKiePMMLRegressionModelSourcesMap {} {} {}", new Object[]{dataDictionary, model, packageName});
        String className = KiePMMLModelUtils.getSanitizedClassName((String)model.getModelName());
        String modelName = model.getModelName();
        String targetFieldName = ModelUtils.getTargetFieldName((DataDictionary)dataDictionary, (Model)model).orElse(null);
        List outputFields = KiePMMLOutputFieldFactory.getOutputFields((Model)model);
        Map<String, KiePMMLTableSourceCategory> tablesSourceMap = KiePMMLRegressionModelFactory.getRegressionTablesMap(dataDictionary, model, targetFieldName, outputFields, packageName);
        CompilationUnit templateCU = JavaParserUtils.getFromFileName((String)KIE_PMML_REGRESSION_MODEL_TEMPLATE_JAVA);
        CompilationUnit cloneCU = templateCU.clone();
        cloneCU.setPackageDeclaration(packageName);
        ClassOrInterfaceDeclaration modelTemplate = (ClassOrInterfaceDeclaration)cloneCU.getClassByName(KIE_PMML_REGRESSION_MODEL_TEMPLATE).orElseThrow(() -> new RuntimeException("Main class not found"));
        modelTemplate.setName(className);
        String nestedTable = tablesSourceMap.size() == 1 ? tablesSourceMap.keySet().iterator().next() : tablesSourceMap.keySet().stream().filter(tableName -> tableName.startsWith("KiePMMLRegressionTableClassification")).findFirst().orElseThrow(() -> new RuntimeException("Failed to find expected KiePMMLRegressionTableClassification"));
        ConstructorDeclaration constructorDeclaration = (ConstructorDeclaration)modelTemplate.getDefaultConstructor().orElseThrow(() -> new KiePMMLInternalException(String.format("Missing default constructor in ClassOrInterfaceDeclaration %s ", modelTemplate.getName())));
        KiePMMLRegressionModelFactory.populateConstructor(className, nestedTable, constructorDeclaration, targetFieldName, MINING_FUNCTION.byName((String)model.getMiningFunction().value()), modelName);
        KiePMMLModelFactoryUtils.addTransformationsInClassOrInterfaceDeclaration((ClassOrInterfaceDeclaration)modelTemplate, (TransformationDictionary)transformationDictionary, (LocalTransformations)model.getLocalTransformations());
        Map<String, String> toReturn = tablesSourceMap.entrySet().stream().collect(Collectors.toMap(entry -> packageName + "." + (String)entry.getKey(), entry -> ((KiePMMLTableSourceCategory)entry.getValue()).getSource()));
        String fullClassName = packageName + "." + className;
        toReturn.put(fullClassName, cloneCU.toString());
        return toReturn;
    }

    static Map<String, KiePMMLTableSourceCategory> getRegressionTablesMap(DataDictionary dataDictionary, RegressionModel model, String targetFieldName, List<KiePMMLOutputField> outputFields, String packageName) throws IOException {
        DataField targetDataField = dataDictionary.getDataFields().stream().filter(field -> Objects.equals(targetFieldName, field.getName().getValue())).findFirst().orElse(null);
        OpType opType = targetDataField != null ? targetDataField.getOpType() : null;
        Map<String, KiePMMLTableSourceCategory> toReturn = KiePMMLRegressionModelFactory.isRegression(model.getMiningFunction(), targetFieldName, opType) ? KiePMMLRegressionTableRegressionFactory.getRegressionTables(Collections.singletonList(model.getRegressionTables().get(0)), model.getNormalizationMethod(), targetFieldName, packageName) : KiePMMLRegressionTableClassificationFactory.getRegressionTables(model.getRegressionTables(), model.getNormalizationMethod(), opType, outputFields, targetFieldName, packageName);
        return toReturn;
    }

    static void populateConstructor(String generatedClassName, String nestedTable, ConstructorDeclaration constructorDeclaration, String targetField, MINING_FUNCTION miningFunction, String modelName) {
        ObjectCreationExpr objectCreationExpr = new ObjectCreationExpr();
        objectCreationExpr.setType(nestedTable);
        constructorDeclaration.setName(generatedClassName);
        BlockStmt body = constructorDeclaration.getBody();
        body.getStatements().iterator().forEachRemaining(statement -> {
            if (statement instanceof ExplicitConstructorInvocationStmt) {
                ExplicitConstructorInvocationStmt superStatement = (ExplicitConstructorInvocationStmt)statement;
                NameExpr modelNameExpr = (NameExpr)superStatement.getArgument(0);
                modelNameExpr.setName(String.format("\"%s\"", modelName));
            }
        });
        List assignExprs = body.findAll(AssignExpr.class);
        assignExprs.forEach(assignExpr -> {
            if (assignExpr.getTarget().asNameExpr().getNameAsString().equals("regressionTable")) {
                assignExpr.setValue((Expression)objectCreationExpr);
            } else if (assignExpr.getTarget().asNameExpr().getNameAsString().equals("targetField")) {
                assignExpr.setValue((Expression)new StringLiteralExpr(targetField));
            } else if (assignExpr.getTarget().asNameExpr().getNameAsString().equals("miningFunction")) {
                assignExpr.setValue((Expression)new NameExpr(miningFunction.getClass().getName() + "." + miningFunction.name()));
            } else if (assignExpr.getTarget().asNameExpr().getNameAsString().equals("pmmlMODEL")) {
                assignExpr.setValue((Expression)new NameExpr(PMML_MODEL.REGRESSION_MODEL.getClass().getName() + "." + PMML_MODEL.REGRESSION_MODEL.name()));
            }
        });
    }

    static boolean isRegression(MiningFunction miningFunction, String targetField, OpType targetOpType) {
        return Objects.equals(MiningFunction.REGRESSION, miningFunction) && (targetField == null || Objects.equals(OpType.CONTINUOUS, targetOpType));
    }
}

