use of edu.cmu.tetrad.sem.SemIm in project tetrad by cmu-phil.
the class TestFci method testSearch15.
@Test
public void testSearch15() {
int numVars = 80;
int numEdges = 80;
int sampleSize = 1000;
boolean latentDataSaved = false;
int numLatents = 40;
List<Node> nodes = new ArrayList<>();
for (int i = 0; i < numVars; i++) {
nodes.add(new ContinuousVariable("X" + (i + 1)));
}
Dag trueGraph = new Dag(GraphUtils.randomGraph(nodes, numLatents, numEdges, 7, 5, 5, false));
SemPm bayesPm = new SemPm(trueGraph);
SemIm bayesIm = new SemIm(bayesPm);
DataSet dataSet = bayesIm.simulateData(sampleSize, latentDataSaved);
IndependenceTest test = new IndTestFisherZ(dataSet, 0.05);
Cfci search = new Cfci(test);
// Run search
search.search();
}
use of edu.cmu.tetrad.sem.SemIm in project tetrad by cmu-phil.
the class TestFci method testFciAnc.
// @Test
public void testFciAnc() {
int numMeasures = 50;
double edgeFactor = 2.0;
int numRuns = 10;
double alpha = 0.01;
double penaltyDiscount = 4.0;
int numVarsToMarginalize = 5;
int numLatents = 10;
System.out.println("num measures = " + numMeasures);
System.out.println("edge factor = " + edgeFactor);
System.out.println("alpha = " + alpha);
System.out.println("penaltyDiscount = " + penaltyDiscount);
System.out.println("num runs = " + numRuns);
System.out.println("num vars to marginalize = " + numVarsToMarginalize);
System.out.println("num latents = " + numLatents);
System.out.println();
for (int i = 0; i < numRuns; i++) {
int numEdges = (int) (edgeFactor * (numMeasures + numLatents));
List<Node> nodes = new ArrayList<>();
for (int r = 0; r < numMeasures + numLatents; r++) {
String name = "X" + (r + 1);
nodes.add(new ContinuousVariable(name));
}
Graph dag = GraphUtils.randomGraphRandomForwardEdges(nodes, numLatents, numEdges, 10, 10, 10, false);
SemPm pm = new SemPm(dag);
SemIm im = new SemIm(pm);
DataSet data = im.simulateData(1000, false);
Graph pag = getPag(alpha, penaltyDiscount, data);
DataSet marginalData = data.copy();
List<Node> variables = marginalData.getVariables();
Collections.shuffle(variables);
for (int m = 0; m < numVarsToMarginalize; m++) {
marginalData.removeColumn(marginalData.getColumn(variables.get(m)));
}
Graph margPag = getPag(alpha, penaltyDiscount, marginalData);
int ancAnc = 0;
int ancNanc = 0;
int nancAnc = 0;
int nancNanc = 0;
int ambAnc = 0;
int ambNanc = 0;
int totalAncMarg = 0;
int totalNancMarg = 0;
for (Node n1 : marginalData.getVariables()) {
for (Node n2 : marginalData.getVariables()) {
if (n1 == n2)
continue;
if (ancestral(n1, n2, margPag)) {
if (ancestral(n1, n2, pag)) {
ancAnc++;
} else if (nonAncestral(n1, n2, pag)) {
nancAnc++;
} else {
ambAnc++;
}
totalAncMarg++;
} else if (nonAncestral(n1, n2, margPag)) {
if (ancestral(n1, n2, pag)) {
ancNanc++;
} else if (nonAncestral(n1, n2, pag)) {
nancNanc++;
} else {
ambNanc++;
}
totalNancMarg++;
}
}
}
// {
// TextTable table = new TextTable(5, 3);
// table.setToken(0, 1, "Ancestral");
// table.setToken(0, 2, "Nonancestral");
// table.setToken(1, 0, "Ancestral");
// table.setToken(2, 0, "Nonancestral");
// table.setToken(3, 0, "Ambiguous");
// table.setToken(4, 0, "Total");
//
// table.setToken(1, 1, ancAnc + "");
// table.setToken(2, 1, nancAnc + "");
// table.setToken(3, 1, ambAnc + "");
// table.setToken(1, 2, ancNanc + "");
// table.setToken(2, 2, nancNanc + "");
// table.setToken(3, 2, ambNanc + "");
// table.setToken(4, 1, totalAncMarg + "");
// table.setToken(4, 2, totalNancMarg + "");
//
// System.out.println(table);
// }
{
TextTable table = new TextTable(5, 3);
table.setToken(0, 1, "Ancestral");
table.setToken(0, 2, "Nonancestral");
table.setToken(1, 0, "Ancestral");
table.setToken(2, 0, "Nonancestral");
table.setToken(3, 0, "Ambiguous");
table.setToken(4, 0, "Total");
NumberFormat nf = new DecimalFormat("0.00");
table.setToken(1, 1, nf.format(ancAnc / (double) totalAncMarg) + "");
table.setToken(2, 1, nf.format(nancAnc / (double) totalAncMarg) + "");
table.setToken(3, 1, nf.format(ambAnc / (double) totalAncMarg) + "");
table.setToken(1, 2, nf.format(ancNanc / (double) totalNancMarg) + "");
table.setToken(2, 2, nf.format(nancNanc / (double) totalNancMarg) + "");
table.setToken(3, 2, nf.format(ambNanc / (double) totalNancMarg) + "");
table.setToken(4, 1, totalAncMarg + "");
table.setToken(4, 2, totalNancMarg + "");
System.out.println(table);
}
}
}
use of edu.cmu.tetrad.sem.SemIm in project tetrad by cmu-phil.
the class TestSemProposition method testEvidence.
@Test
public void testEvidence() {
Graph graph = constructGraph1();
SemPm semPm = new SemPm(graph);
SemIm semIm = new SemIm(semPm);
List nodes = semIm.getVariableNodes();
SemProposition proposition = SemProposition.tautology(semIm);
for (int i = 0; i < semIm.getVariableNodes().size(); i++) {
assertTrue(Double.isNaN(proposition.getValue(i)));
}
proposition.setValue(1, 0.5);
assertEquals(0.5, proposition.getValue(1), 0.0);
Node node4 = (Node) nodes.get(3);
proposition.setValue(node4, 0.7);
assertEquals(0.7, proposition.getValue(node4), 0.0);
}
use of edu.cmu.tetrad.sem.SemIm in project tetrad by cmu-phil.
the class TestSemVarMeans method testMeansCholesky.
@Test
public void testMeansCholesky() {
Graph graph = constructGraph1();
SemPm semPm1 = new SemPm(graph);
List<Parameter> parameters = semPm1.getParameters();
for (Parameter p : parameters) {
p.setInitializedRandomly(false);
}
SemIm semIm1 = new SemIm(semPm1);
double[] means = { 5.0, 4.0, 3.0, 2.0, 1.0 };
RandomUtil.getInstance().setSeed(-379467L);
for (int i = 0; i < semIm1.getVariableNodes().size(); i++) {
Node node = semIm1.getVariableNodes().get(i);
semIm1.setMean(node, means[i]);
}
DataSet dataSet = semIm1.simulateDataCholesky(1000, false);
SemEstimator semEst = new SemEstimator(dataSet, semPm1);
semEst.estimate();
SemIm estSemIm = semEst.getEstimatedSem();
List<Node> nodes = semPm1.getVariableNodes();
for (Node node : nodes) {
double mean = semIm1.getMean(node);
assertEquals(mean, estSemIm.getMean(node), 0.6);
}
}
use of edu.cmu.tetrad.sem.SemIm in project tetrad by cmu-phil.
the class TestSemVarMeans method testMeansReducedForm.
@Test
public void testMeansReducedForm() {
Graph graph = constructGraph1();
SemPm semPm1 = new SemPm(graph);
List<Parameter> parameters = semPm1.getParameters();
for (Parameter p : parameters) {
p.setInitializedRandomly(false);
}
SemIm semIm1 = new SemIm(semPm1);
double[] means = { 5.0, 4.0, 3.0, 2.0, 1.0 };
RandomUtil.getInstance().setSeed(-379467L);
for (int i = 0; i < semIm1.getVariableNodes().size(); i++) {
Node node = semIm1.getVariableNodes().get(i);
semIm1.setMean(node, means[i]);
}
DataSet dataSet = semIm1.simulateDataReducedForm(1000, false);
SemEstimator semEst = new SemEstimator(dataSet, semPm1);
semEst.estimate();
SemIm estSemIm = semEst.getEstimatedSem();
List<Node> nodes = semPm1.getVariableNodes();
for (Node node : nodes) {
double mean = semIm1.getMean(node);
assertEquals(mean, estSemIm.getMean(node), 0.5);
}
}
Aggregations