use of edu.cmu.tetrad.algcomparison.simulation.Simulation in project tetrad by cmu-phil.
the class TimeoutComparison method compareFromSimulations.
/**
* Compares algorithms.
*
* @param resultsPath Path to the file where the output should be printed.
* @param simulations The list of simulationWrapper that is used to generate
* graphs and data for the comparison.
* @param algorithms The list of algorithms to be compared.
* @param statistics The list of statistics on which to compare the
* algorithm, and their utility weights.
*/
public void compareFromSimulations(String resultsPath, Simulations simulations, String outputFileName, Algorithms algorithms, Statistics statistics, Parameters parameters, long timeout, TimeUnit unit) {
this.resultsPath = resultsPath;
// Create output file.
try {
File dir = new File(resultsPath);
dir.mkdirs();
File file = new File(dir, outputFileName);
this.out = new PrintStream(new FileOutputStream(file));
} catch (Exception e) {
throw new RuntimeException(e);
}
out.println(new Date());
// Set up simulations--create data and graphs, read in parameters. The parameters
// are set in the parameters object.
List<SimulationWrapper> simulationWrappers = new ArrayList<>();
int numRuns = parameters.getInt("numRuns");
for (Simulation simulation : simulations.getSimulations()) {
List<SimulationWrapper> wrappers = getSimulationWrappers(simulation, parameters);
for (SimulationWrapper wrapper : wrappers) {
wrapper.createData(wrapper.getSimulationSpecificParameters());
simulationWrappers.add(wrapper);
}
}
// Set up the algorithms.
List<AlgorithmWrapper> algorithmWrappers = new ArrayList<>();
for (Algorithm algorithm : algorithms.getAlgorithms()) {
List<Integer> _dims = new ArrayList<>();
List<String> varyingParameters = new ArrayList<>();
final List<String> parameters1 = algorithm.getParameters();
for (String name : parameters1) {
if (parameters.getNumValues(name) > 1) {
_dims.add(parameters.getNumValues(name));
varyingParameters.add(name);
}
}
if (varyingParameters.isEmpty()) {
algorithmWrappers.add(new AlgorithmWrapper(algorithm, parameters));
} else {
int[] dims = new int[_dims.size()];
for (int i = 0; i < _dims.size(); i++) {
dims[i] = _dims.get(i);
}
CombinationGenerator gen = new CombinationGenerator(dims);
int[] choice;
while ((choice = gen.next()) != null) {
AlgorithmWrapper wrapper = new AlgorithmWrapper(algorithm, parameters);
for (int h = 0; h < dims.length; h++) {
String parameter = varyingParameters.get(h);
Object[] values = parameters.getValues(parameter);
Object value = values[choice[h]];
wrapper.setValue(parameter, value);
}
algorithmWrappers.add(wrapper);
}
}
}
// Create the algorithm-simulation wrappers for every combination of algorithm and
// simulation.
List<AlgorithmSimulationWrapper> algorithmSimulationWrappers = new ArrayList<>();
for (SimulationWrapper simulationWrapper : simulationWrappers) {
for (AlgorithmWrapper algorithmWrapper : algorithmWrappers) {
DataType algDataType = algorithmWrapper.getDataType();
DataType simDataType = simulationWrapper.getDataType();
if (!(algDataType == DataType.Mixed || (algDataType == simDataType))) {
System.out.println("Type mismatch: " + algorithmWrapper.getDescription() + " / " + simulationWrapper.getDescription());
}
if (algorithmWrapper.getAlgorithm() instanceof ExternalAlgorithm) {
ExternalAlgorithm external = (ExternalAlgorithm) algorithmWrapper.getAlgorithm();
// external.setSimulation(simulationWrapper.getSimulation());
// external.setPath(dirs.get(simulationWrappers.indexOf(simulationWrapper)));
// external.setPath(resultsPath);
external.setSimIndex(simulationWrappers.indexOf(external.getSimulation()));
}
algorithmSimulationWrappers.add(new AlgorithmSimulationWrapper(algorithmWrapper, simulationWrapper));
}
}
// Run all of the algorithms and compile statistics.
double[][][][] allStats = calcStats(algorithmSimulationWrappers, algorithmWrappers, simulationWrappers, statistics, numRuns, timeout, unit);
// Print out the preliminary information for statistics types, etc.
if (allStats != null) {
out.println();
out.println("Statistics:");
out.println();
for (Statistic stat : statistics.getStatistics()) {
out.println(stat.getAbbreviation() + " = " + stat.getDescription());
}
}
out.println();
// out.println();
if (allStats != null) {
int numTables = allStats.length;
int numStats = allStats[0][0].length - 1;
double[][][] statTables = calcStatTables(allStats, Mode.Average, numTables, algorithmSimulationWrappers, numStats, statistics);
double[] utilities = calcUtilities(statistics, algorithmSimulationWrappers, statTables[0]);
// Add utilities to table as the last column.
for (int u = 0; u < numTables; u++) {
for (int t = 0; t < algorithmSimulationWrappers.size(); t++) {
statTables[u][t][numStats] = utilities[t];
}
}
int[] newOrder;
if (isSortByUtility()) {
newOrder = sort(algorithmSimulationWrappers, utilities);
} else {
newOrder = new int[algorithmSimulationWrappers.size()];
for (int q = 0; q < algorithmSimulationWrappers.size(); q++) {
newOrder[q] = q;
}
}
out.println("Simulations:");
out.println();
// if (simulationWrappers.size() == 1) {
// out.println(simulationWrappers.get(0).getDescription());
// } else {
int i = 0;
for (SimulationWrapper simulation : simulationWrappers) {
out.print("Simulation " + (++i) + ": ");
out.println(simulation.getDescription());
out.println();
printParameters(simulation.getParameters(), simulation.getSimulationSpecificParameters(), out);
// for (String param : simulation.getParameters()) {
// out.println(param + " = " + simulation.getValue(param));
// }
out.println();
}
// }
out.println("Algorithms:");
out.println();
for (int t = 0; t < algorithmSimulationWrappers.size(); t++) {
AlgorithmSimulationWrapper wrapper = algorithmSimulationWrappers.get(t);
if (wrapper.getSimulationWrapper() == simulationWrappers.get(0)) {
out.println((t + 1) + ". " + wrapper.getAlgorithmWrapper().getDescription());
}
}
if (isSortByUtility()) {
out.println();
out.println("Sorting by utility, high to low.");
}
if (isShowUtilities()) {
out.println();
out.println("Weighting of statistics:");
out.println();
out.println("U = ");
for (Statistic stat : statistics.getStatistics()) {
String statName = stat.getAbbreviation();
double weight = statistics.getWeight(stat);
if (weight != 0.0) {
out.println(" " + weight + " * f(" + statName + ")");
}
}
out.println();
out.println("...normed to range between 0 and 1.");
out.println();
out.println("Note that f for each statistic is a function that maps the statistic to the ");
out.println("interval [0, 1], with higher being better.");
}
out.println();
out.println("Graphs are being compared to the " + comparisonGraph.toString().replace("_", " ") + ".");
out.println();
// Add utilities to table as the last column.
for (int u = 0; u < numTables; u++) {
for (int t = 0; t < algorithmSimulationWrappers.size(); t++) {
statTables[u][t][numStats] = utilities[t];
}
}
// Print all of the tables.
printStats(statTables, statistics, Mode.Average, newOrder, algorithmSimulationWrappers, algorithmWrappers, simulationWrappers, utilities, parameters);
statTables = calcStatTables(allStats, Mode.StandardDeviation, numTables, algorithmSimulationWrappers, numStats, statistics);
printStats(statTables, statistics, Mode.StandardDeviation, newOrder, algorithmSimulationWrappers, algorithmWrappers, simulationWrappers, utilities, parameters);
statTables = calcStatTables(allStats, Mode.WorstCase, numTables, algorithmSimulationWrappers, numStats, statistics);
// Add utilities to table as the last column.
for (int u = 0; u < numTables; u++) {
for (int t = 0; t < algorithmSimulationWrappers.size(); t++) {
statTables[u][t][numStats] = utilities[t];
}
}
printStats(statTables, statistics, Mode.WorstCase, newOrder, algorithmSimulationWrappers, algorithmWrappers, simulationWrappers, utilities, parameters);
}
out.close();
}
use of edu.cmu.tetrad.algcomparison.simulation.Simulation in project tetrad by cmu-phil.
the class TimeoutComparison method doRun.
private void doRun(List<AlgorithmSimulationWrapper> algorithmSimulationWrappers, List<AlgorithmWrapper> algorithmWrappers, List<SimulationWrapper> simulationWrappers, Statistics statistics, int numGraphTypes, double[][][][] allStats, Run run) {
System.out.println();
System.out.println("Run " + (run.getRunIndex() + 1));
System.out.println();
AlgorithmSimulationWrapper algorithmSimulationWrapper = algorithmSimulationWrappers.get(run.getAlgSimIndex());
AlgorithmWrapper algorithmWrapper = algorithmSimulationWrapper.getAlgorithmWrapper();
SimulationWrapper simulationWrapper = algorithmSimulationWrapper.getSimulationWrapper();
DataModel data = simulationWrapper.getDataModel(run.getRunIndex());
Graph trueGraph = simulationWrapper.getTrueGraph(run.getRunIndex());
System.out.println((run.getAlgSimIndex() + 1) + ". " + algorithmWrapper.getDescription() + " simulationWrapper: " + simulationWrapper.getDescription());
long start = System.currentTimeMillis();
Graph out;
try {
Algorithm algorithm = algorithmWrapper.getAlgorithm();
Simulation simulation = simulationWrapper.getSimulation();
if (algorithm instanceof HasKnowledge && simulation instanceof HasKnowledge) {
((HasKnowledge) algorithm).setKnowledge(((HasKnowledge) simulation).getKnowledge());
}
if (algorithmWrapper.getAlgorithm() instanceof ExternalAlgorithm) {
ExternalAlgorithm external = (ExternalAlgorithm) algorithmWrapper.getAlgorithm();
external.setSimulation(simulationWrapper.getSimulation());
external.setPath(resultsPath);
external.setSimIndex(simulationWrappers.indexOf(simulationWrapper));
}
if (algorithm instanceof MultiDataSetAlgorithm) {
List<Integer> indices = new ArrayList<>();
int numDataModels = simulationWrapper.getSimulation().getNumDataModels();
for (int i = 0; i < numDataModels; i++) {
indices.add(i);
}
Collections.shuffle(indices);
List<DataModel> dataModels = new ArrayList<>();
int randomSelectionSize = algorithmWrapper.getAlgorithmSpecificParameters().getInt("randomSelectionSize");
for (int i = 0; i < Math.min(numDataModels, randomSelectionSize); i++) {
dataModels.add(simulationWrapper.getSimulation().getDataModel(indices.get(i)));
}
Parameters _params = algorithmWrapper.getAlgorithmSpecificParameters();
out = ((MultiDataSetAlgorithm) algorithm).search(dataModels, _params);
} else {
DataModel dataModel = copyData ? data.copy() : data;
Parameters _params = algorithmWrapper.getAlgorithmSpecificParameters();
out = algorithm.search(dataModel, _params);
}
} catch (Exception e) {
System.out.println("Could not run " + algorithmWrapper.getDescription());
e.printStackTrace();
return;
}
int simIndex = simulationWrappers.indexOf(simulationWrapper) + 1;
int algIndex = algorithmWrappers.indexOf(algorithmWrapper) + 1;
long stop = System.currentTimeMillis();
long elapsed = stop - start;
saveGraph(resultsPath, out, run.getRunIndex(), simIndex, algIndex, algorithmWrapper, elapsed);
if (trueGraph != null) {
out = GraphUtils.replaceNodes(out, trueGraph.getNodes());
}
if (algorithmWrapper.getAlgorithm() instanceof ExternalAlgorithm) {
ExternalAlgorithm extAlg = (ExternalAlgorithm) algorithmWrapper.getAlgorithm();
extAlg.setSimIndex(simulationWrappers.indexOf(simulationWrapper));
extAlg.setSimulation(simulationWrapper.getSimulation());
extAlg.setPath(resultsPath);
elapsed = extAlg.getElapsedTime(data, simulationWrapper.getSimulationSpecificParameters());
}
Graph[] est = new Graph[numGraphTypes];
Graph comparisonGraph;
if (this.comparisonGraph == ComparisonGraph.true_DAG) {
comparisonGraph = new EdgeListGraph(trueGraph);
} else if (this.comparisonGraph == ComparisonGraph.Pattern_of_the_true_DAG) {
comparisonGraph = SearchGraphUtils.patternForDag(new EdgeListGraph(trueGraph));
} else if (this.comparisonGraph == ComparisonGraph.PAG_of_the_true_DAG) {
comparisonGraph = new DagToPag(new EdgeListGraph(trueGraph)).convert();
} else {
throw new IllegalArgumentException("Unrecognized graph type.");
}
// Graph comparisonGraph = trueGraph == null ? null : algorithmSimulationWrapper.getComparisonGraph(trueGraph);
est[0] = out;
graphTypeUsed[0] = true;
if (data.isMixed()) {
est[1] = getSubgraph(out, true, true, data);
est[2] = getSubgraph(out, true, false, data);
est[3] = getSubgraph(out, false, false, data);
graphTypeUsed[1] = true;
graphTypeUsed[2] = true;
graphTypeUsed[3] = true;
}
Graph[] truth = new Graph[numGraphTypes];
truth[0] = comparisonGraph;
if (data.isMixed() && comparisonGraph != null) {
truth[1] = getSubgraph(comparisonGraph, true, true, data);
truth[2] = getSubgraph(comparisonGraph, true, false, data);
truth[3] = getSubgraph(comparisonGraph, false, false, data);
}
if (comparisonGraph != null) {
for (int u = 0; u < numGraphTypes; u++) {
if (!graphTypeUsed[u]) {
continue;
}
int statIndex = -1;
for (Statistic _stat : statistics.getStatistics()) {
statIndex++;
if (_stat instanceof ParameterColumn) {
continue;
}
double stat;
if (_stat instanceof ElapsedTime) {
stat = elapsed / 1000.0;
} else {
stat = _stat.getValue(truth[u], est[u]);
}
allStats[u][run.getAlgSimIndex()][statIndex][run.getRunIndex()] = stat;
}
}
}
}
use of edu.cmu.tetrad.algcomparison.simulation.Simulation in project tetrad by cmu-phil.
the class TimeoutComparison method configuration.
/**
*/
public void configuration(String path) {
try {
new File(path).mkdirs();
PrintStream out = new PrintStream(new FileOutputStream(new File(path, "Configuration.txt")));
Parameters allParams = new Parameters();
List<Class> algorithms = new ArrayList<>();
List<Class> statistics = new ArrayList<>();
List<Class> independenceWrappers = new ArrayList<>();
List<Class> scoreWrappers = new ArrayList<>();
List<Class> simulations = new ArrayList<>();
algorithms.addAll(getClasses(Algorithm.class));
statistics.addAll(getClasses(Statistic.class));
independenceWrappers.addAll(getClasses(IndependenceWrapper.class));
scoreWrappers.addAll(getClasses(ScoreWrapper.class));
simulations.addAll(getClasses(Simulation.class));
out.println("Available Algorithms:");
out.println();
out.println("Algorithms that take an independence test (using an example independence test):");
out.println();
for (Class clazz : new ArrayList<>(algorithms)) {
if (Experimental.class.isAssignableFrom(clazz)) {
continue;
}
Constructor[] constructors = clazz.getConstructors();
for (Constructor constructor : constructors) {
if (constructor.getParameterTypes().length == 1 && constructor.getParameterTypes()[0] == IndependenceWrapper.class) {
Algorithm algorithm = (Algorithm) constructor.newInstance(FisherZ.class.newInstance());
out.println(clazz.getSimpleName() + ": " + algorithm.getDescription());
if (HasParameters.class.isAssignableFrom(clazz)) {
printParameters(algorithm.getParameters(), allParams, out);
}
if (TakesInitialGraph.class.isAssignableFrom(clazz)) {
out.println("\t" + clazz.getSimpleName() + " can take an initial graph from some other algorithm as input");
}
}
}
}
out.println();
out.println("Algorithms that take a score (using an example score):");
out.println();
for (Class clazz : new ArrayList<>(algorithms)) {
if (Experimental.class.isAssignableFrom(clazz)) {
continue;
}
Constructor[] constructors = clazz.getConstructors();
for (Constructor constructor : constructors) {
if (constructor.getParameterTypes().length == 1 && constructor.getParameterTypes()[0] == ScoreWrapper.class) {
Algorithm algorithm = (Algorithm) constructor.newInstance(BdeuScore.class.newInstance());
out.println(clazz.getSimpleName() + ": " + algorithm.getDescription());
if (HasParameters.class.isAssignableFrom(clazz)) {
printParameters(algorithm.getParameters(), allParams, out);
}
}
}
}
out.println();
out.println("Algorithms with blank constructor:");
out.println();
for (Class clazz : new ArrayList<>(algorithms)) {
if (Experimental.class.isAssignableFrom(clazz)) {
continue;
}
Constructor[] constructors = clazz.getConstructors();
for (Constructor constructor : constructors) {
if (constructor.getParameterTypes().length == 0) {
Algorithm algorithm = (Algorithm) constructor.newInstance();
out.println(clazz.getSimpleName() + ": " + algorithm.getDescription());
if (HasParameters.class.isAssignableFrom(clazz)) {
printParameters(algorithm.getParameters(), allParams, out);
}
}
}
}
out.println();
out.println("Available Statistics:");
out.println();
for (Class clazz : statistics) {
if (Experimental.class.isAssignableFrom(clazz)) {
continue;
}
Constructor[] constructors = clazz.getConstructors();
for (Constructor constructor : constructors) {
if (constructor.getParameterTypes().length == 0) {
Statistic statistic = (Statistic) constructor.newInstance();
out.println(clazz.getSimpleName() + ": " + statistic.getDescription());
}
}
}
out.println();
out.println("Available Independence Tests:");
out.println();
for (Class clazz : independenceWrappers) {
if (Experimental.class.isAssignableFrom(clazz)) {
continue;
}
Constructor[] constructors = clazz.getConstructors();
for (Constructor constructor : constructors) {
if (constructor.getParameterTypes().length == 0) {
IndependenceWrapper independence = (IndependenceWrapper) constructor.newInstance();
out.println(clazz.getSimpleName() + ": " + independence.getDescription());
if (HasParameters.class.isAssignableFrom(clazz)) {
printParameters(independence.getParameters(), allParams, out);
}
}
}
}
out.println();
out.println("Available Scores:");
out.println();
for (Class clazz : scoreWrappers) {
if (Experimental.class.isAssignableFrom(clazz)) {
continue;
}
Constructor[] constructors = clazz.getConstructors();
for (Constructor constructor : constructors) {
if (constructor.getParameterTypes().length == 0) {
ScoreWrapper score = (ScoreWrapper) constructor.newInstance();
out.println(clazz.getSimpleName() + ": " + score.getDescription());
if (HasParameters.class.isAssignableFrom(clazz)) {
printParameters(score.getParameters(), allParams, out);
}
}
}
}
out.println();
out.println("Available Simulations:");
out.println();
for (Class clazz : simulations) {
if (Experimental.class.isAssignableFrom(clazz)) {
continue;
}
Constructor[] constructors = clazz.getConstructors();
for (Constructor constructor : constructors) {
if (constructor.getParameterTypes().length == 0) {
Simulation simulation = (Simulation) constructor.newInstance();
out.println(clazz.getSimpleName() + ": " + simulation.getDescription());
if (HasParameters.class.isAssignableFrom(clazz)) {
printParameters(simulation.getParameters(), allParams, out);
}
}
}
}
out.println();
out.close();
} catch (Exception e) {
e.printStackTrace();
}
}
use of edu.cmu.tetrad.algcomparison.simulation.Simulation in project tetrad by cmu-phil.
the class Comparison method configuration.
/**
*/
public void configuration(String path) {
try {
new File(path).mkdirs();
PrintStream out = new PrintStream(new FileOutputStream(new File(path, "Configuration.txt")));
Parameters allParams = new Parameters();
List<Class> algorithms = new ArrayList<>();
List<Class> statistics = new ArrayList<>();
List<Class> independenceWrappers = new ArrayList<>();
List<Class> scoreWrappers = new ArrayList<>();
List<Class> simulations = new ArrayList<>();
algorithms.addAll(getClasses(Algorithm.class));
statistics.addAll(getClasses(Statistic.class));
independenceWrappers.addAll(getClasses(IndependenceWrapper.class));
scoreWrappers.addAll(getClasses(ScoreWrapper.class));
simulations.addAll(getClasses(Simulation.class));
out.println("Available Algorithms:");
out.println();
out.println("Algorithms that take an independence test (using an example independence test):");
out.println();
for (Class clazz : new ArrayList<>(algorithms)) {
if (Experimental.class.isAssignableFrom(clazz)) {
continue;
}
Constructor[] constructors = clazz.getConstructors();
for (Constructor constructor : constructors) {
if (constructor.getParameterTypes().length == 1 && constructor.getParameterTypes()[0] == IndependenceWrapper.class) {
Algorithm algorithm = (Algorithm) constructor.newInstance(FisherZ.class.newInstance());
out.println(clazz.getSimpleName() + ": " + algorithm.getDescription());
if (HasParameters.class.isAssignableFrom(clazz)) {
printParameters(algorithm.getParameters(), allParams, out);
}
if (TakesInitialGraph.class.isAssignableFrom(clazz)) {
out.println("\t" + clazz.getSimpleName() + " can take an initial graph from some other algorithm as input");
}
}
}
}
out.println();
out.println("Algorithms that take a score (using an example score):");
out.println();
for (Class clazz : new ArrayList<>(algorithms)) {
if (Experimental.class.isAssignableFrom(clazz)) {
continue;
}
Constructor[] constructors = clazz.getConstructors();
for (Constructor constructor : constructors) {
if (constructor.getParameterTypes().length == 1 && constructor.getParameterTypes()[0] == ScoreWrapper.class) {
Algorithm algorithm = (Algorithm) constructor.newInstance(BdeuScore.class.newInstance());
out.println(clazz.getSimpleName() + ": " + algorithm.getDescription());
if (HasParameters.class.isAssignableFrom(clazz)) {
printParameters(algorithm.getParameters(), allParams, out);
}
}
}
}
out.println();
out.println("Algorithms with blank constructor:");
out.println();
for (Class clazz : new ArrayList<>(algorithms)) {
if (Experimental.class.isAssignableFrom(clazz)) {
continue;
}
Constructor[] constructors = clazz.getConstructors();
for (Constructor constructor : constructors) {
if (constructor.getParameterTypes().length == 0) {
Algorithm algorithm = (Algorithm) constructor.newInstance();
out.println(clazz.getSimpleName() + ": " + algorithm.getDescription());
if (HasParameters.class.isAssignableFrom(clazz)) {
printParameters(algorithm.getParameters(), allParams, out);
}
}
}
}
out.println();
out.println("Available Statistics:");
out.println();
for (Class clazz : statistics) {
if (Experimental.class.isAssignableFrom(clazz)) {
continue;
}
Constructor[] constructors = clazz.getConstructors();
for (Constructor constructor : constructors) {
if (constructor.getParameterTypes().length == 0) {
Statistic statistic = (Statistic) constructor.newInstance();
out.println(clazz.getSimpleName() + ": " + statistic.getDescription());
}
}
}
out.println();
out.println("Available Independence Tests:");
out.println();
for (Class clazz : independenceWrappers) {
if (Experimental.class.isAssignableFrom(clazz)) {
continue;
}
Constructor[] constructors = clazz.getConstructors();
for (Constructor constructor : constructors) {
if (constructor.getParameterTypes().length == 0) {
IndependenceWrapper independence = (IndependenceWrapper) constructor.newInstance();
out.println(clazz.getSimpleName() + ": " + independence.getDescription());
if (HasParameters.class.isAssignableFrom(clazz)) {
printParameters(independence.getParameters(), allParams, out);
}
}
}
}
out.println();
out.println("Available Scores:");
out.println();
for (Class clazz : scoreWrappers) {
if (Experimental.class.isAssignableFrom(clazz)) {
continue;
}
Constructor[] constructors = clazz.getConstructors();
for (Constructor constructor : constructors) {
if (constructor.getParameterTypes().length == 0) {
ScoreWrapper score = (ScoreWrapper) constructor.newInstance();
out.println(clazz.getSimpleName() + ": " + score.getDescription());
if (HasParameters.class.isAssignableFrom(clazz)) {
printParameters(score.getParameters(), allParams, out);
}
}
}
}
out.println();
out.println("Available Simulations:");
out.println();
for (Class clazz : simulations) {
if (Experimental.class.isAssignableFrom(clazz)) {
continue;
}
Constructor[] constructors = clazz.getConstructors();
for (Constructor constructor : constructors) {
if (constructor.getParameterTypes().length == 0) {
Simulation simulation = (Simulation) constructor.newInstance();
out.println(clazz.getSimpleName() + ": " + simulation.getDescription());
if (HasParameters.class.isAssignableFrom(clazz)) {
printParameters(simulation.getParameters(), allParams, out);
}
}
}
}
out.println();
out.close();
} catch (Exception e) {
e.printStackTrace();
}
}
use of edu.cmu.tetrad.algcomparison.simulation.Simulation in project tetrad by cmu-phil.
the class Comparison method doRun.
private void doRun(List<AlgorithmSimulationWrapper> algorithmSimulationWrappers, List<AlgorithmWrapper> algorithmWrappers, List<SimulationWrapper> simulationWrappers, Statistics statistics, int numGraphTypes, double[][][][] allStats, Run run) {
System.out.println();
System.out.println("Run " + (run.getRunIndex() + 1));
System.out.println();
AlgorithmSimulationWrapper algorithmSimulationWrapper = algorithmSimulationWrappers.get(run.getAlgSimIndex());
AlgorithmWrapper algorithmWrapper = algorithmSimulationWrapper.getAlgorithmWrapper();
SimulationWrapper simulationWrapper = algorithmSimulationWrapper.getSimulationWrapper();
DataModel data = simulationWrapper.getDataModel(run.getRunIndex());
Graph trueGraph = simulationWrapper.getTrueGraph(run.getRunIndex());
System.out.println((run.getAlgSimIndex() + 1) + ". " + algorithmWrapper.getDescription() + " simulationWrapper: " + simulationWrapper.getDescription());
long start = System.currentTimeMillis();
Graph out;
try {
Algorithm algorithm = algorithmWrapper.getAlgorithm();
Simulation simulation = simulationWrapper.getSimulation();
if (algorithm instanceof HasKnowledge && simulation instanceof HasKnowledge) {
((HasKnowledge) algorithm).setKnowledge(((HasKnowledge) simulation).getKnowledge());
}
if (algorithmWrapper.getAlgorithm() instanceof ExternalAlgorithm) {
ExternalAlgorithm external = (ExternalAlgorithm) algorithmWrapper.getAlgorithm();
external.setSimulation(simulationWrapper.getSimulation());
external.setPath(resultsPath);
external.setSimIndex(simulationWrappers.indexOf(simulationWrapper));
}
if (algorithm instanceof MultiDataSetAlgorithm) {
List<Integer> indices = new ArrayList<>();
int numDataModels = simulationWrapper.getSimulation().getNumDataModels();
for (int i = 0; i < numDataModels; i++) indices.add(i);
Collections.shuffle(indices);
List<DataModel> dataModels = new ArrayList<>();
int randomSelectionSize = algorithmWrapper.getAlgorithmSpecificParameters().getInt("randomSelectionSize");
for (int i = 0; i < Math.min(numDataModels, randomSelectionSize); i++) {
dataModels.add(simulationWrapper.getSimulation().getDataModel(indices.get(i)));
}
Parameters _params = algorithmWrapper.getAlgorithmSpecificParameters();
out = ((MultiDataSetAlgorithm) algorithm).search(dataModels, _params);
} else {
DataModel dataModel = copyData ? data.copy() : data;
Parameters _params = algorithmWrapper.getAlgorithmSpecificParameters();
out = algorithm.search(dataModel, _params);
}
} catch (Exception e) {
System.out.println("Could not run " + algorithmWrapper.getDescription());
e.printStackTrace();
return;
}
int simIndex = simulationWrappers.indexOf(simulationWrapper) + 1;
int algIndex = algorithmWrappers.indexOf(algorithmWrapper) + 1;
long stop = System.currentTimeMillis();
long elapsed = stop - start;
saveGraph(resultsPath, out, run.getRunIndex(), simIndex, algIndex, algorithmWrapper, elapsed);
if (trueGraph != null) {
out = GraphUtils.replaceNodes(out, trueGraph.getNodes());
}
if (algorithmWrapper.getAlgorithm() instanceof ExternalAlgorithm) {
ExternalAlgorithm extAlg = (ExternalAlgorithm) algorithmWrapper.getAlgorithm();
extAlg.setSimIndex(simulationWrappers.indexOf(simulationWrapper));
extAlg.setSimulation(simulationWrapper.getSimulation());
extAlg.setPath(resultsPath);
elapsed = extAlg.getElapsedTime(data, simulationWrapper.getSimulationSpecificParameters());
}
Graph[] est = new Graph[numGraphTypes];
Graph comparisonGraph;
if (this.comparisonGraph == ComparisonGraph.true_DAG) {
comparisonGraph = new EdgeListGraph(trueGraph);
} else if (this.comparisonGraph == ComparisonGraph.Pattern_of_the_true_DAG) {
comparisonGraph = SearchGraphUtils.patternForDag(new EdgeListGraph(trueGraph));
} else if (this.comparisonGraph == ComparisonGraph.PAG_of_the_true_DAG) {
comparisonGraph = new DagToPag(new EdgeListGraph(trueGraph)).convert();
} else {
throw new IllegalArgumentException("Unrecognized graph type.");
}
// Graph comparisonGraph = trueGraph == null ? null : algorithmSimulationWrapper.getComparisonGraph(trueGraph);
est[0] = out;
graphTypeUsed[0] = true;
if (data.isMixed()) {
est[1] = getSubgraph(out, true, true, data);
est[2] = getSubgraph(out, true, false, data);
est[3] = getSubgraph(out, false, false, data);
graphTypeUsed[1] = true;
graphTypeUsed[2] = true;
graphTypeUsed[3] = true;
}
Graph[] truth = new Graph[numGraphTypes];
truth[0] = comparisonGraph;
if (data.isMixed() && comparisonGraph != null) {
truth[1] = getSubgraph(comparisonGraph, true, true, data);
truth[2] = getSubgraph(comparisonGraph, true, false, data);
truth[3] = getSubgraph(comparisonGraph, false, false, data);
}
if (comparisonGraph != null) {
for (int u = 0; u < numGraphTypes; u++) {
if (!graphTypeUsed[u])
continue;
int statIndex = -1;
for (Statistic _stat : statistics.getStatistics()) {
statIndex++;
if (_stat instanceof ParameterColumn)
continue;
double stat;
if (_stat instanceof ElapsedTime) {
stat = elapsed / 1000.0;
} else {
stat = _stat.getValue(truth[u], est[u]);
}
allStats[u][run.getAlgSimIndex()][statIndex][run.getRunIndex()] = stat;
}
}
}
}
Aggregations