use of edu.cmu.tetrad.algcomparison.algorithm.Algorithm in project tetrad by cmu-phil.
the class TestFges method main.
public static void main(String... args) {
if (args.length > 0) {
int numMeasures = Integer.parseInt(args[0]);
int avgDegree = Integer.parseInt(args[1]);
Parameters parameters = new Parameters();
parameters.set("numMeasures", numMeasures);
parameters.set("numLatents", 0);
parameters.set("avgDegree", avgDegree);
parameters.set("maxDegree", 20);
parameters.set("maxIndegree", 20);
parameters.set("maxOutdegree", 20);
parameters.set("connected", false);
parameters.set("coefLow", 0.2);
parameters.set("coefHigh", 0.9);
parameters.set("varLow", 1);
parameters.set("varHigh", 3);
parameters.set("verbose", false);
parameters.set("coefSymmetric", true);
parameters.set("numRuns", 1);
parameters.set("percentDiscrete", 0);
parameters.set("numCategories", 3);
parameters.set("differentGraphs", true);
parameters.set("sampleSize", 1000);
parameters.set("intervalBetweenShocks", 10);
parameters.set("intervalBetweenRecordings", 10);
parameters.set("fisherEpsilon", 0.001);
parameters.set("randomizeColumns", true);
RandomGraph graph = new RandomForward();
LinearFisherModel sim = new LinearFisherModel(graph);
sim.createData(parameters);
ScoreWrapper score = new FisherZScore();
Algorithm alg = new edu.cmu.tetrad.algcomparison.algorithm.oracle.pattern.Fges(score);
parameters.set("alpha", 1e-8);
for (int i = 0; i < 5; i++) {
Graph out1 = alg.search(sim.getDataModel(0), parameters);
System.out.println(out1);
}
} else {
new TestFges().test9();
}
}
use of edu.cmu.tetrad.algcomparison.algorithm.Algorithm in project tetrad by cmu-phil.
the class GeneralAlgorithmRunner method execute.
// ============================PUBLIC METHODS==========================//
@Override
public void execute() {
List<Graph> graphList = new ArrayList<>();
int i = 0;
if (getDataModelList().isEmpty()) {
if (getSourceGraph() != null) {
Algorithm algo = getAlgorithm();
if (algo instanceof HasKnowledge) {
((HasKnowledge) algo).setKnowledge(getKnowledge());
}
graphList.add(algo.search(null, parameters));
} else {
throw new IllegalArgumentException("The parent boxes did not include any datasets or graphs. Try opening\n" + "the editors for those boxes and loading or simulating them.");
}
} else {
if (getAlgorithm() instanceof MultiDataSetAlgorithm) {
for (int k = 0; k < parameters.getInt("numRuns"); k++) {
List<DataSet> dataSets = getDataModelList().stream().map(e -> (DataSet) e).collect(Collectors.toCollection(ArrayList::new));
if (dataSets.size() < parameters.getInt("randomSelectionSize")) {
throw new IllegalArgumentException("Sorry, the 'random selection size' is greater than " + "the number of data sets.");
}
Collections.shuffle(dataSets);
List<DataModel> sub = new ArrayList<>();
for (int j = 0; j < parameters.getInt("randomSelectionSize"); j++) {
sub.add(dataSets.get(j));
}
Algorithm algo = getAlgorithm();
if (algo instanceof HasKnowledge) {
((HasKnowledge) algo).setKnowledge(getKnowledge());
}
graphList.add(((MultiDataSetAlgorithm) algo).search(sub, parameters));
}
} else if (getAlgorithm() instanceof ClusterAlgorithm) {
for (int k = 0; k < parameters.getInt("numRuns"); k++) {
getDataModelList().forEach(dataModel -> {
if (dataModel instanceof ICovarianceMatrix) {
ICovarianceMatrix dataSet = (ICovarianceMatrix) dataModel;
graphList.add(algorithm.search(dataSet, parameters));
} else if (dataModel instanceof DataSet) {
DataSet dataSet = (DataSet) dataModel;
if (!dataSet.isContinuous()) {
throw new IllegalArgumentException("Sorry, you need a continuous dataset for a cluster algorithm.");
}
graphList.add(algorithm.search(dataSet, parameters));
}
});
}
} else {
getDataModelList().forEach(data -> {
IKnowledge knowledgeFromData = data.getKnowledge();
if (!(knowledgeFromData == null || knowledgeFromData.getVariables().isEmpty())) {
this.knowledge = knowledgeFromData;
}
Algorithm algo = getAlgorithm();
if (algo instanceof HasKnowledge) {
((HasKnowledge) algo).setKnowledge(getKnowledge());
}
DataType algDataType = algo.getDataType();
if (data.isContinuous() && (algDataType == DataType.Continuous || algDataType == DataType.Mixed)) {
graphList.add(algo.search(data, parameters));
} else if (data.isDiscrete() && (algDataType == DataType.Discrete || algDataType == DataType.Mixed)) {
graphList.add(algo.search(data, parameters));
} else if (data.isMixed() && algDataType == DataType.Mixed) {
graphList.add(algo.search(data, parameters));
} else {
throw new IllegalArgumentException("The type of data changed; try opening up the search editor and " + "running the algorithm there.");
}
});
}
}
if (getKnowledge().getVariablesNotInTiers().size() < getKnowledge().getVariables().size()) {
for (Graph graph : graphList) {
SearchGraphUtils.arrangeByKnowledgeTiers(graph, getKnowledge());
}
} else {
for (Graph graph : graphList) {
GraphUtils.circleLayout(graph, 225, 200, 150);
}
}
this.graphList = graphList;
}
use of edu.cmu.tetrad.algcomparison.algorithm.Algorithm in project tetrad by cmu-phil.
the class TimeoutComparison method compareFromSimulations.
/**
* Compares algorithms.
*
* @param resultsPath Path to the file where the output should be printed.
* @param simulations The list of simulationWrapper that is used to generate
* graphs and data for the comparison.
* @param algorithms The list of algorithms to be compared.
* @param statistics The list of statistics on which to compare the
* algorithm, and their utility weights.
*/
public void compareFromSimulations(String resultsPath, Simulations simulations, String outputFileName, Algorithms algorithms, Statistics statistics, Parameters parameters, long timeout, TimeUnit unit) {
this.resultsPath = resultsPath;
// Create output file.
try {
File dir = new File(resultsPath);
dir.mkdirs();
File file = new File(dir, outputFileName);
this.out = new PrintStream(new FileOutputStream(file));
} catch (Exception e) {
throw new RuntimeException(e);
}
out.println(new Date());
// Set up simulations--create data and graphs, read in parameters. The parameters
// are set in the parameters object.
List<SimulationWrapper> simulationWrappers = new ArrayList<>();
int numRuns = parameters.getInt("numRuns");
for (Simulation simulation : simulations.getSimulations()) {
List<SimulationWrapper> wrappers = getSimulationWrappers(simulation, parameters);
for (SimulationWrapper wrapper : wrappers) {
wrapper.createData(wrapper.getSimulationSpecificParameters());
simulationWrappers.add(wrapper);
}
}
// Set up the algorithms.
List<AlgorithmWrapper> algorithmWrappers = new ArrayList<>();
for (Algorithm algorithm : algorithms.getAlgorithms()) {
List<Integer> _dims = new ArrayList<>();
List<String> varyingParameters = new ArrayList<>();
final List<String> parameters1 = algorithm.getParameters();
for (String name : parameters1) {
if (parameters.getNumValues(name) > 1) {
_dims.add(parameters.getNumValues(name));
varyingParameters.add(name);
}
}
if (varyingParameters.isEmpty()) {
algorithmWrappers.add(new AlgorithmWrapper(algorithm, parameters));
} else {
int[] dims = new int[_dims.size()];
for (int i = 0; i < _dims.size(); i++) {
dims[i] = _dims.get(i);
}
CombinationGenerator gen = new CombinationGenerator(dims);
int[] choice;
while ((choice = gen.next()) != null) {
AlgorithmWrapper wrapper = new AlgorithmWrapper(algorithm, parameters);
for (int h = 0; h < dims.length; h++) {
String parameter = varyingParameters.get(h);
Object[] values = parameters.getValues(parameter);
Object value = values[choice[h]];
wrapper.setValue(parameter, value);
}
algorithmWrappers.add(wrapper);
}
}
}
// Create the algorithm-simulation wrappers for every combination of algorithm and
// simulation.
List<AlgorithmSimulationWrapper> algorithmSimulationWrappers = new ArrayList<>();
for (SimulationWrapper simulationWrapper : simulationWrappers) {
for (AlgorithmWrapper algorithmWrapper : algorithmWrappers) {
DataType algDataType = algorithmWrapper.getDataType();
DataType simDataType = simulationWrapper.getDataType();
if (!(algDataType == DataType.Mixed || (algDataType == simDataType))) {
System.out.println("Type mismatch: " + algorithmWrapper.getDescription() + " / " + simulationWrapper.getDescription());
}
if (algorithmWrapper.getAlgorithm() instanceof ExternalAlgorithm) {
ExternalAlgorithm external = (ExternalAlgorithm) algorithmWrapper.getAlgorithm();
// external.setSimulation(simulationWrapper.getSimulation());
// external.setPath(dirs.get(simulationWrappers.indexOf(simulationWrapper)));
// external.setPath(resultsPath);
external.setSimIndex(simulationWrappers.indexOf(external.getSimulation()));
}
algorithmSimulationWrappers.add(new AlgorithmSimulationWrapper(algorithmWrapper, simulationWrapper));
}
}
// Run all of the algorithms and compile statistics.
double[][][][] allStats = calcStats(algorithmSimulationWrappers, algorithmWrappers, simulationWrappers, statistics, numRuns, timeout, unit);
// Print out the preliminary information for statistics types, etc.
if (allStats != null) {
out.println();
out.println("Statistics:");
out.println();
for (Statistic stat : statistics.getStatistics()) {
out.println(stat.getAbbreviation() + " = " + stat.getDescription());
}
}
out.println();
// out.println();
if (allStats != null) {
int numTables = allStats.length;
int numStats = allStats[0][0].length - 1;
double[][][] statTables = calcStatTables(allStats, Mode.Average, numTables, algorithmSimulationWrappers, numStats, statistics);
double[] utilities = calcUtilities(statistics, algorithmSimulationWrappers, statTables[0]);
// Add utilities to table as the last column.
for (int u = 0; u < numTables; u++) {
for (int t = 0; t < algorithmSimulationWrappers.size(); t++) {
statTables[u][t][numStats] = utilities[t];
}
}
int[] newOrder;
if (isSortByUtility()) {
newOrder = sort(algorithmSimulationWrappers, utilities);
} else {
newOrder = new int[algorithmSimulationWrappers.size()];
for (int q = 0; q < algorithmSimulationWrappers.size(); q++) {
newOrder[q] = q;
}
}
out.println("Simulations:");
out.println();
// if (simulationWrappers.size() == 1) {
// out.println(simulationWrappers.get(0).getDescription());
// } else {
int i = 0;
for (SimulationWrapper simulation : simulationWrappers) {
out.print("Simulation " + (++i) + ": ");
out.println(simulation.getDescription());
out.println();
printParameters(simulation.getParameters(), simulation.getSimulationSpecificParameters(), out);
// for (String param : simulation.getParameters()) {
// out.println(param + " = " + simulation.getValue(param));
// }
out.println();
}
// }
out.println("Algorithms:");
out.println();
for (int t = 0; t < algorithmSimulationWrappers.size(); t++) {
AlgorithmSimulationWrapper wrapper = algorithmSimulationWrappers.get(t);
if (wrapper.getSimulationWrapper() == simulationWrappers.get(0)) {
out.println((t + 1) + ". " + wrapper.getAlgorithmWrapper().getDescription());
}
}
if (isSortByUtility()) {
out.println();
out.println("Sorting by utility, high to low.");
}
if (isShowUtilities()) {
out.println();
out.println("Weighting of statistics:");
out.println();
out.println("U = ");
for (Statistic stat : statistics.getStatistics()) {
String statName = stat.getAbbreviation();
double weight = statistics.getWeight(stat);
if (weight != 0.0) {
out.println(" " + weight + " * f(" + statName + ")");
}
}
out.println();
out.println("...normed to range between 0 and 1.");
out.println();
out.println("Note that f for each statistic is a function that maps the statistic to the ");
out.println("interval [0, 1], with higher being better.");
}
out.println();
out.println("Graphs are being compared to the " + comparisonGraph.toString().replace("_", " ") + ".");
out.println();
// Add utilities to table as the last column.
for (int u = 0; u < numTables; u++) {
for (int t = 0; t < algorithmSimulationWrappers.size(); t++) {
statTables[u][t][numStats] = utilities[t];
}
}
// Print all of the tables.
printStats(statTables, statistics, Mode.Average, newOrder, algorithmSimulationWrappers, algorithmWrappers, simulationWrappers, utilities, parameters);
statTables = calcStatTables(allStats, Mode.StandardDeviation, numTables, algorithmSimulationWrappers, numStats, statistics);
printStats(statTables, statistics, Mode.StandardDeviation, newOrder, algorithmSimulationWrappers, algorithmWrappers, simulationWrappers, utilities, parameters);
statTables = calcStatTables(allStats, Mode.WorstCase, numTables, algorithmSimulationWrappers, numStats, statistics);
// Add utilities to table as the last column.
for (int u = 0; u < numTables; u++) {
for (int t = 0; t < algorithmSimulationWrappers.size(); t++) {
statTables[u][t][numStats] = utilities[t];
}
}
printStats(statTables, statistics, Mode.WorstCase, newOrder, algorithmSimulationWrappers, algorithmWrappers, simulationWrappers, utilities, parameters);
}
out.close();
}
use of edu.cmu.tetrad.algcomparison.algorithm.Algorithm in project tetrad by cmu-phil.
the class TimeoutComparison method generateReportFromExternalAlgorithms.
public void generateReportFromExternalAlgorithms(String dataPath, String resultsPath, String outputFileName, Algorithms algorithms, Statistics statistics, Parameters parameters, long timeout, TimeUnit unit) {
this.saveGraphs = false;
this.dataPath = dataPath;
this.resultsPath = resultsPath;
for (Algorithm algorithm : algorithms.getAlgorithms()) {
if (!(algorithm instanceof ExternalAlgorithm)) {
throw new IllegalArgumentException("Expecting all algorithms to implement ExternalAlgorithm.");
}
}
Simulations simulations = new Simulations();
File file = new File(this.dataPath, "save");
File[] dirs = file.listFiles();
if (dirs == null) {
throw new NullPointerException("No files in " + file.getAbsolutePath());
}
this.dirs = new ArrayList<String>();
int count = 0;
for (File dir : dirs) {
if (dir.getName().contains("DS_Store")) {
continue;
}
count++;
}
for (int i = 1; i <= count; i++) {
File _dir = new File(dataPath, "save/" + i);
simulations.add(new LoadDataAndGraphs(_dir.getAbsolutePath()));
this.dirs.add(_dir.getAbsolutePath());
}
compareFromSimulations(this.resultsPath, simulations, outputFileName, algorithms, statistics, parameters, timeout, unit);
}
use of edu.cmu.tetrad.algcomparison.algorithm.Algorithm in project tetrad by cmu-phil.
the class TimeoutComparison method compareFromFiles.
/**
* Compares algorithms.
*
* @param dataPath Path to the directory where data and graph files have
* been saved.
* @param resultsPath Path to the file where the results should be stored.
* @param algorithms The list of algorithms to be compared.
* @param statistics The list of statistics on which to compare the
* algorithm, and their utility weights.
* @param parameters The list of parameters and their values.
*/
public void compareFromFiles(String dataPath, String resultsPath, Algorithms algorithms, Statistics statistics, Parameters parameters, long timeout, TimeUnit unit) {
for (Algorithm algorithm : algorithms.getAlgorithms()) {
if (algorithm instanceof ExternalAlgorithm) {
throw new IllegalArgumentException("Not expecting any implementations of ExternalAlgorithm here.");
}
}
this.dataPath = dataPath;
this.resultsPath = resultsPath;
Simulations simulations = new Simulations();
File file = new File(this.dataPath, "save");
File[] dirs = file.listFiles();
if (dirs == null) {
throw new NullPointerException("No files in " + file.getAbsolutePath());
}
this.dirs = new ArrayList<String>();
int count = 0;
for (File dir : dirs) {
if (dir.getName().contains("DS_Store")) {
continue;
}
count++;
}
for (int i = 1; i <= count; i++) {
File _dir = new File(dataPath, "save/" + i);
simulations.add(new LoadDataAndGraphs(_dir.getAbsolutePath()));
this.dirs.add(_dir.getAbsolutePath());
}
compareFromSimulations(this.resultsPath, simulations, algorithms, statistics, parameters, timeout, unit);
}
Aggregations