use of org.drools.core.reteoo.ConditionalBranchNode in project drools by kiegroup.
the class SegmentUtilities method createSegmentMemory.
public static SegmentMemory createSegmentMemory(LeftTupleSource tupleSource, Memory mem, InternalWorkingMemory wm) {
// find segment root
while (!SegmentUtilities.isRootNode(tupleSource, null)) {
tupleSource = tupleSource.getLeftTupleSource();
}
LeftTupleSource segmentRoot = tupleSource;
int nodeTypesInSegment = 0;
SegmentMemory smem = restoreSegmentFromPrototype(wm, segmentRoot, nodeTypesInSegment);
if (smem != null) {
if (NodeTypeEnums.isBetaNode(segmentRoot) && ((BetaNode) segmentRoot).isRightInputIsRiaNode()) {
createRiaSegmentMemory((BetaNode) segmentRoot, wm);
}
return smem;
}
smem = new SegmentMemory(segmentRoot);
// Iterate all nodes on the same segment, assigning their position as a bit mask value
// allLinkedTestMask is the resulting mask used to test if all nodes are linked in
long nodePosMask = 1;
long allLinkedTestMask = 0;
// nodes after a branch CE can notify, but they cannot impact linking
boolean updateNodeBit = true;
while (true) {
nodeTypesInSegment = updateNodeTypesMask(tupleSource, nodeTypesInSegment);
if (NodeTypeEnums.isBetaNode(tupleSource)) {
allLinkedTestMask = processBetaNode((BetaNode) tupleSource, wm, smem, nodePosMask, allLinkedTestMask, updateNodeBit);
} else {
switch(tupleSource.getType()) {
case NodeTypeEnums.LeftInputAdapterNode:
allLinkedTestMask = processLiaNode((LeftInputAdapterNode) tupleSource, wm, smem, nodePosMask, allLinkedTestMask);
break;
case NodeTypeEnums.EvalConditionNode:
processEvalNode((EvalConditionNode) tupleSource, wm, smem);
break;
case NodeTypeEnums.ConditionalBranchNode:
updateNodeBit = processBranchNode((ConditionalBranchNode) tupleSource, wm, smem);
break;
case NodeTypeEnums.FromNode:
processFromNode((FromNode) tupleSource, wm, smem);
break;
case NodeTypeEnums.ReactiveFromNode:
processReactiveFromNode((MemoryFactory) tupleSource, wm, smem, nodePosMask);
break;
case NodeTypeEnums.TimerConditionNode:
processTimerNode((TimerNode) tupleSource, wm, smem, nodePosMask);
break;
case NodeTypeEnums.QueryElementNode:
updateNodeBit = processQueryNode((QueryElementNode) tupleSource, wm, segmentRoot, smem, nodePosMask);
break;
}
}
nodePosMask = nodePosMask << 1;
if (tupleSource.getSinkPropagator().size() == 1) {
LeftTupleSinkNode sink = tupleSource.getSinkPropagator().getFirstLeftTupleSink();
if (NodeTypeEnums.isLeftTupleSource(sink)) {
tupleSource = (LeftTupleSource) sink;
} else {
// rtn or rian
// While not technically in a segment, we want to be able to iterate easily from the last node memory to the ria/rtn memory
// we don't use createNodeMemory, as these may already have been created by, but not added, by the method updateRiaAndTerminalMemory
Memory memory = wm.getNodeMemory((MemoryFactory) sink);
if (sink.getType() == NodeTypeEnums.RightInputAdaterNode) {
PathMemory riaPmem = ((RiaNodeMemory) memory).getRiaPathMemory();
smem.getNodeMemories().add(riaPmem);
RightInputAdapterNode rian = (RightInputAdapterNode) sink;
ObjectSink[] nodes = rian.getObjectSinkPropagator().getSinks();
for (ObjectSink node : nodes) {
if (NodeTypeEnums.isLeftTupleSource(node)) {
createSegmentMemory((LeftTupleSource) node, wm);
}
}
} else if (NodeTypeEnums.isTerminalNode(sink)) {
smem.getNodeMemories().add(memory);
}
memory.setSegmentMemory(smem);
smem.setTipNode(sink);
break;
}
} else {
// not in same segment
smem.setTipNode(tupleSource);
break;
}
}
smem.setAllLinkedMaskTest(allLinkedTestMask);
// iterate to find root and determine the SegmentNodes position in the RuleSegment
LeftTupleSource pathRoot = segmentRoot;
int ruleSegmentPosMask = 1;
int counter = 0;
while (pathRoot.getType() != NodeTypeEnums.LeftInputAdapterNode) {
LeftTupleSource leftTupleSource = pathRoot.getLeftTupleSource();
if (SegmentUtilities.isNonTerminalTipNode(leftTupleSource, null)) {
// for each new found segment, increase the mask bit position
ruleSegmentPosMask = ruleSegmentPosMask << 1;
counter++;
}
pathRoot = leftTupleSource;
}
smem.setSegmentPosMaskBit(ruleSegmentPosMask);
smem.setPos(counter);
nodeTypesInSegment = updateRiaAndTerminalMemory(tupleSource, tupleSource, smem, wm, false, nodeTypesInSegment);
((KnowledgeBaseImpl) wm.getKnowledgeBase()).registerSegmentPrototype(segmentRoot, smem);
return smem;
}
use of org.drools.core.reteoo.ConditionalBranchNode in project drools by kiegroup.
the class SegmentCreationTest method testBranchCESingleSegment.
@Test
public void testBranchCESingleSegment() throws Exception {
KieBase kbase = buildKnowledgeBase(" $a : A() \n" + " if ( $a != null ) do[t1] \n" + " B() \n");
InternalWorkingMemory wm = ((InternalWorkingMemory) kbase.newKieSession());
ObjectTypeNode aotn = getObjectTypeNode(kbase, LinkingTest.A.class);
LeftInputAdapterNode liaNode = (LeftInputAdapterNode) aotn.getObjectSinkPropagator().getSinks()[0];
ConditionalBranchNode cen1Node = (ConditionalBranchNode) liaNode.getSinkPropagator().getSinks()[0];
JoinNode bNode = (JoinNode) cen1Node.getSinkPropagator().getSinks()[0];
RuleTerminalNode rtn1 = (RuleTerminalNode) bNode.getSinkPropagator().getSinks()[0];
FactHandle bFh = wm.insert(new LinkingTest.B());
wm.flushPropagations();
LiaNodeMemory liaMem = (LiaNodeMemory) wm.getNodeMemory(liaNode);
SegmentMemory smem = liaMem.getSegmentMemory();
assertEquals(1, smem.getAllLinkedMaskTest());
// B links, but it will not trigger mask
assertEquals(4, smem.getLinkedNodeMask());
assertFalse(smem.isSegmentLinked());
PathMemory pmem = (PathMemory) wm.getNodeMemory(rtn1);
assertEquals(1, pmem.getAllLinkedMaskTest());
assertEquals(0, pmem.getLinkedSegmentMask());
assertFalse(pmem.isRuleLinked());
wm.insert(new LinkingTest.A());
wm.flushPropagations();
// A links in segment
assertEquals(5, smem.getLinkedNodeMask());
assertTrue(smem.isSegmentLinked());
assertEquals(1, pmem.getLinkedSegmentMask());
assertTrue(pmem.isRuleLinked());
// retract B does not unlink the rule
wm.delete(bFh);
wm.flushPropagations();
assertEquals(1, pmem.getLinkedSegmentMask());
assertTrue(pmem.isRuleLinked());
}
use of org.drools.core.reteoo.ConditionalBranchNode in project drools by kiegroup.
the class ConditionalBranchBuilder method build.
public void build(BuildContext context, BuildUtils utils, RuleConditionElement rce) {
ConditionalBranch conditionalBranch = (ConditionalBranch) rce;
ConditionalBranchEvaluator branchEvaluator = buildConditionalBranchEvaluator(context, conditionalBranch);
context.pushRuleComponent(rce);
ConditionalBranchNode node = context.getComponentFactory().getNodeFactoryService().buildConditionalBranchNode(context.getNextId(), context.getTupleSource(), branchEvaluator, context);
context.setTupleSource(utils.attachNode(context, node));
context.popRuleComponent();
}
use of org.drools.core.reteoo.ConditionalBranchNode in project drools by kiegroup.
the class SegmentCreationTest method testBranchCEMultipleSegments.
@Test
public void testBranchCEMultipleSegments() throws Exception {
KieBase kbase = buildKnowledgeBase(// r1
" $a : A() \n", " $a : A() \n" + " if ( $a != null ) do[t1] \n" + // r2
" B() \n", " $a : A() \n" + " if ( $a != null ) do[t1] \n" + " B() \n" + // r3
" C() \n");
InternalWorkingMemory wm = ((InternalWorkingMemory) kbase.newKieSession());
ObjectTypeNode aotn = getObjectTypeNode(kbase, LinkingTest.A.class);
LeftInputAdapterNode liaNode = (LeftInputAdapterNode) aotn.getObjectSinkPropagator().getSinks()[0];
ConditionalBranchNode cen1Node = (ConditionalBranchNode) liaNode.getSinkPropagator().getSinks()[1];
JoinNode bNode = (JoinNode) cen1Node.getSinkPropagator().getSinks()[0];
RuleTerminalNode rtn2 = (RuleTerminalNode) bNode.getSinkPropagator().getSinks()[0];
JoinNode cNode = (JoinNode) bNode.getSinkPropagator().getSinks()[1];
RuleTerminalNode rtn3 = (RuleTerminalNode) cNode.getSinkPropagator().getSinks()[0];
FactHandle bFh = wm.insert(new LinkingTest.B());
FactHandle cFh = wm.insert(new LinkingTest.C());
wm.flushPropagations();
BetaMemory bNodeBm = (BetaMemory) wm.getNodeMemory(bNode);
SegmentMemory bNodeSmem = bNodeBm.getSegmentMemory();
// no beta nodes before branch CE, so never unlinks
assertEquals(0, bNodeSmem.getAllLinkedMaskTest());
assertEquals(2, bNodeSmem.getLinkedNodeMask());
PathMemory pmemr2 = (PathMemory) wm.getNodeMemory(rtn2);
assertEquals(1, pmemr2.getAllLinkedMaskTest());
assertEquals(2, pmemr2.getLinkedSegmentMask());
assertEquals(3, pmemr2.getSegmentMemories().length);
assertFalse(pmemr2.isRuleLinked());
PathMemory pmemr3 = (PathMemory) wm.getNodeMemory(rtn3);
// notice only the first segment links
assertEquals(1, pmemr3.getAllLinkedMaskTest());
assertEquals(3, pmemr3.getSegmentMemories().length);
assertFalse(pmemr3.isRuleLinked());
BetaMemory cNodeBm = (BetaMemory) wm.getNodeMemory(cNode);
SegmentMemory cNodeSmem = cNodeBm.getSegmentMemory();
assertEquals(1, cNodeSmem.getAllLinkedMaskTest());
assertEquals(1, cNodeSmem.getLinkedNodeMask());
wm.insert(new LinkingTest.A());
wm.flushPropagations();
assertTrue(pmemr2.isRuleLinked());
assertTrue(pmemr3.isRuleLinked());
// retract B does not unlink the rule
wm.delete(bFh);
// retract C does not unlink the rule
wm.delete(cFh);
wm.flushPropagations();
// b segment never unlinks, as it has no impact on path unlinking anyway
assertEquals(3, pmemr2.getLinkedSegmentMask());
assertTrue(pmemr2.isRuleLinked());
// b segment never unlinks, as it has no impact on path unlinking anyway
assertEquals(3, pmemr3.getLinkedSegmentMask());
assertTrue(pmemr3.isRuleLinked());
}
Aggregations