package org.kie.pmml.compiler.commons.utils;

import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.Random;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import javax.xml.bind.JAXBException;
import org.apache.commons.lang3.RandomStringUtils;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MathContext;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.PMML;
import org.dmg.pmml.ResultFeature;
import org.dmg.pmml.Target;
import org.dmg.pmml.Targets;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segment;
import org.jpmml.model.PMMLUtil;
import org.junit.Assert;
import org.junit.Test;
import org.kie.test.util.filesystem.FileUtils;
import org.xml.sax.SAXException;

/* loaded from: input_file:org/kie/pmml/compiler/commons/utils/KiePMMLUtilTest.class */
public class KiePMMLUtilTest {
    private static final String NO_MODELNAME_SAMPLE_NAME = "NoModelNameSample.pmml";
    private static final String NO_MODELNAME_NO_SEGMENTID_SAMPLE_NAME = "NoModelNameNoSegmentIdSample.pmml";
    private static final String NO_MODELNAME_NO_SEGMENT_ID_NOSEGMENT_TARGET_FIELD_SAMPLE = "NoModelNameNoSegmentIdNoSegmentTargetFieldSample.pmml";
    private static final String NO_TARGET_FIELD_SAMPLE = "NoTargetFieldSample.pmml";
    private static final String MINING_WITH_SAME_NESTED_MODEL_NAMES = "MiningWithSameNestedModelNames.pmml";

    @Test
    public void loadString() throws IOException, JAXBException, SAXException {
        commonLoadString(NO_MODELNAME_SAMPLE_NAME);
        commonLoadString(NO_MODELNAME_NO_SEGMENTID_SAMPLE_NAME);
        commonLoadString(MINING_WITH_SAME_NESTED_MODEL_NAMES);
    }

    @Test
    public void loadFile() throws JAXBException, IOException, SAXException {
        commonLoadFile(NO_MODELNAME_SAMPLE_NAME);
        commonLoadFile(NO_MODELNAME_NO_SEGMENTID_SAMPLE_NAME);
        commonLoadFile(MINING_WITH_SAME_NESTED_MODEL_NAMES);
    }

    @Test
    public void populateMissingModelName() throws Exception {
        Model model = (Model) PMMLUtil.unmarshal(FileUtils.getFileInputStream(NO_MODELNAME_SAMPLE_NAME)).getModels().get(0);
        Assert.assertNull(model.getModelName());
        KiePMMLUtil.populateMissingModelName(model, NO_MODELNAME_SAMPLE_NAME, 0);
        Assert.assertNotNull(model.getModelName());
        Assert.assertEquals(String.format("%s%s%s", NO_MODELNAME_SAMPLE_NAME, model.getClass().getSimpleName(), 0), model.getModelName());
    }

    @Test
    public void populateMissingMiningTargetField() throws Exception {
        PMML unmarshal = PMMLUtil.unmarshal(FileUtils.getFileInputStream(NO_TARGET_FIELD_SAMPLE));
        Model model = (Model) unmarshal.getModels().get(0);
        Assert.assertTrue(KiePMMLUtil.getMiningTargetFields(model.getMiningSchema().getMiningFields()).isEmpty());
        Assert.assertNull(((Target) model.getTargets().getTargets().get(0)).getField());
        KiePMMLUtil.populateMissingMiningTargetField(model, unmarshal.getDataDictionary().getDataFields());
        List miningTargetFields = KiePMMLUtil.getMiningTargetFields(model.getMiningSchema().getMiningFields());
        Assert.assertEquals(1L, miningTargetFields.size());
        MiningField miningField = (MiningField) miningTargetFields.get(0);
        Assert.assertTrue(unmarshal.getDataDictionary().getDataFields().stream().anyMatch(dataField -> {
            return dataField.getName().equals(miningField.getName());
        }));
        Assert.assertEquals(miningField.getName(), ((Target) model.getTargets().getTargets().get(0)).getField());
    }

    @Test
    public void getTargetDataField() throws Exception {
        Optional targetDataField = KiePMMLUtil.getTargetDataField((Model) PMMLUtil.unmarshal(FileUtils.getFileInputStream(NO_TARGET_FIELD_SAMPLE)).getModels().get(0));
        Assert.assertTrue(targetDataField.isPresent());
        Assert.assertEquals(String.format("target%s", "golfing"), ((DataField) targetDataField.get()).getName().getValue());
    }

    @Test
    public void getTargetDataType() {
        MiningFunction miningFunction = MiningFunction.REGRESSION;
        Assert.assertEquals(DataType.DOUBLE, KiePMMLUtil.getTargetDataType(miningFunction, MathContext.DOUBLE));
        MathContext mathContext = MathContext.FLOAT;
        Assert.assertEquals(DataType.FLOAT, KiePMMLUtil.getTargetDataType(miningFunction, mathContext));
        Assert.assertEquals(DataType.STRING, KiePMMLUtil.getTargetDataType(MiningFunction.CLASSIFICATION, mathContext));
        Assert.assertEquals(DataType.STRING, KiePMMLUtil.getTargetDataType(MiningFunction.CLUSTERING, mathContext));
        Arrays.asList(MiningFunction.ASSOCIATION_RULES, MiningFunction.MIXED, MiningFunction.SEQUENCES, MiningFunction.TIME_SERIES).forEach(miningFunction2 -> {
            Assert.assertNull(KiePMMLUtil.getTargetDataType(miningFunction2, MathContext.DOUBLE));
        });
    }

    @Test
    public void getTargetOpType() {
        Assert.assertEquals(OpType.CONTINUOUS, KiePMMLUtil.getTargetOpType(MiningFunction.REGRESSION));
        Assert.assertEquals(OpType.CATEGORICAL, KiePMMLUtil.getTargetOpType(MiningFunction.CLASSIFICATION));
        Assert.assertEquals(OpType.CATEGORICAL, KiePMMLUtil.getTargetOpType(MiningFunction.CLUSTERING));
        Arrays.asList(MiningFunction.ASSOCIATION_RULES, MiningFunction.MIXED, MiningFunction.SEQUENCES, MiningFunction.TIME_SERIES).forEach(miningFunction -> {
            Assert.assertNull(KiePMMLUtil.getTargetOpType(miningFunction));
        });
    }

    @Test
    public void getTargetMiningField() {
        DataField dataField = new DataField();
        dataField.setName(FieldName.create("FIELD_NAME"));
        MiningField targetMiningField = KiePMMLUtil.getTargetMiningField(dataField);
        Assert.assertEquals(dataField.getName().getValue(), targetMiningField.getName().getValue());
        Assert.assertEquals(MiningField.UsageType.TARGET, targetMiningField.getUsageType());
    }

    @Test
    public void correctTargetFields() {
        MiningField miningField = new MiningField(FieldName.create("FIELD_NAME"));
        Targets targets = new Targets();
        Target target = new Target();
        target.setField(FieldName.create("TARGET_NAME"));
        Target target2 = new Target();
        targets.addTargets(new Target[]{target, target2});
        KiePMMLUtil.correctTargetFields(miningField, targets);
        Assert.assertEquals("TARGET_NAME", target.getField().getValue());
        Assert.assertEquals(miningField.getName(), target2.getField());
    }

    @Test
    public void populateCorrectMiningModel() throws Exception {
        MiningModel miningModel = (Model) PMMLUtil.unmarshal(FileUtils.getFileInputStream(NO_MODELNAME_NO_SEGMENT_ID_NOSEGMENT_TARGET_FIELD_SAMPLE)).getModels().get(0);
        Assert.assertTrue(miningModel instanceof MiningModel);
        MiningModel miningModel2 = miningModel;
        miningModel2.getSegmentation().getSegments().forEach(segment -> {
            Assert.assertNull(segment.getId());
            Assert.assertNull(segment.getModel().getModelName());
            Assert.assertTrue(KiePMMLUtil.getMiningTargetFields(segment.getModel().getMiningSchema()).isEmpty());
        });
        KiePMMLUtil.populateCorrectMiningModel(miningModel2);
        miningModel2.getSegmentation().getSegments().forEach(segment2 -> {
            Assert.assertNotNull(segment2.getId());
            Assert.assertNotNull(segment2.getModel().getModelName());
            Assert.assertFalse(KiePMMLUtil.getMiningTargetFields(segment2.getModel().getMiningSchema()).isEmpty());
        });
    }

    @Test
    public void populateCorrectSegmentId() throws Exception {
        MiningModel miningModel = (Model) PMMLUtil.unmarshal(FileUtils.getFileInputStream(NO_MODELNAME_NO_SEGMENTID_SAMPLE_NAME)).getModels().get(0);
        Assert.assertTrue(miningModel instanceof MiningModel);
        Segment segment = (Segment) miningModel.getSegmentation().getSegments().get(0);
        Assert.assertNull(segment.getId());
        KiePMMLUtil.populateCorrectSegmentId(segment, "MODEL_NAME", 0);
        Assert.assertNotNull(segment.getId());
        Assert.assertEquals(String.format("%sSegment%s", "MODEL_NAME", 0), segment.getId());
    }

    @Test
    public void populateMissingSegmentModelName() throws Exception {
        MiningModel miningModel = (Model) PMMLUtil.unmarshal(FileUtils.getFileInputStream(NO_MODELNAME_NO_SEGMENTID_SAMPLE_NAME)).getModels().get(0);
        Assert.assertTrue(miningModel instanceof MiningModel);
        Model model = ((Segment) miningModel.getSegmentation().getSegments().get(0)).getModel();
        Assert.assertNull(model.getModelName());
        KiePMMLUtil.populateMissingSegmentModelName(model, "SEG_ID");
        Assert.assertNotNull(model.getModelName());
        Assert.assertEquals(String.format("Segment%s%s", "SEG_ID", model.getClass().getSimpleName()), model.getModelName());
    }

    @Test
    public void populateMissingTargetFieldInSegment() throws Exception {
        MiningModel miningModel = (Model) PMMLUtil.unmarshal(FileUtils.getFileInputStream(NO_MODELNAME_NO_SEGMENT_ID_NOSEGMENT_TARGET_FIELD_SAMPLE)).getModels().get(0);
        Assert.assertTrue(miningModel instanceof MiningModel);
        MiningModel miningModel2 = miningModel;
        Model model = ((Segment) miningModel2.getSegmentation().getSegments().get(0)).getModel();
        Assert.assertTrue(KiePMMLUtil.getMiningTargetFields(model.getMiningSchema()).isEmpty());
        KiePMMLUtil.populateMissingTargetFieldInSegment(miningModel.getMiningSchema(), model);
        List miningTargetFields = KiePMMLUtil.getMiningTargetFields(model.getMiningSchema());
        Assert.assertFalse(miningTargetFields.isEmpty());
        KiePMMLUtil.getMiningTargetFields(miningModel2.getMiningSchema()).forEach(miningField -> {
            Assert.assertTrue(miningTargetFields.contains(miningField));
        });
    }

    @Test
    public void populateMissingOutputFieldDataType() {
        Random random = new Random();
        List list = (List) ((List) IntStream.range(0, 6).mapToObj(i -> {
            return RandomStringUtils.random(6, true, false);
        }).collect(Collectors.toList())).stream().map(str -> {
            DataField dataField = new DataField();
            dataField.setName(FieldName.create(str));
            dataField.setDataType(DataType.values()[random.nextInt(DataType.values().length)]);
            return dataField;
        }).collect(Collectors.toList());
        IntStream range = IntStream.range(0, list.size() - 1);
        list.getClass();
        List list2 = (List) range.mapToObj(list::get).map(dataField -> {
            MiningField miningField = new MiningField();
            miningField.setName(FieldName.create(dataField.getName().getValue()));
            miningField.setUsageType(MiningField.UsageType.ACTIVE);
            return miningField;
        }).collect(Collectors.toList());
        DataField dataField2 = (DataField) list.get(list.size() - 1);
        MiningField miningField = new MiningField();
        miningField.setName(FieldName.create(dataField2.getName().getValue()));
        miningField.setUsageType(MiningField.UsageType.TARGET);
        list2.add(miningField);
        List list3 = (List) IntStream.range(0, 3).mapToObj(i2 -> {
            OutputField outputField = new OutputField();
            outputField.setName(FieldName.create(RandomStringUtils.random(6, true, false)));
            outputField.setResultFeature(ResultFeature.PROBABILITY);
            return outputField;
        }).collect(Collectors.toList());
        OutputField outputField = new OutputField();
        outputField.setName(FieldName.create(RandomStringUtils.random(6, true, false)));
        outputField.setResultFeature(ResultFeature.PREDICTED_VALUE);
        list3.add(outputField);
        OutputField outputField2 = new OutputField();
        outputField2.setName(FieldName.create(RandomStringUtils.random(6, true, false)));
        outputField2.setTargetField(FieldName.create(miningField.getName().getValue()));
        list3.add(outputField2);
        list3.forEach(outputField3 -> {
            Assert.assertNull(outputField3.getDataType());
        });
        IntStream.range(0, 2).forEach(i3 -> {
            OutputField outputField4 = new OutputField();
            outputField4.setName(FieldName.create(RandomStringUtils.random(6, true, false)));
            outputField4.setDataType(DataType.values()[random.nextInt(DataType.values().length)]);
            list3.add(outputField4);
        });
        KiePMMLUtil.populateMissingOutputFieldDataType(list3, list2, list);
        list3.forEach(outputField4 -> {
            Assert.assertNotNull(outputField4.getDataType());
        });
    }

    @Test
    public void getSanitizedId() {
        Assert.assertEquals(String.format("%sSegment%s", "MODEL_NAME", "2"), KiePMMLUtil.getSanitizedId("2", "MODEL_NAME"));
        Assert.assertEquals(String.format("%sSegment%s", "MODEL_NAME", "34.5"), KiePMMLUtil.getSanitizedId("34.5", "MODEL_NAME"));
        Assert.assertEquals(String.format("%sSegment%s", "MODEL_NAME", "3,45"), KiePMMLUtil.getSanitizedId("3,45", "MODEL_NAME"));
    }

    @Test
    public void getMiningTargetFieldsFromMiningSchema() throws Exception {
        List miningTargetFields = KiePMMLUtil.getMiningTargetFields(((Model) PMMLUtil.unmarshal(FileUtils.getFileInputStream(NO_MODELNAME_SAMPLE_NAME)).getModels().get(0)).getMiningSchema());
        Assert.assertNotNull(miningTargetFields);
        Assert.assertEquals(1L, miningTargetFields.size());
        MiningField miningField = (MiningField) miningTargetFields.get(0);
        Assert.assertEquals("car_location", miningField.getName().getValue());
        Assert.assertEquals("target", miningField.getUsageType().value());
    }

    @Test
    public void getMiningTargetFieldsFromMiningFields() throws Exception {
        List miningTargetFields = KiePMMLUtil.getMiningTargetFields(((Model) PMMLUtil.unmarshal(FileUtils.getFileInputStream(NO_MODELNAME_SAMPLE_NAME)).getModels().get(0)).getMiningSchema().getMiningFields());
        Assert.assertNotNull(miningTargetFields);
        Assert.assertEquals(1L, miningTargetFields.size());
        MiningField miningField = (MiningField) miningTargetFields.get(0);
        Assert.assertEquals("car_location", miningField.getName().getValue());
        Assert.assertEquals("target", miningField.getUsageType().value());
    }

    private void commonLoadString(String str) throws IOException, JAXBException, SAXException {
        FileInputStream fileInputStream = FileUtils.getFileInputStream(str);
        StringBuilder sb = new StringBuilder();
        BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(fileInputStream, Charset.forName(StandardCharsets.UTF_8.name())));
        Throwable th = null;
        while (true) {
            try {
                try {
                    int read = bufferedReader.read();
                    if (read == -1) {
                        break;
                    } else {
                        sb.append((char) read);
                    }
                } finally {
                }
            } catch (Throwable th2) {
                if (bufferedReader != null) {
                    if (th != null) {
                        try {
                            bufferedReader.close();
                        } catch (Throwable th3) {
                            th.addSuppressed(th3);
                        }
                    } else {
                        bufferedReader.close();
                    }
                }
                throw th2;
            }
        }
        if (bufferedReader != null) {
            if (0 != 0) {
                try {
                    bufferedReader.close();
                } catch (Throwable th4) {
                    th.addSuppressed(th4);
                }
            } else {
                bufferedReader.close();
            }
        }
        commonValidatePMML(KiePMMLUtil.load(sb.toString()));
    }

    private void commonLoadFile(String str) throws IOException, JAXBException, SAXException {
        commonValidatePMML(KiePMMLUtil.load(FileUtils.getFileInputStream(str), str));
    }

    private void commonValidatePMML(PMML pmml) {
        Assert.assertNotNull(pmml);
        for (Model model : pmml.getModels()) {
            Assert.assertNotNull(model.getModelName());
            if (model instanceof MiningModel) {
                commonValidateMiningModel((MiningModel) model);
            }
        }
    }

    private void commonValidateMiningModel(MiningModel miningModel) {
        Assert.assertNotNull(miningModel);
        for (Segment segment : miningModel.getSegmentation().getSegments()) {
            Assert.assertNotNull(segment.getId());
            Model model = segment.getModel();
            Assert.assertNotNull(model.getModelName());
            if (model instanceof MiningModel) {
                commonValidateMiningModel((MiningModel) model);
            }
        }
        Assert.assertEquals(r0.size(), ((List) miningModel.getSegmentation().getSegments().stream().map(segment2 -> {
            return segment2.getModel().getModelName();
        }).collect(Collectors.toList())).stream().distinct().count());
    }
}
