use of edu.cmu.tetrad.algcomparison.statistic.ParameterColumn in project tetrad by cmu-phil.
the class TimeoutComparison method calcStatTables.
private double[][][] calcStatTables(double[][][][] allStats, Mode mode, int numTables, List<AlgorithmSimulationWrapper> wrappers, int numStats, Statistics statistics) {
double[][][] statTables = new double[numTables][wrappers.size()][numStats + 1];
for (int u = 0; u < numTables; u++) {
for (int i = 0; i < wrappers.size(); i++) {
for (int j = 0; j < numStats; j++) {
if (statistics.getStatistics().get(j) instanceof ParameterColumn) {
String statName = statistics.getStatistics().get(j).getAbbreviation();
SimulationWrapper simulationWrapper = wrappers.get(i).getSimulationWrapper();
AlgorithmWrapper algorithmWrapper = wrappers.get(i).getAlgorithmWrapper();
double stat = Double.NaN;
List<String> parameterNames = simulationWrapper.getParameters();
Parameters parameters = simulationWrapper.getSimulationSpecificParameters();
for (String name : parameterNames) {
if (name.equals(statName)) {
if (parameters.get(name) instanceof Boolean) {
boolean b = parameters.getBoolean(name);
stat = b ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY;
} else {
stat = parameters.getDouble(name);
}
break;
}
}
if (Double.isNaN(stat)) {
List<String> _parameterNames = algorithmWrapper.getParameters();
Parameters _parameters = algorithmWrapper.parameters;
for (String name : _parameterNames) {
if (name.equals(statName)) {
try {
stat = _parameters.getDouble(name);
} catch (Exception e) {
boolean b = _parameters.getBoolean(name);
stat = b ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY;
}
break;
}
}
}
statTables[u][i][j] = stat;
} else if (mode == Mode.Average) {
statTables[u][i][j] = StatUtils.mean(allStats[u][i][j]);
} else if (mode == Mode.WorstCase) {
statTables[u][i][j] = StatUtils.min(allStats[u][i][j]);
} else if (mode == Mode.StandardDeviation) {
statTables[u][i][j] = StatUtils.sd(allStats[u][i][j]);
} else {
throw new IllegalStateException();
}
}
}
}
return statTables;
}
use of edu.cmu.tetrad.algcomparison.statistic.ParameterColumn in project tetrad by cmu-phil.
the class Comparison method calcStatTables.
private double[][][] calcStatTables(double[][][][] allStats, Mode mode, int numTables, List<AlgorithmSimulationWrapper> wrappers, int numStats, Statistics statistics) {
double[][][] statTables = new double[numTables][wrappers.size()][numStats + 1];
for (int u = 0; u < numTables; u++) {
for (int i = 0; i < wrappers.size(); i++) {
for (int j = 0; j < numStats; j++) {
if (statistics.getStatistics().get(j) instanceof ParameterColumn) {
String statName = statistics.getStatistics().get(j).getAbbreviation();
SimulationWrapper simulationWrapper = wrappers.get(i).getSimulationWrapper();
AlgorithmWrapper algorithmWrapper = wrappers.get(i).getAlgorithmWrapper();
double stat = Double.NaN;
List<String> parameterNames = simulationWrapper.getParameters();
Parameters parameters = simulationWrapper.getSimulationSpecificParameters();
for (String name : parameterNames) {
if (name.equals(statName)) {
if (parameters.get(name) instanceof Boolean) {
boolean b = parameters.getBoolean(name);
stat = b ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY;
} else {
stat = parameters.getDouble(name);
}
break;
}
}
if (Double.isNaN(stat)) {
List<String> _parameterNames = algorithmWrapper.getParameters();
Parameters _parameters = algorithmWrapper.parameters;
for (String name : _parameterNames) {
if (name.equals(statName)) {
try {
stat = _parameters.getDouble(name);
} catch (Exception e) {
boolean b = _parameters.getBoolean(name);
stat = b ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY;
}
break;
}
}
}
statTables[u][i][j] = stat;
} else if (mode == Mode.Average) {
statTables[u][i][j] = StatUtils.mean(allStats[u][i][j]);
} else if (mode == Mode.WorstCase) {
statTables[u][i][j] = StatUtils.min(allStats[u][i][j]);
} else if (mode == Mode.StandardDeviation) {
statTables[u][i][j] = StatUtils.sd(allStats[u][i][j]);
} else {
throw new IllegalStateException();
}
}
}
}
return statTables;
}
use of edu.cmu.tetrad.algcomparison.statistic.ParameterColumn in project tetrad by cmu-phil.
the class TimeoutComparison method doRun.
private void doRun(List<AlgorithmSimulationWrapper> algorithmSimulationWrappers, List<AlgorithmWrapper> algorithmWrappers, List<SimulationWrapper> simulationWrappers, Statistics statistics, int numGraphTypes, double[][][][] allStats, Run run) {
System.out.println();
System.out.println("Run " + (run.getRunIndex() + 1));
System.out.println();
AlgorithmSimulationWrapper algorithmSimulationWrapper = algorithmSimulationWrappers.get(run.getAlgSimIndex());
AlgorithmWrapper algorithmWrapper = algorithmSimulationWrapper.getAlgorithmWrapper();
SimulationWrapper simulationWrapper = algorithmSimulationWrapper.getSimulationWrapper();
DataModel data = simulationWrapper.getDataModel(run.getRunIndex());
Graph trueGraph = simulationWrapper.getTrueGraph(run.getRunIndex());
System.out.println((run.getAlgSimIndex() + 1) + ". " + algorithmWrapper.getDescription() + " simulationWrapper: " + simulationWrapper.getDescription());
long start = System.currentTimeMillis();
Graph out;
try {
Algorithm algorithm = algorithmWrapper.getAlgorithm();
Simulation simulation = simulationWrapper.getSimulation();
if (algorithm instanceof HasKnowledge && simulation instanceof HasKnowledge) {
((HasKnowledge) algorithm).setKnowledge(((HasKnowledge) simulation).getKnowledge());
}
if (algorithmWrapper.getAlgorithm() instanceof ExternalAlgorithm) {
ExternalAlgorithm external = (ExternalAlgorithm) algorithmWrapper.getAlgorithm();
external.setSimulation(simulationWrapper.getSimulation());
external.setPath(resultsPath);
external.setSimIndex(simulationWrappers.indexOf(simulationWrapper));
}
if (algorithm instanceof MultiDataSetAlgorithm) {
List<Integer> indices = new ArrayList<>();
int numDataModels = simulationWrapper.getSimulation().getNumDataModels();
for (int i = 0; i < numDataModels; i++) {
indices.add(i);
}
Collections.shuffle(indices);
List<DataModel> dataModels = new ArrayList<>();
int randomSelectionSize = algorithmWrapper.getAlgorithmSpecificParameters().getInt("randomSelectionSize");
for (int i = 0; i < Math.min(numDataModels, randomSelectionSize); i++) {
dataModels.add(simulationWrapper.getSimulation().getDataModel(indices.get(i)));
}
Parameters _params = algorithmWrapper.getAlgorithmSpecificParameters();
out = ((MultiDataSetAlgorithm) algorithm).search(dataModels, _params);
} else {
DataModel dataModel = copyData ? data.copy() : data;
Parameters _params = algorithmWrapper.getAlgorithmSpecificParameters();
out = algorithm.search(dataModel, _params);
}
} catch (Exception e) {
System.out.println("Could not run " + algorithmWrapper.getDescription());
e.printStackTrace();
return;
}
int simIndex = simulationWrappers.indexOf(simulationWrapper) + 1;
int algIndex = algorithmWrappers.indexOf(algorithmWrapper) + 1;
long stop = System.currentTimeMillis();
long elapsed = stop - start;
saveGraph(resultsPath, out, run.getRunIndex(), simIndex, algIndex, algorithmWrapper, elapsed);
if (trueGraph != null) {
out = GraphUtils.replaceNodes(out, trueGraph.getNodes());
}
if (algorithmWrapper.getAlgorithm() instanceof ExternalAlgorithm) {
ExternalAlgorithm extAlg = (ExternalAlgorithm) algorithmWrapper.getAlgorithm();
extAlg.setSimIndex(simulationWrappers.indexOf(simulationWrapper));
extAlg.setSimulation(simulationWrapper.getSimulation());
extAlg.setPath(resultsPath);
elapsed = extAlg.getElapsedTime(data, simulationWrapper.getSimulationSpecificParameters());
}
Graph[] est = new Graph[numGraphTypes];
Graph comparisonGraph;
if (this.comparisonGraph == ComparisonGraph.true_DAG) {
comparisonGraph = new EdgeListGraph(trueGraph);
} else if (this.comparisonGraph == ComparisonGraph.Pattern_of_the_true_DAG) {
comparisonGraph = SearchGraphUtils.patternForDag(new EdgeListGraph(trueGraph));
} else if (this.comparisonGraph == ComparisonGraph.PAG_of_the_true_DAG) {
comparisonGraph = new DagToPag(new EdgeListGraph(trueGraph)).convert();
} else {
throw new IllegalArgumentException("Unrecognized graph type.");
}
// Graph comparisonGraph = trueGraph == null ? null : algorithmSimulationWrapper.getComparisonGraph(trueGraph);
est[0] = out;
graphTypeUsed[0] = true;
if (data.isMixed()) {
est[1] = getSubgraph(out, true, true, data);
est[2] = getSubgraph(out, true, false, data);
est[3] = getSubgraph(out, false, false, data);
graphTypeUsed[1] = true;
graphTypeUsed[2] = true;
graphTypeUsed[3] = true;
}
Graph[] truth = new Graph[numGraphTypes];
truth[0] = comparisonGraph;
if (data.isMixed() && comparisonGraph != null) {
truth[1] = getSubgraph(comparisonGraph, true, true, data);
truth[2] = getSubgraph(comparisonGraph, true, false, data);
truth[3] = getSubgraph(comparisonGraph, false, false, data);
}
if (comparisonGraph != null) {
for (int u = 0; u < numGraphTypes; u++) {
if (!graphTypeUsed[u]) {
continue;
}
int statIndex = -1;
for (Statistic _stat : statistics.getStatistics()) {
statIndex++;
if (_stat instanceof ParameterColumn) {
continue;
}
double stat;
if (_stat instanceof ElapsedTime) {
stat = elapsed / 1000.0;
} else {
stat = _stat.getValue(truth[u], est[u]);
}
allStats[u][run.getAlgSimIndex()][statIndex][run.getRunIndex()] = stat;
}
}
}
}
use of edu.cmu.tetrad.algcomparison.statistic.ParameterColumn in project tetrad by cmu-phil.
the class Comparison method doRun.
private void doRun(List<AlgorithmSimulationWrapper> algorithmSimulationWrappers, List<AlgorithmWrapper> algorithmWrappers, List<SimulationWrapper> simulationWrappers, Statistics statistics, int numGraphTypes, double[][][][] allStats, Run run) {
System.out.println();
System.out.println("Run " + (run.getRunIndex() + 1));
System.out.println();
AlgorithmSimulationWrapper algorithmSimulationWrapper = algorithmSimulationWrappers.get(run.getAlgSimIndex());
AlgorithmWrapper algorithmWrapper = algorithmSimulationWrapper.getAlgorithmWrapper();
SimulationWrapper simulationWrapper = algorithmSimulationWrapper.getSimulationWrapper();
DataModel data = simulationWrapper.getDataModel(run.getRunIndex());
Graph trueGraph = simulationWrapper.getTrueGraph(run.getRunIndex());
System.out.println((run.getAlgSimIndex() + 1) + ". " + algorithmWrapper.getDescription() + " simulationWrapper: " + simulationWrapper.getDescription());
long start = System.currentTimeMillis();
Graph out;
try {
Algorithm algorithm = algorithmWrapper.getAlgorithm();
Simulation simulation = simulationWrapper.getSimulation();
if (algorithm instanceof HasKnowledge && simulation instanceof HasKnowledge) {
((HasKnowledge) algorithm).setKnowledge(((HasKnowledge) simulation).getKnowledge());
}
if (algorithmWrapper.getAlgorithm() instanceof ExternalAlgorithm) {
ExternalAlgorithm external = (ExternalAlgorithm) algorithmWrapper.getAlgorithm();
external.setSimulation(simulationWrapper.getSimulation());
external.setPath(resultsPath);
external.setSimIndex(simulationWrappers.indexOf(simulationWrapper));
}
if (algorithm instanceof MultiDataSetAlgorithm) {
List<Integer> indices = new ArrayList<>();
int numDataModels = simulationWrapper.getSimulation().getNumDataModels();
for (int i = 0; i < numDataModels; i++) indices.add(i);
Collections.shuffle(indices);
List<DataModel> dataModels = new ArrayList<>();
int randomSelectionSize = algorithmWrapper.getAlgorithmSpecificParameters().getInt("randomSelectionSize");
for (int i = 0; i < Math.min(numDataModels, randomSelectionSize); i++) {
dataModels.add(simulationWrapper.getSimulation().getDataModel(indices.get(i)));
}
Parameters _params = algorithmWrapper.getAlgorithmSpecificParameters();
out = ((MultiDataSetAlgorithm) algorithm).search(dataModels, _params);
} else {
DataModel dataModel = copyData ? data.copy() : data;
Parameters _params = algorithmWrapper.getAlgorithmSpecificParameters();
out = algorithm.search(dataModel, _params);
}
} catch (Exception e) {
System.out.println("Could not run " + algorithmWrapper.getDescription());
e.printStackTrace();
return;
}
int simIndex = simulationWrappers.indexOf(simulationWrapper) + 1;
int algIndex = algorithmWrappers.indexOf(algorithmWrapper) + 1;
long stop = System.currentTimeMillis();
long elapsed = stop - start;
saveGraph(resultsPath, out, run.getRunIndex(), simIndex, algIndex, algorithmWrapper, elapsed);
if (trueGraph != null) {
out = GraphUtils.replaceNodes(out, trueGraph.getNodes());
}
if (algorithmWrapper.getAlgorithm() instanceof ExternalAlgorithm) {
ExternalAlgorithm extAlg = (ExternalAlgorithm) algorithmWrapper.getAlgorithm();
extAlg.setSimIndex(simulationWrappers.indexOf(simulationWrapper));
extAlg.setSimulation(simulationWrapper.getSimulation());
extAlg.setPath(resultsPath);
elapsed = extAlg.getElapsedTime(data, simulationWrapper.getSimulationSpecificParameters());
}
Graph[] est = new Graph[numGraphTypes];
Graph comparisonGraph;
if (this.comparisonGraph == ComparisonGraph.true_DAG) {
comparisonGraph = new EdgeListGraph(trueGraph);
} else if (this.comparisonGraph == ComparisonGraph.Pattern_of_the_true_DAG) {
comparisonGraph = SearchGraphUtils.patternForDag(new EdgeListGraph(trueGraph));
} else if (this.comparisonGraph == ComparisonGraph.PAG_of_the_true_DAG) {
comparisonGraph = new DagToPag(new EdgeListGraph(trueGraph)).convert();
} else {
throw new IllegalArgumentException("Unrecognized graph type.");
}
// Graph comparisonGraph = trueGraph == null ? null : algorithmSimulationWrapper.getComparisonGraph(trueGraph);
est[0] = out;
graphTypeUsed[0] = true;
if (data.isMixed()) {
est[1] = getSubgraph(out, true, true, data);
est[2] = getSubgraph(out, true, false, data);
est[3] = getSubgraph(out, false, false, data);
graphTypeUsed[1] = true;
graphTypeUsed[2] = true;
graphTypeUsed[3] = true;
}
Graph[] truth = new Graph[numGraphTypes];
truth[0] = comparisonGraph;
if (data.isMixed() && comparisonGraph != null) {
truth[1] = getSubgraph(comparisonGraph, true, true, data);
truth[2] = getSubgraph(comparisonGraph, true, false, data);
truth[3] = getSubgraph(comparisonGraph, false, false, data);
}
if (comparisonGraph != null) {
for (int u = 0; u < numGraphTypes; u++) {
if (!graphTypeUsed[u])
continue;
int statIndex = -1;
for (Statistic _stat : statistics.getStatistics()) {
statIndex++;
if (_stat instanceof ParameterColumn)
continue;
double stat;
if (_stat instanceof ElapsedTime) {
stat = elapsed / 1000.0;
} else {
stat = _stat.getValue(truth[u], est[u]);
}
allStats[u][run.getAlgSimIndex()][statIndex][run.getRunIndex()] = stat;
}
}
}
}
Aggregations