/*
 * Decompiled with CFR 0.152.
 */
package org.kie.pmml.compiler.commons.utils;

import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.Reader;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.Random;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import javax.xml.bind.JAXBException;
import org.apache.commons.lang3.RandomStringUtils;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MathContext;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.PMML;
import org.dmg.pmml.ResultFeature;
import org.dmg.pmml.Target;
import org.dmg.pmml.Targets;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segment;
import org.jpmml.model.PMMLUtil;
import org.junit.Assert;
import org.junit.Test;
import org.kie.pmml.compiler.commons.utils.KiePMMLUtil;
import org.kie.test.util.filesystem.FileUtils;
import org.xml.sax.SAXException;

public class KiePMMLUtilTest {
    private static final String NO_MODELNAME_SAMPLE_NAME = "NoModelNameSample.pmml";
    private static final String NO_MODELNAME_NO_SEGMENTID_SAMPLE_NAME = "NoModelNameNoSegmentIdSample.pmml";
    private static final String NO_MODELNAME_NO_SEGMENT_ID_NOSEGMENT_TARGET_FIELD_SAMPLE = "NoModelNameNoSegmentIdNoSegmentTargetFieldSample.pmml";
    private static final String NO_TARGET_FIELD_SAMPLE = "NoTargetFieldSample.pmml";
    private static final String MINING_WITH_SAME_NESTED_MODEL_NAMES = "MiningWithSameNestedModelNames.pmml";

    @Test
    public void loadString() throws IOException, JAXBException, SAXException {
        this.commonLoadString(NO_MODELNAME_SAMPLE_NAME);
        this.commonLoadString(NO_MODELNAME_NO_SEGMENTID_SAMPLE_NAME);
        this.commonLoadString(MINING_WITH_SAME_NESTED_MODEL_NAMES);
    }

    @Test
    public void loadFile() throws JAXBException, IOException, SAXException {
        this.commonLoadFile(NO_MODELNAME_SAMPLE_NAME);
        this.commonLoadFile(NO_MODELNAME_NO_SEGMENTID_SAMPLE_NAME);
        this.commonLoadFile(MINING_WITH_SAME_NESTED_MODEL_NAMES);
    }

    @Test
    public void populateMissingModelName() throws Exception {
        FileInputStream inputStream = FileUtils.getFileInputStream((String)NO_MODELNAME_SAMPLE_NAME);
        PMML pmml = PMMLUtil.unmarshal((InputStream)inputStream);
        Model toPopulate = (Model)pmml.getModels().get(0);
        Assert.assertNull((Object)toPopulate.getModelName());
        KiePMMLUtil.populateMissingModelName((Model)toPopulate, (String)NO_MODELNAME_SAMPLE_NAME, (int)0);
        Assert.assertNotNull((Object)toPopulate.getModelName());
        String expected = String.format("%s%s%s", NO_MODELNAME_SAMPLE_NAME, toPopulate.getClass().getSimpleName(), 0);
        Assert.assertEquals((Object)expected, (Object)toPopulate.getModelName());
    }

    @Test
    public void populateMissingMiningTargetField() throws Exception {
        FileInputStream inputStream = FileUtils.getFileInputStream((String)NO_TARGET_FIELD_SAMPLE);
        PMML pmml = PMMLUtil.unmarshal((InputStream)inputStream);
        Model toPopulate = (Model)pmml.getModels().get(0);
        List miningTargetFields = KiePMMLUtil.getMiningTargetFields((List)toPopulate.getMiningSchema().getMiningFields());
        Assert.assertTrue((boolean)miningTargetFields.isEmpty());
        Assert.assertNull((Object)((Target)toPopulate.getTargets().getTargets().get(0)).getField());
        KiePMMLUtil.populateMissingMiningTargetField((Model)toPopulate, (List)pmml.getDataDictionary().getDataFields());
        miningTargetFields = KiePMMLUtil.getMiningTargetFields((List)toPopulate.getMiningSchema().getMiningFields());
        Assert.assertEquals((long)1L, (long)miningTargetFields.size());
        MiningField targetField = (MiningField)miningTargetFields.get(0);
        Assert.assertTrue((boolean)pmml.getDataDictionary().getDataFields().stream().anyMatch(dataField -> dataField.getName().equals((Object)targetField.getName())));
        Assert.assertEquals((Object)targetField.getName(), (Object)((Target)toPopulate.getTargets().getTargets().get(0)).getField());
    }

    @Test
    public void getTargetDataField() throws Exception {
        FileInputStream inputStream = FileUtils.getFileInputStream((String)NO_TARGET_FIELD_SAMPLE);
        PMML pmml = PMMLUtil.unmarshal((InputStream)inputStream);
        Model model = (Model)pmml.getModels().get(0);
        Optional optionalDataField = KiePMMLUtil.getTargetDataField((Model)model);
        Assert.assertTrue((boolean)optionalDataField.isPresent());
        DataField retrieved = (DataField)optionalDataField.get();
        String expected = String.format("target%s", "golfing");
        Assert.assertEquals((Object)expected, (Object)retrieved.getName().getValue());
    }

    @Test
    public void getTargetDataType() {
        MiningFunction miningFunction = MiningFunction.REGRESSION;
        MathContext mathContext = MathContext.DOUBLE;
        DataType retrieved = KiePMMLUtil.getTargetDataType((MiningFunction)miningFunction, (MathContext)mathContext);
        Assert.assertEquals((Object)DataType.DOUBLE, (Object)retrieved);
        mathContext = MathContext.FLOAT;
        retrieved = KiePMMLUtil.getTargetDataType((MiningFunction)miningFunction, (MathContext)mathContext);
        Assert.assertEquals((Object)DataType.FLOAT, (Object)retrieved);
        miningFunction = MiningFunction.CLASSIFICATION;
        retrieved = KiePMMLUtil.getTargetDataType((MiningFunction)miningFunction, (MathContext)mathContext);
        Assert.assertEquals((Object)DataType.STRING, (Object)retrieved);
        miningFunction = MiningFunction.CLUSTERING;
        retrieved = KiePMMLUtil.getTargetDataType((MiningFunction)miningFunction, (MathContext)mathContext);
        Assert.assertEquals((Object)DataType.STRING, (Object)retrieved);
        List<MiningFunction> notMappedMiningFunctions = Arrays.asList(MiningFunction.ASSOCIATION_RULES, MiningFunction.MIXED, MiningFunction.SEQUENCES, MiningFunction.TIME_SERIES);
        notMappedMiningFunctions.forEach(minFun -> Assert.assertNull((Object)KiePMMLUtil.getTargetDataType((MiningFunction)minFun, (MathContext)MathContext.DOUBLE)));
    }

    @Test
    public void getTargetOpType() {
        MiningFunction miningFunction = MiningFunction.REGRESSION;
        OpType retrieved = KiePMMLUtil.getTargetOpType((MiningFunction)miningFunction);
        Assert.assertEquals((Object)OpType.CONTINUOUS, (Object)retrieved);
        miningFunction = MiningFunction.CLASSIFICATION;
        retrieved = KiePMMLUtil.getTargetOpType((MiningFunction)miningFunction);
        Assert.assertEquals((Object)OpType.CATEGORICAL, (Object)retrieved);
        miningFunction = MiningFunction.CLUSTERING;
        retrieved = KiePMMLUtil.getTargetOpType((MiningFunction)miningFunction);
        Assert.assertEquals((Object)OpType.CATEGORICAL, (Object)retrieved);
        List<MiningFunction> notMappedMiningFunctions = Arrays.asList(MiningFunction.ASSOCIATION_RULES, MiningFunction.MIXED, MiningFunction.SEQUENCES, MiningFunction.TIME_SERIES);
        notMappedMiningFunctions.forEach(minFun -> Assert.assertNull((Object)KiePMMLUtil.getTargetOpType((MiningFunction)minFun)));
    }

    @Test
    public void getTargetMiningField() {
        DataField dataField = new DataField();
        dataField.setName(FieldName.create((String)"FIELD_NAME"));
        MiningField retrieved = KiePMMLUtil.getTargetMiningField((DataField)dataField);
        Assert.assertEquals((Object)dataField.getName().getValue(), (Object)retrieved.getName().getValue());
        Assert.assertEquals((Object)MiningField.UsageType.TARGET, (Object)retrieved.getUsageType());
    }

    @Test
    public void correctTargetFields() {
        MiningField miningField = new MiningField(FieldName.create((String)"FIELD_NAME"));
        Targets targets = new Targets();
        Target namedTarget = new Target();
        String targetName = "TARGET_NAME";
        namedTarget.setField(FieldName.create((String)targetName));
        Target unnamedTarget = new Target();
        targets.addTargets(new Target[]{namedTarget, unnamedTarget});
        KiePMMLUtil.correctTargetFields((MiningField)miningField, (Targets)targets);
        Assert.assertEquals((Object)targetName, (Object)namedTarget.getField().getValue());
        Assert.assertEquals((Object)miningField.getName(), (Object)unnamedTarget.getField());
    }

    @Test
    public void populateCorrectMiningModel() throws Exception {
        FileInputStream inputStream = FileUtils.getFileInputStream((String)NO_MODELNAME_NO_SEGMENT_ID_NOSEGMENT_TARGET_FIELD_SAMPLE);
        PMML pmml = PMMLUtil.unmarshal((InputStream)inputStream);
        Model retrieved = (Model)pmml.getModels().get(0);
        Assert.assertTrue((boolean)(retrieved instanceof MiningModel));
        MiningModel miningModel = (MiningModel)retrieved;
        miningModel.getSegmentation().getSegments().forEach(segment -> {
            Assert.assertNull((Object)segment.getId());
            Assert.assertNull((Object)segment.getModel().getModelName());
            Assert.assertTrue((boolean)KiePMMLUtil.getMiningTargetFields((MiningSchema)segment.getModel().getMiningSchema()).isEmpty());
        });
        KiePMMLUtil.populateCorrectMiningModel((MiningModel)miningModel);
        miningModel.getSegmentation().getSegments().forEach(segment -> {
            Assert.assertNotNull((Object)segment.getId());
            Assert.assertNotNull((Object)segment.getModel().getModelName());
            Assert.assertFalse((boolean)KiePMMLUtil.getMiningTargetFields((MiningSchema)segment.getModel().getMiningSchema()).isEmpty());
        });
    }

    @Test
    public void populateCorrectSegmentId() throws Exception {
        FileInputStream inputStream = FileUtils.getFileInputStream((String)NO_MODELNAME_NO_SEGMENTID_SAMPLE_NAME);
        PMML pmml = PMMLUtil.unmarshal((InputStream)inputStream);
        Model retrieved = (Model)pmml.getModels().get(0);
        Assert.assertTrue((boolean)(retrieved instanceof MiningModel));
        MiningModel miningModel = (MiningModel)retrieved;
        Segment toPopulate = (Segment)miningModel.getSegmentation().getSegments().get(0);
        Assert.assertNull((Object)toPopulate.getId());
        String modelName = "MODEL_NAME";
        int i = 0;
        KiePMMLUtil.populateCorrectSegmentId((Segment)toPopulate, (String)modelName, (int)i);
        Assert.assertNotNull((Object)toPopulate.getId());
        String expected = String.format("%sSegment%s", modelName, i);
        Assert.assertEquals((Object)expected, (Object)toPopulate.getId());
    }

    @Test
    public void populateMissingSegmentModelName() throws Exception {
        FileInputStream inputStream = FileUtils.getFileInputStream((String)NO_MODELNAME_NO_SEGMENTID_SAMPLE_NAME);
        PMML pmml = PMMLUtil.unmarshal((InputStream)inputStream);
        Model retrieved = (Model)pmml.getModels().get(0);
        Assert.assertTrue((boolean)(retrieved instanceof MiningModel));
        MiningModel miningModel = (MiningModel)retrieved;
        Model toPopulate = ((Segment)miningModel.getSegmentation().getSegments().get(0)).getModel();
        Assert.assertNull((Object)toPopulate.getModelName());
        String segmentId = "SEG_ID";
        KiePMMLUtil.populateMissingSegmentModelName((Model)toPopulate, (String)segmentId);
        Assert.assertNotNull((Object)toPopulate.getModelName());
        String expected = String.format("Segment%s%s", segmentId, toPopulate.getClass().getSimpleName());
        Assert.assertEquals((Object)expected, (Object)toPopulate.getModelName());
    }

    @Test
    public void populateMissingTargetFieldInSegment() throws Exception {
        FileInputStream inputStream = FileUtils.getFileInputStream((String)NO_MODELNAME_NO_SEGMENT_ID_NOSEGMENT_TARGET_FIELD_SAMPLE);
        PMML pmml = PMMLUtil.unmarshal((InputStream)inputStream);
        Model retrieved = (Model)pmml.getModels().get(0);
        Assert.assertTrue((boolean)(retrieved instanceof MiningModel));
        MiningModel miningModel = (MiningModel)retrieved;
        Model toPopulate = ((Segment)miningModel.getSegmentation().getSegments().get(0)).getModel();
        Assert.assertTrue((boolean)KiePMMLUtil.getMiningTargetFields((MiningSchema)toPopulate.getMiningSchema()).isEmpty());
        KiePMMLUtil.populateMissingTargetFieldInSegment((MiningSchema)retrieved.getMiningSchema(), (Model)toPopulate);
        List childrenTargetFields = KiePMMLUtil.getMiningTargetFields((MiningSchema)toPopulate.getMiningSchema());
        Assert.assertFalse((boolean)childrenTargetFields.isEmpty());
        KiePMMLUtil.getMiningTargetFields((MiningSchema)miningModel.getMiningSchema()).forEach(parentTargetField -> Assert.assertTrue((boolean)childrenTargetFields.contains(parentTargetField)));
    }

    @Test
    public void populateMissingOutputFieldDataType() {
        Random random = new Random();
        List fieldNames = IntStream.range(0, 6).mapToObj(i -> RandomStringUtils.random((int)6, (boolean)true, (boolean)false)).collect(Collectors.toList());
        List dataFields = fieldNames.stream().map(fieldName -> {
            DataField toReturn = new DataField();
            toReturn.setName(FieldName.create((String)fieldName));
            DataType dataType = DataType.values()[random.nextInt(DataType.values().length)];
            toReturn.setDataType(dataType);
            return toReturn;
        }).collect(Collectors.toList());
        List miningFields = IntStream.range(0, dataFields.size() - 1).mapToObj(dataFields::get).map(dataField -> {
            MiningField toReturn = new MiningField();
            toReturn.setName(FieldName.create((String)dataField.getName().getValue()));
            toReturn.setUsageType(MiningField.UsageType.ACTIVE);
            return toReturn;
        }).collect(Collectors.toList());
        DataField lastDataField = (DataField)dataFields.get(dataFields.size() - 1);
        MiningField targetMiningField = new MiningField();
        targetMiningField.setName(FieldName.create((String)lastDataField.getName().getValue()));
        targetMiningField.setUsageType(MiningField.UsageType.TARGET);
        miningFields.add(targetMiningField);
        List<OutputField> outputFields = IntStream.range(0, 3).mapToObj(i -> {
            OutputField toReturn = new OutputField();
            toReturn.setName(FieldName.create((String)RandomStringUtils.random((int)6, (boolean)true, (boolean)false)));
            toReturn.setResultFeature(ResultFeature.PROBABILITY);
            return toReturn;
        }).collect(Collectors.toList());
        OutputField targetOutputField = new OutputField();
        targetOutputField.setName(FieldName.create((String)RandomStringUtils.random((int)6, (boolean)true, (boolean)false)));
        targetOutputField.setResultFeature(ResultFeature.PREDICTED_VALUE);
        outputFields.add(targetOutputField);
        OutputField targetingOutputField = new OutputField();
        targetingOutputField.setName(FieldName.create((String)RandomStringUtils.random((int)6, (boolean)true, (boolean)false)));
        targetingOutputField.setTargetField(FieldName.create((String)targetMiningField.getName().getValue()));
        outputFields.add(targetingOutputField);
        outputFields.forEach(outputField -> Assert.assertNull((Object)outputField.getDataType()));
        IntStream.range(0, 2).forEach(i -> {
            OutputField toAdd = new OutputField();
            toAdd.setName(FieldName.create((String)RandomStringUtils.random((int)6, (boolean)true, (boolean)false)));
            DataType dataType = DataType.values()[random.nextInt(DataType.values().length)];
            toAdd.setDataType(dataType);
            outputFields.add(toAdd);
        });
        KiePMMLUtil.populateMissingOutputFieldDataType(outputFields, miningFields, dataFields);
        outputFields.forEach(outputField -> Assert.assertNotNull((Object)outputField.getDataType()));
    }

    @Test
    public void getSanitizedId() {
        String modelName = "MODEL_NAME";
        String id = "2";
        String expected = String.format("%sSegment%s", "MODEL_NAME", id);
        String retrieved = KiePMMLUtil.getSanitizedId((String)id, (String)"MODEL_NAME");
        Assert.assertEquals((Object)expected, (Object)retrieved);
        id = "34.5";
        expected = String.format("%sSegment%s", "MODEL_NAME", id);
        retrieved = KiePMMLUtil.getSanitizedId((String)id, (String)"MODEL_NAME");
        Assert.assertEquals((Object)expected, (Object)retrieved);
        id = "3,45";
        expected = String.format("%sSegment%s", "MODEL_NAME", id);
        retrieved = KiePMMLUtil.getSanitizedId((String)id, (String)"MODEL_NAME");
        Assert.assertEquals((Object)expected, (Object)retrieved);
    }

    @Test
    public void getMiningTargetFieldsFromMiningSchema() throws Exception {
        FileInputStream inputStream = FileUtils.getFileInputStream((String)NO_MODELNAME_SAMPLE_NAME);
        PMML toPopulate = PMMLUtil.unmarshal((InputStream)inputStream);
        Model model = (Model)toPopulate.getModels().get(0);
        List retrieved = KiePMMLUtil.getMiningTargetFields((MiningSchema)model.getMiningSchema());
        Assert.assertNotNull((Object)retrieved);
        Assert.assertEquals((long)1L, (long)retrieved.size());
        MiningField targetField = (MiningField)retrieved.get(0);
        Assert.assertEquals((Object)"car_location", (Object)targetField.getName().getValue());
        Assert.assertEquals((Object)"target", (Object)targetField.getUsageType().value());
    }

    @Test
    public void getMiningTargetFieldsFromMiningFields() throws Exception {
        FileInputStream inputStream = FileUtils.getFileInputStream((String)NO_MODELNAME_SAMPLE_NAME);
        PMML toPopulate = PMMLUtil.unmarshal((InputStream)inputStream);
        Model model = (Model)toPopulate.getModels().get(0);
        List retrieved = KiePMMLUtil.getMiningTargetFields((List)model.getMiningSchema().getMiningFields());
        Assert.assertNotNull((Object)retrieved);
        Assert.assertEquals((long)1L, (long)retrieved.size());
        MiningField targetField = (MiningField)retrieved.get(0);
        Assert.assertEquals((Object)"car_location", (Object)targetField.getName().getValue());
        Assert.assertEquals((Object)"target", (Object)targetField.getUsageType().value());
    }

    private void commonLoadString(String fileName) throws IOException, JAXBException, SAXException {
        FileInputStream inputStream = FileUtils.getFileInputStream((String)fileName);
        StringBuilder textBuilder = new StringBuilder();
        try (BufferedReader reader = new BufferedReader(new InputStreamReader((InputStream)inputStream, Charset.forName(StandardCharsets.UTF_8.name())));){
            int c = 0;
            while ((c = ((Reader)reader).read()) != -1) {
                textBuilder.append((char)c);
            }
        }
        PMML retrieved = KiePMMLUtil.load((String)textBuilder.toString());
        this.commonValidatePMML(retrieved);
    }

    private void commonLoadFile(String fileName) throws IOException, JAXBException, SAXException {
        PMML retrieved = KiePMMLUtil.load((InputStream)FileUtils.getFileInputStream((String)fileName), (String)fileName);
        this.commonValidatePMML(retrieved);
    }

    private void commonValidatePMML(PMML toValidate) {
        Assert.assertNotNull((Object)toValidate);
        for (Model model : toValidate.getModels()) {
            Assert.assertNotNull((Object)model.getModelName());
            if (!(model instanceof MiningModel)) continue;
            this.commonValidateMiningModel((MiningModel)model);
        }
    }

    private void commonValidateMiningModel(MiningModel toValidate) {
        Assert.assertNotNull((Object)toValidate);
        for (Segment segment2 : toValidate.getSegmentation().getSegments()) {
            Assert.assertNotNull((Object)segment2.getId());
            Model segmentModel = segment2.getModel();
            Assert.assertNotNull((Object)segmentModel.getModelName());
            if (!(segmentModel instanceof MiningModel)) continue;
            this.commonValidateMiningModel((MiningModel)segmentModel);
        }
        List modelNames = toValidate.getSegmentation().getSegments().stream().map(segment -> segment.getModel().getModelName()).collect(Collectors.toList());
        Assert.assertEquals((long)modelNames.size(), (long)modelNames.stream().distinct().count());
    }
}

