use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.
the class Mgm method search.
@Override
public Graph search(DataModel ds, Parameters parameters) {
// Notify the user that you need at least one continuous and one discrete variable to run MGM
List<Node> variables = ds.getVariables();
boolean hasContinuous = false;
boolean hasDiscrete = false;
for (Node node : variables) {
if (node instanceof ContinuousVariable) {
hasContinuous = true;
}
if (node instanceof DiscreteVariable) {
hasDiscrete = true;
}
}
if (!hasContinuous || !hasDiscrete) {
throw new IllegalArgumentException("You need at least one continuous and one discrete variable to run MGM.");
}
if (parameters.getInt("bootstrapSampleSize") < 1) {
DataSet _ds = DataUtils.getMixedDataSet(ds);
double mgmParam1 = parameters.getDouble("mgmParam1");
double mgmParam2 = parameters.getDouble("mgmParam2");
double mgmParam3 = parameters.getDouble("mgmParam3");
double[] lambda = { mgmParam1, mgmParam2, mgmParam3 };
MGM m = new MGM(_ds, lambda);
return m.search();
} else {
Mgm algorithm = new Mgm();
DataSet data = (DataSet) ds;
GeneralBootstrapTest search = new GeneralBootstrapTest(data, algorithm, parameters.getInt("bootstrapSampleSize"));
BootstrapEdgeEnsemble edgeEnsemble = BootstrapEdgeEnsemble.Highest;
switch(parameters.getInt("bootstrapEnsemble", 1)) {
case 0:
edgeEnsemble = BootstrapEdgeEnsemble.Preserved;
break;
case 1:
edgeEnsemble = BootstrapEdgeEnsemble.Highest;
break;
case 2:
edgeEnsemble = BootstrapEdgeEnsemble.Majority;
}
search.setEdgeEnsemble(edgeEnsemble);
search.setParameters(parameters);
search.setVerbose(parameters.getBoolean("verbose"));
return search.search();
}
}
use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.
the class IndTestMultinomialLogisticRegression method expandVariable.
private List<Node> expandVariable(DataSet dataSet, Node node) {
if (node instanceof ContinuousVariable) {
return Collections.singletonList(node);
}
if (node instanceof DiscreteVariable && ((DiscreteVariable) node).getNumCategories() < 3) {
return Collections.singletonList(node);
}
if (!(node instanceof DiscreteVariable)) {
throw new IllegalArgumentException();
}
List<String> varCats = new ArrayList<>(((DiscreteVariable) node).getCategories());
varCats.remove(0);
List<Node> variables = new ArrayList<>();
for (String cat : varCats) {
Node newVar;
do {
String newVarName = node.getName() + "MULTINOM" + "." + cat;
newVar = new DiscreteVariable(newVarName, 2);
} while (dataSet.getVariable(newVar.getName()) != null);
variables.add(newVar);
dataSet.addVariable(newVar);
int newVarIndex = dataSet.getColumn(newVar);
int numCases = dataSet.getNumRows();
for (int l = 0; l < numCases; l++) {
Object dataCell = dataSet.getObject(l, dataSet.getColumn(node));
int dataCellIndex = ((DiscreteVariable) node).getIndex(dataCell.toString());
if (dataCellIndex == ((DiscreteVariable) node).getIndex(cat))
dataSet.setInt(l, newVarIndex, 1);
else
dataSet.setInt(l, newVarIndex, 0);
}
}
return variables;
}
use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.
the class BayesPmEditorWizard method copyCategories.
private void copyCategories() {
Node node = (Node) variableChooser.getSelectedItem();
DiscreteVariable variable = (DiscreteVariable) bayesPm.getVariable(node);
this.copiedCategories = variable.getCategories();
}
use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.
the class BayesUpdaterClassifierEditor method doClassify.
private void doClassify() {
DiscreteVariable variable = (DiscreteVariable) getVariableDropdown().getSelectedItem();
String varName = variable.getName();
String category = (String) getCategoryDropdown().getSelectedItem();
int catIndex = variable.getIndex(category);
getClassifier().setTarget(varName, catIndex);
getClassifier().classify();
}
use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.
the class BayesUpdaterClassifierEditor method showConfusionMatrix.
private void showConfusionMatrix() {
int tabIndex = -1;
for (int i = 0; i < getTabbedPane().getTabCount(); i++) {
if ("Confusion Matrix".equals(getTabbedPane().getTitleAt(i))) {
getTabbedPane().remove(i);
tabIndex = i;
}
}
StringBuilder buf = new StringBuilder();
int[][] crossTabs = getClassifier().crossTabulation();
// this case, don't put the confusion matrix back in.
if (crossTabs == null) {
return;
}
DiscreteVariable targetVariable = getClassifier().getTargetVariable();
int nvalues = targetVariable.getNumCategories();
int ncases = getClassifier().getNumCases();
int ntot = getClassifier().getTotalUsableCases();
// System.out.println("Number correct = " + numCorrect);
// buf.append("<html><pre>");
buf.append("Total number of usable cases = ");
buf.append(ntot);
buf.append(" out of ");
buf.append(ncases);
buf.append("\n\nTarget Variable ");
buf.append(targetVariable);
buf.append("\n\t\tEstimated\t");
buf.append("\nObserved\t");
for (int i = 0; i < nvalues - 1; i++) {
buf.append(targetVariable.getCategory(i));
buf.append("\t");
}
buf.append(targetVariable.getCategory(nvalues - 1));
for (int i = 0; i < nvalues; i++) {
buf.append("\n");
buf.append(targetVariable.getCategory(i));
buf.append("\t");
for (int j = 0; j < nvalues - 1; j++) {
buf.append(crossTabs[i][j]);
buf.append("\t");
}
buf.append(crossTabs[i][nvalues - 1]);
}
buf.append("\n\nPercentage correctly classified: ");
buf.append(getClassifier().getPercentCorrect());
// buf.append("</pre></html>");
JTextArea label = new JTextArea(buf.toString());
// label.setFocusable(false);
label.setFont(new Font("SansSerif", Font.PLAIN, 14));
JPanel panel = new JPanel();
panel.setLayout(new BorderLayout());
panel.setBackground(Color.WHITE);
Box b1 = Box.createVerticalBox();
Box b2 = Box.createHorizontalBox();
b2.add(Box.createHorizontalStrut(5));
b2.add(label);
b2.add(Box.createHorizontalGlue());
b1.add(b2);
b1.add(Box.createVerticalGlue());
b1.add(Box.createVerticalGlue());
panel.add(b1, BorderLayout.CENTER);
JScrollPane scroll = new JScrollPane(panel);
if (tabIndex == -1) {
getTabbedPane().add("Confusion Matrix", scroll);
} else {
getTabbedPane().add(scroll, tabIndex);
getTabbedPane().setTitleAt(tabIndex, "Confusion Matrix");
}
}
Aggregations