package org.kie.pmml.models.clustering.compiler.factories;

import com.github.javaparser.ast.CompilationUnit;
import com.github.javaparser.ast.NodeList;
import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration;
import com.github.javaparser.ast.expr.BooleanLiteralExpr;
import com.github.javaparser.ast.expr.DoubleLiteralExpr;
import com.github.javaparser.ast.expr.MethodCallExpr;
import com.github.javaparser.ast.expr.NullLiteralExpr;
import com.github.javaparser.ast.expr.ObjectCreationExpr;
import com.github.javaparser.ast.stmt.BlockStmt;
import com.github.javaparser.ast.type.ClassOrInterfaceType;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.lang3.StringUtils;
import org.dmg.pmml.Array;
import org.dmg.pmml.ComparisonMeasure;
import org.dmg.pmml.clustering.Cluster;
import org.dmg.pmml.clustering.ClusteringField;
import org.dmg.pmml.clustering.ClusteringModel;
import org.dmg.pmml.clustering.MissingValueWeights;
import org.kie.pmml.api.exceptions.KiePMMLException;
import org.kie.pmml.api.exceptions.KiePMMLInternalException;
import org.kie.pmml.commons.Constants;
import org.kie.pmml.compiler.api.dto.CompilationDTO;
import org.kie.pmml.compiler.commons.codegenfactories.KiePMMLModelFactoryUtils;
import org.kie.pmml.compiler.commons.utils.CommonCodegenUtils;
import org.kie.pmml.compiler.commons.utils.JavaParserUtils;
import org.kie.pmml.models.clustering.compiler.dto.ClusteringCompilationDTO;
import org.kie.pmml.models.clustering.model.KiePMMLCluster;
import org.kie.pmml.models.clustering.model.KiePMMLClusteringField;
import org.kie.pmml.models.clustering.model.KiePMMLClusteringModel;
import org.kie.pmml.models.clustering.model.KiePMMLComparisonMeasure;
import org.kie.pmml.models.clustering.model.KiePMMLMissingValueWeights;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:BOOT-INF/lib/kie-pmml-models-clustering-compiler-8.40.1-SNAPSHOT.jar:org/kie/pmml/models/clustering/compiler/factories/KiePMMLClusteringModelFactory.class */
public class KiePMMLClusteringModelFactory {
    static final String KIE_PMML_CLUSTERING_MODEL_TEMPLATE_JAVA = "KiePMMLClusteringModelTemplate.tmpl";
    static final String KIE_PMML_CLUSTERING_MODEL_TEMPLATE = "KiePMMLClusteringModelTemplate";
    static final String GET_CLUSTERS = "getClusters";
    static final String GET_CLUSTERING_FIELDS = "getClusteringFields";
    private static final Logger logger = LoggerFactory.getLogger(KiePMMLClusteringModelFactory.class.getName());

    private KiePMMLClusteringModelFactory() {
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static KiePMMLClusteringModel getKiePMMLClusteringModel(ClusteringCompilationDTO clusteringCompilationDTO) {
        logger.trace("getKiePMMLClusteringModel {}", clusteringCompilationDTO);
        try {
            ClusteringModel model = clusteringCompilationDTO.getModel();
            KiePMMLClusteringModel.ModelClass modelClassFrom = KiePMMLClusteringConversionUtils.modelClassFrom(model.getModelClass());
            List<KiePMMLCluster> kiePMMLClusters = getKiePMMLClusters(model.getClusters());
            List<KiePMMLClusteringField> kiePMMLClusteringFields = getKiePMMLClusteringFields(model.getClusteringFields());
            KiePMMLComparisonMeasure kiePMMLComparisonMeasure = getKiePMMLComparisonMeasure(model.getComparisonMeasure());
            return (KiePMMLClusteringModel) KiePMMLClusteringModel.builder(clusteringCompilationDTO.getFileName(), clusteringCompilationDTO.getModelName(), clusteringCompilationDTO.getMINING_FUNCTION()).withModelClass(modelClassFrom).withClusters(kiePMMLClusters).withClusteringFields(kiePMMLClusteringFields).withComparisonMeasure(kiePMMLComparisonMeasure).withMissingValueWeights(getKiePMMLMissingValueWeights(model.getMissingValueWeights())).withTargetField2(clusteringCompilationDTO.getTargetFieldName()).withMiningFields(clusteringCompilationDTO.getKieMiningFields()).withOutputFields(clusteringCompilationDTO.getKieOutputFields()).withKiePMMLMiningFields(clusteringCompilationDTO.getKiePMMLMiningFields()).withKiePMMLOutputFields(clusteringCompilationDTO.getKiePMMLOutputFields()).withKiePMMLTargets(clusteringCompilationDTO.getKiePMMLTargetFields()).withKiePMMLTransformationDictionary(clusteringCompilationDTO.getKiePMMLTransformationDictionary()).withKiePMMLLocalTransformations(clusteringCompilationDTO.getKiePMMLLocalTransformations()).build();
        } catch (Exception e) {
            throw new KiePMMLException(e);
        }
    }

    public static Map<String, String> getKiePMMLClusteringModelSourcesMap(ClusteringCompilationDTO clusteringCompilationDTO) {
        logger.trace("getKiePMMLClusteringModelSourcesMap {}", clusteringCompilationDTO);
        String simpleClassName = clusteringCompilationDTO.getSimpleClassName();
        CompilationUnit kiePMMLModelCompilationUnit = JavaParserUtils.getKiePMMLModelCompilationUnit(simpleClassName, clusteringCompilationDTO.getPackageName(), KIE_PMML_CLUSTERING_MODEL_TEMPLATE_JAVA, KIE_PMML_CLUSTERING_MODEL_TEMPLATE);
        ClassOrInterfaceDeclaration orElseThrow = kiePMMLModelCompilationUnit.getClassByName(simpleClassName).orElseThrow(() -> {
            return new KiePMMLException("Main class not found: " + simpleClassName);
        });
        setStaticGetter(clusteringCompilationDTO, orElseThrow);
        populateGetClustersMethod(orElseThrow, clusteringCompilationDTO.getModel());
        populateGetClusteringFieldsMethod(orElseThrow, clusteringCompilationDTO.getModel());
        HashMap hashMap = new HashMap();
        hashMap.put(JavaParserUtils.getFullClassName(kiePMMLModelCompilationUnit), kiePMMLModelCompilationUnit.toString());
        return hashMap;
    }

    static List<KiePMMLCluster> getKiePMMLClusters(List<Cluster> list) {
        return list != null ? (List) list.stream().map(KiePMMLClusteringModelFactory::getKiePMMLCluster).collect(Collectors.toList()) : Collections.emptyList();
    }

    static KiePMMLCluster getKiePMMLCluster(Cluster cluster) {
        return new KiePMMLCluster(cluster.getId(), cluster.getName(), getClusterDoubleValues(cluster));
    }

    static List<KiePMMLClusteringField> getKiePMMLClusteringFields(List<ClusteringField> list) {
        return list != null ? (List) list.stream().map(KiePMMLClusteringModelFactory::getKiePMMLClusteringField).collect(Collectors.toList()) : Collections.emptyList();
    }

    static KiePMMLClusteringField getKiePMMLClusteringField(ClusteringField clusteringField) {
        double doubleValue = clusteringField.getFieldWeight() == null ? 1.0d : clusteringField.getFieldWeight().doubleValue();
        boolean z = clusteringField.getCenterField() == null || clusteringField.getCenterField() == ClusteringField.CenterField.TRUE;
        return new KiePMMLClusteringField(clusteringField.getField().getValue(), Double.valueOf(doubleValue), Boolean.valueOf(z), clusteringField.getCompareFunction() != null ? KiePMMLClusteringConversionUtils.compareFunctionFrom(clusteringField.getCompareFunction()) : null, null);
    }

    static KiePMMLComparisonMeasure getKiePMMLComparisonMeasure(ComparisonMeasure comparisonMeasure) {
        return new KiePMMLComparisonMeasure(KiePMMLClusteringConversionUtils.comparisonMeasureKindFrom(comparisonMeasure.getKind()), KiePMMLClusteringConversionUtils.aggregateFunctionFrom(comparisonMeasure.getMeasure()), KiePMMLClusteringConversionUtils.compareFunctionFrom(comparisonMeasure.getCompareFunction()));
    }

    static KiePMMLMissingValueWeights getKiePMMLMissingValueWeights(MissingValueWeights missingValueWeights) {
        if (missingValueWeights != null) {
            return new KiePMMLMissingValueWeights(getMissingValueWeightsDoubleValues(missingValueWeights));
        }
        return null;
    }

    static void setStaticGetter(CompilationDTO<ClusteringModel> compilationDTO, ClassOrInterfaceDeclaration classOrInterfaceDeclaration) {
        KiePMMLModelFactoryUtils.initStaticGetter(compilationDTO, classOrInterfaceDeclaration);
        BlockStmt methodDeclarationBlockStmt = CommonCodegenUtils.getMethodDeclarationBlockStmt(classOrInterfaceDeclaration, Constants.GET_MODEL);
        MethodCallExpr asMethodCallExpr = CommonCodegenUtils.getVariableDeclarator(methodDeclarationBlockStmt, Constants.TO_RETURN).orElseThrow(() -> {
            return new KiePMMLException(String.format("Missing expected variable '%s' in body %s", Constants.TO_RETURN, methodDeclarationBlockStmt));
        }).getInitializer().orElseThrow(() -> {
            return new KiePMMLException(String.format("Missing '%s' initializer in %s", Constants.TO_RETURN, methodDeclarationBlockStmt));
        }).asMethodCallExpr();
        ClusteringModel model = compilationDTO.getModel();
        CommonCodegenUtils.getChainedMethodCallExprFrom("withModelClass", asMethodCallExpr).setArgument(0, CommonCodegenUtils.literalExprFrom(KiePMMLClusteringConversionUtils.modelClassFrom(model.getModelClass())));
        CommonCodegenUtils.getChainedMethodCallExprFrom("withComparisonMeasure", asMethodCallExpr).setArgument(0, comparisonMeasureCreationExprFrom(model.getComparisonMeasure()));
        CommonCodegenUtils.getChainedMethodCallExprFrom("withMissingValueWeights", asMethodCallExpr).setArgument(0, model.getMissingValueWeights() != null ? missingValueWeightsCreationExprFrom(model.getMissingValueWeights()) : new NullLiteralExpr());
    }

    static void populateGetClustersMethod(ClassOrInterfaceDeclaration classOrInterfaceDeclaration, ClusteringModel clusteringModel) {
        CommonCodegenUtils.populateListInListGetter((List) clusteringModel.getClusters().stream().map(KiePMMLClusteringModelFactory::clusterCreationExprFrom).collect(Collectors.toList()), CommonCodegenUtils.getMethodDeclaration(classOrInterfaceDeclaration, GET_CLUSTERS).orElseThrow(() -> {
            return new KiePMMLInternalException(String.format(Constants.MISSING_METHOD_IN_CLASS, classOrInterfaceDeclaration, GET_CLUSTERS));
        }), Constants.TO_RETURN);
    }

    static void populateGetClusteringFieldsMethod(ClassOrInterfaceDeclaration classOrInterfaceDeclaration, ClusteringModel clusteringModel) {
        CommonCodegenUtils.populateListInListGetter((List) clusteringModel.getClusteringFields().stream().map(KiePMMLClusteringModelFactory::clusteringFieldCreationExprFrom).collect(Collectors.toList()), CommonCodegenUtils.getMethodDeclaration(classOrInterfaceDeclaration, GET_CLUSTERING_FIELDS).orElseThrow(() -> {
            return new KiePMMLInternalException(String.format(Constants.MISSING_METHOD_IN_CLASS, classOrInterfaceDeclaration, GET_CLUSTERING_FIELDS));
        }), Constants.TO_RETURN);
    }

    private static ObjectCreationExpr clusterCreationExprFrom(Cluster cluster) {
        NodeList nodeList = new NodeList();
        nodeList.add((NodeList) CommonCodegenUtils.literalExprFrom(cluster.getId()));
        nodeList.add((NodeList) CommonCodegenUtils.literalExprFrom(cluster.getName()));
        nodeList.add((NodeList) CommonCodegenUtils.createArraysAsListFromList(getClusterDoubleValues(cluster)).getExpression());
        return new ObjectCreationExpr(null, new ClassOrInterfaceType(null, KiePMMLCluster.class.getCanonicalName()), nodeList);
    }

    private static ObjectCreationExpr clusteringFieldCreationExprFrom(ClusteringField clusteringField) {
        double doubleValue = clusteringField.getFieldWeight() == null ? 1.0d : clusteringField.getFieldWeight().doubleValue();
        boolean z = clusteringField.getCenterField() == null || clusteringField.getCenterField() == ClusteringField.CenterField.TRUE;
        NodeList nodeList = new NodeList();
        nodeList.add((NodeList) CommonCodegenUtils.literalExprFrom(clusteringField.getField().getValue()));
        nodeList.add((NodeList) new DoubleLiteralExpr(doubleValue));
        nodeList.add((NodeList) new BooleanLiteralExpr(z));
        nodeList.add((NodeList) (clusteringField.getCompareFunction() == null ? new NullLiteralExpr() : CommonCodegenUtils.literalExprFrom(KiePMMLClusteringConversionUtils.compareFunctionFrom(clusteringField.getCompareFunction()))));
        nodeList.add((NodeList) new NullLiteralExpr());
        return new ObjectCreationExpr(null, new ClassOrInterfaceType(null, KiePMMLClusteringField.class.getCanonicalName()), nodeList);
    }

    private static ObjectCreationExpr comparisonMeasureCreationExprFrom(ComparisonMeasure comparisonMeasure) {
        NodeList nodeList = new NodeList();
        nodeList.add((NodeList) CommonCodegenUtils.literalExprFrom(KiePMMLClusteringConversionUtils.comparisonMeasureKindFrom(comparisonMeasure.getKind())));
        nodeList.add((NodeList) CommonCodegenUtils.literalExprFrom(KiePMMLClusteringConversionUtils.aggregateFunctionFrom(comparisonMeasure.getMeasure())));
        nodeList.add((NodeList) CommonCodegenUtils.literalExprFrom(KiePMMLClusteringConversionUtils.compareFunctionFrom(comparisonMeasure.getCompareFunction())));
        return new ObjectCreationExpr(null, new ClassOrInterfaceType(null, KiePMMLComparisonMeasure.class.getCanonicalName()), nodeList);
    }

    private static ObjectCreationExpr missingValueWeightsCreationExprFrom(MissingValueWeights missingValueWeights) {
        NodeList nodeList = new NodeList();
        nodeList.add((NodeList) CommonCodegenUtils.createArraysAsListFromList(getMissingValueWeightsDoubleValues(missingValueWeights)).getExpression());
        return new ObjectCreationExpr(null, new ClassOrInterfaceType(null, KiePMMLMissingValueWeights.class.getCanonicalName()), nodeList);
    }

    private static List<Double> getClusterDoubleValues(Cluster cluster) {
        return cluster.getArray() != null ? getDoubleValuesFromArray(cluster.getArray()) : Collections.emptyList();
    }

    private static List<Double> getMissingValueWeightsDoubleValues(MissingValueWeights missingValueWeights) {
        return missingValueWeights.getArray() != null ? getDoubleValuesFromArray(missingValueWeights.getArray()) : Collections.emptyList();
    }

    private static List<Double> getDoubleValuesFromArray(Array array) {
        ArrayList arrayList = new ArrayList();
        if (array.getType() == Array.Type.REAL) {
            String str = (String) array.getValue();
            try {
                Stream map = Arrays.stream(str.split(StringUtils.SPACE)).map(Double::parseDouble);
                Objects.requireNonNull(arrayList);
                map.forEach((v1) -> {
                    r1.add(v1);
                });
            } catch (NumberFormatException e) {
                logger.error("Can't parse \"real\" cluster with value \"" + str + "\"", (Throwable) e);
            }
        }
        return arrayList;
    }
}
