/*
 * Decompiled with CFR 0.152.
 */
package org.kie.pmml.compiler.commons.utils;

import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.Reader;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.Random;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import javax.xml.bind.JAXBException;
import org.apache.commons.lang3.RandomStringUtils;
import org.assertj.core.api.Assertions;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MathContext;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.PMML;
import org.dmg.pmml.ResultFeature;
import org.dmg.pmml.Target;
import org.dmg.pmml.Targets;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segment;
import org.drools.util.FileUtils;
import org.jpmml.model.PMMLUtil;
import org.junit.jupiter.api.Test;
import org.kie.pmml.compiler.commons.utils.KiePMMLUtil;
import org.xml.sax.SAXException;

public class KiePMMLUtilTest {
    private static final String NO_MODELNAME_SAMPLE_NAME = "NoModelNameSample.pmml";
    private static final String NO_MODELNAME_NO_SEGMENTID_SAMPLE_NAME = "NoModelNameNoSegmentIdSample.pmml";
    private static final String NO_MODELNAME_NO_SEGMENT_ID_NOSEGMENT_TARGET_FIELD_SAMPLE = "NoModelNameNoSegmentIdNoSegmentTargetFieldSample.pmml";
    private static final String NO_TARGET_FIELD_SAMPLE = "NoTargetFieldSample.pmml";
    private static final String NO_OUTPUT_FIELD_TARGET_NAME_SAMPLE = "NoOutputFieldTargetNameSample.pmml";
    private static final String MINING_WITH_SAME_NESTED_MODEL_NAMES = "MiningWithSameNestedModelNames.pmml";

    @Test
    void loadString() throws IOException, JAXBException, SAXException {
        this.commonLoadString(NO_MODELNAME_SAMPLE_NAME);
        this.commonLoadString(NO_MODELNAME_NO_SEGMENTID_SAMPLE_NAME);
        this.commonLoadString(MINING_WITH_SAME_NESTED_MODEL_NAMES);
    }

    @Test
    void loadFile() throws JAXBException, IOException, SAXException {
        this.commonLoadFile(NO_MODELNAME_SAMPLE_NAME);
        this.commonLoadFile(NO_MODELNAME_NO_SEGMENTID_SAMPLE_NAME);
        this.commonLoadFile(MINING_WITH_SAME_NESTED_MODEL_NAMES);
    }

    @Test
    void populateMissingModelName() throws Exception {
        FileInputStream inputStream = FileUtils.getFileInputStream((String)NO_MODELNAME_SAMPLE_NAME);
        PMML pmml = PMMLUtil.unmarshal((InputStream)inputStream);
        Model toPopulate = (Model)pmml.getModels().get(0);
        Assertions.assertThat((String)toPopulate.getModelName()).isNull();
        KiePMMLUtil.populateMissingModelName((Model)toPopulate, (String)NO_MODELNAME_SAMPLE_NAME, (int)0);
        Assertions.assertThat((String)toPopulate.getModelName()).isNotNull();
        String expected = String.format("%s%s%s", NO_MODELNAME_SAMPLE_NAME, toPopulate.getClass().getSimpleName(), 0);
        Assertions.assertThat((String)toPopulate.getModelName()).isEqualTo(expected);
    }

    @Test
    void populateMissingMiningTargetField() throws Exception {
        FileInputStream inputStream = FileUtils.getFileInputStream((String)NO_TARGET_FIELD_SAMPLE);
        PMML pmml = PMMLUtil.unmarshal((InputStream)inputStream);
        Model toPopulate = (Model)pmml.getModels().get(0);
        List miningTargetFields = KiePMMLUtil.getMiningTargetFields((List)toPopulate.getMiningSchema().getMiningFields());
        Assertions.assertThat((List)miningTargetFields).isEmpty();
        Assertions.assertThat((Object)((Target)toPopulate.getTargets().getTargets().get(0)).getField()).isNull();
        KiePMMLUtil.populateMissingMiningTargetField((Model)toPopulate, (List)pmml.getDataDictionary().getDataFields());
        miningTargetFields = KiePMMLUtil.getMiningTargetFields((List)toPopulate.getMiningSchema().getMiningFields());
        Assertions.assertThat((List)miningTargetFields).hasSize(1);
        MiningField targetField = (MiningField)miningTargetFields.get(0);
        Assertions.assertThat((boolean)pmml.getDataDictionary().getDataFields().stream().anyMatch(dataField -> dataField.getName().equals((Object)targetField.getName()))).isTrue();
        Assertions.assertThat((Object)((Target)toPopulate.getTargets().getTargets().get(0)).getField()).isEqualTo((Object)targetField.getName());
    }

    @Test
    void populateMissingPredictedOutputFieldTarget() throws Exception {
        FileInputStream inputStream = FileUtils.getFileInputStream((String)NO_OUTPUT_FIELD_TARGET_NAME_SAMPLE);
        PMML pmml = PMMLUtil.unmarshal((InputStream)inputStream);
        Model toPopulate = (Model)pmml.getModels().get(0);
        OutputField outputField = (OutputField)toPopulate.getOutput().getOutputFields().get(0);
        Assertions.assertThat((Comparable)outputField.getResultFeature()).isEqualTo((Object)ResultFeature.PREDICTED_VALUE);
        Assertions.assertThat((Object)outputField.getTargetField()).isNull();
        KiePMMLUtil.populateMissingPredictedOutputFieldTarget((Model)toPopulate);
        MiningField targetField = (MiningField)KiePMMLUtil.getMiningTargetFields((List)toPopulate.getMiningSchema().getMiningFields()).get(0);
        Assertions.assertThat((Object)outputField.getTargetField()).isNotNull();
        Assertions.assertThat((Object)outputField.getTargetField()).isEqualTo((Object)targetField.getName());
    }

    @Test
    void getTargetDataField() throws Exception {
        FileInputStream inputStream = FileUtils.getFileInputStream((String)NO_TARGET_FIELD_SAMPLE);
        PMML pmml = PMMLUtil.unmarshal((InputStream)inputStream);
        Model model = (Model)pmml.getModels().get(0);
        Optional optionalDataField = KiePMMLUtil.getTargetDataField((Model)model);
        Assertions.assertThat((Optional)optionalDataField).isPresent();
        DataField retrieved = (DataField)optionalDataField.get();
        String expected = String.format("target%s", "golfing");
        Assertions.assertThat((String)retrieved.getName().getValue()).isEqualTo(expected);
    }

    @Test
    void getTargetDataType() {
        MiningFunction miningFunction = MiningFunction.REGRESSION;
        MathContext mathContext = MathContext.DOUBLE;
        DataType retrieved = KiePMMLUtil.getTargetDataType((MiningFunction)miningFunction, (MathContext)mathContext);
        Assertions.assertThat((Comparable)retrieved).isEqualTo((Object)DataType.DOUBLE);
        mathContext = MathContext.FLOAT;
        retrieved = KiePMMLUtil.getTargetDataType((MiningFunction)miningFunction, (MathContext)mathContext);
        Assertions.assertThat((Comparable)retrieved).isEqualTo((Object)DataType.FLOAT);
        miningFunction = MiningFunction.CLASSIFICATION;
        retrieved = KiePMMLUtil.getTargetDataType((MiningFunction)miningFunction, (MathContext)mathContext);
        Assertions.assertThat((Comparable)retrieved).isEqualTo((Object)DataType.STRING);
        miningFunction = MiningFunction.CLUSTERING;
        retrieved = KiePMMLUtil.getTargetDataType((MiningFunction)miningFunction, (MathContext)mathContext);
        Assertions.assertThat((Comparable)retrieved).isEqualTo((Object)DataType.STRING);
        List<MiningFunction> notMappedMiningFunctions = Arrays.asList(MiningFunction.ASSOCIATION_RULES, MiningFunction.MIXED, MiningFunction.SEQUENCES, MiningFunction.TIME_SERIES);
        notMappedMiningFunctions.forEach(minFun -> Assertions.assertThat((Comparable)KiePMMLUtil.getTargetDataType((MiningFunction)minFun, (MathContext)MathContext.DOUBLE)).isNull());
    }

    @Test
    void getTargetOpType() {
        MiningFunction miningFunction = MiningFunction.REGRESSION;
        OpType retrieved = KiePMMLUtil.getTargetOpType((MiningFunction)miningFunction);
        Assertions.assertThat((Comparable)retrieved).isEqualTo((Object)OpType.CONTINUOUS);
        miningFunction = MiningFunction.CLASSIFICATION;
        retrieved = KiePMMLUtil.getTargetOpType((MiningFunction)miningFunction);
        Assertions.assertThat((Comparable)retrieved).isEqualTo((Object)OpType.CATEGORICAL);
        miningFunction = MiningFunction.CLUSTERING;
        retrieved = KiePMMLUtil.getTargetOpType((MiningFunction)miningFunction);
        Assertions.assertThat((Comparable)retrieved).isEqualTo((Object)OpType.CATEGORICAL);
        List<MiningFunction> notMappedMiningFunctions = Arrays.asList(MiningFunction.ASSOCIATION_RULES, MiningFunction.MIXED, MiningFunction.SEQUENCES, MiningFunction.TIME_SERIES);
        notMappedMiningFunctions.forEach(minFun -> Assertions.assertThat((Comparable)KiePMMLUtil.getTargetOpType((MiningFunction)minFun)).isNull());
    }

    @Test
    void getTargetMiningField() {
        DataField dataField = new DataField();
        dataField.setName(FieldName.create((String)"FIELD_NAME"));
        MiningField retrieved = KiePMMLUtil.getTargetMiningField((DataField)dataField);
        Assertions.assertThat((String)retrieved.getName().getValue()).isEqualTo(dataField.getName().getValue());
        Assertions.assertThat((Comparable)retrieved.getUsageType()).isEqualTo((Object)MiningField.UsageType.TARGET);
    }

    @Test
    void correctTargetFields() {
        MiningField miningField = new MiningField(FieldName.create((String)"FIELD_NAME"));
        Targets targets = new Targets();
        Target namedTarget = new Target();
        String targetName = "TARGET_NAME";
        namedTarget.setField(FieldName.create((String)targetName));
        Target unnamedTarget = new Target();
        targets.addTargets(new Target[]{namedTarget, unnamedTarget});
        KiePMMLUtil.correctTargetFields((MiningField)miningField, (Targets)targets);
        Assertions.assertThat((String)namedTarget.getField().getValue()).isEqualTo(targetName);
        Assertions.assertThat((Object)unnamedTarget.getField()).isEqualTo((Object)miningField.getName());
    }

    @Test
    void populateCorrectMiningModel() throws Exception {
        FileInputStream inputStream = FileUtils.getFileInputStream((String)NO_MODELNAME_NO_SEGMENT_ID_NOSEGMENT_TARGET_FIELD_SAMPLE);
        PMML pmml = PMMLUtil.unmarshal((InputStream)inputStream);
        Model retrieved = (Model)pmml.getModels().get(0);
        Assertions.assertThat((Object)retrieved).isInstanceOf(MiningModel.class);
        MiningModel miningModel = (MiningModel)retrieved;
        miningModel.getSegmentation().getSegments().forEach(segment -> {
            Assertions.assertThat((String)segment.getId()).isNull();
            Assertions.assertThat((String)segment.getModel().getModelName()).isNull();
            Assertions.assertThat((List)KiePMMLUtil.getMiningTargetFields((MiningSchema)segment.getModel().getMiningSchema())).isEmpty();
        });
        KiePMMLUtil.populateCorrectMiningModel((MiningModel)miningModel);
        miningModel.getSegmentation().getSegments().forEach(segment -> {
            Assertions.assertThat((String)segment.getId()).isNotNull();
            Assertions.assertThat((String)segment.getModel().getModelName()).isNotNull();
            Assertions.assertThat((List)KiePMMLUtil.getMiningTargetFields((MiningSchema)segment.getModel().getMiningSchema())).isNotEmpty();
        });
    }

    @Test
    void populateCorrectSegmentId() throws Exception {
        FileInputStream inputStream = FileUtils.getFileInputStream((String)NO_MODELNAME_NO_SEGMENTID_SAMPLE_NAME);
        PMML pmml = PMMLUtil.unmarshal((InputStream)inputStream);
        Model retrieved = (Model)pmml.getModels().get(0);
        Assertions.assertThat((Object)retrieved).isInstanceOf(MiningModel.class);
        MiningModel miningModel = (MiningModel)retrieved;
        Segment toPopulate = (Segment)miningModel.getSegmentation().getSegments().get(0);
        Assertions.assertThat((String)toPopulate.getId()).isNull();
        String modelName = "MODEL_NAME";
        int i = 0;
        KiePMMLUtil.populateCorrectSegmentId((Segment)toPopulate, (String)modelName, (int)i);
        Assertions.assertThat((String)toPopulate.getId()).isNotNull();
        String expected = String.format("%sSegment%s", modelName, i);
        Assertions.assertThat((String)toPopulate.getId()).isEqualTo(expected);
    }

    @Test
    void populateMissingSegmentModelName() throws Exception {
        FileInputStream inputStream = FileUtils.getFileInputStream((String)NO_MODELNAME_NO_SEGMENTID_SAMPLE_NAME);
        PMML pmml = PMMLUtil.unmarshal((InputStream)inputStream);
        Model retrieved = (Model)pmml.getModels().get(0);
        Assertions.assertThat((Object)retrieved).isInstanceOf(MiningModel.class);
        MiningModel miningModel = (MiningModel)retrieved;
        Model toPopulate = ((Segment)miningModel.getSegmentation().getSegments().get(0)).getModel();
        Assertions.assertThat((String)toPopulate.getModelName()).isNull();
        String segmentId = "SEG_ID";
        KiePMMLUtil.populateMissingSegmentModelName((Model)toPopulate, (String)segmentId);
        Assertions.assertThat((String)toPopulate.getModelName()).isNotNull();
        String expected = String.format("Segment%s%s", segmentId, toPopulate.getClass().getSimpleName());
        Assertions.assertThat((String)toPopulate.getModelName()).isEqualTo(expected);
    }

    @Test
    void populateMissingTargetFieldInSegment() throws Exception {
        FileInputStream inputStream = FileUtils.getFileInputStream((String)NO_MODELNAME_NO_SEGMENT_ID_NOSEGMENT_TARGET_FIELD_SAMPLE);
        PMML pmml = PMMLUtil.unmarshal((InputStream)inputStream);
        Model retrieved = (Model)pmml.getModels().get(0);
        Assertions.assertThat((Object)retrieved).isInstanceOf(MiningModel.class);
        MiningModel miningModel = (MiningModel)retrieved;
        Model toPopulate = ((Segment)miningModel.getSegmentation().getSegments().get(0)).getModel();
        Assertions.assertThat((List)KiePMMLUtil.getMiningTargetFields((MiningSchema)toPopulate.getMiningSchema())).isEmpty();
        KiePMMLUtil.populateMissingTargetFieldInSegment((MiningSchema)retrieved.getMiningSchema(), (Model)toPopulate);
        List childrenTargetFields = KiePMMLUtil.getMiningTargetFields((MiningSchema)toPopulate.getMiningSchema());
        Assertions.assertThat((List)childrenTargetFields).isNotEmpty();
        KiePMMLUtil.getMiningTargetFields((MiningSchema)miningModel.getMiningSchema()).forEach(parentTargetField -> Assertions.assertThat((List)childrenTargetFields).contains((Object[])new MiningField[]{parentTargetField}));
    }

    @Test
    void populateMissingOutputFieldDataType() {
        Random random = new Random();
        List fieldNames = IntStream.range(0, 6).mapToObj(i -> RandomStringUtils.random((int)6, (boolean)true, (boolean)false)).collect(Collectors.toList());
        List dataFields = fieldNames.stream().map(fieldName -> {
            DataField toReturn = new DataField();
            toReturn.setName(FieldName.create((String)fieldName));
            DataType dataType = DataType.values()[random.nextInt(DataType.values().length)];
            toReturn.setDataType(dataType);
            return toReturn;
        }).collect(Collectors.toList());
        List miningFields = IntStream.range(0, dataFields.size() - 1).mapToObj(dataFields::get).map(dataField -> {
            MiningField toReturn = new MiningField();
            toReturn.setName(FieldName.create((String)dataField.getName().getValue()));
            toReturn.setUsageType(MiningField.UsageType.ACTIVE);
            return toReturn;
        }).collect(Collectors.toList());
        DataField lastDataField = (DataField)dataFields.get(dataFields.size() - 1);
        MiningField targetMiningField = new MiningField();
        targetMiningField.setName(FieldName.create((String)lastDataField.getName().getValue()));
        targetMiningField.setUsageType(MiningField.UsageType.TARGET);
        miningFields.add(targetMiningField);
        List<OutputField> outputFields = IntStream.range(0, 3).mapToObj(i -> {
            OutputField toReturn = new OutputField();
            toReturn.setName(FieldName.create((String)RandomStringUtils.random((int)6, (boolean)true, (boolean)false)));
            toReturn.setResultFeature(ResultFeature.PROBABILITY);
            return toReturn;
        }).collect(Collectors.toList());
        OutputField targetOutputField = new OutputField();
        targetOutputField.setName(FieldName.create((String)RandomStringUtils.random((int)6, (boolean)true, (boolean)false)));
        targetOutputField.setResultFeature(ResultFeature.PREDICTED_VALUE);
        outputFields.add(targetOutputField);
        OutputField targetingOutputField = new OutputField();
        targetingOutputField.setName(FieldName.create((String)RandomStringUtils.random((int)6, (boolean)true, (boolean)false)));
        targetingOutputField.setTargetField(FieldName.create((String)targetMiningField.getName().getValue()));
        outputFields.add(targetingOutputField);
        outputFields.forEach(outputField -> Assertions.assertThat((Comparable)outputField.getDataType()).isNull());
        IntStream.range(0, 2).forEach(i -> {
            OutputField toAdd = new OutputField();
            toAdd.setName(FieldName.create((String)RandomStringUtils.random((int)6, (boolean)true, (boolean)false)));
            DataType dataType = DataType.values()[random.nextInt(DataType.values().length)];
            toAdd.setDataType(dataType);
            outputFields.add(toAdd);
        });
        KiePMMLUtil.populateMissingOutputFieldDataType(outputFields, miningFields, dataFields);
        outputFields.forEach(outputField -> Assertions.assertThat((Comparable)outputField.getDataType()).isNotNull());
    }

    @Test
    void getSanitizedId() {
        String modelName = "MODEL_NAME";
        String id = "2";
        String expected = String.format("%sSegment%s", "MODEL_NAME", id);
        String retrieved = KiePMMLUtil.getSanitizedId((String)id, (String)"MODEL_NAME");
        Assertions.assertThat((String)retrieved).isEqualTo(expected);
        id = "34.5";
        expected = String.format("%sSegment%s", "MODEL_NAME", id);
        retrieved = KiePMMLUtil.getSanitizedId((String)id, (String)"MODEL_NAME");
        Assertions.assertThat((String)retrieved).isEqualTo(expected);
        id = "3,45";
        expected = String.format("%sSegment%s", "MODEL_NAME", id);
        retrieved = KiePMMLUtil.getSanitizedId((String)id, (String)"MODEL_NAME");
        Assertions.assertThat((String)retrieved).isEqualTo(expected);
    }

    @Test
    void getMiningTargetFieldsFromMiningSchema() throws Exception {
        FileInputStream inputStream = FileUtils.getFileInputStream((String)NO_MODELNAME_SAMPLE_NAME);
        PMML toPopulate = PMMLUtil.unmarshal((InputStream)inputStream);
        Model model = (Model)toPopulate.getModels().get(0);
        List retrieved = KiePMMLUtil.getMiningTargetFields((MiningSchema)model.getMiningSchema());
        Assertions.assertThat((List)retrieved).isNotNull();
        Assertions.assertThat((List)retrieved).hasSize(1);
        MiningField targetField = (MiningField)retrieved.get(0);
        Assertions.assertThat((String)targetField.getName().getValue()).isEqualTo("car_location");
        Assertions.assertThat((String)targetField.getUsageType().value()).isEqualTo("target");
    }

    @Test
    void getMiningTargetFieldsFromMiningFields() throws Exception {
        FileInputStream inputStream = FileUtils.getFileInputStream((String)NO_MODELNAME_SAMPLE_NAME);
        PMML toPopulate = PMMLUtil.unmarshal((InputStream)inputStream);
        Model model = (Model)toPopulate.getModels().get(0);
        List retrieved = KiePMMLUtil.getMiningTargetFields((List)model.getMiningSchema().getMiningFields());
        Assertions.assertThat((List)retrieved).isNotNull();
        Assertions.assertThat((List)retrieved).hasSize(1);
        MiningField targetField = (MiningField)retrieved.get(0);
        Assertions.assertThat((String)targetField.getName().getValue()).isEqualTo("car_location");
        Assertions.assertThat((String)targetField.getUsageType().value()).isEqualTo("target");
    }

    private void commonLoadString(String fileName) throws IOException, JAXBException, SAXException {
        FileInputStream inputStream = FileUtils.getFileInputStream((String)fileName);
        StringBuilder textBuilder = new StringBuilder();
        try (BufferedReader reader = new BufferedReader(new InputStreamReader((InputStream)inputStream, Charset.forName(StandardCharsets.UTF_8.name())));){
            int c = 0;
            while ((c = ((Reader)reader).read()) != -1) {
                textBuilder.append((char)c);
            }
        }
        PMML retrieved = KiePMMLUtil.load((String)textBuilder.toString());
        this.commonValidatePMML(retrieved);
    }

    private void commonLoadFile(String fileName) throws IOException, JAXBException, SAXException {
        PMML retrieved = KiePMMLUtil.load((InputStream)FileUtils.getFileInputStream((String)fileName), (String)fileName);
        this.commonValidatePMML(retrieved);
    }

    private void commonValidatePMML(PMML toValidate) {
        Assertions.assertThat((Object)toValidate).isNotNull();
        for (Model model : toValidate.getModels()) {
            Assertions.assertThat((String)model.getModelName()).isNotNull();
            if (!(model instanceof MiningModel)) continue;
            this.commonValidateMiningModel((MiningModel)model);
        }
    }

    private void commonValidateMiningModel(MiningModel toValidate) {
        Assertions.assertThat((Object)toValidate).isNotNull();
        for (Segment segment2 : toValidate.getSegmentation().getSegments()) {
            Assertions.assertThat((String)segment2.getId()).isNotNull();
            Model segmentModel = segment2.getModel();
            Assertions.assertThat((String)segmentModel.getModelName()).isNotNull();
            if (!(segmentModel instanceof MiningModel)) continue;
            this.commonValidateMiningModel((MiningModel)segmentModel);
        }
        List modelNames = toValidate.getSegmentation().getSegments().stream().map(segment -> segment.getModel().getModelName()).collect(Collectors.toList());
        Assertions.assertThat(modelNames.stream().distinct()).hasSameSizeAs(modelNames);
    }
}

