Search in sources :

Example 1 with IndTestFisherZ

use of edu.cmu.tetrad.search.IndTestFisherZ in project tetrad by cmu-phil.

the class FisherZScore method getScore.

@Override
public Score getScore(DataModel dataSet, Parameters parameters) {
    this.dataSet = dataSet;
    double alpha = parameters.getDouble("alpha");
    this.alpha = alpha;
    IndTestFisherZ test = new IndTestFisherZ((DataSet) dataSet, alpha);
    return new ScoredIndTest(test);
}
Also used : IndTestFisherZ(edu.cmu.tetrad.search.IndTestFisherZ) ScoredIndTest(edu.cmu.tetrad.search.ScoredIndTest)

Example 2 with IndTestFisherZ

use of edu.cmu.tetrad.search.IndTestFisherZ in project tetrad by cmu-phil.

the class FisherZ method getTest.

@Override
public IndependenceTest getTest(DataModel dataSet, Parameters parameters) {
    double alpha = parameters.getDouble("alpha");
    this.alpha = alpha;
    if (dataSet instanceof ICovarianceMatrix) {
        return new IndTestFisherZ((ICovarianceMatrix) dataSet, alpha);
    } else if (dataSet instanceof DataSet) {
        return new IndTestFisherZ((DataSet) dataSet, alpha);
    }
    throw new IllegalArgumentException("Expecting eithet a data set or a covariance matrix.");
}
Also used : IndTestFisherZ(edu.cmu.tetrad.search.IndTestFisherZ)

Example 3 with IndTestFisherZ

use of edu.cmu.tetrad.search.IndTestFisherZ in project tetrad by cmu-phil.

the class PerformanceTestsDan method testIdaOutputForDan.

private void testIdaOutputForDan() {
    int numRuns = 100;
    for (int run = 0; run < numRuns; run++) {
        double alphaGFci = 0.01;
        double alphaPc = 0.01;
        int penaltyDiscount = 1;
        int depth = 3;
        int maxPathLength = 3;
        final int numVars = 15;
        final double edgesPerNode = 1.0;
        final int numCases = 1000;
        // final int numLatents = RandomUtil.getInstance().nextInt(3) + 1;
        final int numLatents = 6;
        // writeToFile = false;
        PrintStream out1;
        PrintStream out2;
        PrintStream out3;
        PrintStream out4;
        PrintStream out5;
        PrintStream out6;
        PrintStream out7;
        PrintStream out8;
        PrintStream out9;
        PrintStream out10;
        PrintStream out11;
        PrintStream out12;
        File dir0 = new File("gfci.output");
        dir0.mkdirs();
        File dir = new File(dir0, "" + (run + 1));
        dir.mkdir();
        try {
            out1 = new PrintStream(new File(dir, "hyperparameters.txt"));
            out2 = new PrintStream(new File(dir, "variables.txt"));
            out3 = new PrintStream(new File(dir, "dag.long.txt"));
            out4 = new PrintStream(new File(dir, "dag.matrix.txt"));
            out5 = new PrintStream(new File(dir, "coef.matrix.txt"));
            out6 = new PrintStream(new File(dir, "pag.long.txt"));
            out7 = new PrintStream(new File(dir, "pag.matrix.txt"));
            out8 = new PrintStream(new File(dir, "pattern.long.txt"));
            out9 = new PrintStream(new File(dir, "pattern.matrix.txt"));
            out10 = new PrintStream(new File(dir, "data.txt"));
            out11 = new PrintStream(new File(dir, "true.pag.long.txt"));
            out12 = new PrintStream(new File(dir, "true.pag.matrix.txt"));
        } catch (FileNotFoundException e) {
            e.printStackTrace();
            throw new RuntimeException(e);
        }
        out1.println("Num _vars = " + numVars);
        out1.println("Num edges = " + (int) (numVars * edgesPerNode));
        out1.println("Num cases = " + numCases);
        out1.println("Alpha for PC = " + alphaPc);
        out1.println("Alpha for FFCI = " + alphaGFci);
        out1.println("Penalty discount = " + penaltyDiscount);
        out1.println("Depth = " + depth);
        out1.println("Maximum reachable path length for dsep search and discriminating undirectedPaths = " + maxPathLength);
        List<Node> vars = new ArrayList<>();
        for (int i = 0; i < numVars; i++) vars.add(new GraphNode("X" + (i + 1)));
        // Graph dag = DataGraphUtils.randomDagQuick2(varsWithLatents, 0, (int) (varsWithLatents.size() * edgesPerNode));
        Graph dag = GraphUtils.randomGraph(vars, 0, (int) (vars.size() * edgesPerNode), 5, 5, 5, false);
        GraphUtils.fixLatents1(numLatents, dag);
        // List<Node> varsWithLatents = new ArrayList<Node>();
        // 
        // Graph dag = getLatentGraph(_vars, varsWithLatents, edgesPerNode, numLatents);
        out3.println(dag);
        printDanMatrix(vars, dag, out4);
        SemPm pm = new SemPm(dag);
        SemIm im = new SemIm(pm);
        NumberFormat nf = new DecimalFormat("0.0000");
        for (int i = 0; i < vars.size(); i++) {
            for (Node var : vars) {
                if (im.existsEdgeCoef(var, vars.get(i))) {
                    double coef = im.getEdgeCoef(var, vars.get(i));
                    out5.print(nf.format(coef) + "\t");
                } else {
                    out5.print(nf.format(0) + "\t");
                }
            }
            out5.println();
        }
        out5.println();
        String vars_temp = vars.toString();
        vars_temp = vars_temp.replace("[", "");
        vars_temp = vars_temp.replace("]", "");
        vars_temp = vars_temp.replace("X", "");
        out2.println(vars_temp);
        List<Node> _vars = new ArrayList<>();
        for (Node node : vars) {
            if (node.getNodeType() == NodeType.MEASURED) {
                _vars.add(node);
            }
        }
        String _vars_temp = _vars.toString();
        _vars_temp = _vars_temp.replace("[", "");
        _vars_temp = _vars_temp.replace("]", "");
        _vars_temp = _vars_temp.replace("X", "");
        out2.println(_vars_temp);
        DataSet fullData = im.simulateData(numCases, false);
        DataSet data = DataUtils.restrictToMeasured(fullData);
        ICovarianceMatrix cov = new CovarianceMatrix(data);
        final IndTestFisherZ independenceTestGFci = new IndTestFisherZ(cov, alphaGFci);
        final edu.cmu.tetrad.search.SemBicScore scoreGfci = new edu.cmu.tetrad.search.SemBicScore(cov);
        out6.println("GFCI.PAG_of_the_true_DAG");
        GFci gFci = new GFci(independenceTestGFci, scoreGfci);
        gFci.setVerbose(false);
        gFci.setMaxDegree(depth);
        gFci.setMaxPathLength(maxPathLength);
        // gFci.setPossibleDsepSearchDone(true);
        gFci.setCompleteRuleSetUsed(true);
        Graph pag = gFci.search();
        out6.println(pag);
        printDanMatrix(_vars, pag, out7);
        out8.println("Pattern_of_the_true_DAG OVER MEASURED VARIABLES");
        final IndTestFisherZ independencePc = new IndTestFisherZ(cov, alphaPc);
        Pc pc = new Pc(independencePc);
        pc.setVerbose(false);
        pc.setDepth(depth);
        Graph pattern = pc.search();
        out8.println(pattern);
        printDanMatrix(_vars, pattern, out9);
        out10.println(data);
        out11.println("True PAG_of_the_true_DAG");
        final Graph truePag = new DagToPag(dag).convert();
        out11.println(truePag);
        printDanMatrix(_vars, truePag, out12);
        out1.close();
        out2.close();
        out3.close();
        out4.close();
        out5.close();
        out6.close();
        out7.close();
        out8.close();
        out9.close();
        out10.close();
        out11.close();
        out12.close();
    }
}
Also used : DataSet(edu.cmu.tetrad.data.DataSet) ICovarianceMatrix(edu.cmu.tetrad.data.ICovarianceMatrix) DecimalFormat(java.text.DecimalFormat) FileNotFoundException(java.io.FileNotFoundException) ArrayList(java.util.ArrayList) Pc(edu.cmu.tetrad.search.Pc) SemPm(edu.cmu.tetrad.sem.SemPm) PrintStream(java.io.PrintStream) IndTestFisherZ(edu.cmu.tetrad.search.IndTestFisherZ) GFci(edu.cmu.tetrad.search.GFci) CovarianceMatrix(edu.cmu.tetrad.data.CovarianceMatrix) ICovarianceMatrix(edu.cmu.tetrad.data.ICovarianceMatrix) DagToPag(edu.cmu.tetrad.search.DagToPag) File(java.io.File) SemIm(edu.cmu.tetrad.sem.SemIm) NumberFormat(java.text.NumberFormat)

Example 4 with IndTestFisherZ

use of edu.cmu.tetrad.search.IndTestFisherZ in project tetrad by cmu-phil.

the class PcFges method search.

@Override
public Graph search(DataModel dataSet, Parameters parameters) {
    if (parameters.getInt("bootstrapSampleSize") < 1) {
        DataSet _dataSet = (DataSet) dataSet;
        ICovarianceMatrix cov = new CovarianceMatrix(_dataSet);
        // parameters.getDouble("alpha")));
        edu.cmu.tetrad.search.FasStable fas = new edu.cmu.tetrad.search.FasStable(new IndTestFisherZ(cov, 0.001));
        Graph bound = fas.search();
        edu.cmu.tetrad.search.Fges search = new edu.cmu.tetrad.search.Fges(score.getScore(cov, parameters));
        search.setVerbose(parameters.getBoolean("verbose"));
        search.setFaithfulnessAssumed(parameters.getBoolean("faithfulnessAssumed"));
        search.setKnowledge(knowledge);
        search.setMaxDegree(parameters.getInt("maxDegree"));
        search.setSymmetricFirstStep(parameters.getBoolean("symmetricFirstStep"));
        System.out.println("Bound graph done");
        Object obj = parameters.get("printStream");
        if (obj instanceof PrintStream) {
            search.setOut((PrintStream) obj);
        }
        search.setBoundGraph(bound);
        return search.search();
    } else {
        PcFges algorithm = new PcFges(score, compareToTrue);
        // algorithm.setKnowledge(knowledge);
        if (initialGraph != null) {
            algorithm.setInitialGraph(initialGraph);
        }
        DataSet data = (DataSet) dataSet;
        GeneralBootstrapTest search = new GeneralBootstrapTest(data, algorithm, parameters.getInt("bootstrapSampleSize"));
        search.setKnowledge(knowledge);
        BootstrapEdgeEnsemble edgeEnsemble = BootstrapEdgeEnsemble.Highest;
        switch(parameters.getInt("bootstrapEnsemble", 1)) {
            case 0:
                edgeEnsemble = BootstrapEdgeEnsemble.Preserved;
                break;
            case 1:
                edgeEnsemble = BootstrapEdgeEnsemble.Highest;
                break;
            case 2:
                edgeEnsemble = BootstrapEdgeEnsemble.Majority;
        }
        search.setEdgeEnsemble(edgeEnsemble);
        search.setParameters(parameters);
        search.setVerbose(parameters.getBoolean("verbose"));
        return search.search();
    }
}
Also used : PrintStream(java.io.PrintStream) GeneralBootstrapTest(edu.pitt.dbmi.algo.bootstrap.GeneralBootstrapTest) IndTestFisherZ(edu.cmu.tetrad.search.IndTestFisherZ) BootstrapEdgeEnsemble(edu.pitt.dbmi.algo.bootstrap.BootstrapEdgeEnsemble) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) TakesInitialGraph(edu.cmu.tetrad.algcomparison.utils.TakesInitialGraph) Graph(edu.cmu.tetrad.graph.Graph)

Example 5 with IndTestFisherZ

use of edu.cmu.tetrad.search.IndTestFisherZ in project tetrad by cmu-phil.

the class TestIndTestFisherZ method testDirections.

@Test
public void testDirections() {
    RandomUtil.getInstance().setSeed(48285934L);
    Graph graph1 = new EdgeListGraph();
    Graph graph2 = new EdgeListGraph();
    Node x = new GraphNode("X");
    Node y = new GraphNode("Y");
    Node z = new GraphNode("Z");
    graph1.addNode(x);
    graph1.addNode(y);
    graph1.addNode(z);
    graph2.addNode(x);
    graph2.addNode(y);
    graph2.addNode(z);
    graph1.addEdge(Edges.directedEdge(x, y));
    graph1.addEdge(Edges.directedEdge(y, z));
    graph2.addEdge(Edges.directedEdge(x, y));
    graph2.addEdge(Edges.directedEdge(z, y));
    SemPm pm1 = new SemPm(graph1);
    SemPm pm2 = new SemPm(graph2);
    SemIm im1 = new SemIm(pm1);
    SemIm im2 = new SemIm(pm2);
    im2.setEdgeCoef(x, y, im1.getEdgeCoef(x, y));
    im2.setEdgeCoef(z, y, im1.getEdgeCoef(y, z));
    DataSet data1 = im1.simulateData(500, false);
    DataSet data2 = im2.simulateData(500, false);
    IndependenceTest test1 = new IndTestFisherZ(data1, 0.05);
    IndependenceTest test2 = new IndTestFisherZ(data2, 0.05);
    test1.isIndependent(data1.getVariable(x.getName()), data1.getVariable(y.getName()));
    double p1 = test1.getPValue();
    test2.isIndependent(data2.getVariable(x.getName()), data2.getVariable(z.getName()), data2.getVariable(y.getName()));
    double p2 = test2.getPValue();
    test2.isIndependent(data2.getVariable(x.getName()), data2.getVariable(z.getName()));
    double p3 = test2.getPValue();
    assertEquals(0, p1, 0.01);
    assertEquals(0, p2, 0.01);
    assertEquals(0.38, p3, 0.01);
}
Also used : IndependenceTest(edu.cmu.tetrad.search.IndependenceTest) IndTestFisherZ(edu.cmu.tetrad.search.IndTestFisherZ) SemPm(edu.cmu.tetrad.sem.SemPm) SemIm(edu.cmu.tetrad.sem.SemIm) Test(org.junit.Test) IndependenceTest(edu.cmu.tetrad.search.IndependenceTest)

Aggregations

IndTestFisherZ (edu.cmu.tetrad.search.IndTestFisherZ)5 SemIm (edu.cmu.tetrad.sem.SemIm)2 SemPm (edu.cmu.tetrad.sem.SemPm)2 PrintStream (java.io.PrintStream)2 TakesInitialGraph (edu.cmu.tetrad.algcomparison.utils.TakesInitialGraph)1 CovarianceMatrix (edu.cmu.tetrad.data.CovarianceMatrix)1 DataSet (edu.cmu.tetrad.data.DataSet)1 ICovarianceMatrix (edu.cmu.tetrad.data.ICovarianceMatrix)1 EdgeListGraph (edu.cmu.tetrad.graph.EdgeListGraph)1 Graph (edu.cmu.tetrad.graph.Graph)1 DagToPag (edu.cmu.tetrad.search.DagToPag)1 GFci (edu.cmu.tetrad.search.GFci)1 IndependenceTest (edu.cmu.tetrad.search.IndependenceTest)1 Pc (edu.cmu.tetrad.search.Pc)1 ScoredIndTest (edu.cmu.tetrad.search.ScoredIndTest)1 BootstrapEdgeEnsemble (edu.pitt.dbmi.algo.bootstrap.BootstrapEdgeEnsemble)1 GeneralBootstrapTest (edu.pitt.dbmi.algo.bootstrap.GeneralBootstrapTest)1 File (java.io.File)1 FileNotFoundException (java.io.FileNotFoundException)1 DecimalFormat (java.text.DecimalFormat)1