Search in sources :

Example 1 with Pc

use of edu.cmu.tetrad.search.Pc in project tetrad by cmu-phil.

the class IambnPc method findMb.

public List<Node> findMb(String targetName) {
    Node target = getVariableForName(targetName);
    List<Node> cmb = new LinkedList<>();
    Pc pc = new Pc(independenceTest);
    boolean cont = true;
    // Forward phase.
    while (cont) {
        cont = false;
        List<Node> remaining = new LinkedList<>(variables);
        remaining.removeAll(cmb);
        remaining.remove(target);
        double strength = Double.NEGATIVE_INFINITY;
        Node f = null;
        for (Node v : remaining) {
            if (v == target) {
                continue;
            }
            double _strength = associationStrength(v, target, cmb);
            if (_strength > strength) {
                strength = _strength;
                f = v;
            }
        }
        if (f == null) {
            break;
        }
        if (!independenceTest.isIndependent(f, target, cmb)) {
            cmb.add(f);
            cont = true;
        }
    }
    // Backward phase.
    cmb.add(target);
    Graph graph = pc.search(cmb);
    MbUtils.trimToMbNodes(graph, target, false);
    // cmb = DataGraphUtils.markovBlanketDag(target, graph).getNodes();
    cmb = graph.getNodes();
    cmb.remove(target);
    return cmb;
}
Also used : Pc(edu.cmu.tetrad.search.Pc) Graph(edu.cmu.tetrad.graph.Graph) Node(edu.cmu.tetrad.graph.Node) LinkedList(java.util.LinkedList)

Example 2 with Pc

use of edu.cmu.tetrad.search.Pc in project tetrad by cmu-phil.

the class YeastPcCcdSearchWrapper method PCAccuracy.

private static int[] PCAccuracy(double alpha, int ngenes, DataSet cds, IKnowledge bk, int[][] yeastReg, List names, DataOutputStream d, boolean v) {
    int[] falsePosNeg = new int[2];
    IndTestCramerT indTestCramerT = new IndTestCramerT(cds, alpha);
    Pc pcs = new Pc(indTestCramerT);
    pcs.setKnowledge(bk);
    Graph pcModel = pcs.search();
    int falsePositives = 0;
    int falseNegatives = 0;
    int[][] pcModelAdj = new int[ngenes][ngenes];
    /*
        for(int i = 0; i < ngenes; i++) {
          String namei1 = (String) names.get(i);
              String namei2 = (String) names.get(i + ngenes);
              for(int j = 0; j < ngenes; j++) {

                String namej1 = (String) names.get(j);
                String namej2 = (String) names.get(j + ngenes);

                //Set adjacency matrix for PC search
                if(pcModel.isAdjacent(cds.get(namei1).getVariable(),
                                     cds.get(namej1).getVariable()) ||
                  pcModel.isAdjacent(cds.get(namei1).getVariable(),
                                     cds.get(namej2).getVariable()) ||
                  pcModel.isAdjacent(cds.get(namei2).getVariable(),
                                     cds.get(namej1).getVariable()) ||
                  pcModel.isAdjacent(cds.get(namei2).getVariable(),
                                     cds.get(namej2).getVariable()))
                  pcModelAdj[i][j] = 1;
                else pcModelAdj[i][j] = 0;
              }
        }
        */
    int nvariables = names.size();
    for (int i = 0; i < nvariables; i++) {
        String namei = (String) names.get(i);
        for (int j = 0; j < nvariables; j++) {
            String namej = (String) names.get(j);
            pcModelAdj[i][j] = 0;
            Node vari = indTestCramerT.getVariable(namei);
            Node varj = indTestCramerT.getVariable(namej);
            if (!pcModel.isAdjacentTo(vari, varj)) {
                continue;
            }
            pcModelAdj[i][j] = 1;
        }
    }
    for (int i = 0; i < ngenes; i++) {
        for (int j = i; j < ngenes; j++) {
            if (yeastReg[i][j] == 0 && pcModelAdj[i][j] == 1) {
                falsePositives++;
            }
            if (yeastReg[i][j] == 1 && pcModelAdj[i][j] == 0) {
                falseNegatives++;
            }
        }
    }
    falsePosNeg[0] = falsePositives;
    falsePosNeg[1] = falseNegatives;
    if (v) {
        try {
            d.writeBytes("\n \n");
            d.writeBytes("  Results of PC search with alpha = " + alpha);
            d.writeBytes("  false+ " + falsePositives + "\t");
            d.writeBytes("false- " + falseNegatives + "\n");
            d.writeBytes("  Adjacency matrix of estimated model:  \n");
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
    if (v) {
        printAdjMatrix(pcModelAdj, names, d);
    }
    return falsePosNeg;
}
Also used : Pc(edu.cmu.tetrad.search.Pc) Graph(edu.cmu.tetrad.graph.Graph) IndTestCramerT(edu.cmu.tetrad.search.IndTestCramerT) Node(edu.cmu.tetrad.graph.Node)

Example 3 with Pc

use of edu.cmu.tetrad.search.Pc in project tetrad by cmu-phil.

the class PerformanceTestsDan method testIdaOutputForDan.

private void testIdaOutputForDan() {
    int numRuns = 100;
    for (int run = 0; run < numRuns; run++) {
        double alphaGFci = 0.01;
        double alphaPc = 0.01;
        int penaltyDiscount = 1;
        int depth = 3;
        int maxPathLength = 3;
        final int numVars = 15;
        final double edgesPerNode = 1.0;
        final int numCases = 1000;
        // final int numLatents = RandomUtil.getInstance().nextInt(3) + 1;
        final int numLatents = 6;
        // writeToFile = false;
        PrintStream out1;
        PrintStream out2;
        PrintStream out3;
        PrintStream out4;
        PrintStream out5;
        PrintStream out6;
        PrintStream out7;
        PrintStream out8;
        PrintStream out9;
        PrintStream out10;
        PrintStream out11;
        PrintStream out12;
        File dir0 = new File("gfci.output");
        dir0.mkdirs();
        File dir = new File(dir0, "" + (run + 1));
        dir.mkdir();
        try {
            out1 = new PrintStream(new File(dir, "hyperparameters.txt"));
            out2 = new PrintStream(new File(dir, "variables.txt"));
            out3 = new PrintStream(new File(dir, "dag.long.txt"));
            out4 = new PrintStream(new File(dir, "dag.matrix.txt"));
            out5 = new PrintStream(new File(dir, "coef.matrix.txt"));
            out6 = new PrintStream(new File(dir, "pag.long.txt"));
            out7 = new PrintStream(new File(dir, "pag.matrix.txt"));
            out8 = new PrintStream(new File(dir, "pattern.long.txt"));
            out9 = new PrintStream(new File(dir, "pattern.matrix.txt"));
            out10 = new PrintStream(new File(dir, "data.txt"));
            out11 = new PrintStream(new File(dir, "true.pag.long.txt"));
            out12 = new PrintStream(new File(dir, "true.pag.matrix.txt"));
        } catch (FileNotFoundException e) {
            e.printStackTrace();
            throw new RuntimeException(e);
        }
        out1.println("Num _vars = " + numVars);
        out1.println("Num edges = " + (int) (numVars * edgesPerNode));
        out1.println("Num cases = " + numCases);
        out1.println("Alpha for PC = " + alphaPc);
        out1.println("Alpha for FFCI = " + alphaGFci);
        out1.println("Penalty discount = " + penaltyDiscount);
        out1.println("Depth = " + depth);
        out1.println("Maximum reachable path length for dsep search and discriminating undirectedPaths = " + maxPathLength);
        List<Node> vars = new ArrayList<>();
        for (int i = 0; i < numVars; i++) vars.add(new GraphNode("X" + (i + 1)));
        // Graph dag = DataGraphUtils.randomDagQuick2(varsWithLatents, 0, (int) (varsWithLatents.size() * edgesPerNode));
        Graph dag = GraphUtils.randomGraph(vars, 0, (int) (vars.size() * edgesPerNode), 5, 5, 5, false);
        GraphUtils.fixLatents1(numLatents, dag);
        // List<Node> varsWithLatents = new ArrayList<Node>();
        // 
        // Graph dag = getLatentGraph(_vars, varsWithLatents, edgesPerNode, numLatents);
        out3.println(dag);
        printDanMatrix(vars, dag, out4);
        SemPm pm = new SemPm(dag);
        SemIm im = new SemIm(pm);
        NumberFormat nf = new DecimalFormat("0.0000");
        for (int i = 0; i < vars.size(); i++) {
            for (Node var : vars) {
                if (im.existsEdgeCoef(var, vars.get(i))) {
                    double coef = im.getEdgeCoef(var, vars.get(i));
                    out5.print(nf.format(coef) + "\t");
                } else {
                    out5.print(nf.format(0) + "\t");
                }
            }
            out5.println();
        }
        out5.println();
        String vars_temp = vars.toString();
        vars_temp = vars_temp.replace("[", "");
        vars_temp = vars_temp.replace("]", "");
        vars_temp = vars_temp.replace("X", "");
        out2.println(vars_temp);
        List<Node> _vars = new ArrayList<>();
        for (Node node : vars) {
            if (node.getNodeType() == NodeType.MEASURED) {
                _vars.add(node);
            }
        }
        String _vars_temp = _vars.toString();
        _vars_temp = _vars_temp.replace("[", "");
        _vars_temp = _vars_temp.replace("]", "");
        _vars_temp = _vars_temp.replace("X", "");
        out2.println(_vars_temp);
        DataSet fullData = im.simulateData(numCases, false);
        DataSet data = DataUtils.restrictToMeasured(fullData);
        ICovarianceMatrix cov = new CovarianceMatrix(data);
        final IndTestFisherZ independenceTestGFci = new IndTestFisherZ(cov, alphaGFci);
        final edu.cmu.tetrad.search.SemBicScore scoreGfci = new edu.cmu.tetrad.search.SemBicScore(cov);
        out6.println("GFCI.PAG_of_the_true_DAG");
        GFci gFci = new GFci(independenceTestGFci, scoreGfci);
        gFci.setVerbose(false);
        gFci.setMaxDegree(depth);
        gFci.setMaxPathLength(maxPathLength);
        // gFci.setPossibleDsepSearchDone(true);
        gFci.setCompleteRuleSetUsed(true);
        Graph pag = gFci.search();
        out6.println(pag);
        printDanMatrix(_vars, pag, out7);
        out8.println("Pattern_of_the_true_DAG OVER MEASURED VARIABLES");
        final IndTestFisherZ independencePc = new IndTestFisherZ(cov, alphaPc);
        Pc pc = new Pc(independencePc);
        pc.setVerbose(false);
        pc.setDepth(depth);
        Graph pattern = pc.search();
        out8.println(pattern);
        printDanMatrix(_vars, pattern, out9);
        out10.println(data);
        out11.println("True PAG_of_the_true_DAG");
        final Graph truePag = new DagToPag(dag).convert();
        out11.println(truePag);
        printDanMatrix(_vars, truePag, out12);
        out1.close();
        out2.close();
        out3.close();
        out4.close();
        out5.close();
        out6.close();
        out7.close();
        out8.close();
        out9.close();
        out10.close();
        out11.close();
        out12.close();
    }
}
Also used : DataSet(edu.cmu.tetrad.data.DataSet) ICovarianceMatrix(edu.cmu.tetrad.data.ICovarianceMatrix) DecimalFormat(java.text.DecimalFormat) FileNotFoundException(java.io.FileNotFoundException) ArrayList(java.util.ArrayList) Pc(edu.cmu.tetrad.search.Pc) SemPm(edu.cmu.tetrad.sem.SemPm) PrintStream(java.io.PrintStream) IndTestFisherZ(edu.cmu.tetrad.search.IndTestFisherZ) GFci(edu.cmu.tetrad.search.GFci) CovarianceMatrix(edu.cmu.tetrad.data.CovarianceMatrix) ICovarianceMatrix(edu.cmu.tetrad.data.ICovarianceMatrix) DagToPag(edu.cmu.tetrad.search.DagToPag) File(java.io.File) SemIm(edu.cmu.tetrad.sem.SemIm) NumberFormat(java.text.NumberFormat)

Example 4 with Pc

use of edu.cmu.tetrad.search.Pc in project tetrad by cmu-phil.

the class TestFges method testFromGraph.

@Test
public void testFromGraph() {
    int numNodes = 20;
    int numIterations = 20;
    for (int i = 0; i < numIterations; i++) {
        // System.out.println("Iteration " + (i + 1));
        Graph dag = GraphUtils.randomDag(numNodes, 0, numNodes, 10, 10, 10, false);
        Fges fges = new Fges(new GraphScore(dag));
        fges.setFaithfulnessAssumed(true);
        Graph pattern1 = fges.search();
        Graph pattern2 = new Pc(new IndTestDSep(dag)).search();
        // System.out.println(pattern2);
        assertEquals(pattern2, pattern1);
    }
}
Also used : RandomGraph(edu.cmu.tetrad.algcomparison.graph.RandomGraph) Pc(edu.cmu.tetrad.search.Pc) Fges(edu.cmu.tetrad.search.Fges) SemBicDTest(edu.cmu.tetrad.algcomparison.independence.SemBicDTest) SemBicTest(edu.cmu.tetrad.algcomparison.independence.SemBicTest) Test(org.junit.Test)

Example 5 with Pc

use of edu.cmu.tetrad.search.Pc in project tetrad by cmu-phil.

the class TestFges method testFromGraphSimpleFges.

@Test
public void testFromGraphSimpleFges() {
    // This may fail if faithfulness is assumed but should pass if not.
    Node x1 = new GraphNode("X1");
    Node x2 = new GraphNode("X2");
    Node x3 = new GraphNode("X3");
    Node x4 = new GraphNode("X4");
    Graph g = new EdgeListGraph();
    g.addNode(x1);
    g.addNode(x2);
    g.addNode(x3);
    g.addNode(x4);
    g.addDirectedEdge(x1, x2);
    g.addDirectedEdge(x1, x3);
    g.addDirectedEdge(x4, x2);
    g.addDirectedEdge(x4, x3);
    Graph pattern1 = new Pc(new IndTestDSep(g)).search();
    Fges fges = new Fges(new GraphScore(g));
    fges.setFaithfulnessAssumed(true);
    Graph pattern2 = fges.search();
    // System.out.println(pattern1);
    // System.out.println(pattern2);
    assertEquals(pattern1, pattern2);
}
Also used : RandomGraph(edu.cmu.tetrad.algcomparison.graph.RandomGraph) Pc(edu.cmu.tetrad.search.Pc) Fges(edu.cmu.tetrad.search.Fges) SemBicDTest(edu.cmu.tetrad.algcomparison.independence.SemBicDTest) SemBicTest(edu.cmu.tetrad.algcomparison.independence.SemBicTest) Test(org.junit.Test)

Aggregations

Pc (edu.cmu.tetrad.search.Pc)6 Graph (edu.cmu.tetrad.graph.Graph)3 RandomGraph (edu.cmu.tetrad.algcomparison.graph.RandomGraph)2 SemBicDTest (edu.cmu.tetrad.algcomparison.independence.SemBicDTest)2 SemBicTest (edu.cmu.tetrad.algcomparison.independence.SemBicTest)2 Node (edu.cmu.tetrad.graph.Node)2 Fges (edu.cmu.tetrad.search.Fges)2 Test (org.junit.Test)2 CovarianceMatrix (edu.cmu.tetrad.data.CovarianceMatrix)1 DataSet (edu.cmu.tetrad.data.DataSet)1 ICovarianceMatrix (edu.cmu.tetrad.data.ICovarianceMatrix)1 DagToPag (edu.cmu.tetrad.search.DagToPag)1 GFci (edu.cmu.tetrad.search.GFci)1 IndTestCramerT (edu.cmu.tetrad.search.IndTestCramerT)1 IndTestDSep (edu.cmu.tetrad.search.IndTestDSep)1 IndTestFisherZ (edu.cmu.tetrad.search.IndTestFisherZ)1 IndependenceTest (edu.cmu.tetrad.search.IndependenceTest)1 SemIm (edu.cmu.tetrad.sem.SemIm)1 SemPm (edu.cmu.tetrad.sem.SemPm)1 File (java.io.File)1