use of edu.cmu.tetrad.sem.SemPm in project tetrad by cmu-phil.
the class TestPc method printStats.
private double[] printStats(String[] algorithms, int t, boolean directed, int numRuns, double alpha, double penaltyDiscount, int numMeasures, int numLatents, double edgeFactor) {
NumberFormat nf = new DecimalFormat("0.00");
double sumArrowPrecision = 0.0;
double sumTailPrecision = 0.0;
double sumBidirectedPrecision = 0.0;
int numArrows = 0;
int numTails = 0;
int numBidirected = 0;
int count = 0;
int totalElapsed = 0;
int countAP = 0;
int countTP = 0;
int countBP = 0;
for (int i = 0; i < numRuns; i++) {
int numEdges = (int) (edgeFactor * (numMeasures + numLatents));
List<Node> nodes = new ArrayList<>();
List<String> names = new ArrayList<>();
for (int r = 0; r < numMeasures + numLatents; r++) {
String name = "X" + (r + 1);
nodes.add(new ContinuousVariable(name));
names.add(name);
}
Graph dag = GraphUtils.randomGraphRandomForwardEdges(nodes, numLatents, numEdges, 10, 10, 10, false);
SemPm pm = new SemPm(dag);
SemIm im = new SemIm(pm);
DataSet data = im.simulateData(1000, false);
IndTestFisherZ test = new IndTestFisherZ(data, alpha);
SemBicScore score = new SemBicScore(new CovarianceMatrixOnTheFly(data));
score.setPenaltyDiscount(penaltyDiscount);
GraphSearch search;
switch(t) {
case 0:
search = new Pc(test);
break;
case 1:
search = new Cpc(test);
break;
case 2:
search = new Fges(score);
break;
case 3:
search = new Fci(test);
break;
case 4:
search = new GFci(test, score);
break;
case 5:
search = new Rfci(test);
break;
case 6:
search = new Cfci(test);
break;
default:
throw new IllegalStateException();
}
long start = System.currentTimeMillis();
Graph out = search.search();
long stop = System.currentTimeMillis();
long elapsed = stop - start;
totalElapsed += elapsed;
out = GraphUtils.replaceNodes(out, dag.getNodes());
int arrowsTp = 0;
int arrowsFp = 0;
int tailsTp = 0;
int tailsFp = 0;
int bidirectedTp = 0;
int bidirectedFp = 0;
for (Edge edge : out.getEdges()) {
if (directed && !(edge.isDirected() || Edges.isBidirectedEdge(edge))) {
continue;
}
if (edge.getEndpoint1() == Endpoint.ARROW) {
if (!dag.isAncestorOf(edge.getNode1(), edge.getNode2()) && dag.existsTrek(edge.getNode1(), edge.getNode2())) {
arrowsTp++;
} else {
arrowsFp++;
}
numArrows++;
}
if (edge.getEndpoint2() == Endpoint.ARROW) {
if (!dag.isAncestorOf(edge.getNode2(), edge.getNode1()) && dag.existsTrek(edge.getNode1(), edge.getNode2())) {
arrowsTp++;
} else {
arrowsFp++;
}
numArrows++;
}
if (edge.getEndpoint1() == Endpoint.TAIL) {
if (dag.existsDirectedPathFromTo(edge.getNode1(), edge.getNode2())) {
tailsTp++;
} else {
tailsFp++;
}
numTails++;
}
if (edge.getEndpoint2() == Endpoint.TAIL) {
if (dag.existsDirectedPathFromTo(edge.getNode2(), edge.getNode1())) {
tailsTp++;
} else {
tailsFp++;
}
numTails++;
}
if (Edges.isBidirectedEdge(edge)) {
if (!dag.isAncestorOf(edge.getNode1(), edge.getNode2()) && !dag.isAncestorOf(edge.getNode2(), edge.getNode1()) && dag.existsTrek(edge.getNode1(), edge.getNode2())) {
bidirectedTp++;
} else {
bidirectedFp++;
}
numBidirected++;
}
}
double arrowPrecision = arrowsTp / (double) (arrowsTp + arrowsFp);
double tailPrecision = tailsTp / (double) (tailsTp + tailsFp);
double bidirectedPrecision = bidirectedTp / (double) (bidirectedTp + bidirectedFp);
if (!Double.isNaN(arrowPrecision)) {
sumArrowPrecision += arrowPrecision;
countAP++;
}
if (!Double.isNaN(tailPrecision)) {
sumTailPrecision += tailPrecision;
countTP++;
}
if (!Double.isNaN(bidirectedPrecision)) {
sumBidirectedPrecision += bidirectedPrecision;
countBP++;
}
count++;
}
double avgArrowPrecision = sumArrowPrecision / (double) countAP;
double avgTailPrecision = sumTailPrecision / (double) countTP;
double avgBidirectedPrecision = sumBidirectedPrecision / (double) countBP;
double avgNumArrows = numArrows / (double) count;
double avgNumTails = numTails / (double) count;
double avgNumBidirected = numBidirected / (double) count;
double avgElapsed = totalElapsed / (double) numRuns;
// double avgRatioPrecisionToElapsed = avgArrowPrecision / avgElapsed;
double[] ret = new double[] { avgArrowPrecision, avgTailPrecision, avgBidirectedPrecision, avgNumArrows, avgNumTails, avgNumBidirected, // minimize
-avgElapsed // avgRatioPrecisionToElapsed
};
System.out.println();
NumberFormat nf2 = new DecimalFormat("0.0000");
System.out.println(algorithms[t] + " arrow precision " + nf.format(avgArrowPrecision));
System.out.println(algorithms[t] + " tail precision " + nf.format(avgTailPrecision));
System.out.println(algorithms[t] + " bidirected precision " + nf.format(avgBidirectedPrecision));
System.out.println(algorithms[t] + " avg num arrow " + nf.format(avgNumArrows));
System.out.println(algorithms[t] + " avg num tails " + nf.format(avgNumTails));
System.out.println(algorithms[t] + " avg num bidirected " + nf.format(avgNumBidirected));
System.out.println(algorithms[t] + " avg elapsed " + nf.format(avgElapsed));
return ret;
}
use of edu.cmu.tetrad.sem.SemPm in project tetrad by cmu-phil.
the class TestStandardizedSem method rtest8.
public void rtest8() {
// RandomUtil.getInstance().setSeed(2958442283L);
SemGraph graph = new SemGraph();
Node x = new ContinuousVariable("X");
Node y = new ContinuousVariable("Y");
Node z = new ContinuousVariable("Z");
graph.addNode(x);
graph.addNode(y);
graph.addNode(z);
graph.addDirectedEdge(x, y);
graph.addBidirectedEdge(x, y);
graph.addDirectedEdge(x, z);
graph.addDirectedEdge(y, z);
graph.setShowErrorTerms(true);
SemPm semPm = new SemPm(graph);
SemIm semIm = new SemIm(semPm);
StandardizedSemIm sem = new StandardizedSemIm(semIm, StandardizedSemIm.Initialization.CALCULATE_FROM_SEM);
DataSet data = semIm.simulateDataCholesky(1000, false);
data = ColtDataSet.makeContinuousData(data.getVariables(), DataUtils.standardizeData(data.getDoubleData()));
SemEstimator estimator = new SemEstimator(data, semPm);
semIm = estimator.estimate();
DataSet data2 = semIm.simulateDataReducedForm(1000, false);
DataSet data3 = sem.simulateDataReducedForm(1000, false);
StandardizedSemIm.ParameterRange range2 = sem.getCovarianceRange(x, y);
double high = range2.getHigh();
double low = range2.getLow();
if (high == Double.POSITIVE_INFINITY)
high = 1000;
if (low == Double.NEGATIVE_INFINITY)
low = -1000;
double coef = low + RandomUtil.getInstance().nextDouble() * (high - low);
assertTrue(sem.setErrorCovariance(x, y, coef));
assertTrue(isStandardized(sem));
}
use of edu.cmu.tetrad.sem.SemPm in project tetrad by cmu-phil.
the class TestStandardizedSem method test5.
@Test
public void test5() {
RandomUtil.getInstance().setSeed(582374923L);
SemGraph graph = new SemGraph();
graph.setShowErrorTerms(true);
Node x1 = new ContinuousVariable("X1");
Node x2 = new ContinuousVariable("X2");
Node x3 = new ContinuousVariable("X3");
graph.addNode(x1);
graph.addNode(x2);
graph.addNode(x3);
graph.setShowErrorTerms(true);
Node ex1 = graph.getExogenous(x1);
Node ex2 = graph.getExogenous(x2);
Node ex3 = graph.getExogenous(x3);
graph.addDirectedEdge(x1, x3);
graph.addDirectedEdge(x2, x3);
// graph.addDirectedEdge(x1, x2);
// graph.addBidirectedEdge(ex1, ex2);
SemPm pm = new SemPm(graph);
SemIm im = new SemIm(pm);
DataSet dataSet = im.simulateDataRecursive(1000, false);
TetradMatrix _dataSet = dataSet.getDoubleData();
_dataSet = DataUtils.standardizeData(_dataSet);
DataSet dataSetStandardized = ColtDataSet.makeData(dataSet.getVariables(), _dataSet);
SemEstimator estimator = new SemEstimator(dataSetStandardized, im.getSemPm());
SemIm imStandardized = estimator.estimate();
StandardizedSemIm sem = new StandardizedSemIm(im);
// sem.setErrorCovariance(ex1, ex2, -.24);
assertTrue(isStandardized(sem));
}
use of edu.cmu.tetrad.sem.SemPm in project tetrad by cmu-phil.
the class TestStandardizedSem method test2.
@Test
public void test2() {
RandomUtil.getInstance().setSeed(5729384723L);
SemGraph graph = new SemGraph();
Node x1 = new ContinuousVariable("X1");
Node x2 = new ContinuousVariable("X2");
Node x3 = new ContinuousVariable("X3");
Node x4 = new ContinuousVariable("X4");
Node x5 = new ContinuousVariable("X5");
graph.addNode(x1);
graph.addNode(x2);
graph.addNode(x3);
graph.addNode(x4);
graph.addNode(x5);
graph.setShowErrorTerms(true);
graph.addDirectedEdge(x1, x2);
graph.addDirectedEdge(x2, x3);
graph.addDirectedEdge(x4, x3);
graph.addDirectedEdge(x2, x4);
graph.addDirectedEdge(x1, x4);
graph.addDirectedEdge(x5, x4);
SemPm pm = new SemPm(graph);
SemIm im = new SemIm(pm);
StandardizedSemIm sem = new StandardizedSemIm(im);
assertTrue(isStandardized(sem));
}
use of edu.cmu.tetrad.sem.SemPm in project tetrad by cmu-phil.
the class TestStandardizedSem method test6.
@Test
public void test6() {
// RandomUtil.getInstance().setSeed(582374923L);
SemGraph graph = new SemGraph();
graph.setShowErrorTerms(true);
Node x1 = new ContinuousVariable("X1");
Node x2 = new ContinuousVariable("X2");
Node x3 = new ContinuousVariable("X3");
graph.addNode(x1);
graph.addNode(x2);
graph.addNode(x3);
graph.setShowErrorTerms(true);
Node ex1 = graph.getExogenous(x1);
Node ex2 = graph.getExogenous(x2);
Node ex3 = graph.getExogenous(x3);
graph.addDirectedEdge(x1, x3);
graph.addDirectedEdge(x2, x3);
graph.addDirectedEdge(x1, x2);
graph.addBidirectedEdge(ex1, ex2);
// List<List<Node>> treks = DataGraphUtils.treksIncludingBidirected(graph, x1, x3);
//
// for (List<Node> trek : treks) {
// System.out.println(trek);
// }
SemPm pm = new SemPm(graph);
SemIm im = new SemIm(pm);
DataSet dataSet = im.simulateDataRecursive(1000, false);
TetradMatrix _dataSet = dataSet.getDoubleData();
_dataSet = DataUtils.standardizeData(_dataSet);
DataSet dataSetStandardized = ColtDataSet.makeData(dataSet.getVariables(), _dataSet);
SemEstimator estimator = new SemEstimator(dataSetStandardized, im.getSemPm());
SemIm imStandardized = estimator.estimate();
StandardizedSemIm sem = new StandardizedSemIm(im);
assertTrue(isStandardized(sem));
}
Aggregations