package org.drools.mvel.integrationtests;

import java.util.Collection;
import java.util.Iterator;
import org.assertj.core.api.Assertions;
import org.drools.core.InitialFact;
import org.drools.core.common.BaseNode;
import org.drools.core.common.NetworkNode;
import org.drools.core.common.RuleBasePartitionId;
import org.drools.core.reteoo.BetaNode;
import org.drools.core.reteoo.CompositePartitionAwareObjectSinkAdapter;
import org.drools.core.reteoo.EntryPointNode;
import org.drools.core.reteoo.ObjectSink;
import org.drools.core.reteoo.ObjectSinkPropagator;
import org.drools.core.reteoo.ObjectSource;
import org.drools.core.reteoo.ObjectTypeNode;
import org.drools.core.reteoo.TerminalNode;
import org.drools.testcoverage.common.util.KieBaseTestConfiguration;
import org.drools.testcoverage.common.util.KieBaseUtil;
import org.drools.testcoverage.common.util.KieUtil;
import org.drools.testcoverage.common.util.TestParametersUtil;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.kie.api.conf.KieBaseOption;
import org.kie.internal.conf.MultithreadEvaluationOption;

@RunWith(Parameterized.class)
/* loaded from: input_file:org/drools/mvel/integrationtests/NodesPartitioningTest.class */
public class NodesPartitioningTest {
    private final KieBaseTestConfiguration kieBaseTestConfiguration;

    /* loaded from: input_file:org/drools/mvel/integrationtests/NodesPartitioningTest$Account.class */
    public static class Account {
        private final int number;
        private final String uuid;
        private final Customer owner;

        public Account(int i, String str, Customer customer) {
            this.number = i;
            this.uuid = str;
            this.owner = customer;
        }

        public int getNumber() {
            return this.number;
        }

        public String getUuid() {
            return this.uuid;
        }

        public Customer getOwner() {
            return this.owner;
        }
    }

    /* loaded from: input_file:org/drools/mvel/integrationtests/NodesPartitioningTest$Customer.class */
    public static class Customer {
        private final String uuid;

        public Customer(String str) {
            this.uuid = str;
        }

        public String getUuid() {
            return this.uuid;
        }
    }

    public NodesPartitioningTest(KieBaseTestConfiguration kieBaseTestConfiguration) {
        this.kieBaseTestConfiguration = kieBaseTestConfiguration;
    }

    @Parameterized.Parameters(name = "KieBase type={0}")
    public static Collection<Object[]> getParameters() {
        return TestParametersUtil.getKieBaseCloudConfigurations(true);
    }

    @Test
    public void test2Partitions() {
        checkDrl(ruleA(1) + ruleB(2) + ruleC(2) + ruleD(1) + ruleD(2) + ruleC(1) + ruleA(2) + ruleB(1));
    }

    @Test
    public void testPartitioningWithSharedNodes() {
        StringBuilder sb = new StringBuilder(400);
        for (int i = 1; i < 4; i++) {
            sb.append(getRule(i));
        }
        for (int i2 = 1; i2 < 4; i2++) {
            sb.append(getNotRule(i2));
        }
        checkDrl(sb.toString());
    }

    private void checkDrl(String str) {
        Iterator it = KieBaseUtil.newKieBaseFromKieModuleWithAdditionalOptions(KieUtil.getKieModuleFromDrls("test", this.kieBaseTestConfiguration, new String[]{str}), this.kieBaseTestConfiguration, new KieBaseOption[]{MultithreadEvaluationOption.YES}).getRete().getEntryPointNodes().values().iterator();
        while (it.hasNext()) {
            traverse((EntryPointNode) it.next());
        }
    }

    private void traverse(BaseNode baseNode) {
        checkNode(baseNode);
        NetworkNode[] sinks = baseNode.getSinks();
        if (sinks != null) {
            for (NetworkNode networkNode : sinks) {
                if (networkNode instanceof BaseNode) {
                    traverse((BaseNode) networkNode);
                }
            }
        }
    }

    private void checkNode(NetworkNode networkNode) {
        if (networkNode instanceof EntryPointNode) {
            Assertions.assertThat(networkNode.getPartitionId()).isSameAs(RuleBasePartitionId.MAIN_PARTITION);
            return;
        }
        if (networkNode instanceof ObjectTypeNode) {
            Assertions.assertThat(networkNode.getPartitionId()).isSameAs(RuleBasePartitionId.MAIN_PARTITION);
            checkPartitionedSinks((ObjectTypeNode) networkNode);
            return;
        }
        if (networkNode instanceof ObjectSource) {
            ObjectSource parentObjectSource = ((ObjectSource) networkNode).getParentObjectSource();
            if (parentObjectSource instanceof ObjectTypeNode) {
                return;
            }
            Assertions.assertThat(networkNode.getPartitionId()).isSameAs(parentObjectSource.getPartitionId());
            return;
        }
        if (!(networkNode instanceof BetaNode)) {
            if (networkNode instanceof TerminalNode) {
                Assertions.assertThat(networkNode.getPartitionId()).isSameAs(((TerminalNode) networkNode).getLeftTupleSource().getPartitionId());
                return;
            }
            return;
        }
        ObjectSource rightInput = ((BetaNode) networkNode).getRightInput();
        if (!(rightInput instanceof ObjectTypeNode)) {
            Assertions.assertThat(networkNode.getPartitionId()).isSameAs(rightInput.getPartitionId());
        }
        Assertions.assertThat(networkNode.getPartitionId()).isSameAs(((BetaNode) networkNode).getLeftTupleSource().getPartitionId());
    }

    private void checkPartitionedSinks(ObjectTypeNode objectTypeNode) {
        if (InitialFact.class.isAssignableFrom(objectTypeNode.getObjectType().getClassType())) {
            return;
        }
        CompositePartitionAwareObjectSinkAdapter objectSinkPropagator = objectTypeNode.getObjectSinkPropagator();
        ObjectSinkPropagator[] partitionedPropagators = objectSinkPropagator instanceof CompositePartitionAwareObjectSinkAdapter ? objectSinkPropagator.getPartitionedPropagators() : new ObjectSinkPropagator[]{objectSinkPropagator};
        for (int i = 0; i < partitionedPropagators.length; i++) {
            for (ObjectSink objectSink : partitionedPropagators[i].getSinks()) {
                Assertions.assertThat(objectSink.getPartitionId().getId() % partitionedPropagators.length).as(objectSink + " on " + objectSink.getPartitionId() + " is expcted to be on propagator " + i, new Object[0]).isEqualTo(i);
            }
        }
    }

    private String ruleA(int i) {
        return "rule Ra" + i + " when\n    $i : Integer( this == " + i + " )\n    $s : String( length == $i )\n    Integer( this == $s.length )\nthen end\n";
    }

    private String ruleB(int i) {
        return "rule Rb" + i + " when\n    $i : Integer( this == " + i + " )\n    $s : String( this == $i.toString )\n    Integer( this == $s.length )\nthen end\n";
    }

    private String ruleC(int i) {
        return "rule Rc" + i + " when\n    $i : Integer( this == " + i + " )\n    $s : String( length == $i )\n    Integer( this == $i+1 )\nthen end\n";
    }

    private String ruleD(int i) {
        return "rule Rd" + i + " when\n    $i : Integer( this == " + i + " )\n    $s : String( length == $i )\nthen end\n";
    }

    private String getRule(int i) {
        return "rule R" + i + " when\n    $i : Integer( this == " + i + " )    String( this == $i.toString )\nthen end\n";
    }

    private String getNotRule(int i) {
        return "rule Rnot" + i + " when\n    String( this == \"" + i + "\" )\n    not Integer( this == " + i + " )then end\n";
    }

    @Test
    public void testChangePartitionOfAlphaSourceOfAlpha() {
        checkDrl("import " + Account.class.getCanonicalName() + ";\nimport " + Customer.class.getCanonicalName() + ";\nrule \"customerDoesNotHaveSpecifiedAccount_2\"\nwhen\n    $account : Account (number == 1, uuid == \"customerDoesNotHaveSpecifiedAccount\")\n    Customer (uuid == \"customerDoesNotHaveSpecifiedAccount\")\nthen\nend\n\nrule \"customerDoesNotHaveSpecifiedAccount_1\"\nwhen\n    $account : Account (number == 2, uuid == \"customerDoesNotHaveSpecifiedAccount\")\n    Customer (uuid == \"customerDoesNotHaveSpecifiedAccount\")\nthen\nend");
    }
}
