package org.kie.pmml.compiler.commons.utils;

import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import javax.xml.bind.JAXBException;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MathContext;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.PMML;
import org.dmg.pmml.ResultFeature;
import org.dmg.pmml.Targets;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segment;
import org.jpmml.model.PMMLUtil;
import org.kie.pmml.api.exceptions.KiePMMLException;
import org.xml.sax.SAXException;

/* loaded from: input_file:BOOT-INF/lib/kie-pmml-compiler-commons-8.34.0.Final.jar:org/kie/pmml/compiler/commons/utils/KiePMMLUtil.class */
public class KiePMMLUtil {
    public static final String SEGMENTID_TEMPLATE = "%sSegment%s";
    static final String MODELNAME_TEMPLATE = "%s%s%s";
    static final String SEGMENTMODELNAME_TEMPLATE = "Segment%s%s";
    static final String TARGETFIELD_TEMPLATE = "target%s";

    private KiePMMLUtil() {
    }

    public static PMML load(String str) throws SAXException, JAXBException {
        return load(new ByteArrayInputStream(str.getBytes()), "");
    }

    public static PMML load(InputStream inputStream, String str) throws SAXException, JAXBException {
        PMML unmarshal = PMMLUtil.unmarshal(inputStream);
        String substring = str.contains(".") ? str.substring(0, str.indexOf(46)) : str;
        List<DataField> dataFields = unmarshal.getDataDictionary().getDataFields();
        List<Model> models = unmarshal.getModels();
        for (int i = 0; i < models.size(); i++) {
            Model model = models.get(i);
            populateMissingModelName(model, substring, i);
            populateMissingOutputFieldDataType(model, dataFields);
            populateMissingMiningTargetField(model, dataFields);
            populateMissingPredictedOutputFieldTarget(model);
            if (model instanceof MiningModel) {
                populateCorrectMiningModel((MiningModel) model);
            }
        }
        return unmarshal;
    }

    static void populateMissingModelName(Model model, String str, int i) {
        if (model.getModelName() == null || model.getModelName().isEmpty()) {
            model.setModelName(String.format(MODELNAME_TEMPLATE, str, model.getClass().getSimpleName(), Integer.valueOf(i)));
        }
    }

    static void populateMissingMiningTargetField(Model model, List<DataField> list) {
        if (getMiningTargetFields(model.getMiningSchema().getMiningFields()).isEmpty()) {
            getTargetDataField(model).ifPresent(dataField -> {
                list.add(dataField);
                MiningField targetMiningField = getTargetMiningField(dataField);
                model.getMiningSchema().addMiningFields(targetMiningField);
                correctTargetFields(targetMiningField, model.getTargets());
            });
        }
    }

    static void populateMissingPredictedOutputFieldTarget(Model model) {
        if (model.getOutput() == null || model.getMiningSchema() == null) {
            return;
        }
        model.getOutput().getOutputFields().stream().filter(outputField -> {
            return (outputField.getResultFeature() == null || outputField.getResultFeature().equals(ResultFeature.PREDICTED_VALUE)) && outputField.getTargetField() == null;
        }).findFirst().ifPresent(outputField2 -> {
            List<MiningField> miningTargetFields = getMiningTargetFields(model.getMiningSchema().getMiningFields());
            if (miningTargetFields.isEmpty()) {
                return;
            }
            outputField2.setTargetField(miningTargetFields.get(0).getName());
        });
    }

    static Optional<DataField> getTargetDataField(Model model) {
        DataType targetDataType = getTargetDataType(model.getMiningFunction(), model.getMathContext());
        OpType targetOpType = getTargetOpType(model.getMiningFunction());
        if (targetDataType == null || targetOpType == null) {
            return Optional.empty();
        }
        String format = String.format(TARGETFIELD_TEMPLATE, model.getModelName().replaceAll("[^A-Za-z0-9]", ""));
        DataField dataField = new DataField();
        dataField.setName2(FieldName.create(format));
        dataField.setOpType(targetOpType);
        dataField.setDataType(targetDataType);
        return Optional.of(dataField);
    }

    static DataType getTargetDataType(MiningFunction miningFunction, MathContext mathContext) {
        switch (miningFunction) {
            case REGRESSION:
                return DataType.fromValue(mathContext.value());
            case CLASSIFICATION:
            case CLUSTERING:
                return DataType.STRING;
            default:
                return null;
        }
    }

    static OpType getTargetOpType(MiningFunction miningFunction) {
        switch (miningFunction) {
            case REGRESSION:
                return OpType.CONTINUOUS;
            case CLASSIFICATION:
            case CLUSTERING:
                return OpType.CATEGORICAL;
            default:
                return null;
        }
    }

    static MiningField getTargetMiningField(DataField dataField) {
        MiningField miningField = new MiningField();
        miningField.setName(dataField.getName());
        miningField.setUsageType(MiningField.UsageType.TARGET);
        return miningField;
    }

    static void correctTargetFields(MiningField miningField, Targets targets) {
        if (targets == null || targets.getTargets().isEmpty()) {
            return;
        }
        targets.getTargets().stream().filter(target -> {
            return target.getField() == null;
        }).forEach(target2 -> {
            target2.setField(miningField.getName());
        });
    }

    static void populateCorrectMiningModel(MiningModel miningModel) {
        List<Segment> segments = miningModel.getSegmentation().getSegments();
        for (int i = 0; i < segments.size(); i++) {
            Segment segment = segments.get(i);
            populateCorrectSegmentId(segment, miningModel.getModelName(), i);
            Model model = segment.getModel();
            populateMissingSegmentModelName(model, segment.getId());
            populateMissingTargetFieldInSegment(miningModel.getMiningSchema(), model);
            populateMissingPredictedOutputFieldTarget(model);
            if (model instanceof MiningModel) {
                populateCorrectMiningModel((MiningModel) segment.getModel());
            }
        }
    }

    static void populateCorrectSegmentId(Segment segment, String str, int i) {
        segment.setId((segment.getId() == null || segment.getId().isEmpty()) ? String.format(SEGMENTID_TEMPLATE, str, Integer.valueOf(i)) : getSanitizedId(segment.getId(), str));
    }

    static void populateMissingSegmentModelName(Model model, String str) {
        model.setModelName(String.format(SEGMENTMODELNAME_TEMPLATE, str, model.getClass().getSimpleName()));
    }

    static void populateMissingTargetFieldInSegment(MiningSchema miningSchema, Model model) {
        List<MiningField> miningTargetFields = getMiningTargetFields(miningSchema.getMiningFields());
        if (getMiningTargetFields(model.getMiningSchema().getMiningFields()).isEmpty()) {
            model.getMiningSchema().addMiningFields((MiningField[]) miningTargetFields.toArray(new MiningField[miningTargetFields.size()]));
        }
    }

    static void populateMissingOutputFieldDataType(Model model, List<DataField> list) {
        if (model.getOutput() == null || model.getOutput().getOutputFields() == null) {
            return;
        }
        populateMissingOutputFieldDataType(model.getOutput().getOutputFields(), model.getMiningSchema().getMiningFields(), list);
    }

    static void populateMissingOutputFieldDataType(List<OutputField> list, List<MiningField> list2, List<DataField> list3) {
        List<MiningField> miningTargetFields = getMiningTargetFields(list2);
        list.stream().filter(outputField -> {
            return outputField.getDataType() == null;
        }).forEach(outputField2 -> {
            MiningField miningField = null;
            if (outputField2.getTargetField() != null) {
                miningField = (MiningField) miningTargetFields.stream().filter(miningField2 -> {
                    return outputField2.getTargetField().equals(miningField2.getName());
                }).findFirst().orElseThrow(() -> {
                    return new KiePMMLException("Failed to find a target field for OutputField " + outputField2.getName().getValue());
                });
            }
            if (miningField == null && (outputField2.getResultFeature() == null || outputField2.getResultFeature().equals(ResultFeature.PREDICTED_VALUE))) {
                miningField = (MiningField) miningTargetFields.stream().findFirst().orElse(null);
            }
            if (miningField == null && ResultFeature.PROBABILITY.equals(outputField2.getResultFeature())) {
                outputField2.setDataType(DataType.DOUBLE);
            } else if (miningField != null) {
                FieldName name = miningField.getName();
                outputField2.setDataType(((DataField) list3.stream().filter(dataField -> {
                    return dataField.getName().equals(name);
                }).findFirst().orElseThrow(() -> {
                    return new KiePMMLException("Failed to find a DataField field for MiningField " + name.toString());
                })).getDataType());
            }
        });
    }

    static String getSanitizedId(String str, String str2) {
        String replace = str.replace(".", "").replace(",", "");
        try {
            Integer.parseInt(replace);
            replace = String.format(SEGMENTID_TEMPLATE, str2, str);
        } catch (NumberFormatException e) {
        }
        return replace;
    }

    static List<MiningField> getMiningTargetFields(MiningSchema miningSchema) {
        return getMiningTargetFields(miningSchema.getMiningFields());
    }

    static List<MiningField> getMiningTargetFields(List<MiningField> list) {
        return (List) list.stream().filter(miningField -> {
            return MiningField.UsageType.PREDICTED.equals(miningField.getUsageType()) || MiningField.UsageType.TARGET.equals(miningField.getUsageType());
        }).collect(Collectors.toList());
    }
}
