Search in sources :

Example 16 with SemIm

use of edu.cmu.tetrad.sem.SemIm in project tetrad by cmu-phil.

the class TestFci method testSearch15.

@Test
public void testSearch15() {
    int numVars = 80;
    int numEdges = 80;
    int sampleSize = 1000;
    boolean latentDataSaved = false;
    int numLatents = 40;
    List<Node> nodes = new ArrayList<>();
    for (int i = 0; i < numVars; i++) {
        nodes.add(new ContinuousVariable("X" + (i + 1)));
    }
    Dag trueGraph = new Dag(GraphUtils.randomGraph(nodes, numLatents, numEdges, 7, 5, 5, false));
    SemPm bayesPm = new SemPm(trueGraph);
    SemIm bayesIm = new SemIm(bayesPm);
    DataSet dataSet = bayesIm.simulateData(sampleSize, latentDataSaved);
    IndependenceTest test = new IndTestFisherZ(dataSet, 0.05);
    Cfci search = new Cfci(test);
    // Run search
    search.search();
}
Also used : ArrayList(java.util.ArrayList) SemPm(edu.cmu.tetrad.sem.SemPm) SemIm(edu.cmu.tetrad.sem.SemIm) Test(org.junit.Test)

Example 17 with SemIm

use of edu.cmu.tetrad.sem.SemIm in project tetrad by cmu-phil.

the class TestFci method testFciAnc.

// @Test
public void testFciAnc() {
    int numMeasures = 50;
    double edgeFactor = 2.0;
    int numRuns = 10;
    double alpha = 0.01;
    double penaltyDiscount = 4.0;
    int numVarsToMarginalize = 5;
    int numLatents = 10;
    System.out.println("num measures = " + numMeasures);
    System.out.println("edge factor = " + edgeFactor);
    System.out.println("alpha = " + alpha);
    System.out.println("penaltyDiscount = " + penaltyDiscount);
    System.out.println("num runs = " + numRuns);
    System.out.println("num vars to marginalize = " + numVarsToMarginalize);
    System.out.println("num latents = " + numLatents);
    System.out.println();
    for (int i = 0; i < numRuns; i++) {
        int numEdges = (int) (edgeFactor * (numMeasures + numLatents));
        List<Node> nodes = new ArrayList<>();
        for (int r = 0; r < numMeasures + numLatents; r++) {
            String name = "X" + (r + 1);
            nodes.add(new ContinuousVariable(name));
        }
        Graph dag = GraphUtils.randomGraphRandomForwardEdges(nodes, numLatents, numEdges, 10, 10, 10, false);
        SemPm pm = new SemPm(dag);
        SemIm im = new SemIm(pm);
        DataSet data = im.simulateData(1000, false);
        Graph pag = getPag(alpha, penaltyDiscount, data);
        DataSet marginalData = data.copy();
        List<Node> variables = marginalData.getVariables();
        Collections.shuffle(variables);
        for (int m = 0; m < numVarsToMarginalize; m++) {
            marginalData.removeColumn(marginalData.getColumn(variables.get(m)));
        }
        Graph margPag = getPag(alpha, penaltyDiscount, marginalData);
        int ancAnc = 0;
        int ancNanc = 0;
        int nancAnc = 0;
        int nancNanc = 0;
        int ambAnc = 0;
        int ambNanc = 0;
        int totalAncMarg = 0;
        int totalNancMarg = 0;
        for (Node n1 : marginalData.getVariables()) {
            for (Node n2 : marginalData.getVariables()) {
                if (n1 == n2)
                    continue;
                if (ancestral(n1, n2, margPag)) {
                    if (ancestral(n1, n2, pag)) {
                        ancAnc++;
                    } else if (nonAncestral(n1, n2, pag)) {
                        nancAnc++;
                    } else {
                        ambAnc++;
                    }
                    totalAncMarg++;
                } else if (nonAncestral(n1, n2, margPag)) {
                    if (ancestral(n1, n2, pag)) {
                        ancNanc++;
                    } else if (nonAncestral(n1, n2, pag)) {
                        nancNanc++;
                    } else {
                        ambNanc++;
                    }
                    totalNancMarg++;
                }
            }
        }
        // {
        // TextTable table = new TextTable(5, 3);
        // table.setToken(0, 1, "Ancestral");
        // table.setToken(0, 2, "Nonancestral");
        // table.setToken(1, 0, "Ancestral");
        // table.setToken(2, 0, "Nonancestral");
        // table.setToken(3, 0, "Ambiguous");
        // table.setToken(4, 0, "Total");
        // 
        // table.setToken(1, 1, ancAnc + "");
        // table.setToken(2, 1, nancAnc + "");
        // table.setToken(3, 1, ambAnc + "");
        // table.setToken(1, 2, ancNanc + "");
        // table.setToken(2, 2, nancNanc + "");
        // table.setToken(3, 2, ambNanc + "");
        // table.setToken(4, 1, totalAncMarg + "");
        // table.setToken(4, 2, totalNancMarg + "");
        // 
        // System.out.println(table);
        // }
        {
            TextTable table = new TextTable(5, 3);
            table.setToken(0, 1, "Ancestral");
            table.setToken(0, 2, "Nonancestral");
            table.setToken(1, 0, "Ancestral");
            table.setToken(2, 0, "Nonancestral");
            table.setToken(3, 0, "Ambiguous");
            table.setToken(4, 0, "Total");
            NumberFormat nf = new DecimalFormat("0.00");
            table.setToken(1, 1, nf.format(ancAnc / (double) totalAncMarg) + "");
            table.setToken(2, 1, nf.format(nancAnc / (double) totalAncMarg) + "");
            table.setToken(3, 1, nf.format(ambAnc / (double) totalAncMarg) + "");
            table.setToken(1, 2, nf.format(ancNanc / (double) totalNancMarg) + "");
            table.setToken(2, 2, nf.format(nancNanc / (double) totalNancMarg) + "");
            table.setToken(3, 2, nf.format(ambNanc / (double) totalNancMarg) + "");
            table.setToken(4, 1, totalAncMarg + "");
            table.setToken(4, 2, totalNancMarg + "");
            System.out.println(table);
        }
    }
}
Also used : DecimalFormat(java.text.DecimalFormat) ArrayList(java.util.ArrayList) TextTable(edu.cmu.tetrad.util.TextTable) SemPm(edu.cmu.tetrad.sem.SemPm) SemIm(edu.cmu.tetrad.sem.SemIm) NumberFormat(java.text.NumberFormat)

Example 18 with SemIm

use of edu.cmu.tetrad.sem.SemIm in project tetrad by cmu-phil.

the class TestSemProposition method testEvidence.

@Test
public void testEvidence() {
    Graph graph = constructGraph1();
    SemPm semPm = new SemPm(graph);
    SemIm semIm = new SemIm(semPm);
    List nodes = semIm.getVariableNodes();
    SemProposition proposition = SemProposition.tautology(semIm);
    for (int i = 0; i < semIm.getVariableNodes().size(); i++) {
        assertTrue(Double.isNaN(proposition.getValue(i)));
    }
    proposition.setValue(1, 0.5);
    assertEquals(0.5, proposition.getValue(1), 0.0);
    Node node4 = (Node) nodes.get(3);
    proposition.setValue(node4, 0.7);
    assertEquals(0.7, proposition.getValue(node4), 0.0);
}
Also used : EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) Graph(edu.cmu.tetrad.graph.Graph) SemProposition(edu.cmu.tetrad.sem.SemProposition) GraphNode(edu.cmu.tetrad.graph.GraphNode) Node(edu.cmu.tetrad.graph.Node) SemPm(edu.cmu.tetrad.sem.SemPm) List(java.util.List) SemIm(edu.cmu.tetrad.sem.SemIm) Test(org.junit.Test)

Example 19 with SemIm

use of edu.cmu.tetrad.sem.SemIm in project tetrad by cmu-phil.

the class TestSemVarMeans method testMeansCholesky.

@Test
public void testMeansCholesky() {
    Graph graph = constructGraph1();
    SemPm semPm1 = new SemPm(graph);
    List<Parameter> parameters = semPm1.getParameters();
    for (Parameter p : parameters) {
        p.setInitializedRandomly(false);
    }
    SemIm semIm1 = new SemIm(semPm1);
    double[] means = { 5.0, 4.0, 3.0, 2.0, 1.0 };
    RandomUtil.getInstance().setSeed(-379467L);
    for (int i = 0; i < semIm1.getVariableNodes().size(); i++) {
        Node node = semIm1.getVariableNodes().get(i);
        semIm1.setMean(node, means[i]);
    }
    DataSet dataSet = semIm1.simulateDataCholesky(1000, false);
    SemEstimator semEst = new SemEstimator(dataSet, semPm1);
    semEst.estimate();
    SemIm estSemIm = semEst.getEstimatedSem();
    List<Node> nodes = semPm1.getVariableNodes();
    for (Node node : nodes) {
        double mean = semIm1.getMean(node);
        assertEquals(mean, estSemIm.getMean(node), 0.6);
    }
}
Also used : EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) Graph(edu.cmu.tetrad.graph.Graph) DataSet(edu.cmu.tetrad.data.DataSet) GraphNode(edu.cmu.tetrad.graph.GraphNode) Node(edu.cmu.tetrad.graph.Node) SemPm(edu.cmu.tetrad.sem.SemPm) Parameter(edu.cmu.tetrad.sem.Parameter) SemEstimator(edu.cmu.tetrad.sem.SemEstimator) SemIm(edu.cmu.tetrad.sem.SemIm) Test(org.junit.Test)

Example 20 with SemIm

use of edu.cmu.tetrad.sem.SemIm in project tetrad by cmu-phil.

the class TestSemVarMeans method testMeansReducedForm.

@Test
public void testMeansReducedForm() {
    Graph graph = constructGraph1();
    SemPm semPm1 = new SemPm(graph);
    List<Parameter> parameters = semPm1.getParameters();
    for (Parameter p : parameters) {
        p.setInitializedRandomly(false);
    }
    SemIm semIm1 = new SemIm(semPm1);
    double[] means = { 5.0, 4.0, 3.0, 2.0, 1.0 };
    RandomUtil.getInstance().setSeed(-379467L);
    for (int i = 0; i < semIm1.getVariableNodes().size(); i++) {
        Node node = semIm1.getVariableNodes().get(i);
        semIm1.setMean(node, means[i]);
    }
    DataSet dataSet = semIm1.simulateDataReducedForm(1000, false);
    SemEstimator semEst = new SemEstimator(dataSet, semPm1);
    semEst.estimate();
    SemIm estSemIm = semEst.getEstimatedSem();
    List<Node> nodes = semPm1.getVariableNodes();
    for (Node node : nodes) {
        double mean = semIm1.getMean(node);
        assertEquals(mean, estSemIm.getMean(node), 0.5);
    }
}
Also used : EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) Graph(edu.cmu.tetrad.graph.Graph) DataSet(edu.cmu.tetrad.data.DataSet) GraphNode(edu.cmu.tetrad.graph.GraphNode) Node(edu.cmu.tetrad.graph.Node) SemPm(edu.cmu.tetrad.sem.SemPm) Parameter(edu.cmu.tetrad.sem.Parameter) SemEstimator(edu.cmu.tetrad.sem.SemEstimator) SemIm(edu.cmu.tetrad.sem.SemIm) Test(org.junit.Test)

Aggregations

SemIm (edu.cmu.tetrad.sem.SemIm)81 SemPm (edu.cmu.tetrad.sem.SemPm)71 Test (org.junit.Test)46 DataSet (edu.cmu.tetrad.data.DataSet)28 ArrayList (java.util.ArrayList)28 Graph (edu.cmu.tetrad.graph.Graph)26 Node (edu.cmu.tetrad.graph.Node)19 ContinuousVariable (edu.cmu.tetrad.data.ContinuousVariable)16 EdgeListGraph (edu.cmu.tetrad.graph.EdgeListGraph)16 SemEstimator (edu.cmu.tetrad.sem.SemEstimator)15 Dag (edu.cmu.tetrad.graph.Dag)10 DMSearch (edu.cmu.tetrad.search.DMSearch)9 StandardizedSemIm (edu.cmu.tetrad.sem.StandardizedSemIm)9 TetradMatrix (edu.cmu.tetrad.util.TetradMatrix)7 NumberFormat (java.text.NumberFormat)7 GraphNode (edu.cmu.tetrad.graph.GraphNode)5 IndependenceTest (edu.cmu.tetrad.search.IndependenceTest)4 DecimalFormat (java.text.DecimalFormat)4 List (java.util.List)4 CovarianceMatrix (edu.cmu.tetrad.data.CovarianceMatrix)3