use of edu.cmu.tetrad.graph.Dag in project tetrad by cmu-phil.
the class TestCptInvariantUpdater method testUpdate4.
@Test
public void testUpdate4() {
Node x0Node = new GraphNode("X0");
Node x1Node = new GraphNode("X1");
Node x2Node = new GraphNode("X2");
Node x3Node = new GraphNode("X3");
Dag graph = new Dag();
graph.addNode(x0Node);
graph.addNode(x1Node);
graph.addNode(x2Node);
graph.addNode(x3Node);
graph.addDirectedEdge(x0Node, x1Node);
graph.addDirectedEdge(x0Node, x2Node);
graph.addDirectedEdge(x1Node, x3Node);
graph.addDirectedEdge(x2Node, x3Node);
BayesPm bayesPm = new BayesPm(graph);
MlBayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM);
// int x0 = bayesIm.getNodeIndex(x0Node);
// int x1 = bayesIm.getNodeIndex(x1Node);
int x2 = bayesIm.getNodeIndex(x2Node);
int x3 = bayesIm.getNodeIndex(x3Node);
Evidence evidence = Evidence.tautology(bayesIm);
evidence.getProposition().setCategory(x2, 0);
BayesUpdater updater1 = new CptInvariantUpdater(bayesIm);
updater1.setEvidence(evidence);
BayesUpdater updater2 = new RowSummingExactUpdater(bayesIm);
updater2.setEvidence(evidence);
double marginal1 = updater1.getMarginal(x3, 0);
double marginal2 = updater2.getMarginal(x3, 0);
assertEquals(marginal1, marginal2, 0.000001);
}
use of edu.cmu.tetrad.graph.Dag in project tetrad by cmu-phil.
the class TestCptInvariantUpdater method sampleBayesIm1.
private BayesIm sampleBayesIm1() {
Node x = new GraphNode("x");
Node z = new GraphNode("z");
Dag graph = new Dag();
graph.addNode(x);
graph.addNode(z);
graph.addDirectedEdge(x, z);
BayesPm bayesPm = new BayesPm(graph);
BayesIm bayesIm1 = new MlBayesIm(bayesPm);
bayesIm1.setProbability(0, 0, 0, .3);
bayesIm1.setProbability(0, 0, 1, .7);
bayesIm1.setProbability(1, 0, 0, .8);
bayesIm1.setProbability(1, 0, 1, .2);
bayesIm1.setProbability(1, 1, 0, .4);
bayesIm1.setProbability(1, 1, 1, .6);
return bayesIm1;
}
use of edu.cmu.tetrad.graph.Dag in project tetrad by cmu-phil.
the class TestDag method checkCopy.
private void checkCopy(Graph graph) {
Graph graph2 = new Dag(graph);
assertEquals(graph, graph2);
}
use of edu.cmu.tetrad.graph.Dag in project tetrad by cmu-phil.
the class TestDataLoadersRoundtrip method testDiscreteRoundtrip.
@Test
public void testDiscreteRoundtrip() {
setUp();
try {
for (int i = 0; i < 1; i++) {
List<Node> nodes = new ArrayList<>();
for (int j = 0; j < 5; j++) {
nodes.add(new ContinuousVariable("X" + (j + 1)));
}
Graph randomGraph = new Dag(GraphUtils.randomGraph(nodes, 0, 8, 30, 15, 15, false));
Dag dag = new Dag(randomGraph);
BayesPm bayesPm1 = new BayesPm(dag);
MlBayesIm bayesIm1 = new MlBayesIm(bayesPm1, MlBayesIm.RANDOM);
DataSet dataSet = bayesIm1.simulateData(10, false);
new File("target/test_data").mkdir();
FileWriter fileWriter = new FileWriter("target/test_data/roundtrip.dat");
Writer writer = new PrintWriter(fileWriter);
DataWriter.writeRectangularData(dataSet, writer, '\t');
writer.close();
File file = new File("target/test_data/roundtrip.dat");
DataReader reader = new DataReader();
reader.setKnownVariables(dataSet.getVariables());
DataSet _dataSet = reader.parseTabular(file);
assertTrue(dataSet.equals(_dataSet));
}
} catch (IOException e) {
e.printStackTrace();
fail(e.getMessage());
}
}
use of edu.cmu.tetrad.graph.Dag in project tetrad by cmu-phil.
the class TestDataLoadersRoundtrip method testContinuousRoundtrip.
@Test
public void testContinuousRoundtrip() {
setUp();
try {
List<Node> nodes = new ArrayList<>();
for (int i = 0; i < 5; i++) {
nodes.add(new ContinuousVariable("X" + (i + 1)));
}
Graph randomGraph = new Dag(GraphUtils.randomGraph(nodes, 0, 5, 30, 15, 15, false));
SemPm semPm1 = new SemPm(randomGraph);
SemIm semIm1 = new SemIm(semPm1);
DataSet dataSet = semIm1.simulateData(10, false);
FileWriter fileWriter = new FileWriter("target/test_data/roundtrip.dat");
Writer writer = new PrintWriter(fileWriter);
DataWriter.writeRectangularData(dataSet, writer, ',');
writer.close();
//
new File("test_data").mkdir();
File file = new File("target/test_data/roundtrip.dat");
DataReader reader = new DataReader();
reader.setDelimiter(DelimiterType.COMMA);
DataSet _dataSet = reader.parseTabular(file);
assertTrue(dataSet.equals(_dataSet));
} catch (IOException e) {
e.printStackTrace();
fail();
}
}
Aggregations