package org.kie.pmml.pmml_4_2.predictive.models.mining;

import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.kie.api.KieBaseConfiguration;
import org.kie.internal.io.ResourceFactory;
import org.kie.pmml.pmml_4_2.PMML4ExecutionHelper;
import org.kie.pmml.pmml_4_2.PMMLRequestDataBuilder;
import org.kie.pmml.pmml_4_2.model.mining.SegmentExecution;

@RunWith(Parameterized.class)
/* loaded from: input_file:org/kie/pmml/pmml_4_2/predictive/models/mining/MiningModelSelectAllRegressionTest.class */
public class MiningModelSelectAllRegressionTest {
    private static final String PMML_SOURCE = "org/kie/pmml/pmml_4_2/test_mining_model_selectall_regression.pmml";
    private static final String MINING_MODEL = "SampleMiningModel";
    private static final String INPUT1_FIELD_NAME = "input1";
    private static final String INPUT2_FIELD_NAME = "input2";
    private static final String INPUT3_FIELD_NAME = "input3";
    private static final String OUTPUT_FIELD_NAME = "Result";
    private static final double COMPARISON_DELTA = 0.001d;
    private double input1;
    private double input2;
    private double input3;

    @Parameterized.Parameters
    public static Collection<Object[]> data() {
        return Arrays.asList(new Object[]{Double.valueOf(10.0d), Double.valueOf(10.0d), Double.valueOf(10.0d)}, new Object[]{Double.valueOf(200.0d), Double.valueOf(-1.0d), Double.valueOf(2.0d)}, new Object[]{Double.valueOf(90.0d), Double.valueOf(2.0d), Double.valueOf(4.0d)});
    }

    public MiningModelSelectAllRegressionTest(double d, double d2, double d3) {
        this.input1 = d;
        this.input2 = d2;
        this.input3 = d3;
    }

    @Test
    public void testMiningModelSelectAllRegression() {
        PMML4ExecutionHelper executionHelper = PMML4ExecutionHelper.PMML4ExecutionHelperFactory.getExecutionHelper(MINING_MODEL, ResourceFactory.newClassPathResource(PMML_SOURCE), (KieBaseConfiguration) null, true);
        executionHelper.submitRequest(new PMMLRequestDataBuilder("1234", MINING_MODEL).addParameter(INPUT1_FIELD_NAME, Double.valueOf(this.input1), Double.class).addParameter(INPUT2_FIELD_NAME, Double.valueOf(this.input2), Double.class).addParameter(INPUT3_FIELD_NAME, Double.valueOf(this.input3), Double.class).build());
        Map<String, Double> expectedResults = expectedResults(this.input1, this.input2, this.input3);
        HashMap hashMap = new HashMap();
        for (SegmentExecution segmentExecution : executionHelper.getChildModelSegments()) {
            hashMap.put(segmentExecution.getSegmentId(), segmentExecution.getResult().getResultValue(OUTPUT_FIELD_NAME, "value", Double.class, new Object[0]).orElse(null));
        }
        compareMaps(expectedResults, hashMap);
    }

    private Map<String, Double> expectedResults(double d, double d2, double d3) {
        HashMap hashMap = new HashMap();
        if (d < 50.0d) {
            hashMap.put("segment1", Double.valueOf(500.0d + (2.0d * d) + (5.0d * d2) + (d3 * d3)));
        }
        if (d > 150.0d) {
            hashMap.put("segment2", Double.valueOf((-500.0d) + d + d2 + d3));
        }
        if (d < 100.0d) {
            hashMap.put("segment3", Double.valueOf(800.0d + (2.0d * d * d) + (2.0d * d2 * d2) + (2.0d * d3 * d3)));
        }
        return hashMap;
    }

    private void compareMaps(Map<String, Double> map, Map<String, Double> map2) {
        Assert.assertEquals(map.size(), map2.size());
        for (String str : map.keySet()) {
            Assert.assertTrue(map2.containsKey(str));
            Assert.assertEquals(map.get(str).doubleValue(), map2.get(str).doubleValue(), COMPARISON_DELTA);
        }
    }
}
