use of edu.cmu.tetrad.algcomparison.statistic.Statistic in project tetrad by cmu-phil.
the class TimeoutComparison method configuration.
/**
*/
public void configuration(String path) {
try {
new File(path).mkdirs();
PrintStream out = new PrintStream(new FileOutputStream(new File(path, "Configuration.txt")));
Parameters allParams = new Parameters();
List<Class> algorithms = new ArrayList<>();
List<Class> statistics = new ArrayList<>();
List<Class> independenceWrappers = new ArrayList<>();
List<Class> scoreWrappers = new ArrayList<>();
List<Class> simulations = new ArrayList<>();
algorithms.addAll(getClasses(Algorithm.class));
statistics.addAll(getClasses(Statistic.class));
independenceWrappers.addAll(getClasses(IndependenceWrapper.class));
scoreWrappers.addAll(getClasses(ScoreWrapper.class));
simulations.addAll(getClasses(Simulation.class));
out.println("Available Algorithms:");
out.println();
out.println("Algorithms that take an independence test (using an example independence test):");
out.println();
for (Class clazz : new ArrayList<>(algorithms)) {
if (Experimental.class.isAssignableFrom(clazz)) {
continue;
}
Constructor[] constructors = clazz.getConstructors();
for (Constructor constructor : constructors) {
if (constructor.getParameterTypes().length == 1 && constructor.getParameterTypes()[0] == IndependenceWrapper.class) {
Algorithm algorithm = (Algorithm) constructor.newInstance(FisherZ.class.newInstance());
out.println(clazz.getSimpleName() + ": " + algorithm.getDescription());
if (HasParameters.class.isAssignableFrom(clazz)) {
printParameters(algorithm.getParameters(), allParams, out);
}
if (TakesInitialGraph.class.isAssignableFrom(clazz)) {
out.println("\t" + clazz.getSimpleName() + " can take an initial graph from some other algorithm as input");
}
}
}
}
out.println();
out.println("Algorithms that take a score (using an example score):");
out.println();
for (Class clazz : new ArrayList<>(algorithms)) {
if (Experimental.class.isAssignableFrom(clazz)) {
continue;
}
Constructor[] constructors = clazz.getConstructors();
for (Constructor constructor : constructors) {
if (constructor.getParameterTypes().length == 1 && constructor.getParameterTypes()[0] == ScoreWrapper.class) {
Algorithm algorithm = (Algorithm) constructor.newInstance(BdeuScore.class.newInstance());
out.println(clazz.getSimpleName() + ": " + algorithm.getDescription());
if (HasParameters.class.isAssignableFrom(clazz)) {
printParameters(algorithm.getParameters(), allParams, out);
}
}
}
}
out.println();
out.println("Algorithms with blank constructor:");
out.println();
for (Class clazz : new ArrayList<>(algorithms)) {
if (Experimental.class.isAssignableFrom(clazz)) {
continue;
}
Constructor[] constructors = clazz.getConstructors();
for (Constructor constructor : constructors) {
if (constructor.getParameterTypes().length == 0) {
Algorithm algorithm = (Algorithm) constructor.newInstance();
out.println(clazz.getSimpleName() + ": " + algorithm.getDescription());
if (HasParameters.class.isAssignableFrom(clazz)) {
printParameters(algorithm.getParameters(), allParams, out);
}
}
}
}
out.println();
out.println("Available Statistics:");
out.println();
for (Class clazz : statistics) {
if (Experimental.class.isAssignableFrom(clazz)) {
continue;
}
Constructor[] constructors = clazz.getConstructors();
for (Constructor constructor : constructors) {
if (constructor.getParameterTypes().length == 0) {
Statistic statistic = (Statistic) constructor.newInstance();
out.println(clazz.getSimpleName() + ": " + statistic.getDescription());
}
}
}
out.println();
out.println("Available Independence Tests:");
out.println();
for (Class clazz : independenceWrappers) {
if (Experimental.class.isAssignableFrom(clazz)) {
continue;
}
Constructor[] constructors = clazz.getConstructors();
for (Constructor constructor : constructors) {
if (constructor.getParameterTypes().length == 0) {
IndependenceWrapper independence = (IndependenceWrapper) constructor.newInstance();
out.println(clazz.getSimpleName() + ": " + independence.getDescription());
if (HasParameters.class.isAssignableFrom(clazz)) {
printParameters(independence.getParameters(), allParams, out);
}
}
}
}
out.println();
out.println("Available Scores:");
out.println();
for (Class clazz : scoreWrappers) {
if (Experimental.class.isAssignableFrom(clazz)) {
continue;
}
Constructor[] constructors = clazz.getConstructors();
for (Constructor constructor : constructors) {
if (constructor.getParameterTypes().length == 0) {
ScoreWrapper score = (ScoreWrapper) constructor.newInstance();
out.println(clazz.getSimpleName() + ": " + score.getDescription());
if (HasParameters.class.isAssignableFrom(clazz)) {
printParameters(score.getParameters(), allParams, out);
}
}
}
}
out.println();
out.println("Available Simulations:");
out.println();
for (Class clazz : simulations) {
if (Experimental.class.isAssignableFrom(clazz)) {
continue;
}
Constructor[] constructors = clazz.getConstructors();
for (Constructor constructor : constructors) {
if (constructor.getParameterTypes().length == 0) {
Simulation simulation = (Simulation) constructor.newInstance();
out.println(clazz.getSimpleName() + ": " + simulation.getDescription());
if (HasParameters.class.isAssignableFrom(clazz)) {
printParameters(simulation.getParameters(), allParams, out);
}
}
}
}
out.println();
out.close();
} catch (Exception e) {
e.printStackTrace();
}
}
use of edu.cmu.tetrad.algcomparison.statistic.Statistic in project tetrad by cmu-phil.
the class Comparison method configuration.
/**
*/
public void configuration(String path) {
try {
new File(path).mkdirs();
PrintStream out = new PrintStream(new FileOutputStream(new File(path, "Configuration.txt")));
Parameters allParams = new Parameters();
List<Class> algorithms = new ArrayList<>();
List<Class> statistics = new ArrayList<>();
List<Class> independenceWrappers = new ArrayList<>();
List<Class> scoreWrappers = new ArrayList<>();
List<Class> simulations = new ArrayList<>();
algorithms.addAll(getClasses(Algorithm.class));
statistics.addAll(getClasses(Statistic.class));
independenceWrappers.addAll(getClasses(IndependenceWrapper.class));
scoreWrappers.addAll(getClasses(ScoreWrapper.class));
simulations.addAll(getClasses(Simulation.class));
out.println("Available Algorithms:");
out.println();
out.println("Algorithms that take an independence test (using an example independence test):");
out.println();
for (Class clazz : new ArrayList<>(algorithms)) {
if (Experimental.class.isAssignableFrom(clazz)) {
continue;
}
Constructor[] constructors = clazz.getConstructors();
for (Constructor constructor : constructors) {
if (constructor.getParameterTypes().length == 1 && constructor.getParameterTypes()[0] == IndependenceWrapper.class) {
Algorithm algorithm = (Algorithm) constructor.newInstance(FisherZ.class.newInstance());
out.println(clazz.getSimpleName() + ": " + algorithm.getDescription());
if (HasParameters.class.isAssignableFrom(clazz)) {
printParameters(algorithm.getParameters(), allParams, out);
}
if (TakesInitialGraph.class.isAssignableFrom(clazz)) {
out.println("\t" + clazz.getSimpleName() + " can take an initial graph from some other algorithm as input");
}
}
}
}
out.println();
out.println("Algorithms that take a score (using an example score):");
out.println();
for (Class clazz : new ArrayList<>(algorithms)) {
if (Experimental.class.isAssignableFrom(clazz)) {
continue;
}
Constructor[] constructors = clazz.getConstructors();
for (Constructor constructor : constructors) {
if (constructor.getParameterTypes().length == 1 && constructor.getParameterTypes()[0] == ScoreWrapper.class) {
Algorithm algorithm = (Algorithm) constructor.newInstance(BdeuScore.class.newInstance());
out.println(clazz.getSimpleName() + ": " + algorithm.getDescription());
if (HasParameters.class.isAssignableFrom(clazz)) {
printParameters(algorithm.getParameters(), allParams, out);
}
}
}
}
out.println();
out.println("Algorithms with blank constructor:");
out.println();
for (Class clazz : new ArrayList<>(algorithms)) {
if (Experimental.class.isAssignableFrom(clazz)) {
continue;
}
Constructor[] constructors = clazz.getConstructors();
for (Constructor constructor : constructors) {
if (constructor.getParameterTypes().length == 0) {
Algorithm algorithm = (Algorithm) constructor.newInstance();
out.println(clazz.getSimpleName() + ": " + algorithm.getDescription());
if (HasParameters.class.isAssignableFrom(clazz)) {
printParameters(algorithm.getParameters(), allParams, out);
}
}
}
}
out.println();
out.println("Available Statistics:");
out.println();
for (Class clazz : statistics) {
if (Experimental.class.isAssignableFrom(clazz)) {
continue;
}
Constructor[] constructors = clazz.getConstructors();
for (Constructor constructor : constructors) {
if (constructor.getParameterTypes().length == 0) {
Statistic statistic = (Statistic) constructor.newInstance();
out.println(clazz.getSimpleName() + ": " + statistic.getDescription());
}
}
}
out.println();
out.println("Available Independence Tests:");
out.println();
for (Class clazz : independenceWrappers) {
if (Experimental.class.isAssignableFrom(clazz)) {
continue;
}
Constructor[] constructors = clazz.getConstructors();
for (Constructor constructor : constructors) {
if (constructor.getParameterTypes().length == 0) {
IndependenceWrapper independence = (IndependenceWrapper) constructor.newInstance();
out.println(clazz.getSimpleName() + ": " + independence.getDescription());
if (HasParameters.class.isAssignableFrom(clazz)) {
printParameters(independence.getParameters(), allParams, out);
}
}
}
}
out.println();
out.println("Available Scores:");
out.println();
for (Class clazz : scoreWrappers) {
if (Experimental.class.isAssignableFrom(clazz)) {
continue;
}
Constructor[] constructors = clazz.getConstructors();
for (Constructor constructor : constructors) {
if (constructor.getParameterTypes().length == 0) {
ScoreWrapper score = (ScoreWrapper) constructor.newInstance();
out.println(clazz.getSimpleName() + ": " + score.getDescription());
if (HasParameters.class.isAssignableFrom(clazz)) {
printParameters(score.getParameters(), allParams, out);
}
}
}
}
out.println();
out.println("Available Simulations:");
out.println();
for (Class clazz : simulations) {
if (Experimental.class.isAssignableFrom(clazz)) {
continue;
}
Constructor[] constructors = clazz.getConstructors();
for (Constructor constructor : constructors) {
if (constructor.getParameterTypes().length == 0) {
Simulation simulation = (Simulation) constructor.newInstance();
out.println(clazz.getSimpleName() + ": " + simulation.getDescription());
if (HasParameters.class.isAssignableFrom(clazz)) {
printParameters(simulation.getParameters(), allParams, out);
}
}
}
}
out.println();
out.close();
} catch (Exception e) {
e.printStackTrace();
}
}
use of edu.cmu.tetrad.algcomparison.statistic.Statistic in project tetrad by cmu-phil.
the class Comparison method doRun.
private void doRun(List<AlgorithmSimulationWrapper> algorithmSimulationWrappers, List<AlgorithmWrapper> algorithmWrappers, List<SimulationWrapper> simulationWrappers, Statistics statistics, int numGraphTypes, double[][][][] allStats, Run run) {
System.out.println();
System.out.println("Run " + (run.getRunIndex() + 1));
System.out.println();
AlgorithmSimulationWrapper algorithmSimulationWrapper = algorithmSimulationWrappers.get(run.getAlgSimIndex());
AlgorithmWrapper algorithmWrapper = algorithmSimulationWrapper.getAlgorithmWrapper();
SimulationWrapper simulationWrapper = algorithmSimulationWrapper.getSimulationWrapper();
DataModel data = simulationWrapper.getDataModel(run.getRunIndex());
Graph trueGraph = simulationWrapper.getTrueGraph(run.getRunIndex());
System.out.println((run.getAlgSimIndex() + 1) + ". " + algorithmWrapper.getDescription() + " simulationWrapper: " + simulationWrapper.getDescription());
long start = System.currentTimeMillis();
Graph out;
try {
Algorithm algorithm = algorithmWrapper.getAlgorithm();
Simulation simulation = simulationWrapper.getSimulation();
if (algorithm instanceof HasKnowledge && simulation instanceof HasKnowledge) {
((HasKnowledge) algorithm).setKnowledge(((HasKnowledge) simulation).getKnowledge());
}
if (algorithmWrapper.getAlgorithm() instanceof ExternalAlgorithm) {
ExternalAlgorithm external = (ExternalAlgorithm) algorithmWrapper.getAlgorithm();
external.setSimulation(simulationWrapper.getSimulation());
external.setPath(resultsPath);
external.setSimIndex(simulationWrappers.indexOf(simulationWrapper));
}
if (algorithm instanceof MultiDataSetAlgorithm) {
List<Integer> indices = new ArrayList<>();
int numDataModels = simulationWrapper.getSimulation().getNumDataModels();
for (int i = 0; i < numDataModels; i++) indices.add(i);
Collections.shuffle(indices);
List<DataModel> dataModels = new ArrayList<>();
int randomSelectionSize = algorithmWrapper.getAlgorithmSpecificParameters().getInt("randomSelectionSize");
for (int i = 0; i < Math.min(numDataModels, randomSelectionSize); i++) {
dataModels.add(simulationWrapper.getSimulation().getDataModel(indices.get(i)));
}
Parameters _params = algorithmWrapper.getAlgorithmSpecificParameters();
out = ((MultiDataSetAlgorithm) algorithm).search(dataModels, _params);
} else {
DataModel dataModel = copyData ? data.copy() : data;
Parameters _params = algorithmWrapper.getAlgorithmSpecificParameters();
out = algorithm.search(dataModel, _params);
}
} catch (Exception e) {
System.out.println("Could not run " + algorithmWrapper.getDescription());
e.printStackTrace();
return;
}
int simIndex = simulationWrappers.indexOf(simulationWrapper) + 1;
int algIndex = algorithmWrappers.indexOf(algorithmWrapper) + 1;
long stop = System.currentTimeMillis();
long elapsed = stop - start;
saveGraph(resultsPath, out, run.getRunIndex(), simIndex, algIndex, algorithmWrapper, elapsed);
if (trueGraph != null) {
out = GraphUtils.replaceNodes(out, trueGraph.getNodes());
}
if (algorithmWrapper.getAlgorithm() instanceof ExternalAlgorithm) {
ExternalAlgorithm extAlg = (ExternalAlgorithm) algorithmWrapper.getAlgorithm();
extAlg.setSimIndex(simulationWrappers.indexOf(simulationWrapper));
extAlg.setSimulation(simulationWrapper.getSimulation());
extAlg.setPath(resultsPath);
elapsed = extAlg.getElapsedTime(data, simulationWrapper.getSimulationSpecificParameters());
}
Graph[] est = new Graph[numGraphTypes];
Graph comparisonGraph;
if (this.comparisonGraph == ComparisonGraph.true_DAG) {
comparisonGraph = new EdgeListGraph(trueGraph);
} else if (this.comparisonGraph == ComparisonGraph.Pattern_of_the_true_DAG) {
comparisonGraph = SearchGraphUtils.patternForDag(new EdgeListGraph(trueGraph));
} else if (this.comparisonGraph == ComparisonGraph.PAG_of_the_true_DAG) {
comparisonGraph = new DagToPag(new EdgeListGraph(trueGraph)).convert();
} else {
throw new IllegalArgumentException("Unrecognized graph type.");
}
// Graph comparisonGraph = trueGraph == null ? null : algorithmSimulationWrapper.getComparisonGraph(trueGraph);
est[0] = out;
graphTypeUsed[0] = true;
if (data.isMixed()) {
est[1] = getSubgraph(out, true, true, data);
est[2] = getSubgraph(out, true, false, data);
est[3] = getSubgraph(out, false, false, data);
graphTypeUsed[1] = true;
graphTypeUsed[2] = true;
graphTypeUsed[3] = true;
}
Graph[] truth = new Graph[numGraphTypes];
truth[0] = comparisonGraph;
if (data.isMixed() && comparisonGraph != null) {
truth[1] = getSubgraph(comparisonGraph, true, true, data);
truth[2] = getSubgraph(comparisonGraph, true, false, data);
truth[3] = getSubgraph(comparisonGraph, false, false, data);
}
if (comparisonGraph != null) {
for (int u = 0; u < numGraphTypes; u++) {
if (!graphTypeUsed[u])
continue;
int statIndex = -1;
for (Statistic _stat : statistics.getStatistics()) {
statIndex++;
if (_stat instanceof ParameterColumn)
continue;
double stat;
if (_stat instanceof ElapsedTime) {
stat = elapsed / 1000.0;
} else {
stat = _stat.getValue(truth[u], est[u]);
}
allStats[u][run.getAlgSimIndex()][statIndex][run.getRunIndex()] = stat;
}
}
}
}
use of edu.cmu.tetrad.algcomparison.statistic.Statistic in project tetrad by cmu-phil.
the class Comparison method calcUtilities.
private double[] calcUtilities(Statistics statistics, List<AlgorithmSimulationWrapper> wrappers, double[][] stats) {
// Calculate utilities for the first table.
double[] utilities = new double[wrappers.size()];
for (int t = 0; t < wrappers.size(); t++) {
int j = -1;
Iterator it2 = statistics.getStatistics().iterator();
double sum = 0.0;
double max = 0.0;
while (it2.hasNext()) {
Statistic stat = (Statistic) it2.next();
j++;
double weight = statistics.getWeight(stat);
if (weight != 0.0) {
sum += weight * stat.getNormValue(stats[t][j]);
max += weight;
}
}
utilities[t] = sum / max;
}
return utilities;
}
use of edu.cmu.tetrad.algcomparison.statistic.Statistic in project tetrad by cmu-phil.
the class Comparison method printStats.
private void printStats(double[][][] statTables, Statistics statistics, Mode mode, int[] newOrder, List<AlgorithmSimulationWrapper> algorithmSimulationWrappers, List<AlgorithmWrapper> algorithmWrappers, List<SimulationWrapper> simulationWrappers, double[] utilities, Parameters parameters) {
if (mode == Mode.Average) {
out.println("AVERAGE STATISTICS");
} else if (mode == Mode.StandardDeviation) {
out.println("STANDARD DEVIATIONS");
} else if (mode == Mode.WorstCase) {
out.println("WORST CASE");
} else {
throw new IllegalStateException();
}
int numTables = statTables.length;
int numStats = statistics.size();
NumberFormat nf = new DecimalFormat("0.00");
NumberFormat smallNf = new DecimalFormat("0.00E0");
out.println();
for (int u = 0; u < numTables; u++) {
if (!graphTypeUsed[u])
continue;
int rows = algorithmSimulationWrappers.size() + 1;
int cols = (isShowSimulationIndices() ? 1 : 0) + (isShowAlgorithmIndices() ? 1 : 0) + numStats + (isShowUtilities() ? 1 : 0);
TextTable table = new TextTable(rows, cols);
table.setTabDelimited(isTabDelimitedTables());
int initialColumn = 0;
if (isShowSimulationIndices()) {
table.setToken(0, initialColumn, "Sim");
for (int t = 0; t < algorithmSimulationWrappers.size(); t++) {
Simulation simulation = algorithmSimulationWrappers.get(newOrder[t]).getSimulationWrapper();
table.setToken(t + 1, initialColumn, "" + (simulationWrappers.indexOf(simulation) + 1));
}
initialColumn++;
}
if (isShowAlgorithmIndices()) {
table.setToken(0, initialColumn, "Alg");
for (int t = 0; t < algorithmSimulationWrappers.size(); t++) {
AlgorithmWrapper algorithm = algorithmSimulationWrappers.get(newOrder[t]).getAlgorithmWrapper();
table.setToken(t + 1, initialColumn, "" + (algorithmWrappers.indexOf(algorithm) + 1));
}
initialColumn++;
}
for (int statIndex = 0; statIndex < numStats; statIndex++) {
String statLabel = statistics.getStatistics().get(statIndex).getAbbreviation();
table.setToken(0, initialColumn + statIndex, statLabel);
}
if (isShowUtilities()) {
table.setToken(0, initialColumn + numStats, "U");
}
for (int t = 0; t < algorithmSimulationWrappers.size(); t++) {
for (int statIndex = 0; statIndex < numStats; statIndex++) {
Statistic statistic = statistics.getStatistics().get(statIndex);
final AlgorithmWrapper algorithmWrapper = algorithmSimulationWrappers.get(newOrder[t]).getAlgorithmWrapper();
final SimulationWrapper simulationWrapper = algorithmSimulationWrappers.get(newOrder[t]).getSimulationWrapper();
Algorithm algorithm = algorithmWrapper.getAlgorithm();
Simulation simulation = simulationWrapper.getSimulation();
if (algorithm instanceof HasParameterValues) {
parameters.putAll(((HasParameterValues) algorithm).getParameterValues());
}
if (simulation instanceof HasParameterValues) {
parameters.putAll(((HasParameterValues) simulation).getParameterValues());
}
final String abbreviation = statistic.getAbbreviation();
Object[] o = parameters.getValues(abbreviation);
if (o.length == 1 && o[0] instanceof String) {
table.setToken(t + 1, initialColumn + statIndex, (String) o[0]);
continue;
}
double stat = statTables[u][newOrder[t]][statIndex];
if (stat == 0.0) {
table.setToken(t + 1, initialColumn + statIndex, "-");
} else if (stat == Double.POSITIVE_INFINITY) {
table.setToken(t + 1, initialColumn + statIndex, "Yes");
} else if (stat == Double.NEGATIVE_INFINITY) {
table.setToken(t + 1, initialColumn + statIndex, "No");
} else if (Double.isNaN(stat)) {
table.setToken(t + 1, initialColumn + statIndex, "*");
} else {
table.setToken(t + 1, initialColumn + statIndex, Math.abs(stat) < Math.pow(10, -smallNf.getMaximumFractionDigits()) && stat != 0 ? smallNf.format(stat) : nf.format(stat));
}
}
if (isShowUtilities()) {
table.setToken(t + 1, initialColumn + numStats, nf.format(utilities[newOrder[t]]));
}
}
out.println(getHeader(u));
out.println();
out.println(table);
}
}
Aggregations