package org.kie.pmml.models.regression.model;

import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.DoubleUnaryOperator;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.kie.pmml.api.enums.OP_TYPE;
import org.kie.pmml.api.exceptions.KiePMMLException;
import org.kie.pmml.api.iinterfaces.SerializableFunction;
import org.kie.pmml.api.runtime.PMMLRuntimeContext;
import org.kie.pmml.commons.Constants;
import org.kie.pmml.commons.model.KiePMMLExtension;
import org.kie.pmml.models.regression.model.AbstractKiePMMLTable;
import org.kie.pmml.models.regression.model.enums.REGRESSION_NORMALIZATION_METHOD;

/* loaded from: input_file:BOOT-INF/lib/kie-pmml-models-regression-model-8.33.1-SNAPSHOT.jar:org/kie/pmml/models/regression/model/KiePMMLClassificationTable.class */
public final class KiePMMLClassificationTable extends AbstractKiePMMLTable {
    private static final long serialVersionUID = 458989873257189359L;
    private REGRESSION_NORMALIZATION_METHOD regressionNormalizationMethod;
    private OP_TYPE opType;
    private Map<String, KiePMMLRegressionTable> categoryTableMap;
    private SerializableFunction<LinkedHashMap<String, Double>, LinkedHashMap<String, Double>> probabilityMapFunction;
    private boolean isBinary;

    /* loaded from: input_file:BOOT-INF/lib/kie-pmml-models-regression-model-8.33.1-SNAPSHOT.jar:org/kie/pmml/models/regression/model/KiePMMLClassificationTable$Builder.class */
    public static class Builder extends AbstractKiePMMLTable.Builder<KiePMMLClassificationTable> {
        protected Builder(String str, List<KiePMMLExtension> list) {
            super("KiePMMLRegressionClassificationTable-", () -> {
                return new KiePMMLClassificationTable(str, list);
            });
        }

        public Builder withRegressionNormalizationMethod(REGRESSION_NORMALIZATION_METHOD regression_normalization_method) {
            ((KiePMMLClassificationTable) this.toBuild).regressionNormalizationMethod = regression_normalization_method;
            return this;
        }

        public Builder withOpType(OP_TYPE op_type) {
            ((KiePMMLClassificationTable) this.toBuild).opType = op_type;
            return this;
        }

        public Builder withCategoryTableMap(Map<String, KiePMMLRegressionTable> map) {
            if (map != null) {
                ((KiePMMLClassificationTable) this.toBuild).categoryTableMap.putAll(map);
            }
            return this;
        }

        public Builder withProbabilityMapFunction(SerializableFunction<LinkedHashMap<String, Double>, LinkedHashMap<String, Double>> serializableFunction) {
            ((KiePMMLClassificationTable) this.toBuild).probabilityMapFunction = serializableFunction;
            return this;
        }

        public Builder withIsBinary(Boolean bool) {
            if (bool != null) {
                ((KiePMMLClassificationTable) this.toBuild).isBinary = bool.booleanValue();
            }
            return this;
        }
    }

    public static Builder builder(String str, List<KiePMMLExtension> list) {
        return new Builder(str, list);
    }

    @Override // org.kie.pmml.models.regression.model.AbstractKiePMMLTable
    public Object evaluateRegression(Map<String, Object> map, PMMLRuntimeContext pMMLRuntimeContext) {
        LinkedHashMap<String, Double> linkedHashMap = new LinkedHashMap<>();
        for (Map.Entry<String, KiePMMLRegressionTable> entry : this.categoryTableMap.entrySet()) {
            linkedHashMap.put(entry.getKey(), (Double) entry.getValue().evaluateRegression(map, pMMLRuntimeContext));
        }
        pMMLRuntimeContext.setProbabilityResultMap(this.probabilityMapFunction.apply(linkedHashMap));
        return ((Map.Entry) Collections.max(pMMLRuntimeContext.getProbabilityResultMap().entrySet(), Map.Entry.comparingByValue())).getKey();
    }

    public boolean isBinary() {
        return this.isBinary;
    }

    public REGRESSION_NORMALIZATION_METHOD getRegressionNormalizationMethod() {
        return this.regressionNormalizationMethod;
    }

    public OP_TYPE getOpType() {
        return this.opType;
    }

    public Map<String, KiePMMLRegressionTable> getCategoryTableMap() {
        return this.categoryTableMap;
    }

    private KiePMMLClassificationTable(String str, List<KiePMMLExtension> list) {
        super(str, list);
        this.categoryTableMap = new LinkedHashMap();
    }

    public static LinkedHashMap<String, Double> getProbabilityMap(LinkedHashMap<String, Double> linkedHashMap, DoubleUnaryOperator doubleUnaryOperator, DoubleUnaryOperator doubleUnaryOperator2) {
        if (linkedHashMap.size() != 2) {
            throw new KiePMMLException(String.format(Constants.EXPECTED_TWO_ENTRIES_RETRIEVED, Integer.valueOf(linkedHashMap.size())));
        }
        LinkedHashMap<String, Double> linkedHashMap2 = new LinkedHashMap<>();
        String[] strArr = (String[]) linkedHashMap.keySet().toArray(new String[0]);
        double applyAsDouble = doubleUnaryOperator.applyAsDouble(linkedHashMap.get(strArr[0]).doubleValue());
        double applyAsDouble2 = doubleUnaryOperator2.applyAsDouble(applyAsDouble);
        linkedHashMap2.put(strArr[0], Double.valueOf(applyAsDouble));
        linkedHashMap2.put(strArr[1], Double.valueOf(applyAsDouble2));
        return linkedHashMap2;
    }

    public static LinkedHashMap<String, Double> getSOFTMAXProbabilityMap(LinkedHashMap<String, Double> linkedHashMap) {
        LinkedHashMap<String, Double> linkedHashMap2 = new LinkedHashMap<>();
        AtomicReference atomicReference = new AtomicReference(Double.valueOf(0.0d));
        for (Map.Entry<String, Double> entry : linkedHashMap.entrySet()) {
            double exp = Math.exp(entry.getValue().doubleValue());
            linkedHashMap2.put(entry.getKey(), Double.valueOf(exp));
            atomicReference.accumulateAndGet(Double.valueOf(exp), (v0, v1) -> {
                return Double.sum(v0, v1);
            });
        }
        double doubleValue = ((Double) atomicReference.get()).doubleValue();
        for (Map.Entry<String, Double> entry2 : linkedHashMap2.entrySet()) {
            entry2.setValue(Double.valueOf(entry2.getValue().doubleValue() / doubleValue));
        }
        return linkedHashMap2;
    }

    public static LinkedHashMap<String, Double> getSIMPLEMAXProbabilityMap(LinkedHashMap<String, Double> linkedHashMap) {
        AtomicReference atomicReference = new AtomicReference(Double.valueOf(0.0d));
        Iterator<Double> it = linkedHashMap.values().iterator();
        while (it.hasNext()) {
            atomicReference.accumulateAndGet(it.next(), (v0, v1) -> {
                return Double.sum(v0, v1);
            });
        }
        double doubleValue = ((Double) atomicReference.get()).doubleValue();
        LinkedHashMap<String, Double> linkedHashMap2 = new LinkedHashMap<>();
        for (Map.Entry<String, Double> entry : linkedHashMap.entrySet()) {
            linkedHashMap2.put(entry.getKey(), Double.valueOf(entry.getValue().doubleValue() / doubleValue));
        }
        return linkedHashMap2;
    }

    public static LinkedHashMap<String, Double> getNONEProbabilityMap(LinkedHashMap<String, Double> linkedHashMap) {
        LinkedHashMap<String, Double> linkedHashMap2 = new LinkedHashMap<>();
        String[] strArr = (String[]) linkedHashMap.keySet().toArray(new String[0]);
        AtomicReference atomicReference = new AtomicReference(Double.valueOf(0.0d));
        for (int i = 0; i < linkedHashMap.size(); i++) {
            String str = strArr[i];
            double doubleValue = linkedHashMap.get(str).doubleValue();
            if (i < strArr.length - 1) {
                atomicReference.accumulateAndGet(Double.valueOf(doubleValue), (v0, v1) -> {
                    return Double.sum(v0, v1);
                });
                linkedHashMap2.put(str, Double.valueOf(doubleValue));
            } else {
                linkedHashMap2.put(str, Double.valueOf(1.0d - ((Double) atomicReference.get()).doubleValue()));
            }
        }
        return linkedHashMap2;
    }

    public static LinkedHashMap<String, Double> getNONEBinaryProbabilityMap(LinkedHashMap<String, Double> linkedHashMap) {
        LinkedHashMap<String, Double> linkedHashMap2 = new LinkedHashMap<>();
        String[] strArr = (String[]) linkedHashMap.keySet().toArray(new String[0]);
        String str = strArr[0];
        double max = Math.max(0.0d, Math.min(1.0d, linkedHashMap.get(str).doubleValue()));
        linkedHashMap2.put(str, Double.valueOf(max));
        linkedHashMap2.put(strArr[1], Double.valueOf(1.0d - max));
        return linkedHashMap2;
    }

    public static LinkedHashMap<String, Double> getLOGITProbabilityMap(LinkedHashMap<String, Double> linkedHashMap) {
        return getProbabilityMap(linkedHashMap, d -> {
            return 1.0d / (1.0d + Math.exp(0.0d - d));
        }, d2 -> {
            return 1.0d - d2;
        });
    }

    public static LinkedHashMap<String, Double> getPROBITProbabilityMap(LinkedHashMap<String, Double> linkedHashMap) {
        return getProbabilityMap(linkedHashMap, d -> {
            return new NormalDistribution().cumulativeProbability(d);
        }, d2 -> {
            return 1.0d - d2;
        });
    }

    public static LinkedHashMap<String, Double> getCLOGLOGProbabilityMap(LinkedHashMap<String, Double> linkedHashMap) {
        return getProbabilityMap(linkedHashMap, d -> {
            return 1.0d - Math.exp(0.0d - Math.exp(d));
        }, d2 -> {
            return 1.0d - d2;
        });
    }

    public static LinkedHashMap<String, Double> getCAUCHITProbabilityMap(LinkedHashMap<String, Double> linkedHashMap) {
        return getProbabilityMap(linkedHashMap, d -> {
            return 0.5d + (0.3183098861837907d * Math.atan(d));
        }, d2 -> {
            return 1.0d - d2;
        });
    }
}
