/*
 * Decompiled with CFR 0.152.
 */
package org.kie.pmml.models.clustering.compiler.factories;

import com.github.javaparser.ast.CompilationUnit;
import com.github.javaparser.ast.Node;
import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration;
import com.github.javaparser.ast.body.MethodDeclaration;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.dmg.pmml.Array;
import org.dmg.pmml.CompareFunction;
import org.dmg.pmml.ComparisonMeasure;
import org.dmg.pmml.DataDictionary;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Euclidean;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.Measure;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PMML;
import org.dmg.pmml.TransformationDictionary;
import org.dmg.pmml.clustering.Cluster;
import org.dmg.pmml.clustering.ClusteringField;
import org.dmg.pmml.clustering.ClusteringModel;
import org.dmg.pmml.clustering.MissingValueWeights;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import org.kie.pmml.commons.model.HasClassLoader;
import org.kie.pmml.compiler.api.dto.CommonCompilationDTO;
import org.kie.pmml.compiler.api.dto.CompilationDTO;
import org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils;
import org.kie.pmml.compiler.api.utils.ModelUtils;
import org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock;
import org.kie.pmml.compiler.commons.utils.JavaParserUtils;
import org.kie.pmml.models.clustering.compiler.dto.ClusteringCompilationDTO;
import org.kie.pmml.models.clustering.compiler.factories.KiePMMLClusteringConversionUtils;
import org.kie.pmml.models.clustering.compiler.factories.KiePMMLClusteringModelFactory;
import org.kie.pmml.models.clustering.model.KiePMMLAggregateFunction;
import org.kie.pmml.models.clustering.model.KiePMMLCluster;
import org.kie.pmml.models.clustering.model.KiePMMLClusteringField;
import org.kie.pmml.models.clustering.model.KiePMMLClusteringModel;
import org.kie.pmml.models.clustering.model.KiePMMLCompareFunction;
import org.kie.pmml.models.clustering.model.KiePMMLComparisonMeasure;
import org.kie.pmml.models.clustering.model.KiePMMLMissingValueWeights;
import org.kie.test.util.filesystem.FileUtils;

public class KiePMMLClusteringModelFactoryTest {
    private static CompilationUnit COMPILATION_UNIT;
    private static ClassOrInterfaceDeclaration MODEL_TEMPLATE;
    private static final String TEST_01_SOURCE = "KiePMMLClusteringModelFactoryTest_01.txt";
    private static final String modelName = "firstModel";
    private static List<ClusteringField> clusteringFields;
    private static List<Cluster> clusters;
    private static List<DataField> dataFields;
    private static List<MiningField> miningFields;
    private static MiningField targetMiningField;
    private static DataDictionary dataDictionary;
    private static TransformationDictionary transformationDictionary;
    private static MiningSchema miningSchema;
    private static ClusteringModel clusteringModel;
    private static PMML pmml;

    @BeforeClass
    public static void setup() {
        HashSet fieldNames = new HashSet();
        clusteringFields = new ArrayList<ClusteringField>();
        clusters = new ArrayList<Cluster>();
        IntStream.range(0, 3).forEach(i -> {
            ClusteringField clusteringField = PMMLModelTestUtils.getRandomClusteringField();
            clusteringFields.add(clusteringField);
            fieldNames.add(clusteringField.getField().getValue());
            clusters.add(PMMLModelTestUtils.getRandomCluster());
        });
        dataFields = new ArrayList<DataField>();
        miningFields = new ArrayList<MiningField>();
        fieldNames.forEach(fieldName -> {
            dataFields.add(PMMLModelTestUtils.getDataField((String)fieldName, (OpType)OpType.CATEGORICAL, (DataType)DataType.STRING));
            miningFields.add(PMMLModelTestUtils.getMiningField((String)fieldName, (MiningField.UsageType)MiningField.UsageType.ACTIVE));
        });
        targetMiningField = miningFields.get(0);
        targetMiningField.setUsageType(MiningField.UsageType.TARGET);
        dataDictionary = PMMLModelTestUtils.getDataDictionary(dataFields);
        transformationDictionary = new TransformationDictionary();
        miningSchema = PMMLModelTestUtils.getMiningSchema(miningFields);
        clusteringModel = PMMLModelTestUtils.getClusteringModel((String)modelName, (MiningFunction)MiningFunction.CLUSTERING, (MiningSchema)miningSchema, clusteringFields, clusters);
        COMPILATION_UNIT = JavaParserUtils.getFromFileName((String)"KiePMMLClusteringModelTemplate.tmpl");
        MODEL_TEMPLATE = (ClassOrInterfaceDeclaration)COMPILATION_UNIT.getClassByName("KiePMMLClusteringModelTemplate").get();
        pmml = new PMML();
        pmml.setDataDictionary(dataDictionary);
        pmml.setTransformationDictionary(transformationDictionary);
        pmml.addModels(new Model[]{clusteringModel});
    }

    @Test
    public void getKiePMMLClusteringModel() {
        CommonCompilationDTO compilationDTO = CommonCompilationDTO.fromGeneratedPackageNameAndFields((String)"PACKAGE_NAME", (PMML)pmml, (Model)clusteringModel, (HasClassLoader)new HasClassLoaderMock());
        KiePMMLClusteringModel retrieved = KiePMMLClusteringModelFactory.getKiePMMLClusteringModel((ClusteringCompilationDTO)ClusteringCompilationDTO.fromCompilationDTO((CompilationDTO)compilationDTO));
        Assert.assertNotNull((Object)retrieved);
        Assert.assertEquals((Object)clusteringModel.getModelName(), (Object)retrieved.getName());
        Assert.assertEquals((Object)clusteringModel.getModelClass().value(), (Object)retrieved.getModelClass().getName());
        List retrievedClusters = retrieved.getClusters();
        Assert.assertEquals((long)clusteringModel.getClusters().size(), (long)retrievedClusters.size());
        IntStream.range(0, clusteringModel.getClusters().size()).forEach(i -> this.commonEvaluateKiePMMLCluster((KiePMMLCluster)retrievedClusters.get(i), (Cluster)clusteringModel.getClusters().get(i)));
        List retrievedClusteringFields = retrieved.getClusteringFields();
        Assert.assertEquals((long)clusteringModel.getClusters().size(), (long)retrievedClusters.size());
        IntStream.range(0, clusteringModel.getClusters().size()).forEach(i -> this.commonEvaluateKiePMMLCluster((KiePMMLCluster)retrievedClusters.get(i), (Cluster)clusteringModel.getClusters().get(i)));
        Assert.assertEquals((long)clusteringModel.getClusteringFields().size(), (long)retrievedClusteringFields.size());
        IntStream.range(0, clusteringModel.getClusteringFields().size()).forEach(i -> this.commonEvaluateKiePMMLClusteringField((KiePMMLClusteringField)retrievedClusteringFields.get(i), (ClusteringField)clusteringModel.getClusteringFields().get(i)));
        this.commonEvaluateKiePMMLComparisonMeasure(retrieved.getComparisonMeasure(), clusteringModel.getComparisonMeasure());
        this.commonEvaluateKiePMMLMissingValueWeights(retrieved.getMissingValueWeights(), clusteringModel.getMissingValueWeights());
    }

    @Test
    public void getKiePMMLClusteringModelSourcesMap() {
        CommonCompilationDTO compilationDTO = CommonCompilationDTO.fromGeneratedPackageNameAndFields((String)"PACKAGE_NAME", (PMML)pmml, (Model)clusteringModel, (HasClassLoader)new HasClassLoaderMock());
        Map retrieved = KiePMMLClusteringModelFactory.getKiePMMLClusteringModelSourcesMap((ClusteringCompilationDTO)ClusteringCompilationDTO.fromCompilationDTO((CompilationDTO)compilationDTO));
        Assert.assertNotNull((Object)retrieved);
        Assert.assertEquals((long)1L, (long)retrieved.size());
    }

    @Test
    public void getKiePMMLCluster() {
        Cluster cluster = new Cluster();
        cluster.setId("ID");
        cluster.setName("NAME");
        Random random = new Random();
        List doubleValues = IntStream.range(0, 3).mapToObj(i -> random.nextDouble()).collect(Collectors.toList());
        List values = doubleValues.stream().map(String::valueOf).collect(Collectors.toList());
        Array array = PMMLModelTestUtils.getArray((Array.Type)Array.Type.REAL, values);
        cluster.setArray(array);
        KiePMMLCluster retrieved = KiePMMLClusteringModelFactory.getKiePMMLCluster((Cluster)cluster);
        this.commonEvaluateKiePMMLCluster(retrieved, cluster);
    }

    @Test
    public void getKiePMMLClusteringField() {
        ClusteringField clusteringField = new ClusteringField();
        Random random = new Random();
        clusteringField.setField(FieldName.create((String)"TEXT"));
        clusteringField.setFieldWeight((Number)random.nextDouble());
        clusteringField.setCenterField((ClusteringField.CenterField)PMMLModelTestUtils.getRandomEnum((Enum[])ClusteringField.CenterField.values()));
        clusteringField.setCompareFunction((CompareFunction)PMMLModelTestUtils.getRandomEnum((Enum[])CompareFunction.values()));
        KiePMMLClusteringField retrieved = KiePMMLClusteringModelFactory.getKiePMMLClusteringField((ClusteringField)clusteringField);
        this.commonEvaluateKiePMMLClusteringField(retrieved, clusteringField);
    }

    @Test
    public void getKiePMMLComparisonMeasure() {
        ComparisonMeasure comparisonMeasure = new ComparisonMeasure();
        PMMLModelTestUtils.getRandomEnum((Enum[])ComparisonMeasure.Kind.values());
        comparisonMeasure.setKind((ComparisonMeasure.Kind)PMMLModelTestUtils.getRandomEnum((Enum[])ComparisonMeasure.Kind.values()));
        comparisonMeasure.setCompareFunction((CompareFunction)PMMLModelTestUtils.getRandomEnum((Enum[])CompareFunction.values()));
        Random random = new Random();
        comparisonMeasure.setMinimum((Number)random.nextInt(10));
        comparisonMeasure.setMaximum((Number)(comparisonMeasure.getMinimum().intValue() + random.nextInt(10)));
        comparisonMeasure.setMeasure((Measure)new Euclidean());
        KiePMMLComparisonMeasure retrieved = KiePMMLClusteringModelFactory.getKiePMMLComparisonMeasure((ComparisonMeasure)comparisonMeasure);
        Assert.assertEquals((Object)KiePMMLAggregateFunction.EUCLIDEAN, (Object)retrieved.getAggregateFunction());
        this.commonEvaluateKiePMMLComparisonMeasure(retrieved, comparisonMeasure);
    }

    @Test
    public void getKiePMMLMissingValueWeights() {
        Assert.assertNull((Object)KiePMMLClusteringModelFactory.getKiePMMLMissingValueWeights(null));
        KiePMMLMissingValueWeights retrieved = KiePMMLClusteringModelFactory.getKiePMMLMissingValueWeights((MissingValueWeights)new MissingValueWeights());
        Assert.assertNotNull((Object)retrieved);
        Assert.assertNotNull((Object)retrieved.getValues());
        Assert.assertTrue((boolean)retrieved.getValues().isEmpty());
        MissingValueWeights missingValueWeights = new MissingValueWeights();
        Random random = new Random();
        List doubleValues = IntStream.range(0, 3).mapToObj(i -> random.nextDouble()).collect(Collectors.toList());
        List values = doubleValues.stream().map(String::valueOf).collect(Collectors.toList());
        Array array = PMMLModelTestUtils.getArray((Array.Type)Array.Type.REAL, values);
        missingValueWeights.setArray(array);
        retrieved = KiePMMLClusteringModelFactory.getKiePMMLMissingValueWeights((MissingValueWeights)missingValueWeights);
        this.commonEvaluateKiePMMLMissingValueWeights(retrieved, missingValueWeights);
    }

    @Test
    public void setStaticGetter() throws IOException {
        ClassOrInterfaceDeclaration modelTemplate = MODEL_TEMPLATE.clone();
        CommonCompilationDTO compilationDTO = CommonCompilationDTO.fromGeneratedPackageNameAndFields((String)"PACKAGE_NAME", (PMML)pmml, (Model)clusteringModel, (HasClassLoader)new HasClassLoaderMock());
        String expectedModelClass = KiePMMLClusteringModel.ModelClass.class.getCanonicalName() + "." + clusteringModel.getModelClass().name();
        ComparisonMeasure comparisonMeasure = clusteringModel.getComparisonMeasure();
        String expectedKind = KiePMMLComparisonMeasure.Kind.class.getCanonicalName() + "." + comparisonMeasure.getKind().name();
        String expectedAggregateFunction = KiePMMLAggregateFunction.class.getCanonicalName() + "." + ((KiePMMLAggregateFunction)KiePMMLClusteringConversionUtils.AGGREGATE_FN_MAP.get(comparisonMeasure.getMeasure().getClass())).name();
        String expectedCompareFunction = KiePMMLCompareFunction.class.getCanonicalName() + "." + comparisonMeasure.getCompareFunction().name();
        String expectedTargetField = targetMiningField.getName().getValue();
        KiePMMLClusteringModelFactory.setStaticGetter((CompilationDTO)compilationDTO, (ClassOrInterfaceDeclaration)modelTemplate);
        MethodDeclaration retrieved = (MethodDeclaration)modelTemplate.getMethodsByName("getModel").get(0);
        String text = String.format(FileUtils.getFileContent((String)TEST_01_SOURCE), expectedModelClass, expectedKind, expectedAggregateFunction, expectedCompareFunction, expectedTargetField);
        MethodDeclaration expected = JavaParserUtils.parseMethod((String)text);
        Assert.assertTrue((boolean)JavaParserUtils.equalsNode((Node)expected, (Node)retrieved));
    }

    private void commonEvaluateKiePMMLCluster(KiePMMLCluster retrieved, Cluster cluster) {
        Assert.assertNotNull((Object)retrieved);
        Assert.assertTrue((boolean)retrieved.getId().isPresent());
        Assert.assertEquals((Object)cluster.getId(), retrieved.getId().get());
        Assert.assertTrue((boolean)retrieved.getName().isPresent());
        Assert.assertEquals((Object)cluster.getName(), retrieved.getName().get());
        this.commonEvaluateDoubles(retrieved.getValues(), cluster.getArray());
    }

    private void commonEvaluateKiePMMLClusteringField(KiePMMLClusteringField retrieved, ClusteringField clusteringField) {
        Assert.assertNotNull((Object)retrieved);
        boolean isCenterField = clusteringField.getCenterField() == ClusteringField.CenterField.TRUE;
        Assert.assertEquals((Object)clusteringField.getField().getValue(), (Object)retrieved.getField());
        Assert.assertEquals((Object)clusteringField.getFieldWeight(), (Object)retrieved.getFieldWeight());
        Assert.assertEquals((Object)isCenterField, (Object)retrieved.getCenterField());
        Assert.assertTrue((boolean)retrieved.getCompareFunction().isPresent());
        Assert.assertEquals((Object)clusteringField.getCompareFunction().value(), (Object)((KiePMMLCompareFunction)retrieved.getCompareFunction().get()).getName());
    }

    private void commonEvaluateKiePMMLComparisonMeasure(KiePMMLComparisonMeasure retrieved, ComparisonMeasure comparisonMeasure) {
        Assert.assertEquals((Object)comparisonMeasure.getKind().value(), (Object)retrieved.getKind().getName());
        Assert.assertEquals((Object)comparisonMeasure.getCompareFunction().value(), (Object)retrieved.getCompareFunction().getName());
    }

    private void commonEvaluateKiePMMLMissingValueWeights(KiePMMLMissingValueWeights retrieved, MissingValueWeights missingValueWeights) {
        Assert.assertNotNull((Object)retrieved);
        this.commonEvaluateDoubles(retrieved.getValues(), missingValueWeights.getArray());
    }

    private void commonEvaluateDoubles(List<Double> retrievedValues, Array array) {
        if (array != null) {
            List doubleValues = ModelUtils.getObjectsFromArray((Array)array);
            Assert.assertNotNull(retrievedValues);
            Assert.assertEquals((long)doubleValues.size(), (long)retrievedValues.size());
            IntStream.range(0, doubleValues.size()).forEach(i -> Assert.assertEquals(doubleValues.get(i), retrievedValues.get(i)));
        }
    }
}

