Search in sources :

Example 21 with SemPm

use of edu.cmu.tetrad.sem.SemPm in project tetrad by cmu-phil.

the class TestSemVarMeans method testMeansCholesky.

@Test
public void testMeansCholesky() {
    Graph graph = constructGraph1();
    SemPm semPm1 = new SemPm(graph);
    List<Parameter> parameters = semPm1.getParameters();
    for (Parameter p : parameters) {
        p.setInitializedRandomly(false);
    }
    SemIm semIm1 = new SemIm(semPm1);
    double[] means = { 5.0, 4.0, 3.0, 2.0, 1.0 };
    RandomUtil.getInstance().setSeed(-379467L);
    for (int i = 0; i < semIm1.getVariableNodes().size(); i++) {
        Node node = semIm1.getVariableNodes().get(i);
        semIm1.setMean(node, means[i]);
    }
    DataSet dataSet = semIm1.simulateDataCholesky(1000, false);
    SemEstimator semEst = new SemEstimator(dataSet, semPm1);
    semEst.estimate();
    SemIm estSemIm = semEst.getEstimatedSem();
    List<Node> nodes = semPm1.getVariableNodes();
    for (Node node : nodes) {
        double mean = semIm1.getMean(node);
        assertEquals(mean, estSemIm.getMean(node), 0.6);
    }
}
Also used : EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) Graph(edu.cmu.tetrad.graph.Graph) DataSet(edu.cmu.tetrad.data.DataSet) GraphNode(edu.cmu.tetrad.graph.GraphNode) Node(edu.cmu.tetrad.graph.Node) SemPm(edu.cmu.tetrad.sem.SemPm) Parameter(edu.cmu.tetrad.sem.Parameter) SemEstimator(edu.cmu.tetrad.sem.SemEstimator) SemIm(edu.cmu.tetrad.sem.SemIm) Test(org.junit.Test)

Example 22 with SemPm

use of edu.cmu.tetrad.sem.SemPm in project tetrad by cmu-phil.

the class TestSemVarMeans method testMeansReducedForm.

@Test
public void testMeansReducedForm() {
    Graph graph = constructGraph1();
    SemPm semPm1 = new SemPm(graph);
    List<Parameter> parameters = semPm1.getParameters();
    for (Parameter p : parameters) {
        p.setInitializedRandomly(false);
    }
    SemIm semIm1 = new SemIm(semPm1);
    double[] means = { 5.0, 4.0, 3.0, 2.0, 1.0 };
    RandomUtil.getInstance().setSeed(-379467L);
    for (int i = 0; i < semIm1.getVariableNodes().size(); i++) {
        Node node = semIm1.getVariableNodes().get(i);
        semIm1.setMean(node, means[i]);
    }
    DataSet dataSet = semIm1.simulateDataReducedForm(1000, false);
    SemEstimator semEst = new SemEstimator(dataSet, semPm1);
    semEst.estimate();
    SemIm estSemIm = semEst.getEstimatedSem();
    List<Node> nodes = semPm1.getVariableNodes();
    for (Node node : nodes) {
        double mean = semIm1.getMean(node);
        assertEquals(mean, estSemIm.getMean(node), 0.5);
    }
}
Also used : EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) Graph(edu.cmu.tetrad.graph.Graph) DataSet(edu.cmu.tetrad.data.DataSet) GraphNode(edu.cmu.tetrad.graph.GraphNode) Node(edu.cmu.tetrad.graph.Node) SemPm(edu.cmu.tetrad.sem.SemPm) Parameter(edu.cmu.tetrad.sem.Parameter) SemEstimator(edu.cmu.tetrad.sem.SemEstimator) SemIm(edu.cmu.tetrad.sem.SemIm) Test(org.junit.Test)

Example 23 with SemPm

use of edu.cmu.tetrad.sem.SemPm in project tetrad by cmu-phil.

the class TestSemVarMeans method testMeansRecursive.

@Test
public void testMeansRecursive() {
    Graph graph = constructGraph1();
    SemPm semPm1 = new SemPm(graph);
    List<Parameter> parameters = semPm1.getParameters();
    for (Parameter p : parameters) {
        p.setInitializedRandomly(false);
    }
    SemIm semIm1 = new SemIm(semPm1);
    double[] means = { 5.0, 4.0, 3.0, 2.0, 1.0 };
    RandomUtil.getInstance().setSeed(-379467L);
    for (int i = 0; i < semIm1.getVariableNodes().size(); i++) {
        Node node = semIm1.getVariableNodes().get(i);
        semIm1.setMean(node, means[i]);
    }
    DataSet dataSet = semIm1.simulateDataRecursive(1000, false);
    SemEstimator semEst = new SemEstimator(dataSet, semPm1);
    semEst.estimate();
    SemIm estSemIm = semEst.getEstimatedSem();
    List<Node> nodes = semPm1.getVariableNodes();
    for (Node node : nodes) {
        double mean = semIm1.getMean(node);
        assertEquals(mean, estSemIm.getMean(node), 0.5);
    }
}
Also used : EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) Graph(edu.cmu.tetrad.graph.Graph) DataSet(edu.cmu.tetrad.data.DataSet) GraphNode(edu.cmu.tetrad.graph.GraphNode) Node(edu.cmu.tetrad.graph.Node) SemPm(edu.cmu.tetrad.sem.SemPm) Parameter(edu.cmu.tetrad.sem.Parameter) SemEstimator(edu.cmu.tetrad.sem.SemEstimator) SemIm(edu.cmu.tetrad.sem.SemIm) Test(org.junit.Test)

Example 24 with SemPm

use of edu.cmu.tetrad.sem.SemPm in project tetrad by cmu-phil.

the class TestPurify method test1b.

@Test
public void test1b() {
    RandomUtil.getInstance().setSeed(48290483L);
    SemGraph graph = new SemGraph();
    Node l1 = new GraphNode("L1");
    l1.setNodeType(NodeType.LATENT);
    Node l2 = new GraphNode("L2");
    l2.setNodeType(NodeType.LATENT);
    Node x1 = new GraphNode("X1");
    Node x2 = new GraphNode("X2");
    Node x3 = new GraphNode("X3");
    Node x4 = new GraphNode("X4");
    Node x5 = new GraphNode("X5");
    Node x6 = new GraphNode("X6");
    Node x7 = new GraphNode("X7");
    Node x8 = new GraphNode("X8");
    Node x9 = new GraphNode("X9");
    Node x10 = new GraphNode("X10");
    Node x11 = new GraphNode("X11");
    Node x12 = new GraphNode("X12");
    graph.addNode(l1);
    graph.addNode(l2);
    graph.addNode(x1);
    graph.addNode(x2);
    graph.addNode(x3);
    graph.addNode(x4);
    graph.addNode(x5);
    graph.addNode(x6);
    graph.addNode(x7);
    graph.addNode(x8);
    graph.addNode(x9);
    graph.addNode(x10);
    graph.addNode(x11);
    graph.addNode(x12);
    graph.addDirectedEdge(l1, x1);
    graph.addDirectedEdge(l1, x2);
    graph.addDirectedEdge(l1, x3);
    graph.addDirectedEdge(l1, x4);
    graph.addDirectedEdge(l1, x5);
    graph.addDirectedEdge(l1, x5);
    graph.addDirectedEdge(l1, x6);
    graph.addDirectedEdge(l2, x6);
    graph.addDirectedEdge(l2, x7);
    graph.addDirectedEdge(l2, x8);
    graph.addDirectedEdge(l2, x9);
    graph.addDirectedEdge(l2, x10);
    graph.addDirectedEdge(l2, x11);
    graph.addDirectedEdge(l2, x12);
    graph.addDirectedEdge(x3, x4);
    graph.addDirectedEdge(x9, x10);
    SemPm pm = new SemPm(graph);
    SemIm im = new SemIm(pm);
    DataSet data = im.simulateData(3000, false);
    List<List<Node>> partition = new ArrayList<>();
    List<Node> cluster1 = new ArrayList<>();
    cluster1.add(x1);
    cluster1.add(x2);
    cluster1.add(x3);
    cluster1.add(x4);
    cluster1.add(x5);
    List<Node> cluster2 = new ArrayList<>();
    cluster2.add(x7);
    cluster2.add(x8);
    cluster2.add(x9);
    cluster2.add(x10);
    cluster2.add(x11);
    cluster2.add(x12);
    partition.add(cluster1);
    partition.add(cluster2);
    TetradTest test = new ContinuousTetradTest(data, TestType.TETRAD_WISHART, 0.0001);
    IPurify purify = new PurifyTetradBased2(test);
    purify.setTrueGraph(graph);
    List<List<Node>> clustering = purify.purify(partition);
    assertEquals(4, clustering.get(0).size());
    assertEquals(5, clustering.get(1).size());
}
Also used : ArrayList(java.util.ArrayList) SemPm(edu.cmu.tetrad.sem.SemPm) List(java.util.List) ArrayList(java.util.ArrayList) SemIm(edu.cmu.tetrad.sem.SemIm) Test(org.junit.Test)

Example 25 with SemPm

use of edu.cmu.tetrad.sem.SemPm in project tetrad by cmu-phil.

the class TestPurify method test2.

@Test
public void test2() {
    RandomUtil.getInstance().setSeed(48290483L);
    Graph graph = new EdgeListGraph(DataGraphUtils.randomSingleFactorModel(3, 3, 5, 0, 0, 0));
    SemPm pm = new SemPm(graph);
    SemIm im = new SemIm(pm);
    DataSet data = im.simulateData(1000, false);
    List<Node> latents = new ArrayList<>();
    for (Node node : graph.getNodes()) {
        if (node.getNodeType() == NodeType.LATENT)
            latents.add(node);
    }
    Graph structuralGraph = graph.subgraph(latents);
    List<List<Node>> clustering = new ArrayList<>();
    for (Node node : latents) {
        List<Node> adj = graph.getAdjacentNodes(node);
        adj.removeAll(latents);
        clustering.add(adj);
    }
    ContinuousTetradTest test = new ContinuousTetradTest(data, TestType.TETRAD_WISHART, 0.001);
    IPurify purify = new PurifyTetradBased2(test);
    List<List<Node>> purifiedClustering = purify.purify(clustering);
    List<String> latentsNames = new ArrayList<>();
    for (int i = 0; i < latents.size(); i++) {
        latentsNames.add(latents.get(i).getName());
    }
    Mimbuild2 mimbuild = new Mimbuild2();
    mimbuild.setAlpha(0.0001);
    Graph _graph = mimbuild.search(purifiedClustering, latentsNames, new CovarianceMatrix(data));
    List<Node> _latents = new ArrayList<>();
    for (Node node : _graph.getNodes()) {
        if (node.getNodeType() == NodeType.LATENT)
            _latents.add(node);
    }
    Graph _structuralGraph = _graph.subgraph(_latents);
    assertEquals(2, _structuralGraph.getNumEdges());
}
Also used : Mimbuild2(edu.cmu.tetrad.search.Mimbuild2) ArrayList(java.util.ArrayList) SemPm(edu.cmu.tetrad.sem.SemPm) List(java.util.List) ArrayList(java.util.ArrayList) SemIm(edu.cmu.tetrad.sem.SemIm) Test(org.junit.Test)

Aggregations

SemPm (edu.cmu.tetrad.sem.SemPm)77 SemIm (edu.cmu.tetrad.sem.SemIm)71 Test (org.junit.Test)44 ArrayList (java.util.ArrayList)29 DataSet (edu.cmu.tetrad.data.DataSet)28 Graph (edu.cmu.tetrad.graph.Graph)25 ContinuousVariable (edu.cmu.tetrad.data.ContinuousVariable)18 Node (edu.cmu.tetrad.graph.Node)18 SemEstimator (edu.cmu.tetrad.sem.SemEstimator)16 EdgeListGraph (edu.cmu.tetrad.graph.EdgeListGraph)15 Dag (edu.cmu.tetrad.graph.Dag)10 DMSearch (edu.cmu.tetrad.search.DMSearch)9 StandardizedSemIm (edu.cmu.tetrad.sem.StandardizedSemIm)9 NumberFormat (java.text.NumberFormat)7 TetradMatrix (edu.cmu.tetrad.util.TetradMatrix)6 ICovarianceMatrix (edu.cmu.tetrad.data.ICovarianceMatrix)5 GraphNode (edu.cmu.tetrad.graph.GraphNode)5 CovarianceMatrix (edu.cmu.tetrad.data.CovarianceMatrix)4 IndependenceTest (edu.cmu.tetrad.search.IndependenceTest)4 Parameters (edu.cmu.tetrad.util.Parameters)4