Search in sources :

Example 46 with DiscreteVariable

use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.

the class TestGeneralBootstrapTest method makeDiscreteDAG.

private static Graph makeDiscreteDAG(int numVars, int numLatentConfounders, double edgesPerNode) {
    final int numEdges = (int) (numVars * edgesPerNode);
    // System.out.println("Making list of vars");
    List<Node> vars = new ArrayList<>();
    for (int i = 0; i < numVars; i++) {
        vars.add(new DiscreteVariable(Integer.toString(i)));
    }
    // System.out.println("Making dag");
    return GraphUtils.randomGraph(vars, numLatentConfounders, numEdges, 30, 15, 15, false);
}
Also used : DiscreteVariable(edu.cmu.tetrad.data.DiscreteVariable) Node(edu.cmu.tetrad.graph.Node) ArrayList(java.util.ArrayList)

Example 47 with DiscreteVariable

use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.

the class BayesUpdaterClassifierEditor method showRocCurve.

private void showRocCurve() {
    int tabIndex = -1;
    for (int i = 0; i < getTabbedPane().getTabCount(); i++) {
        if ("ROC Plot".equals(getTabbedPane().getTitleAt(i))) {
            getTabbedPane().remove(i);
            tabIndex = i;
            this.rocPlot = null;
            this.saveRoc.setEnabled(false);
        }
    }
    double[][] marginals = getClassifier().getMarginals();
    int ncases = getClassifier().getNumCases();
    boolean[] inCategory = new boolean[ncases];
    DataSet testData = getClassifier().getTestData();
    DiscreteVariable targetVariable = classifier.getTargetVariable();
    String targetName = targetVariable.getName();
    Node variable2 = testData.getVariable(targetName);
    int varIndex = testData.getVariables().indexOf(variable2);
    // If the target is not in the data set, don't compute a ROC plot.
    if (varIndex == -1) {
        return;
    }
    String category = (String) getCategoryDropdown().getSelectedItem();
    DiscreteVariable variable = (DiscreteVariable) variable2;
    int catIndex = variable.getIndex(category);
    for (int i = 0; i < inCategory.length; i++) {
        inCategory[i] = (testData.getInt(i, varIndex) == catIndex);
    }
    double[] scores = marginals[catIndex];
    RocCalculator rocc = new RocCalculator(scores, inCategory, RocCalculator.ASCENDING);
    double[][] points = rocc.getScaledRocPlot();
    double area = rocc.getAuc();
    NumberFormat nf = NumberFormatUtil.getInstance().getNumberFormat();
    String info = "AUC = " + nf.format(area);
    String title = "ROC Plot, " + classifier.getTargetVariable() + " = " + category;
    RocPlot plot = new RocPlot(points, title, info);
    this.rocPlot = plot;
    this.saveRoc.setEnabled(true);
    if (tabIndex == -1) {
        getTabbedPane().add("ROC Plot", plot);
    } else {
        getTabbedPane().add(plot, tabIndex);
        getTabbedPane().setTitleAt(tabIndex, "ROC Plot");
    }
}
Also used : DiscreteVariable(edu.cmu.tetrad.data.DiscreteVariable) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) DataSet(edu.cmu.tetrad.data.DataSet) Node(edu.cmu.tetrad.graph.Node) RocCalculator(edu.cmu.tetrad.util.RocCalculator) NumberFormat(java.text.NumberFormat)

Example 48 with DiscreteVariable

use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.

the class BayesUpdaterClassifierEditor method showClassification.

private void showClassification() {
    int tabIndex = -1;
    for (int i = 0; i < getTabbedPane().getTabCount(); i++) {
        if ("Classification".equals(getTabbedPane().getTitleAt(i))) {
            getTabbedPane().remove(i);
            tabIndex = i;
        }
    }
    // Put the class information into a DataSet.
    int[] classifications = getClassifier().getClassifications();
    double[][] marginals = getClassifier().getMarginals();
    int maxCategory = 0;
    for (int classification : classifications) {
        if (classification > maxCategory) {
            maxCategory = classification;
        }
    }
    List<Node> variables = new LinkedList<>();
    DiscreteVariable targetVariable = classifier.getTargetVariable();
    DiscreteVariable classVar = new DiscreteVariable(targetVariable.getName(), maxCategory + 1);
    variables.add(classVar);
    for (int i = 0; i < marginals.length; i++) {
        String name = "P(" + targetVariable + "=" + i + ")";
        ContinuousVariable scoreVar = new ContinuousVariable(name);
        variables.add(scoreVar);
    }
    classVar.setName("Result");
    DataSet dataSet = new ColtDataSet(classifications.length, variables);
    for (int i = 0; i < classifications.length; i++) {
        dataSet.setInt(i, 0, classifications[i]);
        for (int j = 0; j < marginals.length; j++) {
            dataSet.setDouble(i, j + 1, marginals[j][i]);
        }
    }
    DataDisplay jTable = new DataDisplay(dataSet);
    JScrollPane scroll = new JScrollPane(jTable);
    if (tabIndex == -1) {
        getTabbedPane().add("Classification", scroll);
    } else {
        getTabbedPane().add(scroll, tabIndex);
        getTabbedPane().setTitleAt(tabIndex, "Classification");
    }
}
Also used : ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) DataSet(edu.cmu.tetrad.data.DataSet) Node(edu.cmu.tetrad.graph.Node) LinkedList(java.util.LinkedList) ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) DiscreteVariable(edu.cmu.tetrad.data.DiscreteVariable) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet)

Example 49 with DiscreteVariable

use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.

the class ModeInterpolator method filter.

public DataSet filter(DataSet dataSet) {
    DataSet newDataSet = dataSet.copy();
    for (int j = 0; j < dataSet.getNumColumns(); j++) {
        Node var = dataSet.getVariable(j);
        if (var instanceof DiscreteVariable) {
            DiscreteVariable variable = (DiscreteVariable) var;
            int numCategories = variable.getNumCategories();
            int[] categoryCounts = new int[numCategories];
            for (int i = 0; i < dataSet.getNumRows(); i++) {
                if (dataSet.getInt(i, j) == DiscreteVariable.MISSING_VALUE) {
                    continue;
                }
                categoryCounts[dataSet.getInt(i, j)]++;
            }
            int mode = -1;
            int max = -1;
            for (int i = 0; i < numCategories; i++) {
                if (categoryCounts[i] > max) {
                    max = categoryCounts[i];
                    mode = i;
                }
            }
            for (int i = 0; i < dataSet.getNumRows(); i++) {
                if (dataSet.getInt(i, j) == DiscreteVariable.MISSING_VALUE) {
                    newDataSet.setInt(i, j, mode);
                }
            // else {
            // newDataSet.setInt(i, j, dataSet.getInt(i, j));
            // }
            }
        } else if (dataSet.getVariable(j) instanceof ContinuousVariable) {
            double[] data = new double[dataSet.getNumRows()];
            int k = -1;
            for (int i = 0; i < dataSet.getNumRows(); i++) {
                if (!Double.isNaN(dataSet.getDouble(i, j))) {
                    data[++k] = dataSet.getDouble(i, j);
                }
            }
            Arrays.sort(data);
            double mode = Double.NaN;
            if (k >= 0) {
                mode = (data[(k + 1) / 2] + data[k / 2]) / 2.d;
            }
            for (int i = 0; i < dataSet.getNumRows(); i++) {
                if (Double.isNaN(dataSet.getDouble(i, j))) {
                    newDataSet.setDouble(i, j, mode);
                }
            // else {
            // newDataSet.setDouble(i, j, dataSet.getDouble(i, j));
            // }
            }
        }
    }
    return newDataSet;
}
Also used : ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) DiscreteVariable(edu.cmu.tetrad.data.DiscreteVariable) DataSet(edu.cmu.tetrad.data.DataSet) Node(edu.cmu.tetrad.graph.Node)

Example 50 with DiscreteVariable

use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.

the class BayesPm method setNumCategories.

/**
 * Sets the number of values for the given node to the given number.
 */
public void setNumCategories(Node node, int numCategories) {
    if (!nodesToVariables.containsKey(node)) {
        throw new IllegalArgumentException("Node not in BayesPm: " + node);
    }
    if (numCategories < 1) {
        throw new IllegalArgumentException("Number of categories must be >= 1: " + numCategories);
    }
    DiscreteVariable variable = nodesToVariables.get(node);
    List<String> oldCategories = variable.getCategories();
    List<String> newCategories = new LinkedList<>();
    int min = Math.min(numCategories, oldCategories.size());
    for (int i = 0; i < min; i++) {
        newCategories.add(oldCategories.get(i));
    }
    for (int i = min; i < numCategories; i++) {
        String proposedName = DataUtils.defaultCategory(i);
        if (newCategories.contains(proposedName)) {
            throw new IllegalArgumentException("Default name already in " + "list of categories: " + proposedName);
        }
        newCategories.add(proposedName);
    }
    mapNodeToVariable(node, newCategories);
}
Also used : DiscreteVariable(edu.cmu.tetrad.data.DiscreteVariable)

Aggregations

DiscreteVariable (edu.cmu.tetrad.data.DiscreteVariable)56 Node (edu.cmu.tetrad.graph.Node)37 DataSet (edu.cmu.tetrad.data.DataSet)18 ContinuousVariable (edu.cmu.tetrad.data.ContinuousVariable)16 ColtDataSet (edu.cmu.tetrad.data.ColtDataSet)11 LinkedList (java.util.LinkedList)9 Test (org.junit.Test)5 ArrayList (java.util.ArrayList)4 Dag (edu.cmu.tetrad.graph.Dag)3 NumberFormat (java.text.NumberFormat)3 Element (nu.xom.Element)3 EdgeListGraph (edu.cmu.tetrad.graph.EdgeListGraph)2 Graph (edu.cmu.tetrad.graph.Graph)2 LogisticRegression (edu.cmu.tetrad.regression.LogisticRegression)2 List (java.util.List)2 Elements (nu.xom.Elements)2 DoubleMatrix2D (cern.colt.matrix.DoubleMatrix2D)1 TakesInitialGraph (edu.cmu.tetrad.algcomparison.utils.TakesInitialGraph)1 StoredCellProbs (edu.cmu.tetrad.bayes.StoredCellProbs)1 BoxDataSet (edu.cmu.tetrad.data.BoxDataSet)1