Search in sources :

Example 26 with ContinuousVariable

use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.

the class TestSearchGraph method testDSeparation.

/**
 * Tests to see if d separation facts are symmetric.
 */
@Test
public void testDSeparation() {
    List<Node> nodes1 = new ArrayList<>();
    for (int i1 = 0; i1 < 7; i1++) {
        nodes1.add(new ContinuousVariable("X" + (i1 + 1)));
    }
    EdgeListGraphSingleConnections graph = new EdgeListGraphSingleConnections(new Dag(GraphUtils.randomGraph(nodes1, 0, 7, 30, 15, 15, true)));
    List<Node> nodes = graph.getNodes();
    int depth = -1;
    for (int i = 0; i < nodes.size(); i++) {
        for (int j = i + 1; j < nodes.size(); j++) {
            Node x = nodes.get(i);
            Node y = nodes.get(j);
            List<Node> theRest = new ArrayList<>(nodes);
            theRest.remove(x);
            theRest.remove(y);
            DepthChoiceGenerator gen = new DepthChoiceGenerator(theRest.size(), depth);
            int[] choice;
            while ((choice = gen.next()) != null) {
                List<Node> z = new LinkedList<>();
                for (int k = 0; k < choice.length; k++) {
                    z.add(theRest.get(choice[k]));
                }
                if (graph.isDSeparatedFrom(x, y, z) != graph.isDSeparatedFrom(y, x, z)) {
                    fail(SearchLogUtils.independenceFact(x, y, z) + " should have same d-sep result as " + SearchLogUtils.independenceFact(y, x, z));
                }
            }
        }
    }
}
Also used : ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) DepthChoiceGenerator(edu.cmu.tetrad.util.DepthChoiceGenerator) Test(org.junit.Test)

Example 27 with ContinuousVariable

use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.

the class TestSimulatedFmri method testClark.

// @Test
public void testClark() {
    double f = .1;
    int N = 512;
    double alpha = 1.0;
    double penaltyDiscount = 1.0;
    for (int i = 0; i < 100; i++) {
        {
            Node x = new ContinuousVariable("X");
            Node y = new ContinuousVariable("Y");
            Node z = new ContinuousVariable("Z");
            Graph g = new EdgeListGraph();
            g.addNode(x);
            g.addNode(y);
            g.addNode(z);
            g.addDirectedEdge(x, y);
            g.addDirectedEdge(z, x);
            g.addDirectedEdge(z, y);
            GeneralizedSemPm pm = new GeneralizedSemPm(g);
            try {
                pm.setNodeExpression(g.getNode("X"), "0.5 * Z + E_X");
                pm.setNodeExpression(g.getNode("Y"), "0.5 * X + 0.5 * Z + E_Y");
                pm.setNodeExpression(g.getNode("Z"), "E_Z");
                String error = "pow(Uniform(0, 1), " + f + ")";
                pm.setNodeExpression(pm.getErrorNode(g.getNode("X")), error);
                pm.setNodeExpression(pm.getErrorNode(g.getNode("Y")), error);
                pm.setNodeExpression(pm.getErrorNode(g.getNode("Z")), error);
            } catch (ParseException e) {
                System.out.println(e);
            }
            GeneralizedSemIm im = new GeneralizedSemIm(pm);
            DataSet data = im.simulateData(N, false);
            edu.cmu.tetrad.search.SemBicScore score = new edu.cmu.tetrad.search.SemBicScore(new CovarianceMatrixOnTheFly(data, false));
            score.setPenaltyDiscount(penaltyDiscount);
            Fask fask = new Fask(data, score);
            fask.setPenaltyDiscount(penaltyDiscount);
            fask.setAlpha(alpha);
            Graph out = fask.search();
            System.out.println(out);
        }
        {
            Node x = new ContinuousVariable("X");
            Node y = new ContinuousVariable("Y");
            Node z = new ContinuousVariable("Z");
            Graph g = new EdgeListGraph();
            g.addNode(x);
            g.addNode(y);
            g.addNode(z);
            g.addDirectedEdge(x, y);
            g.addDirectedEdge(x, z);
            g.addDirectedEdge(y, z);
            GeneralizedSemPm pm = new GeneralizedSemPm(g);
            try {
                pm.setNodeExpression(g.getNode("X"), "E_X");
                pm.setNodeExpression(g.getNode("Y"), "0.4 * X + E_Y");
                pm.setNodeExpression(g.getNode("Z"), "0.4 * X + 0.4 * Y + E_Z");
                String error = "pow(Uniform(0, 1), " + f + ")";
                pm.setNodeExpression(pm.getErrorNode(g.getNode("X")), error);
                pm.setNodeExpression(pm.getErrorNode(g.getNode("Y")), error);
                pm.setNodeExpression(pm.getErrorNode(g.getNode("Z")), error);
            } catch (ParseException e) {
                System.out.println(e);
            }
            GeneralizedSemIm im = new GeneralizedSemIm(pm);
            DataSet data = im.simulateData(N, false);
            edu.cmu.tetrad.search.SemBicScore score = new edu.cmu.tetrad.search.SemBicScore(new CovarianceMatrixOnTheFly(data, false));
            score.setPenaltyDiscount(penaltyDiscount);
            Fask fask = new Fask(data, score);
            fask.setPenaltyDiscount(penaltyDiscount);
            fask.setAlpha(alpha);
            Graph out = fask.search();
            System.out.println(out);
        }
    }
}
Also used : DataSet(edu.cmu.tetrad.data.DataSet) Node(edu.cmu.tetrad.graph.Node) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) Fask(edu.cmu.tetrad.search.Fask) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) Graph(edu.cmu.tetrad.graph.Graph) GeneralizedSemIm(edu.cmu.tetrad.sem.GeneralizedSemIm) ParseException(java.text.ParseException) CovarianceMatrixOnTheFly(edu.cmu.tetrad.data.CovarianceMatrixOnTheFly) GeneralizedSemPm(edu.cmu.tetrad.sem.GeneralizedSemPm) SemBicScore(edu.cmu.tetrad.algcomparison.score.SemBicScore)

Example 28 with ContinuousVariable

use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.

the class TestIndTestWaldLR method testIsIndependent.

@Test
public void testIsIndependent() {
    RandomUtil.getInstance().setSeed(1450705713157L);
    int numPassed = 0;
    for (int i = 0; i < 10; i++) {
        List<Node> nodes = new ArrayList<>();
        for (int i1 = 0; i1 < 10; i1++) {
            nodes.add(new ContinuousVariable("X" + (i1 + 1)));
        }
        Graph graph = GraphUtils.randomGraph(nodes, 0, 10, 3, 3, 3, false);
        SemPm pm = new SemPm(graph);
        SemIm im = new SemIm(pm);
        DataSet data = im.simulateData(1000, false);
        Discretizer discretizer = new Discretizer(data);
        discretizer.setVariablesCopied(true);
        discretizer.equalCounts(data.getVariable(0), 2);
        discretizer.equalCounts(data.getVariable(3), 2);
        data = discretizer.discretize();
        Node x1 = data.getVariable("X1");
        Node x2 = data.getVariable("X2");
        Node x3 = data.getVariable("X3");
        Node x4 = data.getVariable("X4");
        Node x5 = data.getVariable("X5");
        List<Node> cond = new ArrayList<>();
        cond.add(x3);
        cond.add(x4);
        cond.add(x5);
        Node x1Graph = graph.getNode(x1.getName());
        Node x2Graph = graph.getNode(x2.getName());
        List<Node> condGraph = new ArrayList<>();
        for (Node node : cond) {
            condGraph.add(graph.getNode(node.getName()));
        }
        // Using the Wald LR test since it's most up to date.
        IndependenceTest test = new IndTestMultinomialLogisticRegressionWald(data, 0.05, false);
        IndTestDSep dsep = new IndTestDSep(graph);
        boolean correct = test.isIndependent(x2, x1, cond) == dsep.isIndependent(x2Graph, x1Graph, condGraph);
        if (correct) {
            numPassed++;
        }
    }
    // System.out.println(RandomUtil.getInstance().getSeed());
    // Do not always get all 10.
    assertEquals(10, numPassed);
}
Also used : DataSet(edu.cmu.tetrad.data.DataSet) Node(edu.cmu.tetrad.graph.Node) ArrayList(java.util.ArrayList) Discretizer(edu.cmu.tetrad.data.Discretizer) ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) IndependenceTest(edu.cmu.tetrad.search.IndependenceTest) IndTestDSep(edu.cmu.tetrad.search.IndTestDSep) Graph(edu.cmu.tetrad.graph.Graph) SemPm(edu.cmu.tetrad.sem.SemPm) IndTestMultinomialLogisticRegressionWald(edu.pitt.csb.mgm.IndTestMultinomialLogisticRegressionWald) SemIm(edu.cmu.tetrad.sem.SemIm) Test(org.junit.Test) IndependenceTest(edu.cmu.tetrad.search.IndependenceTest)

Example 29 with ContinuousVariable

use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.

the class TestKnowledge method test1.

@Test
public final void test1() {
    List<Node> nodes1 = new ArrayList<>();
    for (int i1 = 0; i1 < 10; i1++) {
        nodes1.add(new ContinuousVariable("X" + (i1 + 1)));
    }
    Graph g = GraphUtils.randomGraph(nodes1, 0, 10, 3, 3, 3, false);
    g.getNode("X1").setName("X1.1");
    g.getNode("X2").setName("X2-1");
    List<Node> nodes = g.getNodes();
    List<String> varNames = new ArrayList<>();
    for (Node node : nodes) {
        varNames.add(node.getName());
    }
    IKnowledge knowledge = new Knowledge2(varNames);
    knowledge.addToTier(0, "X1.*1");
    knowledge.addToTier(0, "X2-1");
    knowledge.addToTier(1, "X3");
    knowledge.setForbidden("X4", "X5");
    knowledge.setRequired("X6", "X7");
    knowledge.setRequired("X7", "X8");
    assertTrue(knowledge.isForbidden("X4", "X5"));
    assertFalse(knowledge.isForbidden("X1.1", "X2-1"));
    assertTrue(knowledge.isForbidden("X3", "X2-1"));
    assertTrue(knowledge.isRequired("X6", "X7"));
    IKnowledge copy = knowledge.copy();
    assertTrue(copy.isForbidden("X4", "X5"));
    assertFalse(copy.isForbidden("X1", "X2-1"));
    assertTrue(copy.isForbidden("X3", "X2-1"));
    knowledge.setTierForbiddenWithin(0, true);
    assertTrue(knowledge.isForbidden("X1.1", "X2-1"));
    assertTrue(knowledge.isForbidden("X2-1", "X1.1"));
    assertFalse(knowledge.isForbidden("X1.1", "X1.1"));
    boolean found = false;
    for (Iterator i = knowledge.forbiddenEdgesIterator(); i.hasNext(); ) {
        KnowledgeEdge edge = (KnowledgeEdge) i.next();
        if (edge.getFrom().equals("X1.1") && edge.getTo().equals("X2-1")) {
            found = true;
        }
    }
    assertTrue(found);
    knowledge.setTierForbiddenWithin(0, false);
    assertFalse(knowledge.isForbidden("X1.1", "X2-1"));
    assertFalse(knowledge.isForbidden("X2-1", "X1.1"));
    assertFalse(knowledge.isForbidden("X1.1", "X1.1"));
}
Also used : KnowledgeEdge(edu.cmu.tetrad.data.KnowledgeEdge) Node(edu.cmu.tetrad.graph.Node) ArrayList(java.util.ArrayList) Knowledge2(edu.cmu.tetrad.data.Knowledge2) ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) IKnowledge(edu.cmu.tetrad.data.IKnowledge) Graph(edu.cmu.tetrad.graph.Graph) Iterator(java.util.Iterator) Test(org.junit.Test)

Example 30 with ContinuousVariable

use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.

the class TestLargeSemSimulator method test1.

@Test
public void test1() {
    List<Node> nodes = new ArrayList<>();
    for (int i = 1; i <= 10; i++) nodes.add(new ContinuousVariable("X" + i));
    Graph graph = GraphUtils.randomGraph(nodes, 0, 10, 5, 5, 5, false);
    LargeScaleSimulation simulator = new LargeScaleSimulation(graph);
    DataSet dataset = simulator.simulateDataFisher(1000);
    assertEquals(1000, dataset.getNumRows());
}
Also used : ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) LargeScaleSimulation(edu.cmu.tetrad.sem.LargeScaleSimulation) DataSet(edu.cmu.tetrad.data.DataSet) ArrayList(java.util.ArrayList) Test(org.junit.Test)

Aggregations

ContinuousVariable (edu.cmu.tetrad.data.ContinuousVariable)91 DataSet (edu.cmu.tetrad.data.DataSet)48 Node (edu.cmu.tetrad.graph.Node)46 Test (org.junit.Test)42 ArrayList (java.util.ArrayList)28 Graph (edu.cmu.tetrad.graph.Graph)22 ColtDataSet (edu.cmu.tetrad.data.ColtDataSet)19 SemPm (edu.cmu.tetrad.sem.SemPm)18 SemIm (edu.cmu.tetrad.sem.SemIm)16 DiscreteVariable (edu.cmu.tetrad.data.DiscreteVariable)15 LinkedList (java.util.LinkedList)13 EdgeListGraph (edu.cmu.tetrad.graph.EdgeListGraph)12 TetradMatrix (edu.cmu.tetrad.util.TetradMatrix)8 DMSearch (edu.cmu.tetrad.search.DMSearch)7 Dag (edu.cmu.tetrad.graph.Dag)6 CovarianceMatrix (edu.cmu.tetrad.data.CovarianceMatrix)5 RandomUtil (edu.cmu.tetrad.util.RandomUtil)5 ParseException (java.text.ParseException)4 CovarianceMatrixOnTheFly (edu.cmu.tetrad.data.CovarianceMatrixOnTheFly)3 Knowledge2 (edu.cmu.tetrad.data.Knowledge2)3