package org.kie.pmml.models.drools.scorecard.compiler.factories;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import org.dmg.pmml.Array;
import org.dmg.pmml.CompoundPredicate;
import org.dmg.pmml.DataDictionary;
import org.dmg.pmml.DataType;
import org.dmg.pmml.PMML;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.SimpleSetPredicate;
import org.dmg.pmml.scorecard.Attribute;
import org.dmg.pmml.scorecard.Characteristic;
import org.dmg.pmml.scorecard.Characteristics;
import org.dmg.pmml.scorecard.Scorecard;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.kie.pmml.commons.enums.ResultCode;
import org.kie.pmml.commons.exceptions.KiePMMLException;
import org.kie.pmml.commons.model.enums.BOOLEAN_OPERATOR;
import org.kie.pmml.commons.model.enums.DATA_TYPE;
import org.kie.pmml.commons.model.enums.OPERATOR;
import org.kie.pmml.compiler.commons.utils.ModelUtils;
import org.kie.pmml.compiler.testutils.TestUtils;
import org.kie.pmml.models.drools.ast.KiePMMLDroolsRule;
import org.kie.pmml.models.drools.ast.KiePMMLFieldOperatorValue;
import org.kie.pmml.models.drools.ast.factories.KiePMMLDataDictionaryASTFactory;
import org.kie.pmml.models.drools.commons.utils.KiePMMLDroolsModelUtils;
import org.kie.pmml.models.drools.scorecard.model.enums.REASONCODE_ALGORITHM;
import org.kie.pmml.models.drools.tuples.KiePMMLOperatorValue;
import org.kie.pmml.models.drools.tuples.KiePMMLReasonCodeAndValue;
import org.kie.pmml.models.drools.utils.KiePMMLASTTestUtils;

/* loaded from: input_file:org/kie/pmml/models/drools/scorecard/compiler/factories/KiePMMLScorecardModelCharacteristicASTFactoryTest.class */
public class KiePMMLScorecardModelCharacteristicASTFactoryTest {
    private static final String SOURCE_SAMPLE = "ScorecardSample.pmml";
    private final String fieldName = "age";
    private PMML samplePmml;
    private Scorecard scorecardModel;
    private DataDictionary dataDictionary;

    @Before
    public void setUp() throws Exception {
        this.samplePmml = TestUtils.loadFromFile(SOURCE_SAMPLE);
        Assert.assertNotNull(this.samplePmml);
        Assert.assertEquals(1L, this.samplePmml.getModels().size());
        Assert.assertTrue(this.samplePmml.getModels().get(0) instanceof Scorecard);
        this.scorecardModel = (Scorecard) this.samplePmml.getModels().get(0);
        this.dataDictionary = this.samplePmml.getDataDictionary();
    }

    @Test
    public void declareRulesFromCharacteristics() {
        Characteristics characteristics = this.scorecardModel.getCharacteristics();
        List declareRulesFromCharacteristics = getKiePMMLScorecardModelCharacteristicASTFactory().declareRulesFromCharacteristics(characteristics, "_will", (Number) null);
        List characteristics2 = characteristics.getCharacteristics();
        ArrayList arrayList = new ArrayList();
        AtomicInteger atomicInteger = new AtomicInteger(0);
        int i = 0;
        while (i < characteristics2.size()) {
            Characteristic characteristic = (Characteristic) characteristics2.get(i);
            arrayList.addAll(characteristic.getAttributes());
            for (int i2 = 0; i2 < characteristic.getAttributes().size(); i2++) {
                Attribute attribute = (Attribute) characteristic.getAttributes().get(i2);
                KiePMMLDroolsRule kiePMMLDroolsRule = (KiePMMLDroolsRule) declareRulesFromCharacteristics.get(atomicInteger.incrementAndGet());
                int i3 = 1;
                BOOLEAN_OPERATOR boolean_operator = BOOLEAN_OPERATOR.AND;
                Integer num = attribute.getPredicate() instanceof SimplePredicate ? 1 : null;
                if (attribute.getPredicate() instanceof CompoundPredicate) {
                    i3 = attribute.getPredicate().getPredicates().size();
                    num = 1;
                }
                Integer num2 = attribute.getPredicate() instanceof SimpleSetPredicate ? 1 : null;
                boolean z = i == characteristics2.size() - 1;
                commonValidateRule(kiePMMLDroolsRule, attribute, z ? "DONE" : String.format("%s_%s", "_will", ((Characteristic) characteristics2.get(i + 1)).getName()), "_will_" + characteristic.getName(), i2, z, num, num2, boolean_operator, null, Integer.valueOf(i3));
            }
            i++;
        }
        Assert.assertEquals(arrayList.size() + 1, declareRulesFromCharacteristics.size());
    }

    @Test
    public void declareRuleFromCharacteristicNotLastCharacteristic() {
        Characteristic characteristic = getCharacteristic();
        ArrayList arrayList = new ArrayList();
        String[] strArr = {"value <= 5.0", "value >= 5.0 && value < 12.0"};
        int[] iArr = {1, 2};
        getKiePMMLScorecardModelCharacteristicASTFactory().declareRuleFromCharacteristic(characteristic, "parent_path", arrayList, "status_to_set", false);
        Assert.assertEquals(characteristic.getAttributes().size(), arrayList.size());
        for (int i = 0; i < arrayList.size(); i++) {
            commonValidateRule((KiePMMLDroolsRule) arrayList.get(i), (Attribute) characteristic.getAttributes().get(i), "status_to_set", "parent_path_AgeScore", i, false, 1, null, BOOLEAN_OPERATOR.AND, strArr[i], Integer.valueOf(iArr[i]));
        }
    }

    @Test
    public void declareRuleFromAttributeWithSimplePredicateNotLastCharacteristic() {
        Attribute simplePredicateAttribute = getSimplePredicateAttribute();
        ArrayList arrayList = new ArrayList();
        getKiePMMLScorecardModelCharacteristicASTFactory().declareRuleFromAttribute(simplePredicateAttribute, "parent_path", 2, arrayList, "status_to_set", "REASON_CODE", Double.valueOf(12.0d), false);
        Assert.assertEquals(1L, arrayList.size());
        commonValidateRule((KiePMMLDroolsRule) arrayList.get(0), simplePredicateAttribute, "status_to_set", "parent_path", 2, false, 1, null, BOOLEAN_OPERATOR.AND, "value <= 5.0", 1);
    }

    @Test
    public void declareRuleFromAttributeWithSimplePredicateUseReasonCodesTrue() {
        Attribute simplePredicateAttribute = getSimplePredicateAttribute();
        ArrayList arrayList = new ArrayList();
        getKiePMMLScorecardModelCharacteristicASTFactory().withReasonCodes((Number) null, REASONCODE_ALGORITHM.POINTS_ABOVE).declareRuleFromAttribute(simplePredicateAttribute, "parent_path", 2, arrayList, "status_to_set", "REASON_CODE", Double.valueOf(12.0d), false);
        Assert.assertEquals(1L, arrayList.size());
        KiePMMLDroolsRule kiePMMLDroolsRule = (KiePMMLDroolsRule) arrayList.get(0);
        commonValidateRule(kiePMMLDroolsRule, simplePredicateAttribute, "status_to_set", "parent_path", 2, false, 1, null, BOOLEAN_OPERATOR.AND, "value <= 5.0", 1);
        KiePMMLReasonCodeAndValue reasonCodeAndValue = kiePMMLDroolsRule.getReasonCodeAndValue();
        Assert.assertNotNull(reasonCodeAndValue);
        Assert.assertEquals("REASON_CODE", reasonCodeAndValue.getReasonCode());
        Assert.assertEquals(simplePredicateAttribute.getPartialScore().doubleValue() - 12.0d, reasonCodeAndValue.getValue(), 0.0d);
    }

    @Test
    public void declareRuleFromAttributeWithSimplePredicateLastCharacteristic() {
        Attribute simplePredicateAttribute = getSimplePredicateAttribute();
        ArrayList arrayList = new ArrayList();
        getKiePMMLScorecardModelCharacteristicASTFactory().declareRuleFromAttribute(simplePredicateAttribute, "parent_path", 2, arrayList, "status_to_set", "REASON_CODE", Double.valueOf(12.0d), true);
        Assert.assertEquals(1L, arrayList.size());
        commonValidateRule((KiePMMLDroolsRule) arrayList.get(0), simplePredicateAttribute, "status_to_set", "parent_path", 2, true, 1, null, BOOLEAN_OPERATOR.AND, "value <= 5.0", 1);
    }

    @Test
    public void declareRuleFromAttributeWithCompoundPredicate() {
        Attribute compoundPredicateAttribute = getCompoundPredicateAttribute();
        ArrayList arrayList = new ArrayList();
        getKiePMMLScorecardModelCharacteristicASTFactory().declareRuleFromAttribute(compoundPredicateAttribute, "parent_path", 2, arrayList, "status_to_set", "REASON_CODE", Double.valueOf(12.0d), false);
        Assert.assertEquals(1L, arrayList.size());
        commonValidateRule((KiePMMLDroolsRule) arrayList.get(0), compoundPredicateAttribute, "status_to_set", "parent_path", 2, false, 1, null, BOOLEAN_OPERATOR.AND, "value >= 5.0 && value < 12.0", 2);
    }

    @Test
    public void declareRuleFromAttributeWithSimpleSetPredicate() {
        Attribute simpleSetPredicateAttribute = getSimpleSetPredicateAttribute();
        ArrayList arrayList = new ArrayList();
        getKiePMMLScorecardModelCharacteristicASTFactory().declareRuleFromAttribute(simpleSetPredicateAttribute, "parent_path", 2, arrayList, "status_to_set", "REASON_CODE", Double.valueOf(12.0d), false);
        Assert.assertEquals(1L, arrayList.size());
        commonValidateRule((KiePMMLDroolsRule) arrayList.get(0), simpleSetPredicateAttribute, "status_to_set", "parent_path", 2, false, null, 1, null, null, null);
    }

    private void commonValidateRule(KiePMMLDroolsRule kiePMMLDroolsRule, Attribute attribute, String str, String str2, int i, boolean z, Integer num, Integer num2, BOOLEAN_OPERATOR boolean_operator, String str3, Integer num3) {
        Assert.assertEquals(String.format("%s_%s", str2, Integer.valueOf(i)), kiePMMLDroolsRule.getName());
        Assert.assertEquals(str, kiePMMLDroolsRule.getStatusToSet());
        Assert.assertEquals(String.format("status == \"%s\"", str2), kiePMMLDroolsRule.getStatusConstraint());
        Assert.assertEquals(attribute.getPartialScore().doubleValue(), kiePMMLDroolsRule.getToAccumulate().doubleValue(), 0.0d);
        if (z) {
            Assert.assertTrue(kiePMMLDroolsRule.isAccumulationResult());
            Assert.assertEquals(ResultCode.OK, kiePMMLDroolsRule.getResultCode());
        } else {
            Assert.assertFalse(kiePMMLDroolsRule.isAccumulationResult());
            Assert.assertNull(kiePMMLDroolsRule.getResultCode());
        }
        Assert.assertNull(kiePMMLDroolsRule.getResult());
        if (num != null) {
            Assert.assertEquals(num.intValue(), kiePMMLDroolsRule.getAndConstraints().size());
            commonValidateAndConstraint((KiePMMLFieldOperatorValue) kiePMMLDroolsRule.getAndConstraints().get(0), attribute, boolean_operator, str3, num3.intValue());
        }
        if (num2 != null) {
            Assert.assertEquals(num2.intValue(), kiePMMLDroolsRule.getInConstraints().size());
            commonValidateInConstraint(kiePMMLDroolsRule.getInConstraints(), attribute);
        }
    }

    private void commonValidateAndConstraint(KiePMMLFieldOperatorValue kiePMMLFieldOperatorValue, Attribute attribute, BOOLEAN_OPERATOR boolean_operator, String str, int i) {
        Assert.assertEquals(boolean_operator, kiePMMLFieldOperatorValue.getOperator());
        if (str != null) {
            Assert.assertEquals(str, kiePMMLFieldOperatorValue.getConstraintsAsString());
        }
        List<KiePMMLOperatorValue> kiePMMLOperatorValues = kiePMMLFieldOperatorValue.getKiePMMLOperatorValues();
        Assert.assertEquals(i, kiePMMLOperatorValues.size());
        if (attribute.getPredicate() instanceof SimplePredicate) {
            commonValidateKiePMMLOperatorValue(kiePMMLOperatorValues.get(0), (SimplePredicate) attribute.getPredicate());
        } else if (attribute.getPredicate() instanceof CompoundPredicate) {
            commonValidateKiePMMLOperatorValues(kiePMMLOperatorValues, (CompoundPredicate) attribute.getPredicate());
        }
    }

    private void commonValidateInConstraint(Map<String, List<Object>> map, Attribute attribute) {
        if (attribute.getPredicate() instanceof SimpleSetPredicate) {
            commonValidateObjectValues(map.values().iterator().next(), (SimpleSetPredicate) attribute.getPredicate());
        }
    }

    private void commonValidateKiePMMLOperatorValue(KiePMMLOperatorValue kiePMMLOperatorValue, SimplePredicate simplePredicate) {
        Assert.assertEquals(OPERATOR.byName(simplePredicate.getOperator().value()), kiePMMLOperatorValue.getOperator());
        Assert.assertEquals(getExpectedValue(simplePredicate), kiePMMLOperatorValue.getValue());
    }

    private Object getExpectedValue(SimplePredicate simplePredicate) throws RuntimeException {
        return KiePMMLDroolsModelUtils.getCorrectlyFormattedResult(simplePredicate.getValue(), (DATA_TYPE) this.dataDictionary.getDataFields().stream().filter(dataField -> {
            return dataField.getName().equals(simplePredicate.getField());
        }).map(dataField2 -> {
            return DATA_TYPE.byName(dataField2.getDataType().value());
        }).findFirst().orElseThrow(() -> {
            return new RuntimeException("Failed to find DataField for " + simplePredicate.getField().getValue());
        }));
    }

    private void commonValidateObjectValues(List<Object> list, SimpleSetPredicate simpleSetPredicate) {
        Assert.assertEquals(list.size(), simpleSetPredicate.getArray().getN().intValue());
        String[] split = ((String) simpleSetPredicate.getArray().getValue()).split(" ");
        for (int i = 0; i < list.size(); i++) {
            Assert.assertEquals(list.get(i), "\"" + split[i] + "\"");
        }
    }

    private void commonValidateKiePMMLOperatorValues(List<KiePMMLOperatorValue> list, CompoundPredicate compoundPredicate) {
        Assert.assertEquals(list.size(), compoundPredicate.getPredicates().size());
        for (int i = 0; i < list.size(); i++) {
            commonValidateKiePMMLOperatorValue(list.get(i), (SimplePredicate) compoundPredicate.getPredicates().get(i));
        }
    }

    private Characteristic getCharacteristic() {
        Characteristic characteristic = new Characteristic();
        characteristic.setName("AgeScore");
        characteristic.addAttributes(new Attribute[]{getSimplePredicateAttribute(), getCompoundPredicateAttribute()});
        return characteristic;
    }

    private Attribute getSimplePredicateAttribute() {
        Attribute attribute = new Attribute();
        attribute.setPartialScore(Double.valueOf(10.0d));
        attribute.setPredicate(KiePMMLASTTestUtils.getSimplePredicate("age", DataType.DOUBLE, Double.valueOf(5.0d), SimplePredicate.Operator.LESS_OR_EQUAL, new HashMap()));
        return attribute;
    }

    private Attribute getCompoundPredicateAttribute() {
        Attribute attribute = new Attribute();
        attribute.setPartialScore(Double.valueOf(30.0d));
        attribute.setPredicate(getCompoundPredicate());
        return attribute;
    }

    private Attribute getSimpleSetPredicateAttribute() {
        Attribute attribute = new Attribute();
        attribute.setPartialScore(Double.valueOf(-10.0d));
        attribute.setPredicate(KiePMMLASTTestUtils.getSimpleSetPredicate("occupation", Array.Type.STRING, Arrays.asList("SKYDIVER", "ASTRONAUT"), SimpleSetPredicate.BooleanOperator.IS_IN, new HashMap()));
        return attribute;
    }

    private CompoundPredicate getCompoundPredicate() {
        CompoundPredicate compoundPredicate = new CompoundPredicate();
        compoundPredicate.setBooleanOperator(CompoundPredicate.BooleanOperator.AND);
        HashMap hashMap = new HashMap();
        compoundPredicate.addPredicates(new Predicate[]{KiePMMLASTTestUtils.getSimplePredicate("age", DataType.DOUBLE, Double.valueOf(5.0d), SimplePredicate.Operator.GREATER_OR_EQUAL, hashMap), KiePMMLASTTestUtils.getSimplePredicate("age", DataType.DOUBLE, Double.valueOf(12.0d), SimplePredicate.Operator.LESS_THAN, hashMap)});
        return compoundPredicate;
    }

    @Test
    public void getKiePMMLReasonCodeAndValueUseReasonCodesFalse() {
        Assert.assertNull(getKiePMMLScorecardModelCharacteristicASTFactory().getKiePMMLReasonCodeAndValue(new Attribute(), "", 0));
    }

    @Test(expected = KiePMMLException.class)
    public void getKiePMMLReasonCodeAndValueUseReasonCodesTrueNoBaselineScore() {
        getKiePMMLScorecardModelCharacteristicASTFactory().withReasonCodes((Number) null, (REASONCODE_ALGORITHM) null).getKiePMMLReasonCodeAndValue(new Attribute(), "", (Number) null);
    }

    @Test(expected = KiePMMLException.class)
    public void getKiePMMLReasonCodeAndValueUseReasonCodesTrueNoReasonCodeAlgorithm() {
        getKiePMMLScorecardModelCharacteristicASTFactory().withReasonCodes((Number) null, (REASONCODE_ALGORITHM) null).getKiePMMLReasonCodeAndValue(new Attribute(), "", 12);
    }

    @Test(expected = KiePMMLException.class)
    public void getKiePMMLReasonCodeAndValueUseReasonCodesTrueNoReasonCode() {
        getKiePMMLScorecardModelCharacteristicASTFactory().withReasonCodes((Number) null, REASONCODE_ALGORITHM.POINTS_ABOVE).getKiePMMLReasonCodeAndValue(new Attribute(), "", 12);
    }

    @Test
    public void getKiePMMLReasonCodeAndValueUseReasonCodesTrue() {
        Attribute attribute = new Attribute();
        attribute.setPartialScore(Double.valueOf(13.17d));
        KiePMMLReasonCodeAndValue kiePMMLReasonCodeAndValue = getKiePMMLScorecardModelCharacteristicASTFactory().withReasonCodes(Double.valueOf(13.0d), REASONCODE_ALGORITHM.POINTS_ABOVE).getKiePMMLReasonCodeAndValue(attribute, "CHARACTERISTIC_REASON_CODE", (Number) null);
        Assert.assertNotNull(kiePMMLReasonCodeAndValue);
        Assert.assertEquals("CHARACTERISTIC_REASON_CODE", kiePMMLReasonCodeAndValue.getReasonCode());
        Assert.assertEquals(13.17d - 13.0d, kiePMMLReasonCodeAndValue.getValue(), 0.0d);
        KiePMMLReasonCodeAndValue kiePMMLReasonCodeAndValue2 = getKiePMMLScorecardModelCharacteristicASTFactory().withReasonCodes(Double.valueOf(13.0d), REASONCODE_ALGORITHM.POINTS_ABOVE).getKiePMMLReasonCodeAndValue(attribute, "CHARACTERISTIC_REASON_CODE", Double.valueOf(24.45d));
        Assert.assertNotNull(kiePMMLReasonCodeAndValue2);
        Assert.assertEquals("CHARACTERISTIC_REASON_CODE", kiePMMLReasonCodeAndValue2.getReasonCode());
        double d = 13.17d - 24.45d;
        Assert.assertEquals(d, kiePMMLReasonCodeAndValue2.getValue(), 0.0d);
        attribute.setReasonCode("ATTRIBUTE_REASON_CODE");
        KiePMMLReasonCodeAndValue kiePMMLReasonCodeAndValue3 = getKiePMMLScorecardModelCharacteristicASTFactory().withReasonCodes(Double.valueOf(13.0d), REASONCODE_ALGORITHM.POINTS_ABOVE).getKiePMMLReasonCodeAndValue(attribute, "CHARACTERISTIC_REASON_CODE", Double.valueOf(24.45d));
        Assert.assertNotNull(kiePMMLReasonCodeAndValue3);
        Assert.assertEquals("ATTRIBUTE_REASON_CODE", kiePMMLReasonCodeAndValue3.getReasonCode());
        Assert.assertEquals(d, kiePMMLReasonCodeAndValue3.getValue(), 0.0d);
    }

    private KiePMMLScorecardModelCharacteristicASTFactory getKiePMMLScorecardModelCharacteristicASTFactory() {
        HashMap hashMap = new HashMap();
        DATA_TYPE targetFieldType = ModelUtils.getTargetFieldType(this.samplePmml.getDataDictionary(), this.scorecardModel);
        KiePMMLDataDictionaryASTFactory.factory(hashMap).declareTypes(this.samplePmml.getDataDictionary());
        Assert.assertFalse(hashMap.isEmpty());
        return KiePMMLScorecardModelCharacteristicASTFactory.factory(hashMap, Collections.emptyList(), targetFieldType);
    }
}
