Search in sources :

Example 1 with IndependenceTest

use of edu.cmu.tetrad.search.IndependenceTest in project tetrad by cmu-phil.

the class TsImages method search.

@Override
public Graph search(List<DataModel> dataSets, Parameters parameters) {
    List<DataModel> dataModels = new ArrayList<>();
    for (DataModel dataSet : dataSets) {
        dataModels.add(dataSet);
    }
    TsGFci search;
    if (score instanceof SemBicScore) {
        SemBicScoreImages gesScore = new SemBicScoreImages(dataModels);
        gesScore.setPenaltyDiscount(parameters.getDouble("penaltyDiscount"));
        IndependenceTest test = new IndTestScore(gesScore);
        search = new TsGFci(test, gesScore);
    } else if (score instanceof BdeuScore) {
        double samplePrior = parameters.getDouble("samplePrior", 1);
        double structurePrior = parameters.getDouble("structurePrior", 1);
        BdeuScoreImages score = new BdeuScoreImages(dataModels);
        score.setSamplePrior(samplePrior);
        score.setStructurePrior(structurePrior);
        IndependenceTest test = new IndTestScore(score);
        search = new TsGFci(test, score);
    } else {
        throw new IllegalStateException("Sorry, data must either be all continuous or all discrete.");
    }
    IKnowledge knowledge = dataModels.get(0).getKnowledge();
    search.setKnowledge(knowledge);
    return search.search();
}
Also used : BdeuScoreImages(edu.cmu.tetrad.search.BdeuScoreImages) IKnowledge(edu.cmu.tetrad.data.IKnowledge) IndependenceTest(edu.cmu.tetrad.search.IndependenceTest) SemBicScoreImages(edu.cmu.tetrad.search.SemBicScoreImages) DataModel(edu.cmu.tetrad.data.DataModel) IndTestScore(edu.cmu.tetrad.search.IndTestScore) BdeuScore(edu.cmu.tetrad.algcomparison.score.BdeuScore) ArrayList(java.util.ArrayList) TsGFci(edu.cmu.tetrad.search.TsGFci) SemBicScore(edu.cmu.tetrad.algcomparison.score.SemBicScore)

Example 2 with IndependenceTest

use of edu.cmu.tetrad.search.IndependenceTest in project tetrad by cmu-phil.

the class CcdMax method search.

@Override
public Graph search(DataModel dataSet, Parameters parameters) {
    if (parameters.getInt("bootstrapSampleSize") < 1) {
        IndependenceTest test = this.test.getTest(dataSet, parameters);
        edu.cmu.tetrad.search.CcdMax search = new edu.cmu.tetrad.search.CcdMax(test);
        search.setDoColliderOrientations(parameters.getBoolean("doColliderOrientation"));
        search.setUseHeuristic(parameters.getBoolean("useMaxPOrientationHeuristic"));
        search.setMaxPathLength(parameters.getInt("maxPOrientationMaxPathLength"));
        search.setKnowledge(knowledge);
        search.setDepth(parameters.getInt("depth"));
        search.setApplyOrientAwayFromCollider(parameters.getBoolean("applyR1"));
        search.setUseOrientTowardDConnections(parameters.getBoolean("orientTowardDConnections"));
        search.setDepth(parameters.getInt("depth"));
        return search.search();
    } else {
        CcdMax algorithm = new CcdMax(test);
        DataSet data = (DataSet) dataSet;
        GeneralBootstrapTest search = new GeneralBootstrapTest(data, algorithm, parameters.getInt("bootstrapSampleSize"));
        search.setKnowledge(knowledge);
        BootstrapEdgeEnsemble edgeEnsemble = BootstrapEdgeEnsemble.Highest;
        switch(parameters.getInt("bootstrapEnsemble", 1)) {
            case 0:
                edgeEnsemble = BootstrapEdgeEnsemble.Preserved;
                break;
            case 1:
                edgeEnsemble = BootstrapEdgeEnsemble.Highest;
                break;
            case 2:
                edgeEnsemble = BootstrapEdgeEnsemble.Majority;
        }
        search.setEdgeEnsemble(edgeEnsemble);
        search.setParameters(parameters);
        search.setVerbose(parameters.getBoolean("verbose"));
        return search.search();
    }
}
Also used : IndependenceTest(edu.cmu.tetrad.search.IndependenceTest) GeneralBootstrapTest(edu.pitt.dbmi.algo.bootstrap.GeneralBootstrapTest) BootstrapEdgeEnsemble(edu.pitt.dbmi.algo.bootstrap.BootstrapEdgeEnsemble)

Example 3 with IndependenceTest

use of edu.cmu.tetrad.search.IndependenceTest in project tetrad by cmu-phil.

the class TestIndTestWaldLR method testIsIndependent.

@Test
public void testIsIndependent() {
    RandomUtil.getInstance().setSeed(1450705713157L);
    int numPassed = 0;
    for (int i = 0; i < 10; i++) {
        List<Node> nodes = new ArrayList<>();
        for (int i1 = 0; i1 < 10; i1++) {
            nodes.add(new ContinuousVariable("X" + (i1 + 1)));
        }
        Graph graph = GraphUtils.randomGraph(nodes, 0, 10, 3, 3, 3, false);
        SemPm pm = new SemPm(graph);
        SemIm im = new SemIm(pm);
        DataSet data = im.simulateData(1000, false);
        Discretizer discretizer = new Discretizer(data);
        discretizer.setVariablesCopied(true);
        discretizer.equalCounts(data.getVariable(0), 2);
        discretizer.equalCounts(data.getVariable(3), 2);
        data = discretizer.discretize();
        Node x1 = data.getVariable("X1");
        Node x2 = data.getVariable("X2");
        Node x3 = data.getVariable("X3");
        Node x4 = data.getVariable("X4");
        Node x5 = data.getVariable("X5");
        List<Node> cond = new ArrayList<>();
        cond.add(x3);
        cond.add(x4);
        cond.add(x5);
        Node x1Graph = graph.getNode(x1.getName());
        Node x2Graph = graph.getNode(x2.getName());
        List<Node> condGraph = new ArrayList<>();
        for (Node node : cond) {
            condGraph.add(graph.getNode(node.getName()));
        }
        // Using the Wald LR test since it's most up to date.
        IndependenceTest test = new IndTestMultinomialLogisticRegressionWald(data, 0.05, false);
        IndTestDSep dsep = new IndTestDSep(graph);
        boolean correct = test.isIndependent(x2, x1, cond) == dsep.isIndependent(x2Graph, x1Graph, condGraph);
        if (correct) {
            numPassed++;
        }
    }
    // System.out.println(RandomUtil.getInstance().getSeed());
    // Do not always get all 10.
    assertEquals(10, numPassed);
}
Also used : DataSet(edu.cmu.tetrad.data.DataSet) Node(edu.cmu.tetrad.graph.Node) ArrayList(java.util.ArrayList) Discretizer(edu.cmu.tetrad.data.Discretizer) ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) IndependenceTest(edu.cmu.tetrad.search.IndependenceTest) IndTestDSep(edu.cmu.tetrad.search.IndTestDSep) Graph(edu.cmu.tetrad.graph.Graph) SemPm(edu.cmu.tetrad.sem.SemPm) IndTestMultinomialLogisticRegressionWald(edu.pitt.csb.mgm.IndTestMultinomialLogisticRegressionWald) SemIm(edu.cmu.tetrad.sem.SemIm) Test(org.junit.Test) IndependenceTest(edu.cmu.tetrad.search.IndependenceTest)

Example 4 with IndependenceTest

use of edu.cmu.tetrad.search.IndependenceTest in project tetrad by cmu-phil.

the class TestMbfs method testRandom.

@Test
public void testRandom() {
    RandomUtil.getInstance().setSeed(8388428832L);
    List<Node> nodes1 = new ArrayList<>();
    for (int i = 0; i < 10; i++) {
        nodes1.add(new ContinuousVariable("X" + (i + 1)));
    }
    Dag dag = new Dag(GraphUtils.randomGraph(nodes1, 0, 10, 5, 5, 5, false));
    IndependenceTest test = new IndTestDSep(dag);
    Mbfs search = new Mbfs(test, -1);
    List<Node> nodes = dag.getNodes();
    for (Node node : nodes) {
        Graph resultMb = search.search(node.getName());
        Graph trueMb = GraphUtils.markovBlanketDag(node, dag);
        List<Node> resultNodes = resultMb.getNodes();
        List<Node> trueNodes = trueMb.getNodes();
        Set<String> resultNames = new HashSet<>();
        for (Node resultNode : resultNodes) {
            resultNames.add(resultNode.getName());
        }
        Set<String> trueNames = new HashSet<>();
        for (Node v : trueNodes) {
            trueNames.add(v.getName());
        }
        assertTrue(resultNames.equals(trueNames));
        Set<Edge> resultEdges = resultMb.getEdges();
        for (Edge resultEdge : resultEdges) {
            if (Edges.isDirectedEdge(resultEdge)) {
                String name1 = resultEdge.getNode1().getName();
                String name2 = resultEdge.getNode2().getName();
                Node node1 = trueMb.getNode(name1);
                Node node2 = trueMb.getNode(name2);
                // possibility that the node is actually a child.
                if (node1 == null) {
                    fail("Node " + name1 + " is not in the true graph.");
                }
                if (node2 == null) {
                    fail("Node " + name2 + " is not in the true graph.");
                }
                Edge trueEdge = trueMb.getEdge(node1, node2);
                if (trueEdge == null) {
                    Node resultNode1 = resultMb.getNode(node1.getName());
                    Node resultNode2 = resultMb.getNode(node2.getName());
                    Node resultTarget = resultMb.getNode(node.getName());
                    Edge a = resultMb.getEdge(resultNode1, resultTarget);
                    Edge b = resultMb.getEdge(resultNode2, resultTarget);
                    if (a == null || b == null) {
                        continue;
                    }
                    if ((Edges.isDirectedEdge(a) && Edges.isUndirectedEdge(b)) || (Edges.isUndirectedEdge(a) && Edges.isDirectedEdge(b))) {
                        continue;
                    }
                    fail("EXTRA EDGE: Edge in result MB but not true MB = " + resultEdge);
                }
                assertEquals(resultEdge.getEndpoint1(), trueEdge.getEndpoint1());
                assertEquals(resultEdge.getEndpoint2(), trueEdge.getEndpoint2());
            }
        }
    }
}
Also used : Mbfs(edu.cmu.tetrad.search.Mbfs) ArrayList(java.util.ArrayList) ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) IndependenceTest(edu.cmu.tetrad.search.IndependenceTest) IndTestDSep(edu.cmu.tetrad.search.IndTestDSep) HashSet(java.util.HashSet) Test(org.junit.Test) IndependenceTest(edu.cmu.tetrad.search.IndependenceTest)

Example 5 with IndependenceTest

use of edu.cmu.tetrad.search.IndependenceTest in project tetrad by cmu-phil.

the class TestPcd method checkSearch.

/**
 * Presents the input graph to FCI and checks to make sure the output of FCI is equivalent to the given output
 * graph.
 */
private void checkSearch(String inputGraph, String outputGraph) {
    // Set up graph and node objects.
    Graph graph = GraphConverter.convert(inputGraph);
    // Set up search.
    IndependenceTest independence = new IndTestDSep(graph);
    Pcd pc = new Pcd(independence);
    // Run search
    Graph resultGraph = pc.search();
    // Build comparison graph.
    Graph trueGraph = GraphConverter.convert(outputGraph);
    resultGraph = GraphUtils.replaceNodes(resultGraph, trueGraph.getNodes());
    // Do test.
    assertTrue(resultGraph.equals(trueGraph));
}
Also used : Pcd(edu.cmu.tetrad.search.Pcd) IndependenceTest(edu.cmu.tetrad.search.IndependenceTest) IndTestDSep(edu.cmu.tetrad.search.IndTestDSep) Graph(edu.cmu.tetrad.graph.Graph)

Aggregations

IndependenceTest (edu.cmu.tetrad.search.IndependenceTest)12 DataSet (edu.cmu.tetrad.data.DataSet)4 IndTestDSep (edu.cmu.tetrad.search.IndTestDSep)4 BootstrapEdgeEnsemble (edu.pitt.dbmi.algo.bootstrap.BootstrapEdgeEnsemble)4 GeneralBootstrapTest (edu.pitt.dbmi.algo.bootstrap.GeneralBootstrapTest)4 ArrayList (java.util.ArrayList)4 Test (org.junit.Test)4 Graph (edu.cmu.tetrad.graph.Graph)3 Node (edu.cmu.tetrad.graph.Node)3 IndTestScore (edu.cmu.tetrad.search.IndTestScore)3 SemIm (edu.cmu.tetrad.sem.SemIm)3 SemPm (edu.cmu.tetrad.sem.SemPm)3 BdeuScore (edu.cmu.tetrad.algcomparison.score.BdeuScore)2 SemBicScore (edu.cmu.tetrad.algcomparison.score.SemBicScore)2 ContinuousVariable (edu.cmu.tetrad.data.ContinuousVariable)2 TsGFci (edu.cmu.tetrad.search.TsGFci)2 DataModel (edu.cmu.tetrad.data.DataModel)1 Discretizer (edu.cmu.tetrad.data.Discretizer)1 IKnowledge (edu.cmu.tetrad.data.IKnowledge)1 BdeuScoreImages (edu.cmu.tetrad.search.BdeuScoreImages)1