Search in sources :

Example 36 with Dag

use of edu.cmu.tetrad.graph.Dag in project tetrad by cmu-phil.

the class TestRowSummingUpdater method sampleBayesIm1.

private BayesIm sampleBayesIm1() {
    Node x = new GraphNode("x");
    Node z = new GraphNode("z");
    Dag graph = new Dag();
    graph.addNode(x);
    graph.addNode(z);
    graph.addDirectedEdge(x, z);
    BayesPm bayesPm = new BayesPm(graph);
    BayesIm bayesIm1 = new MlBayesIm(bayesPm);
    bayesIm1.setProbability(0, 0, 0, .3);
    bayesIm1.setProbability(0, 0, 1, .7);
    bayesIm1.setProbability(1, 0, 0, .8);
    bayesIm1.setProbability(1, 0, 1, .2);
    bayesIm1.setProbability(1, 1, 0, .4);
    bayesIm1.setProbability(1, 1, 1, .6);
    return bayesIm1;
}
Also used : GraphNode(edu.cmu.tetrad.graph.GraphNode) Node(edu.cmu.tetrad.graph.Node) GraphNode(edu.cmu.tetrad.graph.GraphNode) Dag(edu.cmu.tetrad.graph.Dag)

Example 37 with Dag

use of edu.cmu.tetrad.graph.Dag in project tetrad by cmu-phil.

the class TestRowSummingUpdater method testUpdate5.

@Test
public void testUpdate5() {
    Node x0Node = new GraphNode("X0");
    Node x1Node = new GraphNode("X1");
    Node x2Node = new GraphNode("X2");
    Node x3Node = new GraphNode("X3");
    Node x4Node = new GraphNode("X4");
    Dag graph = new Dag();
    graph.addNode(x0Node);
    graph.addNode(x1Node);
    graph.addNode(x2Node);
    graph.addNode(x3Node);
    graph.addNode(x4Node);
    graph.addDirectedEdge(x0Node, x1Node);
    graph.addDirectedEdge(x0Node, x2Node);
    graph.addDirectedEdge(x1Node, x3Node);
    graph.addDirectedEdge(x2Node, x3Node);
    graph.addDirectedEdge(x4Node, x0Node);
    graph.addDirectedEdge(x4Node, x2Node);
    BayesPm bayesPm = new BayesPm(graph);
    MlBayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM);
    int x1 = bayesIm.getNodeIndex(x1Node);
    int x2 = bayesIm.getNodeIndex(x2Node);
    int x3 = bayesIm.getNodeIndex(x3Node);
    Evidence evidence = Evidence.tautology(bayesIm);
    evidence.getProposition().setCategory(x1, 1);
    evidence.getProposition().setCategory(x2, 0);
    evidence.getNodeIndex("X1");
    BayesUpdater updater1 = new CptInvariantUpdater(bayesIm);
    updater1.setEvidence(evidence);
    BayesUpdater updater2 = new RowSummingExactUpdater(bayesIm);
    updater2.setEvidence(evidence);
    double marginal1 = updater1.getMarginal(x3, 0);
    double marginal2 = updater2.getMarginal(x3, 0);
    assertEquals(marginal1, marginal2, 0.000001);
}
Also used : GraphNode(edu.cmu.tetrad.graph.GraphNode) Node(edu.cmu.tetrad.graph.Node) GraphNode(edu.cmu.tetrad.graph.GraphNode) Dag(edu.cmu.tetrad.graph.Dag) Test(org.junit.Test)

Example 38 with Dag

use of edu.cmu.tetrad.graph.Dag in project tetrad by cmu-phil.

the class TestSemXml method sampleSemIm1.

private static SemIm sampleSemIm1() {
    List<Node> nodes = new ArrayList<>();
    for (int i = 0; i < 5; i++) {
        nodes.add(new ContinuousVariable("X" + (i + 1)));
    }
    Graph graph = new Dag(GraphUtils.randomGraph(nodes, 0, 5, 30, 15, 15, true));
    SemPm pm = new SemPm(graph);
    return new SemIm(pm);
}
Also used : ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) Graph(edu.cmu.tetrad.graph.Graph) Node(edu.cmu.tetrad.graph.Node) ArrayList(java.util.ArrayList) SemPm(edu.cmu.tetrad.sem.SemPm) Dag(edu.cmu.tetrad.graph.Dag) SemIm(edu.cmu.tetrad.sem.SemIm)

Example 39 with Dag

use of edu.cmu.tetrad.graph.Dag in project tetrad by cmu-phil.

the class TestStatUtils method testConditionalCorrelation.

/**
 * Tests that the unconditional correlations and covariances are correct,
 * at least for the unconditional tests.
 */
@Test
public void testConditionalCorrelation() {
    RandomUtil.getInstance().setSeed(30299533L);
    // Make sure the unconditional correlations and covariances are OK.
    List<Node> nodes1 = new ArrayList<>();
    for (int i = 0; i < 5; i++) {
        nodes1.add(new ContinuousVariable("X" + (i + 1)));
    }
    Graph graph = new Dag(GraphUtils.randomGraph(nodes1, 0, 5, 3, 3, 3, false));
    SemPm pm = new SemPm(graph);
    SemIm im = new SemIm(pm);
    DataSet dataSet = im.simulateData(1000, false);
    double[] x = dataSet.getDoubleData().getColumn(0).toArray();
    double[] y = dataSet.getDoubleData().getColumn(1).toArray();
    double r1 = StatUtils.correlation(x, y);
    double s1 = StatUtils.covariance(x, y);
    double v1 = StatUtils.variance(x);
    double sd1 = StatUtils.sd(x);
    ICovarianceMatrix cov = new CovarianceMatrix(dataSet);
    TetradMatrix _cov = cov.getMatrix();
    double r2 = StatUtils.partialCorrelation(_cov, 0, 1);
    double s2 = StatUtils.partialCovariance(_cov, 0, 1);
    double v2 = StatUtils.partialVariance(_cov, 0);
    double sd2 = StatUtils.partialStandardDeviation(_cov, 0);
    assertEquals(r1, r2, .1);
    assertEquals(s1, s2, .1);
    assertEquals(v1, v2, .1);
    assertEquals(sd1, sd2, 0.1);
}
Also used : Node(edu.cmu.tetrad.graph.Node) ArrayList(java.util.ArrayList) TetradMatrix(edu.cmu.tetrad.util.TetradMatrix) Dag(edu.cmu.tetrad.graph.Dag) Graph(edu.cmu.tetrad.graph.Graph) SemPm(edu.cmu.tetrad.sem.SemPm) SemIm(edu.cmu.tetrad.sem.SemIm) Test(org.junit.Test)

Example 40 with Dag

use of edu.cmu.tetrad.graph.Dag in project tetrad by cmu-phil.

the class TestUpdaterJointMarginal method testEstimate1.

@Test
public void testEstimate1() {
    Dag graph = new Dag();
    Node L1 = new GraphNode("L1");
    Node X1 = new GraphNode("X1");
    Node X2 = new GraphNode("X2");
    Node X3 = new GraphNode("X3");
    L1.setNodeType(NodeType.MEASURED);
    X1.setNodeType(NodeType.MEASURED);
    X2.setNodeType(NodeType.MEASURED);
    X3.setNodeType(NodeType.MEASURED);
    graph.addNode(L1);
    graph.addNode(X1);
    graph.addNode(X2);
    graph.addNode(X3);
    graph.addDirectedEdge(L1, X1);
    graph.addDirectedEdge(L1, X2);
    graph.addDirectedEdge(L1, X3);
    BayesPm bayesPm = new BayesPm(graph);
    bayesPm.setNumCategories(L1, 2);
    bayesPm.setNumCategories(X1, 2);
    bayesPm.setNumCategories(X2, 2);
    bayesPm.setNumCategories(X3, 2);
    BayesIm estimatedIm = new MlBayesIm(bayesPm);
    Node l1Node = graph.getNode("L1");
    // int l1Index = bayesImMixed.getNodeIndex(l1Node);
    int l1Index = estimatedIm.getNodeIndex(l1Node);
    Node x1Node = graph.getNode("X1");
    // int x1Index = bayesImMixed.getNodeIndex(x1Node);
    int x1Index = estimatedIm.getNodeIndex(x1Node);
    Node x2Node = graph.getNode("X2");
    // int x2Index = bayesImMixed.getNodeIndex(x2Node);
    int x2Index = estimatedIm.getNodeIndex(x2Node);
    Node x3Node = graph.getNode("X3");
    // int x3Index = bayesImMixed.getNodeIndex(x3Node);
    int x3Index = estimatedIm.getNodeIndex(x3Node);
    // bayesImMixed.setProbability(l1Index, 0, 0, 0.5);
    // bayesImMixed.setProbability(l1Index, 0, 1, 0.5);
    estimatedIm.setProbability(l1Index, 0, 0, 0.5);
    estimatedIm.setProbability(l1Index, 0, 1, 0.5);
    // bayesImMixed.setProbability(x1Index, 0, 0, 0.33333);
    // bayesImMixed.setProbability(x1Index, 0, 1, 0.66667);
    estimatedIm.setProbability(x1Index, 0, 0, // p(x1 = 0 | l1 = 0)
    0.33333);
    estimatedIm.setProbability(x1Index, 0, 1, // p(x1 = 1 | l1 = 0)
    0.66667);
    estimatedIm.setProbability(x1Index, 1, 0, // p(x1 = 0 | l1 = 1)
    0.66667);
    estimatedIm.setProbability(x1Index, 1, 1, // p(x1 = 1 | l1 = 1)
    0.33333);
    // bayesImMixed.setProbability(x2Index, 1, 0, 0.66667);
    // bayesImMixed.setProbability(x2Index, 1, 1, 0.33333);
    estimatedIm.setProbability(x2Index, 1, 0, // p(x2 = 0 | l1 = 1)
    0.66667);
    estimatedIm.setProbability(x2Index, 1, 1, // p(x2 = 1 | l1 = 1)
    0.33333);
    estimatedIm.setProbability(x2Index, 0, 0, // p(x2 = 0 | l1 = 0)
    0.33333);
    estimatedIm.setProbability(x2Index, 0, 1, // p(x2 = 1 | l1 = 0)
    0.66667);
    // bayesImMixed.setProbability(x3Index, 1, 0, 0.66667);
    // bayesImMixed.setProbability(x3Index, 1, 1, 0.33333);
    estimatedIm.setProbability(x3Index, 1, 0, // p(x3 = 0 | l1 = 1)
    0.66667);
    estimatedIm.setProbability(x3Index, 1, 1, // p(x3 = 1 | l1 = 1)
    0.33333);
    estimatedIm.setProbability(x3Index, 0, 0, // p(x3 = 0 | l1 = 0)
    0.33333);
    estimatedIm.setProbability(x3Index, 0, 1, // p(x3 = 1 | l1 = 0)
    0.66667);
    Evidence evidence = Evidence.tautology(estimatedIm);
    evidence.getProposition().setCategory(x1Index, 0);
    evidence.getProposition().setCategory(x2Index, 0);
    evidence.getProposition().setCategory(x3Index, 0);
    RowSummingExactUpdater rseu = new RowSummingExactUpdater(estimatedIm);
    rseu.setEvidence(evidence);
    int[] vars1 = { l1Index };
    int[] vals1 = { 0 };
    double p1 = rseu.getJointMarginal(vars1, vals1);
    assertEquals(0.1111, p1, 0.0001);
    int[] vars2 = { l1Index, x1Index };
    int[] vals2 = { 0, 0 };
    double p2 = rseu.getJointMarginal(vars2, vals2);
    assertEquals(0.1111, p2, 0.0001);
    int[] vals3 = { 1, 0 };
    double p3 = rseu.getJointMarginal(vars2, vals3);
    assertEquals(0.8888, p3, 0.0001);
}
Also used : GraphNode(edu.cmu.tetrad.graph.GraphNode) Node(edu.cmu.tetrad.graph.Node) GraphNode(edu.cmu.tetrad.graph.GraphNode) Dag(edu.cmu.tetrad.graph.Dag) Test(org.junit.Test)

Aggregations

Dag (edu.cmu.tetrad.graph.Dag)53 Node (edu.cmu.tetrad.graph.Node)41 Graph (edu.cmu.tetrad.graph.Graph)21 GraphNode (edu.cmu.tetrad.graph.GraphNode)21 Test (org.junit.Test)18 ArrayList (java.util.ArrayList)12 SemIm (edu.cmu.tetrad.sem.SemIm)10 SemPm (edu.cmu.tetrad.sem.SemPm)10 DataSet (edu.cmu.tetrad.data.DataSet)7 BayesPm (edu.cmu.tetrad.bayes.BayesPm)6 MlBayesIm (edu.cmu.tetrad.bayes.MlBayesIm)6 ContinuousVariable (edu.cmu.tetrad.data.ContinuousVariable)6 BayesIm (edu.cmu.tetrad.bayes.BayesIm)5 DiscreteVariable (edu.cmu.tetrad.data.DiscreteVariable)3 FruchtermanReingoldLayout (edu.cmu.tetrad.graph.FruchtermanReingoldLayout)2 TetradMatrix (edu.cmu.tetrad.util.TetradMatrix)2 LinkedList (java.util.LinkedList)2 StoredCellProbs (edu.cmu.tetrad.bayes.StoredCellProbs)1 ColtDataSet (edu.cmu.tetrad.data.ColtDataSet)1 CovarianceMatrixOnTheFly (edu.cmu.tetrad.data.CovarianceMatrixOnTheFly)1