use of edu.cmu.tetrad.sem.LargeScaleSimulation in project tetrad by cmu-phil.
the class TestGFci method test1.
public void test1() {
RandomUtil.getInstance().setSeed(1450189593459L);
int numNodes = 10;
int numLatents = 5;
int numEdges = 10;
int sampleSize = 1000;
double alpha = 0.01;
double penaltyDiscount = 2;
int depth = -1;
int maxPathLength = -1;
List<Node> vars = new ArrayList<>();
for (int i = 0; i < numNodes; i++) {
vars.add(new ContinuousVariable("X" + (i + 1)));
}
Graph dag = GraphUtils.randomGraphUniform(vars, numLatents, numEdges, 4, 4, 4, false);
// Graph dag = GraphUtils.randomGraphRandomForwardEdges1(vars, numLatents, numEdges);
// Graph dag = DataGraphUtils.scaleFreeGraph(vars, numLatents, .05, .05, .05, 3);
DataSet data;
LargeScaleSimulation simulator = new LargeScaleSimulation(dag);
simulator.setCoefRange(.5, 1.5);
simulator.setVarRange(1, 3);
data = simulator.simulateDataFisher(sampleSize);
data = DataUtils.restrictToMeasured(data);
ICovarianceMatrix cov = new CovarianceMatrix(data);
IndTestFisherZ independenceTest = new IndTestFisherZ(cov, alpha);
SemBicScore score = new SemBicScore(cov);
score.setPenaltyDiscount(penaltyDiscount);
independenceTest.setAlpha(alpha);
GFci gFci = new GFci(independenceTest, score);
gFci.setVerbose(false);
gFci.setMaxDegree(depth);
gFci.setMaxPathLength(maxPathLength);
gFci.setCompleteRuleSetUsed(false);
gFci.setFaithfulnessAssumed(true);
Graph outGraph = gFci.search();
final DagToPag dagToPag = new DagToPag(dag);
dagToPag.setCompleteRuleSetUsed(false);
dagToPag.setMaxPathLength(maxPathLength);
Graph truePag = dagToPag.convert();
outGraph = GraphUtils.replaceNodes(outGraph, truePag.getNodes());
int[][] counts = SearchGraphUtils.graphComparison(outGraph, truePag, null);
int[][] expectedCounts = { { 0, 0, 0, 0, 0, 0 }, { 0, 4, 0, 0, 0, 1 }, { 0, 0, 3, 0, 0, 1 }, { 0, 0, 0, 0, 0, 0 }, { 0, 0, 0, 0, 0, 0 }, { 0, 0, 0, 0, 0, 0 }, { 0, 0, 0, 0, 0, 0 }, { 0, 0, 0, 0, 0, 0 } };
for (int i = 0; i < counts.length; i++) {
assertTrue(Arrays.equals(counts[i], expectedCounts[i]));
}
// System.out.println(MatrixUtils.toString(counts));
// System.out.println(MatrixUtils.toString(expectedCounts));
}
use of edu.cmu.tetrad.sem.LargeScaleSimulation in project tetrad by cmu-phil.
the class PerformanceTests method testPc.
public void testPc(int numVars, double edgeFactor, int numCases, double alpha) {
int depth = -1;
init(new File("long.pc." + numVars + "." + edgeFactor + "." + alpha + ".txt"), "Tests performance of the PC algorithm");
long time1 = System.currentTimeMillis();
System.out.println("Making list of vars");
List<Node> vars = new ArrayList<>();
for (int i = 0; i < numVars; i++) {
vars.add(new ContinuousVariable("X" + i));
}
System.out.println("Finishing list of vars");
System.out.println("Making graph");
Graph graph = GraphUtils.randomGraphRandomForwardEdges(vars, 0, (int) (numVars * edgeFactor), 30, 15, 15, false, true);
System.out.println("Graph done");
out.println("Graph done");
System.out.println("Starting simulation");
LargeScaleSimulation simulator = new LargeScaleSimulation(graph);
simulator.setOut(out);
DataSet data = simulator.simulateDataFisher(numCases);
System.out.println("Finishing simulation");
long time2 = System.currentTimeMillis();
out.println("Elapsed (simulating the data): " + (time2 - time1) + " ms");
System.out.println("Making covariance matrix");
// ICovarianceMatrix cov = new CovarianceMatrix2(data);
ICovarianceMatrix cov = new CovarianceMatrixOnTheFly(data);
// ICovarianceMatrix cov = new CorreqlationMatrix(new CovarianceMatrix2(data));
// ICovarianceMatrix cov = new CovarianceMatrixOnTheFly(data, false);
// ICovarianceMatrix cov = DataUtils.covarianceParanormalDrton(data);
// ICovarianceMatrix cov = new CovarianceMatrix(DataUtils.covarianceParanormalWasserman(data));
// System.out.println(cov);
System.gc();
System.out.println("Covariance matrix done");
long time3 = System.currentTimeMillis();
out.println("Elapsed (calculating cov): " + (time3 - time2) + " ms");
// out.println(cov);
IndTestFisherZ test = new IndTestFisherZ(cov, alpha);
Pc pc = new Pc(test);
pc.setVerbose(false);
pc.setDepth(depth);
// pcStable.setOut(out);
Graph outGraph = pc.search();
out.println(outGraph);
long time4 = System.currentTimeMillis();
out.println("# Vars = " + numVars);
out.println("# Edges = " + (int) (numVars * edgeFactor));
out.println("# Cases = " + numCases);
out.println("Elapsed (simulating the data): " + (time2 - time1) + " ms");
out.println("Elapsed (calculating cov): " + (time3 - time2) + " ms");
out.println("Elapsed (running PC-Stable) " + (time4 - time3) + " ms");
out.println("Total elapsed (cov + PC-Stable) " + (time4 - time2) + " ms");
SearchGraphUtils.graphComparison(outGraph, SearchGraphUtils.patternForDag(graph), out);
out.close();
}
use of edu.cmu.tetrad.sem.LargeScaleSimulation in project tetrad by cmu-phil.
the class PerformanceTests method printStuffForKlea.
public void printStuffForKlea() {
try {
File _data = new File("data.txt");
File _graph = new File("graph.txt");
PrintStream out1 = new PrintStream(new FileOutputStream(_data));
PrintStream out2 = new PrintStream(new FileOutputStream(_graph));
int numVars = 50000;
List<Node> vars = new ArrayList<>();
for (int i = 0; i < numVars; i++) {
vars.add(new ContinuousVariable("X" + (i + 1)));
}
double edgeFactor = 1.0;
int numCases = 1000;
Graph graph = GraphUtils.randomGraphRandomForwardEdges(vars, 0, (int) (numVars * edgeFactor), 30, 15, 15, false, true);
out2.println(graph);
System.out.println("Graph done");
out.println("Graph done");
System.out.println("Starting simulation");
LargeScaleSimulation simulator = new LargeScaleSimulation(graph);
simulator.setOut(out);
DataSet data = simulator.simulateDataFisher(numCases);
out1.println(data);
out1.close();
out2.close();
} catch (Exception e) {
}
}
use of edu.cmu.tetrad.sem.LargeScaleSimulation in project tetrad by cmu-phil.
the class PerformanceTests method testCpc.
public void testCpc(int numVars, double edgeFactor, int numCases) {
double alpha = 0.0001;
int depth = -1;
init(new File("long.cpc." + numVars + ".txt"), "Tests performance of the CPC algorithm");
long time1 = System.currentTimeMillis();
System.out.println("Making list of vars");
List<Node> vars = new ArrayList<>();
for (int i = 0; i < numVars; i++) {
vars.add(new ContinuousVariable("X" + i));
}
System.out.println("Finishing list of vars");
System.out.println("Making graph");
Graph graph = GraphUtils.randomGraphRandomForwardEdges(vars, 0, (int) (numVars * edgeFactor), 30, 15, 15, false, true);
System.out.println("Graph done");
out.println("Graph done");
System.out.println("Starting simulation");
LargeScaleSimulation simulator = new LargeScaleSimulation(graph);
simulator.setOut(out);
DataSet data = simulator.simulateDataFisher(numCases);
System.out.println("Finishing simulation");
long time2 = System.currentTimeMillis();
out.println("Elapsed (simulating the data): " + (time2 - time1) + " ms");
System.out.println("Making covariance matrix");
// ICovarianceMatrix cov = new CovarianceMatrix2(data);
ICovarianceMatrix cov = new CovarianceMatrixOnTheFly(data);
// ICovarianceMatrix cov = new CorreqlationMatrix(new CovarianceMatrix2(data));
// ICovarianceMatrix cov = new CovarianceMatrixOnTheFly(data, false);
// ICovarianceMatrix cov = DataUtils.covarianceParanormalDrton(data);
// ICovarianceMatrix cov = new CovarianceMatrix(DataUtils.covarianceParanormalWasserman(data));
// System.out.println(cov);
System.gc();
System.out.println("Covariance matrix done");
long time3 = System.currentTimeMillis();
out.println("Elapsed (calculating cov): " + (time3 - time2) + " ms");
// out.println(cov);
IndTestFisherZ test = new IndTestFisherZ(cov, alpha);
Cpc cpc = new Cpc(test);
cpc.setVerbose(false);
cpc.setDepth(depth);
// pcStable.setOut(out);
Graph outGraph = cpc.search();
out.println(outGraph);
long time4 = System.currentTimeMillis();
out.println("# Vars = " + numVars);
out.println("# Edges = " + (int) (numVars * edgeFactor));
out.println("# Cases = " + numCases);
out.println("Elapsed (simulating the data): " + (time2 - time1) + " ms");
out.println("Elapsed (calculating cov): " + (time3 - time2) + " ms");
out.println("Elapsed (running PC-Stable) " + (time4 - time3) + " ms");
out.println("Total elapsed (cov + PC-Stable) " + (time4 - time2) + " ms");
SearchGraphUtils.graphComparison(outGraph, SearchGraphUtils.patternForDag(graph), out);
out.close();
}
use of edu.cmu.tetrad.sem.LargeScaleSimulation in project tetrad by cmu-phil.
the class PerformanceTests method testGFciComparison.
public void testGFciComparison(int numVars, double edgeFactor, int numCases, int numLatents) {
numVars = 1000;
edgeFactor = 1.0;
numLatents = 100;
numCases = 1000;
int numRuns = 5;
double alpha = 0.01;
double penaltyDiscount = 3.0;
int depth = 3;
int maxPathLength = 3;
boolean possibleDsepDone = true;
boolean completeRuleSetUsed = false;
boolean faithfulnessAssumed = true;
init(new File("fci.algorithm.comparison" + numVars + "." + (int) (edgeFactor * numVars) + "." + numCases + ".txt"), "Num runs = " + numRuns);
out.println("Num vars = " + numVars);
out.println("Num edges = " + (int) (numVars * edgeFactor));
out.println("Num cases = " + numCases);
out.println("Alpha = " + alpha);
out.println("Penalty discount = " + penaltyDiscount);
out.println("Depth = " + depth);
out.println("Maximum reachable path length for dsep search and discriminating undirectedPaths = " + maxPathLength);
out.println("Num additional latent common causes = " + numLatents);
out.println("Possible Dsep Done = " + possibleDsepDone);
out.println("Complete Rule Set Used = " + completeRuleSetUsed);
out.println();
List<GraphUtils.GraphComparison> ffciCounts = new ArrayList<>();
List<double[]> ffciArrowStats = new ArrayList<>();
List<double[]> ffciTailStats = new ArrayList<>();
List<Long> ffciElapsedTimes = new ArrayList<>();
for (int run = 0; run < numRuns; run++) {
out.println("\n\n\n******************************** RUN " + (run + 1) + " ********************************\n\n");
System.out.println("Making list of vars");
List<Node> vars = new ArrayList<>();
for (int i = 0; i < numVars; i++) {
vars.add(new ContinuousVariable("X" + (i + 1)));
}
System.out.println("Finishing list of vars");
Graph dag = getLatentGraph(vars, edgeFactor, numLatents);
System.out.println("Graph done");
final DagToPag dagToPag = new DagToPag(dag);
dagToPag.setCompleteRuleSetUsed(false);
dagToPag.setMaxPathLength(maxPathLength);
Graph truePag = dagToPag.convert();
System.out.println("True PAG_of_the_true_DAG done");
// Data.
System.out.println("Starting simulation");
LargeScaleSimulation simulator = new LargeScaleSimulation(dag);
simulator.setCoefRange(.5, 1.5);
simulator.setVarRange(1, 3);
DataSet data = simulator.simulateDataFisher(numCases);
data = DataUtils.restrictToMeasured(data);
System.out.println("Finishing simulation");
System.out.println("Making covariance matrix");
ICovarianceMatrix cov = new CovarianceMatrixOnTheFly(data);
System.out.println("Covariance matrix done");
// Independence test.
final IndTestFisherZ independenceTest = new IndTestFisherZ(cov, alpha);
final SemBicScore score = new SemBicScore(cov);
score.setPenaltyDiscount(penaltyDiscount);
Graph estPag;
long elapsed;
// out.println("\n\n\n========================FCI run " + (run + 1));
out.println("\n\n\n========================TGFCI run " + (run + 1));
long ta1 = System.currentTimeMillis();
// FCI fci = new FCI(independenceTest);
GFci fci = new GFci(independenceTest, score);
// TFci fci = new TFci(independenceTest);
// fci.setVerbose(false);
fci.setMaxDegree(depth);
fci.setMaxPathLength(maxPathLength);
// fci.setPossibleDsepSearchDone(possibleDsepDone);
fci.setCompleteRuleSetUsed(completeRuleSetUsed);
fci.setFaithfulnessAssumed(faithfulnessAssumed);
estPag = fci.search();
long ta2 = System.currentTimeMillis();
estPag = GraphUtils.replaceNodes(estPag, truePag.getNodes());
Set<Node> missingNodes = new HashSet<>();
for (Node node : dag.getNodes()) {
if (!estPag.containsNode(node)) {
missingNodes.add(node);
}
}
ffciArrowStats.add(printCorrectArrows(dag, estPag, truePag));
ffciTailStats.add(printCorrectTails(dag, estPag, truePag));
ffciCounts.add(SearchGraphUtils.getGraphComparison2(estPag, truePag));
elapsed = ta2 - ta1;
ffciElapsedTimes.add(elapsed);
out.println("\nElapsed: " + elapsed + " ms");
try {
PrintStream out2 = new PrintStream(new File("dag." + run + ".txt"));
out2.println(dag);
PrintStream out3 = new PrintStream(new File("estpag." + run + ".txt"));
out3.println(estPag);
PrintStream out4 = new PrintStream(new File("truepag." + run + ".txt"));
out4.println(truePag);
out2.close();
out3.close();
out4.close();
} catch (FileNotFoundException e) {
e.printStackTrace();
throw new RuntimeException(e);
}
}
// printAverageConfusion("Average", ffciCounts);
printAverageStatistics(ffciElapsedTimes, new ArrayList<Double>());
out.close();
}
Aggregations