Search in sources :

Example 21 with EdgeListGraph

use of edu.cmu.tetrad.graph.EdgeListGraph in project tetrad by cmu-phil.

the class GraphHistory method add.

public void add(Graph graph) {
    if (graph == null) {
        throw new NullPointerException();
    }
    for (int i = graphs.size() - 1; i > index; i--) {
        graphs.remove(i);
    }
    graphs.addLast(new EdgeListGraph(graph));
    index++;
}
Also used : EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph)

Example 22 with EdgeListGraph

use of edu.cmu.tetrad.graph.EdgeListGraph in project tetrad by cmu-phil.

the class PcStable method search.

/**
 * Runs PC starting with a commplete graph over the given list of nodes, using the given independence test and
 * knowledge and returns the resultant graph. The returned graph will be a pattern if the independence information
 * is consistent with the hypothesis that there are no latent common causes. It may, however, contain cycles or
 * bidirected edges if this assumption is not born out, either due to the actual presence of latent common causes,
 * or due to statistical errors in conditional independence judgments.
 * <p>
 * All of the given nodes must be in the domain of the given conditional independence test.
 */
public Graph search(List<Node> nodes) {
    this.logger.log("info", "Starting PC algorithm");
    this.logger.log("info", "Independence test = " + getIndependenceTest() + ".");
    // this.logger.log("info", "Variables " + independenceTest.getVariable());
    long startTime = System.currentTimeMillis();
    if (getIndependenceTest() == null) {
        throw new NullPointerException();
    }
    List allNodes = getIndependenceTest().getVariables();
    if (!allNodes.containsAll(nodes)) {
        throw new IllegalArgumentException("All of the given nodes must " + "be in the domain of the independence test provided.");
    }
    graph = new EdgeListGraph(nodes);
    IFas fas = new FasStable(initialGraph, getIndependenceTest());
    fas.setKnowledge(getKnowledge());
    fas.setDepth(getDepth());
    fas.setVerbose(verbose);
    graph = fas.search();
    sepsets = fas.getSepsets();
    SearchGraphUtils.pcOrientbk(knowledge, graph, nodes);
    // SearchGraphUtils.orientCollidersUsingSepsets(this.sepsets, knowledge, graph, initialGraph, verbose);
    // SearchGraphUtils.orientCollidersUsingSepsets(this.sepsets, knowledge, graph, verbose);
    // SearchGraphUtils.orientColeelidersLocally(knowledge, graph, independenceTest, depth);
    SearchGraphUtils.orientCollidersUsingSepsets(this.sepsets, knowledge, graph, verbose, false);
    MeekRules rules = new MeekRules();
    rules.setAggressivelyPreventCycles(this.aggressivelyPreventCycles);
    rules.setKnowledge(knowledge);
    rules.orientImplied(graph);
    this.logger.log("graph", "\nReturning this graph: " + graph);
    this.elapsedTime = System.currentTimeMillis() - startTime;
    this.logger.log("info", "Elapsed time = " + (elapsedTime) / 1000. + " s");
    this.logger.log("info", "Finishing PC Algorithm.");
    this.logger.flush();
    return graph;
}
Also used : EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) List(java.util.List)

Example 23 with EdgeListGraph

use of edu.cmu.tetrad.graph.EdgeListGraph in project tetrad by cmu-phil.

the class TimeoutComparison method getSubgraph.

private Graph getSubgraph(Graph graph, boolean discrete1, boolean discrete2, DataModel DataModel) {
    if (discrete1 && discrete2) {
        Graph newGraph = new EdgeListGraph(graph.getNodes());
        for (Edge edge : graph.getEdges()) {
            Node node1 = DataModel.getVariable(edge.getNode1().getName());
            Node node2 = DataModel.getVariable(edge.getNode2().getName());
            if (node1 instanceof DiscreteVariable && node2 instanceof DiscreteVariable) {
                newGraph.addEdge(edge);
            }
        }
        return newGraph;
    } else if (!discrete1 && !discrete2) {
        Graph newGraph = new EdgeListGraph(graph.getNodes());
        for (Edge edge : graph.getEdges()) {
            Node node1 = DataModel.getVariable(edge.getNode1().getName());
            Node node2 = DataModel.getVariable(edge.getNode2().getName());
            if (node1 instanceof ContinuousVariable && node2 instanceof ContinuousVariable) {
                newGraph.addEdge(edge);
            }
        }
        return newGraph;
    } else {
        Graph newGraph = new EdgeListGraph(graph.getNodes());
        for (Edge edge : graph.getEdges()) {
            Node node1 = DataModel.getVariable(edge.getNode1().getName());
            Node node2 = DataModel.getVariable(edge.getNode2().getName());
            if (node1 instanceof DiscreteVariable && node2 instanceof ContinuousVariable) {
                newGraph.addEdge(edge);
            }
            if (node1 instanceof ContinuousVariable && node2 instanceof DiscreteVariable) {
                newGraph.addEdge(edge);
            }
        }
        return newGraph;
    }
}
Also used : ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) DiscreteVariable(edu.cmu.tetrad.data.DiscreteVariable) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) TakesInitialGraph(edu.cmu.tetrad.algcomparison.utils.TakesInitialGraph) Graph(edu.cmu.tetrad.graph.Graph) Node(edu.cmu.tetrad.graph.Node) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) Edge(edu.cmu.tetrad.graph.Edge)

Example 24 with EdgeListGraph

use of edu.cmu.tetrad.graph.EdgeListGraph in project tetrad by cmu-phil.

the class TimeoutComparison method doRun.

private void doRun(List<AlgorithmSimulationWrapper> algorithmSimulationWrappers, List<AlgorithmWrapper> algorithmWrappers, List<SimulationWrapper> simulationWrappers, Statistics statistics, int numGraphTypes, double[][][][] allStats, Run run) {
    System.out.println();
    System.out.println("Run " + (run.getRunIndex() + 1));
    System.out.println();
    AlgorithmSimulationWrapper algorithmSimulationWrapper = algorithmSimulationWrappers.get(run.getAlgSimIndex());
    AlgorithmWrapper algorithmWrapper = algorithmSimulationWrapper.getAlgorithmWrapper();
    SimulationWrapper simulationWrapper = algorithmSimulationWrapper.getSimulationWrapper();
    DataModel data = simulationWrapper.getDataModel(run.getRunIndex());
    Graph trueGraph = simulationWrapper.getTrueGraph(run.getRunIndex());
    System.out.println((run.getAlgSimIndex() + 1) + ". " + algorithmWrapper.getDescription() + " simulationWrapper: " + simulationWrapper.getDescription());
    long start = System.currentTimeMillis();
    Graph out;
    try {
        Algorithm algorithm = algorithmWrapper.getAlgorithm();
        Simulation simulation = simulationWrapper.getSimulation();
        if (algorithm instanceof HasKnowledge && simulation instanceof HasKnowledge) {
            ((HasKnowledge) algorithm).setKnowledge(((HasKnowledge) simulation).getKnowledge());
        }
        if (algorithmWrapper.getAlgorithm() instanceof ExternalAlgorithm) {
            ExternalAlgorithm external = (ExternalAlgorithm) algorithmWrapper.getAlgorithm();
            external.setSimulation(simulationWrapper.getSimulation());
            external.setPath(resultsPath);
            external.setSimIndex(simulationWrappers.indexOf(simulationWrapper));
        }
        if (algorithm instanceof MultiDataSetAlgorithm) {
            List<Integer> indices = new ArrayList<>();
            int numDataModels = simulationWrapper.getSimulation().getNumDataModels();
            for (int i = 0; i < numDataModels; i++) {
                indices.add(i);
            }
            Collections.shuffle(indices);
            List<DataModel> dataModels = new ArrayList<>();
            int randomSelectionSize = algorithmWrapper.getAlgorithmSpecificParameters().getInt("randomSelectionSize");
            for (int i = 0; i < Math.min(numDataModels, randomSelectionSize); i++) {
                dataModels.add(simulationWrapper.getSimulation().getDataModel(indices.get(i)));
            }
            Parameters _params = algorithmWrapper.getAlgorithmSpecificParameters();
            out = ((MultiDataSetAlgorithm) algorithm).search(dataModels, _params);
        } else {
            DataModel dataModel = copyData ? data.copy() : data;
            Parameters _params = algorithmWrapper.getAlgorithmSpecificParameters();
            out = algorithm.search(dataModel, _params);
        }
    } catch (Exception e) {
        System.out.println("Could not run " + algorithmWrapper.getDescription());
        e.printStackTrace();
        return;
    }
    int simIndex = simulationWrappers.indexOf(simulationWrapper) + 1;
    int algIndex = algorithmWrappers.indexOf(algorithmWrapper) + 1;
    long stop = System.currentTimeMillis();
    long elapsed = stop - start;
    saveGraph(resultsPath, out, run.getRunIndex(), simIndex, algIndex, algorithmWrapper, elapsed);
    if (trueGraph != null) {
        out = GraphUtils.replaceNodes(out, trueGraph.getNodes());
    }
    if (algorithmWrapper.getAlgorithm() instanceof ExternalAlgorithm) {
        ExternalAlgorithm extAlg = (ExternalAlgorithm) algorithmWrapper.getAlgorithm();
        extAlg.setSimIndex(simulationWrappers.indexOf(simulationWrapper));
        extAlg.setSimulation(simulationWrapper.getSimulation());
        extAlg.setPath(resultsPath);
        elapsed = extAlg.getElapsedTime(data, simulationWrapper.getSimulationSpecificParameters());
    }
    Graph[] est = new Graph[numGraphTypes];
    Graph comparisonGraph;
    if (this.comparisonGraph == ComparisonGraph.true_DAG) {
        comparisonGraph = new EdgeListGraph(trueGraph);
    } else if (this.comparisonGraph == ComparisonGraph.Pattern_of_the_true_DAG) {
        comparisonGraph = SearchGraphUtils.patternForDag(new EdgeListGraph(trueGraph));
    } else if (this.comparisonGraph == ComparisonGraph.PAG_of_the_true_DAG) {
        comparisonGraph = new DagToPag(new EdgeListGraph(trueGraph)).convert();
    } else {
        throw new IllegalArgumentException("Unrecognized graph type.");
    }
    // Graph comparisonGraph = trueGraph == null ? null : algorithmSimulationWrapper.getComparisonGraph(trueGraph);
    est[0] = out;
    graphTypeUsed[0] = true;
    if (data.isMixed()) {
        est[1] = getSubgraph(out, true, true, data);
        est[2] = getSubgraph(out, true, false, data);
        est[3] = getSubgraph(out, false, false, data);
        graphTypeUsed[1] = true;
        graphTypeUsed[2] = true;
        graphTypeUsed[3] = true;
    }
    Graph[] truth = new Graph[numGraphTypes];
    truth[0] = comparisonGraph;
    if (data.isMixed() && comparisonGraph != null) {
        truth[1] = getSubgraph(comparisonGraph, true, true, data);
        truth[2] = getSubgraph(comparisonGraph, true, false, data);
        truth[3] = getSubgraph(comparisonGraph, false, false, data);
    }
    if (comparisonGraph != null) {
        for (int u = 0; u < numGraphTypes; u++) {
            if (!graphTypeUsed[u]) {
                continue;
            }
            int statIndex = -1;
            for (Statistic _stat : statistics.getStatistics()) {
                statIndex++;
                if (_stat instanceof ParameterColumn) {
                    continue;
                }
                double stat;
                if (_stat instanceof ElapsedTime) {
                    stat = elapsed / 1000.0;
                } else {
                    stat = _stat.getValue(truth[u], est[u]);
                }
                allStats[u][run.getAlgSimIndex()][statIndex][run.getRunIndex()] = stat;
            }
        }
    }
}
Also used : ArrayList(java.util.ArrayList) ElapsedTime(edu.cmu.tetrad.algcomparison.statistic.ElapsedTime) HasKnowledge(edu.cmu.tetrad.algcomparison.utils.HasKnowledge) ExternalAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.ExternalAlgorithm) MultiDataSetAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.MultiDataSetAlgorithm) Statistic(edu.cmu.tetrad.algcomparison.statistic.Statistic) ParameterColumn(edu.cmu.tetrad.algcomparison.statistic.ParameterColumn) Parameters(edu.cmu.tetrad.util.Parameters) HasParameters(edu.cmu.tetrad.algcomparison.utils.HasParameters) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) ExternalAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.ExternalAlgorithm) Algorithm(edu.cmu.tetrad.algcomparison.algorithm.Algorithm) MultiDataSetAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.MultiDataSetAlgorithm) TimeoutException(java.util.concurrent.TimeoutException) FileNotFoundException(java.io.FileNotFoundException) IOException(java.io.IOException) ExecutionException(java.util.concurrent.ExecutionException) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) TakesInitialGraph(edu.cmu.tetrad.algcomparison.utils.TakesInitialGraph) Graph(edu.cmu.tetrad.graph.Graph) DagToPag(edu.cmu.tetrad.search.DagToPag) Simulation(edu.cmu.tetrad.algcomparison.simulation.Simulation) DataModel(edu.cmu.tetrad.data.DataModel)

Example 25 with EdgeListGraph

use of edu.cmu.tetrad.graph.EdgeListGraph in project tetrad by cmu-phil.

the class PngWriter method writePng.

public static void writePng(Graph graph, File file) {
    // circleLayout(graph, 200, 200, 175);
    JPanel panel = new JPanel();
    panel.setLayout(new BorderLayout());
    // Remove self-loops.
    graph = new EdgeListGraph(graph);
    for (Node node : graph.getNodes()) {
        for (Edge edge : new ArrayList<>(graph.getEdges(node, node))) {
            graph.removeEdge(edge);
        }
    }
    final GraphWorkbench workbench = new GraphWorkbench(graph);
    int maxx = 0;
    int maxy = 0;
    for (Node node : graph.getNodes()) {
        if (node.getCenterX() > maxx) {
            maxx = node.getCenterX();
        }
        if (node.getCenterY() > maxy) {
            maxy = node.getCenterY();
        }
    }
    workbench.setSize(new Dimension(maxx + 50, maxy + 50));
    panel.add(workbench, BorderLayout.CENTER);
    JDialog dialog = new JDialog();
    dialog.add(workbench);
    dialog.pack();
    Dimension size = workbench.getSize();
    BufferedImage image = new BufferedImage(size.width, size.height, BufferedImage.TYPE_BYTE_INDEXED);
    Graphics2D graphics = image.createGraphics();
    workbench.paint(graphics);
    image.flush();
    // Write the image to resultFile.
    try {
        ImageIO.write(image, "PNG", file);
    } catch (IOException e1) {
        throw new RuntimeException("Could not write to " + file, e1);
    }
}
Also used : Node(edu.cmu.tetrad.graph.Node) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) ArrayList(java.util.ArrayList) IOException(java.io.IOException) BufferedImage(java.awt.image.BufferedImage) Edge(edu.cmu.tetrad.graph.Edge)

Aggregations

EdgeListGraph (edu.cmu.tetrad.graph.EdgeListGraph)46 Graph (edu.cmu.tetrad.graph.Graph)36 Node (edu.cmu.tetrad.graph.Node)31 ContinuousVariable (edu.cmu.tetrad.data.ContinuousVariable)11 DataSet (edu.cmu.tetrad.data.DataSet)10 Edge (edu.cmu.tetrad.graph.Edge)8 GraphNode (edu.cmu.tetrad.graph.GraphNode)8 ArrayList (java.util.ArrayList)8 DMSearch (edu.cmu.tetrad.search.DMSearch)6 Test (org.junit.Test)6 SemIm (edu.cmu.tetrad.sem.SemIm)5 SemPm (edu.cmu.tetrad.sem.SemPm)5 List (java.util.List)5 IOException (java.io.IOException)4 RandomGraph (edu.cmu.tetrad.algcomparison.graph.RandomGraph)3 TetradMatrix (edu.cmu.tetrad.util.TetradMatrix)3 DoubleMatrix2D (cern.colt.matrix.DoubleMatrix2D)2 DenseDoubleMatrix2D (cern.colt.matrix.impl.DenseDoubleMatrix2D)2 SemBicScore (edu.cmu.tetrad.algcomparison.score.SemBicScore)2 TakesInitialGraph (edu.cmu.tetrad.algcomparison.utils.TakesInitialGraph)2