use of edu.cmu.tetrad.sem.SemIm in project tetrad by cmu-phil.
the class TestDeltaSextadTest method test2.
@Test
public void test2() {
int c = 2;
int m = 2;
int p = 6;
Graph g = new EdgeListGraph();
List<List<Node>> varClusters = new ArrayList<>();
List<List<Node>> latents = new ArrayList<>();
List<Node> vars = new ArrayList<>();
for (int y = 0; y < c; y++) {
varClusters.add(new ArrayList<Node>());
latents.add(new ArrayList<Node>());
}
int e = 0;
for (int y = 0; y < c; y++) {
for (int i = 0; i < p; i++) {
GraphNode n = new GraphNode("V" + ++e);
vars.add(n);
varClusters.get(y).add(n);
g.addNode(n);
}
}
List<Node> l = new ArrayList<>();
int f = 0;
for (int y = 0; y < c; y++) {
for (int j = 0; j < m; j++) {
Node _l = new GraphNode("L" + ++f);
_l.setNodeType(NodeType.LATENT);
l.add(_l);
latents.get(y).add(_l);
g.addNode(_l);
}
}
for (int y = 0; y < c; y++) {
for (int i = 0; i < m; i++) {
for (int j = 0; j < p; j++) {
g.addDirectedEdge(latents.get(y).get(i), varClusters.get(y).get(j));
}
}
}
for (int y = 1; y < c; y++) {
for (int j = 0; j < m; j++) {
g.addDirectedEdge(latents.get(y - 1).get(j), latents.get(y).get(j));
}
}
SemPm pm = new SemPm(g);
SemIm im = new SemIm(pm);
DataSet data = im.simulateData(1000, false);
List<Integer> indices = new ArrayList<>();
indices.add(0);
indices.add(1);
indices.add(2);
indices.add(4);
indices.add(5);
indices.add(7);
Collections.shuffle(indices);
// Node x1 = data.getVariable(indices.get(0));
// Node x2 = data.getVariable(indices.get(1));
// Node x3 = data.getVariable(indices.get(2));
// Node x4 = data.getVariable(indices.get(3));
// Node x5 = data.getVariable(indices.get(4));
// Node x6 = data.getVariable(indices.get(5));
int x1 = indices.get(0);
int x2 = indices.get(1);
int x3 = indices.get(2);
int x4 = indices.get(3);
int x5 = indices.get(4);
int x6 = indices.get(5);
DeltaSextadTest test = new DeltaSextadTest(data);
// Should be invariant to changes or order of the first three or of the last three variables.
double a = test.getPValue(new IntSextad(x1, x2, x3, x4, x5, x6));
double b = test.getPValue(new IntSextad(x2, x3, x1, x5, x4, x6));
assertEquals(a, b, 1e-7);
}
use of edu.cmu.tetrad.sem.SemIm in project tetrad by cmu-phil.
the class TestDeltaSextadTest method testBollenExample1.
// Bollen and Ting, Confirmatory Tetrad Analysis, p. 164 Sympathy and Anger.
@Test
public void testBollenExample1() {
SemIm sem = getSem1();
DataSet data = sem.simulateData(3000, false);
List<Node> variables = data.getVariables();
int m1 = 0;
int m2 = 1;
int m3 = 2;
int m4 = 3;
int m5 = 4;
int m6 = 5;
IntSextad t1 = new IntSextad(m1, m2, m3, m4, m5, m6);
IntSextad t2 = new IntSextad(m1, m2, m4, m3, m5, m6);
IntSextad t3 = new IntSextad(m1, m2, m5, m3, m4, m6);
IntSextad t4 = new IntSextad(m1, m2, m6, m3, m4, m5);
IntSextad t5 = new IntSextad(m1, m3, m4, m2, m5, m6);
IntSextad t6 = new IntSextad(m1, m3, m5, m2, m4, m6);
IntSextad t7 = new IntSextad(m1, m3, m6, m2, m4, m5);
IntSextad t8 = new IntSextad(m1, m4, m5, m2, m3, m6);
IntSextad t9 = new IntSextad(m1, m4, m6, m2, m3, m5);
IntSextad t10 = new IntSextad(m1, m5, m6, m2, m3, m4);
List<IntSextad> sextads = new ArrayList<>();
sextads.add(t1);
sextads.add(t2);
sextads.add(t3);
sextads.add(t4);
sextads.add(t5);
sextads.add(t6);
sextads.add(t7);
sextads.add(t8);
sextads.add(t9);
sextads.add(t10);
DeltaSextadTest test = new DeltaSextadTest(data);
int numSextads = 3;
double alpha = 0.001;
ChoiceGenerator gen = new ChoiceGenerator(sextads.size(), numSextads);
int[] choice;
while ((choice = gen.next()) != null) {
IntSextad[] _sextads = new IntSextad[numSextads];
for (int i = 0; i < numSextads; i++) {
_sextads[i] = sextads.get(choice[i]);
}
double p = test.getPValue(_sextads);
}
}
use of edu.cmu.tetrad.sem.SemIm in project tetrad by cmu-phil.
the class TestDeltaSextadTest method getSem1.
private SemIm getSem1() {
Graph graph = new EdgeListGraph();
Node l1 = new GraphNode("l1");
Node l2 = new GraphNode("l2");
l1.setNodeType(NodeType.LATENT);
l2.setNodeType(NodeType.LATENT);
List<Node> measures = new ArrayList<>();
int numMeasures = 8;
for (int i = 0; i < numMeasures; i++) {
measures.add(new GraphNode("X" + (i + 1)));
}
graph.addNode(l1);
graph.addNode(l2);
for (int i = 0; i < numMeasures; i++) {
graph.addNode(measures.get(i));
graph.addDirectedEdge(l1, measures.get(i));
graph.addDirectedEdge(l2, measures.get(i));
}
SemPm pm = new SemPm(graph);
Parameters params = new Parameters();
return new SemIm(pm, params);
}
use of edu.cmu.tetrad.sem.SemIm in project tetrad by cmu-phil.
the class TestPc method printStatsPcRegression.
private double[] printStatsPcRegression(String[] algorithms, int t, boolean directed, int numRuns, double alpha, double penaltyDiscount, int numMeasures, int numLatents, double edgeFactor, int sampleSize) {
NumberFormat nf = new DecimalFormat("0.00");
double sumAdjPrecision = 0.0;
double sumAdjRecall = 0.0;
int count = 0;
for (int i = 0; i < numRuns; i++) {
int numEdges = (int) (edgeFactor * (numMeasures + numLatents));
List<Node> nodes = new ArrayList<>();
List<String> names = new ArrayList<>();
for (int r = 0; r < numMeasures + numLatents; r++) {
String name = "X" + (r + 1);
nodes.add(new ContinuousVariable(name));
names.add(name);
}
Graph dag = GraphUtils.randomGraphRandomForwardEdges(nodes, numLatents, numEdges, 10, 10, 10, false);
SemPm pm = new SemPm(dag);
SemIm im = new SemIm(pm);
DataSet data = im.simulateData(sampleSize, false);
// Graph comparison = dag;
Graph comparison = new DagToPag(dag).convert();
// Graph comparison = new Pc(new IndTestDSep(dag)).search();
IndTestFisherZ test = new IndTestFisherZ(data, alpha);
SemBicScore score = new SemBicScore(new CovarianceMatrixOnTheFly(data));
score.setPenaltyDiscount(penaltyDiscount);
GraphSearch search;
Graph out;
Node target = null;
for (Node node : nodes) {
if (node.getNodeType() == NodeType.MEASURED) {
target = node;
break;
}
}
switch(t) {
case 0:
search = new Pc(test);
out = search.search();
break;
case 1:
search = new Cpc(test);
out = search.search();
break;
case 2:
search = new Fges(score);
out = search.search();
break;
case 3:
search = new Fci(test);
out = search.search();
break;
case 4:
search = new GFci(test, score);
out = search.search();
break;
case 5:
search = new Rfci(test);
out = search.search();
break;
case 6:
search = new Cfci(test);
out = search.search();
break;
case 7:
out = getRegressionGraph(data, target);
break;
default:
throw new IllegalStateException();
}
target = out.getNode(target.getName());
out = trim(out, target);
long start = System.currentTimeMillis();
long stop = System.currentTimeMillis();
long elapsed = stop - start;
out = GraphUtils.replaceNodes(out, dag.getNodes());
for (Node node : dag.getNodes()) {
if (!out.containsNode(node)) {
out.addNode(node);
}
}
int adjTp = 0;
int adjFp = 0;
int adjFn = 0;
for (Node node : out.getAdjacentNodes(target)) {
if (comparison.isAdjacentTo(target, node)) {
adjTp++;
} else {
adjFp++;
}
}
for (Node node : dag.getAdjacentNodes(target)) {
if (!out.isAdjacentTo(target, node)) {
adjFn++;
}
}
double adjPrecision = adjTp / (double) (adjTp + adjFp);
double adjRecall = adjTp / (double) (adjTp + adjFn);
if (!Double.isNaN(adjPrecision)) {
sumAdjPrecision += adjPrecision;
}
if (!Double.isNaN(adjRecall)) {
sumAdjRecall += adjRecall;
}
count++;
}
double avgAdjPrecision = sumAdjPrecision / (double) count;
double avgAdjRecall = sumAdjRecall / (double) count;
double[] ret = new double[] { avgAdjPrecision, avgAdjRecall };
System.out.println();
System.out.println(algorithms[t] + " adj precision " + nf.format(avgAdjPrecision));
System.out.println(algorithms[t] + " adj recall " + nf.format(avgAdjRecall));
return ret;
}
use of edu.cmu.tetrad.sem.SemIm in project tetrad by cmu-phil.
the class TestPc method testPcStable2.
@Test
public void testPcStable2() {
RandomUtil.getInstance().setSeed(1450030184196L);
List<Node> nodes = new ArrayList<>();
for (int i = 0; i < 10; i++) {
nodes.add(new ContinuousVariable("X" + (i + 1)));
}
Graph graph = GraphUtils.randomGraph(nodes, 0, 10, 30, 15, 15, false);
SemPm pm = new SemPm(graph);
SemIm im = new SemIm(pm);
DataSet data = im.simulateData(200, false);
TetradLogger.getInstance().setForceLog(false);
IndependenceTest test = new IndTestFisherZ(data, 0.05);
PcStableMax pc = new PcStableMax(test);
pc.setVerbose(false);
Graph pattern = pc.search();
for (int i = 0; i < 1; i++) {
DataSet data2 = DataUtils.reorderColumns(data);
IndependenceTest test2 = new IndTestFisherZ(data2, 0.05);
PcStableMax pc2 = new PcStableMax(test2);
pc2.setVerbose(false);
Graph pattern2 = pc2.search();
assertTrue(pattern.equals(pattern2));
}
}
Aggregations