use of edu.cmu.tetrad.search.IndTestFisherZ in project tetrad by cmu-phil.
the class FisherZScore method getScore.
@Override
public Score getScore(DataModel dataSet, Parameters parameters) {
this.dataSet = dataSet;
double alpha = parameters.getDouble("alpha");
this.alpha = alpha;
IndTestFisherZ test = new IndTestFisherZ((DataSet) dataSet, alpha);
return new ScoredIndTest(test);
}
use of edu.cmu.tetrad.search.IndTestFisherZ in project tetrad by cmu-phil.
the class FisherZ method getTest.
@Override
public IndependenceTest getTest(DataModel dataSet, Parameters parameters) {
double alpha = parameters.getDouble("alpha");
this.alpha = alpha;
if (dataSet instanceof ICovarianceMatrix) {
return new IndTestFisherZ((ICovarianceMatrix) dataSet, alpha);
} else if (dataSet instanceof DataSet) {
return new IndTestFisherZ((DataSet) dataSet, alpha);
}
throw new IllegalArgumentException("Expecting eithet a data set or a covariance matrix.");
}
use of edu.cmu.tetrad.search.IndTestFisherZ in project tetrad by cmu-phil.
the class PerformanceTestsDan method testIdaOutputForDan.
private void testIdaOutputForDan() {
int numRuns = 100;
for (int run = 0; run < numRuns; run++) {
double alphaGFci = 0.01;
double alphaPc = 0.01;
int penaltyDiscount = 1;
int depth = 3;
int maxPathLength = 3;
final int numVars = 15;
final double edgesPerNode = 1.0;
final int numCases = 1000;
// final int numLatents = RandomUtil.getInstance().nextInt(3) + 1;
final int numLatents = 6;
// writeToFile = false;
PrintStream out1;
PrintStream out2;
PrintStream out3;
PrintStream out4;
PrintStream out5;
PrintStream out6;
PrintStream out7;
PrintStream out8;
PrintStream out9;
PrintStream out10;
PrintStream out11;
PrintStream out12;
File dir0 = new File("gfci.output");
dir0.mkdirs();
File dir = new File(dir0, "" + (run + 1));
dir.mkdir();
try {
out1 = new PrintStream(new File(dir, "hyperparameters.txt"));
out2 = new PrintStream(new File(dir, "variables.txt"));
out3 = new PrintStream(new File(dir, "dag.long.txt"));
out4 = new PrintStream(new File(dir, "dag.matrix.txt"));
out5 = new PrintStream(new File(dir, "coef.matrix.txt"));
out6 = new PrintStream(new File(dir, "pag.long.txt"));
out7 = new PrintStream(new File(dir, "pag.matrix.txt"));
out8 = new PrintStream(new File(dir, "pattern.long.txt"));
out9 = new PrintStream(new File(dir, "pattern.matrix.txt"));
out10 = new PrintStream(new File(dir, "data.txt"));
out11 = new PrintStream(new File(dir, "true.pag.long.txt"));
out12 = new PrintStream(new File(dir, "true.pag.matrix.txt"));
} catch (FileNotFoundException e) {
e.printStackTrace();
throw new RuntimeException(e);
}
out1.println("Num _vars = " + numVars);
out1.println("Num edges = " + (int) (numVars * edgesPerNode));
out1.println("Num cases = " + numCases);
out1.println("Alpha for PC = " + alphaPc);
out1.println("Alpha for FFCI = " + alphaGFci);
out1.println("Penalty discount = " + penaltyDiscount);
out1.println("Depth = " + depth);
out1.println("Maximum reachable path length for dsep search and discriminating undirectedPaths = " + maxPathLength);
List<Node> vars = new ArrayList<>();
for (int i = 0; i < numVars; i++) vars.add(new GraphNode("X" + (i + 1)));
// Graph dag = DataGraphUtils.randomDagQuick2(varsWithLatents, 0, (int) (varsWithLatents.size() * edgesPerNode));
Graph dag = GraphUtils.randomGraph(vars, 0, (int) (vars.size() * edgesPerNode), 5, 5, 5, false);
GraphUtils.fixLatents1(numLatents, dag);
// List<Node> varsWithLatents = new ArrayList<Node>();
//
// Graph dag = getLatentGraph(_vars, varsWithLatents, edgesPerNode, numLatents);
out3.println(dag);
printDanMatrix(vars, dag, out4);
SemPm pm = new SemPm(dag);
SemIm im = new SemIm(pm);
NumberFormat nf = new DecimalFormat("0.0000");
for (int i = 0; i < vars.size(); i++) {
for (Node var : vars) {
if (im.existsEdgeCoef(var, vars.get(i))) {
double coef = im.getEdgeCoef(var, vars.get(i));
out5.print(nf.format(coef) + "\t");
} else {
out5.print(nf.format(0) + "\t");
}
}
out5.println();
}
out5.println();
String vars_temp = vars.toString();
vars_temp = vars_temp.replace("[", "");
vars_temp = vars_temp.replace("]", "");
vars_temp = vars_temp.replace("X", "");
out2.println(vars_temp);
List<Node> _vars = new ArrayList<>();
for (Node node : vars) {
if (node.getNodeType() == NodeType.MEASURED) {
_vars.add(node);
}
}
String _vars_temp = _vars.toString();
_vars_temp = _vars_temp.replace("[", "");
_vars_temp = _vars_temp.replace("]", "");
_vars_temp = _vars_temp.replace("X", "");
out2.println(_vars_temp);
DataSet fullData = im.simulateData(numCases, false);
DataSet data = DataUtils.restrictToMeasured(fullData);
ICovarianceMatrix cov = new CovarianceMatrix(data);
final IndTestFisherZ independenceTestGFci = new IndTestFisherZ(cov, alphaGFci);
final edu.cmu.tetrad.search.SemBicScore scoreGfci = new edu.cmu.tetrad.search.SemBicScore(cov);
out6.println("GFCI.PAG_of_the_true_DAG");
GFci gFci = new GFci(independenceTestGFci, scoreGfci);
gFci.setVerbose(false);
gFci.setMaxDegree(depth);
gFci.setMaxPathLength(maxPathLength);
// gFci.setPossibleDsepSearchDone(true);
gFci.setCompleteRuleSetUsed(true);
Graph pag = gFci.search();
out6.println(pag);
printDanMatrix(_vars, pag, out7);
out8.println("Pattern_of_the_true_DAG OVER MEASURED VARIABLES");
final IndTestFisherZ independencePc = new IndTestFisherZ(cov, alphaPc);
Pc pc = new Pc(independencePc);
pc.setVerbose(false);
pc.setDepth(depth);
Graph pattern = pc.search();
out8.println(pattern);
printDanMatrix(_vars, pattern, out9);
out10.println(data);
out11.println("True PAG_of_the_true_DAG");
final Graph truePag = new DagToPag(dag).convert();
out11.println(truePag);
printDanMatrix(_vars, truePag, out12);
out1.close();
out2.close();
out3.close();
out4.close();
out5.close();
out6.close();
out7.close();
out8.close();
out9.close();
out10.close();
out11.close();
out12.close();
}
}
use of edu.cmu.tetrad.search.IndTestFisherZ in project tetrad by cmu-phil.
the class PcFges method search.
@Override
public Graph search(DataModel dataSet, Parameters parameters) {
if (parameters.getInt("bootstrapSampleSize") < 1) {
DataSet _dataSet = (DataSet) dataSet;
ICovarianceMatrix cov = new CovarianceMatrix(_dataSet);
// parameters.getDouble("alpha")));
edu.cmu.tetrad.search.FasStable fas = new edu.cmu.tetrad.search.FasStable(new IndTestFisherZ(cov, 0.001));
Graph bound = fas.search();
edu.cmu.tetrad.search.Fges search = new edu.cmu.tetrad.search.Fges(score.getScore(cov, parameters));
search.setVerbose(parameters.getBoolean("verbose"));
search.setFaithfulnessAssumed(parameters.getBoolean("faithfulnessAssumed"));
search.setKnowledge(knowledge);
search.setMaxDegree(parameters.getInt("maxDegree"));
search.setSymmetricFirstStep(parameters.getBoolean("symmetricFirstStep"));
System.out.println("Bound graph done");
Object obj = parameters.get("printStream");
if (obj instanceof PrintStream) {
search.setOut((PrintStream) obj);
}
search.setBoundGraph(bound);
return search.search();
} else {
PcFges algorithm = new PcFges(score, compareToTrue);
// algorithm.setKnowledge(knowledge);
if (initialGraph != null) {
algorithm.setInitialGraph(initialGraph);
}
DataSet data = (DataSet) dataSet;
GeneralBootstrapTest search = new GeneralBootstrapTest(data, algorithm, parameters.getInt("bootstrapSampleSize"));
search.setKnowledge(knowledge);
BootstrapEdgeEnsemble edgeEnsemble = BootstrapEdgeEnsemble.Highest;
switch(parameters.getInt("bootstrapEnsemble", 1)) {
case 0:
edgeEnsemble = BootstrapEdgeEnsemble.Preserved;
break;
case 1:
edgeEnsemble = BootstrapEdgeEnsemble.Highest;
break;
case 2:
edgeEnsemble = BootstrapEdgeEnsemble.Majority;
}
search.setEdgeEnsemble(edgeEnsemble);
search.setParameters(parameters);
search.setVerbose(parameters.getBoolean("verbose"));
return search.search();
}
}
use of edu.cmu.tetrad.search.IndTestFisherZ in project tetrad by cmu-phil.
the class TestIndTestFisherZ method testDirections.
@Test
public void testDirections() {
RandomUtil.getInstance().setSeed(48285934L);
Graph graph1 = new EdgeListGraph();
Graph graph2 = new EdgeListGraph();
Node x = new GraphNode("X");
Node y = new GraphNode("Y");
Node z = new GraphNode("Z");
graph1.addNode(x);
graph1.addNode(y);
graph1.addNode(z);
graph2.addNode(x);
graph2.addNode(y);
graph2.addNode(z);
graph1.addEdge(Edges.directedEdge(x, y));
graph1.addEdge(Edges.directedEdge(y, z));
graph2.addEdge(Edges.directedEdge(x, y));
graph2.addEdge(Edges.directedEdge(z, y));
SemPm pm1 = new SemPm(graph1);
SemPm pm2 = new SemPm(graph2);
SemIm im1 = new SemIm(pm1);
SemIm im2 = new SemIm(pm2);
im2.setEdgeCoef(x, y, im1.getEdgeCoef(x, y));
im2.setEdgeCoef(z, y, im1.getEdgeCoef(y, z));
DataSet data1 = im1.simulateData(500, false);
DataSet data2 = im2.simulateData(500, false);
IndependenceTest test1 = new IndTestFisherZ(data1, 0.05);
IndependenceTest test2 = new IndTestFisherZ(data2, 0.05);
test1.isIndependent(data1.getVariable(x.getName()), data1.getVariable(y.getName()));
double p1 = test1.getPValue();
test2.isIndependent(data2.getVariable(x.getName()), data2.getVariable(z.getName()), data2.getVariable(y.getName()));
double p2 = test2.getPValue();
test2.isIndependent(data2.getVariable(x.getName()), data2.getVariable(z.getName()));
double p3 = test2.getPValue();
assertEquals(0, p1, 0.01);
assertEquals(0, p2, 0.01);
assertEquals(0.38, p3, 0.01);
}
Aggregations