use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.
the class BayesUpdaterClassifierEditor method getToolbar.
private Component getToolbar() {
JButton classifyButton = new JButton("Classify");
classifyButton.addActionListener(new ActionListener() {
public void actionPerformed(ActionEvent e) {
Window owner = (Window) getTopLevelAncestor();
new WatchedProcess(owner) {
public void watch() {
doClassify();
showClassification();
showRocCurve();
showConfusionMatrix();
}
};
}
});
List<Node> nodes = getClassifier().getBayesImVars();
Node[] variables = nodes.toArray(new Node[0]);
this.variableDropdown = new JComboBox(variables);
getVariableDropdown().setBackground(Color.WHITE);
getVariableDropdown().setMaximumSize(new Dimension(200, 50));
DiscreteVariable variable = (DiscreteVariable) getVariableDropdown().getSelectedItem();
this.categoryDropdown = new JComboBox(variable.getCategories().toArray(new String[0]));
getCategoryDropdown().setBackground(Color.WHITE);
getCategoryDropdown().setMaximumSize(new Dimension(200, 50));
this.variableDropdown.addItemListener(new ItemListener() {
public void itemStateChanged(ItemEvent e) {
JComboBox comboBox = (JComboBox) e.getSource();
Object selectedItem = comboBox.getSelectedItem();
DiscreteVariable variable = (DiscreteVariable) selectedItem;
List<String> categories = variable.getCategories();
DefaultComboBoxModel newModel = new DefaultComboBoxModel(categories.toArray(new String[0]));
getCategoryDropdown().setModel(newModel);
// if (categories.size() == 2) {
// getBinaryCutoffField().setEnabled(true);
// getBinaryCutoffField().setEditable(true);
// }
// else {
// getBinaryCutoffField().setEnabled(false);
// getBinaryCutoffField().setEditable(false);
// }
}
});
this.categoryDropdown.addItemListener(new ItemListener() {
public void itemStateChanged(ItemEvent e) {
showRocCurve();
}
});
// this.binaryCutoffField = new DoubleTextField(getBinaryCutoff(), 5,
// NumberFormatUtil.getInstance().getNumberFormat());
// this.binaryCutoffField.setFilter(new DoubleTextField.Filter() {
// public double filter(double value, double oldValue) {
// if (value >= 0.0 && value <= 1.0) {
// setBinaryCutoff(value);
// return value;
// }
//
// return oldValue;
// }
// });
// DiscreteVariable selectedVar =
// (DiscreteVariable) this.variableDropdown.getSelectedItem();
// List<String> categories = selectedVar.getCategories();
//
// if (categories.size() == 2) {
// getBinaryCutoffField().setEnabled(true);
// getBinaryCutoffField().setEditable(true);
// }
// else {
// getBinaryCutoffField().setEnabled(false);
// getBinaryCutoffField().setEditable(false);
// }
Box toolbar = Box.createVerticalBox();
Box row1 = Box.createHorizontalBox();
row1.add(Box.createHorizontalStrut(5));
row1.add(new JLabel("Target = "));
row1.add(getVariableDropdown());
row1.add(Box.createHorizontalStrut(5));
row1.add(new JLabel("Category for ROC ="));
row1.add(getCategoryDropdown());
row1.add(Box.createHorizontalStrut(10));
row1.add(classifyButton);
row1.add(Box.createHorizontalGlue());
toolbar.add(row1);
toolbar.add(Box.createVerticalStrut(5));
// Box row2 = Box.createHorizontalBox();
// row2.add(Box.createHorizontalStrut(5));
// row2.add(new JLabel("(Cutoff for binary target = "));
// row2.add(getBinaryCutoffField());
// row2.add(new JLabel(" )"));
// row2.add(Box.createHorizontalGlue());
// toolbar.add(row2);
toolbar.setBorder(new EmptyBorder(2, 2, 2, 2));
return toolbar;
}
use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.
the class MGM method runTests1.
private static void runTests1() {
try {
// DoubleMatrix2D xIn = DoubleFactory2D.dense.make(loadDataSelect("/Users/ajsedgewick/tetrad/test_data", "med_test_C.txt"));
// DoubleMatrix2D yIn = DoubleFactory2D.dense.make(loadDataSelect("/Users/ajsedgewick/tetrad/test_data", "med_test_D.txt"));
// String path = MGM.class.getResource("test_data").getPath();
String path = "/Users/ajsedgewick/tetrad_master/tetrad/tetrad-lib/src/main/java/edu/pitt/csb/mgm/test_data";
System.out.println(path);
DoubleMatrix2D xIn = DoubleFactory2D.dense.make(MixedUtils.loadDelim(path, "med_test_C.txt").getDoubleData().toArray());
DoubleMatrix2D yIn = DoubleFactory2D.dense.make(MixedUtils.loadDelim(path, "med_test_D.txt").getDoubleData().toArray());
int[] L = new int[24];
Node[] vars = new Node[48];
for (int i = 0; i < 24; i++) {
L[i] = 2;
vars[i] = new ContinuousVariable("X" + i);
vars[i + 24] = new DiscreteVariable("Y" + i);
}
double lam = .2;
MGM model = new MGM(xIn, yIn, new ArrayList<>(Arrays.asList(vars)), L, new double[] { lam, lam, lam });
MGM model2 = new MGM(xIn, yIn, new ArrayList<>(Arrays.asList(vars)), L, new double[] { lam, lam, lam });
System.out.println("Weights: " + Arrays.toString(model.weights.toArray()));
DoubleMatrix2D test = xIn.copy();
DoubleMatrix2D test2 = xIn.copy();
long t = System.currentTimeMillis();
for (int i = 0; i < 50000; i++) {
test2 = xIn.copy();
test.assign(test2);
}
System.out.println("assign Time: " + (System.currentTimeMillis() - t));
t = System.currentTimeMillis();
double[][] xArr = xIn.toArray();
for (int i = 0; i < 50000; i++) {
if (Thread.currentThread().isInterrupted()) {
break;
}
// test = DoubleFactory2D.dense.make(xArr);
test2 = xIn.copy();
test = test2;
}
System.out.println("equals Time: " + (System.currentTimeMillis() - t));
System.out.println("Init nll: " + model.smoothValue(model.params.toMatrix1D()));
System.out.println("Init reg term: " + model.nonSmoothValue(model.params.toMatrix1D()));
t = System.currentTimeMillis();
model.learnEdges(700);
// model.learn(1e-7, 700);
System.out.println("Orig Time: " + (System.currentTimeMillis() - t));
System.out.println("nll: " + model.smoothValue(model.params.toMatrix1D()));
System.out.println("reg term: " + model.nonSmoothValue(model.params.toMatrix1D()));
System.out.println("params:\n" + model.params);
System.out.println("adjMat:\n" + model.adjMatFromMGM());
} catch (IOException ex) {
ex.printStackTrace();
}
}
use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.
the class DataConvertUtils method toMixedDataBox.
public static DataModel toMixedDataBox(MixedTabularDataset mixedTabularDataset) {
int numOfRows = mixedTabularDataset.getNumOfRows();
MixedVarInfo[] mixedVarInfos = mixedTabularDataset.getMixedVarInfos();
double[][] continuousData = mixedTabularDataset.getContinuousData();
int[][] discreteData = mixedTabularDataset.getDiscreteData();
List<Node> nodes = new LinkedList<>();
for (MixedVarInfo mixedVarInfo : mixedVarInfos) {
if (mixedVarInfo.isContinuous()) {
nodes.add(new ContinuousVariable(mixedVarInfo.getName()));
} else {
nodes.add(new DiscreteVariable(mixedVarInfo.getName(), mixedVarInfo.getCategories()));
}
}
return new BoxDataSet(new MixedDataBox(nodes, numOfRows, continuousData, discreteData), nodes);
}
use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.
the class TestLogisticRegression method test1.
@Test
public void test1() {
List<Node> nodes = new ArrayList<>();
for (int i = 0; i < 5; i++) {
nodes.add(new ContinuousVariable("X" + (i + 1)));
}
Graph graph = new Dag(GraphUtils.randomGraph(nodes, 0, 5, 3, 3, 3, false));
System.out.println(graph);
SemPm pm = new SemPm(graph);
SemIm im = new SemIm(pm);
DataSet data = im.simulateDataRecursive(1000, false);
Node x1 = data.getVariable("X1");
Node x2 = data.getVariable("X2");
Node x3 = data.getVariable("X3");
Node x4 = data.getVariable("X4");
Node x5 = data.getVariable("X5");
Discretizer discretizer = new Discretizer(data);
discretizer.equalCounts(x1, 2);
DataSet d2 = discretizer.discretize();
LogisticRegression regression = new LogisticRegression(d2);
List<Node> regressors = new ArrayList<>();
regressors.add(x2);
regressors.add(x3);
regressors.add(x4);
regressors.add(x5);
DiscreteVariable x1b = (DiscreteVariable) d2.getVariable("X1");
regression.regress(x1b, regressors);
System.out.println(regression);
}
use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.
the class AdLeafTree method getCellLeaves.
/**
* Finds the set of indices into the leaves of the tree for the given variables.
* Counts are the sizes of the index sets.
*
* @param A A list of discrete variables.
* @return The list of index sets of the first variable varied by the second variable,
* and so on, to the last variable.
*/
public List<List<Integer>> getCellLeaves(List<DiscreteVariable> A) {
Collections.sort(A, new Comparator<DiscreteVariable>() {
@Override
public int compare(DiscreteVariable o1, DiscreteVariable o2) {
return Integer.compare(nodesHash.get(o1), nodesHash.get(o2));
}
});
if (baseCase == null) {
Vary vary = new Vary();
this.baseCase = new ArrayList<>();
baseCase.add(vary);
}
List<Vary> varies = baseCase;
for (DiscreteVariable v : A) {
varies = getVaries(varies, nodesHash.get(v));
}
List<List<Integer>> rows = new ArrayList<>();
for (Vary vary : varies) {
rows.addAll(vary.getRows());
}
return rows;
}
Aggregations