Search in sources :

Example 1 with GFci

use of edu.cmu.tetrad.search.GFci in project tetrad by cmu-phil.

the class Gfci method search.

@Override
public Graph search(DataModel dataSet, Parameters parameters) {
    if (parameters.getInt("bootstrapSampleSize") < 1) {
        GFci search = new GFci(test.getTest(dataSet, parameters), score.getScore(dataSet, parameters));
        search.setMaxDegree(parameters.getInt("maxDegree"));
        search.setKnowledge(knowledge);
        search.setVerbose(parameters.getBoolean("verbose"));
        search.setFaithfulnessAssumed(parameters.getBoolean("faithfulnessAssumed"));
        search.setMaxPathLength(parameters.getInt("maxPathLength"));
        search.setCompleteRuleSetUsed(parameters.getBoolean("completeRuleSetUsed"));
        Object obj = parameters.get("printStream");
        if (obj instanceof PrintStream) {
            search.setOut((PrintStream) obj);
        }
        return search.search();
    } else {
        Gfci algorithm = new Gfci(test, score);
        // algorithm.setKnowledge(knowledge);
        // if (initialGraph != null) {
        // algorithm.setInitialGraph(initialGraph);
        // }
        DataSet data = (DataSet) dataSet;
        GeneralBootstrapTest search = new GeneralBootstrapTest(data, algorithm, parameters.getInt("bootstrapSampleSize"));
        search.setKnowledge(knowledge);
        BootstrapEdgeEnsemble edgeEnsemble = BootstrapEdgeEnsemble.Highest;
        switch(parameters.getInt("bootstrapEnsemble", 1)) {
            case 0:
                edgeEnsemble = BootstrapEdgeEnsemble.Preserved;
                break;
            case 1:
                edgeEnsemble = BootstrapEdgeEnsemble.Highest;
                break;
            case 2:
                edgeEnsemble = BootstrapEdgeEnsemble.Majority;
        }
        search.setEdgeEnsemble(edgeEnsemble);
        search.setParameters(parameters);
        search.setVerbose(parameters.getBoolean("verbose"));
        return search.search();
    }
}
Also used : PrintStream(java.io.PrintStream) GeneralBootstrapTest(edu.pitt.dbmi.algo.bootstrap.GeneralBootstrapTest) GFci(edu.cmu.tetrad.search.GFci) BootstrapEdgeEnsemble(edu.pitt.dbmi.algo.bootstrap.BootstrapEdgeEnsemble)

Example 2 with GFci

use of edu.cmu.tetrad.search.GFci in project tetrad by cmu-phil.

the class PerformanceTestsDan method testIdaOutputForDan.

private void testIdaOutputForDan() {
    int numRuns = 100;
    for (int run = 0; run < numRuns; run++) {
        double alphaGFci = 0.01;
        double alphaPc = 0.01;
        int penaltyDiscount = 1;
        int depth = 3;
        int maxPathLength = 3;
        final int numVars = 15;
        final double edgesPerNode = 1.0;
        final int numCases = 1000;
        // final int numLatents = RandomUtil.getInstance().nextInt(3) + 1;
        final int numLatents = 6;
        // writeToFile = false;
        PrintStream out1;
        PrintStream out2;
        PrintStream out3;
        PrintStream out4;
        PrintStream out5;
        PrintStream out6;
        PrintStream out7;
        PrintStream out8;
        PrintStream out9;
        PrintStream out10;
        PrintStream out11;
        PrintStream out12;
        File dir0 = new File("gfci.output");
        dir0.mkdirs();
        File dir = new File(dir0, "" + (run + 1));
        dir.mkdir();
        try {
            out1 = new PrintStream(new File(dir, "hyperparameters.txt"));
            out2 = new PrintStream(new File(dir, "variables.txt"));
            out3 = new PrintStream(new File(dir, "dag.long.txt"));
            out4 = new PrintStream(new File(dir, "dag.matrix.txt"));
            out5 = new PrintStream(new File(dir, "coef.matrix.txt"));
            out6 = new PrintStream(new File(dir, "pag.long.txt"));
            out7 = new PrintStream(new File(dir, "pag.matrix.txt"));
            out8 = new PrintStream(new File(dir, "pattern.long.txt"));
            out9 = new PrintStream(new File(dir, "pattern.matrix.txt"));
            out10 = new PrintStream(new File(dir, "data.txt"));
            out11 = new PrintStream(new File(dir, "true.pag.long.txt"));
            out12 = new PrintStream(new File(dir, "true.pag.matrix.txt"));
        } catch (FileNotFoundException e) {
            e.printStackTrace();
            throw new RuntimeException(e);
        }
        out1.println("Num _vars = " + numVars);
        out1.println("Num edges = " + (int) (numVars * edgesPerNode));
        out1.println("Num cases = " + numCases);
        out1.println("Alpha for PC = " + alphaPc);
        out1.println("Alpha for FFCI = " + alphaGFci);
        out1.println("Penalty discount = " + penaltyDiscount);
        out1.println("Depth = " + depth);
        out1.println("Maximum reachable path length for dsep search and discriminating undirectedPaths = " + maxPathLength);
        List<Node> vars = new ArrayList<>();
        for (int i = 0; i < numVars; i++) vars.add(new GraphNode("X" + (i + 1)));
        // Graph dag = DataGraphUtils.randomDagQuick2(varsWithLatents, 0, (int) (varsWithLatents.size() * edgesPerNode));
        Graph dag = GraphUtils.randomGraph(vars, 0, (int) (vars.size() * edgesPerNode), 5, 5, 5, false);
        GraphUtils.fixLatents1(numLatents, dag);
        // List<Node> varsWithLatents = new ArrayList<Node>();
        // 
        // Graph dag = getLatentGraph(_vars, varsWithLatents, edgesPerNode, numLatents);
        out3.println(dag);
        printDanMatrix(vars, dag, out4);
        SemPm pm = new SemPm(dag);
        SemIm im = new SemIm(pm);
        NumberFormat nf = new DecimalFormat("0.0000");
        for (int i = 0; i < vars.size(); i++) {
            for (Node var : vars) {
                if (im.existsEdgeCoef(var, vars.get(i))) {
                    double coef = im.getEdgeCoef(var, vars.get(i));
                    out5.print(nf.format(coef) + "\t");
                } else {
                    out5.print(nf.format(0) + "\t");
                }
            }
            out5.println();
        }
        out5.println();
        String vars_temp = vars.toString();
        vars_temp = vars_temp.replace("[", "");
        vars_temp = vars_temp.replace("]", "");
        vars_temp = vars_temp.replace("X", "");
        out2.println(vars_temp);
        List<Node> _vars = new ArrayList<>();
        for (Node node : vars) {
            if (node.getNodeType() == NodeType.MEASURED) {
                _vars.add(node);
            }
        }
        String _vars_temp = _vars.toString();
        _vars_temp = _vars_temp.replace("[", "");
        _vars_temp = _vars_temp.replace("]", "");
        _vars_temp = _vars_temp.replace("X", "");
        out2.println(_vars_temp);
        DataSet fullData = im.simulateData(numCases, false);
        DataSet data = DataUtils.restrictToMeasured(fullData);
        ICovarianceMatrix cov = new CovarianceMatrix(data);
        final IndTestFisherZ independenceTestGFci = new IndTestFisherZ(cov, alphaGFci);
        final edu.cmu.tetrad.search.SemBicScore scoreGfci = new edu.cmu.tetrad.search.SemBicScore(cov);
        out6.println("GFCI.PAG_of_the_true_DAG");
        GFci gFci = new GFci(independenceTestGFci, scoreGfci);
        gFci.setVerbose(false);
        gFci.setMaxDegree(depth);
        gFci.setMaxPathLength(maxPathLength);
        // gFci.setPossibleDsepSearchDone(true);
        gFci.setCompleteRuleSetUsed(true);
        Graph pag = gFci.search();
        out6.println(pag);
        printDanMatrix(_vars, pag, out7);
        out8.println("Pattern_of_the_true_DAG OVER MEASURED VARIABLES");
        final IndTestFisherZ independencePc = new IndTestFisherZ(cov, alphaPc);
        Pc pc = new Pc(independencePc);
        pc.setVerbose(false);
        pc.setDepth(depth);
        Graph pattern = pc.search();
        out8.println(pattern);
        printDanMatrix(_vars, pattern, out9);
        out10.println(data);
        out11.println("True PAG_of_the_true_DAG");
        final Graph truePag = new DagToPag(dag).convert();
        out11.println(truePag);
        printDanMatrix(_vars, truePag, out12);
        out1.close();
        out2.close();
        out3.close();
        out4.close();
        out5.close();
        out6.close();
        out7.close();
        out8.close();
        out9.close();
        out10.close();
        out11.close();
        out12.close();
    }
}
Also used : DataSet(edu.cmu.tetrad.data.DataSet) ICovarianceMatrix(edu.cmu.tetrad.data.ICovarianceMatrix) DecimalFormat(java.text.DecimalFormat) FileNotFoundException(java.io.FileNotFoundException) ArrayList(java.util.ArrayList) Pc(edu.cmu.tetrad.search.Pc) SemPm(edu.cmu.tetrad.sem.SemPm) PrintStream(java.io.PrintStream) IndTestFisherZ(edu.cmu.tetrad.search.IndTestFisherZ) GFci(edu.cmu.tetrad.search.GFci) CovarianceMatrix(edu.cmu.tetrad.data.CovarianceMatrix) ICovarianceMatrix(edu.cmu.tetrad.data.ICovarianceMatrix) DagToPag(edu.cmu.tetrad.search.DagToPag) File(java.io.File) SemIm(edu.cmu.tetrad.sem.SemIm) NumberFormat(java.text.NumberFormat)

Aggregations

GFci (edu.cmu.tetrad.search.GFci)2 PrintStream (java.io.PrintStream)2 CovarianceMatrix (edu.cmu.tetrad.data.CovarianceMatrix)1 DataSet (edu.cmu.tetrad.data.DataSet)1 ICovarianceMatrix (edu.cmu.tetrad.data.ICovarianceMatrix)1 DagToPag (edu.cmu.tetrad.search.DagToPag)1 IndTestFisherZ (edu.cmu.tetrad.search.IndTestFisherZ)1 Pc (edu.cmu.tetrad.search.Pc)1 SemIm (edu.cmu.tetrad.sem.SemIm)1 SemPm (edu.cmu.tetrad.sem.SemPm)1 BootstrapEdgeEnsemble (edu.pitt.dbmi.algo.bootstrap.BootstrapEdgeEnsemble)1 GeneralBootstrapTest (edu.pitt.dbmi.algo.bootstrap.GeneralBootstrapTest)1 File (java.io.File)1 FileNotFoundException (java.io.FileNotFoundException)1 DecimalFormat (java.text.DecimalFormat)1 NumberFormat (java.text.NumberFormat)1 ArrayList (java.util.ArrayList)1