package org.kie.pmml.models.mining.evaluator;

import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import org.drools.core.RuleBaseConfiguration;
import org.drools.core.impl.InternalKnowledgeBase;
import org.drools.core.impl.KnowledgeBaseFactory;
import org.drools.core.impl.KnowledgeBaseImpl;
import org.drools.core.util.StringUtils;
import org.kie.api.KieBase;
import org.kie.api.pmml.PMML4Result;
import org.kie.api.runtime.KieRuntimeFactory;
import org.kie.pmml.api.enums.MINING_FUNCTION;
import org.kie.pmml.api.enums.PMML_MODEL;
import org.kie.pmml.api.enums.ResultCode;
import org.kie.pmml.api.exceptions.KieEnumException;
import org.kie.pmml.api.exceptions.KiePMMLException;
import org.kie.pmml.api.exceptions.KiePMMLInternalException;
import org.kie.pmml.api.runtime.PMMLContext;
import org.kie.pmml.api.runtime.PMMLRuntime;
import org.kie.pmml.commons.model.KiePMMLModel;
import org.kie.pmml.commons.model.predicates.KiePMMLPredicate;
import org.kie.pmml.commons.model.tuples.KiePMMLNameValue;
import org.kie.pmml.commons.model.tuples.KiePMMLValueWeight;
import org.kie.pmml.evaluator.api.exceptions.KiePMMLModelException;
import org.kie.pmml.evaluator.core.executor.PMMLModelEvaluator;
import org.kie.pmml.evaluator.core.utils.Converter;
import org.kie.pmml.models.mining.model.KiePMMLMiningModel;
import org.kie.pmml.models.mining.model.enums.MULTIPLE_MODEL_METHOD;
import org.kie.pmml.models.mining.model.segmentation.KiePMMLSegment;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:WEB-INF/lib/kie-pmml-models-mining-evaluator-7.62.0.Final.jar:org/kie/pmml/models/mining/evaluator/PMMLMiningModelEvaluator.class */
public class PMMLMiningModelEvaluator implements PMMLModelEvaluator<KiePMMLMiningModel> {
    private static final String EXPECTED_A_KIE_PMMLMINING_MODEL_RECEIVED = "Expected a KiePMMLMiningModel, received %s";
    private static final String TARGET_FIELD_REQUIRED_RETRIEVED = "TargetField required, retrieved %s";
    private static final Logger logger = LoggerFactory.getLogger(PMMLMiningModelEvaluator.class.getName());
    private static final Map<String, InternalKnowledgeBase> MAPPED_KIEBASES = new HashMap();

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:WEB-INF/lib/kie-pmml-models-mining-evaluator-7.62.0.Final.jar:org/kie/pmml/models/mining/evaluator/PMMLMiningModelEvaluator$KiePMMLNameValueProbabilityMapTuple.class */
    public static class KiePMMLNameValueProbabilityMapTuple {
        private final KiePMMLNameValue predictionValue;
        private final List<KiePMMLNameValue> probabilityValues;

        public KiePMMLNameValueProbabilityMapTuple(KiePMMLNameValue kiePMMLNameValue, List<KiePMMLNameValue> list) {
            this.predictionValue = kiePMMLNameValue;
            this.probabilityValues = list;
        }
    }

    /* loaded from: input_file:WEB-INF/lib/kie-pmml-models-mining-evaluator-7.62.0.Final.jar:org/kie/pmml/models/mining/evaluator/PMMLMiningModelEvaluator$PMML4ResultProbabilityMapTuple.class */
    static class PMML4ResultProbabilityMapTuple {
        private final PMML4Result pmml4Result;
        private final Map<String, Double> probabilityResultMap;

        public PMML4ResultProbabilityMapTuple(PMML4Result pMML4Result, Map<String, Double> map) {
            this.pmml4Result = pMML4Result;
            this.probabilityResultMap = map;
        }
    }

    @Override // org.kie.pmml.evaluator.core.executor.PMMLModelEvaluator
    public PMML_MODEL getPMMLModelType() {
        return PMML_MODEL.MINING_MODEL;
    }

    @Override // org.kie.pmml.evaluator.core.executor.PMMLModelEvaluator
    public PMML4Result evaluate(KieBase kieBase, KiePMMLMiningModel kiePMMLMiningModel, PMMLContext pMMLContext) {
        validate(kiePMMLMiningModel);
        return evaluateMiningModel(kiePMMLMiningModel, pMMLContext, kieBase);
    }

    PMML4Result getPMML4Result(KiePMMLMiningModel kiePMMLMiningModel, LinkedHashMap<String, KiePMMLNameValueProbabilityMapTuple> linkedHashMap, PMMLContext pMMLContext) {
        MULTIPLE_MODEL_METHOD multipleModelMethod = kiePMMLMiningModel.getSegmentation().getMultipleModelMethod();
        Object obj = null;
        LinkedHashMap<String, Double> linkedHashMap2 = null;
        ResultCode resultCode = ResultCode.OK;
        LinkedHashMap<String, KiePMMLNameValue> linkedHashMap3 = new LinkedHashMap<>();
        LinkedHashMap<String, List<KiePMMLNameValue>> linkedHashMap4 = new LinkedHashMap<>();
        linkedHashMap.forEach((str, kiePMMLNameValueProbabilityMapTuple) -> {
            linkedHashMap3.put(str, kiePMMLNameValueProbabilityMapTuple.predictionValue);
            linkedHashMap4.put(str, kiePMMLNameValueProbabilityMapTuple.probabilityValues);
        });
        try {
            if (MINING_FUNCTION.CLASSIFICATION.equals(kiePMMLMiningModel.getMiningFunction())) {
                obj = multipleModelMethod.applyClassification(linkedHashMap3);
                linkedHashMap2 = multipleModelMethod.applyProbability(linkedHashMap4);
            } else {
                obj = multipleModelMethod.applyPrediction(linkedHashMap3);
            }
        } catch (KieEnumException e) {
            logger.warn(e.getMessage());
            resultCode = ResultCode.FAIL;
        }
        pMMLContext.setProbabilityResultMap(linkedHashMap2);
        PMML4Result pMML4Result = new PMML4Result();
        pMML4Result.addResultVariable(kiePMMLMiningModel.getTargetField(), obj);
        pMML4Result.setResultObjectName(kiePMMLMiningModel.getTargetField());
        pMML4Result.setResultCode(resultCode.getName());
        return pMML4Result;
    }

    PMMLRuntime getPMMLRuntime(String str, KieBase kieBase, String str2) {
        return (PMMLRuntime) KieRuntimeFactory.of(MAPPED_KIEBASES.computeIfAbsent(str2 + "_" + str, str3 -> {
            List singletonList = kieBase.getKiePackage(str) != null ? Collections.singletonList(kieBase.getKiePackage(str)) : Collections.emptyList();
            RuleBaseConfiguration ruleBaseConfiguration = new RuleBaseConfiguration();
            ruleBaseConfiguration.setClassLoader(((KnowledgeBaseImpl) kieBase).getRootClassLoader());
            InternalKnowledgeBase newKnowledgeBase = KnowledgeBaseFactory.newKnowledgeBase(str, ruleBaseConfiguration);
            newKnowledgeBase.addPackages(singletonList);
            return newKnowledgeBase;
        })).get(PMMLRuntime.class);
    }

    KiePMMLNameValue getKiePMMLNameValue(PMML4Result pMML4Result, MULTIPLE_MODEL_METHOD multiple_model_method, double d) {
        String resultObjectName = pMML4Result.getResultObjectName();
        return new KiePMMLNameValue(resultObjectName, getEventuallyWeightedResult(pMML4Result.getResultVariables().get(resultObjectName), multiple_model_method, d));
    }

    List<KiePMMLNameValue> getKiePMMLNameValues(Map<String, Double> map, MULTIPLE_MODEL_METHOD multiple_model_method, double d) {
        return (List) map.entrySet().stream().map(entry -> {
            return new KiePMMLNameValue((String) entry.getKey(), getEventuallyWeightedResult(entry.getValue(), multiple_model_method, d));
        }).collect(Collectors.toList());
    }

    Object getEventuallyWeightedResult(Object obj, MULTIPLE_MODEL_METHOD multiple_model_method, double d) {
        switch (multiple_model_method) {
            case MAJORITY_VOTE:
            case MODEL_CHAIN:
            case SELECT_ALL:
            case SELECT_FIRST:
                return obj;
            case MAX:
            case SUM:
            case MEDIAN:
            case AVERAGE:
            case WEIGHTED_SUM:
            case WEIGHTED_MEDIAN:
            case WEIGHTED_AVERAGE:
                if (obj instanceof Number) {
                    return new KiePMMLValueWeight(((Number) obj).doubleValue(), d);
                }
                throw new KiePMMLException("Expected a number, retrieved " + obj.getClass().getName());
            case WEIGHTED_MAJORITY_VOTE:
                throw new KiePMMLException(multiple_model_method + " not implemented, yet");
            default:
                throw new KiePMMLException("Unrecognized MULTIPLE_MODEL_METHOD " + multiple_model_method);
        }
    }

    void validate(KiePMMLModel kiePMMLModel) {
        if (!(kiePMMLModel instanceof KiePMMLMiningModel)) {
            throw new KiePMMLModelException(String.format(EXPECTED_A_KIE_PMMLMINING_MODEL_RECEIVED, kiePMMLModel.getClass().getName()));
        }
        validateMining((KiePMMLMiningModel) kiePMMLModel);
    }

    void validateMining(KiePMMLMiningModel kiePMMLMiningModel) {
        if (kiePMMLMiningModel.getTargetField() == null || StringUtils.isEmpty(kiePMMLMiningModel.getTargetField().trim())) {
            throw new KiePMMLInternalException(String.format(TARGET_FIELD_REQUIRED_RETRIEVED, kiePMMLMiningModel.getTargetField()));
        }
    }

    private PMML4Result evaluateMiningModel(KiePMMLMiningModel kiePMMLMiningModel, PMMLContext pMMLContext, KieBase kieBase) {
        MULTIPLE_MODEL_METHOD multipleModelMethod = kiePMMLMiningModel.getSegmentation().getMultipleModelMethod();
        List<KiePMMLSegment> segments = kiePMMLMiningModel.getSegmentation().getSegments();
        LinkedHashMap<String, KiePMMLNameValueProbabilityMapTuple> linkedHashMap = new LinkedHashMap<>();
        for (KiePMMLSegment kiePMMLSegment : segments) {
            evaluateSegment(kiePMMLSegment, pMMLContext, kieBase, kiePMMLMiningModel.getName()).ifPresent(pMML4Result -> {
                pMML4Result.getResultVariables().forEach((str, obj) -> {
                    pMMLContext.getRequestData().addRequestParam(str, obj);
                });
                PMML4ResultProbabilityMapTuple pMML4ResultProbabilityMapTuple = new PMML4ResultProbabilityMapTuple(pMML4Result, pMMLContext.getProbabilityMap());
                linkedHashMap.put(kiePMMLSegment.getId(), new KiePMMLNameValueProbabilityMapTuple(getKiePMMLNameValue(pMML4ResultProbabilityMapTuple.pmml4Result, multipleModelMethod, kiePMMLSegment.getWeight()), getKiePMMLNameValues(pMML4ResultProbabilityMapTuple.probabilityResultMap, multipleModelMethod, kiePMMLSegment.getWeight())));
            });
        }
        return getPMML4Result(kiePMMLMiningModel, linkedHashMap, pMMLContext);
    }

    private Optional<PMML4Result> evaluateSegment(KiePMMLSegment kiePMMLSegment, PMMLContext pMMLContext, KieBase kieBase, String str) {
        logger.trace("evaluateSegment {}", kiePMMLSegment.getId());
        KiePMMLPredicate kiePMMLPredicate = kiePMMLSegment.getKiePMMLPredicate();
        Optional<PMML4Result> empty = Optional.empty();
        Map<String, Object> unwrappedParametersMap = Converter.getUnwrappedParametersMap(pMMLContext.getRequestData().getMappedRequestParams());
        String name = kiePMMLSegment.getModel().getName();
        if (kiePMMLPredicate.evaluate(unwrappedParametersMap)) {
            PMMLRuntime pMMLRuntime = getPMMLRuntime(kiePMMLSegment.getModel().getKModulePackageName(), kieBase, str);
            logger.trace("{}: matching predicate, evaluating... ", kiePMMLSegment.getId());
            empty = Optional.ofNullable(pMMLRuntime.evaluate(name, pMMLContext));
        }
        return empty;
    }
}
