use of edu.cmu.tetrad.algcomparison.algorithm.Algorithm in project tetrad by cmu-phil.
the class GeneralAlgorithmEditor method restorePreviousState.
private void restorePreviousState(Map<String, Object> models) {
Object obj = models.get(LINEAR_PARAM);
if ((obj != null) && (obj instanceof Boolean)) {
linearVarChkBox.setSelected((Boolean) obj);
}
obj = models.get(GAUSSIAN_PARAM);
if ((obj != null) && (obj instanceof Boolean)) {
gaussianVarChkBox.setSelected((Boolean) obj);
}
obj = models.get(KNOWLEDGE_PARAM);
if ((obj != null) && (obj instanceof Boolean)) {
knowledgeChkBox.setSelected((Boolean) obj);
}
obj = models.get(ALGO_TYPE_PARAM);
if ((obj != null) && (obj instanceof String)) {
String actCmd = String.valueOf(obj);
Optional<JRadioButton> opt = algoTypeOpts.stream().filter(e -> e.getActionCommand().equals(actCmd)).findFirst();
if (opt.isPresent()) {
opt.get().setSelected(true);
}
}
refreshAlgorithmList();
refreshTestAndScoreList();
obj = models.get(ALGO_PARAM);
if ((obj != null) && (obj instanceof AlgorithmModel)) {
String value = ((AlgorithmModel) obj).toString();
Enumeration<AlgorithmModel> enums = algoModels.elements();
while (enums.hasMoreElements()) {
AlgorithmModel model = enums.nextElement();
if (model.toString().equals(value)) {
models.put(ALGO_PARAM, model);
algorithmList.setSelectedValue(model, true);
String title = String.format("Algorithm: %s", model.getAlgorithm().getAnnotation().name());
algorithmGraphTitle.setText(title);
break;
}
}
}
obj = models.get(IND_TEST_PARAM);
if ((obj != null) && (obj instanceof IndependenceTestModel)) {
String value = ((IndependenceTestModel) obj).toString();
ComboBoxModel<IndependenceTestModel> comboBoxModels = indTestComboBox.getModel();
int size = comboBoxModels.getSize();
for (int i = 0; i < size; i++) {
IndependenceTestModel model = comboBoxModels.getElementAt(i);
if (model.toString().equals(value)) {
models.put(IND_TEST_PARAM, model);
indTestComboBox.getModel().setSelectedItem(model);
break;
}
}
}
obj = models.get(SCORE_PARAM);
if ((obj != null) && (obj instanceof ScoreModel)) {
String value = ((ScoreModel) obj).toString();
ComboBoxModel<ScoreModel> comboBoxModels = scoreComboBox.getModel();
int size = comboBoxModels.getSize();
for (int i = 0; i < size; i++) {
ScoreModel model = comboBoxModels.getElementAt(i);
if (model.toString().equals(value)) {
models.put(SCORE_PARAM, model);
scoreComboBox.getModel().setSelectedItem(model);
break;
}
}
}
}
use of edu.cmu.tetrad.algcomparison.algorithm.Algorithm in project tetrad by cmu-phil.
the class GeneralAlgorithmEditor method validateAlgorithmOption.
private void validateAlgorithmOption() {
paramSetFwdBtn.setEnabled(true);
AlgorithmModel algoModel = algorithmList.getSelectedValue();
Class algoClass = algoModel.getAlgorithm().getClazz();
if (algoClass.isAnnotationPresent(Nonexecutable.class)) {
String msg;
try {
Object algo = algoClass.newInstance();
Method m = algoClass.getDeclaredMethod("getDescription");
m.setAccessible(true);
try {
msg = String.valueOf(m.invoke(algo));
} catch (InvocationTargetException exception) {
msg = "";
}
} catch (IllegalAccessException | InstantiationException | NoSuchMethodException exception) {
LOGGER.error("", exception);
msg = "";
}
paramSetFwdBtn.setEnabled(false);
JOptionPane.showMessageDialog(desktop, msg, "Please Note", JOptionPane.INFORMATION_MESSAGE);
} else {
// Check if initial graph is provided for those pairwise algorithms
if (TakesInitialGraph.class.isAssignableFrom(algoClass)) {
if (runner.getSourceGraph() == null || runner.getDataModelList().isEmpty()) {
try {
Object algo = algoClass.newInstance();
Method m = algoClass.getDeclaredMethod("setInitialGraph", Algorithm.class);
m.setAccessible(true);
try {
Algorithm algorithm = null;
m.invoke(algo, algorithm);
} catch (InvocationTargetException | IllegalArgumentException exception) {
paramSetFwdBtn.setEnabled(false);
JOptionPane.showMessageDialog(desktop, exception.getCause().getMessage(), "Please Note", JOptionPane.INFORMATION_MESSAGE);
}
} catch (IllegalAccessException | InstantiationException | NoSuchMethodException exception) {
LOGGER.error("", exception);
}
}
}
}
// Check dataset data type for those algorithms take mixed data?
}
use of edu.cmu.tetrad.algcomparison.algorithm.Algorithm in project tetrad by cmu-phil.
the class ExampleFirstInflection method main.
public static void main(String... args) {
Parameters parameters = new Parameters();
parameters.set("numMeasures", 40, 100);
parameters.set("avgDegree", 2);
parameters.set("sampleSize", 400, 800);
parameters.set("numRuns", 10);
parameters.set("differentGraphs", true);
parameters.set("numLatents", 0);
parameters.set("maxDegree", 100);
parameters.set("maxIndegree", 100);
parameters.set("maxOutdegree", 100);
parameters.set("connected", false);
parameters.set("coefLow", 0.2);
parameters.set("coefHigh", 0.9);
parameters.set("varLow", 1);
parameters.set("varHigh", 3);
parameters.set("verbose", false);
parameters.set("coefSymmetric", true);
parameters.set("percentDiscrete", 0);
parameters.set("numCategories", 3);
parameters.set("differentGraphs", true);
parameters.set("intervalBetweenShocks", 10);
parameters.set("intervalBetweenRecordings", 10);
parameters.set("fisherEpsilon", 0.001);
parameters.set("randomizeColumns", true);
parameters.set("alpha", 1e-8);
parameters.set("depth", -1);
parameters.set("penaltyDiscount", 4);
parameters.set("useMaxPOrientationHeuristic", false);
parameters.set("maxPOrientationMaxPathLength", 3);
parameters.set("verbose", false);
parameters.set("scaleFreeAlpha", 0.00001);
parameters.set("scaleFreeBeta", 0.4);
parameters.set("scaleFreeDeltaIn", .1);
parameters.set("scaleFreeDeltaOut", 3);
parameters.set("symmetricFirstStep", false);
parameters.set("faithfulnessAssumed", true);
parameters.set("maxDegree", 100);
// parameters.set("logScale", true);
Statistics statistics = new Statistics();
statistics.add(new ParameterColumn("numMeasures"));
statistics.add(new ParameterColumn("avgDegree"));
statistics.add(new ParameterColumn("sampleSize"));
statistics.add(new AdjacencyPrecision());
statistics.add(new AdjacencyRecall());
statistics.add(new ArrowheadPrecision());
statistics.add(new ArrowheadRecall());
statistics.add(new ElapsedTime());
statistics.setWeight("AP", 0.25);
statistics.setWeight("AR", 0.25);
statistics.setWeight("AHP", 0.25);
statistics.setWeight("AHR", 0.25);
Algorithms algorithms = new Algorithms();
Algorithm fges = new Fges(new SemBicScore());
// algorithms.add(new FirstInflection(fges, "alpha", -7, -2, -.5));
algorithms.add(new FirstInflection(fges, "penaltyDiscount", 0.7, 5, 1));
Simulations simulations = new Simulations();
simulations.add(new LinearFisherModel(new RandomForward()));
Comparison comparison = new Comparison();
comparison.setShowAlgorithmIndices(true);
comparison.setShowSimulationIndices(true);
comparison.setSortByUtility(false);
comparison.setShowUtilities(false);
comparison.setParallelized(true);
comparison.setComparisonGraph(Comparison.ComparisonGraph.Pattern_of_the_true_DAG);
comparison.compareFromSimulations("first.inflection", simulations, algorithms, statistics, parameters);
}
use of edu.cmu.tetrad.algcomparison.algorithm.Algorithm in project tetrad by cmu-phil.
the class TimeoutComparison method printStats.
private void printStats(double[][][] statTables, Statistics statistics, Mode mode, int[] newOrder, List<AlgorithmSimulationWrapper> algorithmSimulationWrappers, List<AlgorithmWrapper> algorithmWrappers, List<SimulationWrapper> simulationWrappers, double[] utilities, Parameters parameters) {
if (mode == Mode.Average) {
out.println("AVERAGE STATISTICS");
} else if (mode == Mode.StandardDeviation) {
out.println("STANDARD DEVIATIONS");
} else if (mode == Mode.WorstCase) {
out.println("WORST CASE");
} else {
throw new IllegalStateException();
}
int numTables = statTables.length;
int numStats = statistics.size();
NumberFormat nf = new DecimalFormat("0.00");
NumberFormat smallNf = new DecimalFormat("0.00E0");
out.println();
for (int u = 0; u < numTables; u++) {
if (!graphTypeUsed[u]) {
continue;
}
int rows = algorithmSimulationWrappers.size() + 1;
int cols = (isShowSimulationIndices() ? 1 : 0) + (isShowAlgorithmIndices() ? 1 : 0) + numStats + (isShowUtilities() ? 1 : 0);
TextTable table = new TextTable(rows, cols);
table.setTabDelimited(isTabDelimitedTables());
int initialColumn = 0;
if (isShowSimulationIndices()) {
table.setToken(0, initialColumn, "Sim");
for (int t = 0; t < algorithmSimulationWrappers.size(); t++) {
Simulation simulation = algorithmSimulationWrappers.get(newOrder[t]).getSimulationWrapper();
table.setToken(t + 1, initialColumn, "" + (simulationWrappers.indexOf(simulation) + 1));
}
initialColumn++;
}
if (isShowAlgorithmIndices()) {
table.setToken(0, initialColumn, "Alg");
for (int t = 0; t < algorithmSimulationWrappers.size(); t++) {
AlgorithmWrapper algorithm = algorithmSimulationWrappers.get(newOrder[t]).getAlgorithmWrapper();
table.setToken(t + 1, initialColumn, "" + (algorithmWrappers.indexOf(algorithm) + 1));
}
initialColumn++;
}
for (int statIndex = 0; statIndex < numStats; statIndex++) {
String statLabel = statistics.getStatistics().get(statIndex).getAbbreviation();
table.setToken(0, initialColumn + statIndex, statLabel);
}
if (isShowUtilities()) {
table.setToken(0, initialColumn + numStats, "U");
}
for (int t = 0; t < algorithmSimulationWrappers.size(); t++) {
for (int statIndex = 0; statIndex < numStats; statIndex++) {
Statistic statistic = statistics.getStatistics().get(statIndex);
final AlgorithmWrapper algorithmWrapper = algorithmSimulationWrappers.get(newOrder[t]).getAlgorithmWrapper();
final SimulationWrapper simulationWrapper = algorithmSimulationWrappers.get(newOrder[t]).getSimulationWrapper();
Algorithm algorithm = algorithmWrapper.getAlgorithm();
Simulation simulation = simulationWrapper.getSimulation();
if (algorithm instanceof HasParameterValues) {
parameters.putAll(((HasParameterValues) algorithm).getParameterValues());
}
if (simulation instanceof HasParameterValues) {
parameters.putAll(((HasParameterValues) simulation).getParameterValues());
}
final String abbreviation = statistic.getAbbreviation();
Object[] o = parameters.getValues(abbreviation);
if (o.length == 1 && o[0] instanceof String) {
table.setToken(t + 1, initialColumn + statIndex, (String) o[0]);
continue;
}
double stat = statTables[u][newOrder[t]][statIndex];
if (stat == 0.0) {
table.setToken(t + 1, initialColumn + statIndex, "-");
} else if (stat == Double.POSITIVE_INFINITY) {
table.setToken(t + 1, initialColumn + statIndex, "Yes");
} else if (stat == Double.NEGATIVE_INFINITY) {
table.setToken(t + 1, initialColumn + statIndex, "No");
} else if (Double.isNaN(stat)) {
table.setToken(t + 1, initialColumn + statIndex, "*");
} else {
table.setToken(t + 1, initialColumn + statIndex, Math.abs(stat) < Math.pow(10, -smallNf.getMaximumFractionDigits()) && stat != 0 ? smallNf.format(stat) : nf.format(stat));
}
}
if (isShowUtilities()) {
table.setToken(t + 1, initialColumn + numStats, nf.format(utilities[newOrder[t]]));
}
}
out.println(getHeader(u));
out.println();
out.println(table);
}
}
use of edu.cmu.tetrad.algcomparison.algorithm.Algorithm in project tetrad by cmu-phil.
the class Comparison method compareFromSimulations.
/**
* Compares algorithms.
*
* @param resultsPath Path to the file where the output should be printed.
* @param simulations The list of simulationWrapper that is used to generate graphs and data for the comparison.
* @param algorithms The list of algorithms to be compared.
* @param statistics The list of statistics on which to compare the algorithm, and their utility weights.
*/
public void compareFromSimulations(String resultsPath, Simulations simulations, String outputFileName, Algorithms algorithms, Statistics statistics, Parameters parameters) {
this.resultsPath = resultsPath;
// Create output file.
try {
File dir = new File(resultsPath);
dir.mkdirs();
File file = new File(dir, outputFileName);
this.out = new PrintStream(new FileOutputStream(file));
} catch (Exception e) {
throw new RuntimeException(e);
}
out.println(new Date());
// Set up simulations--create data and graphs, read in parameters. The parameters
// are set in the parameters object.
List<SimulationWrapper> simulationWrappers = new ArrayList<>();
int numRuns = parameters.getInt("numRuns");
for (Simulation simulation : simulations.getSimulations()) {
List<SimulationWrapper> wrappers = getSimulationWrappers(simulation, parameters);
for (SimulationWrapper wrapper : wrappers) {
wrapper.createData(wrapper.getSimulationSpecificParameters());
simulationWrappers.add(wrapper);
}
}
// Set up the algorithms.
List<AlgorithmWrapper> algorithmWrappers = new ArrayList<>();
for (Algorithm algorithm : algorithms.getAlgorithms()) {
List<Integer> _dims = new ArrayList<>();
List<String> varyingParameters = new ArrayList<>();
final List<String> parameters1 = algorithm.getParameters();
for (String name : parameters1) {
if (parameters.getNumValues(name) > 1) {
_dims.add(parameters.getNumValues(name));
varyingParameters.add(name);
}
}
if (varyingParameters.isEmpty()) {
algorithmWrappers.add(new AlgorithmWrapper(algorithm, parameters));
} else {
int[] dims = new int[_dims.size()];
for (int i = 0; i < _dims.size(); i++) dims[i] = _dims.get(i);
CombinationGenerator gen = new CombinationGenerator(dims);
int[] choice;
while ((choice = gen.next()) != null) {
AlgorithmWrapper wrapper = new AlgorithmWrapper(algorithm, parameters);
for (int h = 0; h < dims.length; h++) {
String parameter = varyingParameters.get(h);
Object[] values = parameters.getValues(parameter);
Object value = values[choice[h]];
wrapper.setValue(parameter, value);
}
algorithmWrappers.add(wrapper);
}
}
}
// Create the algorithm-simulation wrappers for every combination of algorithm and
// simulation.
List<AlgorithmSimulationWrapper> algorithmSimulationWrappers = new ArrayList<>();
for (SimulationWrapper simulationWrapper : simulationWrappers) {
for (AlgorithmWrapper algorithmWrapper : algorithmWrappers) {
DataType algDataType = algorithmWrapper.getDataType();
DataType simDataType = simulationWrapper.getDataType();
if (!(algDataType == DataType.Mixed || (algDataType == simDataType))) {
System.out.println("Type mismatch: " + algorithmWrapper.getDescription() + " / " + simulationWrapper.getDescription());
}
if (algorithmWrapper.getAlgorithm() instanceof ExternalAlgorithm) {
ExternalAlgorithm external = (ExternalAlgorithm) algorithmWrapper.getAlgorithm();
// external.setSimulation(simulationWrapper.getSimulation());
// external.setPath(dirs.get(simulationWrappers.indexOf(simulationWrapper)));
// external.setPath(resultsPath);
external.setSimIndex(simulationWrappers.indexOf(external.getSimulation()));
}
algorithmSimulationWrappers.add(new AlgorithmSimulationWrapper(algorithmWrapper, simulationWrapper));
}
}
// Run all of the algorithms and compile statistics.
double[][][][] allStats = calcStats(algorithmSimulationWrappers, algorithmWrappers, simulationWrappers, statistics, numRuns);
// Print out the preliminary information for statistics types, etc.
if (allStats != null) {
out.println();
out.println("Statistics:");
out.println();
for (Statistic stat : statistics.getStatistics()) {
out.println(stat.getAbbreviation() + " = " + stat.getDescription());
}
}
out.println();
if (allStats != null) {
int numTables = allStats.length;
int numStats = allStats[0][0].length - 1;
double[][][] statTables = calcStatTables(allStats, Mode.Average, numTables, algorithmSimulationWrappers, numStats, statistics);
double[] utilities = calcUtilities(statistics, algorithmSimulationWrappers, statTables[0]);
// Add utilities to table as the last column.
for (int u = 0; u < numTables; u++) {
for (int t = 0; t < algorithmSimulationWrappers.size(); t++) {
statTables[u][t][numStats] = utilities[t];
}
}
int[] newOrder;
if (isSortByUtility()) {
newOrder = sort(algorithmSimulationWrappers, utilities);
} else {
newOrder = new int[algorithmSimulationWrappers.size()];
for (int q = 0; q < algorithmSimulationWrappers.size(); q++) {
newOrder[q] = q;
}
}
out.println("Simulations:");
out.println();
// if (simulationWrappers.size() == 1) {
// out.println(simulationWrappers.get(0).getDescription());
// } else {
int i = 0;
for (SimulationWrapper simulation : simulationWrappers) {
out.print("Simulation " + (++i) + ": ");
out.println(simulation.getDescription());
out.println();
printParameters(simulation.getParameters(), simulation.getSimulationSpecificParameters(), out);
// for (String param : simulation.getParameters()) {
// out.println(param + " = " + simulation.getValue(param));
// }
out.println();
}
// }
out.println("Algorithms:");
out.println();
for (int t = 0; t < algorithmSimulationWrappers.size(); t++) {
AlgorithmSimulationWrapper wrapper = algorithmSimulationWrappers.get(t);
if (wrapper.getSimulationWrapper() == simulationWrappers.get(0)) {
out.println((t + 1) + ". " + wrapper.getAlgorithmWrapper().getDescription());
}
}
if (isSortByUtility()) {
out.println();
out.println("Sorting by utility, high to low.");
}
if (isShowUtilities()) {
out.println();
out.println("Weighting of statistics:");
out.println();
out.println("U = ");
for (Statistic stat : statistics.getStatistics()) {
String statName = stat.getAbbreviation();
double weight = statistics.getWeight(stat);
if (weight != 0.0) {
out.println(" " + weight + " * f(" + statName + ")");
}
}
out.println();
out.println("...normed to range between 0 and 1.");
out.println();
out.println("Note that f for each statistic is a function that maps the statistic to the ");
out.println("interval [0, 1], with higher being better.");
}
out.println();
out.println("Graphs are being compared to the " + comparisonGraph.toString().replace("_", " ") + ".");
out.println();
// Add utilities to table as the last column.
for (int u = 0; u < numTables; u++) {
for (int t = 0; t < algorithmSimulationWrappers.size(); t++) {
statTables[u][t][numStats] = utilities[t];
}
}
// Print all of the tables.
printStats(statTables, statistics, Mode.Average, newOrder, algorithmSimulationWrappers, algorithmWrappers, simulationWrappers, utilities, parameters);
statTables = calcStatTables(allStats, Mode.StandardDeviation, numTables, algorithmSimulationWrappers, numStats, statistics);
printStats(statTables, statistics, Mode.StandardDeviation, newOrder, algorithmSimulationWrappers, algorithmWrappers, simulationWrappers, utilities, parameters);
statTables = calcStatTables(allStats, Mode.WorstCase, numTables, algorithmSimulationWrappers, numStats, statistics);
// Add utilities to table as the last column.
for (int u = 0; u < numTables; u++) {
for (int t = 0; t < algorithmSimulationWrappers.size(); t++) {
statTables[u][t][numStats] = utilities[t];
}
}
printStats(statTables, statistics, Mode.WorstCase, newOrder, algorithmSimulationWrappers, algorithmWrappers, simulationWrappers, utilities, parameters);
}
out.close();
}
Aggregations