package org.kie.pmml.models.clustering.compiler.factories;

import com.github.javaparser.ast.CompilationUnit;
import com.github.javaparser.ast.NodeList;
import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration;
import com.github.javaparser.ast.expr.BooleanLiteralExpr;
import com.github.javaparser.ast.expr.DoubleLiteralExpr;
import com.github.javaparser.ast.expr.NullLiteralExpr;
import com.github.javaparser.ast.expr.ObjectCreationExpr;
import com.github.javaparser.ast.stmt.BlockStmt;
import com.github.javaparser.ast.type.ClassOrInterfaceType;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Stream;
import org.dmg.pmml.Array;
import org.dmg.pmml.ComparisonMeasure;
import org.dmg.pmml.DataDictionary;
import org.dmg.pmml.TransformationDictionary;
import org.dmg.pmml.clustering.Cluster;
import org.dmg.pmml.clustering.ClusteringField;
import org.dmg.pmml.clustering.ClusteringModel;
import org.dmg.pmml.clustering.MissingValueWeights;
import org.jboss.forge.roaster._shade.org.eclipse.jdt.internal.core.JavadocConstants;
import org.kie.pmml.api.exceptions.KiePMMLException;
import org.kie.pmml.api.exceptions.KiePMMLInternalException;
import org.kie.pmml.commons.Constants;
import org.kie.pmml.commons.model.HasClassLoader;
import org.kie.pmml.commons.utils.KiePMMLModelUtils;
import org.kie.pmml.compiler.commons.builders.KiePMMLModelCodegenUtils;
import org.kie.pmml.compiler.commons.utils.CommonCodegenUtils;
import org.kie.pmml.compiler.commons.utils.JavaParserUtils;
import org.kie.pmml.models.clustering.model.KiePMMLCluster;
import org.kie.pmml.models.clustering.model.KiePMMLClusteringField;
import org.kie.pmml.models.clustering.model.KiePMMLClusteringModel;
import org.kie.pmml.models.clustering.model.KiePMMLComparisonMeasure;
import org.kie.pmml.models.clustering.model.KiePMMLMissingValueWeights;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:WEB-INF/lib/kie-pmml-models-clustering-compiler-7.59.1-SNAPSHOT.jar:org/kie/pmml/models/clustering/compiler/factories/KiePMMLClusteringModelFactory.class */
public class KiePMMLClusteringModelFactory {
    static final String KIE_PMML_CLUSTERING_MODEL_TEMPLATE_JAVA = "KiePMMLClusteringModelTemplate.tmpl";
    static final String KIE_PMML_CLUSTERING_MODEL_TEMPLATE = "KiePMMLClusteringModelTemplate";
    private static final Logger logger = LoggerFactory.getLogger(KiePMMLClusteringModelFactory.class.getName());

    private KiePMMLClusteringModelFactory() {
    }

    public static KiePMMLClusteringModel getKiePMMLClusteringModel(DataDictionary dataDictionary, TransformationDictionary transformationDictionary, ClusteringModel clusteringModel, String str, HasClassLoader hasClassLoader) {
        logger.trace("getKiePMMLClusteringModel {} {}", dataDictionary, clusteringModel);
        try {
            return (KiePMMLClusteringModel) hasClassLoader.compileAndLoadClass(getKiePMMLClusteringModelSourcesMap(dataDictionary, transformationDictionary, clusteringModel, str), str + "." + KiePMMLModelUtils.getSanitizedClassName(clusteringModel.getModelName())).getDeclaredConstructor(new Class[0]).newInstance(new Object[0]);
        } catch (Exception e) {
            throw new KiePMMLException(e);
        }
    }

    public static Map<String, String> getKiePMMLClusteringModelSourcesMap(DataDictionary dataDictionary, TransformationDictionary transformationDictionary, ClusteringModel clusteringModel, String str) {
        logger.trace("getKiePMMLClusteringModelSourcesMap {} {} {}", dataDictionary, clusteringModel, str);
        String sanitizedClassName = KiePMMLModelUtils.getSanitizedClassName(clusteringModel.getModelName());
        CompilationUnit kiePMMLModelCompilationUnit = JavaParserUtils.getKiePMMLModelCompilationUnit(sanitizedClassName, str, KIE_PMML_CLUSTERING_MODEL_TEMPLATE_JAVA, KIE_PMML_CLUSTERING_MODEL_TEMPLATE);
        setConstructor(clusteringModel, dataDictionary, transformationDictionary, kiePMMLModelCompilationUnit.getClassByName(sanitizedClassName).orElseThrow(() -> {
            return new KiePMMLException("Main class not found: " + sanitizedClassName);
        }));
        HashMap hashMap = new HashMap();
        hashMap.put(JavaParserUtils.getFullClassName(kiePMMLModelCompilationUnit), kiePMMLModelCompilationUnit.toString());
        return hashMap;
    }

    static void setConstructor(ClusteringModel clusteringModel, DataDictionary dataDictionary, TransformationDictionary transformationDictionary, ClassOrInterfaceDeclaration classOrInterfaceDeclaration) {
        KiePMMLModelCodegenUtils.init(classOrInterfaceDeclaration, dataDictionary, transformationDictionary, clusteringModel);
        BlockStmt body = classOrInterfaceDeclaration.getDefaultConstructor().orElseThrow(() -> {
            return new KiePMMLInternalException(String.format(Constants.MISSING_DEFAULT_CONSTRUCTOR, classOrInterfaceDeclaration.getName()));
        }).getBody();
        body.addStatement(CommonCodegenUtils.assignExprFrom("modelClass", KiePMMLClusteringConversionUtils.modelClassFrom(clusteringModel.getModelClass())));
        Stream map = clusteringModel.getClusters().stream().map(KiePMMLClusteringModelFactory::clusterCreationExprFrom).map(objectCreationExpr -> {
            return CommonCodegenUtils.methodCallExprFrom("clusters", "add", objectCreationExpr);
        });
        Objects.requireNonNull(body);
        map.forEach((v1) -> {
            r1.addStatement(v1);
        });
        Stream map2 = clusteringModel.getClusteringFields().stream().map(KiePMMLClusteringModelFactory::clusteringFieldCreationExprFrom).map(objectCreationExpr2 -> {
            return CommonCodegenUtils.methodCallExprFrom("clusteringFields", "add", objectCreationExpr2);
        });
        Objects.requireNonNull(body);
        map2.forEach((v1) -> {
            r1.addStatement(v1);
        });
        body.addStatement(CommonCodegenUtils.assignExprFrom("comparisonMeasure", comparisonMeasureCreationExprFrom(clusteringModel.getComparisonMeasure())));
        if (clusteringModel.getMissingValueWeights() != null) {
            body.addStatement(CommonCodegenUtils.assignExprFrom("missingValueWeights", missingValueWeightsCreationExprFrom(clusteringModel.getMissingValueWeights())));
        }
    }

    private static ObjectCreationExpr clusterCreationExprFrom(Cluster cluster) {
        NodeList nodeList = new NodeList();
        nodeList.add((NodeList) CommonCodegenUtils.literalExprFrom(cluster.getId()));
        nodeList.add((NodeList) CommonCodegenUtils.literalExprFrom(cluster.getName()));
        if (cluster.getArray() != null && cluster.getArray().getType() == Array.Type.REAL) {
            String str = (String) cluster.getArray().getValue();
            try {
                Stream map = Arrays.stream(str.split(" ")).map(Double::parseDouble).map((v1) -> {
                    return new DoubleLiteralExpr(v1);
                });
                Objects.requireNonNull(nodeList);
                map.forEach((v1) -> {
                    r1.add(v1);
                });
            } catch (NumberFormatException e) {
                logger.error("Can't parse \"real\" cluster with value \"" + str + JavadocConstants.ANCHOR_PREFIX_END, (Throwable) e);
            }
        }
        return new ObjectCreationExpr(null, new ClassOrInterfaceType(null, KiePMMLCluster.class.getCanonicalName()), nodeList);
    }

    private static ObjectCreationExpr clusteringFieldCreationExprFrom(ClusteringField clusteringField) {
        double doubleValue = clusteringField.getFieldWeight() == null ? 1.0d : clusteringField.getFieldWeight().doubleValue();
        boolean z = clusteringField.getCenterField() == null || clusteringField.getCenterField() == ClusteringField.CenterField.TRUE;
        NodeList nodeList = new NodeList();
        nodeList.add((NodeList) CommonCodegenUtils.literalExprFrom(clusteringField.getField().getValue()));
        nodeList.add((NodeList) new DoubleLiteralExpr(doubleValue));
        nodeList.add((NodeList) new BooleanLiteralExpr(z));
        nodeList.add((NodeList) (clusteringField.getCompareFunction() == null ? new NullLiteralExpr() : CommonCodegenUtils.literalExprFrom(KiePMMLClusteringConversionUtils.compareFunctionFrom(clusteringField.getCompareFunction()))));
        nodeList.add((NodeList) new NullLiteralExpr());
        return new ObjectCreationExpr(null, new ClassOrInterfaceType(null, KiePMMLClusteringField.class.getCanonicalName()), nodeList);
    }

    private static ObjectCreationExpr comparisonMeasureCreationExprFrom(ComparisonMeasure comparisonMeasure) {
        NodeList nodeList = new NodeList();
        nodeList.add((NodeList) CommonCodegenUtils.literalExprFrom(KiePMMLClusteringConversionUtils.comparisonMeasureKindFrom(comparisonMeasure.getKind())));
        nodeList.add((NodeList) CommonCodegenUtils.literalExprFrom(KiePMMLClusteringConversionUtils.aggregateFunctionFrom(comparisonMeasure.getMeasure())));
        nodeList.add((NodeList) CommonCodegenUtils.literalExprFrom(KiePMMLClusteringConversionUtils.compareFunctionFrom(comparisonMeasure.getCompareFunction())));
        return new ObjectCreationExpr(null, new ClassOrInterfaceType(null, KiePMMLComparisonMeasure.class.getCanonicalName()), nodeList);
    }

    private static ObjectCreationExpr missingValueWeightsCreationExprFrom(MissingValueWeights missingValueWeights) {
        NodeList nodeList = new NodeList();
        if (missingValueWeights.getArray() != null && missingValueWeights.getArray().getType() == Array.Type.REAL) {
            String str = (String) missingValueWeights.getArray().getValue();
            try {
                Stream map = Arrays.stream(str.split(" ")).map(Double::parseDouble).map((v1) -> {
                    return new DoubleLiteralExpr(v1);
                });
                Objects.requireNonNull(nodeList);
                map.forEach((v1) -> {
                    r1.add(v1);
                });
            } catch (NumberFormatException e) {
                logger.error("Can't parse \"real\" missing value weights with value \"" + str + JavadocConstants.ANCHOR_PREFIX_END, (Throwable) e);
            }
        }
        return new ObjectCreationExpr(null, new ClassOrInterfaceType(null, KiePMMLMissingValueWeights.class.getCanonicalName()), nodeList);
    }
}
