Search in sources :

Example 21 with SemIm

use of edu.cmu.tetrad.sem.SemIm in project tetrad by cmu-phil.

the class TestSemVarMeans method testMeansRecursive.

@Test
public void testMeansRecursive() {
    Graph graph = constructGraph1();
    SemPm semPm1 = new SemPm(graph);
    List<Parameter> parameters = semPm1.getParameters();
    for (Parameter p : parameters) {
        p.setInitializedRandomly(false);
    }
    SemIm semIm1 = new SemIm(semPm1);
    double[] means = { 5.0, 4.0, 3.0, 2.0, 1.0 };
    RandomUtil.getInstance().setSeed(-379467L);
    for (int i = 0; i < semIm1.getVariableNodes().size(); i++) {
        Node node = semIm1.getVariableNodes().get(i);
        semIm1.setMean(node, means[i]);
    }
    DataSet dataSet = semIm1.simulateDataRecursive(1000, false);
    SemEstimator semEst = new SemEstimator(dataSet, semPm1);
    semEst.estimate();
    SemIm estSemIm = semEst.getEstimatedSem();
    List<Node> nodes = semPm1.getVariableNodes();
    for (Node node : nodes) {
        double mean = semIm1.getMean(node);
        assertEquals(mean, estSemIm.getMean(node), 0.5);
    }
}
Also used : EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) Graph(edu.cmu.tetrad.graph.Graph) DataSet(edu.cmu.tetrad.data.DataSet) GraphNode(edu.cmu.tetrad.graph.GraphNode) Node(edu.cmu.tetrad.graph.Node) SemPm(edu.cmu.tetrad.sem.SemPm) Parameter(edu.cmu.tetrad.sem.Parameter) SemEstimator(edu.cmu.tetrad.sem.SemEstimator) SemIm(edu.cmu.tetrad.sem.SemIm) Test(org.junit.Test)

Example 22 with SemIm

use of edu.cmu.tetrad.sem.SemIm in project tetrad by cmu-phil.

the class TestPurify method test1b.

@Test
public void test1b() {
    RandomUtil.getInstance().setSeed(48290483L);
    SemGraph graph = new SemGraph();
    Node l1 = new GraphNode("L1");
    l1.setNodeType(NodeType.LATENT);
    Node l2 = new GraphNode("L2");
    l2.setNodeType(NodeType.LATENT);
    Node x1 = new GraphNode("X1");
    Node x2 = new GraphNode("X2");
    Node x3 = new GraphNode("X3");
    Node x4 = new GraphNode("X4");
    Node x5 = new GraphNode("X5");
    Node x6 = new GraphNode("X6");
    Node x7 = new GraphNode("X7");
    Node x8 = new GraphNode("X8");
    Node x9 = new GraphNode("X9");
    Node x10 = new GraphNode("X10");
    Node x11 = new GraphNode("X11");
    Node x12 = new GraphNode("X12");
    graph.addNode(l1);
    graph.addNode(l2);
    graph.addNode(x1);
    graph.addNode(x2);
    graph.addNode(x3);
    graph.addNode(x4);
    graph.addNode(x5);
    graph.addNode(x6);
    graph.addNode(x7);
    graph.addNode(x8);
    graph.addNode(x9);
    graph.addNode(x10);
    graph.addNode(x11);
    graph.addNode(x12);
    graph.addDirectedEdge(l1, x1);
    graph.addDirectedEdge(l1, x2);
    graph.addDirectedEdge(l1, x3);
    graph.addDirectedEdge(l1, x4);
    graph.addDirectedEdge(l1, x5);
    graph.addDirectedEdge(l1, x5);
    graph.addDirectedEdge(l1, x6);
    graph.addDirectedEdge(l2, x6);
    graph.addDirectedEdge(l2, x7);
    graph.addDirectedEdge(l2, x8);
    graph.addDirectedEdge(l2, x9);
    graph.addDirectedEdge(l2, x10);
    graph.addDirectedEdge(l2, x11);
    graph.addDirectedEdge(l2, x12);
    graph.addDirectedEdge(x3, x4);
    graph.addDirectedEdge(x9, x10);
    SemPm pm = new SemPm(graph);
    SemIm im = new SemIm(pm);
    DataSet data = im.simulateData(3000, false);
    List<List<Node>> partition = new ArrayList<>();
    List<Node> cluster1 = new ArrayList<>();
    cluster1.add(x1);
    cluster1.add(x2);
    cluster1.add(x3);
    cluster1.add(x4);
    cluster1.add(x5);
    List<Node> cluster2 = new ArrayList<>();
    cluster2.add(x7);
    cluster2.add(x8);
    cluster2.add(x9);
    cluster2.add(x10);
    cluster2.add(x11);
    cluster2.add(x12);
    partition.add(cluster1);
    partition.add(cluster2);
    TetradTest test = new ContinuousTetradTest(data, TestType.TETRAD_WISHART, 0.0001);
    IPurify purify = new PurifyTetradBased2(test);
    purify.setTrueGraph(graph);
    List<List<Node>> clustering = purify.purify(partition);
    assertEquals(4, clustering.get(0).size());
    assertEquals(5, clustering.get(1).size());
}
Also used : ArrayList(java.util.ArrayList) SemPm(edu.cmu.tetrad.sem.SemPm) List(java.util.List) ArrayList(java.util.ArrayList) SemIm(edu.cmu.tetrad.sem.SemIm) Test(org.junit.Test)

Example 23 with SemIm

use of edu.cmu.tetrad.sem.SemIm in project tetrad by cmu-phil.

the class TestPurify method test2.

@Test
public void test2() {
    RandomUtil.getInstance().setSeed(48290483L);
    Graph graph = new EdgeListGraph(DataGraphUtils.randomSingleFactorModel(3, 3, 5, 0, 0, 0));
    SemPm pm = new SemPm(graph);
    SemIm im = new SemIm(pm);
    DataSet data = im.simulateData(1000, false);
    List<Node> latents = new ArrayList<>();
    for (Node node : graph.getNodes()) {
        if (node.getNodeType() == NodeType.LATENT)
            latents.add(node);
    }
    Graph structuralGraph = graph.subgraph(latents);
    List<List<Node>> clustering = new ArrayList<>();
    for (Node node : latents) {
        List<Node> adj = graph.getAdjacentNodes(node);
        adj.removeAll(latents);
        clustering.add(adj);
    }
    ContinuousTetradTest test = new ContinuousTetradTest(data, TestType.TETRAD_WISHART, 0.001);
    IPurify purify = new PurifyTetradBased2(test);
    List<List<Node>> purifiedClustering = purify.purify(clustering);
    List<String> latentsNames = new ArrayList<>();
    for (int i = 0; i < latents.size(); i++) {
        latentsNames.add(latents.get(i).getName());
    }
    Mimbuild2 mimbuild = new Mimbuild2();
    mimbuild.setAlpha(0.0001);
    Graph _graph = mimbuild.search(purifiedClustering, latentsNames, new CovarianceMatrix(data));
    List<Node> _latents = new ArrayList<>();
    for (Node node : _graph.getNodes()) {
        if (node.getNodeType() == NodeType.LATENT)
            _latents.add(node);
    }
    Graph _structuralGraph = _graph.subgraph(_latents);
    assertEquals(2, _structuralGraph.getNumEdges());
}
Also used : Mimbuild2(edu.cmu.tetrad.search.Mimbuild2) ArrayList(java.util.ArrayList) SemPm(edu.cmu.tetrad.sem.SemPm) List(java.util.List) ArrayList(java.util.ArrayList) SemIm(edu.cmu.tetrad.sem.SemIm) Test(org.junit.Test)

Example 24 with SemIm

use of edu.cmu.tetrad.sem.SemIm in project tetrad by cmu-phil.

the class TestPurify method test1.

@Test
public void test1() {
    RandomUtil.getInstance().setSeed(48290483L);
    SemGraph graph = new SemGraph();
    Node l1 = new GraphNode("L1");
    l1.setNodeType(NodeType.LATENT);
    Node l2 = new GraphNode("L2");
    l2.setNodeType(NodeType.LATENT);
    Node l3 = new GraphNode("L3");
    l3.setNodeType(NodeType.LATENT);
    Node x1 = new GraphNode("X1");
    Node x2 = new GraphNode("X2");
    Node x3 = new GraphNode("X3");
    Node x4 = new GraphNode("X4");
    Node x4b = new GraphNode("X4b");
    Node x5 = new GraphNode("X5");
    Node x6 = new GraphNode("X6");
    Node x7 = new GraphNode("X7");
    Node x8 = new GraphNode("X8");
    Node x8b = new GraphNode("X8b");
    Node x9 = new GraphNode("X9");
    Node x10 = new GraphNode("X10");
    Node x11 = new GraphNode("X11");
    Node x12 = new GraphNode("X12");
    Node x12b = new GraphNode("X12b");
    graph.addNode(l1);
    graph.addNode(l2);
    graph.addNode(l3);
    graph.addNode(x1);
    graph.addNode(x2);
    graph.addNode(x3);
    graph.addNode(x4);
    graph.addNode(x4b);
    graph.addNode(x5);
    graph.addNode(x6);
    graph.addNode(x7);
    graph.addNode(x8);
    graph.addNode(x8b);
    graph.addNode(x9);
    graph.addNode(x10);
    graph.addNode(x11);
    graph.addNode(x12);
    graph.addNode(x12b);
    graph.addDirectedEdge(l1, x1);
    graph.addDirectedEdge(l1, x2);
    graph.addDirectedEdge(l1, x3);
    graph.addDirectedEdge(l1, x4);
    graph.addDirectedEdge(l1, x4b);
    graph.addDirectedEdge(l1, x5);
    graph.addDirectedEdge(l2, x5);
    graph.addDirectedEdge(l2, x6);
    graph.addDirectedEdge(l2, x7);
    graph.addDirectedEdge(l2, x8);
    graph.addDirectedEdge(l2, x8b);
    graph.addDirectedEdge(l3, x9);
    graph.addDirectedEdge(l3, x10);
    graph.addDirectedEdge(l3, x11);
    graph.addDirectedEdge(l3, x12);
    graph.addDirectedEdge(l3, x12b);
    graph.addDirectedEdge(x1, x4);
    SemPm pm = new SemPm(graph);
    SemIm im = new SemIm(pm);
    DataSet data = im.simulateData(1000, false);
    List<List<Node>> partition = new ArrayList<>();
    List<Node> cluster1 = new ArrayList<>();
    cluster1.add(x1);
    cluster1.add(x2);
    cluster1.add(x3);
    cluster1.add(x4);
    cluster1.add(x4b);
    cluster1.add(x5);
    List<Node> cluster2 = new ArrayList<>();
    cluster2.add(x5);
    cluster2.add(x6);
    cluster2.add(x7);
    cluster2.add(x8);
    cluster2.add(x8b);
    List<Node> cluster3 = new ArrayList<>();
    cluster3.add(x9);
    cluster3.add(x10);
    cluster3.add(x11);
    cluster3.add(x12);
    cluster3.add(x12b);
    partition.add(cluster1);
    partition.add(cluster2);
    partition.add(cluster3);
    TetradTest test = new ContinuousTetradTest(data, TestType.TETRAD_WISHART, 0.05);
    IPurify purify = new PurifyTetradBased2(test);
    purify.setTrueGraph(graph);
    List<List<Node>> partition2 = purify.purify(partition);
    assertEquals(3, partition2.get(0).size());
    assertEquals(2, partition2.get(1).size());
    assertEquals(5, partition2.get(2).size());
}
Also used : ArrayList(java.util.ArrayList) SemPm(edu.cmu.tetrad.sem.SemPm) List(java.util.List) ArrayList(java.util.ArrayList) SemIm(edu.cmu.tetrad.sem.SemIm) Test(org.junit.Test)

Example 25 with SemIm

use of edu.cmu.tetrad.sem.SemIm in project tetrad by cmu-phil.

the class TestRicf method test2.

@Test
public void test2() {
    List<Node> nodes = new ArrayList<>();
    for (int i = 0; i < 10; i++) {
        nodes.add(new ContinuousVariable("X" + (i + 1)));
    }
    Graph graph = new Dag(GraphUtils.randomGraph(nodes, 0, 10, 30, 15, 15, false));
    SemPm pm = new SemPm(graph);
    SemIm im = new SemIm(pm);
    DataSet data = im.simulateData(1000, false);
    CovarianceMatrix cov = new CovarianceMatrix(data);
    Ricf.RicfResult result = new Ricf().ricf(new SemGraph(graph), cov, 0.001);
    result.getBhat();
}
Also used : ArrayList(java.util.ArrayList) SemPm(edu.cmu.tetrad.sem.SemPm) Ricf(edu.cmu.tetrad.sem.Ricf) SemIm(edu.cmu.tetrad.sem.SemIm) Test(org.junit.Test)

Aggregations

SemIm (edu.cmu.tetrad.sem.SemIm)81 SemPm (edu.cmu.tetrad.sem.SemPm)71 Test (org.junit.Test)46 DataSet (edu.cmu.tetrad.data.DataSet)28 ArrayList (java.util.ArrayList)28 Graph (edu.cmu.tetrad.graph.Graph)26 Node (edu.cmu.tetrad.graph.Node)19 ContinuousVariable (edu.cmu.tetrad.data.ContinuousVariable)16 EdgeListGraph (edu.cmu.tetrad.graph.EdgeListGraph)16 SemEstimator (edu.cmu.tetrad.sem.SemEstimator)15 Dag (edu.cmu.tetrad.graph.Dag)10 DMSearch (edu.cmu.tetrad.search.DMSearch)9 StandardizedSemIm (edu.cmu.tetrad.sem.StandardizedSemIm)9 TetradMatrix (edu.cmu.tetrad.util.TetradMatrix)7 NumberFormat (java.text.NumberFormat)7 GraphNode (edu.cmu.tetrad.graph.GraphNode)5 IndependenceTest (edu.cmu.tetrad.search.IndependenceTest)4 DecimalFormat (java.text.DecimalFormat)4 List (java.util.List)4 CovarianceMatrix (edu.cmu.tetrad.data.CovarianceMatrix)3