Search in sources :

Example 31 with SemPm

use of edu.cmu.tetrad.sem.SemPm in project tetrad by cmu-phil.

the class TestPc method printStats.

private double[] printStats(String[] algorithms, int t, boolean directed, int numRuns, double alpha, double penaltyDiscount, int numMeasures, int numLatents, double edgeFactor) {
    NumberFormat nf = new DecimalFormat("0.00");
    double sumArrowPrecision = 0.0;
    double sumTailPrecision = 0.0;
    double sumBidirectedPrecision = 0.0;
    int numArrows = 0;
    int numTails = 0;
    int numBidirected = 0;
    int count = 0;
    int totalElapsed = 0;
    int countAP = 0;
    int countTP = 0;
    int countBP = 0;
    for (int i = 0; i < numRuns; i++) {
        int numEdges = (int) (edgeFactor * (numMeasures + numLatents));
        List<Node> nodes = new ArrayList<>();
        List<String> names = new ArrayList<>();
        for (int r = 0; r < numMeasures + numLatents; r++) {
            String name = "X" + (r + 1);
            nodes.add(new ContinuousVariable(name));
            names.add(name);
        }
        Graph dag = GraphUtils.randomGraphRandomForwardEdges(nodes, numLatents, numEdges, 10, 10, 10, false);
        SemPm pm = new SemPm(dag);
        SemIm im = new SemIm(pm);
        DataSet data = im.simulateData(1000, false);
        IndTestFisherZ test = new IndTestFisherZ(data, alpha);
        SemBicScore score = new SemBicScore(new CovarianceMatrixOnTheFly(data));
        score.setPenaltyDiscount(penaltyDiscount);
        GraphSearch search;
        switch(t) {
            case 0:
                search = new Pc(test);
                break;
            case 1:
                search = new Cpc(test);
                break;
            case 2:
                search = new Fges(score);
                break;
            case 3:
                search = new Fci(test);
                break;
            case 4:
                search = new GFci(test, score);
                break;
            case 5:
                search = new Rfci(test);
                break;
            case 6:
                search = new Cfci(test);
                break;
            default:
                throw new IllegalStateException();
        }
        long start = System.currentTimeMillis();
        Graph out = search.search();
        long stop = System.currentTimeMillis();
        long elapsed = stop - start;
        totalElapsed += elapsed;
        out = GraphUtils.replaceNodes(out, dag.getNodes());
        int arrowsTp = 0;
        int arrowsFp = 0;
        int tailsTp = 0;
        int tailsFp = 0;
        int bidirectedTp = 0;
        int bidirectedFp = 0;
        for (Edge edge : out.getEdges()) {
            if (directed && !(edge.isDirected() || Edges.isBidirectedEdge(edge))) {
                continue;
            }
            if (edge.getEndpoint1() == Endpoint.ARROW) {
                if (!dag.isAncestorOf(edge.getNode1(), edge.getNode2()) && dag.existsTrek(edge.getNode1(), edge.getNode2())) {
                    arrowsTp++;
                } else {
                    arrowsFp++;
                }
                numArrows++;
            }
            if (edge.getEndpoint2() == Endpoint.ARROW) {
                if (!dag.isAncestorOf(edge.getNode2(), edge.getNode1()) && dag.existsTrek(edge.getNode1(), edge.getNode2())) {
                    arrowsTp++;
                } else {
                    arrowsFp++;
                }
                numArrows++;
            }
            if (edge.getEndpoint1() == Endpoint.TAIL) {
                if (dag.existsDirectedPathFromTo(edge.getNode1(), edge.getNode2())) {
                    tailsTp++;
                } else {
                    tailsFp++;
                }
                numTails++;
            }
            if (edge.getEndpoint2() == Endpoint.TAIL) {
                if (dag.existsDirectedPathFromTo(edge.getNode2(), edge.getNode1())) {
                    tailsTp++;
                } else {
                    tailsFp++;
                }
                numTails++;
            }
            if (Edges.isBidirectedEdge(edge)) {
                if (!dag.isAncestorOf(edge.getNode1(), edge.getNode2()) && !dag.isAncestorOf(edge.getNode2(), edge.getNode1()) && dag.existsTrek(edge.getNode1(), edge.getNode2())) {
                    bidirectedTp++;
                } else {
                    bidirectedFp++;
                }
                numBidirected++;
            }
        }
        double arrowPrecision = arrowsTp / (double) (arrowsTp + arrowsFp);
        double tailPrecision = tailsTp / (double) (tailsTp + tailsFp);
        double bidirectedPrecision = bidirectedTp / (double) (bidirectedTp + bidirectedFp);
        if (!Double.isNaN(arrowPrecision)) {
            sumArrowPrecision += arrowPrecision;
            countAP++;
        }
        if (!Double.isNaN(tailPrecision)) {
            sumTailPrecision += tailPrecision;
            countTP++;
        }
        if (!Double.isNaN(bidirectedPrecision)) {
            sumBidirectedPrecision += bidirectedPrecision;
            countBP++;
        }
        count++;
    }
    double avgArrowPrecision = sumArrowPrecision / (double) countAP;
    double avgTailPrecision = sumTailPrecision / (double) countTP;
    double avgBidirectedPrecision = sumBidirectedPrecision / (double) countBP;
    double avgNumArrows = numArrows / (double) count;
    double avgNumTails = numTails / (double) count;
    double avgNumBidirected = numBidirected / (double) count;
    double avgElapsed = totalElapsed / (double) numRuns;
    // double avgRatioPrecisionToElapsed = avgArrowPrecision / avgElapsed;
    double[] ret = new double[] { avgArrowPrecision, avgTailPrecision, avgBidirectedPrecision, avgNumArrows, avgNumTails, avgNumBidirected, // minimize
    -avgElapsed // avgRatioPrecisionToElapsed
    };
    System.out.println();
    NumberFormat nf2 = new DecimalFormat("0.0000");
    System.out.println(algorithms[t] + " arrow precision " + nf.format(avgArrowPrecision));
    System.out.println(algorithms[t] + " tail precision " + nf.format(avgTailPrecision));
    System.out.println(algorithms[t] + " bidirected precision " + nf.format(avgBidirectedPrecision));
    System.out.println(algorithms[t] + " avg num arrow " + nf.format(avgNumArrows));
    System.out.println(algorithms[t] + " avg num tails " + nf.format(avgNumTails));
    System.out.println(algorithms[t] + " avg num bidirected " + nf.format(avgNumBidirected));
    System.out.println(algorithms[t] + " avg elapsed " + nf.format(avgElapsed));
    return ret;
}
Also used : DecimalFormat(java.text.DecimalFormat) SemPm(edu.cmu.tetrad.sem.SemPm) SemIm(edu.cmu.tetrad.sem.SemIm) NumberFormat(java.text.NumberFormat)

Example 32 with SemPm

use of edu.cmu.tetrad.sem.SemPm in project tetrad by cmu-phil.

the class TestStandardizedSem method rtest8.

public void rtest8() {
    // RandomUtil.getInstance().setSeed(2958442283L);
    SemGraph graph = new SemGraph();
    Node x = new ContinuousVariable("X");
    Node y = new ContinuousVariable("Y");
    Node z = new ContinuousVariable("Z");
    graph.addNode(x);
    graph.addNode(y);
    graph.addNode(z);
    graph.addDirectedEdge(x, y);
    graph.addBidirectedEdge(x, y);
    graph.addDirectedEdge(x, z);
    graph.addDirectedEdge(y, z);
    graph.setShowErrorTerms(true);
    SemPm semPm = new SemPm(graph);
    SemIm semIm = new SemIm(semPm);
    StandardizedSemIm sem = new StandardizedSemIm(semIm, StandardizedSemIm.Initialization.CALCULATE_FROM_SEM);
    DataSet data = semIm.simulateDataCholesky(1000, false);
    data = ColtDataSet.makeContinuousData(data.getVariables(), DataUtils.standardizeData(data.getDoubleData()));
    SemEstimator estimator = new SemEstimator(data, semPm);
    semIm = estimator.estimate();
    DataSet data2 = semIm.simulateDataReducedForm(1000, false);
    DataSet data3 = sem.simulateDataReducedForm(1000, false);
    StandardizedSemIm.ParameterRange range2 = sem.getCovarianceRange(x, y);
    double high = range2.getHigh();
    double low = range2.getLow();
    if (high == Double.POSITIVE_INFINITY)
        high = 1000;
    if (low == Double.NEGATIVE_INFINITY)
        low = -1000;
    double coef = low + RandomUtil.getInstance().nextDouble() * (high - low);
    assertTrue(sem.setErrorCovariance(x, y, coef));
    assertTrue(isStandardized(sem));
}
Also used : SemPm(edu.cmu.tetrad.sem.SemPm) SemEstimator(edu.cmu.tetrad.sem.SemEstimator) SemIm(edu.cmu.tetrad.sem.SemIm) StandardizedSemIm(edu.cmu.tetrad.sem.StandardizedSemIm) StandardizedSemIm(edu.cmu.tetrad.sem.StandardizedSemIm)

Example 33 with SemPm

use of edu.cmu.tetrad.sem.SemPm in project tetrad by cmu-phil.

the class TestStandardizedSem method test5.

@Test
public void test5() {
    RandomUtil.getInstance().setSeed(582374923L);
    SemGraph graph = new SemGraph();
    graph.setShowErrorTerms(true);
    Node x1 = new ContinuousVariable("X1");
    Node x2 = new ContinuousVariable("X2");
    Node x3 = new ContinuousVariable("X3");
    graph.addNode(x1);
    graph.addNode(x2);
    graph.addNode(x3);
    graph.setShowErrorTerms(true);
    Node ex1 = graph.getExogenous(x1);
    Node ex2 = graph.getExogenous(x2);
    Node ex3 = graph.getExogenous(x3);
    graph.addDirectedEdge(x1, x3);
    graph.addDirectedEdge(x2, x3);
    // graph.addDirectedEdge(x1, x2);
    // graph.addBidirectedEdge(ex1, ex2);
    SemPm pm = new SemPm(graph);
    SemIm im = new SemIm(pm);
    DataSet dataSet = im.simulateDataRecursive(1000, false);
    TetradMatrix _dataSet = dataSet.getDoubleData();
    _dataSet = DataUtils.standardizeData(_dataSet);
    DataSet dataSetStandardized = ColtDataSet.makeData(dataSet.getVariables(), _dataSet);
    SemEstimator estimator = new SemEstimator(dataSetStandardized, im.getSemPm());
    SemIm imStandardized = estimator.estimate();
    StandardizedSemIm sem = new StandardizedSemIm(im);
    // sem.setErrorCovariance(ex1, ex2, -.24);
    assertTrue(isStandardized(sem));
}
Also used : SemPm(edu.cmu.tetrad.sem.SemPm) TetradMatrix(edu.cmu.tetrad.util.TetradMatrix) SemEstimator(edu.cmu.tetrad.sem.SemEstimator) SemIm(edu.cmu.tetrad.sem.SemIm) StandardizedSemIm(edu.cmu.tetrad.sem.StandardizedSemIm) StandardizedSemIm(edu.cmu.tetrad.sem.StandardizedSemIm) Test(org.junit.Test)

Example 34 with SemPm

use of edu.cmu.tetrad.sem.SemPm in project tetrad by cmu-phil.

the class TestStandardizedSem method test2.

@Test
public void test2() {
    RandomUtil.getInstance().setSeed(5729384723L);
    SemGraph graph = new SemGraph();
    Node x1 = new ContinuousVariable("X1");
    Node x2 = new ContinuousVariable("X2");
    Node x3 = new ContinuousVariable("X3");
    Node x4 = new ContinuousVariable("X4");
    Node x5 = new ContinuousVariable("X5");
    graph.addNode(x1);
    graph.addNode(x2);
    graph.addNode(x3);
    graph.addNode(x4);
    graph.addNode(x5);
    graph.setShowErrorTerms(true);
    graph.addDirectedEdge(x1, x2);
    graph.addDirectedEdge(x2, x3);
    graph.addDirectedEdge(x4, x3);
    graph.addDirectedEdge(x2, x4);
    graph.addDirectedEdge(x1, x4);
    graph.addDirectedEdge(x5, x4);
    SemPm pm = new SemPm(graph);
    SemIm im = new SemIm(pm);
    StandardizedSemIm sem = new StandardizedSemIm(im);
    assertTrue(isStandardized(sem));
}
Also used : SemPm(edu.cmu.tetrad.sem.SemPm) SemIm(edu.cmu.tetrad.sem.SemIm) StandardizedSemIm(edu.cmu.tetrad.sem.StandardizedSemIm) StandardizedSemIm(edu.cmu.tetrad.sem.StandardizedSemIm) Test(org.junit.Test)

Example 35 with SemPm

use of edu.cmu.tetrad.sem.SemPm in project tetrad by cmu-phil.

the class TestStandardizedSem method test6.

@Test
public void test6() {
    // RandomUtil.getInstance().setSeed(582374923L);
    SemGraph graph = new SemGraph();
    graph.setShowErrorTerms(true);
    Node x1 = new ContinuousVariable("X1");
    Node x2 = new ContinuousVariable("X2");
    Node x3 = new ContinuousVariable("X3");
    graph.addNode(x1);
    graph.addNode(x2);
    graph.addNode(x3);
    graph.setShowErrorTerms(true);
    Node ex1 = graph.getExogenous(x1);
    Node ex2 = graph.getExogenous(x2);
    Node ex3 = graph.getExogenous(x3);
    graph.addDirectedEdge(x1, x3);
    graph.addDirectedEdge(x2, x3);
    graph.addDirectedEdge(x1, x2);
    graph.addBidirectedEdge(ex1, ex2);
    // List<List<Node>> treks = DataGraphUtils.treksIncludingBidirected(graph, x1, x3);
    // 
    // for (List<Node> trek : treks) {
    // System.out.println(trek);
    // }
    SemPm pm = new SemPm(graph);
    SemIm im = new SemIm(pm);
    DataSet dataSet = im.simulateDataRecursive(1000, false);
    TetradMatrix _dataSet = dataSet.getDoubleData();
    _dataSet = DataUtils.standardizeData(_dataSet);
    DataSet dataSetStandardized = ColtDataSet.makeData(dataSet.getVariables(), _dataSet);
    SemEstimator estimator = new SemEstimator(dataSetStandardized, im.getSemPm());
    SemIm imStandardized = estimator.estimate();
    StandardizedSemIm sem = new StandardizedSemIm(im);
    assertTrue(isStandardized(sem));
}
Also used : SemPm(edu.cmu.tetrad.sem.SemPm) TetradMatrix(edu.cmu.tetrad.util.TetradMatrix) SemEstimator(edu.cmu.tetrad.sem.SemEstimator) SemIm(edu.cmu.tetrad.sem.SemIm) StandardizedSemIm(edu.cmu.tetrad.sem.StandardizedSemIm) StandardizedSemIm(edu.cmu.tetrad.sem.StandardizedSemIm) Test(org.junit.Test)

Aggregations

SemPm (edu.cmu.tetrad.sem.SemPm)77 SemIm (edu.cmu.tetrad.sem.SemIm)71 Test (org.junit.Test)44 ArrayList (java.util.ArrayList)29 DataSet (edu.cmu.tetrad.data.DataSet)28 Graph (edu.cmu.tetrad.graph.Graph)25 ContinuousVariable (edu.cmu.tetrad.data.ContinuousVariable)18 Node (edu.cmu.tetrad.graph.Node)18 SemEstimator (edu.cmu.tetrad.sem.SemEstimator)16 EdgeListGraph (edu.cmu.tetrad.graph.EdgeListGraph)15 Dag (edu.cmu.tetrad.graph.Dag)10 DMSearch (edu.cmu.tetrad.search.DMSearch)9 StandardizedSemIm (edu.cmu.tetrad.sem.StandardizedSemIm)9 NumberFormat (java.text.NumberFormat)7 TetradMatrix (edu.cmu.tetrad.util.TetradMatrix)6 ICovarianceMatrix (edu.cmu.tetrad.data.ICovarianceMatrix)5 GraphNode (edu.cmu.tetrad.graph.GraphNode)5 CovarianceMatrix (edu.cmu.tetrad.data.CovarianceMatrix)4 IndependenceTest (edu.cmu.tetrad.search.IndependenceTest)4 Parameters (edu.cmu.tetrad.util.Parameters)4