package org.kie.pmml.compiler.commons.utils;

import jakarta.xml.bind.JAXBException;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Random;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.lang3.RandomStringUtils;
import org.assertj.core.api.Assertions;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MathContext;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.PMML;
import org.dmg.pmml.ResultFeature;
import org.dmg.pmml.Target;
import org.dmg.pmml.Targets;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segment;
import org.drools.util.FileUtils;
import org.jpmml.model.PMMLUtil;
import org.junit.jupiter.api.Test;
import org.xml.sax.SAXException;

/* loaded from: input_file:org/kie/pmml/compiler/commons/utils/KiePMMLUtilTest.class */
public class KiePMMLUtilTest {
    private static final String NO_MODELNAME_SAMPLE_NAME = "NoModelNameSample.pmml";
    private static final String NO_MODELNAME_NO_SEGMENTID_SAMPLE_NAME = "NoModelNameNoSegmentIdSample.pmml";
    private static final String NO_MODELNAME_NO_SEGMENT_ID_NOSEGMENT_TARGET_FIELD_SAMPLE = "NoModelNameNoSegmentIdNoSegmentTargetFieldSample.pmml";
    private static final String NO_TARGET_FIELD_SAMPLE = "NoTargetFieldSample.pmml";
    private static final String NO_OUTPUT_FIELD_TARGET_NAME_SAMPLE = "NoOutputFieldTargetNameSample.pmml";
    private static final String MINING_WITH_SAME_NESTED_MODEL_NAMES = "MiningWithSameNestedModelNames.pmml";

    @Test
    void loadString() throws IOException, JAXBException, SAXException {
        commonLoadString(NO_MODELNAME_SAMPLE_NAME);
        commonLoadString(NO_MODELNAME_NO_SEGMENTID_SAMPLE_NAME);
        commonLoadString(MINING_WITH_SAME_NESTED_MODEL_NAMES);
    }

    @Test
    void loadFile() throws JAXBException, IOException, SAXException {
        commonLoadFile(NO_MODELNAME_SAMPLE_NAME);
        commonLoadFile(NO_MODELNAME_NO_SEGMENTID_SAMPLE_NAME);
        commonLoadFile(MINING_WITH_SAME_NESTED_MODEL_NAMES);
    }

    @Test
    void populateMissingModelName() throws Exception {
        Model model = (Model) PMMLUtil.unmarshal(FileUtils.getFileInputStream(NO_MODELNAME_SAMPLE_NAME)).getModels().get(0);
        Assertions.assertThat(model.getModelName()).isNull();
        KiePMMLUtil.populateMissingModelName(model, NO_MODELNAME_SAMPLE_NAME, 0);
        Assertions.assertThat(model.getModelName()).isNotNull();
        Assertions.assertThat(model.getModelName()).isEqualTo(String.format("%s%s%s", NO_MODELNAME_SAMPLE_NAME, model.getClass().getSimpleName(), 0));
    }

    @Test
    void populateMissingMiningTargetField() throws Exception {
        PMML unmarshal = PMMLUtil.unmarshal(FileUtils.getFileInputStream(NO_TARGET_FIELD_SAMPLE));
        Model model = (Model) unmarshal.getModels().get(0);
        Assertions.assertThat(KiePMMLUtil.getMiningTargetFields(model.getMiningSchema().getMiningFields())).isEmpty();
        Assertions.assertThat(((Target) model.getTargets().getTargets().get(0)).getField()).isNull();
        KiePMMLUtil.populateMissingMiningTargetField(model, unmarshal.getDataDictionary().getDataFields());
        List miningTargetFields = KiePMMLUtil.getMiningTargetFields(model.getMiningSchema().getMiningFields());
        Assertions.assertThat(miningTargetFields).hasSize(1);
        MiningField miningField = (MiningField) miningTargetFields.get(0);
        Assertions.assertThat(unmarshal.getDataDictionary().getDataFields().stream().anyMatch(dataField -> {
            return dataField.getName().equals(miningField.getName());
        })).isTrue();
        Assertions.assertThat(((Target) model.getTargets().getTargets().get(0)).getField()).isEqualTo(miningField.getName());
    }

    @Test
    void populateMissingPredictedOutputFieldTarget() throws Exception {
        Model model = (Model) PMMLUtil.unmarshal(FileUtils.getFileInputStream(NO_OUTPUT_FIELD_TARGET_NAME_SAMPLE)).getModels().get(0);
        OutputField outputField = (OutputField) model.getOutput().getOutputFields().get(0);
        Assertions.assertThat(outputField.getResultFeature()).isEqualTo(ResultFeature.PREDICTED_VALUE);
        Assertions.assertThat(outputField.getTargetField()).isNull();
        KiePMMLUtil.populateMissingPredictedOutputFieldTarget(model);
        MiningField miningField = (MiningField) KiePMMLUtil.getMiningTargetFields(model.getMiningSchema().getMiningFields()).get(0);
        Assertions.assertThat(outputField.getTargetField()).isNotNull();
        Assertions.assertThat(outputField.getTargetField()).isEqualTo(miningField.getName());
    }

    @Test
    void getTargetDataField() throws Exception {
        Optional targetDataField = KiePMMLUtil.getTargetDataField((Model) PMMLUtil.unmarshal(FileUtils.getFileInputStream(NO_TARGET_FIELD_SAMPLE)).getModels().get(0));
        Assertions.assertThat(targetDataField).isPresent();
        Assertions.assertThat(((DataField) targetDataField.get()).getName()).isEqualTo(String.format("target%s", "golfing"));
    }

    @Test
    void getTargetDataType() {
        MiningFunction miningFunction = MiningFunction.REGRESSION;
        Assertions.assertThat(KiePMMLUtil.getTargetDataType(miningFunction, MathContext.DOUBLE)).isEqualTo(DataType.DOUBLE);
        MathContext mathContext = MathContext.FLOAT;
        Assertions.assertThat(KiePMMLUtil.getTargetDataType(miningFunction, mathContext)).isEqualTo(DataType.FLOAT);
        Assertions.assertThat(KiePMMLUtil.getTargetDataType(MiningFunction.CLASSIFICATION, mathContext)).isEqualTo(DataType.STRING);
        Assertions.assertThat(KiePMMLUtil.getTargetDataType(MiningFunction.CLUSTERING, mathContext)).isEqualTo(DataType.STRING);
        Arrays.asList(MiningFunction.ASSOCIATION_RULES, MiningFunction.MIXED, MiningFunction.SEQUENCES, MiningFunction.TIME_SERIES).forEach(miningFunction2 -> {
            Assertions.assertThat(KiePMMLUtil.getTargetDataType(miningFunction2, MathContext.DOUBLE)).isNull();
        });
    }

    @Test
    void getTargetOpType() {
        Assertions.assertThat(KiePMMLUtil.getTargetOpType(MiningFunction.REGRESSION)).isEqualTo(OpType.CONTINUOUS);
        Assertions.assertThat(KiePMMLUtil.getTargetOpType(MiningFunction.CLASSIFICATION)).isEqualTo(OpType.CATEGORICAL);
        Assertions.assertThat(KiePMMLUtil.getTargetOpType(MiningFunction.CLUSTERING)).isEqualTo(OpType.CATEGORICAL);
        Arrays.asList(MiningFunction.ASSOCIATION_RULES, MiningFunction.MIXED, MiningFunction.SEQUENCES, MiningFunction.TIME_SERIES).forEach(miningFunction -> {
            Assertions.assertThat(KiePMMLUtil.getTargetOpType(miningFunction)).isNull();
        });
    }

    @Test
    void getTargetMiningField() {
        DataField dataField = new DataField();
        dataField.setName("FIELD_NAME");
        MiningField targetMiningField = KiePMMLUtil.getTargetMiningField(dataField);
        Assertions.assertThat(targetMiningField.getName()).isEqualTo(dataField.getName());
        Assertions.assertThat(targetMiningField.getUsageType()).isEqualTo(MiningField.UsageType.TARGET);
    }

    @Test
    void correctTargetFields() {
        MiningField miningField = new MiningField("FIELD_NAME");
        Targets targets = new Targets();
        Target target = new Target();
        target.setField("TARGET_NAME");
        Target target2 = new Target();
        targets.addTargets(new Target[]{target, target2});
        KiePMMLUtil.correctTargetFields(miningField, targets);
        Assertions.assertThat(target.getField()).isEqualTo("TARGET_NAME");
        Assertions.assertThat(target2.getField()).isEqualTo(miningField.getName());
    }

    @Test
    void populateCorrectMiningModel() throws Exception {
        MiningModel miningModel = (Model) PMMLUtil.unmarshal(FileUtils.getFileInputStream(NO_MODELNAME_NO_SEGMENT_ID_NOSEGMENT_TARGET_FIELD_SAMPLE)).getModels().get(0);
        Assertions.assertThat(miningModel).isInstanceOf(MiningModel.class);
        MiningModel miningModel2 = miningModel;
        miningModel2.getSegmentation().getSegments().forEach(segment -> {
            Assertions.assertThat(segment.getId()).isNull();
            Assertions.assertThat(segment.getModel().getModelName()).isNull();
            Assertions.assertThat(KiePMMLUtil.getMiningTargetFields(segment.getModel().getMiningSchema())).isEmpty();
        });
        KiePMMLUtil.populateCorrectMiningModel(miningModel2);
        miningModel2.getSegmentation().getSegments().forEach(segment2 -> {
            Assertions.assertThat(segment2.getId()).isNotNull();
            Assertions.assertThat(segment2.getModel().getModelName()).isNotNull();
            Assertions.assertThat(KiePMMLUtil.getMiningTargetFields(segment2.getModel().getMiningSchema())).isNotEmpty();
        });
    }

    @Test
    void populateCorrectSegmentId() throws Exception {
        MiningModel miningModel = (Model) PMMLUtil.unmarshal(FileUtils.getFileInputStream(NO_MODELNAME_NO_SEGMENTID_SAMPLE_NAME)).getModels().get(0);
        Assertions.assertThat(miningModel).isInstanceOf(MiningModel.class);
        Segment segment = (Segment) miningModel.getSegmentation().getSegments().get(0);
        Assertions.assertThat(segment.getId()).isNull();
        KiePMMLUtil.populateCorrectSegmentId(segment, "MODEL_NAME", 0);
        Assertions.assertThat(segment.getId()).isNotNull();
        Assertions.assertThat(segment.getId()).isEqualTo(String.format("%sSegment%s", "MODEL_NAME", 0));
    }

    @Test
    void populateMissingSegmentModelName() throws Exception {
        MiningModel miningModel = (Model) PMMLUtil.unmarshal(FileUtils.getFileInputStream(NO_MODELNAME_NO_SEGMENTID_SAMPLE_NAME)).getModels().get(0);
        Assertions.assertThat(miningModel).isInstanceOf(MiningModel.class);
        Model model = ((Segment) miningModel.getSegmentation().getSegments().get(0)).getModel();
        Assertions.assertThat(model.getModelName()).isNull();
        KiePMMLUtil.populateMissingSegmentModelName(model, "SEG_ID");
        Assertions.assertThat(model.getModelName()).isNotNull();
        Assertions.assertThat(model.getModelName()).isEqualTo(String.format("Segment%s%s", "SEG_ID", model.getClass().getSimpleName()));
    }

    @Test
    void populateMissingTargetFieldInSegment() throws Exception {
        MiningModel miningModel = (Model) PMMLUtil.unmarshal(FileUtils.getFileInputStream(NO_MODELNAME_NO_SEGMENT_ID_NOSEGMENT_TARGET_FIELD_SAMPLE)).getModels().get(0);
        Assertions.assertThat(miningModel).isInstanceOf(MiningModel.class);
        MiningModel miningModel2 = miningModel;
        Model model = ((Segment) miningModel2.getSegmentation().getSegments().get(0)).getModel();
        Assertions.assertThat(KiePMMLUtil.getMiningTargetFields(model.getMiningSchema())).isEmpty();
        KiePMMLUtil.populateMissingTargetFieldInSegment(miningModel.getMiningSchema(), model);
        List miningTargetFields = KiePMMLUtil.getMiningTargetFields(model.getMiningSchema());
        Assertions.assertThat(miningTargetFields).isNotEmpty();
        KiePMMLUtil.getMiningTargetFields(miningModel2.getMiningSchema()).forEach(miningField -> {
            Assertions.assertThat(miningTargetFields).contains(new MiningField[]{miningField});
        });
    }

    @Test
    void populateMissingOutputFieldDataType() {
        Random random = new Random();
        List list = (List) ((List) IntStream.range(0, 6).mapToObj(i -> {
            return RandomStringUtils.random(6, true, false);
        }).collect(Collectors.toList())).stream().map(str -> {
            DataField dataField = new DataField();
            dataField.setName(str);
            dataField.setDataType(DataType.values()[random.nextInt(DataType.values().length)]);
            return dataField;
        }).collect(Collectors.toList());
        IntStream range = IntStream.range(0, list.size() - 1);
        Objects.requireNonNull(list);
        List list2 = (List) range.mapToObj(list::get).map(dataField -> {
            MiningField miningField = new MiningField();
            miningField.setName(dataField.getName());
            miningField.setUsageType(MiningField.UsageType.ACTIVE);
            return miningField;
        }).collect(Collectors.toList());
        DataField dataField2 = (DataField) list.get(list.size() - 1);
        MiningField miningField = new MiningField();
        miningField.setName(dataField2.getName());
        miningField.setUsageType(MiningField.UsageType.TARGET);
        list2.add(miningField);
        List list3 = (List) IntStream.range(0, 3).mapToObj(i2 -> {
            OutputField outputField = new OutputField();
            outputField.setName(RandomStringUtils.random(6, true, false));
            outputField.setResultFeature(ResultFeature.PROBABILITY);
            return outputField;
        }).collect(Collectors.toList());
        OutputField outputField = new OutputField();
        outputField.setName(RandomStringUtils.random(6, true, false));
        outputField.setResultFeature(ResultFeature.PREDICTED_VALUE);
        list3.add(outputField);
        OutputField outputField2 = new OutputField();
        outputField2.setName(RandomStringUtils.random(6, true, false));
        outputField2.setTargetField(miningField.getName());
        list3.add(outputField2);
        list3.forEach(outputField3 -> {
            Assertions.assertThat(outputField3.getDataType()).isNull();
        });
        IntStream.range(0, 2).forEach(i3 -> {
            OutputField outputField4 = new OutputField();
            outputField4.setName(RandomStringUtils.random(6, true, false));
            outputField4.setDataType(DataType.values()[random.nextInt(DataType.values().length)]);
            list3.add(outputField4);
        });
        KiePMMLUtil.populateMissingOutputFieldDataType(list3, list2, list);
        list3.forEach(outputField4 -> {
            Assertions.assertThat(outputField4.getDataType()).isNotNull();
        });
    }

    @Test
    void getSanitizedId() {
        Assertions.assertThat(KiePMMLUtil.getSanitizedId("2", "MODEL_NAME")).isEqualTo(String.format("%sSegment%s", "MODEL_NAME", "2"));
        Assertions.assertThat(KiePMMLUtil.getSanitizedId("34.5", "MODEL_NAME")).isEqualTo(String.format("%sSegment%s", "MODEL_NAME", "34.5"));
        Assertions.assertThat(KiePMMLUtil.getSanitizedId("3,45", "MODEL_NAME")).isEqualTo(String.format("%sSegment%s", "MODEL_NAME", "3,45"));
    }

    @Test
    void getMiningTargetFieldsFromMiningSchema() throws Exception {
        List miningTargetFields = KiePMMLUtil.getMiningTargetFields(((Model) PMMLUtil.unmarshal(FileUtils.getFileInputStream(NO_MODELNAME_SAMPLE_NAME)).getModels().get(0)).getMiningSchema());
        Assertions.assertThat(miningTargetFields).isNotNull();
        Assertions.assertThat(miningTargetFields).hasSize(1);
        MiningField miningField = (MiningField) miningTargetFields.get(0);
        Assertions.assertThat(miningField.getName()).isEqualTo("car_location");
        Assertions.assertThat(miningField.getUsageType().value()).isEqualTo("target");
    }

    @Test
    void getMiningTargetFieldsFromMiningFields() throws Exception {
        List miningTargetFields = KiePMMLUtil.getMiningTargetFields(((Model) PMMLUtil.unmarshal(FileUtils.getFileInputStream(NO_MODELNAME_SAMPLE_NAME)).getModels().get(0)).getMiningSchema().getMiningFields());
        Assertions.assertThat(miningTargetFields).isNotNull();
        Assertions.assertThat(miningTargetFields).hasSize(1);
        MiningField miningField = (MiningField) miningTargetFields.get(0);
        Assertions.assertThat(miningField.getName()).isEqualTo("car_location");
        Assertions.assertThat(miningField.getUsageType().value()).isEqualTo("target");
    }

    private void commonLoadString(String str) throws IOException, JAXBException, SAXException {
        FileInputStream fileInputStream = FileUtils.getFileInputStream(str);
        StringBuilder sb = new StringBuilder();
        BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(fileInputStream, Charset.forName(StandardCharsets.UTF_8.name())));
        while (true) {
            try {
                int read = bufferedReader.read();
                if (read == -1) {
                    bufferedReader.close();
                    commonValidatePMML(KiePMMLUtil.load(sb.toString()));
                    return;
                }
                sb.append((char) read);
            } catch (Throwable th) {
                try {
                    bufferedReader.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
                throw th;
            }
        }
    }

    private void commonLoadFile(String str) throws IOException, JAXBException, SAXException {
        commonValidatePMML(KiePMMLUtil.load(FileUtils.getFileInputStream(str), str));
    }

    private void commonValidatePMML(PMML pmml) {
        Assertions.assertThat(pmml).isNotNull();
        for (Model model : pmml.getModels()) {
            Assertions.assertThat(model.getModelName()).isNotNull();
            if (model instanceof MiningModel) {
                commonValidateMiningModel((MiningModel) model);
            }
        }
    }

    private void commonValidateMiningModel(MiningModel miningModel) {
        Assertions.assertThat(miningModel).isNotNull();
        for (Segment segment : miningModel.getSegmentation().getSegments()) {
            Assertions.assertThat(segment.getId()).isNotNull();
            Model model = segment.getModel();
            Assertions.assertThat(model.getModelName()).isNotNull();
            if (model instanceof MiningModel) {
                commonValidateMiningModel((MiningModel) model);
            }
        }
        List list = (List) miningModel.getSegmentation().getSegments().stream().map(segment2 -> {
            return segment2.getModel().getModelName();
        }).collect(Collectors.toList());
        Assertions.assertThat(list.stream().distinct()).hasSameSizeAs(list);
    }
}
