use of edu.cmu.tetrad.graph.EdgeListGraph in project tetrad by cmu-phil.
the class GraphHistory method add.
public void add(Graph graph) {
if (graph == null) {
throw new NullPointerException();
}
for (int i = graphs.size() - 1; i > index; i--) {
graphs.remove(i);
}
graphs.addLast(new EdgeListGraph(graph));
index++;
}
use of edu.cmu.tetrad.graph.EdgeListGraph in project tetrad by cmu-phil.
the class PcStable method search.
/**
* Runs PC starting with a commplete graph over the given list of nodes, using the given independence test and
* knowledge and returns the resultant graph. The returned graph will be a pattern if the independence information
* is consistent with the hypothesis that there are no latent common causes. It may, however, contain cycles or
* bidirected edges if this assumption is not born out, either due to the actual presence of latent common causes,
* or due to statistical errors in conditional independence judgments.
* <p>
* All of the given nodes must be in the domain of the given conditional independence test.
*/
public Graph search(List<Node> nodes) {
this.logger.log("info", "Starting PC algorithm");
this.logger.log("info", "Independence test = " + getIndependenceTest() + ".");
// this.logger.log("info", "Variables " + independenceTest.getVariable());
long startTime = System.currentTimeMillis();
if (getIndependenceTest() == null) {
throw new NullPointerException();
}
List allNodes = getIndependenceTest().getVariables();
if (!allNodes.containsAll(nodes)) {
throw new IllegalArgumentException("All of the given nodes must " + "be in the domain of the independence test provided.");
}
graph = new EdgeListGraph(nodes);
IFas fas = new FasStable(initialGraph, getIndependenceTest());
fas.setKnowledge(getKnowledge());
fas.setDepth(getDepth());
fas.setVerbose(verbose);
graph = fas.search();
sepsets = fas.getSepsets();
SearchGraphUtils.pcOrientbk(knowledge, graph, nodes);
// SearchGraphUtils.orientCollidersUsingSepsets(this.sepsets, knowledge, graph, initialGraph, verbose);
// SearchGraphUtils.orientCollidersUsingSepsets(this.sepsets, knowledge, graph, verbose);
// SearchGraphUtils.orientColeelidersLocally(knowledge, graph, independenceTest, depth);
SearchGraphUtils.orientCollidersUsingSepsets(this.sepsets, knowledge, graph, verbose, false);
MeekRules rules = new MeekRules();
rules.setAggressivelyPreventCycles(this.aggressivelyPreventCycles);
rules.setKnowledge(knowledge);
rules.orientImplied(graph);
this.logger.log("graph", "\nReturning this graph: " + graph);
this.elapsedTime = System.currentTimeMillis() - startTime;
this.logger.log("info", "Elapsed time = " + (elapsedTime) / 1000. + " s");
this.logger.log("info", "Finishing PC Algorithm.");
this.logger.flush();
return graph;
}
use of edu.cmu.tetrad.graph.EdgeListGraph in project tetrad by cmu-phil.
the class TimeoutComparison method getSubgraph.
private Graph getSubgraph(Graph graph, boolean discrete1, boolean discrete2, DataModel DataModel) {
if (discrete1 && discrete2) {
Graph newGraph = new EdgeListGraph(graph.getNodes());
for (Edge edge : graph.getEdges()) {
Node node1 = DataModel.getVariable(edge.getNode1().getName());
Node node2 = DataModel.getVariable(edge.getNode2().getName());
if (node1 instanceof DiscreteVariable && node2 instanceof DiscreteVariable) {
newGraph.addEdge(edge);
}
}
return newGraph;
} else if (!discrete1 && !discrete2) {
Graph newGraph = new EdgeListGraph(graph.getNodes());
for (Edge edge : graph.getEdges()) {
Node node1 = DataModel.getVariable(edge.getNode1().getName());
Node node2 = DataModel.getVariable(edge.getNode2().getName());
if (node1 instanceof ContinuousVariable && node2 instanceof ContinuousVariable) {
newGraph.addEdge(edge);
}
}
return newGraph;
} else {
Graph newGraph = new EdgeListGraph(graph.getNodes());
for (Edge edge : graph.getEdges()) {
Node node1 = DataModel.getVariable(edge.getNode1().getName());
Node node2 = DataModel.getVariable(edge.getNode2().getName());
if (node1 instanceof DiscreteVariable && node2 instanceof ContinuousVariable) {
newGraph.addEdge(edge);
}
if (node1 instanceof ContinuousVariable && node2 instanceof DiscreteVariable) {
newGraph.addEdge(edge);
}
}
return newGraph;
}
}
use of edu.cmu.tetrad.graph.EdgeListGraph in project tetrad by cmu-phil.
the class TimeoutComparison method doRun.
private void doRun(List<AlgorithmSimulationWrapper> algorithmSimulationWrappers, List<AlgorithmWrapper> algorithmWrappers, List<SimulationWrapper> simulationWrappers, Statistics statistics, int numGraphTypes, double[][][][] allStats, Run run) {
System.out.println();
System.out.println("Run " + (run.getRunIndex() + 1));
System.out.println();
AlgorithmSimulationWrapper algorithmSimulationWrapper = algorithmSimulationWrappers.get(run.getAlgSimIndex());
AlgorithmWrapper algorithmWrapper = algorithmSimulationWrapper.getAlgorithmWrapper();
SimulationWrapper simulationWrapper = algorithmSimulationWrapper.getSimulationWrapper();
DataModel data = simulationWrapper.getDataModel(run.getRunIndex());
Graph trueGraph = simulationWrapper.getTrueGraph(run.getRunIndex());
System.out.println((run.getAlgSimIndex() + 1) + ". " + algorithmWrapper.getDescription() + " simulationWrapper: " + simulationWrapper.getDescription());
long start = System.currentTimeMillis();
Graph out;
try {
Algorithm algorithm = algorithmWrapper.getAlgorithm();
Simulation simulation = simulationWrapper.getSimulation();
if (algorithm instanceof HasKnowledge && simulation instanceof HasKnowledge) {
((HasKnowledge) algorithm).setKnowledge(((HasKnowledge) simulation).getKnowledge());
}
if (algorithmWrapper.getAlgorithm() instanceof ExternalAlgorithm) {
ExternalAlgorithm external = (ExternalAlgorithm) algorithmWrapper.getAlgorithm();
external.setSimulation(simulationWrapper.getSimulation());
external.setPath(resultsPath);
external.setSimIndex(simulationWrappers.indexOf(simulationWrapper));
}
if (algorithm instanceof MultiDataSetAlgorithm) {
List<Integer> indices = new ArrayList<>();
int numDataModels = simulationWrapper.getSimulation().getNumDataModels();
for (int i = 0; i < numDataModels; i++) {
indices.add(i);
}
Collections.shuffle(indices);
List<DataModel> dataModels = new ArrayList<>();
int randomSelectionSize = algorithmWrapper.getAlgorithmSpecificParameters().getInt("randomSelectionSize");
for (int i = 0; i < Math.min(numDataModels, randomSelectionSize); i++) {
dataModels.add(simulationWrapper.getSimulation().getDataModel(indices.get(i)));
}
Parameters _params = algorithmWrapper.getAlgorithmSpecificParameters();
out = ((MultiDataSetAlgorithm) algorithm).search(dataModels, _params);
} else {
DataModel dataModel = copyData ? data.copy() : data;
Parameters _params = algorithmWrapper.getAlgorithmSpecificParameters();
out = algorithm.search(dataModel, _params);
}
} catch (Exception e) {
System.out.println("Could not run " + algorithmWrapper.getDescription());
e.printStackTrace();
return;
}
int simIndex = simulationWrappers.indexOf(simulationWrapper) + 1;
int algIndex = algorithmWrappers.indexOf(algorithmWrapper) + 1;
long stop = System.currentTimeMillis();
long elapsed = stop - start;
saveGraph(resultsPath, out, run.getRunIndex(), simIndex, algIndex, algorithmWrapper, elapsed);
if (trueGraph != null) {
out = GraphUtils.replaceNodes(out, trueGraph.getNodes());
}
if (algorithmWrapper.getAlgorithm() instanceof ExternalAlgorithm) {
ExternalAlgorithm extAlg = (ExternalAlgorithm) algorithmWrapper.getAlgorithm();
extAlg.setSimIndex(simulationWrappers.indexOf(simulationWrapper));
extAlg.setSimulation(simulationWrapper.getSimulation());
extAlg.setPath(resultsPath);
elapsed = extAlg.getElapsedTime(data, simulationWrapper.getSimulationSpecificParameters());
}
Graph[] est = new Graph[numGraphTypes];
Graph comparisonGraph;
if (this.comparisonGraph == ComparisonGraph.true_DAG) {
comparisonGraph = new EdgeListGraph(trueGraph);
} else if (this.comparisonGraph == ComparisonGraph.Pattern_of_the_true_DAG) {
comparisonGraph = SearchGraphUtils.patternForDag(new EdgeListGraph(trueGraph));
} else if (this.comparisonGraph == ComparisonGraph.PAG_of_the_true_DAG) {
comparisonGraph = new DagToPag(new EdgeListGraph(trueGraph)).convert();
} else {
throw new IllegalArgumentException("Unrecognized graph type.");
}
// Graph comparisonGraph = trueGraph == null ? null : algorithmSimulationWrapper.getComparisonGraph(trueGraph);
est[0] = out;
graphTypeUsed[0] = true;
if (data.isMixed()) {
est[1] = getSubgraph(out, true, true, data);
est[2] = getSubgraph(out, true, false, data);
est[3] = getSubgraph(out, false, false, data);
graphTypeUsed[1] = true;
graphTypeUsed[2] = true;
graphTypeUsed[3] = true;
}
Graph[] truth = new Graph[numGraphTypes];
truth[0] = comparisonGraph;
if (data.isMixed() && comparisonGraph != null) {
truth[1] = getSubgraph(comparisonGraph, true, true, data);
truth[2] = getSubgraph(comparisonGraph, true, false, data);
truth[3] = getSubgraph(comparisonGraph, false, false, data);
}
if (comparisonGraph != null) {
for (int u = 0; u < numGraphTypes; u++) {
if (!graphTypeUsed[u]) {
continue;
}
int statIndex = -1;
for (Statistic _stat : statistics.getStatistics()) {
statIndex++;
if (_stat instanceof ParameterColumn) {
continue;
}
double stat;
if (_stat instanceof ElapsedTime) {
stat = elapsed / 1000.0;
} else {
stat = _stat.getValue(truth[u], est[u]);
}
allStats[u][run.getAlgSimIndex()][statIndex][run.getRunIndex()] = stat;
}
}
}
}
use of edu.cmu.tetrad.graph.EdgeListGraph in project tetrad by cmu-phil.
the class PngWriter method writePng.
public static void writePng(Graph graph, File file) {
// circleLayout(graph, 200, 200, 175);
JPanel panel = new JPanel();
panel.setLayout(new BorderLayout());
// Remove self-loops.
graph = new EdgeListGraph(graph);
for (Node node : graph.getNodes()) {
for (Edge edge : new ArrayList<>(graph.getEdges(node, node))) {
graph.removeEdge(edge);
}
}
final GraphWorkbench workbench = new GraphWorkbench(graph);
int maxx = 0;
int maxy = 0;
for (Node node : graph.getNodes()) {
if (node.getCenterX() > maxx) {
maxx = node.getCenterX();
}
if (node.getCenterY() > maxy) {
maxy = node.getCenterY();
}
}
workbench.setSize(new Dimension(maxx + 50, maxy + 50));
panel.add(workbench, BorderLayout.CENTER);
JDialog dialog = new JDialog();
dialog.add(workbench);
dialog.pack();
Dimension size = workbench.getSize();
BufferedImage image = new BufferedImage(size.width, size.height, BufferedImage.TYPE_BYTE_INDEXED);
Graphics2D graphics = image.createGraphics();
workbench.paint(graphics);
image.flush();
// Write the image to resultFile.
try {
ImageIO.write(image, "PNG", file);
} catch (IOException e1) {
throw new RuntimeException("Could not write to " + file, e1);
}
}
Aggregations