package org.kie.pmml.compiler.commons.utils;

import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.util.List;
import java.util.stream.Collectors;
import javax.xml.bind.JAXBException;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.Model;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.PMML;
import org.dmg.pmml.ResultFeature;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segment;
import org.infinispan.xsite.GlobalXSiteAdminOperations;
import org.jpmml.model.PMMLUtil;
import org.kie.pmml.api.exceptions.KiePMMLException;
import org.xml.sax.SAXException;

/* loaded from: input_file:WEB-INF/lib/kie-pmml-compiler-commons-7.52.1-SNAPSHOT.jar:org/kie/pmml/compiler/commons/utils/KiePMMLUtil.class */
public class KiePMMLUtil {
    public static final String SEGMENTID_TEMPLATE = "%sSegment%s";
    static final String MODELNAME_TEMPLATE = "%s%s%s";
    static final String SEGMENTMODELNAME_TEMPLATE = "Segment%s%s";

    private KiePMMLUtil() {
    }

    public static PMML load(String str) throws SAXException, JAXBException {
        return load(new ByteArrayInputStream(str.getBytes()), "");
    }

    public static PMML load(InputStream inputStream, String str) throws SAXException, JAXBException {
        PMML unmarshal = PMMLUtil.unmarshal(inputStream);
        populateMissingNames(unmarshal, str.contains(".") ? str.substring(0, str.indexOf(46)) : str);
        List<DataField> dataFields = unmarshal.getDataDictionary().getDataFields();
        for (Model model : unmarshal.getModels()) {
            if (model.getOutput() != null && model.getOutput().getOutputFields() != null) {
                populateMissingOutputFieldDataType(model.getOutput().getOutputFields(), model.getMiningSchema().getMiningFields(), dataFields);
            }
        }
        return unmarshal;
    }

    static void populateMissingNames(PMML pmml, String str) {
        List<Model> models = pmml.getModels();
        for (int i = 0; i < models.size(); i++) {
            Model model = models.get(i);
            if (model.getModelName() == null || model.getModelName().isEmpty()) {
                model.setModelName(String.format(MODELNAME_TEMPLATE, str, model.getClass().getSimpleName(), Integer.valueOf(i)));
            }
            if (model instanceof MiningModel) {
                populateCorrectMiningModel((MiningModel) model);
            }
        }
    }

    static void populateCorrectMiningModel(MiningModel miningModel) {
        List<Segment> segments = miningModel.getSegmentation().getSegments();
        for (int i = 0; i < segments.size(); i++) {
            Segment segment = segments.get(i);
            segment.setId((segment.getId() == null || segment.getId().isEmpty()) ? String.format(SEGMENTID_TEMPLATE, miningModel.getModelName(), Integer.valueOf(i)) : getSanitizedId(segment.getId(), miningModel.getModelName()));
            Model model = segment.getModel();
            if (model.getModelName() == null || model.getModelName().isEmpty()) {
                model.setModelName(String.format(SEGMENTMODELNAME_TEMPLATE, segment.getId(), model.getClass().getSimpleName()));
            }
            if (segment.getModel() instanceof MiningModel) {
                populateCorrectMiningModel((MiningModel) segment.getModel());
            }
        }
    }

    static void populateMissingOutputFieldDataType(List<OutputField> list, List<MiningField> list2, List<DataField> list3) {
        List list4 = (List) list2.stream().filter(miningField -> {
            return MiningField.UsageType.PREDICTED.equals(miningField.getUsageType()) || MiningField.UsageType.TARGET.equals(miningField.getUsageType());
        }).collect(Collectors.toList());
        list.stream().filter(outputField -> {
            return outputField.getDataType() == null;
        }).forEach(outputField2 -> {
            MiningField miningField2 = null;
            if (outputField2.getTargetField() != null) {
                miningField2 = (MiningField) list4.stream().filter(miningField3 -> {
                    return outputField2.getTargetField().equals(miningField3.getName());
                }).findFirst().orElseThrow(() -> {
                    return new KiePMMLException("Failed to find a target field for OutputField " + outputField2.getName().getValue());
                });
            }
            if (miningField2 == null && (outputField2.getResultFeature() == null || outputField2.getResultFeature().equals(ResultFeature.PREDICTED_VALUE))) {
                miningField2 = (MiningField) list4.stream().findFirst().orElse(null);
            }
            if (miningField2 == null && ResultFeature.PROBABILITY.equals(outputField2.getResultFeature())) {
                outputField2.setDataType(DataType.DOUBLE);
            } else if (miningField2 != null) {
                FieldName name = miningField2.getName();
                outputField2.setDataType(((DataField) list3.stream().filter(dataField -> {
                    return dataField.getName().equals(name);
                }).findFirst().orElseThrow(() -> {
                    return new KiePMMLException("Failed to find a DataField field for MiningField " + name.toString());
                })).getDataType());
            }
        });
    }

    static String getSanitizedId(String str, String str2) {
        String replace = str.replace(".", "").replace(GlobalXSiteAdminOperations.CACHE_DELIMITER, "");
        try {
            Integer.parseInt(replace);
            replace = String.format(SEGMENTID_TEMPLATE, str2, str);
        } catch (NumberFormatException e) {
        }
        return replace;
    }
}
