use of edu.cmu.tetrad.sem.SemPm in project tetrad by cmu-phil.
the class TestPurify method test1.
@Test
public void test1() {
RandomUtil.getInstance().setSeed(48290483L);
SemGraph graph = new SemGraph();
Node l1 = new GraphNode("L1");
l1.setNodeType(NodeType.LATENT);
Node l2 = new GraphNode("L2");
l2.setNodeType(NodeType.LATENT);
Node l3 = new GraphNode("L3");
l3.setNodeType(NodeType.LATENT);
Node x1 = new GraphNode("X1");
Node x2 = new GraphNode("X2");
Node x3 = new GraphNode("X3");
Node x4 = new GraphNode("X4");
Node x4b = new GraphNode("X4b");
Node x5 = new GraphNode("X5");
Node x6 = new GraphNode("X6");
Node x7 = new GraphNode("X7");
Node x8 = new GraphNode("X8");
Node x8b = new GraphNode("X8b");
Node x9 = new GraphNode("X9");
Node x10 = new GraphNode("X10");
Node x11 = new GraphNode("X11");
Node x12 = new GraphNode("X12");
Node x12b = new GraphNode("X12b");
graph.addNode(l1);
graph.addNode(l2);
graph.addNode(l3);
graph.addNode(x1);
graph.addNode(x2);
graph.addNode(x3);
graph.addNode(x4);
graph.addNode(x4b);
graph.addNode(x5);
graph.addNode(x6);
graph.addNode(x7);
graph.addNode(x8);
graph.addNode(x8b);
graph.addNode(x9);
graph.addNode(x10);
graph.addNode(x11);
graph.addNode(x12);
graph.addNode(x12b);
graph.addDirectedEdge(l1, x1);
graph.addDirectedEdge(l1, x2);
graph.addDirectedEdge(l1, x3);
graph.addDirectedEdge(l1, x4);
graph.addDirectedEdge(l1, x4b);
graph.addDirectedEdge(l1, x5);
graph.addDirectedEdge(l2, x5);
graph.addDirectedEdge(l2, x6);
graph.addDirectedEdge(l2, x7);
graph.addDirectedEdge(l2, x8);
graph.addDirectedEdge(l2, x8b);
graph.addDirectedEdge(l3, x9);
graph.addDirectedEdge(l3, x10);
graph.addDirectedEdge(l3, x11);
graph.addDirectedEdge(l3, x12);
graph.addDirectedEdge(l3, x12b);
graph.addDirectedEdge(x1, x4);
SemPm pm = new SemPm(graph);
SemIm im = new SemIm(pm);
DataSet data = im.simulateData(1000, false);
List<List<Node>> partition = new ArrayList<>();
List<Node> cluster1 = new ArrayList<>();
cluster1.add(x1);
cluster1.add(x2);
cluster1.add(x3);
cluster1.add(x4);
cluster1.add(x4b);
cluster1.add(x5);
List<Node> cluster2 = new ArrayList<>();
cluster2.add(x5);
cluster2.add(x6);
cluster2.add(x7);
cluster2.add(x8);
cluster2.add(x8b);
List<Node> cluster3 = new ArrayList<>();
cluster3.add(x9);
cluster3.add(x10);
cluster3.add(x11);
cluster3.add(x12);
cluster3.add(x12b);
partition.add(cluster1);
partition.add(cluster2);
partition.add(cluster3);
TetradTest test = new ContinuousTetradTest(data, TestType.TETRAD_WISHART, 0.05);
IPurify purify = new PurifyTetradBased2(test);
purify.setTrueGraph(graph);
List<List<Node>> partition2 = purify.purify(partition);
assertEquals(3, partition2.get(0).size());
assertEquals(2, partition2.get(1).size());
assertEquals(5, partition2.get(2).size());
}
use of edu.cmu.tetrad.sem.SemPm in project tetrad by cmu-phil.
the class TestRicf method test2.
@Test
public void test2() {
List<Node> nodes = new ArrayList<>();
for (int i = 0; i < 10; i++) {
nodes.add(new ContinuousVariable("X" + (i + 1)));
}
Graph graph = new Dag(GraphUtils.randomGraph(nodes, 0, 10, 30, 15, 15, false));
SemPm pm = new SemPm(graph);
SemIm im = new SemIm(pm);
DataSet data = im.simulateData(1000, false);
CovarianceMatrix cov = new CovarianceMatrix(data);
Ricf.RicfResult result = new Ricf().ricf(new SemGraph(graph), cov, 0.001);
result.getBhat();
}
use of edu.cmu.tetrad.sem.SemPm in project tetrad by cmu-phil.
the class TestRicf method test4.
@Test
public void test4() {
List<Node> nodes1 = new ArrayList<>();
for (int i1 = 0; i1 < 5; i1++) {
nodes1.add(new ContinuousVariable("X" + (i1 + 1)));
}
Graph g1 = GraphUtils.randomGraph(nodes1, 0, 5, 0, 0, 0, false);
List<Node> nodes = new ArrayList<>();
for (int i = 0; i < 5; i++) {
nodes.add(new ContinuousVariable("X" + (i + 1)));
}
Graph g2 = GraphUtils.randomGraph(nodes, 0, 5, 0, 0, 0, false);
SemPm pm = new SemPm(g1);
SemIm im = new SemIm(pm);
DataSet dataset = im.simulateData(1000, false);
ICovarianceMatrix cov = new CovarianceMatrix(dataset);
new Ricf().ricf(new SemGraph(g1), cov, 0.001);
new Ricf().ricf(new SemGraph(g2), cov, 0.001);
}
use of edu.cmu.tetrad.sem.SemPm in project tetrad by cmu-phil.
the class TestIndTestWaldLR method testIsIndependent.
@Test
public void testIsIndependent() {
RandomUtil.getInstance().setSeed(1450705713157L);
int numPassed = 0;
for (int i = 0; i < 10; i++) {
List<Node> nodes = new ArrayList<>();
for (int i1 = 0; i1 < 10; i1++) {
nodes.add(new ContinuousVariable("X" + (i1 + 1)));
}
Graph graph = GraphUtils.randomGraph(nodes, 0, 10, 3, 3, 3, false);
SemPm pm = new SemPm(graph);
SemIm im = new SemIm(pm);
DataSet data = im.simulateData(1000, false);
Discretizer discretizer = new Discretizer(data);
discretizer.setVariablesCopied(true);
discretizer.equalCounts(data.getVariable(0), 2);
discretizer.equalCounts(data.getVariable(3), 2);
data = discretizer.discretize();
Node x1 = data.getVariable("X1");
Node x2 = data.getVariable("X2");
Node x3 = data.getVariable("X3");
Node x4 = data.getVariable("X4");
Node x5 = data.getVariable("X5");
List<Node> cond = new ArrayList<>();
cond.add(x3);
cond.add(x4);
cond.add(x5);
Node x1Graph = graph.getNode(x1.getName());
Node x2Graph = graph.getNode(x2.getName());
List<Node> condGraph = new ArrayList<>();
for (Node node : cond) {
condGraph.add(graph.getNode(node.getName()));
}
// Using the Wald LR test since it's most up to date.
IndependenceTest test = new IndTestMultinomialLogisticRegressionWald(data, 0.05, false);
IndTestDSep dsep = new IndTestDSep(graph);
boolean correct = test.isIndependent(x2, x1, cond) == dsep.isIndependent(x2Graph, x1Graph, condGraph);
if (correct) {
numPassed++;
}
}
// System.out.println(RandomUtil.getInstance().getSeed());
// Do not always get all 10.
assertEquals(10, numPassed);
}
use of edu.cmu.tetrad.sem.SemPm in project tetrad by cmu-phil.
the class TestLingamPattern method test1.
@Test
public void test1() {
RandomUtil.getInstance().setSeed(4938492L);
int sampleSize = 1000;
List<Node> nodes = new ArrayList<>();
for (int i = 0; i < 6; i++) {
nodes.add(new ContinuousVariable("X" + (i + 1)));
}
Graph graph = new Dag(GraphUtils.randomGraph(nodes, 0, 6, 4, 4, 4, false));
List<Distribution> variableDistributions = new ArrayList<>();
variableDistributions.add(new Normal(0, 1));
variableDistributions.add(new Normal(0, 1));
variableDistributions.add(new Normal(0, 1));
variableDistributions.add(new Uniform(-1, 1));
variableDistributions.add(new Normal(0, 1));
variableDistributions.add(new Normal(0, 1));
SemPm semPm = new SemPm(graph);
SemIm semIm = new SemIm(semPm);
DataSet dataSet = simulateDataNonNormal(semIm, sampleSize, variableDistributions);
Score score = new SemBicScore(new CovarianceMatrixOnTheFly(dataSet));
Graph estPattern = new Fges(score).search();
LingamPattern lingam = new LingamPattern(estPattern, dataSet);
lingam.search();
double[] pvals = lingam.getPValues();
double[] expectedPVals = { 0.18, 0.29, 0.88, 0.00, 0.01, 0.58 };
for (int i = 0; i < pvals.length; i++) {
assertEquals(expectedPVals[i], pvals[i], 0.01);
}
}
Aggregations