/*
 * Decompiled with CFR 0.152.
 */
package org.kie.pmml.pmml_4_2.predictive.models;

import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.drools.core.impl.InternalKnowledgeBase;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.kie.api.pmml.PMML4Result;
import org.kie.api.pmml.PMMLRequestData;
import org.kie.api.runtime.rule.RuleUnit;
import org.kie.api.runtime.rule.RuleUnitExecutor;
import org.kie.pmml.pmml_4_2.DroolsAbstractPMMLTest;

@RunWith(value=Parameterized.class)
public class SimpleRegressionTest
extends DroolsAbstractPMMLTest {
    private static final String source1 = "org/kie/pmml/pmml_4_2/test_regression.pmml";
    private static final String source2 = "org/kie/pmml/pmml_4_2/test_regression_clax.pmml";
    private static final double COMPARISON_DELTA = 1.0E-6;
    private double fld1;
    private double fld2;
    private String fld3;

    @Parameterized.Parameters
    public static Collection<Object[]> data() {
        return Arrays.asList({1.0, 1.0, "x"}, {0.9, 0.3, "x"}, {12.0, 25.0, "x"}, {0.2, 0.1, "x"}, {5, 8, "y"});
    }

    public SimpleRegressionTest(double fld1, double fld2, String fld3) {
        this.fld1 = fld1;
        this.fld2 = fld2;
        this.fld3 = fld3;
    }

    @Test
    public void testRegression() throws Exception {
        RuleUnitExecutor executor = this.createExecutor(source1);
        PMMLRequestData request = new PMMLRequestData("123", "LinReg");
        request.addRequestParam("fld1", (Object)this.fld1);
        request.addRequestParam("fld2", (Object)this.fld2);
        request.addRequestParam("fld3", (Object)this.fld3);
        PMML4Result resultHolder = new PMML4Result();
        List<String> possiblePackages = this.calculatePossiblePackageNames("LinReg", new String[0]);
        Class<? extends RuleUnit> unitClass = this.getStartingRuleUnit("RuleUnitIndicator", (InternalKnowledgeBase)this.kbase, possiblePackages);
        Assert.assertNotNull(unitClass);
        int x = executor.run(unitClass);
        this.data.insert((Object)request);
        this.resultData.insert((Object)resultHolder);
        executor.run(unitClass);
        Assert.assertEquals((Object)"OK", (Object)resultHolder.getResultCode());
        Assert.assertNotNull((Object)resultHolder.getResultValue("Fld4", null, new Object[0]));
        Double value = resultHolder.getResultValue("Fld4", "value", Double.class, new Object[0]).orElse(null);
        Assert.assertNotNull((Object)value);
        double expectedValue = this.simpleRegressionResult(this.fld1, this.fld2, this.fld3);
        Assert.assertEquals((double)expectedValue, (double)value, (double)1.0E-6);
    }

    private double simpleRegressionResult(double fld1, double fld2, String fld3) {
        double result = 0.5 + 5.0 * fld1 * fld1 + 2.0 * fld2 + SimpleRegressionTest.fld3Coefficient(fld3) + 0.4 * fld1 * fld2;
        result = 1.0 / (1.0 + Math.exp(-result));
        return result;
    }

    @Test
    public void testClassification() throws Exception {
        RuleUnitExecutor executor = this.createExecutor(source2);
        PMMLRequestData request = new PMMLRequestData("123", "LinReg");
        request.addRequestParam("fld1", (Object)this.fld1);
        request.addRequestParam("fld2", (Object)this.fld2);
        request.addRequestParam("fld3", (Object)this.fld3);
        PMML4Result resultHolder = new PMML4Result();
        List<String> possiblePackages = this.calculatePossiblePackageNames("LinReg", new String[0]);
        Class<? extends RuleUnit> unitClass = this.getStartingRuleUnit("RuleUnitIndicator", (InternalKnowledgeBase)this.kbase, possiblePackages);
        Assert.assertNotNull(unitClass);
        int x = executor.run(unitClass);
        this.data.insert((Object)request);
        this.resultData.insert((Object)resultHolder);
        executor.run(unitClass);
        Map<String, Double> probabilities = this.categoryProbabilities(this.fld1, this.fld2, this.fld3);
        String maxCategory = null;
        double maxValue = Double.MIN_VALUE;
        for (String key : probabilities.keySet()) {
            double value = probabilities.get(key);
            if (!(value > maxValue)) continue;
            maxCategory = key;
            maxValue = value;
        }
        Assert.assertNotNull((Object)resultHolder.getResultValue("RegOut", null, new Object[0]));
        Assert.assertNotNull((Object)resultHolder.getResultValue("RegProb", null, new Object[0]));
        Assert.assertNotNull((Object)resultHolder.getResultValue("RegProbA", null, new Object[0]));
        String regOut = resultHolder.getResultValue("RegOut", "value", String.class, new Object[0]).orElse(null);
        Double regProb = resultHolder.getResultValue("RegProb", "value", Double.class, new Object[0]).orElse(null);
        Double regProbA = resultHolder.getResultValue("RegProbA", "value", Double.class, new Object[0]).orElse(null);
        Assert.assertEquals((Object)("cat" + maxCategory), (Object)regOut);
        Assert.assertEquals((double)maxValue, (double)regProb, (double)1.0E-6);
        Assert.assertEquals((double)probabilities.get("A"), (double)regProbA, (double)1.0E-6);
    }

    private Map<String, Double> categoryProbabilities(double fld1, double fld2, String fld3) {
        HashMap<String, RegressionInterface> regressionTables = new HashMap<String, RegressionInterface>();
        regressionTables.put("A", (f1, f2, f3) -> 0.1 + fld1 + fld2 + SimpleRegressionTest.fld3Coefficient(fld3));
        regressionTables.put("B", (f1, f2, f3) -> 0.2 + 2.0 * fld1 + 2.0 * fld2 + SimpleRegressionTest.fld3Coefficient(fld3));
        regressionTables.put("C", (f1, f2, f3) -> 0.3 + 3.0 * fld1 + 3.0 * fld2 + SimpleRegressionTest.fld3Coefficient(fld3));
        regressionTables.put("D", (f1, f2, f3) -> 5.0);
        HashMap<String, Double> regressionTablesValues = new HashMap<String, Double>();
        double sum = 0.0;
        for (String item : regressionTables.keySet()) {
            double value = ((RegressionInterface)regressionTables.get(item)).apply(fld1, fld2, fld3);
            value = Math.exp(value);
            regressionTablesValues.put(item, value);
            sum += value;
        }
        for (String item : regressionTables.keySet()) {
            regressionTablesValues.put(item, (Double)regressionTablesValues.get(item) / sum);
        }
        return regressionTablesValues;
    }

    private static double fld3Coefficient(String fld3) {
        HashMap<String, Double> fld3ValueMap = new HashMap<String, Double>();
        fld3ValueMap.put("x", -3.0);
        fld3ValueMap.put("y", 3.0);
        if (!fld3ValueMap.containsKey(fld3)) {
            return 0.0;
        }
        return (Double)fld3ValueMap.get(fld3);
    }

    @FunctionalInterface
    private static interface RegressionInterface {
        public Double apply(double var1, double var3, String var5);
    }
}

