Search in sources :

Example 6 with EdgeListGraph

use of edu.cmu.tetrad.graph.EdgeListGraph in project tetrad by cmu-phil.

the class Mmhc method search.

/**
 * Runs PC starting with a fully connected graph over all of the variables in the domain of the independence test.
 */
public Graph search() {
    List<Node> variables = independenceTest.getVariables();
    Mmmb mmmb = new Mmmb(independenceTest, getDepth(), true);
    Map<Node, List<Node>> pc = new HashMap<>();
    for (Node x : variables) {
        pc.put(x, mmmb.getPc(x));
    }
    Graph graph = new EdgeListGraph();
    for (Node x : variables) {
        graph.addNode(x);
    }
    for (Node x : variables) {
        for (Node y : pc.get(x)) {
            if (!graph.isAdjacentTo(x, y)) {
                graph.addUndirectedEdge(x, y);
            }
        }
    }
    FgesOrienter orienter = new FgesOrienter(data);
    orienter.orient(graph);
    return graph;
}
Also used : FgesOrienter(edu.cmu.tetrad.search.FgesOrienter) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) Graph(edu.cmu.tetrad.graph.Graph) HashMap(java.util.HashMap) Node(edu.cmu.tetrad.graph.Node) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) List(java.util.List)

Example 7 with EdgeListGraph

use of edu.cmu.tetrad.graph.EdgeListGraph in project tetrad by cmu-phil.

the class LingDisplay method resetDisplay.

private void resetDisplay() {
    String option = (String) subsetCombo.getSelectedItem();
    if ("Show All".equals(option)) {
        final List<Integer> _subsetIndices = getAllIndices(getStoredGraphs());
        subsetIndices.clear();
        subsetIndices.addAll(_subsetIndices);
        int min = subsetIndices.size() == 0 ? 0 : 1;
        final SpinnerNumberModel model = new SpinnerNumberModel(min, min, subsetIndices.size(), 1);
        model.addChangeListener(new ChangeListener() {

            public void stateChanged(ChangeEvent e) {
                int index = model.getNumber().intValue();
                workbench.setGraph(storedGraphs.getGraph(subsetIndices.get(index - 1)));
            }
        });
        spinner.setModel(model);
        totalLabel.setText(" of " + _subsetIndices.size());
        if (subsetIndices.isEmpty()) {
            workbench.setGraph(new EdgeListGraph());
        } else {
            workbench.setGraph(storedGraphs.getGraph(subsetIndices.get(0)));
        }
    } else if ("Show Stable".equals(option)) {
        final List<Integer> _subsetIndices = getStableIndices(getStoredGraphs());
        subsetIndices.clear();
        subsetIndices.addAll(_subsetIndices);
        int min = subsetIndices.size() == 0 ? 0 : 1;
        final SpinnerNumberModel model = new SpinnerNumberModel(min, min, subsetIndices.size(), 1);
        model.addChangeListener(new ChangeListener() {

            public void stateChanged(ChangeEvent e) {
                int index = model.getNumber().intValue();
                workbench.setGraph(storedGraphs.getGraph(subsetIndices.get(index - 1)));
            }
        });
        spinner.setModel(model);
        totalLabel.setText(" of " + _subsetIndices.size());
        if (subsetIndices.isEmpty()) {
            workbench.setGraph(new EdgeListGraph());
        } else {
            workbench.setGraph(storedGraphs.getGraph(subsetIndices.get(0)));
        }
    } else if ("Show Unstable".equals(option)) {
        final List<Integer> _subsetIndices = getUnstableIndices(getStoredGraphs());
        subsetIndices.clear();
        subsetIndices.addAll(_subsetIndices);
        int min = subsetIndices.size() == 0 ? 0 : 1;
        final SpinnerNumberModel model = new SpinnerNumberModel(min, min, subsetIndices.size(), 1);
        model.addChangeListener(new ChangeListener() {

            public void stateChanged(ChangeEvent e) {
                int index = model.getNumber().intValue();
                workbench.setGraph(storedGraphs.getGraph(subsetIndices.get(index - 1)));
            }
        });
        spinner.setModel(model);
        totalLabel.setText(" of " + _subsetIndices.size());
        if (subsetIndices.isEmpty()) {
            workbench.setGraph(new EdgeListGraph());
        } else {
            workbench.setGraph(storedGraphs.getGraph(subsetIndices.get(0)));
        }
    }
}
Also used : ChangeEvent(javax.swing.event.ChangeEvent) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) ChangeListener(javax.swing.event.ChangeListener) ArrayList(java.util.ArrayList) List(java.util.List)

Example 8 with EdgeListGraph

use of edu.cmu.tetrad.graph.EdgeListGraph in project tetrad by cmu-phil.

the class IonDisplay method resetDisplay.

private void resetDisplay() {
    final List<Integer> _subsetIndices = getAllIndices(getStoredGraphs());
    indices.clear();
    indices.addAll(_subsetIndices);
    int min = indices.size() == 0 ? 0 : 1;
    final SpinnerNumberModel model = new SpinnerNumberModel(min, min, indices.size(), 1);
    model.addChangeListener(new ChangeListener() {

        public void stateChanged(ChangeEvent e) {
            int index = model.getNumber().intValue();
            workbench.setGraph(storedGraphs.get(indices.get(index - 1)));
        }
    });
    spinner.setModel(model);
    totalLabel.setText(" of " + _subsetIndices.size());
    if (indices.isEmpty()) {
        workbench.setGraph(new EdgeListGraph());
    } else {
        workbench.setGraph(storedGraphs.get(indices.get(0)));
    }
}
Also used : ChangeEvent(javax.swing.event.ChangeEvent) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) ChangeListener(javax.swing.event.ChangeListener)

Example 9 with EdgeListGraph

use of edu.cmu.tetrad.graph.EdgeListGraph in project tetrad by cmu-phil.

the class Comparison method compare.

/**
 * Simulates data from model paramerizing the given DAG, and runs the algorithm on that data,
 * printing out error statistics.
 */
public static ComparisonResult compare(ComparisonParameters params) {
    DataSet dataSet;
    Graph trueDag;
    IndependenceTest test = null;
    Score score = null;
    ComparisonResult result = new ComparisonResult(params);
    if (params.getDataFile() != null) {
        dataSet = loadDataFile(params.getDataFile());
        if (params.getGraphFile() == null) {
            throw new IllegalArgumentException("True graph file not set.");
        }
        trueDag = loadGraphFile(params.getGraphFile());
    } else {
        if (params.getNumVars() == -1) {
            throw new IllegalArgumentException("Number of variables not set.");
        }
        if (params.getNumEdges() == -1) {
            throw new IllegalArgumentException("Number of edges not set.");
        }
        if (params.getDataType() == ComparisonParameters.DataType.Continuous) {
            List<Node> nodes = new ArrayList<>();
            for (int i = 0; i < params.getNumVars(); i++) {
                nodes.add(new ContinuousVariable("X" + (i + 1)));
            }
            trueDag = GraphUtils.randomGraphRandomForwardEdges(nodes, 0, params.getNumEdges(), 10, 10, 10, false, true);
            if (params.getDataType() == null) {
                throw new IllegalArgumentException("Data type not set or inferred.");
            }
            if (params.getSampleSize() == -1) {
                throw new IllegalArgumentException("Sample size not set.");
            }
            LargeScaleSimulation sim = new LargeScaleSimulation(trueDag);
            dataSet = sim.simulateDataFisher(params.getSampleSize());
        } else if (params.getDataType() == ComparisonParameters.DataType.Discrete) {
            List<Node> nodes = new ArrayList<>();
            for (int i = 0; i < params.getNumVars(); i++) {
                nodes.add(new DiscreteVariable("X" + (i + 1), 3));
            }
            trueDag = GraphUtils.randomGraphRandomForwardEdges(nodes, 0, params.getNumEdges(), 10, 10, 10, false, true);
            if (params.getDataType() == null) {
                throw new IllegalArgumentException("Data type not set or inferred.");
            }
            if (params.getSampleSize() == -1) {
                throw new IllegalArgumentException("Sample size not set.");
            }
            int[] tiers = new int[nodes.size()];
            for (int i = 0; i < nodes.size(); i++) {
                tiers[i] = i;
            }
            BayesPm pm = new BayesPm(trueDag, 3, 3);
            MlBayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM);
            dataSet = im.simulateData(params.getSampleSize(), false, tiers);
        } else {
            throw new IllegalArgumentException("Unrecognized data type.");
        }
        if (dataSet == null) {
            throw new IllegalArgumentException("No data set.");
        }
    }
    if (params.getIndependenceTest() == ComparisonParameters.IndependenceTestType.FisherZ) {
        if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Continuous) {
            throw new IllegalArgumentException("Data type previously set to something other than continuous.");
        }
        if (Double.isNaN(params.getAlpha())) {
            throw new IllegalArgumentException("Alpha not set.");
        }
        test = new IndTestFisherZ(dataSet, params.getAlpha());
        params.setDataType(ComparisonParameters.DataType.Continuous);
    } else if (params.getIndependenceTest() == ComparisonParameters.IndependenceTestType.ChiSquare) {
        if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Discrete) {
            throw new IllegalArgumentException("Data type previously set to something other than discrete.");
        }
        if (Double.isNaN(params.getAlpha())) {
            throw new IllegalArgumentException("Alpha not set.");
        }
        test = new IndTestChiSquare(dataSet, params.getAlpha());
        params.setDataType(ComparisonParameters.DataType.Discrete);
    }
    if (params.getScore() == ScoreType.SemBic) {
        if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Continuous) {
            throw new IllegalArgumentException("Data type previously set to something other than continuous.");
        }
        if (Double.isNaN(params.getPenaltyDiscount())) {
            throw new IllegalArgumentException("Penalty discount not set.");
        }
        SemBicScore semBicScore = new SemBicScore(new CovarianceMatrixOnTheFly(dataSet));
        semBicScore.setPenaltyDiscount(params.getPenaltyDiscount());
        score = semBicScore;
        params.setDataType(ComparisonParameters.DataType.Continuous);
    } else if (params.getScore() == ScoreType.BDeu) {
        if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Discrete) {
            throw new IllegalArgumentException("Data type previously set to something other than discrete.");
        }
        if (Double.isNaN(params.getSamplePrior())) {
            throw new IllegalArgumentException("Sample prior not set.");
        }
        if (Double.isNaN(params.getStructurePrior())) {
            throw new IllegalArgumentException("Structure prior not set.");
        }
        score = new BDeuScore(dataSet);
        ((BDeuScore) score).setSamplePrior(params.getSamplePrior());
        ((BDeuScore) score).setStructurePrior(params.getStructurePrior());
        params.setDataType(ComparisonParameters.DataType.Discrete);
        params.setDataType(ComparisonParameters.DataType.Discrete);
    }
    if (params.getAlgorithm() == null) {
        throw new IllegalArgumentException("Algorithm not set.");
    }
    long time1 = System.currentTimeMillis();
    if (params.getAlgorithm() == ComparisonParameters.Algorithm.PC) {
        if (test == null)
            throw new IllegalArgumentException("Test not set.");
        Pc search = new Pc(test);
        result.setResultGraph(search.search());
        result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
    } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.CPC) {
        if (test == null)
            throw new IllegalArgumentException("Test not set.");
        Cpc search = new Cpc(test);
        result.setResultGraph(search.search());
        result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
    } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.PCLocal) {
        if (test == null)
            throw new IllegalArgumentException("Test not set.");
        PcLocal search = new PcLocal(test);
        result.setResultGraph(search.search());
        result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
    } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.PCStableMax) {
        if (test == null)
            throw new IllegalArgumentException("Test not set.");
        PcStableMax search = new PcStableMax(test);
        result.setResultGraph(search.search());
        result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
    } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FGES) {
        if (score == null)
            throw new IllegalArgumentException("Score not set.");
        Fges search = new Fges(score);
        search.setFaithfulnessAssumed(params.isOneEdgeFaithfulnessAssumed());
        result.setResultGraph(search.search());
        result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
    } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FGES2) {
        if (score == null)
            throw new IllegalArgumentException("Score not set.");
        Fges search = new Fges(score);
        search.setFaithfulnessAssumed(params.isOneEdgeFaithfulnessAssumed());
        result.setResultGraph(search.search());
        result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
    } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FCI) {
        if (test == null)
            throw new IllegalArgumentException("Test not set.");
        Fci search = new Fci(test);
        result.setResultGraph(search.search());
        result.setCorrectResult(new DagToPag(trueDag).convert());
    } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.GFCI) {
        if (test == null)
            throw new IllegalArgumentException("Test not set.");
        GFci search = new GFci(test, score);
        result.setResultGraph(search.search());
        result.setCorrectResult(new DagToPag(trueDag).convert());
    } else {
        throw new IllegalArgumentException("Unrecognized algorithm.");
    }
    long time2 = System.currentTimeMillis();
    long elapsed = time2 - time1;
    result.setElapsed(elapsed);
    result.setTrueDag(trueDag);
    return result;
}
Also used : MlBayesIm(edu.cmu.tetrad.bayes.MlBayesIm) Node(edu.cmu.tetrad.graph.Node) ArrayList(java.util.ArrayList) LargeScaleSimulation(edu.cmu.tetrad.sem.LargeScaleSimulation) ArrayList(java.util.ArrayList) List(java.util.List) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) Graph(edu.cmu.tetrad.graph.Graph) BayesPm(edu.cmu.tetrad.bayes.BayesPm)

Example 10 with EdgeListGraph

use of edu.cmu.tetrad.graph.EdgeListGraph in project tetrad by cmu-phil.

the class TestSemVarMeans method constructGraph1.

private Graph constructGraph1() {
    Graph graph = new EdgeListGraph();
    Node x1 = new GraphNode("X1");
    Node x2 = new GraphNode("X2");
    Node x3 = new GraphNode("X3");
    Node x4 = new GraphNode("X4");
    Node x5 = new GraphNode("X5");
    // x1.setNodeType(NodeType.LATENT);
    // x2.setNodeType(NodeType.LATENT);
    graph.addNode(x1);
    graph.addNode(x2);
    graph.addNode(x3);
    graph.addNode(x4);
    graph.addNode(x5);
    graph.addDirectedEdge(x1, x2);
    graph.addDirectedEdge(x2, x3);
    graph.addDirectedEdge(x3, x4);
    graph.addDirectedEdge(x1, x4);
    graph.addDirectedEdge(x4, x5);
    return graph;
}
Also used : EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) Graph(edu.cmu.tetrad.graph.Graph) GraphNode(edu.cmu.tetrad.graph.GraphNode) Node(edu.cmu.tetrad.graph.Node) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) GraphNode(edu.cmu.tetrad.graph.GraphNode)

Aggregations

EdgeListGraph (edu.cmu.tetrad.graph.EdgeListGraph)46 Graph (edu.cmu.tetrad.graph.Graph)36 Node (edu.cmu.tetrad.graph.Node)31 ContinuousVariable (edu.cmu.tetrad.data.ContinuousVariable)11 DataSet (edu.cmu.tetrad.data.DataSet)10 Edge (edu.cmu.tetrad.graph.Edge)8 GraphNode (edu.cmu.tetrad.graph.GraphNode)8 ArrayList (java.util.ArrayList)8 DMSearch (edu.cmu.tetrad.search.DMSearch)6 Test (org.junit.Test)6 SemIm (edu.cmu.tetrad.sem.SemIm)5 SemPm (edu.cmu.tetrad.sem.SemPm)5 List (java.util.List)5 IOException (java.io.IOException)4 RandomGraph (edu.cmu.tetrad.algcomparison.graph.RandomGraph)3 TetradMatrix (edu.cmu.tetrad.util.TetradMatrix)3 DoubleMatrix2D (cern.colt.matrix.DoubleMatrix2D)2 DenseDoubleMatrix2D (cern.colt.matrix.impl.DenseDoubleMatrix2D)2 SemBicScore (edu.cmu.tetrad.algcomparison.score.SemBicScore)2 TakesInitialGraph (edu.cmu.tetrad.algcomparison.utils.TakesInitialGraph)2