Search in sources :

Example 6 with DiscreteVariable

use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.

the class Mgm method search.

@Override
public Graph search(DataModel ds, Parameters parameters) {
    // Notify the user that you need at least one continuous and one discrete variable to run MGM
    List<Node> variables = ds.getVariables();
    boolean hasContinuous = false;
    boolean hasDiscrete = false;
    for (Node node : variables) {
        if (node instanceof ContinuousVariable) {
            hasContinuous = true;
        }
        if (node instanceof DiscreteVariable) {
            hasDiscrete = true;
        }
    }
    if (!hasContinuous || !hasDiscrete) {
        throw new IllegalArgumentException("You need at least one continuous and one discrete variable to run MGM.");
    }
    if (parameters.getInt("bootstrapSampleSize") < 1) {
        DataSet _ds = DataUtils.getMixedDataSet(ds);
        double mgmParam1 = parameters.getDouble("mgmParam1");
        double mgmParam2 = parameters.getDouble("mgmParam2");
        double mgmParam3 = parameters.getDouble("mgmParam3");
        double[] lambda = { mgmParam1, mgmParam2, mgmParam3 };
        MGM m = new MGM(_ds, lambda);
        return m.search();
    } else {
        Mgm algorithm = new Mgm();
        DataSet data = (DataSet) ds;
        GeneralBootstrapTest search = new GeneralBootstrapTest(data, algorithm, parameters.getInt("bootstrapSampleSize"));
        BootstrapEdgeEnsemble edgeEnsemble = BootstrapEdgeEnsemble.Highest;
        switch(parameters.getInt("bootstrapEnsemble", 1)) {
            case 0:
                edgeEnsemble = BootstrapEdgeEnsemble.Preserved;
                break;
            case 1:
                edgeEnsemble = BootstrapEdgeEnsemble.Highest;
                break;
            case 2:
                edgeEnsemble = BootstrapEdgeEnsemble.Majority;
        }
        search.setEdgeEnsemble(edgeEnsemble);
        search.setParameters(parameters);
        search.setVerbose(parameters.getBoolean("verbose"));
        return search.search();
    }
}
Also used : GeneralBootstrapTest(edu.pitt.dbmi.algo.bootstrap.GeneralBootstrapTest) BootstrapEdgeEnsemble(edu.pitt.dbmi.algo.bootstrap.BootstrapEdgeEnsemble) DataSet(edu.cmu.tetrad.data.DataSet) Node(edu.cmu.tetrad.graph.Node) ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) DiscreteVariable(edu.cmu.tetrad.data.DiscreteVariable) MGM(edu.pitt.csb.mgm.MGM)

Example 7 with DiscreteVariable

use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.

the class IndTestMultinomialLogisticRegression method expandVariable.

private List<Node> expandVariable(DataSet dataSet, Node node) {
    if (node instanceof ContinuousVariable) {
        return Collections.singletonList(node);
    }
    if (node instanceof DiscreteVariable && ((DiscreteVariable) node).getNumCategories() < 3) {
        return Collections.singletonList(node);
    }
    if (!(node instanceof DiscreteVariable)) {
        throw new IllegalArgumentException();
    }
    List<String> varCats = new ArrayList<>(((DiscreteVariable) node).getCategories());
    varCats.remove(0);
    List<Node> variables = new ArrayList<>();
    for (String cat : varCats) {
        Node newVar;
        do {
            String newVarName = node.getName() + "MULTINOM" + "." + cat;
            newVar = new DiscreteVariable(newVarName, 2);
        } while (dataSet.getVariable(newVar.getName()) != null);
        variables.add(newVar);
        dataSet.addVariable(newVar);
        int newVarIndex = dataSet.getColumn(newVar);
        int numCases = dataSet.getNumRows();
        for (int l = 0; l < numCases; l++) {
            Object dataCell = dataSet.getObject(l, dataSet.getColumn(node));
            int dataCellIndex = ((DiscreteVariable) node).getIndex(dataCell.toString());
            if (dataCellIndex == ((DiscreteVariable) node).getIndex(cat))
                dataSet.setInt(l, newVarIndex, 1);
            else
                dataSet.setInt(l, newVarIndex, 0);
        }
    }
    return variables;
}
Also used : ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) DiscreteVariable(edu.cmu.tetrad.data.DiscreteVariable) Node(edu.cmu.tetrad.graph.Node)

Example 8 with DiscreteVariable

use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.

the class BayesPmEditorWizard method copyCategories.

private void copyCategories() {
    Node node = (Node) variableChooser.getSelectedItem();
    DiscreteVariable variable = (DiscreteVariable) bayesPm.getVariable(node);
    this.copiedCategories = variable.getCategories();
}
Also used : DiscreteVariable(edu.cmu.tetrad.data.DiscreteVariable) Node(edu.cmu.tetrad.graph.Node)

Example 9 with DiscreteVariable

use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.

the class BayesUpdaterClassifierEditor method doClassify.

private void doClassify() {
    DiscreteVariable variable = (DiscreteVariable) getVariableDropdown().getSelectedItem();
    String varName = variable.getName();
    String category = (String) getCategoryDropdown().getSelectedItem();
    int catIndex = variable.getIndex(category);
    getClassifier().setTarget(varName, catIndex);
    getClassifier().classify();
}
Also used : DiscreteVariable(edu.cmu.tetrad.data.DiscreteVariable)

Example 10 with DiscreteVariable

use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.

the class BayesUpdaterClassifierEditor method showConfusionMatrix.

private void showConfusionMatrix() {
    int tabIndex = -1;
    for (int i = 0; i < getTabbedPane().getTabCount(); i++) {
        if ("Confusion Matrix".equals(getTabbedPane().getTitleAt(i))) {
            getTabbedPane().remove(i);
            tabIndex = i;
        }
    }
    StringBuilder buf = new StringBuilder();
    int[][] crossTabs = getClassifier().crossTabulation();
    // this case, don't put the confusion matrix back in.
    if (crossTabs == null) {
        return;
    }
    DiscreteVariable targetVariable = getClassifier().getTargetVariable();
    int nvalues = targetVariable.getNumCategories();
    int ncases = getClassifier().getNumCases();
    int ntot = getClassifier().getTotalUsableCases();
    // System.out.println("Number correct = " + numCorrect);
    // buf.append("<html><pre>");
    buf.append("Total number of usable cases = ");
    buf.append(ntot);
    buf.append(" out of ");
    buf.append(ncases);
    buf.append("\n\nTarget Variable ");
    buf.append(targetVariable);
    buf.append("\n\t\tEstimated\t");
    buf.append("\nObserved\t");
    for (int i = 0; i < nvalues - 1; i++) {
        buf.append(targetVariable.getCategory(i));
        buf.append("\t");
    }
    buf.append(targetVariable.getCategory(nvalues - 1));
    for (int i = 0; i < nvalues; i++) {
        buf.append("\n");
        buf.append(targetVariable.getCategory(i));
        buf.append("\t");
        for (int j = 0; j < nvalues - 1; j++) {
            buf.append(crossTabs[i][j]);
            buf.append("\t");
        }
        buf.append(crossTabs[i][nvalues - 1]);
    }
    buf.append("\n\nPercentage correctly classified:  ");
    buf.append(getClassifier().getPercentCorrect());
    // buf.append("</pre></html>");
    JTextArea label = new JTextArea(buf.toString());
    // label.setFocusable(false);
    label.setFont(new Font("SansSerif", Font.PLAIN, 14));
    JPanel panel = new JPanel();
    panel.setLayout(new BorderLayout());
    panel.setBackground(Color.WHITE);
    Box b1 = Box.createVerticalBox();
    Box b2 = Box.createHorizontalBox();
    b2.add(Box.createHorizontalStrut(5));
    b2.add(label);
    b2.add(Box.createHorizontalGlue());
    b1.add(b2);
    b1.add(Box.createVerticalGlue());
    b1.add(Box.createVerticalGlue());
    panel.add(b1, BorderLayout.CENTER);
    JScrollPane scroll = new JScrollPane(panel);
    if (tabIndex == -1) {
        getTabbedPane().add("Confusion Matrix", scroll);
    } else {
        getTabbedPane().add(scroll, tabIndex);
        getTabbedPane().setTitleAt(tabIndex, "Confusion Matrix");
    }
}
Also used : DiscreteVariable(edu.cmu.tetrad.data.DiscreteVariable)

Aggregations

DiscreteVariable (edu.cmu.tetrad.data.DiscreteVariable)56 Node (edu.cmu.tetrad.graph.Node)37 DataSet (edu.cmu.tetrad.data.DataSet)18 ContinuousVariable (edu.cmu.tetrad.data.ContinuousVariable)16 ColtDataSet (edu.cmu.tetrad.data.ColtDataSet)11 LinkedList (java.util.LinkedList)9 Test (org.junit.Test)5 ArrayList (java.util.ArrayList)4 Dag (edu.cmu.tetrad.graph.Dag)3 NumberFormat (java.text.NumberFormat)3 Element (nu.xom.Element)3 EdgeListGraph (edu.cmu.tetrad.graph.EdgeListGraph)2 Graph (edu.cmu.tetrad.graph.Graph)2 LogisticRegression (edu.cmu.tetrad.regression.LogisticRegression)2 List (java.util.List)2 Elements (nu.xom.Elements)2 DoubleMatrix2D (cern.colt.matrix.DoubleMatrix2D)1 TakesInitialGraph (edu.cmu.tetrad.algcomparison.utils.TakesInitialGraph)1 StoredCellProbs (edu.cmu.tetrad.bayes.StoredCellProbs)1 BoxDataSet (edu.cmu.tetrad.data.BoxDataSet)1