package org.kie.pmml.models.regression.compiler.executor;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.dmg.pmml.DataField;
import org.dmg.pmml.Field;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.OpType;
import org.dmg.pmml.regression.RegressionModel;
import org.kie.pmml.api.enums.OP_TYPE;
import org.kie.pmml.api.enums.PMML_MODEL;
import org.kie.pmml.api.exceptions.KiePMMLException;
import org.kie.pmml.commons.model.tuples.KiePMMLNameOpType;
import org.kie.pmml.compiler.api.dto.CompilationDTO;
import org.kie.pmml.compiler.api.provider.ModelImplementationProvider;
import org.kie.pmml.compiler.api.utils.ModelUtils;
import org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO;
import org.kie.pmml.models.regression.compiler.factories.KiePMMLRegressionModelFactory;
import org.kie.pmml.models.regression.model.KiePMMLRegressionModel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:WEB-INF/lib/kie-pmml-models-regression-compiler-7.70.0.Final.jar:org/kie/pmml/models/regression/compiler/executor/RegressionModelImplementationProvider.class */
public class RegressionModelImplementationProvider implements ModelImplementationProvider<RegressionModel, KiePMMLRegressionModel> {
    private static final Logger logger = LoggerFactory.getLogger(RegressionModelImplementationProvider.class.getName());
    private static final String INVALID_NORMALIZATION_METHOD = "Invalid Normalization Method ";

    @Override // org.kie.pmml.compiler.api.provider.ModelImplementationProvider
    public PMML_MODEL getPMMLModelType() {
        logger.trace("getPMMLModelType");
        return PMML_MODEL.REGRESSION_MODEL;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.kie.pmml.compiler.api.provider.ModelImplementationProvider
    public KiePMMLRegressionModel getKiePMMLModel(CompilationDTO<RegressionModel> compilationDTO) {
        logger.trace("getKiePMMLModel {} {} {} {}", compilationDTO.getPackageName(), compilationDTO.getFields(), compilationDTO.getModel(), compilationDTO.getHasClassloader());
        validate(compilationDTO.getFields(), compilationDTO.getModel());
        try {
            return KiePMMLRegressionModelFactory.getKiePMMLRegressionModelClasses(RegressionCompilationDTO.fromCompilationDTO(compilationDTO));
        } catch (IOException | IllegalAccessException | InstantiationException e) {
            throw new KiePMMLException(e.getMessage(), e);
        }
    }

    @Override // org.kie.pmml.compiler.api.provider.ModelImplementationProvider
    public Map<String, String> getSourcesMap(CompilationDTO<RegressionModel> compilationDTO) {
        logger.trace("getKiePMMLModelWithSources {} {} {} {}", compilationDTO.getPackageName(), compilationDTO.getFields(), compilationDTO.getModel(), compilationDTO.getHasClassloader());
        try {
            return KiePMMLRegressionModelFactory.getKiePMMLRegressionModelSourcesMap(RegressionCompilationDTO.fromCompilationDTO(compilationDTO));
        } catch (IOException e) {
            throw new KiePMMLException(e);
        }
    }

    @Override // org.kie.pmml.compiler.api.provider.ModelImplementationProvider
    public boolean isInterpreted() {
        return true;
    }

    protected void validate(List<Field<?>> list, RegressionModel regressionModel) {
        if (regressionModel.getRegressionTables() == null || regressionModel.getRegressionTables().isEmpty()) {
            throw new KiePMMLException("At least one RegressionTable required");
        }
        if (isRegression(regressionModel)) {
            validateRegression(ModelUtils.getTargetFields(list, regressionModel), regressionModel);
        } else {
            validateClassification(list, regressionModel);
        }
    }

    void validateRegression(List<KiePMMLNameOpType> list, RegressionModel regressionModel) {
        validateRegressionTargetField(list, regressionModel);
        if (regressionModel.getRegressionTables().size() != 1) {
            throw new KiePMMLException("Expected one RegressionTable, retrieved " + regressionModel.getRegressionTables().size());
        }
        validateNormalizationMethod(regressionModel.getNormalizationMethod());
    }

    void validateNormalizationMethod(RegressionModel.NormalizationMethod normalizationMethod) {
        switch (normalizationMethod) {
            case NONE:
            case SOFTMAX:
            case LOGIT:
            case EXP:
            case PROBIT:
            case CLOGLOG:
            case LOGLOG:
            case CAUCHIT:
                return;
            default:
                throw new KiePMMLException(INVALID_NORMALIZATION_METHOD + normalizationMethod);
        }
    }

    private void validateClassification(List<Field<?>> list, RegressionModel regressionModel) {
        String categoricalTargetName = getCategoricalTargetName(list, regressionModel);
        OP_TYPE opType = ModelUtils.getOpType(list, regressionModel, categoricalTargetName);
        switch (opType) {
            case CATEGORICAL:
                validateClassificationCategorical(list, regressionModel, categoricalTargetName);
                return;
            case ORDINAL:
                validateClassificationOrdinal(regressionModel);
                return;
            default:
                throw new KiePMMLException("Invalid target type " + opType);
        }
    }

    private void validateClassificationCategorical(List<Field<?>> list, RegressionModel regressionModel, String str) {
        if (isBinary(list, str)) {
            validateClassificationCategoricalBinary(regressionModel);
        } else {
            validateClassificationCategoricalNotBinary(regressionModel);
        }
    }

    private void validateClassificationCategoricalBinary(RegressionModel regressionModel) {
        switch (regressionModel.getNormalizationMethod()) {
            case NONE:
            case LOGIT:
            case PROBIT:
            case CLOGLOG:
            case LOGLOG:
            case CAUCHIT:
                if (regressionModel.getRegressionTables().size() != 2) {
                    throw new KiePMMLException("Expected two RegressionTables, retrieved " + regressionModel.getRegressionTables().size());
                }
                return;
            case SOFTMAX:
            case EXP:
            default:
                throw new KiePMMLException(INVALID_NORMALIZATION_METHOD + regressionModel.getNormalizationMethod());
        }
    }

    private void validateClassificationCategoricalNotBinary(RegressionModel regressionModel) {
        switch (regressionModel.getNormalizationMethod()) {
            case NONE:
                if (regressionModel.getRegressionTables().size() < 3) {
                    throw new KiePMMLException("Expected three RegressionTables, retrieved " + regressionModel.getRegressionTables().size());
                }
                return;
            case SOFTMAX:
            case SIMPLEMAX:
                if (regressionModel.getRegressionTables().size() < 2) {
                    throw new KiePMMLException("Expected at least two RegressionTables, retrieved " + regressionModel.getRegressionTables().size());
                }
                return;
            default:
                throw new KiePMMLException(INVALID_NORMALIZATION_METHOD + regressionModel.getNormalizationMethod());
        }
    }

    private void validateClassificationOrdinal(RegressionModel regressionModel) {
        switch (regressionModel.getNormalizationMethod()) {
            case NONE:
            case LOGIT:
            case PROBIT:
            case CLOGLOG:
            case LOGLOG:
            case CAUCHIT:
                if (regressionModel.getRegressionTables().size() < 2) {
                    throw new KiePMMLException("Expected at least two RegressionTables, retrieved " + regressionModel.getRegressionTables().size());
                }
                return;
            case SOFTMAX:
            case EXP:
            default:
                throw new KiePMMLException(INVALID_NORMALIZATION_METHOD + regressionModel.getNormalizationMethod());
        }
    }

    private void validateRegressionTargetField(List<KiePMMLNameOpType> list, RegressionModel regressionModel) {
        if (list.size() != 1) {
            throw new KiePMMLException("Expected one target field, retrieved " + list.size());
        }
        if (regressionModel.getTargetField() != null && !Objects.equals(regressionModel.getTargetField().getValue(), list.get(0).getName())) {
            throw new KiePMMLException(String.format("Not-matching target fields: %s %s", regressionModel.getTargetField(), list.get(0).getName()));
        }
    }

    private boolean isRegression(RegressionModel regressionModel) {
        return Objects.equals(MiningFunction.REGRESSION, regressionModel.getMiningFunction());
    }

    private boolean isBinary(List<Field<?>> list, String str) {
        Stream<Field<?>> stream = list.stream();
        Class<DataField> cls = DataField.class;
        Objects.requireNonNull(DataField.class);
        Stream<Field<?>> filter = stream.filter((v1) -> {
            return r1.isInstance(v1);
        });
        Class<DataField> cls2 = DataField.class;
        Objects.requireNonNull(DataField.class);
        return filter.map((v1) -> {
            return r1.cast(v1);
        }).filter(dataField -> {
            return Objects.equals(dataField.getName().getValue(), str);
        }).mapToDouble(dataField2 -> {
            return dataField2.getValues().size();
        }).findFirst().orElse(CMAESOptimizer.DEFAULT_STOPFITNESS) == 2.0d;
    }

    private String getCategoricalTargetName(List<Field<?>> list, RegressionModel regressionModel) {
        List<KiePMMLNameOpType> targetFields = ModelUtils.getTargetFields(list, regressionModel);
        List list2 = (List) list.stream().filter(field -> {
            return OpType.CATEGORICAL.equals(field.getOpType());
        }).map(field2 -> {
            return field2.getName().getValue();
        }).collect(Collectors.toList());
        List list3 = (List) targetFields.stream().filter(kiePMMLNameOpType -> {
            return list2.contains(kiePMMLNameOpType.getName());
        }).collect(Collectors.toList());
        if (list3.size() != 1) {
            throw new KiePMMLException(String.format("Expected exactly one categorical targets, found %s", Integer.valueOf(list3.size())));
        }
        return ((KiePMMLNameOpType) list3.get(0)).getName();
    }
}
