Search in sources :

Example 26 with SemIm

use of edu.cmu.tetrad.sem.SemIm in project tetrad by cmu-phil.

the class TestRicf method test4.

@Test
public void test4() {
    List<Node> nodes1 = new ArrayList<>();
    for (int i1 = 0; i1 < 5; i1++) {
        nodes1.add(new ContinuousVariable("X" + (i1 + 1)));
    }
    Graph g1 = GraphUtils.randomGraph(nodes1, 0, 5, 0, 0, 0, false);
    List<Node> nodes = new ArrayList<>();
    for (int i = 0; i < 5; i++) {
        nodes.add(new ContinuousVariable("X" + (i + 1)));
    }
    Graph g2 = GraphUtils.randomGraph(nodes, 0, 5, 0, 0, 0, false);
    SemPm pm = new SemPm(g1);
    SemIm im = new SemIm(pm);
    DataSet dataset = im.simulateData(1000, false);
    ICovarianceMatrix cov = new CovarianceMatrix(dataset);
    new Ricf().ricf(new SemGraph(g1), cov, 0.001);
    new Ricf().ricf(new SemGraph(g2), cov, 0.001);
}
Also used : ArrayList(java.util.ArrayList) SemPm(edu.cmu.tetrad.sem.SemPm) Ricf(edu.cmu.tetrad.sem.Ricf) SemIm(edu.cmu.tetrad.sem.SemIm) Test(org.junit.Test)

Example 27 with SemIm

use of edu.cmu.tetrad.sem.SemIm in project tetrad by cmu-phil.

the class TestIndTestWaldLR method testIsIndependent.

@Test
public void testIsIndependent() {
    RandomUtil.getInstance().setSeed(1450705713157L);
    int numPassed = 0;
    for (int i = 0; i < 10; i++) {
        List<Node> nodes = new ArrayList<>();
        for (int i1 = 0; i1 < 10; i1++) {
            nodes.add(new ContinuousVariable("X" + (i1 + 1)));
        }
        Graph graph = GraphUtils.randomGraph(nodes, 0, 10, 3, 3, 3, false);
        SemPm pm = new SemPm(graph);
        SemIm im = new SemIm(pm);
        DataSet data = im.simulateData(1000, false);
        Discretizer discretizer = new Discretizer(data);
        discretizer.setVariablesCopied(true);
        discretizer.equalCounts(data.getVariable(0), 2);
        discretizer.equalCounts(data.getVariable(3), 2);
        data = discretizer.discretize();
        Node x1 = data.getVariable("X1");
        Node x2 = data.getVariable("X2");
        Node x3 = data.getVariable("X3");
        Node x4 = data.getVariable("X4");
        Node x5 = data.getVariable("X5");
        List<Node> cond = new ArrayList<>();
        cond.add(x3);
        cond.add(x4);
        cond.add(x5);
        Node x1Graph = graph.getNode(x1.getName());
        Node x2Graph = graph.getNode(x2.getName());
        List<Node> condGraph = new ArrayList<>();
        for (Node node : cond) {
            condGraph.add(graph.getNode(node.getName()));
        }
        // Using the Wald LR test since it's most up to date.
        IndependenceTest test = new IndTestMultinomialLogisticRegressionWald(data, 0.05, false);
        IndTestDSep dsep = new IndTestDSep(graph);
        boolean correct = test.isIndependent(x2, x1, cond) == dsep.isIndependent(x2Graph, x1Graph, condGraph);
        if (correct) {
            numPassed++;
        }
    }
    // System.out.println(RandomUtil.getInstance().getSeed());
    // Do not always get all 10.
    assertEquals(10, numPassed);
}
Also used : DataSet(edu.cmu.tetrad.data.DataSet) Node(edu.cmu.tetrad.graph.Node) ArrayList(java.util.ArrayList) Discretizer(edu.cmu.tetrad.data.Discretizer) ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) IndependenceTest(edu.cmu.tetrad.search.IndependenceTest) IndTestDSep(edu.cmu.tetrad.search.IndTestDSep) Graph(edu.cmu.tetrad.graph.Graph) SemPm(edu.cmu.tetrad.sem.SemPm) IndTestMultinomialLogisticRegressionWald(edu.pitt.csb.mgm.IndTestMultinomialLogisticRegressionWald) SemIm(edu.cmu.tetrad.sem.SemIm) Test(org.junit.Test) IndependenceTest(edu.cmu.tetrad.search.IndependenceTest)

Example 28 with SemIm

use of edu.cmu.tetrad.sem.SemIm in project tetrad by cmu-phil.

the class TestLingamPattern method test1.

@Test
public void test1() {
    RandomUtil.getInstance().setSeed(4938492L);
    int sampleSize = 1000;
    List<Node> nodes = new ArrayList<>();
    for (int i = 0; i < 6; i++) {
        nodes.add(new ContinuousVariable("X" + (i + 1)));
    }
    Graph graph = new Dag(GraphUtils.randomGraph(nodes, 0, 6, 4, 4, 4, false));
    List<Distribution> variableDistributions = new ArrayList<>();
    variableDistributions.add(new Normal(0, 1));
    variableDistributions.add(new Normal(0, 1));
    variableDistributions.add(new Normal(0, 1));
    variableDistributions.add(new Uniform(-1, 1));
    variableDistributions.add(new Normal(0, 1));
    variableDistributions.add(new Normal(0, 1));
    SemPm semPm = new SemPm(graph);
    SemIm semIm = new SemIm(semPm);
    DataSet dataSet = simulateDataNonNormal(semIm, sampleSize, variableDistributions);
    Score score = new SemBicScore(new CovarianceMatrixOnTheFly(dataSet));
    Graph estPattern = new Fges(score).search();
    LingamPattern lingam = new LingamPattern(estPattern, dataSet);
    lingam.search();
    double[] pvals = lingam.getPValues();
    double[] expectedPVals = { 0.18, 0.29, 0.88, 0.00, 0.01, 0.58 };
    for (int i = 0; i < pvals.length; i++) {
        assertEquals(expectedPVals[i], pvals[i], 0.01);
    }
}
Also used : ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) DataSet(edu.cmu.tetrad.data.DataSet) Uniform(edu.cmu.tetrad.util.dist.Uniform) Normal(edu.cmu.tetrad.util.dist.Normal) ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) Distribution(edu.cmu.tetrad.util.dist.Distribution) SemPm(edu.cmu.tetrad.sem.SemPm) CovarianceMatrixOnTheFly(edu.cmu.tetrad.data.CovarianceMatrixOnTheFly) SemIm(edu.cmu.tetrad.sem.SemIm) Test(org.junit.Test)

Example 29 with SemIm

use of edu.cmu.tetrad.sem.SemIm in project tetrad by cmu-phil.

the class TestPc method printStats.

private double[] printStats(String[] algorithms, int t, boolean directed, int numRuns, double alpha, double penaltyDiscount, int numMeasures, int numLatents, double edgeFactor) {
    NumberFormat nf = new DecimalFormat("0.00");
    double sumArrowPrecision = 0.0;
    double sumTailPrecision = 0.0;
    double sumBidirectedPrecision = 0.0;
    int numArrows = 0;
    int numTails = 0;
    int numBidirected = 0;
    int count = 0;
    int totalElapsed = 0;
    int countAP = 0;
    int countTP = 0;
    int countBP = 0;
    for (int i = 0; i < numRuns; i++) {
        int numEdges = (int) (edgeFactor * (numMeasures + numLatents));
        List<Node> nodes = new ArrayList<>();
        List<String> names = new ArrayList<>();
        for (int r = 0; r < numMeasures + numLatents; r++) {
            String name = "X" + (r + 1);
            nodes.add(new ContinuousVariable(name));
            names.add(name);
        }
        Graph dag = GraphUtils.randomGraphRandomForwardEdges(nodes, numLatents, numEdges, 10, 10, 10, false);
        SemPm pm = new SemPm(dag);
        SemIm im = new SemIm(pm);
        DataSet data = im.simulateData(1000, false);
        IndTestFisherZ test = new IndTestFisherZ(data, alpha);
        SemBicScore score = new SemBicScore(new CovarianceMatrixOnTheFly(data));
        score.setPenaltyDiscount(penaltyDiscount);
        GraphSearch search;
        switch(t) {
            case 0:
                search = new Pc(test);
                break;
            case 1:
                search = new Cpc(test);
                break;
            case 2:
                search = new Fges(score);
                break;
            case 3:
                search = new Fci(test);
                break;
            case 4:
                search = new GFci(test, score);
                break;
            case 5:
                search = new Rfci(test);
                break;
            case 6:
                search = new Cfci(test);
                break;
            default:
                throw new IllegalStateException();
        }
        long start = System.currentTimeMillis();
        Graph out = search.search();
        long stop = System.currentTimeMillis();
        long elapsed = stop - start;
        totalElapsed += elapsed;
        out = GraphUtils.replaceNodes(out, dag.getNodes());
        int arrowsTp = 0;
        int arrowsFp = 0;
        int tailsTp = 0;
        int tailsFp = 0;
        int bidirectedTp = 0;
        int bidirectedFp = 0;
        for (Edge edge : out.getEdges()) {
            if (directed && !(edge.isDirected() || Edges.isBidirectedEdge(edge))) {
                continue;
            }
            if (edge.getEndpoint1() == Endpoint.ARROW) {
                if (!dag.isAncestorOf(edge.getNode1(), edge.getNode2()) && dag.existsTrek(edge.getNode1(), edge.getNode2())) {
                    arrowsTp++;
                } else {
                    arrowsFp++;
                }
                numArrows++;
            }
            if (edge.getEndpoint2() == Endpoint.ARROW) {
                if (!dag.isAncestorOf(edge.getNode2(), edge.getNode1()) && dag.existsTrek(edge.getNode1(), edge.getNode2())) {
                    arrowsTp++;
                } else {
                    arrowsFp++;
                }
                numArrows++;
            }
            if (edge.getEndpoint1() == Endpoint.TAIL) {
                if (dag.existsDirectedPathFromTo(edge.getNode1(), edge.getNode2())) {
                    tailsTp++;
                } else {
                    tailsFp++;
                }
                numTails++;
            }
            if (edge.getEndpoint2() == Endpoint.TAIL) {
                if (dag.existsDirectedPathFromTo(edge.getNode2(), edge.getNode1())) {
                    tailsTp++;
                } else {
                    tailsFp++;
                }
                numTails++;
            }
            if (Edges.isBidirectedEdge(edge)) {
                if (!dag.isAncestorOf(edge.getNode1(), edge.getNode2()) && !dag.isAncestorOf(edge.getNode2(), edge.getNode1()) && dag.existsTrek(edge.getNode1(), edge.getNode2())) {
                    bidirectedTp++;
                } else {
                    bidirectedFp++;
                }
                numBidirected++;
            }
        }
        double arrowPrecision = arrowsTp / (double) (arrowsTp + arrowsFp);
        double tailPrecision = tailsTp / (double) (tailsTp + tailsFp);
        double bidirectedPrecision = bidirectedTp / (double) (bidirectedTp + bidirectedFp);
        if (!Double.isNaN(arrowPrecision)) {
            sumArrowPrecision += arrowPrecision;
            countAP++;
        }
        if (!Double.isNaN(tailPrecision)) {
            sumTailPrecision += tailPrecision;
            countTP++;
        }
        if (!Double.isNaN(bidirectedPrecision)) {
            sumBidirectedPrecision += bidirectedPrecision;
            countBP++;
        }
        count++;
    }
    double avgArrowPrecision = sumArrowPrecision / (double) countAP;
    double avgTailPrecision = sumTailPrecision / (double) countTP;
    double avgBidirectedPrecision = sumBidirectedPrecision / (double) countBP;
    double avgNumArrows = numArrows / (double) count;
    double avgNumTails = numTails / (double) count;
    double avgNumBidirected = numBidirected / (double) count;
    double avgElapsed = totalElapsed / (double) numRuns;
    // double avgRatioPrecisionToElapsed = avgArrowPrecision / avgElapsed;
    double[] ret = new double[] { avgArrowPrecision, avgTailPrecision, avgBidirectedPrecision, avgNumArrows, avgNumTails, avgNumBidirected, // minimize
    -avgElapsed // avgRatioPrecisionToElapsed
    };
    System.out.println();
    NumberFormat nf2 = new DecimalFormat("0.0000");
    System.out.println(algorithms[t] + " arrow precision " + nf.format(avgArrowPrecision));
    System.out.println(algorithms[t] + " tail precision " + nf.format(avgTailPrecision));
    System.out.println(algorithms[t] + " bidirected precision " + nf.format(avgBidirectedPrecision));
    System.out.println(algorithms[t] + " avg num arrow " + nf.format(avgNumArrows));
    System.out.println(algorithms[t] + " avg num tails " + nf.format(avgNumTails));
    System.out.println(algorithms[t] + " avg num bidirected " + nf.format(avgNumBidirected));
    System.out.println(algorithms[t] + " avg elapsed " + nf.format(avgElapsed));
    return ret;
}
Also used : DecimalFormat(java.text.DecimalFormat) SemPm(edu.cmu.tetrad.sem.SemPm) SemIm(edu.cmu.tetrad.sem.SemIm) NumberFormat(java.text.NumberFormat)

Example 30 with SemIm

use of edu.cmu.tetrad.sem.SemIm in project tetrad by cmu-phil.

the class TestStandardizedSem method rtest8.

public void rtest8() {
    // RandomUtil.getInstance().setSeed(2958442283L);
    SemGraph graph = new SemGraph();
    Node x = new ContinuousVariable("X");
    Node y = new ContinuousVariable("Y");
    Node z = new ContinuousVariable("Z");
    graph.addNode(x);
    graph.addNode(y);
    graph.addNode(z);
    graph.addDirectedEdge(x, y);
    graph.addBidirectedEdge(x, y);
    graph.addDirectedEdge(x, z);
    graph.addDirectedEdge(y, z);
    graph.setShowErrorTerms(true);
    SemPm semPm = new SemPm(graph);
    SemIm semIm = new SemIm(semPm);
    StandardizedSemIm sem = new StandardizedSemIm(semIm, StandardizedSemIm.Initialization.CALCULATE_FROM_SEM);
    DataSet data = semIm.simulateDataCholesky(1000, false);
    data = ColtDataSet.makeContinuousData(data.getVariables(), DataUtils.standardizeData(data.getDoubleData()));
    SemEstimator estimator = new SemEstimator(data, semPm);
    semIm = estimator.estimate();
    DataSet data2 = semIm.simulateDataReducedForm(1000, false);
    DataSet data3 = sem.simulateDataReducedForm(1000, false);
    StandardizedSemIm.ParameterRange range2 = sem.getCovarianceRange(x, y);
    double high = range2.getHigh();
    double low = range2.getLow();
    if (high == Double.POSITIVE_INFINITY)
        high = 1000;
    if (low == Double.NEGATIVE_INFINITY)
        low = -1000;
    double coef = low + RandomUtil.getInstance().nextDouble() * (high - low);
    assertTrue(sem.setErrorCovariance(x, y, coef));
    assertTrue(isStandardized(sem));
}
Also used : SemPm(edu.cmu.tetrad.sem.SemPm) SemEstimator(edu.cmu.tetrad.sem.SemEstimator) SemIm(edu.cmu.tetrad.sem.SemIm) StandardizedSemIm(edu.cmu.tetrad.sem.StandardizedSemIm) StandardizedSemIm(edu.cmu.tetrad.sem.StandardizedSemIm)

Aggregations

SemIm (edu.cmu.tetrad.sem.SemIm)81 SemPm (edu.cmu.tetrad.sem.SemPm)71 Test (org.junit.Test)46 DataSet (edu.cmu.tetrad.data.DataSet)28 ArrayList (java.util.ArrayList)28 Graph (edu.cmu.tetrad.graph.Graph)26 Node (edu.cmu.tetrad.graph.Node)19 ContinuousVariable (edu.cmu.tetrad.data.ContinuousVariable)16 EdgeListGraph (edu.cmu.tetrad.graph.EdgeListGraph)16 SemEstimator (edu.cmu.tetrad.sem.SemEstimator)15 Dag (edu.cmu.tetrad.graph.Dag)10 DMSearch (edu.cmu.tetrad.search.DMSearch)9 StandardizedSemIm (edu.cmu.tetrad.sem.StandardizedSemIm)9 TetradMatrix (edu.cmu.tetrad.util.TetradMatrix)7 NumberFormat (java.text.NumberFormat)7 GraphNode (edu.cmu.tetrad.graph.GraphNode)5 IndependenceTest (edu.cmu.tetrad.search.IndependenceTest)4 DecimalFormat (java.text.DecimalFormat)4 List (java.util.List)4 CovarianceMatrix (edu.cmu.tetrad.data.CovarianceMatrix)3