Search in sources :

Example 61 with SemPm

use of edu.cmu.tetrad.sem.SemPm in project tetrad by cmu-phil.

the class TestSemXml method sampleSemIm1.

private static SemIm sampleSemIm1() {
    List<Node> nodes = new ArrayList<>();
    for (int i = 0; i < 5; i++) {
        nodes.add(new ContinuousVariable("X" + (i + 1)));
    }
    Graph graph = new Dag(GraphUtils.randomGraph(nodes, 0, 5, 30, 15, 15, true));
    SemPm pm = new SemPm(graph);
    return new SemIm(pm);
}
Also used : ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) Graph(edu.cmu.tetrad.graph.Graph) Node(edu.cmu.tetrad.graph.Node) ArrayList(java.util.ArrayList) SemPm(edu.cmu.tetrad.sem.SemPm) Dag(edu.cmu.tetrad.graph.Dag) SemIm(edu.cmu.tetrad.sem.SemIm)

Example 62 with SemPm

use of edu.cmu.tetrad.sem.SemPm in project tetrad by cmu-phil.

the class TestStatUtils method testConditionalCorrelation.

/**
 * Tests that the unconditional correlations and covariances are correct,
 * at least for the unconditional tests.
 */
@Test
public void testConditionalCorrelation() {
    RandomUtil.getInstance().setSeed(30299533L);
    // Make sure the unconditional correlations and covariances are OK.
    List<Node> nodes1 = new ArrayList<>();
    for (int i = 0; i < 5; i++) {
        nodes1.add(new ContinuousVariable("X" + (i + 1)));
    }
    Graph graph = new Dag(GraphUtils.randomGraph(nodes1, 0, 5, 3, 3, 3, false));
    SemPm pm = new SemPm(graph);
    SemIm im = new SemIm(pm);
    DataSet dataSet = im.simulateData(1000, false);
    double[] x = dataSet.getDoubleData().getColumn(0).toArray();
    double[] y = dataSet.getDoubleData().getColumn(1).toArray();
    double r1 = StatUtils.correlation(x, y);
    double s1 = StatUtils.covariance(x, y);
    double v1 = StatUtils.variance(x);
    double sd1 = StatUtils.sd(x);
    ICovarianceMatrix cov = new CovarianceMatrix(dataSet);
    TetradMatrix _cov = cov.getMatrix();
    double r2 = StatUtils.partialCorrelation(_cov, 0, 1);
    double s2 = StatUtils.partialCovariance(_cov, 0, 1);
    double v2 = StatUtils.partialVariance(_cov, 0);
    double sd2 = StatUtils.partialStandardDeviation(_cov, 0);
    assertEquals(r1, r2, .1);
    assertEquals(s1, s2, .1);
    assertEquals(v1, v2, .1);
    assertEquals(sd1, sd2, 0.1);
}
Also used : Node(edu.cmu.tetrad.graph.Node) ArrayList(java.util.ArrayList) TetradMatrix(edu.cmu.tetrad.util.TetradMatrix) Dag(edu.cmu.tetrad.graph.Dag) Graph(edu.cmu.tetrad.graph.Graph) SemPm(edu.cmu.tetrad.sem.SemPm) SemIm(edu.cmu.tetrad.sem.SemIm) Test(org.junit.Test)

Example 63 with SemPm

use of edu.cmu.tetrad.sem.SemPm in project tetrad by cmu-phil.

the class TestDeltaSextadTest method test2.

@Test
public void test2() {
    int c = 2;
    int m = 2;
    int p = 6;
    Graph g = new EdgeListGraph();
    List<List<Node>> varClusters = new ArrayList<>();
    List<List<Node>> latents = new ArrayList<>();
    List<Node> vars = new ArrayList<>();
    for (int y = 0; y < c; y++) {
        varClusters.add(new ArrayList<Node>());
        latents.add(new ArrayList<Node>());
    }
    int e = 0;
    for (int y = 0; y < c; y++) {
        for (int i = 0; i < p; i++) {
            GraphNode n = new GraphNode("V" + ++e);
            vars.add(n);
            varClusters.get(y).add(n);
            g.addNode(n);
        }
    }
    List<Node> l = new ArrayList<>();
    int f = 0;
    for (int y = 0; y < c; y++) {
        for (int j = 0; j < m; j++) {
            Node _l = new GraphNode("L" + ++f);
            _l.setNodeType(NodeType.LATENT);
            l.add(_l);
            latents.get(y).add(_l);
            g.addNode(_l);
        }
    }
    for (int y = 0; y < c; y++) {
        for (int i = 0; i < m; i++) {
            for (int j = 0; j < p; j++) {
                g.addDirectedEdge(latents.get(y).get(i), varClusters.get(y).get(j));
            }
        }
    }
    for (int y = 1; y < c; y++) {
        for (int j = 0; j < m; j++) {
            g.addDirectedEdge(latents.get(y - 1).get(j), latents.get(y).get(j));
        }
    }
    SemPm pm = new SemPm(g);
    SemIm im = new SemIm(pm);
    DataSet data = im.simulateData(1000, false);
    List<Integer> indices = new ArrayList<>();
    indices.add(0);
    indices.add(1);
    indices.add(2);
    indices.add(4);
    indices.add(5);
    indices.add(7);
    Collections.shuffle(indices);
    // Node x1 = data.getVariable(indices.get(0));
    // Node x2 = data.getVariable(indices.get(1));
    // Node x3 = data.getVariable(indices.get(2));
    // Node x4 = data.getVariable(indices.get(3));
    // Node x5 = data.getVariable(indices.get(4));
    // Node x6 = data.getVariable(indices.get(5));
    int x1 = indices.get(0);
    int x2 = indices.get(1);
    int x3 = indices.get(2);
    int x4 = indices.get(3);
    int x5 = indices.get(4);
    int x6 = indices.get(5);
    DeltaSextadTest test = new DeltaSextadTest(data);
    // Should be invariant to changes or order of the first three or of the last three variables.
    double a = test.getPValue(new IntSextad(x1, x2, x3, x4, x5, x6));
    double b = test.getPValue(new IntSextad(x2, x3, x1, x5, x4, x6));
    assertEquals(a, b, 1e-7);
}
Also used : IntSextad(edu.cmu.tetrad.search.IntSextad) DataSet(edu.cmu.tetrad.data.DataSet) ArrayList(java.util.ArrayList) DeltaSextadTest(edu.cmu.tetrad.search.DeltaSextadTest) SemPm(edu.cmu.tetrad.sem.SemPm) ArrayList(java.util.ArrayList) List(java.util.List) SemIm(edu.cmu.tetrad.sem.SemIm) Test(org.junit.Test) DeltaSextadTest(edu.cmu.tetrad.search.DeltaSextadTest)

Example 64 with SemPm

use of edu.cmu.tetrad.sem.SemPm in project tetrad by cmu-phil.

the class TestDeltaSextadTest method getSem1.

private SemIm getSem1() {
    Graph graph = new EdgeListGraph();
    Node l1 = new GraphNode("l1");
    Node l2 = new GraphNode("l2");
    l1.setNodeType(NodeType.LATENT);
    l2.setNodeType(NodeType.LATENT);
    List<Node> measures = new ArrayList<>();
    int numMeasures = 8;
    for (int i = 0; i < numMeasures; i++) {
        measures.add(new GraphNode("X" + (i + 1)));
    }
    graph.addNode(l1);
    graph.addNode(l2);
    for (int i = 0; i < numMeasures; i++) {
        graph.addNode(measures.get(i));
        graph.addDirectedEdge(l1, measures.get(i));
        graph.addDirectedEdge(l2, measures.get(i));
    }
    SemPm pm = new SemPm(graph);
    Parameters params = new Parameters();
    return new SemIm(pm, params);
}
Also used : Parameters(edu.cmu.tetrad.util.Parameters) ArrayList(java.util.ArrayList) SemPm(edu.cmu.tetrad.sem.SemPm) SemIm(edu.cmu.tetrad.sem.SemIm)

Example 65 with SemPm

use of edu.cmu.tetrad.sem.SemPm in project tetrad by cmu-phil.

the class TestPc method printStatsPcRegression.

private double[] printStatsPcRegression(String[] algorithms, int t, boolean directed, int numRuns, double alpha, double penaltyDiscount, int numMeasures, int numLatents, double edgeFactor, int sampleSize) {
    NumberFormat nf = new DecimalFormat("0.00");
    double sumAdjPrecision = 0.0;
    double sumAdjRecall = 0.0;
    int count = 0;
    for (int i = 0; i < numRuns; i++) {
        int numEdges = (int) (edgeFactor * (numMeasures + numLatents));
        List<Node> nodes = new ArrayList<>();
        List<String> names = new ArrayList<>();
        for (int r = 0; r < numMeasures + numLatents; r++) {
            String name = "X" + (r + 1);
            nodes.add(new ContinuousVariable(name));
            names.add(name);
        }
        Graph dag = GraphUtils.randomGraphRandomForwardEdges(nodes, numLatents, numEdges, 10, 10, 10, false);
        SemPm pm = new SemPm(dag);
        SemIm im = new SemIm(pm);
        DataSet data = im.simulateData(sampleSize, false);
        // Graph comparison = dag;
        Graph comparison = new DagToPag(dag).convert();
        // Graph comparison = new Pc(new IndTestDSep(dag)).search();
        IndTestFisherZ test = new IndTestFisherZ(data, alpha);
        SemBicScore score = new SemBicScore(new CovarianceMatrixOnTheFly(data));
        score.setPenaltyDiscount(penaltyDiscount);
        GraphSearch search;
        Graph out;
        Node target = null;
        for (Node node : nodes) {
            if (node.getNodeType() == NodeType.MEASURED) {
                target = node;
                break;
            }
        }
        switch(t) {
            case 0:
                search = new Pc(test);
                out = search.search();
                break;
            case 1:
                search = new Cpc(test);
                out = search.search();
                break;
            case 2:
                search = new Fges(score);
                out = search.search();
                break;
            case 3:
                search = new Fci(test);
                out = search.search();
                break;
            case 4:
                search = new GFci(test, score);
                out = search.search();
                break;
            case 5:
                search = new Rfci(test);
                out = search.search();
                break;
            case 6:
                search = new Cfci(test);
                out = search.search();
                break;
            case 7:
                out = getRegressionGraph(data, target);
                break;
            default:
                throw new IllegalStateException();
        }
        target = out.getNode(target.getName());
        out = trim(out, target);
        long start = System.currentTimeMillis();
        long stop = System.currentTimeMillis();
        long elapsed = stop - start;
        out = GraphUtils.replaceNodes(out, dag.getNodes());
        for (Node node : dag.getNodes()) {
            if (!out.containsNode(node)) {
                out.addNode(node);
            }
        }
        int adjTp = 0;
        int adjFp = 0;
        int adjFn = 0;
        for (Node node : out.getAdjacentNodes(target)) {
            if (comparison.isAdjacentTo(target, node)) {
                adjTp++;
            } else {
                adjFp++;
            }
        }
        for (Node node : dag.getAdjacentNodes(target)) {
            if (!out.isAdjacentTo(target, node)) {
                adjFn++;
            }
        }
        double adjPrecision = adjTp / (double) (adjTp + adjFp);
        double adjRecall = adjTp / (double) (adjTp + adjFn);
        if (!Double.isNaN(adjPrecision)) {
            sumAdjPrecision += adjPrecision;
        }
        if (!Double.isNaN(adjRecall)) {
            sumAdjRecall += adjRecall;
        }
        count++;
    }
    double avgAdjPrecision = sumAdjPrecision / (double) count;
    double avgAdjRecall = sumAdjRecall / (double) count;
    double[] ret = new double[] { avgAdjPrecision, avgAdjRecall };
    System.out.println();
    System.out.println(algorithms[t] + " adj precision " + nf.format(avgAdjPrecision));
    System.out.println(algorithms[t] + " adj recall " + nf.format(avgAdjRecall));
    return ret;
}
Also used : DecimalFormat(java.text.DecimalFormat) SemPm(edu.cmu.tetrad.sem.SemPm) SemIm(edu.cmu.tetrad.sem.SemIm) NumberFormat(java.text.NumberFormat)

Aggregations

SemPm (edu.cmu.tetrad.sem.SemPm)77 SemIm (edu.cmu.tetrad.sem.SemIm)71 Test (org.junit.Test)44 ArrayList (java.util.ArrayList)29 DataSet (edu.cmu.tetrad.data.DataSet)28 Graph (edu.cmu.tetrad.graph.Graph)25 ContinuousVariable (edu.cmu.tetrad.data.ContinuousVariable)18 Node (edu.cmu.tetrad.graph.Node)18 SemEstimator (edu.cmu.tetrad.sem.SemEstimator)16 EdgeListGraph (edu.cmu.tetrad.graph.EdgeListGraph)15 Dag (edu.cmu.tetrad.graph.Dag)10 DMSearch (edu.cmu.tetrad.search.DMSearch)9 StandardizedSemIm (edu.cmu.tetrad.sem.StandardizedSemIm)9 NumberFormat (java.text.NumberFormat)7 TetradMatrix (edu.cmu.tetrad.util.TetradMatrix)6 ICovarianceMatrix (edu.cmu.tetrad.data.ICovarianceMatrix)5 GraphNode (edu.cmu.tetrad.graph.GraphNode)5 CovarianceMatrix (edu.cmu.tetrad.data.CovarianceMatrix)4 IndependenceTest (edu.cmu.tetrad.search.IndependenceTest)4 Parameters (edu.cmu.tetrad.util.Parameters)4