Search in sources :

Example 66 with SemIm

use of edu.cmu.tetrad.sem.SemIm in project tetrad by cmu-phil.

the class TestDeltaSextadTest method test2.

@Test
public void test2() {
    int c = 2;
    int m = 2;
    int p = 6;
    Graph g = new EdgeListGraph();
    List<List<Node>> varClusters = new ArrayList<>();
    List<List<Node>> latents = new ArrayList<>();
    List<Node> vars = new ArrayList<>();
    for (int y = 0; y < c; y++) {
        varClusters.add(new ArrayList<Node>());
        latents.add(new ArrayList<Node>());
    }
    int e = 0;
    for (int y = 0; y < c; y++) {
        for (int i = 0; i < p; i++) {
            GraphNode n = new GraphNode("V" + ++e);
            vars.add(n);
            varClusters.get(y).add(n);
            g.addNode(n);
        }
    }
    List<Node> l = new ArrayList<>();
    int f = 0;
    for (int y = 0; y < c; y++) {
        for (int j = 0; j < m; j++) {
            Node _l = new GraphNode("L" + ++f);
            _l.setNodeType(NodeType.LATENT);
            l.add(_l);
            latents.get(y).add(_l);
            g.addNode(_l);
        }
    }
    for (int y = 0; y < c; y++) {
        for (int i = 0; i < m; i++) {
            for (int j = 0; j < p; j++) {
                g.addDirectedEdge(latents.get(y).get(i), varClusters.get(y).get(j));
            }
        }
    }
    for (int y = 1; y < c; y++) {
        for (int j = 0; j < m; j++) {
            g.addDirectedEdge(latents.get(y - 1).get(j), latents.get(y).get(j));
        }
    }
    SemPm pm = new SemPm(g);
    SemIm im = new SemIm(pm);
    DataSet data = im.simulateData(1000, false);
    List<Integer> indices = new ArrayList<>();
    indices.add(0);
    indices.add(1);
    indices.add(2);
    indices.add(4);
    indices.add(5);
    indices.add(7);
    Collections.shuffle(indices);
    // Node x1 = data.getVariable(indices.get(0));
    // Node x2 = data.getVariable(indices.get(1));
    // Node x3 = data.getVariable(indices.get(2));
    // Node x4 = data.getVariable(indices.get(3));
    // Node x5 = data.getVariable(indices.get(4));
    // Node x6 = data.getVariable(indices.get(5));
    int x1 = indices.get(0);
    int x2 = indices.get(1);
    int x3 = indices.get(2);
    int x4 = indices.get(3);
    int x5 = indices.get(4);
    int x6 = indices.get(5);
    DeltaSextadTest test = new DeltaSextadTest(data);
    // Should be invariant to changes or order of the first three or of the last three variables.
    double a = test.getPValue(new IntSextad(x1, x2, x3, x4, x5, x6));
    double b = test.getPValue(new IntSextad(x2, x3, x1, x5, x4, x6));
    assertEquals(a, b, 1e-7);
}
Also used : IntSextad(edu.cmu.tetrad.search.IntSextad) DataSet(edu.cmu.tetrad.data.DataSet) ArrayList(java.util.ArrayList) DeltaSextadTest(edu.cmu.tetrad.search.DeltaSextadTest) SemPm(edu.cmu.tetrad.sem.SemPm) ArrayList(java.util.ArrayList) List(java.util.List) SemIm(edu.cmu.tetrad.sem.SemIm) Test(org.junit.Test) DeltaSextadTest(edu.cmu.tetrad.search.DeltaSextadTest)

Example 67 with SemIm

use of edu.cmu.tetrad.sem.SemIm in project tetrad by cmu-phil.

the class TestDeltaSextadTest method testBollenExample1.

// Bollen and Ting, Confirmatory Tetrad Analysis, p. 164 Sympathy and Anger.
@Test
public void testBollenExample1() {
    SemIm sem = getSem1();
    DataSet data = sem.simulateData(3000, false);
    List<Node> variables = data.getVariables();
    int m1 = 0;
    int m2 = 1;
    int m3 = 2;
    int m4 = 3;
    int m5 = 4;
    int m6 = 5;
    IntSextad t1 = new IntSextad(m1, m2, m3, m4, m5, m6);
    IntSextad t2 = new IntSextad(m1, m2, m4, m3, m5, m6);
    IntSextad t3 = new IntSextad(m1, m2, m5, m3, m4, m6);
    IntSextad t4 = new IntSextad(m1, m2, m6, m3, m4, m5);
    IntSextad t5 = new IntSextad(m1, m3, m4, m2, m5, m6);
    IntSextad t6 = new IntSextad(m1, m3, m5, m2, m4, m6);
    IntSextad t7 = new IntSextad(m1, m3, m6, m2, m4, m5);
    IntSextad t8 = new IntSextad(m1, m4, m5, m2, m3, m6);
    IntSextad t9 = new IntSextad(m1, m4, m6, m2, m3, m5);
    IntSextad t10 = new IntSextad(m1, m5, m6, m2, m3, m4);
    List<IntSextad> sextads = new ArrayList<>();
    sextads.add(t1);
    sextads.add(t2);
    sextads.add(t3);
    sextads.add(t4);
    sextads.add(t5);
    sextads.add(t6);
    sextads.add(t7);
    sextads.add(t8);
    sextads.add(t9);
    sextads.add(t10);
    DeltaSextadTest test = new DeltaSextadTest(data);
    int numSextads = 3;
    double alpha = 0.001;
    ChoiceGenerator gen = new ChoiceGenerator(sextads.size(), numSextads);
    int[] choice;
    while ((choice = gen.next()) != null) {
        IntSextad[] _sextads = new IntSextad[numSextads];
        for (int i = 0; i < numSextads; i++) {
            _sextads[i] = sextads.get(choice[i]);
        }
        double p = test.getPValue(_sextads);
    }
}
Also used : IntSextad(edu.cmu.tetrad.search.IntSextad) DataSet(edu.cmu.tetrad.data.DataSet) ArrayList(java.util.ArrayList) DeltaSextadTest(edu.cmu.tetrad.search.DeltaSextadTest) ChoiceGenerator(edu.cmu.tetrad.util.ChoiceGenerator) SemIm(edu.cmu.tetrad.sem.SemIm) Test(org.junit.Test) DeltaSextadTest(edu.cmu.tetrad.search.DeltaSextadTest)

Example 68 with SemIm

use of edu.cmu.tetrad.sem.SemIm in project tetrad by cmu-phil.

the class TestDeltaSextadTest method getSem1.

private SemIm getSem1() {
    Graph graph = new EdgeListGraph();
    Node l1 = new GraphNode("l1");
    Node l2 = new GraphNode("l2");
    l1.setNodeType(NodeType.LATENT);
    l2.setNodeType(NodeType.LATENT);
    List<Node> measures = new ArrayList<>();
    int numMeasures = 8;
    for (int i = 0; i < numMeasures; i++) {
        measures.add(new GraphNode("X" + (i + 1)));
    }
    graph.addNode(l1);
    graph.addNode(l2);
    for (int i = 0; i < numMeasures; i++) {
        graph.addNode(measures.get(i));
        graph.addDirectedEdge(l1, measures.get(i));
        graph.addDirectedEdge(l2, measures.get(i));
    }
    SemPm pm = new SemPm(graph);
    Parameters params = new Parameters();
    return new SemIm(pm, params);
}
Also used : Parameters(edu.cmu.tetrad.util.Parameters) ArrayList(java.util.ArrayList) SemPm(edu.cmu.tetrad.sem.SemPm) SemIm(edu.cmu.tetrad.sem.SemIm)

Example 69 with SemIm

use of edu.cmu.tetrad.sem.SemIm in project tetrad by cmu-phil.

the class TestPc method printStatsPcRegression.

private double[] printStatsPcRegression(String[] algorithms, int t, boolean directed, int numRuns, double alpha, double penaltyDiscount, int numMeasures, int numLatents, double edgeFactor, int sampleSize) {
    NumberFormat nf = new DecimalFormat("0.00");
    double sumAdjPrecision = 0.0;
    double sumAdjRecall = 0.0;
    int count = 0;
    for (int i = 0; i < numRuns; i++) {
        int numEdges = (int) (edgeFactor * (numMeasures + numLatents));
        List<Node> nodes = new ArrayList<>();
        List<String> names = new ArrayList<>();
        for (int r = 0; r < numMeasures + numLatents; r++) {
            String name = "X" + (r + 1);
            nodes.add(new ContinuousVariable(name));
            names.add(name);
        }
        Graph dag = GraphUtils.randomGraphRandomForwardEdges(nodes, numLatents, numEdges, 10, 10, 10, false);
        SemPm pm = new SemPm(dag);
        SemIm im = new SemIm(pm);
        DataSet data = im.simulateData(sampleSize, false);
        // Graph comparison = dag;
        Graph comparison = new DagToPag(dag).convert();
        // Graph comparison = new Pc(new IndTestDSep(dag)).search();
        IndTestFisherZ test = new IndTestFisherZ(data, alpha);
        SemBicScore score = new SemBicScore(new CovarianceMatrixOnTheFly(data));
        score.setPenaltyDiscount(penaltyDiscount);
        GraphSearch search;
        Graph out;
        Node target = null;
        for (Node node : nodes) {
            if (node.getNodeType() == NodeType.MEASURED) {
                target = node;
                break;
            }
        }
        switch(t) {
            case 0:
                search = new Pc(test);
                out = search.search();
                break;
            case 1:
                search = new Cpc(test);
                out = search.search();
                break;
            case 2:
                search = new Fges(score);
                out = search.search();
                break;
            case 3:
                search = new Fci(test);
                out = search.search();
                break;
            case 4:
                search = new GFci(test, score);
                out = search.search();
                break;
            case 5:
                search = new Rfci(test);
                out = search.search();
                break;
            case 6:
                search = new Cfci(test);
                out = search.search();
                break;
            case 7:
                out = getRegressionGraph(data, target);
                break;
            default:
                throw new IllegalStateException();
        }
        target = out.getNode(target.getName());
        out = trim(out, target);
        long start = System.currentTimeMillis();
        long stop = System.currentTimeMillis();
        long elapsed = stop - start;
        out = GraphUtils.replaceNodes(out, dag.getNodes());
        for (Node node : dag.getNodes()) {
            if (!out.containsNode(node)) {
                out.addNode(node);
            }
        }
        int adjTp = 0;
        int adjFp = 0;
        int adjFn = 0;
        for (Node node : out.getAdjacentNodes(target)) {
            if (comparison.isAdjacentTo(target, node)) {
                adjTp++;
            } else {
                adjFp++;
            }
        }
        for (Node node : dag.getAdjacentNodes(target)) {
            if (!out.isAdjacentTo(target, node)) {
                adjFn++;
            }
        }
        double adjPrecision = adjTp / (double) (adjTp + adjFp);
        double adjRecall = adjTp / (double) (adjTp + adjFn);
        if (!Double.isNaN(adjPrecision)) {
            sumAdjPrecision += adjPrecision;
        }
        if (!Double.isNaN(adjRecall)) {
            sumAdjRecall += adjRecall;
        }
        count++;
    }
    double avgAdjPrecision = sumAdjPrecision / (double) count;
    double avgAdjRecall = sumAdjRecall / (double) count;
    double[] ret = new double[] { avgAdjPrecision, avgAdjRecall };
    System.out.println();
    System.out.println(algorithms[t] + " adj precision " + nf.format(avgAdjPrecision));
    System.out.println(algorithms[t] + " adj recall " + nf.format(avgAdjRecall));
    return ret;
}
Also used : DecimalFormat(java.text.DecimalFormat) SemPm(edu.cmu.tetrad.sem.SemPm) SemIm(edu.cmu.tetrad.sem.SemIm) NumberFormat(java.text.NumberFormat)

Example 70 with SemIm

use of edu.cmu.tetrad.sem.SemIm in project tetrad by cmu-phil.

the class TestPc method testPcStable2.

@Test
public void testPcStable2() {
    RandomUtil.getInstance().setSeed(1450030184196L);
    List<Node> nodes = new ArrayList<>();
    for (int i = 0; i < 10; i++) {
        nodes.add(new ContinuousVariable("X" + (i + 1)));
    }
    Graph graph = GraphUtils.randomGraph(nodes, 0, 10, 30, 15, 15, false);
    SemPm pm = new SemPm(graph);
    SemIm im = new SemIm(pm);
    DataSet data = im.simulateData(200, false);
    TetradLogger.getInstance().setForceLog(false);
    IndependenceTest test = new IndTestFisherZ(data, 0.05);
    PcStableMax pc = new PcStableMax(test);
    pc.setVerbose(false);
    Graph pattern = pc.search();
    for (int i = 0; i < 1; i++) {
        DataSet data2 = DataUtils.reorderColumns(data);
        IndependenceTest test2 = new IndTestFisherZ(data2, 0.05);
        PcStableMax pc2 = new PcStableMax(test2);
        pc2.setVerbose(false);
        Graph pattern2 = pc2.search();
        assertTrue(pattern.equals(pattern2));
    }
}
Also used : SemPm(edu.cmu.tetrad.sem.SemPm) SemIm(edu.cmu.tetrad.sem.SemIm) Test(org.junit.Test)

Aggregations

SemIm (edu.cmu.tetrad.sem.SemIm)81 SemPm (edu.cmu.tetrad.sem.SemPm)71 Test (org.junit.Test)46 DataSet (edu.cmu.tetrad.data.DataSet)28 ArrayList (java.util.ArrayList)28 Graph (edu.cmu.tetrad.graph.Graph)26 Node (edu.cmu.tetrad.graph.Node)19 ContinuousVariable (edu.cmu.tetrad.data.ContinuousVariable)16 EdgeListGraph (edu.cmu.tetrad.graph.EdgeListGraph)16 SemEstimator (edu.cmu.tetrad.sem.SemEstimator)15 Dag (edu.cmu.tetrad.graph.Dag)10 DMSearch (edu.cmu.tetrad.search.DMSearch)9 StandardizedSemIm (edu.cmu.tetrad.sem.StandardizedSemIm)9 TetradMatrix (edu.cmu.tetrad.util.TetradMatrix)7 NumberFormat (java.text.NumberFormat)7 GraphNode (edu.cmu.tetrad.graph.GraphNode)5 IndependenceTest (edu.cmu.tetrad.search.IndependenceTest)4 DecimalFormat (java.text.DecimalFormat)4 List (java.util.List)4 CovarianceMatrix (edu.cmu.tetrad.data.CovarianceMatrix)3