package org.drools.core.reteoo;
import org.drools.core.common.InternalWorkingMemory;
import org.drools.core.common.LeftTupleSets;
import org.drools.core.common.LeftTupleSetsImpl;
import org.drools.core.common.Memory;
import org.drools.core.common.MemoryFactory;
import org.drools.core.common.NetworkNode;
import org.drools.core.common.StreamTupleEntryQueue;
import org.drools.core.common.SynchronizedLeftTupleSets;
import org.drools.core.phreak.SegmentUtilities;
import org.drools.core.reteoo.QueryElementNode.QueryElementNodeMemory;
import org.drools.core.reteoo.TimerNode.TimerNodeMemory;
import org.drools.core.util.AtomicBitwiseLong;
import org.drools.core.util.LinkedList;
import org.drools.core.util.LinkedListNode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.List;
import static org.drools.core.phreak.SegmentUtilities.getQuerySegmentMemory;
public class SegmentMemory extends LinkedList<SegmentMemory>
implements
LinkedListNode<SegmentMemory> {
protected static final Logger log = LoggerFactory.getLogger(SegmentMemory.class);
protected static final boolean isLogTraceEnabled = log.isTraceEnabled();
private NetworkNode rootNode;
private NetworkNode tipNode;
private LinkedList<Memory> nodeMemories;
private AtomicBitwiseLong linkedNodeMask;
private AtomicBitwiseLong dirtyNodeMask;
private long allLinkedMaskTest;
private List<PathMemory> pathMemories;
private long segmentPosMaskBit;
private int pos;
private volatile LeftTupleSets stagedLeftTuples;
private boolean active;
private SegmentMemory previous;
private SegmentMemory next;
private StreamTupleEntryQueue queue;
public SegmentMemory(NetworkNode rootNode) {
this(rootNode, null);
}
public SegmentMemory(NetworkNode rootNode, StreamTupleEntryQueue queue) {
this.rootNode = rootNode;
this.linkedNodeMask = new AtomicBitwiseLong();
this.dirtyNodeMask = new AtomicBitwiseLong();
this.pathMemories = new ArrayList<PathMemory>(1);
this.nodeMemories = new LinkedList<Memory>();
this.stagedLeftTuples = new LeftTupleSetsImpl();
this.queue = queue;
}
public StreamTupleEntryQueue getStreamQueue() {
return queue;
}
public void setStreamQueue(StreamTupleEntryQueue queue) {
this.queue = queue;
}
public NetworkNode getRootNode() {
return rootNode;
}
public NetworkNode getTipNode() {
return tipNode;
}
public void setTipNode(NetworkNode tipNode) {
this.tipNode = tipNode;
}
public LeftTupleSink getSinkFactory() {
return (LeftTupleSink) rootNode;
}
public void setSinkFactory(LeftTupleSink sink) {
}
public Memory createNodeMemory(MemoryFactory memoryFactory,
InternalWorkingMemory wm) {
Memory memory = wm.getNodeMemory(memoryFactory);
nodeMemories.add(memory);
return memory;
}
public LinkedList<Memory> getNodeMemories() {
return nodeMemories;
}
public long getLinkedNodeMask() {
return linkedNodeMask.get();
}
public void setLinkedNodeMask(long linkedNodeMask) {
this.linkedNodeMask.set( linkedNodeMask );
//this.linkedNodeMask = linkedNodeMask;
}
public long getDirtyNodeMask() {
return dirtyNodeMask.get();
}
public void resetDirtyNodeMask() {
dirtyNodeMask.set( 0 );
}
public void updateDirtyNodeMask(long mask) {
dirtyNodeMask.getAndBitwiseOr(mask);
}
public void updateCleanNodeMask(long mask) {
dirtyNodeMask.getAndBitwiseReset(mask);
//dirtyNodeMask = dirtyNodeMask & ~( 1 << mask );
}
public String getRuleNames() {
StringBuilder sbuilder = new StringBuilder();
for (int i = 0; i < pathMemories.size(); i++) {
if (i > 0) {
sbuilder.append(", ");
}
sbuilder.append(pathMemories.get(i));
}
return sbuilder.toString();
}
public void linkNode(long mask,
InternalWorkingMemory wm) {
linkedNodeMask.getAndBitwiseOr( mask );
//dirtyNodeMask = dirtyNodeMask | mask;
if (isLogTraceEnabled) {
log.trace("LinkNode notify=true nmask={} smask={} spos={} rules={}", mask, linkedNodeMask, pos, getRuleNames());
}
notifyRuleLinkSegment(wm);
}
public void linkNodeWithoutRuleNotify(long mask) {
linkedNodeMask.getAndBitwiseOr( mask );
//dirtyNodeMask = dirtyNodeMask | mask;
if (isLogTraceEnabled) {
log.trace("LinkNode notify=false nmask={} smask={} spos={} rules={}", mask, linkedNodeMask, pos, getRuleNames());
}
if (isSegmentLinked()) {
for (int i = 0, length = pathMemories.size(); i < length; i++) {
// do not use foreach, don't want Iterator object creation
pathMemories.get(i).linkNodeWithoutRuleNotify(segmentPosMaskBit);
}
}
}
public void notifyRuleLinkSegment(InternalWorkingMemory wm, long mask) {
dirtyNodeMask.getAndBitwiseOr( mask );
//dirtyNodeMask = dirtyNodeMask | mask;
if (isSegmentLinked()) {
for (int i = 0, length = pathMemories.size(); i < length; i++) {
// do not use foreach, don't want Iterator object creation
pathMemories.get(i).linkSegment(segmentPosMaskBit, wm);
}
}
}
public void notifyRuleLinkSegment(InternalWorkingMemory wm) {
if (isSegmentLinked()) {
for (int i = 0, length = pathMemories.size(); i < length; i++) {
// do not use foreach, don't want Iterator object creation
pathMemories.get(i).linkSegment(segmentPosMaskBit, wm);
}
}
}
public void unlinkNode(long mask,
InternalWorkingMemory wm) {
boolean linked = isSegmentLinked();
// some node unlinking does not unlink the segment, such as nodes after a Branch CE
linkedNodeMask.getAndBitwiseXor(mask);
dirtyNodeMask.getAndBitwiseOr( mask );
//dirtyNodeMask = dirtyNodeMask | mask;
if (isLogTraceEnabled) {
log.trace("UnlinkNode notify=true nmask={} smask={} spos={} rules={}", mask, linkedNodeMask, pos, getRuleNames());
}
if (linked && !isSegmentLinked()) {
for (int i = 0, length = pathMemories.size(); i < length; i++) {
// do not use foreach, don't want Iterator object creation
pathMemories.get(i).unlinkedSegment(segmentPosMaskBit,
wm);
}
} else {
// if not unlinked, then we still need to notify if the rule is linked
for (int i = 0, length = pathMemories.size(); i < length; i++) {
// do not use foreach, don't want Iterator object creation
if (pathMemories.get(i).isRuleLinked() ){
pathMemories.get(i).doLinkRule(wm);
}
}
}
}
public void unlinkNodeWithoutRuleNotify(long mask) {
linkedNodeMask.getAndBitwiseXor( mask );
//dirtyNodeMask = dirtyNodeMask | mask;
if (isLogTraceEnabled) {
log.trace("UnlinkNode notify=false nmask={} smask={} spos={} rules={}", mask, linkedNodeMask, pos, getRuleNames());
}
}
public long getAllLinkedMaskTest() {
return allLinkedMaskTest;
}
public void setAllLinkedMaskTest(long allLinkedTestMask) {
this.allLinkedMaskTest = allLinkedTestMask;
}
public boolean isSegmentLinked() {
return (linkedNodeMask.get() & allLinkedMaskTest) == allLinkedMaskTest;
}
public List<PathMemory> getPathMemories() {
return pathMemories;
}
public void setPathMemories(List<PathMemory> ruleSegments) {
this.pathMemories = ruleSegments;
}
public long getSegmentPosMaskBit() {
return segmentPosMaskBit;
}
public void setSegmentPosMaskBit(long nodeSegmenMask) {
this.segmentPosMaskBit = nodeSegmenMask;
}
public boolean isActive() {
return active;
}
public void setActive(boolean evaluating) {
this.active = evaluating;
}
public int getPos() {
return pos;
}
public void setPos(int pos) {
this.pos = pos;
}
public LeftTupleSets getStagedLeftTuples() {
return stagedLeftTuples;
}
public void setStagedTuples(LeftTupleSets stagedTuples) {
this.stagedLeftTuples = stagedTuples;
}
public SegmentMemory getNext() {
return this.next;
}
public void setNext(SegmentMemory next) {
this.next = next;
}
public SegmentMemory getPrevious() {
return this.previous;
}
public void setPrevious(SegmentMemory previous) {
this.previous = previous;
}
public void nullPrevNext() {
previous = null;
next = null;
}
@Override
public int hashCode() {
final int prime = 31;
int result = prime * rootNode.getId();
return result;
}
@Override
public boolean equals(Object obj) {
if (this == obj) { return true; }
if (!super.equals(obj)) { return false; }
if (getClass() != obj.getClass()) { return false; }
SegmentMemory other = (SegmentMemory) obj;
if (rootNode == null) {
if (other.rootNode != null) { return false; }
} else if (rootNode.getId() != other.rootNode.getId()) { return false; }
return true;
}
public Prototype asPrototype() {
return new Prototype(this);
}
public List<NetworkNode> getNodesInSegment() {
List<NetworkNode> nodes = new java.util.LinkedList<NetworkNode>();
NetworkNode currentNode = tipNode;
while (currentNode != rootNode) {
nodes.add(0, currentNode);
currentNode = ((LeftTupleSinkNode)currentNode).getLeftTupleSource();
}
nodes.add(0, currentNode);
return nodes;
}
public void reset(Prototype prototype) {
this.dirtyNodeMask.set(0);
this.linkedNodeMask.set( prototype != null ? prototype.linkedNodeMask : 0 );
if (queue != null) {
queue.takeAllForFlushing();
}
stagedLeftTuples.resetAll();
}
public static class Prototype {
private NetworkNode rootNode;
private NetworkNode tipNode;
private long linkedNodeMask;
private long allLinkedMaskTest;
private long segmentPosMaskBit;
private int pos;
private List<MemoryPrototype> memories = new ArrayList<MemoryPrototype>();
private boolean hasQueue;
private boolean hasSyncStagedLeftTuple;
private List<NetworkNode> nodesInSegment;
private Prototype(SegmentMemory smem) {
this.rootNode = smem.rootNode;
this.tipNode = smem.tipNode;
this.linkedNodeMask = smem.linkedNodeMask.get();
this.allLinkedMaskTest = smem.allLinkedMaskTest;
this.segmentPosMaskBit = smem.segmentPosMaskBit;
this.pos = smem.pos;
for (Memory mem = smem.nodeMemories.getFirst(); mem != null; mem = mem.getNext()) {
memories.add(MemoryPrototype.get(mem));
}
hasQueue = smem.queue != null;
hasSyncStagedLeftTuple = smem.getStagedLeftTuples() instanceof SynchronizedLeftTupleSets;
}
public SegmentMemory newSegmentMemory(InternalWorkingMemory wm) {
SegmentMemory smem = new SegmentMemory(rootNode);
smem.tipNode = tipNode;
smem.linkedNodeMask = new AtomicBitwiseLong( linkedNodeMask );
smem.allLinkedMaskTest = allLinkedMaskTest;
smem.segmentPosMaskBit = segmentPosMaskBit;
smem.pos = pos;
int i = 0;
for (NetworkNode node : getNodesInSegment(smem)) {
Memory mem = wm.getNodeMemory((MemoryFactory) node);
mem.setSegmentMemory(smem);
smem.getNodeMemories().add(mem);
MemoryPrototype proto = memories.get(i++);
if (proto != null) {
proto.populateMemory(wm, mem);
}
}
if ( hasQueue && smem.getStreamQueue() == null ) {
// need to make sure there is one Queue, for the rule, when a stream mode node is found.
StreamTupleEntryQueue queue = SegmentUtilities.initAndGetTupleQueue(smem.getTipNode(), wm);
smem.setStreamQueue( queue );
}
if (hasSyncStagedLeftTuple) {
smem.setStagedTuples( new SynchronizedLeftTupleSets() );
}
return smem;
}
private List<NetworkNode> getNodesInSegment(SegmentMemory smem) {
if (nodesInSegment == null) {
nodesInSegment = smem.getNodesInSegment();
}
return nodesInSegment;
}
}
public abstract static class MemoryPrototype {
public static MemoryPrototype get(Memory memory) {
if (memory instanceof BetaMemory) {
return new BetaMemoryPrototype((BetaMemory)memory);
}
if (memory instanceof LeftInputAdapterNode.LiaNodeMemory) {
return new LiaMemoryPrototype((LeftInputAdapterNode.LiaNodeMemory)memory);
}
if (memory instanceof QueryElementNode.QueryElementNodeMemory) {
return new QueryMemoryPrototype((QueryElementNode.QueryElementNodeMemory)memory);
}
if (memory instanceof TimerNodeMemory) {
return new TimerMemoryPrototype((TimerNodeMemory)memory);
}
if (memory instanceof AccumulateNode.AccumulateMemory) {
return new AccumulateMemoryPrototype((AccumulateNode.AccumulateMemory)memory);
}
return null;
}
public abstract void populateMemory(InternalWorkingMemory wm, Memory memory);
}
public static class BetaMemoryPrototype extends MemoryPrototype {
private final long nodePosMaskBit;
private RightInputAdapterNode riaNode;
private BetaMemoryPrototype(BetaMemory betaMemory) {
this.nodePosMaskBit = betaMemory.getNodePosMaskBit();
if (betaMemory.getRiaRuleMemory() != null) {
riaNode = betaMemory.getRiaRuleMemory().getRightInputAdapterNode();
}
}
@Override
public void populateMemory(InternalWorkingMemory wm, Memory memory) {
BetaMemory betaMemory = (BetaMemory)memory;
betaMemory.setNodePosMaskBit(nodePosMaskBit);
if (riaNode != null) {
RightInputAdapterNode.RiaNodeMemory riaMem = (RightInputAdapterNode.RiaNodeMemory)wm.getNodeMemory(riaNode);
betaMemory.setRiaRuleMemory(riaMem.getRiaPathMemory());
}
}
}
public static class LiaMemoryPrototype extends MemoryPrototype {
private final long nodePosMaskBit;
private LiaMemoryPrototype(LeftInputAdapterNode.LiaNodeMemory liaMemory) {
this.nodePosMaskBit = liaMemory.getNodePosMaskBit();
}
@Override
public void populateMemory(InternalWorkingMemory wm, Memory liaMemory) {
((LeftInputAdapterNode.LiaNodeMemory)liaMemory).setNodePosMaskBit(nodePosMaskBit);
}
}
public static class QueryMemoryPrototype extends MemoryPrototype {
private final long nodePosMaskBit;
private final QueryElementNode queryNode;
private QueryMemoryPrototype(QueryElementNode.QueryElementNodeMemory queryMemory) {
nodePosMaskBit = queryMemory.getNodePosMaskBit();
this.queryNode = queryMemory.getNode();
}
@Override
public void populateMemory(InternalWorkingMemory wm, Memory mem) {
QueryElementNodeMemory qmem = (QueryElementNodeMemory) mem;
SegmentMemory querySmem = getQuerySegmentMemory(wm, (LeftTupleSource)qmem.getSegmentMemory().getRootNode(), queryNode);
qmem.setQuerySegmentMemory(querySmem);
qmem.setNodePosMaskBit(nodePosMaskBit);
}
}
public static class TimerMemoryPrototype extends MemoryPrototype {
private final long nodePosMaskBit;
private TimerMemoryPrototype(TimerNodeMemory timerMemory) {
nodePosMaskBit = timerMemory.getNodePosMaskBit();
}
@Override
public void populateMemory(InternalWorkingMemory wm, Memory mem) {
TimerNodeMemory tmem = (TimerNodeMemory) mem;
tmem.setNodePosMaskBit(nodePosMaskBit);
}
}
public static class AccumulateMemoryPrototype extends MemoryPrototype {
private final BetaMemoryPrototype betaProto;
private AccumulateMemoryPrototype(AccumulateNode.AccumulateMemory accMemory) {
betaProto = new BetaMemoryPrototype(accMemory.getBetaMemory());
}
@Override
public void populateMemory(InternalWorkingMemory wm, Memory accMemory) {
betaProto.populateMemory(wm, ((AccumulateNode.AccumulateMemory)accMemory).getBetaMemory());
}
}
public static class FromMemoryPrototype extends MemoryPrototype {
private final BetaMemoryPrototype betaProto;
private FromMemoryPrototype(FromNode.FromMemory fromMemory) {
betaProto = new BetaMemoryPrototype(fromMemory.getBetaMemory());
}
@Override
public void populateMemory(InternalWorkingMemory wm, Memory fromMemory) {
betaProto.populateMemory(wm, ((FromNode.FromMemory) fromMemory).getBetaMemory());
}
}
}