package org.drools.compiler.integrationtests;

import java.util.Iterator;
import org.drools.core.InitialFact;
import org.drools.core.common.BaseNode;
import org.drools.core.common.NetworkNode;
import org.drools.core.common.RuleBasePartitionId;
import org.drools.core.reteoo.BetaNode;
import org.drools.core.reteoo.EntryPointNode;
import org.drools.core.reteoo.ObjectSink;
import org.drools.core.reteoo.ObjectSinkPropagator;
import org.drools.core.reteoo.ObjectSource;
import org.drools.core.reteoo.ObjectTypeNode;
import org.drools.core.reteoo.ReteDumper;
import org.drools.core.reteoo.TerminalNode;
import org.junit.Assert;
import org.junit.Test;
import org.kie.api.conf.KieBaseOption;
import org.kie.api.io.ResourceType;
import org.kie.internal.conf.MultithreadEvaluationOption;
import org.kie.internal.utils.KieHelper;

/* loaded from: input_file:org/drools/compiler/integrationtests/NodesPartitioningTest.class */
public class NodesPartitioningTest {
    @Test
    public void test2Partitions() {
        checkDrl(ruleA(1) + ruleB(2) + ruleC(2) + ruleD(1) + ruleD(2) + ruleC(1) + ruleA(2) + ruleB(1));
    }

    @Test
    public void testPartitioningWithSharedNodes() {
        StringBuilder sb = new StringBuilder(400);
        for (int i = 1; i < 4; i++) {
            sb.append(getRule(i));
        }
        for (int i2 = 1; i2 < 4; i2++) {
            sb.append(getNotRule(i2));
        }
        checkDrl(sb.toString());
    }

    private void checkDrl(String str) {
        Iterator it = new KieHelper().addContent(str, ResourceType.DRL).build(new KieBaseOption[]{MultithreadEvaluationOption.YES}).getRete().getEntryPointNodes().values().iterator();
        while (it.hasNext()) {
            traverse((EntryPointNode) it.next());
        }
    }

    private void traverse(NetworkNode networkNode) {
        checkNode(networkNode);
        BaseNode[] sinks = ReteDumper.getSinks(networkNode);
        if (sinks != null) {
            for (BaseNode baseNode : sinks) {
                if (baseNode instanceof BaseNode) {
                    traverse(baseNode);
                }
            }
        }
    }

    private void checkNode(NetworkNode networkNode) {
        if (networkNode instanceof EntryPointNode) {
            Assert.assertSame(RuleBasePartitionId.MAIN_PARTITION, networkNode.getPartitionId());
            return;
        }
        if (networkNode instanceof ObjectTypeNode) {
            Assert.assertSame(RuleBasePartitionId.MAIN_PARTITION, networkNode.getPartitionId());
            checkPartitionedSinks((ObjectTypeNode) networkNode);
            return;
        }
        if (networkNode instanceof ObjectSource) {
            ObjectSource parentObjectSource = ((ObjectSource) networkNode).getParentObjectSource();
            if (parentObjectSource instanceof ObjectTypeNode) {
                return;
            }
            Assert.assertSame(parentObjectSource.getPartitionId(), networkNode.getPartitionId());
            return;
        }
        if (!(networkNode instanceof BetaNode)) {
            if (networkNode instanceof TerminalNode) {
                Assert.assertSame(((TerminalNode) networkNode).getLeftTupleSource().getPartitionId(), networkNode.getPartitionId());
            }
        } else {
            ObjectSource rightInput = ((BetaNode) networkNode).getRightInput();
            if (!(rightInput instanceof ObjectTypeNode)) {
                Assert.assertSame(rightInput.getPartitionId(), networkNode.getPartitionId());
            }
            Assert.assertSame(((BetaNode) networkNode).getLeftTupleSource().getPartitionId(), networkNode.getPartitionId());
        }
    }

    private void checkPartitionedSinks(ObjectTypeNode objectTypeNode) {
        if (InitialFact.class.isAssignableFrom(objectTypeNode.getObjectType().getClassType())) {
            return;
        }
        ObjectSinkPropagator[] partitionedPropagators = objectTypeNode.getObjectSinkPropagator().getPartitionedPropagators();
        for (int i = 0; i < partitionedPropagators.length; i++) {
            for (ObjectSink objectSink : partitionedPropagators[i].getSinks()) {
                Assert.assertEquals(objectSink + " on " + objectSink.getPartitionId() + " is expcted to be on propagator " + i, i, objectSink.getPartitionId().getId() % partitionedPropagators.length);
            }
        }
    }

    private String ruleA(int i) {
        return "rule Ra" + i + " when\n    $i : Integer( this == " + i + " )\n    $s : String( length == $i )\n    Integer( this == $s.length )\nthen end\n";
    }

    private String ruleB(int i) {
        return "rule Rb" + i + " when\n    $i : Integer( this == " + i + " )\n    $s : String( this == $i.toString )\n    Integer( this == $s.length )\nthen end\n";
    }

    private String ruleC(int i) {
        return "rule Rc" + i + " when\n    $i : Integer( this == " + i + " )\n    $s : String( length == $i )\n    Integer( this == $i+1 )\nthen end\n";
    }

    private String ruleD(int i) {
        return "rule Rd" + i + " when\n    $i : Integer( this == " + i + " )\n    $s : String( length == $i )\nthen end\n";
    }

    private String getRule(int i) {
        return "rule R" + i + " when\n    $i : Integer( this == " + i + " )    String( this == $i.toString )\nthen end\n";
    }

    private String getNotRule(int i) {
        return "rule Rnot" + i + " when\n    String( this == \"" + i + "\" )\n    not Integer( this == " + i + " )then end\n";
    }
}
