use of org.drools.beliefs.graph.GraphNode in project drools by kiegroup.
the class BayesProjectionTest method testProjection3.
@Test
public void testProjection3() {
// Projects from node1 into sep. A, B and C are in node1. A and C are in the sep.
// this tests a non separator var, in the middle of the vars
BayesVariable a = new BayesVariable<String>("A", 0, new String[] { "A1", "A2" }, new double[][] { { 0.1, 0.2 } });
BayesVariable b = new BayesVariable<String>("B", 1, new String[] { "B1", "B2" }, new double[][] { { 0.1, 0.2 } });
BayesVariable c = new BayesVariable<String>("C", 2, new String[] { "C1", "C2" }, new double[][] { { 0.1, 0.2 } });
Graph<BayesVariable> graph = new BayesNetwork();
GraphNode x0 = addNode(graph);
GraphNode x1 = addNode(graph);
GraphNode x2 = addNode(graph);
GraphNode x3 = addNode(graph);
x0.setContent(a);
x1.setContent(b);
x2.setContent(c);
JunctionTreeClique node1 = new JunctionTreeClique(0, graph, bitSet("0111"));
JunctionTreeClique node2 = new JunctionTreeClique(1, graph, bitSet("0101"));
SeparatorState sep = new JunctionTreeSeparator(0, node1, node2, bitSet("0101"), graph).createState();
double v = 0.1;
for (int i = 0; i < node1.getPotentials().length; i++) {
node1.getPotentials()[i] = v;
v = scaleDouble(3, v + 0.1);
}
BayesVariable[] vars = new BayesVariable[] { a, b, c };
BayesVariable[] sepVars = new BayesVariable[] { a, c };
int[] sepVarPos = PotentialMultiplier.createSubsetVarPos(vars, sepVars);
int sepVarNumberOfStates = PotentialMultiplier.createNumberOfStates(sepVars);
int[] sepVarMultipliers = PotentialMultiplier.createIndexMultipliers(sepVars, sepVarNumberOfStates);
double[] projectedSepPotentials = new double[sep.getPotentials().length];
BayesProjection p = new BayesProjection(vars, node1.getPotentials(), sepVarPos, sepVarMultipliers, projectedSepPotentials);
p.project();
// remember it's been normalized, from 0.4, 0.6, 1.2, 1.4
assertArray(new double[] { 0.111, 0.167, 0.333, 0.389 }, scaleDouble(3, projectedSepPotentials));
}
use of org.drools.beliefs.graph.GraphNode in project drools by kiegroup.
the class GraphTest method connectParentToChildren.
public static void connectParentToChildren(GraphNode parent, GraphNode... children) {
for (GraphNode child : children) {
EdgeImpl e = new EdgeImpl();
e.setOutGraphNode(parent);
e.setInGraphNode(child);
}
}
use of org.drools.beliefs.graph.GraphNode in project drools by kiegroup.
the class GraphTest method connectChildToParents.
public static void connectChildToParents(GraphNode child, GraphNode... parents) {
for (GraphNode parent : parents) {
EdgeImpl e = new EdgeImpl();
e.setOutGraphNode(parent);
e.setInGraphNode(child);
}
}
use of org.drools.beliefs.graph.GraphNode in project drools by kiegroup.
the class JunctionTreeBuilderTest method testEliminationCandidate2.
@Test
public void testEliminationCandidate2() {
Graph graph = new BayesNetwork();
GraphNode x0 = addNode(graph);
GraphNode x1 = addNode(graph);
GraphNode x2 = addNode(graph);
GraphNode x3 = addNode(graph);
GraphNode x4 = addNode(graph);
connectParentToChildren(x1, x2, x3, x4);
connectParentToChildren(x3, x4);
JunctionTreeBuilder jtBuilder = new JunctionTreeBuilder(graph);
jtBuilder.moralize();
EliminationCandidate vt1 = new EliminationCandidate(graph, jtBuilder.getAdjacencyMatrix(), x1);
assertEquals(2, vt1.getNewEdgesRequired());
assertEquals(bitSet("11110"), vt1.getCliqueBitSit());
}
use of org.drools.beliefs.graph.GraphNode in project drools by kiegroup.
the class JunctionTreeBuilderTest method testJunctionWithPruning3.
@Test
public void testJunctionWithPruning3() {
Graph<BayesVariable> graph = new BayesNetwork();
GraphNode x0 = addNode(graph);
GraphNode x1 = addNode(graph);
GraphNode x2 = addNode(graph);
GraphNode x3 = addNode(graph);
GraphNode x4 = addNode(graph);
GraphNode x5 = addNode(graph);
GraphNode x6 = addNode(graph);
GraphNode x7 = addNode(graph);
List<OpenBitSet> list = new ArrayList<OpenBitSet>();
OpenBitSet OpenBitSet1 = bitSet("00001111");
OpenBitSet OpenBitSet2 = bitSet("00011110");
OpenBitSet OpenBitSet3 = bitSet("11100000");
OpenBitSet OpenBitSet4 = bitSet("01100001");
OpenBitSet intersect1And2 = ((OpenBitSet) OpenBitSet2.clone());
intersect1And2.and(OpenBitSet1);
OpenBitSet intersect2And3 = ((OpenBitSet) OpenBitSet2.clone());
intersect2And3.and(OpenBitSet3);
OpenBitSet intersect1And4 = ((OpenBitSet) OpenBitSet1.clone());
intersect1And4.and(OpenBitSet4);
OpenBitSet intersect3And4 = ((OpenBitSet) OpenBitSet3.clone());
intersect3And4.and(OpenBitSet4);
list.add(OpenBitSet1);
list.add(OpenBitSet2);
list.add(OpenBitSet3);
list.add(OpenBitSet4);
JunctionTreeBuilder jtBuilder = new JunctionTreeBuilder(graph);
JunctionTreeClique jtNode = jtBuilder.junctionTree(list, false).getRoot();
;
JunctionTreeClique root = jtNode;
assertEquals(OpenBitSet1, root.getBitSet());
assertEquals(2, root.getChildren().size());
JunctionTreeSeparator sep = root.getChildren().get(0);
assertEquals(OpenBitSet1, sep.getParent().getBitSet());
assertEquals(OpenBitSet2, sep.getChild().getBitSet());
assertEquals(0, sep.getChild().getChildren().size());
sep = root.getChildren().get(1);
assertEquals(OpenBitSet1, sep.getParent().getBitSet());
assertEquals(OpenBitSet4, sep.getChild().getBitSet());
assertEquals(intersect1And4, sep.getBitSet());
assertEquals(1, sep.getChild().getChildren().size());
jtNode = sep.getChild();
assertEquals(OpenBitSet4, jtNode.getBitSet());
assertEquals(1, jtNode.getChildren().size());
sep = jtNode.getChildren().get(0);
assertEquals(OpenBitSet4, sep.getParent().getBitSet());
assertEquals(OpenBitSet3, sep.getChild().getBitSet());
assertEquals(intersect3And4, sep.getBitSet());
assertEquals(0, sep.getChild().getChildren().size());
}
Aggregations