use of edu.cmu.tetrad.algcomparison.algorithm.ExternalAlgorithm in project tetrad by cmu-phil.
the class Comparison method doRun.
private void doRun(List<AlgorithmSimulationWrapper> algorithmSimulationWrappers, List<AlgorithmWrapper> algorithmWrappers, List<SimulationWrapper> simulationWrappers, Statistics statistics, int numGraphTypes, double[][][][] allStats, Run run) {
System.out.println();
System.out.println("Run " + (run.getRunIndex() + 1));
System.out.println();
AlgorithmSimulationWrapper algorithmSimulationWrapper = algorithmSimulationWrappers.get(run.getAlgSimIndex());
AlgorithmWrapper algorithmWrapper = algorithmSimulationWrapper.getAlgorithmWrapper();
SimulationWrapper simulationWrapper = algorithmSimulationWrapper.getSimulationWrapper();
DataModel data = simulationWrapper.getDataModel(run.getRunIndex());
Graph trueGraph = simulationWrapper.getTrueGraph(run.getRunIndex());
System.out.println((run.getAlgSimIndex() + 1) + ". " + algorithmWrapper.getDescription() + " simulationWrapper: " + simulationWrapper.getDescription());
long start = System.currentTimeMillis();
Graph out;
try {
Algorithm algorithm = algorithmWrapper.getAlgorithm();
Simulation simulation = simulationWrapper.getSimulation();
if (algorithm instanceof HasKnowledge && simulation instanceof HasKnowledge) {
((HasKnowledge) algorithm).setKnowledge(((HasKnowledge) simulation).getKnowledge());
}
if (algorithmWrapper.getAlgorithm() instanceof ExternalAlgorithm) {
ExternalAlgorithm external = (ExternalAlgorithm) algorithmWrapper.getAlgorithm();
external.setSimulation(simulationWrapper.getSimulation());
external.setPath(resultsPath);
external.setSimIndex(simulationWrappers.indexOf(simulationWrapper));
}
if (algorithm instanceof MultiDataSetAlgorithm) {
List<Integer> indices = new ArrayList<>();
int numDataModels = simulationWrapper.getSimulation().getNumDataModels();
for (int i = 0; i < numDataModels; i++) indices.add(i);
Collections.shuffle(indices);
List<DataModel> dataModels = new ArrayList<>();
int randomSelectionSize = algorithmWrapper.getAlgorithmSpecificParameters().getInt("randomSelectionSize");
for (int i = 0; i < Math.min(numDataModels, randomSelectionSize); i++) {
dataModels.add(simulationWrapper.getSimulation().getDataModel(indices.get(i)));
}
Parameters _params = algorithmWrapper.getAlgorithmSpecificParameters();
out = ((MultiDataSetAlgorithm) algorithm).search(dataModels, _params);
} else {
DataModel dataModel = copyData ? data.copy() : data;
Parameters _params = algorithmWrapper.getAlgorithmSpecificParameters();
out = algorithm.search(dataModel, _params);
}
} catch (Exception e) {
System.out.println("Could not run " + algorithmWrapper.getDescription());
e.printStackTrace();
return;
}
int simIndex = simulationWrappers.indexOf(simulationWrapper) + 1;
int algIndex = algorithmWrappers.indexOf(algorithmWrapper) + 1;
long stop = System.currentTimeMillis();
long elapsed = stop - start;
saveGraph(resultsPath, out, run.getRunIndex(), simIndex, algIndex, algorithmWrapper, elapsed);
if (trueGraph != null) {
out = GraphUtils.replaceNodes(out, trueGraph.getNodes());
}
if (algorithmWrapper.getAlgorithm() instanceof ExternalAlgorithm) {
ExternalAlgorithm extAlg = (ExternalAlgorithm) algorithmWrapper.getAlgorithm();
extAlg.setSimIndex(simulationWrappers.indexOf(simulationWrapper));
extAlg.setSimulation(simulationWrapper.getSimulation());
extAlg.setPath(resultsPath);
elapsed = extAlg.getElapsedTime(data, simulationWrapper.getSimulationSpecificParameters());
}
Graph[] est = new Graph[numGraphTypes];
Graph comparisonGraph;
if (this.comparisonGraph == ComparisonGraph.true_DAG) {
comparisonGraph = new EdgeListGraph(trueGraph);
} else if (this.comparisonGraph == ComparisonGraph.Pattern_of_the_true_DAG) {
comparisonGraph = SearchGraphUtils.patternForDag(new EdgeListGraph(trueGraph));
} else if (this.comparisonGraph == ComparisonGraph.PAG_of_the_true_DAG) {
comparisonGraph = new DagToPag(new EdgeListGraph(trueGraph)).convert();
} else {
throw new IllegalArgumentException("Unrecognized graph type.");
}
// Graph comparisonGraph = trueGraph == null ? null : algorithmSimulationWrapper.getComparisonGraph(trueGraph);
est[0] = out;
graphTypeUsed[0] = true;
if (data.isMixed()) {
est[1] = getSubgraph(out, true, true, data);
est[2] = getSubgraph(out, true, false, data);
est[3] = getSubgraph(out, false, false, data);
graphTypeUsed[1] = true;
graphTypeUsed[2] = true;
graphTypeUsed[3] = true;
}
Graph[] truth = new Graph[numGraphTypes];
truth[0] = comparisonGraph;
if (data.isMixed() && comparisonGraph != null) {
truth[1] = getSubgraph(comparisonGraph, true, true, data);
truth[2] = getSubgraph(comparisonGraph, true, false, data);
truth[3] = getSubgraph(comparisonGraph, false, false, data);
}
if (comparisonGraph != null) {
for (int u = 0; u < numGraphTypes; u++) {
if (!graphTypeUsed[u])
continue;
int statIndex = -1;
for (Statistic _stat : statistics.getStatistics()) {
statIndex++;
if (_stat instanceof ParameterColumn)
continue;
double stat;
if (_stat instanceof ElapsedTime) {
stat = elapsed / 1000.0;
} else {
stat = _stat.getValue(truth[u], est[u]);
}
allStats[u][run.getAlgSimIndex()][statIndex][run.getRunIndex()] = stat;
}
}
}
}
use of edu.cmu.tetrad.algcomparison.algorithm.ExternalAlgorithm in project tetrad by cmu-phil.
the class Comparison method generateReportFromExternalAlgorithms.
public void generateReportFromExternalAlgorithms(String dataPath, String resultsPath, String outputFileName, Algorithms algorithms, Statistics statistics, Parameters parameters) {
this.saveGraphs = false;
this.dataPath = dataPath;
this.resultsPath = resultsPath;
for (Algorithm algorithm : algorithms.getAlgorithms()) {
if (!(algorithm instanceof ExternalAlgorithm))
throw new IllegalArgumentException("Expecting all algorithms to implement ExternalAlgorithm.");
}
Simulations simulations = new Simulations();
File file = new File(this.dataPath, "save");
File[] dirs = file.listFiles();
if (dirs == null) {
throw new NullPointerException("No files in " + file.getAbsolutePath());
}
this.dirs = new ArrayList<String>();
int count = 0;
for (File dir : dirs) {
if (dir.getName().contains("DS_Store"))
continue;
count++;
}
for (int i = 1; i <= count; i++) {
File _dir = new File(dataPath, "save/" + i);
simulations.add(new LoadDataAndGraphs(_dir.getAbsolutePath()));
this.dirs.add(_dir.getAbsolutePath());
}
compareFromSimulations(this.resultsPath, simulations, outputFileName, algorithms, statistics, parameters);
}
use of edu.cmu.tetrad.algcomparison.algorithm.ExternalAlgorithm in project tetrad by cmu-phil.
the class Comparison method compareFromFiles.
/**
* Compares algorithms.
*
* @param dataPath Path to the directory where data and graph files have been saved.
* @param resultsPath Path to the file where the results should be stored.
* @param algorithms The list of algorithms to be compared.
* @param statistics The list of statistics on which to compare the algorithm, and their utility weights.
* @param parameters The list of parameters and their values.
*/
public void compareFromFiles(String dataPath, String resultsPath, Algorithms algorithms, Statistics statistics, Parameters parameters) {
for (Algorithm algorithm : algorithms.getAlgorithms()) {
if (algorithm instanceof ExternalAlgorithm) {
throw new IllegalArgumentException("Not expecting any implementations of ExternalAlgorithm here.");
}
}
this.dataPath = dataPath;
this.resultsPath = resultsPath;
Simulations simulations = new Simulations();
File file = new File(this.dataPath, "save");
File[] dirs = file.listFiles();
if (dirs == null) {
throw new NullPointerException("No files in " + file.getAbsolutePath());
}
this.dirs = new ArrayList<String>();
int count = 0;
for (File dir : dirs) {
if (dir.getName().contains("DS_Store"))
continue;
count++;
}
for (int i = 1; i <= count; i++) {
File _dir = new File(dataPath, "save/" + i);
simulations.add(new LoadDataAndGraphs(_dir.getAbsolutePath()));
this.dirs.add(_dir.getAbsolutePath());
}
compareFromSimulations(this.resultsPath, simulations, algorithms, statistics, parameters);
}
use of edu.cmu.tetrad.algcomparison.algorithm.ExternalAlgorithm in project tetrad by cmu-phil.
the class ExternalAlgorithmIntersection method search.
/**
* Reads in the relevant graph from the file (see above) and returns it.
*/
public Graph search(DataModel dataSet, Parameters parameters) {
this.elapsed = 0;
for (ExternalAlgorithm algorithm : algorithms) {
algorithm.setPath(this.path);
algorithm.setSimIndex(this.simIndex);
algorithm.setSimulation(this.simulation);
elapsed += algorithm.getElapsedTime((DataSet) dataSet, parameters);
}
Graph graph0 = algorithms[0].search(dataSet, parameters);
Set<Edge> edges = graph0.getEdges();
for (int i = 1; i < algorithms.length; i++) {
edges.retainAll(algorithms[i].search(dataSet, parameters).getEdges());
}
EdgeListGraph intersection = new EdgeListGraph(graph0.getNodes());
for (Edge edge : edges) {
intersection.addEdge(edge);
}
return intersection;
}
Aggregations