Search in sources :

Example 31 with ContinuousVariable

use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.

the class TestLingamPattern method test1.

@Test
public void test1() {
    RandomUtil.getInstance().setSeed(4938492L);
    int sampleSize = 1000;
    List<Node> nodes = new ArrayList<>();
    for (int i = 0; i < 6; i++) {
        nodes.add(new ContinuousVariable("X" + (i + 1)));
    }
    Graph graph = new Dag(GraphUtils.randomGraph(nodes, 0, 6, 4, 4, 4, false));
    List<Distribution> variableDistributions = new ArrayList<>();
    variableDistributions.add(new Normal(0, 1));
    variableDistributions.add(new Normal(0, 1));
    variableDistributions.add(new Normal(0, 1));
    variableDistributions.add(new Uniform(-1, 1));
    variableDistributions.add(new Normal(0, 1));
    variableDistributions.add(new Normal(0, 1));
    SemPm semPm = new SemPm(graph);
    SemIm semIm = new SemIm(semPm);
    DataSet dataSet = simulateDataNonNormal(semIm, sampleSize, variableDistributions);
    Score score = new SemBicScore(new CovarianceMatrixOnTheFly(dataSet));
    Graph estPattern = new Fges(score).search();
    LingamPattern lingam = new LingamPattern(estPattern, dataSet);
    lingam.search();
    double[] pvals = lingam.getPValues();
    double[] expectedPVals = { 0.18, 0.29, 0.88, 0.00, 0.01, 0.58 };
    for (int i = 0; i < pvals.length; i++) {
        assertEquals(expectedPVals[i], pvals[i], 0.01);
    }
}
Also used : ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) DataSet(edu.cmu.tetrad.data.DataSet) Uniform(edu.cmu.tetrad.util.dist.Uniform) Normal(edu.cmu.tetrad.util.dist.Normal) ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) Distribution(edu.cmu.tetrad.util.dist.Distribution) SemPm(edu.cmu.tetrad.sem.SemPm) CovarianceMatrixOnTheFly(edu.cmu.tetrad.data.CovarianceMatrixOnTheFly) SemIm(edu.cmu.tetrad.sem.SemIm) Test(org.junit.Test)

Example 32 with ContinuousVariable

use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.

the class TestMbfs method testRandom.

@Test
public void testRandom() {
    RandomUtil.getInstance().setSeed(8388428832L);
    List<Node> nodes1 = new ArrayList<>();
    for (int i = 0; i < 10; i++) {
        nodes1.add(new ContinuousVariable("X" + (i + 1)));
    }
    Dag dag = new Dag(GraphUtils.randomGraph(nodes1, 0, 10, 5, 5, 5, false));
    IndependenceTest test = new IndTestDSep(dag);
    Mbfs search = new Mbfs(test, -1);
    List<Node> nodes = dag.getNodes();
    for (Node node : nodes) {
        Graph resultMb = search.search(node.getName());
        Graph trueMb = GraphUtils.markovBlanketDag(node, dag);
        List<Node> resultNodes = resultMb.getNodes();
        List<Node> trueNodes = trueMb.getNodes();
        Set<String> resultNames = new HashSet<>();
        for (Node resultNode : resultNodes) {
            resultNames.add(resultNode.getName());
        }
        Set<String> trueNames = new HashSet<>();
        for (Node v : trueNodes) {
            trueNames.add(v.getName());
        }
        assertTrue(resultNames.equals(trueNames));
        Set<Edge> resultEdges = resultMb.getEdges();
        for (Edge resultEdge : resultEdges) {
            if (Edges.isDirectedEdge(resultEdge)) {
                String name1 = resultEdge.getNode1().getName();
                String name2 = resultEdge.getNode2().getName();
                Node node1 = trueMb.getNode(name1);
                Node node2 = trueMb.getNode(name2);
                // possibility that the node is actually a child.
                if (node1 == null) {
                    fail("Node " + name1 + " is not in the true graph.");
                }
                if (node2 == null) {
                    fail("Node " + name2 + " is not in the true graph.");
                }
                Edge trueEdge = trueMb.getEdge(node1, node2);
                if (trueEdge == null) {
                    Node resultNode1 = resultMb.getNode(node1.getName());
                    Node resultNode2 = resultMb.getNode(node2.getName());
                    Node resultTarget = resultMb.getNode(node.getName());
                    Edge a = resultMb.getEdge(resultNode1, resultTarget);
                    Edge b = resultMb.getEdge(resultNode2, resultTarget);
                    if (a == null || b == null) {
                        continue;
                    }
                    if ((Edges.isDirectedEdge(a) && Edges.isUndirectedEdge(b)) || (Edges.isUndirectedEdge(a) && Edges.isDirectedEdge(b))) {
                        continue;
                    }
                    fail("EXTRA EDGE: Edge in result MB but not true MB = " + resultEdge);
                }
                assertEquals(resultEdge.getEndpoint1(), trueEdge.getEndpoint1());
                assertEquals(resultEdge.getEndpoint2(), trueEdge.getEndpoint2());
            }
        }
    }
}
Also used : Mbfs(edu.cmu.tetrad.search.Mbfs) ArrayList(java.util.ArrayList) ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) IndependenceTest(edu.cmu.tetrad.search.IndependenceTest) IndTestDSep(edu.cmu.tetrad.search.IndTestDSep) HashSet(java.util.HashSet) Test(org.junit.Test) IndependenceTest(edu.cmu.tetrad.search.IndependenceTest)

Example 33 with ContinuousVariable

use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.

the class TestSemIm method testIntercepts.

@Test
public void testIntercepts() {
    List<Node> nodes = new ArrayList<>();
    for (int i = 0; i < 5; i++) {
        nodes.add(new ContinuousVariable("X" + (i + 1)));
    }
    Graph randomGraph = new Dag(GraphUtils.randomGraph(nodes, 0, 8, 30, 15, 15, false));
    SemPm semPm = new SemPm(randomGraph);
    SemIm semIm = new SemIm(semPm);
    semIm.setIntercept(semIm.getVariableNodes().get(0), 1.0);
    semIm.setIntercept(semIm.getVariableNodes().get(1), 3.0);
    semIm.setIntercept(semIm.getVariableNodes().get(2), -1.0);
    semIm.setIntercept(semIm.getVariableNodes().get(3), 6.0);
    assertEquals(1.0, semIm.getIntercept(semIm.getVariableNodes().get(0)), 0.1);
    assertEquals(3.0, semIm.getIntercept(semIm.getVariableNodes().get(1)), 0.1);
    assertEquals(-1.0, semIm.getIntercept(semIm.getVariableNodes().get(2)), 0.1);
    assertEquals(6.0, semIm.getIntercept(semIm.getVariableNodes().get(3)), 0.1);
    assertEquals(0.0, semIm.getIntercept(semIm.getVariableNodes().get(4)), 0.1);
}
Also used : ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) DoubleArrayList(cern.colt.list.DoubleArrayList) ArrayList(java.util.ArrayList) Test(org.junit.Test)

Example 34 with ContinuousVariable

use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.

the class TestSemIm method testCovariancesOfSimulated.

@Test
public void testCovariancesOfSimulated() {
    List<Node> nodes = new ArrayList<>();
    for (int i = 0; i < 5; i++) {
        nodes.add(new ContinuousVariable("X" + (i + 1)));
    }
    Graph randomGraph = new Dag(GraphUtils.randomGraph(nodes, 0, 8, 30, 15, 15, false));
    SemPm semPm1 = new SemPm(randomGraph);
    SemIm semIm1 = new SemIm(semPm1);
    TetradMatrix implCovarC = semIm1.getImplCovar(true);
    implCovarC.toArray();
    DataSet dataSet = semIm1.simulateDataRecursive(1000, false);
    new CovarianceMatrix(dataSet);
}
Also used : ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) DataSet(edu.cmu.tetrad.data.DataSet) DoubleArrayList(cern.colt.list.DoubleArrayList) ArrayList(java.util.ArrayList) CovarianceMatrix(edu.cmu.tetrad.data.CovarianceMatrix) ICovarianceMatrix(edu.cmu.tetrad.data.ICovarianceMatrix) Test(org.junit.Test)

Example 35 with ContinuousVariable

use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.

the class TestMarkovBlanketSearches method testRandom.

@Test
public void testRandom() {
    List<Node> nodes1 = new ArrayList<>();
    for (int i = 0; i < 10; i++) {
        nodes1.add(new ContinuousVariable("X" + (i + 1)));
    }
    Dag dag = new Dag(GraphUtils.randomGraph(nodes1, 0, 10, 5, 5, 5, false));
    IndependenceTest test = new IndTestDSep(dag);
    Mbfs search = new Mbfs(test, -1);
    List<Node> nodes = dag.getNodes();
    for (Node node : nodes) {
        List<Node> resultNodes = search.findMb(node.getName());
        Graph trueMb = GraphUtils.markovBlanketDag(node, dag);
        List<Node> trueNodes = trueMb.getNodes();
        trueNodes.remove(node);
        Collections.sort(trueNodes, new Comparator<Node>() {

            public int compare(Node n1, Node n2) {
                return n1.getName().compareTo(n2.getName());
            }
        });
        Collections.sort(resultNodes, new Comparator<Node>() {

            public int compare(Node n1, Node n2) {
                return n1.getName().compareTo(n2.getName());
            }
        });
        assertEquals(trueNodes, resultNodes);
    }
}
Also used : ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) Test(org.junit.Test)

Aggregations

ContinuousVariable (edu.cmu.tetrad.data.ContinuousVariable)91 DataSet (edu.cmu.tetrad.data.DataSet)48 Node (edu.cmu.tetrad.graph.Node)46 Test (org.junit.Test)42 ArrayList (java.util.ArrayList)28 Graph (edu.cmu.tetrad.graph.Graph)22 ColtDataSet (edu.cmu.tetrad.data.ColtDataSet)19 SemPm (edu.cmu.tetrad.sem.SemPm)18 SemIm (edu.cmu.tetrad.sem.SemIm)16 DiscreteVariable (edu.cmu.tetrad.data.DiscreteVariable)15 LinkedList (java.util.LinkedList)13 EdgeListGraph (edu.cmu.tetrad.graph.EdgeListGraph)12 TetradMatrix (edu.cmu.tetrad.util.TetradMatrix)8 DMSearch (edu.cmu.tetrad.search.DMSearch)7 Dag (edu.cmu.tetrad.graph.Dag)6 CovarianceMatrix (edu.cmu.tetrad.data.CovarianceMatrix)5 RandomUtil (edu.cmu.tetrad.util.RandomUtil)5 ParseException (java.text.ParseException)4 CovarianceMatrixOnTheFly (edu.cmu.tetrad.data.CovarianceMatrixOnTheFly)3 Knowledge2 (edu.cmu.tetrad.data.Knowledge2)3