package org.kie.pmml.models.regression.compiler.factories;

import com.github.javaparser.ast.CompilationUnit;
import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration;
import com.github.javaparser.ast.body.ConstructorDeclaration;
import com.github.javaparser.ast.expr.NameExpr;
import com.github.javaparser.ast.expr.ObjectCreationExpr;
import com.github.javaparser.ast.expr.StringLiteralExpr;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.dmg.pmml.DataDictionary;
import org.dmg.pmml.DataField;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.OpType;
import org.dmg.pmml.TransformationDictionary;
import org.dmg.pmml.regression.CategoricalPredictor;
import org.dmg.pmml.regression.NumericPredictor;
import org.dmg.pmml.regression.PredictorTerm;
import org.dmg.pmml.regression.RegressionModel;
import org.dmg.pmml.regression.RegressionTable;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import org.kie.pmml.api.enums.MINING_FUNCTION;
import org.kie.pmml.api.enums.OP_TYPE;
import org.kie.pmml.api.enums.PMML_MODEL;
import org.kie.pmml.commons.utils.KiePMMLModelUtils;
import org.kie.pmml.compiler.commons.testutils.CodegenTestUtils;
import org.kie.pmml.compiler.commons.testutils.PMMLModelTestUtils;
import org.kie.pmml.compiler.commons.utils.JavaParserUtils;
import org.kie.pmml.models.regression.model.KiePMMLRegressionClassificationTable;
import org.kie.pmml.models.regression.model.KiePMMLRegressionModel;
import org.kie.pmml.models.regression.model.KiePMMLRegressionTable;
import org.kie.pmml.models.regression.model.enums.REGRESSION_NORMALIZATION_METHOD;

/* loaded from: input_file:org/kie/pmml/models/regression/compiler/factories/KiePMMLRegressionModelFactoryTest.class */
public class KiePMMLRegressionModelFactoryTest {
    private static final String PACKAGE_NAME = "packagename";
    private static CompilationUnit COMPILATION_UNIT;
    private static ClassOrInterfaceDeclaration MODEL_TEMPLATE;
    private static final String modelName = "firstModel";
    private static final double tableIntercept = 3.5d;
    private static final Object tableTargetCategory = "professional";
    private static List<RegressionTable> regressionTables;
    private static List<DataField> dataFields;
    private static List<MiningField> miningFields;
    private static MiningField targetMiningField;
    private static DataDictionary dataDictionary;
    private static TransformationDictionary transformationDictionary;
    private static MiningSchema miningSchema;
    private static RegressionModel regressionModel;

    @BeforeClass
    public static void setup() {
        Random random = new Random();
        HashSet hashSet = new HashSet();
        regressionTables = (List) IntStream.range(0, 3).mapToObj(i -> {
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            ArrayList arrayList3 = new ArrayList();
            IntStream.range(0, 3).forEach(i -> {
                String str = "CatPred-" + i;
                String str2 = "NumPred-" + i;
                arrayList.add(PMMLModelTestUtils.getCategoricalPredictor(str, random.nextDouble(), random.nextDouble()));
                arrayList2.add(PMMLModelTestUtils.getNumericPredictor(str2, random.nextInt(), random.nextDouble()));
                arrayList3.add(PMMLModelTestUtils.getPredictorTerm("PredTerm-" + i, random.nextDouble(), Arrays.asList(str, str2)));
                hashSet.add(str);
                hashSet.add(str2);
            });
            return PMMLModelTestUtils.getRegressionTable(arrayList, arrayList2, arrayList3, tableIntercept + random.nextDouble(), tableTargetCategory + "-" + i);
        }).collect(Collectors.toList());
        dataFields = new ArrayList();
        miningFields = new ArrayList();
        hashSet.forEach(str -> {
            dataFields.add(PMMLModelTestUtils.getDataField(str, OpType.CATEGORICAL));
            miningFields.add(PMMLModelTestUtils.getMiningField(str, MiningField.UsageType.ACTIVE));
        });
        targetMiningField = miningFields.get(0);
        targetMiningField.setUsageType(MiningField.UsageType.TARGET);
        dataDictionary = PMMLModelTestUtils.getDataDictionary(dataFields);
        transformationDictionary = new TransformationDictionary();
        miningSchema = PMMLModelTestUtils.getMiningSchema(miningFields);
        regressionModel = PMMLModelTestUtils.getRegressionModel(modelName, MiningFunction.REGRESSION, miningSchema, regressionTables);
        COMPILATION_UNIT = JavaParserUtils.getFromFileName("KiePMMLRegressionModelTemplate.tmpl");
        MODEL_TEMPLATE = (ClassOrInterfaceDeclaration) COMPILATION_UNIT.getClassByName("KiePMMLRegressionModelTemplate").get();
    }

    @Test
    public void getKiePMMLRegressionModelClasses() throws IOException, IllegalAccessException, InstantiationException {
        KiePMMLRegressionModel kiePMMLRegressionModelClasses = KiePMMLRegressionModelFactory.getKiePMMLRegressionModelClasses(dataDictionary, transformationDictionary, regressionModel, Thread.currentThread().getContextClassLoader());
        Assert.assertNotNull(kiePMMLRegressionModelClasses);
        Assert.assertEquals(regressionModel.getModelName(), kiePMMLRegressionModelClasses.getName());
        Assert.assertEquals(MINING_FUNCTION.byName(regressionModel.getMiningFunction().value()), kiePMMLRegressionModelClasses.getMiningFunction());
        Assert.assertEquals(miningFields.get(0).getName().getValue(), kiePMMLRegressionModelClasses.getTargetField());
        KiePMMLRegressionTable regressionTable = kiePMMLRegressionModelClasses.getRegressionTable();
        Assert.assertNotNull(regressionTable);
        Assert.assertTrue(regressionTable instanceof KiePMMLRegressionClassificationTable);
        evaluateCategoricalRegressionTable((KiePMMLRegressionClassificationTable) regressionTable);
    }

    @Test
    public void getKiePMMLRegressionModelSourcesMap() throws IOException {
        Assert.assertNotNull(KiePMMLRegressionModelFactory.getKiePMMLRegressionModelSourcesMap(dataDictionary, transformationDictionary, regressionModel, PACKAGE_NAME));
        Assert.assertEquals(regressionTables.size() + 2, r0.size());
    }

    @Test
    public void getRegressionTablesMap() throws IOException {
        Map regressionTablesMap = KiePMMLRegressionModelFactory.getRegressionTablesMap(dataDictionary, regressionModel, "targetFieldName", Collections.emptyList(), PACKAGE_NAME);
        Assert.assertEquals(regressionTables.size() + 1, regressionTablesMap.size());
        Collection values = regressionTablesMap.values();
        regressionTables.forEach(regressionTable -> {
            Assert.assertTrue(values.stream().anyMatch(kiePMMLTableSourceCategory -> {
                return kiePMMLTableSourceCategory.getCategory().equals(regressionTable.getTargetCategory());
            }));
        });
    }

    @Test
    public void setConstructor() {
        ConstructorDeclaration constructorDeclaration = (ConstructorDeclaration) MODEL_TEMPLATE.getDefaultConstructor().get();
        MINING_FUNCTION byName = MINING_FUNCTION.byName(regressionModel.getMiningFunction().value());
        KiePMMLRegressionModelFactory.setConstructor(regressionModel, "NestedTable", constructorDeclaration, "targetField");
        HashMap hashMap = new HashMap();
        hashMap.put(0, new NameExpr(String.format("\"%s\"", regressionModel.getModelName())));
        HashMap hashMap2 = new HashMap();
        hashMap2.put("targetField", new StringLiteralExpr("targetField"));
        hashMap2.put("miningFunction", new NameExpr(byName.getClass().getName() + "." + byName.name()));
        hashMap2.put("pmmlMODEL", new NameExpr(PMML_MODEL.class.getName() + "." + PMML_MODEL.REGRESSION_MODEL.name()));
        ObjectCreationExpr objectCreationExpr = new ObjectCreationExpr();
        objectCreationExpr.setType("NestedTable");
        hashMap2.put("regressionTable", objectCreationExpr);
        Assert.assertTrue(CodegenTestUtils.commonEvaluateConstructor(constructorDeclaration, KiePMMLModelUtils.getSanitizedClassName(regressionModel.getModelName()), hashMap, hashMap2));
    }

    @Test
    public void isRegression() {
        Assert.assertTrue(KiePMMLRegressionModelFactory.isRegression(MiningFunction.REGRESSION, (String) null, (OpType) null));
        Assert.assertTrue(KiePMMLRegressionModelFactory.isRegression(MiningFunction.REGRESSION, "TARGET", OpType.CONTINUOUS));
        Assert.assertFalse(KiePMMLRegressionModelFactory.isRegression(MiningFunction.REGRESSION, "TARGET", OpType.CATEGORICAL));
        Assert.assertFalse(KiePMMLRegressionModelFactory.isRegression(MiningFunction.CLASSIFICATION, (String) null, (OpType) null));
    }

    private void evaluateCategoricalRegressionTable(KiePMMLRegressionClassificationTable kiePMMLRegressionClassificationTable) {
        Assert.assertEquals(REGRESSION_NORMALIZATION_METHOD.byName(regressionModel.getNormalizationMethod().value()), kiePMMLRegressionClassificationTable.getRegressionNormalizationMethod());
        Assert.assertEquals(OP_TYPE.CATEGORICAL, kiePMMLRegressionClassificationTable.getOpType());
        Map categoryTableMap = kiePMMLRegressionClassificationTable.getCategoryTableMap();
        for (RegressionTable regressionTable : regressionTables) {
            Assert.assertTrue(categoryTableMap.containsKey(regressionTable.getTargetCategory().toString()));
            evaluateRegressionTable((KiePMMLRegressionTable) categoryTableMap.get(regressionTable.getTargetCategory().toString()), regressionTable);
        }
    }

    private void evaluateRegressionTable(KiePMMLRegressionTable kiePMMLRegressionTable, RegressionTable regressionTable) {
        Assert.assertEquals(regressionTable.getIntercept(), Double.valueOf(kiePMMLRegressionTable.getIntercept()));
        Map numericFunctionMap = kiePMMLRegressionTable.getNumericFunctionMap();
        Iterator it = regressionTable.getNumericPredictors().iterator();
        while (it.hasNext()) {
            Assert.assertTrue(numericFunctionMap.containsKey(((NumericPredictor) it.next()).getName().getValue()));
        }
        Map categoricalFunctionMap = kiePMMLRegressionTable.getCategoricalFunctionMap();
        Iterator it2 = regressionTable.getCategoricalPredictors().iterator();
        while (it2.hasNext()) {
            Assert.assertTrue(categoricalFunctionMap.containsKey(((CategoricalPredictor) it2.next()).getName().getValue()));
        }
        Map predictorTermsFunctionMap = kiePMMLRegressionTable.getPredictorTermsFunctionMap();
        Iterator it3 = regressionTable.getPredictorTerms().iterator();
        while (it3.hasNext()) {
            Assert.assertTrue(predictorTermsFunctionMap.containsKey(((PredictorTerm) it3.next()).getName().getValue()));
        }
    }
}
