package org.kie.pmml.models.clustering.compiler.factories;

import com.github.javaparser.ast.CompilationUnit;
import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration;
import com.github.javaparser.ast.body.MethodDeclaration;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.assertj.core.api.Assertions;
import org.dmg.pmml.Array;
import org.dmg.pmml.CompareFunction;
import org.dmg.pmml.ComparisonMeasure;
import org.dmg.pmml.DataDictionary;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Euclidean;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PMML;
import org.dmg.pmml.TransformationDictionary;
import org.dmg.pmml.clustering.Cluster;
import org.dmg.pmml.clustering.ClusteringField;
import org.dmg.pmml.clustering.ClusteringModel;
import org.dmg.pmml.clustering.MissingValueWeights;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.kie.pmml.compiler.api.dto.CommonCompilationDTO;
import org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils;
import org.kie.pmml.compiler.api.utils.ModelUtils;
import org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock;
import org.kie.pmml.compiler.commons.utils.JavaParserUtils;
import org.kie.pmml.models.clustering.compiler.dto.ClusteringCompilationDTO;
import org.kie.pmml.models.clustering.model.KiePMMLAggregateFunction;
import org.kie.pmml.models.clustering.model.KiePMMLCluster;
import org.kie.pmml.models.clustering.model.KiePMMLClusteringField;
import org.kie.pmml.models.clustering.model.KiePMMLClusteringModel;
import org.kie.pmml.models.clustering.model.KiePMMLCompareFunction;
import org.kie.pmml.models.clustering.model.KiePMMLComparisonMeasure;
import org.kie.pmml.models.clustering.model.KiePMMLMissingValueWeights;
import org.kie.test.util.filesystem.FileUtils;

/* loaded from: input_file:org/kie/pmml/models/clustering/compiler/factories/KiePMMLClusteringModelFactoryTest.class */
public class KiePMMLClusteringModelFactoryTest {
    private static CompilationUnit COMPILATION_UNIT;
    private static ClassOrInterfaceDeclaration MODEL_TEMPLATE;
    private static final String TEST_01_SOURCE = "KiePMMLClusteringModelFactoryTest_01.txt";
    private static final String modelName = "firstModel";
    private static List<ClusteringField> clusteringFields;
    private static List<Cluster> clusters;
    private static List<DataField> dataFields;
    private static List<MiningField> miningFields;
    private static MiningField targetMiningField;
    private static DataDictionary dataDictionary;
    private static TransformationDictionary transformationDictionary;
    private static MiningSchema miningSchema;
    private static ClusteringModel clusteringModel;
    private static PMML pmml;

    @BeforeAll
    public static void setup() {
        HashSet hashSet = new HashSet();
        clusteringFields = new ArrayList();
        clusters = new ArrayList();
        IntStream.range(0, 3).forEach(i -> {
            ClusteringField randomClusteringField = PMMLModelTestUtils.getRandomClusteringField();
            clusteringFields.add(randomClusteringField);
            hashSet.add(randomClusteringField.getField().getValue());
            clusters.add(PMMLModelTestUtils.getRandomCluster());
        });
        dataFields = new ArrayList();
        miningFields = new ArrayList();
        hashSet.forEach(str -> {
            dataFields.add(PMMLModelTestUtils.getDataField(str, OpType.CATEGORICAL, DataType.STRING));
            miningFields.add(PMMLModelTestUtils.getMiningField(str, MiningField.UsageType.ACTIVE));
        });
        targetMiningField = miningFields.get(0);
        targetMiningField.setUsageType(MiningField.UsageType.TARGET);
        dataDictionary = PMMLModelTestUtils.getDataDictionary(dataFields);
        transformationDictionary = new TransformationDictionary();
        miningSchema = PMMLModelTestUtils.getMiningSchema(miningFields);
        clusteringModel = PMMLModelTestUtils.getClusteringModel(modelName, MiningFunction.CLUSTERING, miningSchema, clusteringFields, clusters);
        COMPILATION_UNIT = JavaParserUtils.getFromFileName("KiePMMLClusteringModelTemplate.tmpl");
        MODEL_TEMPLATE = (ClassOrInterfaceDeclaration) COMPILATION_UNIT.getClassByName("KiePMMLClusteringModelTemplate").get();
        pmml = new PMML();
        pmml.setDataDictionary(dataDictionary);
        pmml.setTransformationDictionary(transformationDictionary);
        pmml.addModels(new Model[]{clusteringModel});
    }

    @Test
    void getKiePMMLClusteringModel() {
        KiePMMLClusteringModel kiePMMLClusteringModel = KiePMMLClusteringModelFactory.getKiePMMLClusteringModel(ClusteringCompilationDTO.fromCompilationDTO(CommonCompilationDTO.fromGeneratedPackageNameAndFields("PACKAGE_NAME", pmml, clusteringModel, new HasClassLoaderMock())));
        Assertions.assertThat(kiePMMLClusteringModel).isNotNull();
        Assertions.assertThat(kiePMMLClusteringModel.getName()).isEqualTo(clusteringModel.getModelName());
        Assertions.assertThat(kiePMMLClusteringModel.getModelClass().getName()).isEqualTo(clusteringModel.getModelClass().value());
        List clusters2 = kiePMMLClusteringModel.getClusters();
        Assertions.assertThat(clusters2).hasSameSizeAs(clusteringModel.getClusters());
        IntStream.range(0, clusteringModel.getClusters().size()).forEach(i -> {
            commonEvaluateKiePMMLCluster((KiePMMLCluster) clusters2.get(i), (Cluster) clusteringModel.getClusters().get(i));
        });
        List clusteringFields2 = kiePMMLClusteringModel.getClusteringFields();
        Assertions.assertThat(clusters2).hasSameSizeAs(clusteringModel.getClusters());
        IntStream.range(0, clusteringModel.getClusters().size()).forEach(i2 -> {
            commonEvaluateKiePMMLCluster((KiePMMLCluster) clusters2.get(i2), (Cluster) clusteringModel.getClusters().get(i2));
        });
        Assertions.assertThat(clusteringFields2).hasSameSizeAs(clusteringModel.getClusteringFields());
        IntStream.range(0, clusteringModel.getClusteringFields().size()).forEach(i3 -> {
            commonEvaluateKiePMMLClusteringField((KiePMMLClusteringField) clusteringFields2.get(i3), (ClusteringField) clusteringModel.getClusteringFields().get(i3));
        });
        commonEvaluateKiePMMLComparisonMeasure(kiePMMLClusteringModel.getComparisonMeasure(), clusteringModel.getComparisonMeasure());
        commonEvaluateKiePMMLMissingValueWeights(kiePMMLClusteringModel.getMissingValueWeights(), clusteringModel.getMissingValueWeights());
    }

    @Test
    void getKiePMMLClusteringModelSourcesMap() {
        Map kiePMMLClusteringModelSourcesMap = KiePMMLClusteringModelFactory.getKiePMMLClusteringModelSourcesMap(ClusteringCompilationDTO.fromCompilationDTO(CommonCompilationDTO.fromGeneratedPackageNameAndFields("PACKAGE_NAME", pmml, clusteringModel, new HasClassLoaderMock())));
        Assertions.assertThat(kiePMMLClusteringModelSourcesMap).isNotNull();
        Assertions.assertThat(kiePMMLClusteringModelSourcesMap).hasSize(1);
    }

    @Test
    void getKiePMMLCluster() {
        Cluster cluster = new Cluster();
        cluster.setId("ID");
        cluster.setName("NAME");
        Random random = new Random();
        cluster.setArray(PMMLModelTestUtils.getArray(Array.Type.REAL, (List) ((List) IntStream.range(0, 3).mapToObj(i -> {
            return Double.valueOf(random.nextDouble());
        }).collect(Collectors.toList())).stream().map((v0) -> {
            return String.valueOf(v0);
        }).collect(Collectors.toList())));
        commonEvaluateKiePMMLCluster(KiePMMLClusteringModelFactory.getKiePMMLCluster(cluster), cluster);
    }

    @Test
    void getKiePMMLClusteringField() {
        ClusteringField clusteringField = new ClusteringField();
        Random random = new Random();
        clusteringField.setField(FieldName.create("TEXT"));
        clusteringField.setFieldWeight(Double.valueOf(random.nextDouble()));
        clusteringField.setCenterField(PMMLModelTestUtils.getRandomEnum(ClusteringField.CenterField.values()));
        clusteringField.setCompareFunction(PMMLModelTestUtils.getRandomEnum(CompareFunction.values()));
        commonEvaluateKiePMMLClusteringField(KiePMMLClusteringModelFactory.getKiePMMLClusteringField(clusteringField), clusteringField);
    }

    @Test
    void getKiePMMLComparisonMeasure() {
        ComparisonMeasure comparisonMeasure = new ComparisonMeasure();
        PMMLModelTestUtils.getRandomEnum(ComparisonMeasure.Kind.values());
        comparisonMeasure.setKind(PMMLModelTestUtils.getRandomEnum(ComparisonMeasure.Kind.values()));
        comparisonMeasure.setCompareFunction(PMMLModelTestUtils.getRandomEnum(CompareFunction.values()));
        Random random = new Random();
        comparisonMeasure.setMinimum(Integer.valueOf(random.nextInt(10)));
        comparisonMeasure.setMaximum(Integer.valueOf(comparisonMeasure.getMinimum().intValue() + random.nextInt(10)));
        comparisonMeasure.setMeasure(new Euclidean());
        KiePMMLComparisonMeasure kiePMMLComparisonMeasure = KiePMMLClusteringModelFactory.getKiePMMLComparisonMeasure(comparisonMeasure);
        Assertions.assertThat(kiePMMLComparisonMeasure.getAggregateFunction()).isEqualTo(KiePMMLAggregateFunction.EUCLIDEAN);
        commonEvaluateKiePMMLComparisonMeasure(kiePMMLComparisonMeasure, comparisonMeasure);
    }

    @Test
    void getKiePMMLMissingValueWeights() {
        Assertions.assertThat(KiePMMLClusteringModelFactory.getKiePMMLMissingValueWeights((MissingValueWeights) null)).isNull();
        KiePMMLMissingValueWeights kiePMMLMissingValueWeights = KiePMMLClusteringModelFactory.getKiePMMLMissingValueWeights(new MissingValueWeights());
        Assertions.assertThat(kiePMMLMissingValueWeights).isNotNull();
        Assertions.assertThat(kiePMMLMissingValueWeights.getValues()).isNotNull();
        Assertions.assertThat(kiePMMLMissingValueWeights.getValues()).isEmpty();
        MissingValueWeights missingValueWeights = new MissingValueWeights();
        Random random = new Random();
        missingValueWeights.setArray(PMMLModelTestUtils.getArray(Array.Type.REAL, (List) ((List) IntStream.range(0, 3).mapToObj(i -> {
            return Double.valueOf(random.nextDouble());
        }).collect(Collectors.toList())).stream().map((v0) -> {
            return String.valueOf(v0);
        }).collect(Collectors.toList())));
        commonEvaluateKiePMMLMissingValueWeights(KiePMMLClusteringModelFactory.getKiePMMLMissingValueWeights(missingValueWeights), missingValueWeights);
    }

    @Test
    void setStaticGetter() throws IOException {
        ClassOrInterfaceDeclaration clone = MODEL_TEMPLATE.clone();
        CommonCompilationDTO fromGeneratedPackageNameAndFields = CommonCompilationDTO.fromGeneratedPackageNameAndFields("PACKAGE_NAME", pmml, clusteringModel, new HasClassLoaderMock());
        String str = KiePMMLClusteringModel.ModelClass.class.getCanonicalName() + "." + clusteringModel.getModelClass().name();
        ComparisonMeasure comparisonMeasure = clusteringModel.getComparisonMeasure();
        String str2 = KiePMMLComparisonMeasure.Kind.class.getCanonicalName() + "." + comparisonMeasure.getKind().name();
        String str3 = KiePMMLAggregateFunction.class.getCanonicalName() + "." + ((KiePMMLAggregateFunction) KiePMMLClusteringConversionUtils.AGGREGATE_FN_MAP.get(comparisonMeasure.getMeasure().getClass())).name();
        String str4 = KiePMMLCompareFunction.class.getCanonicalName() + "." + comparisonMeasure.getCompareFunction().name();
        String value = targetMiningField.getName().getValue();
        KiePMMLClusteringModelFactory.setStaticGetter(fromGeneratedPackageNameAndFields, clone);
        Assertions.assertThat(JavaParserUtils.equalsNode(JavaParserUtils.parseMethod(String.format(FileUtils.getFileContent(TEST_01_SOURCE), str, str2, str3, str4, value)), (MethodDeclaration) clone.getMethodsByName("getModel").get(0))).isTrue();
    }

    private void commonEvaluateKiePMMLCluster(KiePMMLCluster kiePMMLCluster, Cluster cluster) {
        Assertions.assertThat(kiePMMLCluster).isNotNull();
        Assertions.assertThat(kiePMMLCluster.getId()).isPresent();
        Assertions.assertThat((String) kiePMMLCluster.getId().get()).isEqualTo(cluster.getId());
        Assertions.assertThat(kiePMMLCluster.getName()).isPresent();
        Assertions.assertThat((String) kiePMMLCluster.getName().get()).isEqualTo(cluster.getName());
        commonEvaluateDoubles(kiePMMLCluster.getValues(), cluster.getArray());
    }

    private void commonEvaluateKiePMMLClusteringField(KiePMMLClusteringField kiePMMLClusteringField, ClusteringField clusteringField) {
        Assertions.assertThat(kiePMMLClusteringField).isNotNull();
        boolean z = clusteringField.getCenterField() == ClusteringField.CenterField.TRUE;
        Assertions.assertThat(kiePMMLClusteringField.getField()).isEqualTo(clusteringField.getField().getValue());
        Assertions.assertThat(kiePMMLClusteringField.getFieldWeight()).isEqualTo(clusteringField.getFieldWeight());
        Assertions.assertThat(kiePMMLClusteringField.getCenterField()).isEqualTo(z);
        Assertions.assertThat(kiePMMLClusteringField.getCompareFunction()).isPresent();
        Assertions.assertThat(((KiePMMLCompareFunction) kiePMMLClusteringField.getCompareFunction().get()).getName()).isEqualTo(clusteringField.getCompareFunction().value());
    }

    private void commonEvaluateKiePMMLComparisonMeasure(KiePMMLComparisonMeasure kiePMMLComparisonMeasure, ComparisonMeasure comparisonMeasure) {
        Assertions.assertThat(kiePMMLComparisonMeasure.getKind().getName()).isEqualTo(comparisonMeasure.getKind().value());
        Assertions.assertThat(kiePMMLComparisonMeasure.getCompareFunction().getName()).isEqualTo(comparisonMeasure.getCompareFunction().value());
    }

    private void commonEvaluateKiePMMLMissingValueWeights(KiePMMLMissingValueWeights kiePMMLMissingValueWeights, MissingValueWeights missingValueWeights) {
        Assertions.assertThat(kiePMMLMissingValueWeights).isNotNull();
        commonEvaluateDoubles(kiePMMLMissingValueWeights.getValues(), missingValueWeights.getArray());
    }

    private void commonEvaluateDoubles(List<Double> list, Array array) {
        if (array != null) {
            List objectsFromArray = ModelUtils.getObjectsFromArray(array);
            Assertions.assertThat(list).isNotNull();
            Assertions.assertThat(list).hasSameSizeAs(objectsFromArray);
            IntStream.range(0, objectsFromArray.size()).forEach(i -> {
                Assertions.assertThat((Double) list.get(i)).isEqualTo(objectsFromArray.get(i));
            });
        }
    }
}
