use of edu.cmu.tetrad.data.DataSet in project tetrad by cmu-phil.
the class MGM method runTests2.
/**
* test non penalty use cases
*/
private static void runTests2() {
Graph g = GraphConverter.convert("X1-->X2,X3-->X2,X4-->X5");
// simple graph pm im gen example
HashMap<String, Integer> nd = new HashMap<>();
nd.put("X1", 0);
nd.put("X2", 0);
nd.put("X3", 4);
nd.put("X4", 4);
nd.put("X5", 4);
g = MixedUtils.makeMixedGraph(g, nd);
GeneralizedSemPm pm = MixedUtils.GaussianCategoricalPm(g, "Split(-1.5,-.5,.5,1.5)");
System.out.println(pm);
GeneralizedSemIm im = MixedUtils.GaussianCategoricalIm(pm);
System.out.println(im);
int samps = 1000;
DataSet ds = im.simulateDataFisher(samps);
ds = MixedUtils.makeMixedData(ds, nd);
// System.out.println(ds);
double lambda = 0;
MGM model = new MGM(ds, new double[] { lambda, lambda, lambda });
System.out.println("Init nll: " + model.smoothValue(model.params.toMatrix1D()));
System.out.println("Init reg term: " + model.nonSmoothValue(model.params.toMatrix1D()));
model.learn(1e-8, 1000);
System.out.println("Learned nll: " + model.smoothValue(model.params.toMatrix1D()));
System.out.println("Learned reg term: " + model.nonSmoothValue(model.params.toMatrix1D()));
System.out.println("params:\n" + model.params);
System.out.println("adjMat:\n" + model.adjMatFromMGM());
}
use of edu.cmu.tetrad.data.DataSet in project tetrad by cmu-phil.
the class GeneralBootstrapSearchAction method compute.
@Override
public void compute() {
if (workLoad < 2) {
long start, stop;
start = System.currentTimeMillis();
if (verbose) {
out.println("thread started ... ");
}
DataSet data = generalBootstrapSearch.getData();
DataSet dataSet = DataUtils.getBootstrapSample(data, data.getNumRows());
Graph graph = algorithm.search(dataSet, parameters);
stop = System.currentTimeMillis();
if (verbose) {
out.println("processing time of bootstrap for thread id : " + dataSetId + " was: " + (stop - start) / 1000.0 + " sec");
}
generalBootstrapSearch.addPAG(graph);
} else {
GeneralBootstrapSearchAction task1 = new GeneralBootstrapSearchAction(dataSetId, workLoad / 2, algorithm, parameters, generalBootstrapSearch, verbose);
GeneralBootstrapSearchAction task2 = new GeneralBootstrapSearchAction(dataSetId + workLoad / 2, workLoad - workLoad / 2, algorithm, parameters, generalBootstrapSearch, verbose);
List<GeneralBootstrapSearchAction> tasks = new ArrayList<>();
tasks.add(task1);
tasks.add(task2);
invokeAll(tasks);
}
}
use of edu.cmu.tetrad.data.DataSet in project tetrad by cmu-phil.
the class ADTreeTest method main.
public static void main(String[] args) throws Exception {
int columns = 40;
int numEdges = 40;
int rows = 500;
List<Node> variables = new ArrayList<>();
List<String> varNames = new ArrayList<>();
for (int i = 0; i < columns; i++) {
final String name = "X" + (i + 1);
varNames.add(name);
variables.add(new ContinuousVariable(name));
}
Graph graph = GraphUtils.randomGraphRandomForwardEdges(variables, 0, numEdges, 30, 15, 15, false, true);
BayesPm pm = new BayesPm(graph);
BayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM);
DataSet data = im.simulateData(rows, false);
// This implementation uses a DataTable to represent the data
// The first type parameter is the type for the variables
// The second type parameter is the type for the values of the variables
DataTableImpl<Node, Short> dataTable = new DataTableImpl<>(variables);
for (int i = 0; i < rows; i++) {
ArrayList<Short> intArray = new ArrayList<>();
for (int j = 0; j < columns; j++) {
intArray.add((short) data.getInt(i, j));
}
dataTable.addRow(intArray);
}
// create the tree
long start = System.currentTimeMillis();
ADTree<Node, Short> adTree = new ADTree<>(dataTable);
System.out.println(String.format("Generated tree in %s millis", System.currentTimeMillis() - start));
// the query is an arbitrary map of vars and their values
TreeMap<Node, Short> query = new TreeMap<>();
query.put(node(pm, "X1"), (short) 1);
query.put(node(pm, "X5"), (short) 0);
start = System.currentTimeMillis();
System.out.println(String.format("Count is %d", adTree.count(query)));
System.out.println(String.format("Query in %s ms", System.currentTimeMillis() - start));
query.clear();
query.put(node(pm, "X1"), (short) 1);
query.put(node(pm, "X2"), (short) 1);
query.put(node(pm, "X5"), (short) 0);
query.put(node(pm, "X10"), (short) 1);
start = System.currentTimeMillis();
System.out.println(String.format("Count is %d", adTree.count(query)));
System.out.println(String.format("Query in %s ms", System.currentTimeMillis() - start));
}
use of edu.cmu.tetrad.data.DataSet in project tetrad by cmu-phil.
the class HistogramAction method actionPerformed.
public void actionPerformed(ActionEvent e) {
DataSet dataSet = (DataSet) dataEditor.getSelectedDataModel();
if (dataSet == null || dataSet.getNumColumns() == 0) {
JOptionPane.showMessageDialog(findOwner(), "Cannot display a histogram for an empty data set.");
return;
}
int[] selected = dataSet.getSelectedIndices();
// if more then one column is selected then open up more than one histogram
if (selected != null && selected.length >= 1) {
// warn user if they selected more than 10
if (selected.length > 10) {
int option = JOptionPane.showConfirmDialog(findOwner(), "You are about to open " + selected.length + " histograms, are you sure you want to proceed?", "Histogram Warning", JOptionPane.YES_NO_OPTION);
// if selected no, return
if (option == JOptionPane.NO_OPTION) {
return;
}
}
for (int index : selected) {
JPanel component = createHistogramPanel(dataSet.getVariable(index));
EditorWindow editorWindow = new EditorWindow(component, "Histogram", "Close", false, dataEditor);
DesktopController.getInstance().addEditorWindow(editorWindow, JLayeredPane.PALETTE_LAYER);
editorWindow.pack();
setLocation(editorWindow, index);
editorWindow.setVisible(true);
}
} else {
// No selected variable--just show a histogram for the first variable.
JPanel component = createHistogramPanel(null);
EditorWindow editorWindow = new EditorWindow(component, "Histogram", "Close", false, dataEditor);
DesktopController.getInstance().addEditorWindow(editorWindow, JLayeredPane.PALETTE_LAYER);
editorWindow.pack();
editorWindow.setVisible(true);
}
}
use of edu.cmu.tetrad.data.DataSet in project tetrad by cmu-phil.
the class GdistanceApply method main.
public static void main(String... args) {
double xdist = 2.4;
double ydist = 2.4;
double zdist = 2;
long timestart = System.nanoTime();
System.out.println("Loading first graph");
Graph graph1 = GraphUtils.loadGraphTxt(new File("Motion_Corrected_Graphs/singlesub_motion_graph_025_04.txt"));
long timegraph1 = System.nanoTime();
// System.out.println(graph1);
System.out.println("Done loading first graph. Elapsed time: " + (timegraph1 - timestart) / 1000000000 + "s");
System.out.println("Loading second graph");
Graph graph2 = GraphUtils.loadGraphTxt(new File("Motion_Corrected_Graphs/singlesub_motion_graph_027_04.txt"));
long timegraph2 = System.nanoTime();
System.out.println("Done loading second graph. Elapsed time: " + (timegraph2 - timegraph1) / 1000000000 + "s");
// +++++++++ these steps are specifically for the motion corrected fMRI graphs ++++++++++++
graph1.removeNode(graph1.getNode("Motion_1"));
graph1.removeNode(graph1.getNode("Motion_2"));
graph1.removeNode(graph1.getNode("Motion_3"));
graph1.removeNode(graph1.getNode("Motion_4"));
graph1.removeNode(graph1.getNode("Motion_5"));
graph1.removeNode(graph1.getNode("Motion_6"));
graph2.removeNode(graph2.getNode("Motion_1"));
graph2.removeNode(graph2.getNode("Motion_2"));
graph2.removeNode(graph2.getNode("Motion_3"));
graph2.removeNode(graph2.getNode("Motion_4"));
graph2.removeNode(graph2.getNode("Motion_5"));
graph2.removeNode(graph2.getNode("Motion_6"));
// load the location map
String workingDirectory = System.getProperty("user.dir");
System.out.println(workingDirectory);
Path mapPath = Paths.get("coords.txt");
System.out.println(mapPath);
TabularDataReader dataReaderMap = new ContinuousTabularDataFileReader(mapPath.toFile(), Delimiter.COMMA);
try {
DataSet locationMap = (DataSet) DataConvertUtils.toDataModel(dataReaderMap.readInData());
long timegraph3 = System.nanoTime();
System.out.println("Done loading location map. Elapsed time: " + (timegraph3 - timegraph2) / 1000000000 + "s");
System.out.println("Running Gdistance");
Gdistance gdist = new Gdistance(locationMap, xdist, ydist, zdist);
List<Double> distance = gdist.distances(graph1, graph2);
System.out.println(distance);
System.out.println("Done running Distance. Elapsed time: " + (System.nanoTime() - timegraph3) / 1000000000 + "s");
System.out.println("Total elapsed time: " + (System.nanoTime() - timestart) / 1000000000 + "s");
PrintWriter writer = new PrintWriter("Gdistances.txt", "UTF-8");
writer.println(distance);
writer.close();
} catch (Exception IOException) {
IOException.printStackTrace();
}
}
Aggregations