Search in sources :

Example 1 with Fask

use of edu.cmu.tetrad.search.Fask in project tetrad by cmu-phil.

the class TestSimulatedFmri method testClark.

// @Test
public void testClark() {
    double f = .1;
    int N = 512;
    double alpha = 1.0;
    double penaltyDiscount = 1.0;
    for (int i = 0; i < 100; i++) {
        {
            Node x = new ContinuousVariable("X");
            Node y = new ContinuousVariable("Y");
            Node z = new ContinuousVariable("Z");
            Graph g = new EdgeListGraph();
            g.addNode(x);
            g.addNode(y);
            g.addNode(z);
            g.addDirectedEdge(x, y);
            g.addDirectedEdge(z, x);
            g.addDirectedEdge(z, y);
            GeneralizedSemPm pm = new GeneralizedSemPm(g);
            try {
                pm.setNodeExpression(g.getNode("X"), "0.5 * Z + E_X");
                pm.setNodeExpression(g.getNode("Y"), "0.5 * X + 0.5 * Z + E_Y");
                pm.setNodeExpression(g.getNode("Z"), "E_Z");
                String error = "pow(Uniform(0, 1), " + f + ")";
                pm.setNodeExpression(pm.getErrorNode(g.getNode("X")), error);
                pm.setNodeExpression(pm.getErrorNode(g.getNode("Y")), error);
                pm.setNodeExpression(pm.getErrorNode(g.getNode("Z")), error);
            } catch (ParseException e) {
                System.out.println(e);
            }
            GeneralizedSemIm im = new GeneralizedSemIm(pm);
            DataSet data = im.simulateData(N, false);
            edu.cmu.tetrad.search.SemBicScore score = new edu.cmu.tetrad.search.SemBicScore(new CovarianceMatrixOnTheFly(data, false));
            score.setPenaltyDiscount(penaltyDiscount);
            Fask fask = new Fask(data, score);
            fask.setPenaltyDiscount(penaltyDiscount);
            fask.setAlpha(alpha);
            Graph out = fask.search();
            System.out.println(out);
        }
        {
            Node x = new ContinuousVariable("X");
            Node y = new ContinuousVariable("Y");
            Node z = new ContinuousVariable("Z");
            Graph g = new EdgeListGraph();
            g.addNode(x);
            g.addNode(y);
            g.addNode(z);
            g.addDirectedEdge(x, y);
            g.addDirectedEdge(x, z);
            g.addDirectedEdge(y, z);
            GeneralizedSemPm pm = new GeneralizedSemPm(g);
            try {
                pm.setNodeExpression(g.getNode("X"), "E_X");
                pm.setNodeExpression(g.getNode("Y"), "0.4 * X + E_Y");
                pm.setNodeExpression(g.getNode("Z"), "0.4 * X + 0.4 * Y + E_Z");
                String error = "pow(Uniform(0, 1), " + f + ")";
                pm.setNodeExpression(pm.getErrorNode(g.getNode("X")), error);
                pm.setNodeExpression(pm.getErrorNode(g.getNode("Y")), error);
                pm.setNodeExpression(pm.getErrorNode(g.getNode("Z")), error);
            } catch (ParseException e) {
                System.out.println(e);
            }
            GeneralizedSemIm im = new GeneralizedSemIm(pm);
            DataSet data = im.simulateData(N, false);
            edu.cmu.tetrad.search.SemBicScore score = new edu.cmu.tetrad.search.SemBicScore(new CovarianceMatrixOnTheFly(data, false));
            score.setPenaltyDiscount(penaltyDiscount);
            Fask fask = new Fask(data, score);
            fask.setPenaltyDiscount(penaltyDiscount);
            fask.setAlpha(alpha);
            Graph out = fask.search();
            System.out.println(out);
        }
    }
}
Also used : DataSet(edu.cmu.tetrad.data.DataSet) Node(edu.cmu.tetrad.graph.Node) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) Fask(edu.cmu.tetrad.search.Fask) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) Graph(edu.cmu.tetrad.graph.Graph) GeneralizedSemIm(edu.cmu.tetrad.sem.GeneralizedSemIm) ParseException(java.text.ParseException) CovarianceMatrixOnTheFly(edu.cmu.tetrad.data.CovarianceMatrixOnTheFly) GeneralizedSemPm(edu.cmu.tetrad.sem.GeneralizedSemPm) SemBicScore(edu.cmu.tetrad.algcomparison.score.SemBicScore)

Example 2 with Fask

use of edu.cmu.tetrad.search.Fask in project tetrad by cmu-phil.

the class FaskConcatenated method search.

@Override
public Graph search(List<DataModel> dataSets, Parameters parameters) {
    if (parameters.getInt("bootstrapSampleSize") < 1) {
        List<DataSet> centered = new ArrayList<>();
        for (DataModel dataSet : dataSets) {
            centered.add(DataUtils.standardizeData((DataSet) dataSet));
        }
        DataSet dataSet = DataUtils.concatenate(centered);
        dataSet.setNumberFormat(new DecimalFormat("0.000000000000000000"));
        Fask search = new Fask(dataSet, score.getScore(dataSet, parameters));
        search.setDepth(parameters.getInt("depth"));
        search.setPenaltyDiscount(parameters.getDouble("penaltyDiscount"));
        search.setExtraEdgeThreshold(parameters.getDouble("extraEdgeThreshold"));
        search.setDelta(parameters.getDouble("faskDelta"));
        search.setAlpha(parameters.getDouble("twoCycleAlpha"));
        search.setKnowledge(knowledge);
        return search.search();
    } else {
        FaskConcatenated algorithm = new FaskConcatenated(score);
        algorithm.setKnowledge(knowledge);
        List<DataSet> datasets = new ArrayList<>();
        for (DataModel dataModel : dataSets) {
            datasets.add((DataSet) dataModel);
        }
        GeneralBootstrapTest search = new GeneralBootstrapTest(datasets, algorithm, parameters.getInt("bootstrapSampleSize"));
        BootstrapEdgeEnsemble edgeEnsemble = BootstrapEdgeEnsemble.Highest;
        switch(parameters.getInt("bootstrapEnsemble", 1)) {
            case 0:
                edgeEnsemble = BootstrapEdgeEnsemble.Preserved;
                break;
            case 1:
                edgeEnsemble = BootstrapEdgeEnsemble.Highest;
                break;
            case 2:
                edgeEnsemble = BootstrapEdgeEnsemble.Majority;
        }
        search.setEdgeEnsemble(edgeEnsemble);
        search.setParameters(parameters);
        search.setVerbose(parameters.getBoolean("verbose"));
        return search.search();
    }
}
Also used : Fask(edu.cmu.tetrad.search.Fask) GeneralBootstrapTest(edu.pitt.dbmi.algo.bootstrap.GeneralBootstrapTest) BootstrapEdgeEnsemble(edu.pitt.dbmi.algo.bootstrap.BootstrapEdgeEnsemble) DecimalFormat(java.text.DecimalFormat) ArrayList(java.util.ArrayList)

Example 3 with Fask

use of edu.cmu.tetrad.search.Fask in project tetrad by cmu-phil.

the class TestSimulatedFmri method testClark2.

// @Test
public void testClark2() {
    Node x = new ContinuousVariable("X");
    Node y = new ContinuousVariable("Y");
    Node z = new ContinuousVariable("Z");
    Graph g = new EdgeListGraph();
    g.addNode(x);
    g.addNode(y);
    g.addNode(z);
    g.addDirectedEdge(x, y);
    g.addDirectedEdge(x, z);
    g.addDirectedEdge(y, z);
    GeneralizedSemPm pm = new GeneralizedSemPm(g);
    try {
        pm.setNodeExpression(g.getNode("X"), "E_X");
        pm.setNodeExpression(g.getNode("Y"), "0.4 * X + E_Y");
        pm.setNodeExpression(g.getNode("Z"), "0.4 * X + 0.4 * Y + E_Z");
        String error = "pow(Uniform(0, 1), 1.5)";
        pm.setNodeExpression(pm.getErrorNode(g.getNode("X")), error);
        pm.setNodeExpression(pm.getErrorNode(g.getNode("Y")), error);
        pm.setNodeExpression(pm.getErrorNode(g.getNode("Z")), error);
    } catch (ParseException e) {
        System.out.println(e);
    }
    GeneralizedSemIm im = new GeneralizedSemIm(pm);
    DataSet data = im.simulateData(1000, false);
    edu.cmu.tetrad.search.SemBicScore score = new edu.cmu.tetrad.search.SemBicScore(new CovarianceMatrixOnTheFly(data, false));
    Fask fask = new Fask(data, score);
    fask.setPenaltyDiscount(1);
    fask.setAlpha(0.5);
    Graph out = fask.search();
    System.out.println(out);
}
Also used : DataSet(edu.cmu.tetrad.data.DataSet) Node(edu.cmu.tetrad.graph.Node) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) Fask(edu.cmu.tetrad.search.Fask) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) Graph(edu.cmu.tetrad.graph.Graph) GeneralizedSemIm(edu.cmu.tetrad.sem.GeneralizedSemIm) ParseException(java.text.ParseException) CovarianceMatrixOnTheFly(edu.cmu.tetrad.data.CovarianceMatrixOnTheFly) GeneralizedSemPm(edu.cmu.tetrad.sem.GeneralizedSemPm) SemBicScore(edu.cmu.tetrad.algcomparison.score.SemBicScore)

Aggregations

Fask (edu.cmu.tetrad.search.Fask)3 SemBicScore (edu.cmu.tetrad.algcomparison.score.SemBicScore)2 ContinuousVariable (edu.cmu.tetrad.data.ContinuousVariable)2 CovarianceMatrixOnTheFly (edu.cmu.tetrad.data.CovarianceMatrixOnTheFly)2 DataSet (edu.cmu.tetrad.data.DataSet)2 EdgeListGraph (edu.cmu.tetrad.graph.EdgeListGraph)2 Graph (edu.cmu.tetrad.graph.Graph)2 Node (edu.cmu.tetrad.graph.Node)2 GeneralizedSemIm (edu.cmu.tetrad.sem.GeneralizedSemIm)2 GeneralizedSemPm (edu.cmu.tetrad.sem.GeneralizedSemPm)2 ParseException (java.text.ParseException)2 BootstrapEdgeEnsemble (edu.pitt.dbmi.algo.bootstrap.BootstrapEdgeEnsemble)1 GeneralBootstrapTest (edu.pitt.dbmi.algo.bootstrap.GeneralBootstrapTest)1 DecimalFormat (java.text.DecimalFormat)1 ArrayList (java.util.ArrayList)1