use of edu.cmu.tetrad.graph.Dag in project tetrad by cmu-phil.
the class TestCptInvariantUpdater method testUpdate5.
@Test
public void testUpdate5() {
Node x0Node = new GraphNode("X0");
Node x1Node = new GraphNode("X1");
Node x2Node = new GraphNode("X2");
Node x3Node = new GraphNode("X3");
Node x4Node = new GraphNode("X4");
Dag graph = new Dag();
graph.addNode(x0Node);
graph.addNode(x1Node);
graph.addNode(x2Node);
graph.addNode(x3Node);
graph.addNode(x4Node);
graph.addDirectedEdge(x0Node, x1Node);
graph.addDirectedEdge(x0Node, x2Node);
graph.addDirectedEdge(x1Node, x3Node);
graph.addDirectedEdge(x2Node, x3Node);
graph.addDirectedEdge(x4Node, x0Node);
graph.addDirectedEdge(x4Node, x2Node);
BayesPm bayesPm = new BayesPm(graph);
MlBayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM);
int x1 = bayesIm.getNodeIndex(x1Node);
int x2 = bayesIm.getNodeIndex(x2Node);
int x3 = bayesIm.getNodeIndex(x3Node);
Evidence evidence = Evidence.tautology(bayesIm);
evidence.getProposition().setCategory(x1, 1);
evidence.getProposition().setCategory(x2, 0);
evidence.getNodeIndex("X1");
BayesUpdater updater1 = new CptInvariantUpdater(bayesIm);
updater1.setEvidence(evidence);
BayesUpdater updater2 = new RowSummingExactUpdater(bayesIm);
updater2.setEvidence(evidence);
double marginal1 = updater1.getMarginal(x3, 0);
double marginal2 = updater2.getMarginal(x3, 0);
assertEquals(marginal1, marginal2, 0.000001);
}
use of edu.cmu.tetrad.graph.Dag in project tetrad by cmu-phil.
the class TestBayesXml method sampleBayesIm2.
private static BayesIm sampleBayesIm2() {
Node a = new GraphNode("a");
Node b = new GraphNode("b");
Node c = new GraphNode("c");
Dag graph = new Dag();
graph.addNode(a);
graph.addNode(b);
graph.addNode(c);
graph.addDirectedEdge(a, b);
graph.addDirectedEdge(a, c);
graph.addDirectedEdge(b, c);
BayesPm bayesPm = new BayesPm(graph);
bayesPm.setNumCategories(b, 3);
return new MlBayesIm(bayesPm, MlBayesIm.RANDOM);
}
use of edu.cmu.tetrad.graph.Dag in project tetrad by cmu-phil.
the class TestDagScorer method test1.
@Test
public void test1() {
RandomUtil.getInstance().setSeed(492839483L);
List<Node> nodes = new ArrayList<>();
for (int i = 0; i < 10; i++) {
nodes.add(new ContinuousVariable("X" + (i + 1)));
}
Graph dag = new Dag(GraphUtils.randomGraph(nodes, 0, 10, 30, 15, 15, false));
SemPm pm = new SemPm(dag);
SemIm im = new SemIm(pm);
DataSet data = im.simulateData(1000, false);
GraphUtils.replaceNodes(dag, data.getVariables());
SemEstimator est = new SemEstimator(data, pm);
SemIm estSem = est.estimate();
double fml = estSem.getScore();
assertEquals(0.0369, fml, 0.001);
dag = GraphUtils.replaceNodes(dag, data.getVariables());
Scorer scorer = new DagScorer(data);
double _fml = scorer.score(dag);
assertEquals(0.0358, _fml, 0.001);
double bicScore = scorer.getBicScore();
assertEquals(-205, bicScore, 1);
int dof = scorer.getDof();
assertEquals(35, dof);
int numFreeParams = scorer.getNumFreeParams();
assertEquals(20, numFreeParams);
}
Aggregations