use of edu.cmu.tetrad.sem.SemPm in project tetrad by cmu-phil.
the class TestPcStableMax method testPcStable2.
@Test
public void testPcStable2() {
RandomUtil.getInstance().setSeed(1450030184196L);
List<Node> nodes = new ArrayList<>();
for (int i = 0; i < 10; i++) {
nodes.add(new ContinuousVariable("X" + (i + 1)));
}
Graph graph = GraphUtils.randomGraph(nodes, 0, 10, 30, 15, 15, false);
SemPm pm = new SemPm(graph);
SemIm im = new SemIm(pm);
DataSet data = im.simulateData(200, false);
TetradLogger.getInstance().setForceLog(false);
IndependenceTest test = new IndTestFisherZ(data, 0.05);
PcStableMax pc = new PcStableMax(test);
pc.setVerbose(false);
Graph pattern = pc.search();
for (int i = 0; i < 1; i++) {
DataSet data2 = DataUtils.reorderColumns(data);
IndependenceTest test2 = new IndTestFisherZ(data2, 0.05);
PcStableMax pc2 = new PcStableMax(test2);
pc2.setVerbose(false);
Graph pattern2 = pc2.search();
assertTrue(pattern.equals(pattern2));
}
}
use of edu.cmu.tetrad.sem.SemPm in project tetrad by cmu-phil.
the class GraphWithParameters method regress.
/**
* creates a PatternWithParameters by running a regression, given a graph and data
*/
public static GraphWithParameters regress(DataSet dataSet, Graph graph) {
SemPm semPmEstDag = new SemPm(graph);
SemEstimator estimatorEstDag = new SemEstimator(dataSet, semPmEstDag);
estimatorEstDag.estimate();
SemIm semImEstDag = estimatorEstDag.getEstimatedSem();
GraphWithParameters estimatedGraph = new GraphWithParameters(semImEstDag, graph);
return estimatedGraph;
}
use of edu.cmu.tetrad.sem.SemPm in project tetrad by cmu-phil.
the class TestDM method test5.
// Three latent collider case
@Test
public void test5() {
// setting seed for debug.
RandomUtil.getInstance().setSeed(29483818483L);
Graph graph = emptyGraph(6);
graph.addDirectedEdge(new ContinuousVariable("X0"), new ContinuousVariable("X3"));
graph.addDirectedEdge(new ContinuousVariable("X0"), new ContinuousVariable("X4"));
graph.addDirectedEdge(new ContinuousVariable("X1"), new ContinuousVariable("X4"));
graph.addDirectedEdge(new ContinuousVariable("X2"), new ContinuousVariable("X4"));
graph.addDirectedEdge(new ContinuousVariable("X2"), new ContinuousVariable("X5"));
SemPm pm = new SemPm(graph);
SemIm im = new SemIm(pm);
DataSet data = im.simulateData(100000, false);
DMSearch search = new DMSearch();
search.setInputs(new int[] { 0, 1, 2 });
search.setOutputs(new int[] { 3, 4, 5 });
search.setData(data);
search.setTrueInputs(search.getInputs());
Graph foundGraph = search.search();
print("Three Latent Collider Case");
Graph trueGraph = new EdgeListGraph();
trueGraph.addNode(new ContinuousVariable("X0"));
trueGraph.addNode(new ContinuousVariable("X1"));
trueGraph.addNode(new ContinuousVariable("X2"));
trueGraph.addNode(new ContinuousVariable("X3"));
trueGraph.addNode(new ContinuousVariable("X4"));
trueGraph.addNode(new ContinuousVariable("X5"));
trueGraph.addNode(new ContinuousVariable("L0"));
trueGraph.addNode(new ContinuousVariable("L1"));
trueGraph.addNode(new ContinuousVariable("L2"));
trueGraph.addDirectedEdge(new ContinuousVariable("X0"), new ContinuousVariable("L0"));
trueGraph.addDirectedEdge(new ContinuousVariable("L0"), new ContinuousVariable("X3"));
trueGraph.addDirectedEdge(new ContinuousVariable("X0"), new ContinuousVariable("L1"));
trueGraph.addDirectedEdge(new ContinuousVariable("X1"), new ContinuousVariable("L1"));
trueGraph.addDirectedEdge(new ContinuousVariable("X2"), new ContinuousVariable("L1"));
trueGraph.addDirectedEdge(new ContinuousVariable("L1"), new ContinuousVariable("X4"));
trueGraph.addDirectedEdge(new ContinuousVariable("X2"), new ContinuousVariable("L2"));
trueGraph.addDirectedEdge(new ContinuousVariable("L2"), new ContinuousVariable("X5"));
assertTrue(foundGraph.equals(trueGraph));
}
use of edu.cmu.tetrad.sem.SemPm in project tetrad by cmu-phil.
the class TestDM method test4.
// Three latent fork case
@Test
public void test4() {
// setting seed for debug.
RandomUtil.getInstance().setSeed(29483818483L);
Graph graph = emptyGraph(6);
graph.addDirectedEdge(new ContinuousVariable("X0"), new ContinuousVariable("X3"));
graph.addDirectedEdge(new ContinuousVariable("X1"), new ContinuousVariable("X3"));
graph.addDirectedEdge(new ContinuousVariable("X1"), new ContinuousVariable("X4"));
graph.addDirectedEdge(new ContinuousVariable("X1"), new ContinuousVariable("X5"));
graph.addDirectedEdge(new ContinuousVariable("X2"), new ContinuousVariable("X5"));
SemPm pm = new SemPm(graph);
SemIm im = new SemIm(pm);
DataSet data = im.simulateData(100000, false);
DMSearch search = new DMSearch();
search.setInputs(new int[] { 0, 1, 2 });
search.setOutputs(new int[] { 3, 4, 5 });
search.setData(data);
search.setTrueInputs(search.getInputs());
Graph foundGraph = search.search();
print("Three Latent Fork Case");
Graph trueGraph = new EdgeListGraph();
trueGraph.addNode(new ContinuousVariable("X0"));
trueGraph.addNode(new ContinuousVariable("X1"));
trueGraph.addNode(new ContinuousVariable("X2"));
trueGraph.addNode(new ContinuousVariable("X3"));
trueGraph.addNode(new ContinuousVariable("X4"));
trueGraph.addNode(new ContinuousVariable("X5"));
trueGraph.addNode(new ContinuousVariable("L0"));
trueGraph.addNode(new ContinuousVariable("L1"));
trueGraph.addNode(new ContinuousVariable("L2"));
trueGraph.addDirectedEdge(new ContinuousVariable("X0"), new ContinuousVariable("L0"));
trueGraph.addDirectedEdge(new ContinuousVariable("X1"), new ContinuousVariable("L0"));
trueGraph.addDirectedEdge(new ContinuousVariable("L0"), new ContinuousVariable("X3"));
trueGraph.addDirectedEdge(new ContinuousVariable("X1"), new ContinuousVariable("L1"));
trueGraph.addDirectedEdge(new ContinuousVariable("L1"), new ContinuousVariable("X4"));
trueGraph.addDirectedEdge(new ContinuousVariable("X1"), new ContinuousVariable("L2"));
trueGraph.addDirectedEdge(new ContinuousVariable("X2"), new ContinuousVariable("L2"));
trueGraph.addDirectedEdge(new ContinuousVariable("L2"), new ContinuousVariable("X5"));
assertTrue(foundGraph.equals(trueGraph));
}
use of edu.cmu.tetrad.sem.SemPm in project tetrad by cmu-phil.
the class TestDM method test6.
// Four latent case.
@Test
public void test6() {
// setting seed for debug.
RandomUtil.getInstance().setSeed(29483818483L);
Graph graph = emptyGraph(8);
graph.addDirectedEdge(new ContinuousVariable("X0"), new ContinuousVariable("X4"));
graph.addDirectedEdge(new ContinuousVariable("X0"), new ContinuousVariable("X5"));
graph.addDirectedEdge(new ContinuousVariable("X0"), new ContinuousVariable("X6"));
graph.addDirectedEdge(new ContinuousVariable("X0"), new ContinuousVariable("X7"));
graph.addDirectedEdge(new ContinuousVariable("X1"), new ContinuousVariable("X5"));
graph.addDirectedEdge(new ContinuousVariable("X1"), new ContinuousVariable("X6"));
graph.addDirectedEdge(new ContinuousVariable("X1"), new ContinuousVariable("X7"));
graph.addDirectedEdge(new ContinuousVariable("X2"), new ContinuousVariable("X6"));
graph.addDirectedEdge(new ContinuousVariable("X2"), new ContinuousVariable("X7"));
graph.addDirectedEdge(new ContinuousVariable("X3"), new ContinuousVariable("X7"));
SemPm pm = new SemPm(graph);
SemIm im = new SemIm(pm);
DataSet data = im.simulateData(1000, false);
DMSearch search = new DMSearch();
search.setInputs(new int[] { 0, 1, 2, 3 });
search.setOutputs(new int[] { 4, 5, 6, 7 });
search.setData(data);
search.setTrueInputs(search.getInputs());
Graph foundGraph = search.search();
print("Four Latent Case");
print("search.getDmStructure().latentStructToEdgeListGraph(search.getDmStructure())");
Graph trueGraph = new EdgeListGraph();
trueGraph.addNode(new ContinuousVariable("X0"));
trueGraph.addNode(new ContinuousVariable("X1"));
trueGraph.addNode(new ContinuousVariable("X2"));
trueGraph.addNode(new ContinuousVariable("X3"));
trueGraph.addNode(new ContinuousVariable("X4"));
trueGraph.addNode(new ContinuousVariable("X5"));
trueGraph.addNode(new ContinuousVariable("X6"));
trueGraph.addNode(new ContinuousVariable("X7"));
trueGraph.addNode(new ContinuousVariable("L0"));
trueGraph.addNode(new ContinuousVariable("L1"));
trueGraph.addNode(new ContinuousVariable("L2"));
trueGraph.addNode(new ContinuousVariable("L3"));
trueGraph.addDirectedEdge(new ContinuousVariable("X0"), new ContinuousVariable("L0"));
trueGraph.addDirectedEdge(new ContinuousVariable("L0"), new ContinuousVariable("X4"));
trueGraph.addDirectedEdge(new ContinuousVariable("X0"), new ContinuousVariable("L1"));
trueGraph.addDirectedEdge(new ContinuousVariable("X1"), new ContinuousVariable("L1"));
trueGraph.addDirectedEdge(new ContinuousVariable("L1"), new ContinuousVariable("X5"));
trueGraph.addDirectedEdge(new ContinuousVariable("X0"), new ContinuousVariable("L2"));
trueGraph.addDirectedEdge(new ContinuousVariable("X1"), new ContinuousVariable("L2"));
trueGraph.addDirectedEdge(new ContinuousVariable("X2"), new ContinuousVariable("L2"));
trueGraph.addDirectedEdge(new ContinuousVariable("L2"), new ContinuousVariable("X6"));
trueGraph.addDirectedEdge(new ContinuousVariable("X0"), new ContinuousVariable("L3"));
trueGraph.addDirectedEdge(new ContinuousVariable("X1"), new ContinuousVariable("L3"));
trueGraph.addDirectedEdge(new ContinuousVariable("X2"), new ContinuousVariable("L3"));
trueGraph.addDirectedEdge(new ContinuousVariable("X3"), new ContinuousVariable("L3"));
trueGraph.addDirectedEdge(new ContinuousVariable("L3"), new ContinuousVariable("X7"));
assertTrue(foundGraph.equals(trueGraph));
}
Aggregations