package org.kie.pmml.models.regression.evaluator;

import org.drools.core.util.StringUtils;
import org.kie.api.KieBase;
import org.kie.api.pmml.PMML4Result;
import org.kie.pmml.api.enums.PMML_MODEL;
import org.kie.pmml.api.enums.ResultCode;
import org.kie.pmml.api.exceptions.KiePMMLInternalException;
import org.kie.pmml.api.runtime.PMMLContext;
import org.kie.pmml.evaluator.api.exceptions.KiePMMLModelException;
import org.kie.pmml.evaluator.core.executor.PMMLModelEvaluator;
import org.kie.pmml.evaluator.core.utils.Converter;
import org.kie.pmml.models.regression.model.AbstractKiePMMLTable;
import org.kie.pmml.models.regression.model.KiePMMLClassificationTable;
import org.kie.pmml.models.regression.model.KiePMMLRegressionModel;
import org.kie.pmml.models.regression.model.KiePMMLRegressionTable;

/* loaded from: input_file:BOOT-INF/lib/kie-pmml-models-regression-evaluator-7.68.0.Final.jar:org/kie/pmml/models/regression/evaluator/PMMLRegressionModelEvaluator.class */
public class PMMLRegressionModelEvaluator implements PMMLModelEvaluator<KiePMMLRegressionModel> {
    private static final String INVALID_NORMALIZATION_METHOD = "Invalid Normalization Method %s";
    private static final String EXPECTED_AT_LEAST_TWO_REGRESSION_TABLES_RETRIEVED = "Expected at least two RegressionTables, retrieved %s";
    private static final String EXPECTED_TWO_REGRESSION_TABLES_RETRIEVED = "Expected two RegressionTables, retrieved %s";
    private static final String TARGET_FIELD_REQUIRED_RETRIEVED = "TargetField required, retrieved %s";
    private static final String INVALID_TARGET_TYPE = "Invalid target type %s";

    @Override // org.kie.pmml.evaluator.core.executor.PMMLModelEvaluator
    public PMML_MODEL getPMMLModelType() {
        return PMML_MODEL.REGRESSION_MODEL;
    }

    @Override // org.kie.pmml.evaluator.core.executor.PMMLModelEvaluator
    public PMML4Result evaluate(KieBase kieBase, KiePMMLRegressionModel kiePMMLRegressionModel, PMMLContext pMMLContext) {
        validate(kiePMMLRegressionModel);
        PMML4Result pMML4Result = new PMML4Result();
        String targetField = kiePMMLRegressionModel.getTargetField();
        pMML4Result.addResultVariable(targetField, kiePMMLRegressionModel.evaluate(kieBase, Converter.getUnwrappedParametersMap(pMMLContext.getRequestData().getMappedRequestParams()), pMMLContext));
        pMML4Result.setResultObjectName(targetField);
        pMML4Result.setResultCode(ResultCode.OK.getName());
        return pMML4Result;
    }

    private void validate(KiePMMLRegressionModel kiePMMLRegressionModel) {
        if (kiePMMLRegressionModel.getRegressionTable() == null) {
            throw new KiePMMLModelException("At least one RegressionTable required");
        }
        AbstractKiePMMLTable regressionTable = kiePMMLRegressionModel.getRegressionTable();
        if (regressionTable instanceof KiePMMLClassificationTable) {
            validateClassification((KiePMMLClassificationTable) regressionTable);
        } else {
            if (!(regressionTable instanceof KiePMMLRegressionTable)) {
                throw new KiePMMLInternalException("Unexpected regressionTable " + regressionTable);
            }
            validateRegression((KiePMMLRegressionTable) regressionTable);
        }
    }

    private void validateRegression(KiePMMLRegressionTable kiePMMLRegressionTable) {
        if (kiePMMLRegressionTable.getTargetField() == null || StringUtils.isEmpty(kiePMMLRegressionTable.getTargetField().trim())) {
            throw new KiePMMLInternalException(String.format(TARGET_FIELD_REQUIRED_RETRIEVED, kiePMMLRegressionTable.getTargetField()));
        }
    }

    private void validateClassification(KiePMMLClassificationTable kiePMMLClassificationTable) {
        switch (kiePMMLClassificationTable.getOpType()) {
            case CATEGORICAL:
                validateClassificationCategorical(kiePMMLClassificationTable);
                return;
            case ORDINAL:
                validateClassificationOrdinal(kiePMMLClassificationTable);
                return;
            default:
                throw new KiePMMLModelException(String.format(INVALID_TARGET_TYPE, kiePMMLClassificationTable.getOpType()));
        }
    }

    private void validateClassificationCategorical(KiePMMLClassificationTable kiePMMLClassificationTable) {
        if (kiePMMLClassificationTable.isBinary()) {
            validateClassificationCategoricalBinary(kiePMMLClassificationTable);
        } else {
            validateClassificationCategoricalNotBinary(kiePMMLClassificationTable);
        }
    }

    private void validateClassificationCategoricalBinary(KiePMMLClassificationTable kiePMMLClassificationTable) {
        switch (kiePMMLClassificationTable.getRegressionNormalizationMethod()) {
            case LOGIT:
            case PROBIT:
            case CAUCHIT:
            case CLOGLOG:
            case LOGLOG:
            case NONE:
                if (kiePMMLClassificationTable.getCategoryTableMap().size() != 2) {
                    throw new KiePMMLModelException(String.format(EXPECTED_TWO_REGRESSION_TABLES_RETRIEVED, Integer.valueOf(kiePMMLClassificationTable.getCategoryTableMap().size())));
                }
                return;
            default:
                throw new KiePMMLModelException(String.format(INVALID_NORMALIZATION_METHOD, kiePMMLClassificationTable.getRegressionNormalizationMethod()));
        }
    }

    private void validateClassificationCategoricalNotBinary(KiePMMLClassificationTable kiePMMLClassificationTable) {
        switch (kiePMMLClassificationTable.getRegressionNormalizationMethod()) {
            case NONE:
            case SOFTMAX:
            case SIMPLEMAX:
                if (kiePMMLClassificationTable.getCategoryTableMap().size() < 2) {
                    throw new KiePMMLModelException(String.format(EXPECTED_AT_LEAST_TWO_REGRESSION_TABLES_RETRIEVED, Integer.valueOf(kiePMMLClassificationTable.getCategoryTableMap().size())));
                }
                return;
            default:
                throw new KiePMMLModelException(String.format(INVALID_NORMALIZATION_METHOD, kiePMMLClassificationTable.getRegressionNormalizationMethod()));
        }
    }

    private void validateClassificationOrdinal(KiePMMLClassificationTable kiePMMLClassificationTable) {
        switch (kiePMMLClassificationTable.getRegressionNormalizationMethod()) {
            case LOGIT:
            case PROBIT:
            case CAUCHIT:
            case CLOGLOG:
            case LOGLOG:
            case NONE:
                if (kiePMMLClassificationTable.getCategoryTableMap().size() < 2) {
                    throw new KiePMMLModelException(String.format(EXPECTED_AT_LEAST_TWO_REGRESSION_TABLES_RETRIEVED, Integer.valueOf(kiePMMLClassificationTable.getCategoryTableMap().size())));
                }
                return;
            default:
                throw new KiePMMLModelException(String.format(INVALID_NORMALIZATION_METHOD, kiePMMLClassificationTable.getRegressionNormalizationMethod()));
        }
    }
}
