/*
 * Decompiled with CFR 0.152.
 */
package org.kie.pmml.pmml_4_2.predictive.models.mining;

import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.kie.api.io.Resource;
import org.kie.api.pmml.PMMLRequestData;
import org.kie.internal.io.ResourceFactory;
import org.kie.pmml.pmml_4_2.PMML4ExecutionHelper;
import org.kie.pmml.pmml_4_2.PMMLRequestDataBuilder;
import org.kie.pmml.pmml_4_2.model.mining.SegmentExecution;

@RunWith(value=Parameterized.class)
public class MiningModelSelectAllRegressionTest {
    private static final String PMML_SOURCE = "org/kie/pmml/pmml_4_2/test_mining_model_selectall_regression.pmml";
    private static final String MINING_MODEL = "SampleMiningModel";
    private static final String INPUT1_FIELD_NAME = "input1";
    private static final String INPUT2_FIELD_NAME = "input2";
    private static final String INPUT3_FIELD_NAME = "input3";
    private static final String OUTPUT_FIELD_NAME = "Result";
    private static final double COMPARISON_DELTA = 0.001;
    private double input1;
    private double input2;
    private double input3;

    @Parameterized.Parameters
    public static Collection<Object[]> data() {
        return Arrays.asList({10.0, 10.0, 10.0}, {200.0, -1.0, 2.0}, {90.0, 2.0, 4.0});
    }

    public MiningModelSelectAllRegressionTest(double input1, double input2, double input3) {
        this.input1 = input1;
        this.input2 = input2;
        this.input3 = input3;
    }

    @Test
    public void testMiningModelSelectAllRegression() {
        Resource res = ResourceFactory.newClassPathResource((String)PMML_SOURCE);
        PMML4ExecutionHelper helper = PMML4ExecutionHelper.PMML4ExecutionHelperFactory.getExecutionHelper((String)MINING_MODEL, (Resource)res, null, (boolean)true);
        PMMLRequestDataBuilder rdb = new PMMLRequestDataBuilder("1234", MINING_MODEL).addParameter(INPUT1_FIELD_NAME, (Object)this.input1, Double.class).addParameter(INPUT2_FIELD_NAME, (Object)this.input2, Double.class).addParameter(INPUT3_FIELD_NAME, (Object)this.input3, Double.class);
        PMMLRequestData request = rdb.build();
        helper.submitRequest(request);
        Map<String, Double> expected = this.expectedResults(this.input1, this.input2, this.input3);
        HashMap<String, Double> executedSegments = new HashMap<String, Double>();
        for (SegmentExecution cms : helper.getChildModelSegments()) {
            executedSegments.put(cms.getSegmentId(), cms.getResult().getResultValue(OUTPUT_FIELD_NAME, "value", Double.class, new Object[0]).orElse(null));
        }
        this.compareMaps(expected, executedSegments);
    }

    private Map<String, Double> expectedResults(double input1, double input2, double input3) {
        HashMap<String, Double> expected = new HashMap<String, Double>();
        if (input1 < 50.0) {
            expected.put("segment1", 500.0 + 2.0 * input1 + 5.0 * input2 + input3 * input3);
        }
        if (input1 > 150.0) {
            expected.put("segment2", -500.0 + input1 + input2 + input3);
        }
        if (input1 < 100.0) {
            expected.put("segment3", 800.0 + 2.0 * input1 * input1 + 2.0 * input2 * input2 + 2.0 * input3 * input3);
        }
        return expected;
    }

    private void compareMaps(Map<String, Double> expected, Map<String, Double> result) {
        Assert.assertEquals((long)expected.size(), (long)result.size());
        for (String key : expected.keySet()) {
            Assert.assertTrue((boolean)result.containsKey(key));
            Assert.assertEquals((double)expected.get(key), (double)result.get(key), (double)0.001);
        }
    }
}

