use of edu.cmu.tetrad.data.DataType in project tetrad by cmu-phil.
the class GeneralAlgorithmEditor method initComponents.
private void initComponents() {
algoDescTextArea.setWrapStyleWord(true);
algoDescTextArea.setLineWrap(true);
algoDescTextArea.setEditable(false);
populateAlgoTypeOptions(algoTypeOpts);
knowledgeChkBox.addActionListener((e) -> {
refreshAlgorithmList();
});
linearVarChkBox.addActionListener((ActionEvent e) -> {
refreshTestAndScoreList();
});
gaussianVarChkBox.addActionListener((ActionEvent e) -> {
refreshTestAndScoreList();
});
algorithmList.addListSelectionListener((e) -> {
if (!(e.getValueIsAdjusting() || algorithmList.isSelectionEmpty())) {
setAlgorithmDescription();
refreshTestAndScoreList();
validateAlgorithmOption();
}
});
paramSetFwdBtn.addActionListener((e) -> {
AlgorithmModel algoModel = algorithmList.getSelectedValue();
IndependenceTestModel indTestModel = indTestComboBox.getItemAt(indTestComboBox.getSelectedIndex());
ScoreModel scoreModel = scoreComboBox.getItemAt(scoreComboBox.getSelectedIndex());
if (isValid(algoModel, indTestModel, scoreModel)) {
setParameterPanel(algoModel, indTestModel, scoreModel);
changeCard(PARAMETER_CARD);
}
});
indTestComboBox.addActionListener((e) -> {
if (!updatingTestModels && indTestComboBox.getSelectedIndex() >= 0) {
AlgorithmModel algoModel = algorithmList.getSelectedValue();
Map<DataType, IndependenceTestModel> map = defaultIndTestModels.get(algoModel);
if (map == null) {
map = new EnumMap<>(DataType.class);
defaultIndTestModels.put(algoModel, map);
}
map.put(dataType, indTestComboBox.getItemAt(indTestComboBox.getSelectedIndex()));
}
});
scoreComboBox.addActionListener((e) -> {
if (!updatingScoreModels && scoreComboBox.getSelectedIndex() >= 0) {
AlgorithmModel algoModel = algorithmList.getSelectedValue();
Map<DataType, ScoreModel> map = defaultScoreModels.get(algoModel);
if (map == null) {
map = new EnumMap<>(DataType.class);
defaultScoreModels.put(algoModel, map);
}
map.put(dataType, scoreComboBox.getItemAt(scoreComboBox.getSelectedIndex()));
}
});
mainPanel.add(new AlgorithmCard(), ALGORITHM_CARD);
mainPanel.add(new ParameterCard(), PARAMETER_CARD);
mainPanel.add(new GraphCard(), GRAPH_CARD);
mainPanel.setPreferredSize(new Dimension(940, 640));
setLayout(new BorderLayout());
add(mainPanel, BorderLayout.CENTER);
// add(new JScrollPane(mainPanel), BorderLayout.CENTER);
}
use of edu.cmu.tetrad.data.DataType in project tetrad by cmu-phil.
the class GeneralAlgorithmEditor method refreshScoreList.
private void refreshScoreList() {
updatingScoreModels = true;
scoreComboBox.removeAllItems();
AlgorithmModel algoModel = algorithmList.getSelectedValue();
if (algoModel != null && algoModel.isRequiredScore()) {
boolean linear = linearVarChkBox.isSelected();
boolean gaussian = gaussianVarChkBox.isSelected();
List<ScoreModel> models = ScoreModels.getInstance().getModels(dataType);
List<ScoreModel> scoreModels = new LinkedList<>();
if (linear && gaussian) {
models.stream().filter(e -> e.getScore().getClazz().isAnnotationPresent(Linear.class)).filter(e -> e.getScore().getClazz().isAnnotationPresent(Gaussian.class)).forEach(e -> scoreModels.add(e));
} else if (linear) {
models.stream().filter(e -> e.getScore().getClazz().isAnnotationPresent(Linear.class)).filter(e -> !e.getScore().getClazz().isAnnotationPresent(Gaussian.class)).forEach(e -> scoreModels.add(e));
} else if (gaussian) {
models.stream().filter(e -> !e.getScore().getClazz().isAnnotationPresent(Linear.class)).filter(e -> e.getScore().getClazz().isAnnotationPresent(Gaussian.class)).forEach(e -> scoreModels.add(e));
} else {
models.stream().forEach(e -> scoreModels.add(e));
}
// or BDeu score for discrete data
if (TsImages.class.equals(algoModel.getAlgorithm().getClazz())) {
switch(dataType) {
case Continuous:
scoreModels.stream().filter(e -> e.getScore().getClazz().equals(SemBicScore.class)).forEach(e -> scoreComboBox.addItem(e));
break;
case Discrete:
scoreModels.stream().filter(e -> e.getScore().getClazz().equals(BdeuScore.class)).forEach(e -> scoreComboBox.addItem(e));
break;
}
} else {
scoreModels.forEach(e -> scoreComboBox.addItem(e));
}
}
updatingScoreModels = false;
if (scoreComboBox.getItemCount() > 0) {
scoreComboBox.setEnabled(true);
Map<DataType, ScoreModel> map = defaultScoreModels.get(algoModel);
if (map == null) {
map = new EnumMap<>(DataType.class);
defaultScoreModels.put(algoModel, map);
}
ScoreModel scoreModel = map.get(dataType);
if (scoreModel == null) {
scoreModel = ScoreModels.getInstance().getDefaultModel(dataType);
if (scoreModel == null) {
scoreModel = scoreComboBox.getItemAt(0);
}
}
scoreComboBox.setSelectedItem(scoreModel);
} else {
scoreComboBox.setEnabled(false);
}
}
use of edu.cmu.tetrad.data.DataType in project tetrad by cmu-phil.
the class IndependenceTestModels method initDefaultModelMap.
private void initDefaultModelMap() {
DataType[] dataTypes = DataType.values();
for (DataType dataType : dataTypes) {
List<IndependenceTestModel> list = getModels(dataType);
if (!list.isEmpty()) {
String property = getProperty(dataType);
if (property == null) {
defaultModelMap.put(dataType, list.get(0));
} else {
String value = TetradProperties.getInstance().getValue(property);
if (value == null) {
defaultModelMap.put(dataType, list.get(0));
} else {
Optional<IndependenceTestModel> result = list.stream().filter(e -> e.getIndependenceTest().getClazz().getName().equals(value)).findFirst();
defaultModelMap.put(dataType, result.isPresent() ? result.get() : list.get(0));
}
}
}
}
}
use of edu.cmu.tetrad.data.DataType in project tetrad by cmu-phil.
the class GeneralAlgorithmRunner method execute.
// ============================PUBLIC METHODS==========================//
@Override
public void execute() {
List<Graph> graphList = new ArrayList<>();
int i = 0;
if (getDataModelList().isEmpty()) {
if (getSourceGraph() != null) {
Algorithm algo = getAlgorithm();
if (algo instanceof HasKnowledge) {
((HasKnowledge) algo).setKnowledge(getKnowledge());
}
graphList.add(algo.search(null, parameters));
} else {
throw new IllegalArgumentException("The parent boxes did not include any datasets or graphs. Try opening\n" + "the editors for those boxes and loading or simulating them.");
}
} else {
if (getAlgorithm() instanceof MultiDataSetAlgorithm) {
for (int k = 0; k < parameters.getInt("numRuns"); k++) {
List<DataSet> dataSets = getDataModelList().stream().map(e -> (DataSet) e).collect(Collectors.toCollection(ArrayList::new));
if (dataSets.size() < parameters.getInt("randomSelectionSize")) {
throw new IllegalArgumentException("Sorry, the 'random selection size' is greater than " + "the number of data sets.");
}
Collections.shuffle(dataSets);
List<DataModel> sub = new ArrayList<>();
for (int j = 0; j < parameters.getInt("randomSelectionSize"); j++) {
sub.add(dataSets.get(j));
}
Algorithm algo = getAlgorithm();
if (algo instanceof HasKnowledge) {
((HasKnowledge) algo).setKnowledge(getKnowledge());
}
graphList.add(((MultiDataSetAlgorithm) algo).search(sub, parameters));
}
} else if (getAlgorithm() instanceof ClusterAlgorithm) {
for (int k = 0; k < parameters.getInt("numRuns"); k++) {
getDataModelList().forEach(dataModel -> {
if (dataModel instanceof ICovarianceMatrix) {
ICovarianceMatrix dataSet = (ICovarianceMatrix) dataModel;
graphList.add(algorithm.search(dataSet, parameters));
} else if (dataModel instanceof DataSet) {
DataSet dataSet = (DataSet) dataModel;
if (!dataSet.isContinuous()) {
throw new IllegalArgumentException("Sorry, you need a continuous dataset for a cluster algorithm.");
}
graphList.add(algorithm.search(dataSet, parameters));
}
});
}
} else {
getDataModelList().forEach(data -> {
IKnowledge knowledgeFromData = data.getKnowledge();
if (!(knowledgeFromData == null || knowledgeFromData.getVariables().isEmpty())) {
this.knowledge = knowledgeFromData;
}
Algorithm algo = getAlgorithm();
if (algo instanceof HasKnowledge) {
((HasKnowledge) algo).setKnowledge(getKnowledge());
}
DataType algDataType = algo.getDataType();
if (data.isContinuous() && (algDataType == DataType.Continuous || algDataType == DataType.Mixed)) {
graphList.add(algo.search(data, parameters));
} else if (data.isDiscrete() && (algDataType == DataType.Discrete || algDataType == DataType.Mixed)) {
graphList.add(algo.search(data, parameters));
} else if (data.isMixed() && algDataType == DataType.Mixed) {
graphList.add(algo.search(data, parameters));
} else {
throw new IllegalArgumentException("The type of data changed; try opening up the search editor and " + "running the algorithm there.");
}
});
}
}
if (getKnowledge().getVariablesNotInTiers().size() < getKnowledge().getVariables().size()) {
for (Graph graph : graphList) {
SearchGraphUtils.arrangeByKnowledgeTiers(graph, getKnowledge());
}
} else {
for (Graph graph : graphList) {
GraphUtils.circleLayout(graph, 225, 200, 150);
}
}
this.graphList = graphList;
}
use of edu.cmu.tetrad.data.DataType in project tetrad by cmu-phil.
the class TimeoutComparison method compareFromSimulations.
/**
* Compares algorithms.
*
* @param resultsPath Path to the file where the output should be printed.
* @param simulations The list of simulationWrapper that is used to generate
* graphs and data for the comparison.
* @param algorithms The list of algorithms to be compared.
* @param statistics The list of statistics on which to compare the
* algorithm, and their utility weights.
*/
public void compareFromSimulations(String resultsPath, Simulations simulations, String outputFileName, Algorithms algorithms, Statistics statistics, Parameters parameters, long timeout, TimeUnit unit) {
this.resultsPath = resultsPath;
// Create output file.
try {
File dir = new File(resultsPath);
dir.mkdirs();
File file = new File(dir, outputFileName);
this.out = new PrintStream(new FileOutputStream(file));
} catch (Exception e) {
throw new RuntimeException(e);
}
out.println(new Date());
// Set up simulations--create data and graphs, read in parameters. The parameters
// are set in the parameters object.
List<SimulationWrapper> simulationWrappers = new ArrayList<>();
int numRuns = parameters.getInt("numRuns");
for (Simulation simulation : simulations.getSimulations()) {
List<SimulationWrapper> wrappers = getSimulationWrappers(simulation, parameters);
for (SimulationWrapper wrapper : wrappers) {
wrapper.createData(wrapper.getSimulationSpecificParameters());
simulationWrappers.add(wrapper);
}
}
// Set up the algorithms.
List<AlgorithmWrapper> algorithmWrappers = new ArrayList<>();
for (Algorithm algorithm : algorithms.getAlgorithms()) {
List<Integer> _dims = new ArrayList<>();
List<String> varyingParameters = new ArrayList<>();
final List<String> parameters1 = algorithm.getParameters();
for (String name : parameters1) {
if (parameters.getNumValues(name) > 1) {
_dims.add(parameters.getNumValues(name));
varyingParameters.add(name);
}
}
if (varyingParameters.isEmpty()) {
algorithmWrappers.add(new AlgorithmWrapper(algorithm, parameters));
} else {
int[] dims = new int[_dims.size()];
for (int i = 0; i < _dims.size(); i++) {
dims[i] = _dims.get(i);
}
CombinationGenerator gen = new CombinationGenerator(dims);
int[] choice;
while ((choice = gen.next()) != null) {
AlgorithmWrapper wrapper = new AlgorithmWrapper(algorithm, parameters);
for (int h = 0; h < dims.length; h++) {
String parameter = varyingParameters.get(h);
Object[] values = parameters.getValues(parameter);
Object value = values[choice[h]];
wrapper.setValue(parameter, value);
}
algorithmWrappers.add(wrapper);
}
}
}
// Create the algorithm-simulation wrappers for every combination of algorithm and
// simulation.
List<AlgorithmSimulationWrapper> algorithmSimulationWrappers = new ArrayList<>();
for (SimulationWrapper simulationWrapper : simulationWrappers) {
for (AlgorithmWrapper algorithmWrapper : algorithmWrappers) {
DataType algDataType = algorithmWrapper.getDataType();
DataType simDataType = simulationWrapper.getDataType();
if (!(algDataType == DataType.Mixed || (algDataType == simDataType))) {
System.out.println("Type mismatch: " + algorithmWrapper.getDescription() + " / " + simulationWrapper.getDescription());
}
if (algorithmWrapper.getAlgorithm() instanceof ExternalAlgorithm) {
ExternalAlgorithm external = (ExternalAlgorithm) algorithmWrapper.getAlgorithm();
// external.setSimulation(simulationWrapper.getSimulation());
// external.setPath(dirs.get(simulationWrappers.indexOf(simulationWrapper)));
// external.setPath(resultsPath);
external.setSimIndex(simulationWrappers.indexOf(external.getSimulation()));
}
algorithmSimulationWrappers.add(new AlgorithmSimulationWrapper(algorithmWrapper, simulationWrapper));
}
}
// Run all of the algorithms and compile statistics.
double[][][][] allStats = calcStats(algorithmSimulationWrappers, algorithmWrappers, simulationWrappers, statistics, numRuns, timeout, unit);
// Print out the preliminary information for statistics types, etc.
if (allStats != null) {
out.println();
out.println("Statistics:");
out.println();
for (Statistic stat : statistics.getStatistics()) {
out.println(stat.getAbbreviation() + " = " + stat.getDescription());
}
}
out.println();
// out.println();
if (allStats != null) {
int numTables = allStats.length;
int numStats = allStats[0][0].length - 1;
double[][][] statTables = calcStatTables(allStats, Mode.Average, numTables, algorithmSimulationWrappers, numStats, statistics);
double[] utilities = calcUtilities(statistics, algorithmSimulationWrappers, statTables[0]);
// Add utilities to table as the last column.
for (int u = 0; u < numTables; u++) {
for (int t = 0; t < algorithmSimulationWrappers.size(); t++) {
statTables[u][t][numStats] = utilities[t];
}
}
int[] newOrder;
if (isSortByUtility()) {
newOrder = sort(algorithmSimulationWrappers, utilities);
} else {
newOrder = new int[algorithmSimulationWrappers.size()];
for (int q = 0; q < algorithmSimulationWrappers.size(); q++) {
newOrder[q] = q;
}
}
out.println("Simulations:");
out.println();
// if (simulationWrappers.size() == 1) {
// out.println(simulationWrappers.get(0).getDescription());
// } else {
int i = 0;
for (SimulationWrapper simulation : simulationWrappers) {
out.print("Simulation " + (++i) + ": ");
out.println(simulation.getDescription());
out.println();
printParameters(simulation.getParameters(), simulation.getSimulationSpecificParameters(), out);
// for (String param : simulation.getParameters()) {
// out.println(param + " = " + simulation.getValue(param));
// }
out.println();
}
// }
out.println("Algorithms:");
out.println();
for (int t = 0; t < algorithmSimulationWrappers.size(); t++) {
AlgorithmSimulationWrapper wrapper = algorithmSimulationWrappers.get(t);
if (wrapper.getSimulationWrapper() == simulationWrappers.get(0)) {
out.println((t + 1) + ". " + wrapper.getAlgorithmWrapper().getDescription());
}
}
if (isSortByUtility()) {
out.println();
out.println("Sorting by utility, high to low.");
}
if (isShowUtilities()) {
out.println();
out.println("Weighting of statistics:");
out.println();
out.println("U = ");
for (Statistic stat : statistics.getStatistics()) {
String statName = stat.getAbbreviation();
double weight = statistics.getWeight(stat);
if (weight != 0.0) {
out.println(" " + weight + " * f(" + statName + ")");
}
}
out.println();
out.println("...normed to range between 0 and 1.");
out.println();
out.println("Note that f for each statistic is a function that maps the statistic to the ");
out.println("interval [0, 1], with higher being better.");
}
out.println();
out.println("Graphs are being compared to the " + comparisonGraph.toString().replace("_", " ") + ".");
out.println();
// Add utilities to table as the last column.
for (int u = 0; u < numTables; u++) {
for (int t = 0; t < algorithmSimulationWrappers.size(); t++) {
statTables[u][t][numStats] = utilities[t];
}
}
// Print all of the tables.
printStats(statTables, statistics, Mode.Average, newOrder, algorithmSimulationWrappers, algorithmWrappers, simulationWrappers, utilities, parameters);
statTables = calcStatTables(allStats, Mode.StandardDeviation, numTables, algorithmSimulationWrappers, numStats, statistics);
printStats(statTables, statistics, Mode.StandardDeviation, newOrder, algorithmSimulationWrappers, algorithmWrappers, simulationWrappers, utilities, parameters);
statTables = calcStatTables(allStats, Mode.WorstCase, numTables, algorithmSimulationWrappers, numStats, statistics);
// Add utilities to table as the last column.
for (int u = 0; u < numTables; u++) {
for (int t = 0; t < algorithmSimulationWrappers.size(); t++) {
statTables[u][t][numStats] = utilities[t];
}
}
printStats(statTables, statistics, Mode.WorstCase, newOrder, algorithmSimulationWrappers, algorithmWrappers, simulationWrappers, utilities, parameters);
}
out.close();
}
Aggregations