Search in sources :

Example 11 with DiscreteVariable

use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.

the class BayesUpdaterClassifierEditor method getToolbar.

private Component getToolbar() {
    JButton classifyButton = new JButton("Classify");
    classifyButton.addActionListener(new ActionListener() {

        public void actionPerformed(ActionEvent e) {
            Window owner = (Window) getTopLevelAncestor();
            new WatchedProcess(owner) {

                public void watch() {
                    doClassify();
                    showClassification();
                    showRocCurve();
                    showConfusionMatrix();
                }
            };
        }
    });
    List<Node> nodes = getClassifier().getBayesImVars();
    Node[] variables = nodes.toArray(new Node[0]);
    this.variableDropdown = new JComboBox(variables);
    getVariableDropdown().setBackground(Color.WHITE);
    getVariableDropdown().setMaximumSize(new Dimension(200, 50));
    DiscreteVariable variable = (DiscreteVariable) getVariableDropdown().getSelectedItem();
    this.categoryDropdown = new JComboBox(variable.getCategories().toArray(new String[0]));
    getCategoryDropdown().setBackground(Color.WHITE);
    getCategoryDropdown().setMaximumSize(new Dimension(200, 50));
    this.variableDropdown.addItemListener(new ItemListener() {

        public void itemStateChanged(ItemEvent e) {
            JComboBox comboBox = (JComboBox) e.getSource();
            Object selectedItem = comboBox.getSelectedItem();
            DiscreteVariable variable = (DiscreteVariable) selectedItem;
            List<String> categories = variable.getCategories();
            DefaultComboBoxModel newModel = new DefaultComboBoxModel(categories.toArray(new String[0]));
            getCategoryDropdown().setModel(newModel);
        // if (categories.size() == 2) {
        // getBinaryCutoffField().setEnabled(true);
        // getBinaryCutoffField().setEditable(true);
        // }
        // else {
        // getBinaryCutoffField().setEnabled(false);
        // getBinaryCutoffField().setEditable(false);
        // }
        }
    });
    this.categoryDropdown.addItemListener(new ItemListener() {

        public void itemStateChanged(ItemEvent e) {
            showRocCurve();
        }
    });
    // this.binaryCutoffField = new DoubleTextField(getBinaryCutoff(), 5,
    // NumberFormatUtil.getInstance().getNumberFormat());
    // this.binaryCutoffField.setFilter(new DoubleTextField.Filter() {
    // public double filter(double value, double oldValue) {
    // if (value >= 0.0 && value <= 1.0) {
    // setBinaryCutoff(value);
    // return value;
    // }
    // 
    // return oldValue;
    // }
    // });
    // DiscreteVariable selectedVar =
    // (DiscreteVariable) this.variableDropdown.getSelectedItem();
    // List<String> categories = selectedVar.getCategories();
    // 
    // if (categories.size() == 2) {
    // getBinaryCutoffField().setEnabled(true);
    // getBinaryCutoffField().setEditable(true);
    // }
    // else {
    // getBinaryCutoffField().setEnabled(false);
    // getBinaryCutoffField().setEditable(false);
    // }
    Box toolbar = Box.createVerticalBox();
    Box row1 = Box.createHorizontalBox();
    row1.add(Box.createHorizontalStrut(5));
    row1.add(new JLabel("Target = "));
    row1.add(getVariableDropdown());
    row1.add(Box.createHorizontalStrut(5));
    row1.add(new JLabel("Category for ROC ="));
    row1.add(getCategoryDropdown());
    row1.add(Box.createHorizontalStrut(10));
    row1.add(classifyButton);
    row1.add(Box.createHorizontalGlue());
    toolbar.add(row1);
    toolbar.add(Box.createVerticalStrut(5));
    // Box row2 = Box.createHorizontalBox();
    // row2.add(Box.createHorizontalStrut(5));
    // row2.add(new JLabel("(Cutoff for binary target = "));
    // row2.add(getBinaryCutoffField());
    // row2.add(new JLabel(" )"));
    // row2.add(Box.createHorizontalGlue());
    // toolbar.add(row2);
    toolbar.setBorder(new EmptyBorder(2, 2, 2, 2));
    return toolbar;
}
Also used : ItemEvent(java.awt.event.ItemEvent) ActionEvent(java.awt.event.ActionEvent) WatchedProcess(edu.cmu.tetradapp.util.WatchedProcess) Node(edu.cmu.tetrad.graph.Node) DiscreteVariable(edu.cmu.tetrad.data.DiscreteVariable) ActionListener(java.awt.event.ActionListener) ItemListener(java.awt.event.ItemListener) List(java.util.List) LinkedList(java.util.LinkedList) EmptyBorder(javax.swing.border.EmptyBorder)

Example 12 with DiscreteVariable

use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.

the class MGM method runTests1.

private static void runTests1() {
    try {
        // DoubleMatrix2D xIn = DoubleFactory2D.dense.make(loadDataSelect("/Users/ajsedgewick/tetrad/test_data", "med_test_C.txt"));
        // DoubleMatrix2D yIn = DoubleFactory2D.dense.make(loadDataSelect("/Users/ajsedgewick/tetrad/test_data", "med_test_D.txt"));
        // String path = MGM.class.getResource("test_data").getPath();
        String path = "/Users/ajsedgewick/tetrad_master/tetrad/tetrad-lib/src/main/java/edu/pitt/csb/mgm/test_data";
        System.out.println(path);
        DoubleMatrix2D xIn = DoubleFactory2D.dense.make(MixedUtils.loadDelim(path, "med_test_C.txt").getDoubleData().toArray());
        DoubleMatrix2D yIn = DoubleFactory2D.dense.make(MixedUtils.loadDelim(path, "med_test_D.txt").getDoubleData().toArray());
        int[] L = new int[24];
        Node[] vars = new Node[48];
        for (int i = 0; i < 24; i++) {
            L[i] = 2;
            vars[i] = new ContinuousVariable("X" + i);
            vars[i + 24] = new DiscreteVariable("Y" + i);
        }
        double lam = .2;
        MGM model = new MGM(xIn, yIn, new ArrayList<>(Arrays.asList(vars)), L, new double[] { lam, lam, lam });
        MGM model2 = new MGM(xIn, yIn, new ArrayList<>(Arrays.asList(vars)), L, new double[] { lam, lam, lam });
        System.out.println("Weights: " + Arrays.toString(model.weights.toArray()));
        DoubleMatrix2D test = xIn.copy();
        DoubleMatrix2D test2 = xIn.copy();
        long t = System.currentTimeMillis();
        for (int i = 0; i < 50000; i++) {
            test2 = xIn.copy();
            test.assign(test2);
        }
        System.out.println("assign Time: " + (System.currentTimeMillis() - t));
        t = System.currentTimeMillis();
        double[][] xArr = xIn.toArray();
        for (int i = 0; i < 50000; i++) {
            if (Thread.currentThread().isInterrupted()) {
                break;
            }
            // test = DoubleFactory2D.dense.make(xArr);
            test2 = xIn.copy();
            test = test2;
        }
        System.out.println("equals Time: " + (System.currentTimeMillis() - t));
        System.out.println("Init nll: " + model.smoothValue(model.params.toMatrix1D()));
        System.out.println("Init reg term: " + model.nonSmoothValue(model.params.toMatrix1D()));
        t = System.currentTimeMillis();
        model.learnEdges(700);
        // model.learn(1e-7, 700);
        System.out.println("Orig Time: " + (System.currentTimeMillis() - t));
        System.out.println("nll: " + model.smoothValue(model.params.toMatrix1D()));
        System.out.println("reg term: " + model.nonSmoothValue(model.params.toMatrix1D()));
        System.out.println("params:\n" + model.params);
        System.out.println("adjMat:\n" + model.adjMatFromMGM());
    } catch (IOException ex) {
        ex.printStackTrace();
    }
}
Also used : IOException(java.io.IOException) ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) DiscreteVariable(edu.cmu.tetrad.data.DiscreteVariable) DoubleMatrix2D(cern.colt.matrix.DoubleMatrix2D)

Example 13 with DiscreteVariable

use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.

the class DataConvertUtils method toMixedDataBox.

public static DataModel toMixedDataBox(MixedTabularDataset mixedTabularDataset) {
    int numOfRows = mixedTabularDataset.getNumOfRows();
    MixedVarInfo[] mixedVarInfos = mixedTabularDataset.getMixedVarInfos();
    double[][] continuousData = mixedTabularDataset.getContinuousData();
    int[][] discreteData = mixedTabularDataset.getDiscreteData();
    List<Node> nodes = new LinkedList<>();
    for (MixedVarInfo mixedVarInfo : mixedVarInfos) {
        if (mixedVarInfo.isContinuous()) {
            nodes.add(new ContinuousVariable(mixedVarInfo.getName()));
        } else {
            nodes.add(new DiscreteVariable(mixedVarInfo.getName(), mixedVarInfo.getCategories()));
        }
    }
    return new BoxDataSet(new MixedDataBox(nodes, numOfRows, continuousData, discreteData), nodes);
}
Also used : ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) DiscreteVariable(edu.cmu.tetrad.data.DiscreteVariable) Node(edu.cmu.tetrad.graph.Node) BoxDataSet(edu.cmu.tetrad.data.BoxDataSet) MixedDataBox(edu.cmu.tetrad.data.MixedDataBox) LinkedList(java.util.LinkedList) MixedVarInfo(edu.pitt.dbmi.data.reader.tabular.MixedVarInfo)

Example 14 with DiscreteVariable

use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.

the class TestLogisticRegression method test1.

@Test
public void test1() {
    List<Node> nodes = new ArrayList<>();
    for (int i = 0; i < 5; i++) {
        nodes.add(new ContinuousVariable("X" + (i + 1)));
    }
    Graph graph = new Dag(GraphUtils.randomGraph(nodes, 0, 5, 3, 3, 3, false));
    System.out.println(graph);
    SemPm pm = new SemPm(graph);
    SemIm im = new SemIm(pm);
    DataSet data = im.simulateDataRecursive(1000, false);
    Node x1 = data.getVariable("X1");
    Node x2 = data.getVariable("X2");
    Node x3 = data.getVariable("X3");
    Node x4 = data.getVariable("X4");
    Node x5 = data.getVariable("X5");
    Discretizer discretizer = new Discretizer(data);
    discretizer.equalCounts(x1, 2);
    DataSet d2 = discretizer.discretize();
    LogisticRegression regression = new LogisticRegression(d2);
    List<Node> regressors = new ArrayList<>();
    regressors.add(x2);
    regressors.add(x3);
    regressors.add(x4);
    regressors.add(x5);
    DiscreteVariable x1b = (DiscreteVariable) d2.getVariable("X1");
    regression.regress(x1b, regressors);
    System.out.println(regression);
}
Also used : DataSet(edu.cmu.tetrad.data.DataSet) Node(edu.cmu.tetrad.graph.Node) ArrayList(java.util.ArrayList) Dag(edu.cmu.tetrad.graph.Dag) Discretizer(edu.cmu.tetrad.data.Discretizer) ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) DiscreteVariable(edu.cmu.tetrad.data.DiscreteVariable) Graph(edu.cmu.tetrad.graph.Graph) SemPm(edu.cmu.tetrad.sem.SemPm) LogisticRegression(edu.cmu.tetrad.regression.LogisticRegression) SemIm(edu.cmu.tetrad.sem.SemIm) Test(org.junit.Test)

Example 15 with DiscreteVariable

use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.

the class AdLeafTree method getCellLeaves.

/**
 * Finds the set of indices into the leaves of the tree for the given variables.
 * Counts are the sizes of the index sets.
 *
 * @param A A list of discrete variables.
 * @return The list of index sets of the first variable varied by the second variable,
 * and so on, to the last variable.
 */
public List<List<Integer>> getCellLeaves(List<DiscreteVariable> A) {
    Collections.sort(A, new Comparator<DiscreteVariable>() {

        @Override
        public int compare(DiscreteVariable o1, DiscreteVariable o2) {
            return Integer.compare(nodesHash.get(o1), nodesHash.get(o2));
        }
    });
    if (baseCase == null) {
        Vary vary = new Vary();
        this.baseCase = new ArrayList<>();
        baseCase.add(vary);
    }
    List<Vary> varies = baseCase;
    for (DiscreteVariable v : A) {
        varies = getVaries(varies, nodesHash.get(v));
    }
    List<List<Integer>> rows = new ArrayList<>();
    for (Vary vary : varies) {
        rows.addAll(vary.getRows());
    }
    return rows;
}
Also used : DiscreteVariable(edu.cmu.tetrad.data.DiscreteVariable)

Aggregations

DiscreteVariable (edu.cmu.tetrad.data.DiscreteVariable)56 Node (edu.cmu.tetrad.graph.Node)37 DataSet (edu.cmu.tetrad.data.DataSet)18 ContinuousVariable (edu.cmu.tetrad.data.ContinuousVariable)16 ColtDataSet (edu.cmu.tetrad.data.ColtDataSet)11 LinkedList (java.util.LinkedList)9 Test (org.junit.Test)5 ArrayList (java.util.ArrayList)4 Dag (edu.cmu.tetrad.graph.Dag)3 NumberFormat (java.text.NumberFormat)3 Element (nu.xom.Element)3 EdgeListGraph (edu.cmu.tetrad.graph.EdgeListGraph)2 Graph (edu.cmu.tetrad.graph.Graph)2 LogisticRegression (edu.cmu.tetrad.regression.LogisticRegression)2 List (java.util.List)2 Elements (nu.xom.Elements)2 DoubleMatrix2D (cern.colt.matrix.DoubleMatrix2D)1 TakesInitialGraph (edu.cmu.tetrad.algcomparison.utils.TakesInitialGraph)1 StoredCellProbs (edu.cmu.tetrad.bayes.StoredCellProbs)1 BoxDataSet (edu.cmu.tetrad.data.BoxDataSet)1