package org.kie.kogito.predictions.smile;

import java.text.ParseException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.kie.api.runtime.process.WorkItem;
import org.kie.kogito.internal.process.runtime.KogitoWorkItem;
import org.kie.kogito.prediction.api.PredictionOutcome;
import org.kie.kogito.prediction.api.PredictionService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.classification.RandomForest;
import smile.data.Attribute;
import smile.data.AttributeDataset;
import smile.data.NominalAttribute;
import smile.data.NumericAttribute;
import smile.data.StringAttribute;

/* loaded from: input_file:org/kie/kogito/predictions/smile/SmileRandomForest.class */
public class SmileRandomForest extends AbstractPredictionEngine implements PredictionService {
    public static final String IDENTIFIER = "SMILERandomForest";
    private static final String UNABLE_PARSE_TEXT = "Unable to parse text";
    private static final Logger logger = LoggerFactory.getLogger(SmileRandomForest.class);
    private final AttributeDataset dataset;
    private final Map<String, Attribute> smileAttributes;
    private final Attribute outcomeAttribute;
    private final AttributeType outcomeAttributeType;
    private final int numAttributes;
    private final int numberTrees;
    protected List<String> attributeNames;
    private Set<String> outcomeSet;
    private static final int MINIMUM_OBSERVATIONS = 1200;
    private int observations;

    public SmileRandomForest(RandomForestConfiguration randomForestConfiguration) {
        this(randomForestConfiguration.getInputFeatures(), randomForestConfiguration.getOutcomeName(), randomForestConfiguration.getOutcomeType(), randomForestConfiguration.getConfidenceThreshold(), randomForestConfiguration.getNumTrees());
    }

    public SmileRandomForest(Map<String, AttributeType> map, String str, AttributeType attributeType, double d, int i) {
        super(map, str, attributeType, d);
        this.attributeNames = new ArrayList();
        this.outcomeSet = new HashSet();
        this.observations = 0;
        this.numberTrees = i;
        this.smileAttributes = new HashMap();
        for (Map.Entry<String, AttributeType> entry : map.entrySet()) {
            String key = entry.getKey();
            this.smileAttributes.put(key, createAttribute(key, entry.getValue()));
            this.attributeNames.add(key);
        }
        this.numAttributes = this.smileAttributes.size();
        this.outcomeAttribute = createAttribute(str, attributeType);
        this.outcomeAttributeType = attributeType;
        this.dataset = new AttributeDataset("dataset", (Attribute[]) this.smileAttributes.values().toArray(new Attribute[this.numAttributes]), this.outcomeAttribute);
    }

    protected Attribute createAttribute(String str, AttributeType attributeType) {
        return (attributeType == AttributeType.NOMINAL || attributeType == AttributeType.BOOLEAN) ? new NominalAttribute(str) : attributeType == AttributeType.NUMERIC ? new NumericAttribute(str) : new StringAttribute(str);
    }

    protected Object convertValue(String str, AttributeType attributeType) {
        return attributeType == AttributeType.NOMINAL ? str : attributeType == AttributeType.NUMERIC ? Long.valueOf(str) : attributeType == AttributeType.BOOLEAN ? Boolean.valueOf(str) : str;
    }

    public void addData(Map<String, Object> map, Object obj) {
        double[] dArr = new double[this.numAttributes];
        int i = 0;
        for (Map.Entry<String, Attribute> entry : this.smileAttributes.entrySet()) {
            try {
                dArr[i] = this.smileAttributes.get(entry.getKey()).valueOf(map.get(entry.getKey()).toString());
            } catch (ParseException e) {
                logger.error(UNABLE_PARSE_TEXT, e);
            }
            i++;
        }
        try {
            String obj2 = obj.toString();
            this.outcomeSet.add(obj2);
            this.dataset.add(dArr, this.outcomeAttribute.valueOf(obj2));
        } catch (ParseException e2) {
            logger.error(UNABLE_PARSE_TEXT, e2);
        }
    }

    protected double[] buildFeatures(Map<String, Object> map) {
        double[] dArr = new double[this.numAttributes];
        for (int i = 0; i < this.numAttributes; i++) {
            String str = this.attributeNames.get(i);
            try {
                dArr[i] = this.smileAttributes.get(str).valueOf(map.get(str).toString());
            } catch (ParseException e) {
                logger.error(UNABLE_PARSE_TEXT, e);
            }
        }
        return dArr;
    }

    public String getIdentifier() {
        return IDENTIFIER;
    }

    public PredictionOutcome predict(WorkItem workItem, Map<String, Object> map) {
        logger.debug("Predicting with input data: {}", map);
        if (this.observations > MINIMUM_OBSERVATIONS) {
            this.confidenceThreshold = 0.75d;
        }
        HashMap hashMap = new HashMap();
        if (this.outcomeSet.size() < 2) {
            hashMap.put("confidence", Double.valueOf(0.0d));
            return new PredictionOutcome(0.0d, this.confidenceThreshold, hashMap);
        }
        RandomForest randomForest = new RandomForest(this.dataset, this.numberTrees);
        double[] buildFeatures = buildFeatures(map);
        double[] dArr = new double[this.outcomeSet.size()];
        double predict = randomForest.predict(buildFeatures, dArr);
        String attribute = this.dataset.responseAttribute().toString(predict);
        hashMap.put(this.outcomeAttribute.getName(), convertValue(attribute, this.outcomeAttributeType));
        double d = dArr[(int) predict];
        hashMap.put("confidence", Double.valueOf(d));
        logger.debug("task id {}, total {} observations, prediction = {}, confidence = {} (threshold = {})", new Object[]{((KogitoWorkItem) workItem).getStringId(), Integer.valueOf(this.observations), attribute, Double.valueOf(d), Double.valueOf(this.confidenceThreshold)});
        return new PredictionOutcome(d, this.confidenceThreshold, hashMap);
    }

    public void train(WorkItem workItem, Map<String, Object> map, Map<String, Object> map2) {
        logger.debug("Training with input data: {}", map);
        logger.debug("Training with output data: {}", map2);
        this.observations++;
        addData(map, map2.get(this.outcomeAttribute.getName()));
    }
}
