package org.kie.pmml.models.regression.compiler.executor;

import java.io.IOException;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import org.dmg.pmml.DataDictionary;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.OpType;
import org.dmg.pmml.TransformationDictionary;
import org.dmg.pmml.regression.RegressionModel;
import org.kie.pmml.api.enums.OP_TYPE;
import org.kie.pmml.api.enums.PMML_MODEL;
import org.kie.pmml.api.exceptions.KiePMMLException;
import org.kie.pmml.commons.model.HasClassLoader;
import org.kie.pmml.commons.model.tuples.KiePMMLNameOpType;
import org.kie.pmml.compiler.api.provider.ModelImplementationProvider;
import org.kie.pmml.compiler.commons.utils.ModelUtils;
import org.kie.pmml.models.regression.compiler.factories.KiePMMLRegressionModelFactory;
import org.kie.pmml.models.regression.model.KiePMMLRegressionModel;
import org.kie.pmml.models.regression.model.KiePMMLRegressionModelWithSources;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/kie/pmml/models/regression/compiler/executor/RegressionModelImplementationProvider.class */
public class RegressionModelImplementationProvider implements ModelImplementationProvider<RegressionModel, KiePMMLRegressionModel> {
    private static final Logger logger = LoggerFactory.getLogger(RegressionModelImplementationProvider.class.getName());
    private static final String INVALID_NORMALIZATION_METHOD = "Invalid Normalization Method ";

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.kie.pmml.models.regression.compiler.executor.RegressionModelImplementationProvider$1, reason: invalid class name */
    /* loaded from: input_file:org/kie/pmml/models/regression/compiler/executor/RegressionModelImplementationProvider$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$regression$RegressionModel$NormalizationMethod;
        static final /* synthetic */ int[] $SwitchMap$org$kie$pmml$api$enums$OP_TYPE = new int[OP_TYPE.values().length];

        static {
            try {
                $SwitchMap$org$kie$pmml$api$enums$OP_TYPE[OP_TYPE.CATEGORICAL.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$kie$pmml$api$enums$OP_TYPE[OP_TYPE.ORDINAL.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            $SwitchMap$org$dmg$pmml$regression$RegressionModel$NormalizationMethod = new int[RegressionModel.NormalizationMethod.values().length];
            try {
                $SwitchMap$org$dmg$pmml$regression$RegressionModel$NormalizationMethod[RegressionModel.NormalizationMethod.NONE.ordinal()] = 1;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$dmg$pmml$regression$RegressionModel$NormalizationMethod[RegressionModel.NormalizationMethod.SOFTMAX.ordinal()] = 2;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$org$dmg$pmml$regression$RegressionModel$NormalizationMethod[RegressionModel.NormalizationMethod.LOGIT.ordinal()] = 3;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$org$dmg$pmml$regression$RegressionModel$NormalizationMethod[RegressionModel.NormalizationMethod.EXP.ordinal()] = 4;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$org$dmg$pmml$regression$RegressionModel$NormalizationMethod[RegressionModel.NormalizationMethod.PROBIT.ordinal()] = 5;
            } catch (NoSuchFieldError e7) {
            }
            try {
                $SwitchMap$org$dmg$pmml$regression$RegressionModel$NormalizationMethod[RegressionModel.NormalizationMethod.CLOGLOG.ordinal()] = 6;
            } catch (NoSuchFieldError e8) {
            }
            try {
                $SwitchMap$org$dmg$pmml$regression$RegressionModel$NormalizationMethod[RegressionModel.NormalizationMethod.LOGLOG.ordinal()] = 7;
            } catch (NoSuchFieldError e9) {
            }
            try {
                $SwitchMap$org$dmg$pmml$regression$RegressionModel$NormalizationMethod[RegressionModel.NormalizationMethod.CAUCHIT.ordinal()] = 8;
            } catch (NoSuchFieldError e10) {
            }
            try {
                $SwitchMap$org$dmg$pmml$regression$RegressionModel$NormalizationMethod[RegressionModel.NormalizationMethod.SIMPLEMAX.ordinal()] = 9;
            } catch (NoSuchFieldError e11) {
            }
        }
    }

    public PMML_MODEL getPMMLModelType() {
        logger.trace("getPMMLModelType");
        return PMML_MODEL.REGRESSION_MODEL;
    }

    public KiePMMLRegressionModel getKiePMMLModel(DataDictionary dataDictionary, TransformationDictionary transformationDictionary, RegressionModel regressionModel, HasClassLoader hasClassLoader) {
        logger.trace("getKiePMMLModel {} {} {}", new Object[]{dataDictionary, regressionModel, hasClassLoader});
        validate(dataDictionary, regressionModel);
        try {
            return KiePMMLRegressionModelFactory.getKiePMMLRegressionModelClasses(dataDictionary, transformationDictionary, regressionModel, hasClassLoader.getClassLoader());
        } catch (IOException | IllegalAccessException | InstantiationException e) {
            throw new KiePMMLException(e.getMessage(), e);
        }
    }

    public KiePMMLRegressionModel getKiePMMLModelWithSources(String str, DataDictionary dataDictionary, TransformationDictionary transformationDictionary, RegressionModel regressionModel, HasClassLoader hasClassLoader) {
        logger.trace("getKiePMMLModelWithSources {} {} {}", new Object[]{dataDictionary, regressionModel, hasClassLoader});
        try {
            return new KiePMMLRegressionModelWithSources(regressionModel.getModelName(), str, KiePMMLRegressionModelFactory.getKiePMMLRegressionModelSourcesMap(dataDictionary, transformationDictionary, regressionModel, str));
        } catch (IOException e) {
            throw new KiePMMLException(e);
        }
    }

    protected void validate(DataDictionary dataDictionary, RegressionModel regressionModel) {
        if (regressionModel.getRegressionTables() == null || regressionModel.getRegressionTables().isEmpty()) {
            throw new KiePMMLException("At least one RegressionTable required");
        }
        if (isRegression(regressionModel)) {
            validateRegression(ModelUtils.getTargetFields(dataDictionary, regressionModel), regressionModel);
        } else {
            validateClassification(dataDictionary, regressionModel);
        }
    }

    void validateRegression(List<KiePMMLNameOpType> list, RegressionModel regressionModel) {
        validateRegressionTargetField(list, regressionModel);
        if (regressionModel.getRegressionTables().size() != 1) {
            throw new KiePMMLException("Expected one RegressionTable, retrieved " + regressionModel.getRegressionTables().size());
        }
        validateNormalizationMethod(regressionModel.getNormalizationMethod());
    }

    void validateNormalizationMethod(RegressionModel.NormalizationMethod normalizationMethod) {
        switch (AnonymousClass1.$SwitchMap$org$dmg$pmml$regression$RegressionModel$NormalizationMethod[normalizationMethod.ordinal()]) {
            case 1:
            case 2:
            case 3:
            case 4:
            case 5:
            case 6:
            case 7:
            case 8:
                return;
            default:
                throw new KiePMMLException(INVALID_NORMALIZATION_METHOD + normalizationMethod);
        }
    }

    private void validateClassification(DataDictionary dataDictionary, RegressionModel regressionModel) {
        String categoricalTargetName = getCategoricalTargetName(dataDictionary, regressionModel);
        OP_TYPE opType = ModelUtils.getOpType(dataDictionary, regressionModel, categoricalTargetName);
        switch (AnonymousClass1.$SwitchMap$org$kie$pmml$api$enums$OP_TYPE[opType.ordinal()]) {
            case 1:
                validateClassificationCategorical(dataDictionary, regressionModel, categoricalTargetName);
                return;
            case 2:
                validateClassificationOrdinal(regressionModel);
                return;
            default:
                throw new KiePMMLException("Invalid target type " + opType);
        }
    }

    private void validateClassificationCategorical(DataDictionary dataDictionary, RegressionModel regressionModel, String str) {
        if (isBinary(dataDictionary, str)) {
            validateClassificationCategoricalBinary(regressionModel);
        } else {
            validateClassificationCategoricalNotBinary(regressionModel);
        }
    }

    private void validateClassificationCategoricalBinary(RegressionModel regressionModel) {
        switch (AnonymousClass1.$SwitchMap$org$dmg$pmml$regression$RegressionModel$NormalizationMethod[regressionModel.getNormalizationMethod().ordinal()]) {
            case 1:
            case 3:
            case 5:
            case 6:
            case 7:
            case 8:
                if (regressionModel.getRegressionTables().size() != 2) {
                    throw new KiePMMLException("Expected two RegressionTables, retrieved " + regressionModel.getRegressionTables().size());
                }
                return;
            case 2:
            case 4:
            default:
                throw new KiePMMLException(INVALID_NORMALIZATION_METHOD + regressionModel.getNormalizationMethod());
        }
    }

    private void validateClassificationCategoricalNotBinary(RegressionModel regressionModel) {
        switch (AnonymousClass1.$SwitchMap$org$dmg$pmml$regression$RegressionModel$NormalizationMethod[regressionModel.getNormalizationMethod().ordinal()]) {
            case 1:
                if (regressionModel.getRegressionTables().size() < 3) {
                    throw new KiePMMLException("Expected three RegressionTables, retrieved " + regressionModel.getRegressionTables().size());
                }
                return;
            case 2:
            case 9:
                if (regressionModel.getRegressionTables().size() < 2) {
                    throw new KiePMMLException("Expected at least two RegressionTables, retrieved " + regressionModel.getRegressionTables().size());
                }
                return;
            default:
                throw new KiePMMLException(INVALID_NORMALIZATION_METHOD + regressionModel.getNormalizationMethod());
        }
    }

    private void validateClassificationOrdinal(RegressionModel regressionModel) {
        switch (AnonymousClass1.$SwitchMap$org$dmg$pmml$regression$RegressionModel$NormalizationMethod[regressionModel.getNormalizationMethod().ordinal()]) {
            case 1:
            case 3:
            case 5:
            case 6:
            case 7:
            case 8:
                if (regressionModel.getRegressionTables().size() < 2) {
                    throw new KiePMMLException("Expected at least two RegressionTables, retrieved " + regressionModel.getRegressionTables().size());
                }
                return;
            case 2:
            case 4:
            default:
                throw new KiePMMLException(INVALID_NORMALIZATION_METHOD + regressionModel.getNormalizationMethod());
        }
    }

    private void validateRegressionTargetField(List<KiePMMLNameOpType> list, RegressionModel regressionModel) {
        if (list.size() != 1) {
            throw new KiePMMLException("Expected one target field, retrieved " + list.size());
        }
        if (regressionModel.getTargetField() != null && !Objects.equals(regressionModel.getTargetField().getValue(), list.get(0).getName())) {
            throw new KiePMMLException(String.format("Not-matching target fields: %s %s", regressionModel.getTargetField(), list.get(0).getName()));
        }
    }

    private boolean isRegression(RegressionModel regressionModel) {
        return Objects.equals(MiningFunction.REGRESSION, regressionModel.getMiningFunction());
    }

    private boolean isBinary(DataDictionary dataDictionary, String str) {
        return dataDictionary.getDataFields().stream().filter(dataField -> {
            return Objects.equals(dataField.getName().getValue(), str);
        }).mapToDouble(dataField2 -> {
            return dataField2.getValues().size();
        }).findFirst().orElse(0.0d) == 2.0d;
    }

    private String getCategoricalTargetName(DataDictionary dataDictionary, RegressionModel regressionModel) {
        List targetFields = ModelUtils.getTargetFields(dataDictionary, regressionModel);
        List list = (List) dataDictionary.getDataFields().stream().filter(dataField -> {
            return OpType.CATEGORICAL.equals(dataField.getOpType());
        }).map(dataField2 -> {
            return dataField2.getName().getValue();
        }).collect(Collectors.toList());
        List list2 = (List) targetFields.stream().filter(kiePMMLNameOpType -> {
            return list.contains(kiePMMLNameOpType.getName());
        }).collect(Collectors.toList());
        if (list2.size() != 1) {
            throw new KiePMMLException(String.format("Expected exactly one categorical targets, found %s", Integer.valueOf(list2.size())));
        }
        return ((KiePMMLNameOpType) list2.get(0)).getName();
    }
}
