use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.
the class TestGeneralBootstrapTest method makeDiscreteDAG.
private static Graph makeDiscreteDAG(int numVars, int numLatentConfounders, double edgesPerNode) {
final int numEdges = (int) (numVars * edgesPerNode);
// System.out.println("Making list of vars");
List<Node> vars = new ArrayList<>();
for (int i = 0; i < numVars; i++) {
vars.add(new DiscreteVariable(Integer.toString(i)));
}
// System.out.println("Making dag");
return GraphUtils.randomGraph(vars, numLatentConfounders, numEdges, 30, 15, 15, false);
}
use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.
the class BayesUpdaterClassifierEditor method showRocCurve.
private void showRocCurve() {
int tabIndex = -1;
for (int i = 0; i < getTabbedPane().getTabCount(); i++) {
if ("ROC Plot".equals(getTabbedPane().getTitleAt(i))) {
getTabbedPane().remove(i);
tabIndex = i;
this.rocPlot = null;
this.saveRoc.setEnabled(false);
}
}
double[][] marginals = getClassifier().getMarginals();
int ncases = getClassifier().getNumCases();
boolean[] inCategory = new boolean[ncases];
DataSet testData = getClassifier().getTestData();
DiscreteVariable targetVariable = classifier.getTargetVariable();
String targetName = targetVariable.getName();
Node variable2 = testData.getVariable(targetName);
int varIndex = testData.getVariables().indexOf(variable2);
// If the target is not in the data set, don't compute a ROC plot.
if (varIndex == -1) {
return;
}
String category = (String) getCategoryDropdown().getSelectedItem();
DiscreteVariable variable = (DiscreteVariable) variable2;
int catIndex = variable.getIndex(category);
for (int i = 0; i < inCategory.length; i++) {
inCategory[i] = (testData.getInt(i, varIndex) == catIndex);
}
double[] scores = marginals[catIndex];
RocCalculator rocc = new RocCalculator(scores, inCategory, RocCalculator.ASCENDING);
double[][] points = rocc.getScaledRocPlot();
double area = rocc.getAuc();
NumberFormat nf = NumberFormatUtil.getInstance().getNumberFormat();
String info = "AUC = " + nf.format(area);
String title = "ROC Plot, " + classifier.getTargetVariable() + " = " + category;
RocPlot plot = new RocPlot(points, title, info);
this.rocPlot = plot;
this.saveRoc.setEnabled(true);
if (tabIndex == -1) {
getTabbedPane().add("ROC Plot", plot);
} else {
getTabbedPane().add(plot, tabIndex);
getTabbedPane().setTitleAt(tabIndex, "ROC Plot");
}
}
use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.
the class BayesUpdaterClassifierEditor method showClassification.
private void showClassification() {
int tabIndex = -1;
for (int i = 0; i < getTabbedPane().getTabCount(); i++) {
if ("Classification".equals(getTabbedPane().getTitleAt(i))) {
getTabbedPane().remove(i);
tabIndex = i;
}
}
// Put the class information into a DataSet.
int[] classifications = getClassifier().getClassifications();
double[][] marginals = getClassifier().getMarginals();
int maxCategory = 0;
for (int classification : classifications) {
if (classification > maxCategory) {
maxCategory = classification;
}
}
List<Node> variables = new LinkedList<>();
DiscreteVariable targetVariable = classifier.getTargetVariable();
DiscreteVariable classVar = new DiscreteVariable(targetVariable.getName(), maxCategory + 1);
variables.add(classVar);
for (int i = 0; i < marginals.length; i++) {
String name = "P(" + targetVariable + "=" + i + ")";
ContinuousVariable scoreVar = new ContinuousVariable(name);
variables.add(scoreVar);
}
classVar.setName("Result");
DataSet dataSet = new ColtDataSet(classifications.length, variables);
for (int i = 0; i < classifications.length; i++) {
dataSet.setInt(i, 0, classifications[i]);
for (int j = 0; j < marginals.length; j++) {
dataSet.setDouble(i, j + 1, marginals[j][i]);
}
}
DataDisplay jTable = new DataDisplay(dataSet);
JScrollPane scroll = new JScrollPane(jTable);
if (tabIndex == -1) {
getTabbedPane().add("Classification", scroll);
} else {
getTabbedPane().add(scroll, tabIndex);
getTabbedPane().setTitleAt(tabIndex, "Classification");
}
}
use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.
the class ModeInterpolator method filter.
public DataSet filter(DataSet dataSet) {
DataSet newDataSet = dataSet.copy();
for (int j = 0; j < dataSet.getNumColumns(); j++) {
Node var = dataSet.getVariable(j);
if (var instanceof DiscreteVariable) {
DiscreteVariable variable = (DiscreteVariable) var;
int numCategories = variable.getNumCategories();
int[] categoryCounts = new int[numCategories];
for (int i = 0; i < dataSet.getNumRows(); i++) {
if (dataSet.getInt(i, j) == DiscreteVariable.MISSING_VALUE) {
continue;
}
categoryCounts[dataSet.getInt(i, j)]++;
}
int mode = -1;
int max = -1;
for (int i = 0; i < numCategories; i++) {
if (categoryCounts[i] > max) {
max = categoryCounts[i];
mode = i;
}
}
for (int i = 0; i < dataSet.getNumRows(); i++) {
if (dataSet.getInt(i, j) == DiscreteVariable.MISSING_VALUE) {
newDataSet.setInt(i, j, mode);
}
// else {
// newDataSet.setInt(i, j, dataSet.getInt(i, j));
// }
}
} else if (dataSet.getVariable(j) instanceof ContinuousVariable) {
double[] data = new double[dataSet.getNumRows()];
int k = -1;
for (int i = 0; i < dataSet.getNumRows(); i++) {
if (!Double.isNaN(dataSet.getDouble(i, j))) {
data[++k] = dataSet.getDouble(i, j);
}
}
Arrays.sort(data);
double mode = Double.NaN;
if (k >= 0) {
mode = (data[(k + 1) / 2] + data[k / 2]) / 2.d;
}
for (int i = 0; i < dataSet.getNumRows(); i++) {
if (Double.isNaN(dataSet.getDouble(i, j))) {
newDataSet.setDouble(i, j, mode);
}
// else {
// newDataSet.setDouble(i, j, dataSet.getDouble(i, j));
// }
}
}
}
return newDataSet;
}
use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.
the class BayesPm method setNumCategories.
/**
* Sets the number of values for the given node to the given number.
*/
public void setNumCategories(Node node, int numCategories) {
if (!nodesToVariables.containsKey(node)) {
throw new IllegalArgumentException("Node not in BayesPm: " + node);
}
if (numCategories < 1) {
throw new IllegalArgumentException("Number of categories must be >= 1: " + numCategories);
}
DiscreteVariable variable = nodesToVariables.get(node);
List<String> oldCategories = variable.getCategories();
List<String> newCategories = new LinkedList<>();
int min = Math.min(numCategories, oldCategories.size());
for (int i = 0; i < min; i++) {
newCategories.add(oldCategories.get(i));
}
for (int i = min; i < numCategories; i++) {
String proposedName = DataUtils.defaultCategory(i);
if (newCategories.contains(proposedName)) {
throw new IllegalArgumentException("Default name already in " + "list of categories: " + proposedName);
}
newCategories.add(proposedName);
}
mapNodeToVariable(node, newCategories);
}
Aggregations