use of edu.cmu.tetrad.bayes.MlBayesIm in project tetrad by cmu-phil.
the class TestBayesIm method testCopyConstructor.
@Test
public void testCopyConstructor() {
Graph graph = GraphConverter.convert("X1-->X2,X1-->X3,X2-->X4,X3-->X4");
Dag dag = new Dag(graph);
BayesPm bayesPm = new BayesPm(dag);
BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM);
BayesIm bayesIm2 = new MlBayesIm(bayesIm);
assertEquals(bayesIm, bayesIm2);
}
use of edu.cmu.tetrad.bayes.MlBayesIm in project tetrad by cmu-phil.
the class TestBayesIm method testAddRemoveParent.
/**
* Tests whether the BayesIm does the right thing in a very simple case
* where nodes are added or removed from the graph. Start with graph a -> b,
* parameterizing with two values for each node. Construct and fill in
* probability tables in BayesIm. Then add edge c -> b "manually." This
* should create a table of values for c that is unspecified, and it should
* double up the rows from b. Then remove the node c. Now the table for b
* should be completely unspecified.
*/
@Test
public void testAddRemoveParent() {
Node a = new GraphNode("a");
Node b = new GraphNode("b");
Graph dag = new EdgeListGraph();
dag.addNode(a);
dag.addNode(b);
dag.addDirectedEdge(a, b);
BayesPm bayesPm = new BayesPm(dag);
BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM);
BayesIm bayesIm2 = new MlBayesIm(bayesPm, bayesIm, MlBayesIm.MANUAL);
assertEquals(bayesIm, bayesIm2);
Node c = new GraphNode("c");
dag.addNode(c);
dag.addDirectedEdge(c, b);
BayesPm bayesPm3 = new BayesPm(dag, bayesPm);
BayesIm bayesIm3 = new MlBayesIm(bayesPm3, bayesIm2, MlBayesIm.MANUAL);
// Make sure the rows got repeated.
// assertTrue(rowsEqual(bayesIm3, bayesIm3.getNodeIndex(b), 0, 1));
// assertTrue(!rowsEqual(bayesIm3, bayesIm3.getNodeIndex(b), 1, 2));
// assertTrue(rowsEqual(bayesIm3, bayesIm3.getNodeIndex(b), 2, 3));
// Make sure the 'c' node got ?'s.
assertTrue(rowUnspecified(bayesIm3, bayesIm3.getNodeIndex(c), 0));
dag.removeNode(c);
BayesPm bayesPm4 = new BayesPm(dag, bayesPm3);
BayesIm bayesIm4 = new MlBayesIm(bayesPm4, bayesIm3, MlBayesIm.MANUAL);
// Make sure the 'b' node has 2 rows of '?'s'.
assertTrue(bayesIm4.getNumRows(bayesIm4.getNodeIndex(b)) == 2);
assertTrue(rowUnspecified(bayesIm4, bayesIm4.getNodeIndex(b), 0));
assertTrue(rowUnspecified(bayesIm4, bayesIm4.getNodeIndex(b), 1));
}
use of edu.cmu.tetrad.bayes.MlBayesIm in project tetrad by cmu-phil.
the class TestBayesIm method testConstructManual.
@Test
public void testConstructManual() {
Graph graph = GraphConverter.convert("X1-->X2,X1-->X3,X2-->X4,X3-->X4");
Graph dag = new Dag(graph);
BayesPm bayesPm = new BayesPm(dag);
BayesIm bayesIm = new MlBayesIm(bayesPm);
Graph dag1 = bayesIm.getBayesPm().getDag();
Graph dag2 = GraphUtils.replaceNodes(dag1, graph.getNodes());
assertEquals(dag2, graph);
}
use of edu.cmu.tetrad.bayes.MlBayesIm in project tetrad by cmu-phil.
the class TestCellProbabilities method testCreateUsingBayesIm.
@Test
public void testCreateUsingBayesIm() {
RandomUtil.getInstance().setSeed(4828385834L);
Graph graph = GraphConverter.convert("X1-->X2,X1-->X3,X2-->X4,X3-->X4");
Dag dag = new Dag(graph);
BayesPm bayesPm = new BayesPm(dag);
BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM);
StoredCellProbs cellProbs = StoredCellProbs.createCellTable(bayesIm);
double prob = cellProbs.getCellProb(new int[] { 0, 0, 0, 0 });
assertEquals(0.0058, prob, 0.0001);
}
use of edu.cmu.tetrad.bayes.MlBayesIm in project tetrad by cmu-phil.
the class TestDataLoadersRoundtrip method testDiscreteRoundtrip.
@Test
public void testDiscreteRoundtrip() {
setUp();
try {
for (int i = 0; i < 1; i++) {
List<Node> nodes = new ArrayList<>();
for (int j = 0; j < 5; j++) {
nodes.add(new ContinuousVariable("X" + (j + 1)));
}
Graph randomGraph = new Dag(GraphUtils.randomGraph(nodes, 0, 8, 30, 15, 15, false));
Dag dag = new Dag(randomGraph);
BayesPm bayesPm1 = new BayesPm(dag);
MlBayesIm bayesIm1 = new MlBayesIm(bayesPm1, MlBayesIm.RANDOM);
DataSet dataSet = bayesIm1.simulateData(10, false);
new File("target/test_data").mkdir();
FileWriter fileWriter = new FileWriter("target/test_data/roundtrip.dat");
Writer writer = new PrintWriter(fileWriter);
DataWriter.writeRectangularData(dataSet, writer, '\t');
writer.close();
File file = new File("target/test_data/roundtrip.dat");
DataReader reader = new DataReader();
reader.setKnownVariables(dataSet.getVariables());
DataSet _dataSet = reader.parseTabular(file);
assertTrue(dataSet.equals(_dataSet));
}
} catch (IOException e) {
e.printStackTrace();
fail(e.getMessage());
}
}
Aggregations