Search in sources :

Example 1 with IndTestDSep

use of edu.cmu.tetrad.search.IndTestDSep in project tetrad by cmu-phil.

the class IndependenceFactsAction method actionPerformed.

// ========================PUBLIC METHODS==========================//
/**
 * Performs the action of opening a session from a file.
 */
public void actionPerformed(ActionEvent e) {
    this.independenceTest = getIndTestProducer().getIndependenceTest();
    final List<String> varNames = new ArrayList<>();
    varNames.add("VAR");
    varNames.addAll(getDataVars());
    varNames.add("?");
    varNames.add("+");
    final JComboBox variableBox = new JComboBox();
    DefaultComboBoxModel aModel1 = new DefaultComboBoxModel(varNames.toArray(new String[varNames.size()]));
    aModel1.setSelectedItem("VAR");
    variableBox.setModel(aModel1);
    // variableBox.addMouseListener(new MouseAdapter() {
    // public void mouseClicked(MouseEvent e) {
    // System.out.println(e);
    // }
    // });
    variableBox.addActionListener(new ActionListener() {

        public void actionPerformed(ActionEvent e) {
            JComboBox box = (JComboBox) e.getSource();
            String var = (String) box.getSelectedItem();
            LinkedList<String> vars = getVars();
            int size = vars.size();
            if ("VAR".equals(var)) {
                return;
            }
            for (int i = 2; i < getVars().size() - 1; i++) {
                if (wildcard(i)) {
                    if (!("?".equals(var) || "+".equals(var))) {
                        JOptionPane.showMessageDialog(centeringComp, "Please specify wildcards after other variables (e.g. X _||_ ? | Y, +)");
                        return;
                    }
                }
            }
            if ("?".equals(var)) {
                if (size >= 0 && !vars.contains("+")) {
                    vars.addLast(var);
                }
            } else if ("+".equals(var)) {
                if (size >= 2) {
                    vars.addLast(var);
                }
            } else if ((vars.indexOf("?") < 2) && !(vars.contains("+")) && !(vars.contains(var))) {
                vars.add(var);
            }
            if (wildcard(0) && vars.size() >= 2 && !wildcard(1)) {
                JOptionPane.showMessageDialog(centeringComp, "Please specify wildcards after other variables (e.g. X _||_ ? | Y, +)");
                return;
            }
            resetText();
            // This is a workaround to an introduced bug in the JDK whereby
            // repeated selections of the same item send out just one
            // action event.
            DefaultComboBoxModel aModel = new DefaultComboBoxModel(varNames.toArray(new String[varNames.size()]));
            aModel.setSelectedItem("VAR");
            variableBox.setModel(aModel);
        }
    });
    final JButton delete = new JButton("Delete");
    delete.addActionListener(new ActionListener() {

        public void actionPerformed(ActionEvent e) {
            if (!getVars().isEmpty()) {
                getVars().removeLast();
                resetText();
            }
        }
    });
    textField.addKeyListener(new KeyAdapter() {

        public void keyTyped(KeyEvent e) {
            if ('?' == e.getKeyChar()) {
                variableBox.setSelectedItem("?");
            } else if ('+' == e.getKeyChar()) {
                variableBox.setSelectedItem("+");
            } else if ('\b' == e.getKeyChar()) {
                vars.removeLast();
                resetText();
            }
            e.consume();
        }
    });
    delete.addKeyListener(new KeyAdapter() {

        public void keyTyped(KeyEvent e) {
            if ('?' == e.getKeyChar()) {
                variableBox.setSelectedItem("?");
            } else if ('+' == e.getKeyChar()) {
                variableBox.setSelectedItem("+");
            } else if ('\b' == e.getKeyChar()) {
                vars.removeLast();
                resetText();
            }
        }
    });
    variableBox.addKeyListener(new KeyAdapter() {

        public void keyTyped(KeyEvent e) {
            super.keyTyped(e);
            if ('\b' == e.getKeyChar()) {
                vars.removeLast();
                resetText();
            }
        }
    });
    JButton list = new JButton("LIST");
    list.setFont(new Font("Dialog", Font.BOLD, 14));
    list.addActionListener(new ActionListener() {

        public void actionPerformed(ActionEvent e) {
            generateResults();
        }
    });
    Box b1 = Box.createVerticalBox();
    Box b2 = Box.createHorizontalBox();
    b2.add(new JLabel("Test: "));
    b2.add(new JLabel(getIndependenceTest().toString()));
    b2.add(Box.createHorizontalGlue());
    b1.add(b2);
    b1.add(Box.createVerticalStrut(10));
    Box b3 = Box.createHorizontalBox();
    b3.add(getTextField());
    b3.add(variableBox);
    b3.add(delete);
    b1.add(b3);
    b1.add(Box.createVerticalStrut(10));
    tableModel = new AbstractTableModel() {

        public String getColumnName(int column) {
            if (column == 0) {
                return "Index";
            }
            if (column == 1) {
                if (independenceTest instanceof IndTestDSep) {
                    return "D-Separation Relation";
                } else {
                    return "Independence Relation";
                }
            } else if (column == 2) {
                return "Judgment";
            } else if (column == 3) {
                return "P Value";
            }
            return null;
        }

        public int getColumnCount() {
            if (usesDSeparation()) {
                return 3;
            } else {
                return 4;
            }
        }

        public int getRowCount() {
            return getResults().size();
        }

        public Object getValueAt(int rowIndex, int columnIndex) {
            Result result = getResults().get(rowIndex);
            if (columnIndex == 0) {
                return result.getIndex() + 1;
            }
            if (columnIndex == 1) {
                return result.getFact();
            } else if (columnIndex == 2) {
                if (independenceTest instanceof IndTestDSep) {
                    if (result.getType() == Result.Type.INDEPENDENT) {
                        return "D-Separated";
                    } else if (result.getType() == Result.Type.DEPENDENT) {
                        return "D-Connected";
                    } else if (result.getType() == Result.Type.UNDETERMINED) {
                        return "*";
                    }
                // return result.getType() ? "D-Separated" : "D-Connected";
                } else {
                    if (result.getType() == Result.Type.INDEPENDENT) {
                        return "Independent";
                    } else if (result.getType() == Result.Type.DEPENDENT) {
                        return "Dependent";
                    } else if (result.getType() == Result.Type.UNDETERMINED) {
                        return "*";
                    }
                // return result.getType() ? "Independent" : "Dependent";
                }
            } else if (columnIndex == 3) {
                return nf.format(result.getpValue());
            }
            return null;
        }

        public Class getColumnClass(int columnIndex) {
            if (columnIndex == 0) {
                return Number.class;
            }
            if (columnIndex == 1) {
                return String.class;
            } else if (columnIndex == 2) {
                return Number.class;
            } else if (columnIndex == 3) {
                return Number.class;
            }
            return null;
        }
    };
    JTable table = new JTable(tableModel);
    table.getColumnModel().getColumn(0).setMinWidth(40);
    table.getColumnModel().getColumn(0).setMaxWidth(40);
    table.getColumnModel().getColumn(1).setMinWidth(200);
    table.getColumnModel().getColumn(2).setMinWidth(100);
    table.getColumnModel().getColumn(2).setMaxWidth(100);
    if (!(usesDSeparation())) {
        table.getColumnModel().getColumn(3).setMinWidth(80);
        table.getColumnModel().getColumn(3).setMaxWidth(80);
    }
    JTableHeader header = table.getTableHeader();
    header.addMouseListener(new MouseAdapter() {

        public void mouseClicked(MouseEvent e) {
            JTableHeader header = (JTableHeader) e.getSource();
            Point point = e.getPoint();
            int col = header.columnAtPoint(point);
            int sortCol = header.getTable().convertColumnIndexToModel(col);
            sortByColumn(sortCol, true);
        }
    });
    JScrollPane scroll = new JScrollPane(table);
    scroll.setPreferredSize(new Dimension(400, 400));
    b1.add(scroll);
    Box b4 = Box.createHorizontalBox();
    b4.add(new JLabel("Limit list to "));
    IntTextField field = new IntTextField(getListLimit(), 7);
    field.setFilter(new IntTextField.Filter() {

        public int filter(int value, int oldValue) {
            try {
                setListLimit(value);
                return value;
            } catch (Exception e) {
                return oldValue;
            }
        }
    });
    b4.add(field);
    b4.add(new JLabel(" items."));
    b4.add(Box.createHorizontalGlue());
    b4.add(list);
    b1.add(b4);
    b1.add(Box.createVerticalStrut(10));
    JPanel panel = new JPanel();
    panel.setLayout(new BorderLayout());
    panel.add(b1, BorderLayout.CENTER);
    panel.setBorder(new EmptyBorder(10, 10, 10, 10));
    EditorWindow editorWindow = new EditorWindow(panel, "Independence Facts", "Save", false, centeringComp);
    DesktopController.getInstance().addEditorWindow(editorWindow, JLayeredPane.PALETTE_LAYER);
    editorWindow.setVisible(true);
    // Set the ok button so that pressing enter activates it.
    // jdramsey 5/5/02
    JRootPane root = SwingUtilities.getRootPane(editorWindow);
    if (root != null) {
        root.setDefaultButton(list);
    }
}
Also used : JTableHeader(javax.swing.table.JTableHeader) IndTestDSep(edu.cmu.tetrad.search.IndTestDSep) IntTextField(edu.cmu.tetradapp.util.IntTextField) EmptyBorder(javax.swing.border.EmptyBorder) AbstractTableModel(javax.swing.table.AbstractTableModel)

Example 2 with IndTestDSep

use of edu.cmu.tetrad.search.IndTestDSep in project tetrad by cmu-phil.

the class TestIndTestWaldLR method testIsIndependent.

@Test
public void testIsIndependent() {
    RandomUtil.getInstance().setSeed(1450705713157L);
    int numPassed = 0;
    for (int i = 0; i < 10; i++) {
        List<Node> nodes = new ArrayList<>();
        for (int i1 = 0; i1 < 10; i1++) {
            nodes.add(new ContinuousVariable("X" + (i1 + 1)));
        }
        Graph graph = GraphUtils.randomGraph(nodes, 0, 10, 3, 3, 3, false);
        SemPm pm = new SemPm(graph);
        SemIm im = new SemIm(pm);
        DataSet data = im.simulateData(1000, false);
        Discretizer discretizer = new Discretizer(data);
        discretizer.setVariablesCopied(true);
        discretizer.equalCounts(data.getVariable(0), 2);
        discretizer.equalCounts(data.getVariable(3), 2);
        data = discretizer.discretize();
        Node x1 = data.getVariable("X1");
        Node x2 = data.getVariable("X2");
        Node x3 = data.getVariable("X3");
        Node x4 = data.getVariable("X4");
        Node x5 = data.getVariable("X5");
        List<Node> cond = new ArrayList<>();
        cond.add(x3);
        cond.add(x4);
        cond.add(x5);
        Node x1Graph = graph.getNode(x1.getName());
        Node x2Graph = graph.getNode(x2.getName());
        List<Node> condGraph = new ArrayList<>();
        for (Node node : cond) {
            condGraph.add(graph.getNode(node.getName()));
        }
        // Using the Wald LR test since it's most up to date.
        IndependenceTest test = new IndTestMultinomialLogisticRegressionWald(data, 0.05, false);
        IndTestDSep dsep = new IndTestDSep(graph);
        boolean correct = test.isIndependent(x2, x1, cond) == dsep.isIndependent(x2Graph, x1Graph, condGraph);
        if (correct) {
            numPassed++;
        }
    }
    // System.out.println(RandomUtil.getInstance().getSeed());
    // Do not always get all 10.
    assertEquals(10, numPassed);
}
Also used : DataSet(edu.cmu.tetrad.data.DataSet) Node(edu.cmu.tetrad.graph.Node) ArrayList(java.util.ArrayList) Discretizer(edu.cmu.tetrad.data.Discretizer) ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) IndependenceTest(edu.cmu.tetrad.search.IndependenceTest) IndTestDSep(edu.cmu.tetrad.search.IndTestDSep) Graph(edu.cmu.tetrad.graph.Graph) SemPm(edu.cmu.tetrad.sem.SemPm) IndTestMultinomialLogisticRegressionWald(edu.pitt.csb.mgm.IndTestMultinomialLogisticRegressionWald) SemIm(edu.cmu.tetrad.sem.SemIm) Test(org.junit.Test) IndependenceTest(edu.cmu.tetrad.search.IndependenceTest)

Example 3 with IndTestDSep

use of edu.cmu.tetrad.search.IndTestDSep in project tetrad by cmu-phil.

the class TestMbfs method testGenerateDaglist.

/**
 * Tests to make sure the algorithm for generating MB DAGs from an MB Pattern works, at least for one kind of tricky
 * case.
 */
@Test
public void testGenerateDaglist() {
    Graph graph = GraphConverter.convert("T-->X1,T-->X2,X1-->X2,T-->X3,X4-->T");
    IndTestDSep test = new IndTestDSep(graph);
    Mbfs search = new Mbfs(test, -1);
    Graph resultGraph = search.search("T");
    List mbDags = MbUtils.generateMbDags(resultGraph, true, search.getTest(), search.getDepth(), search.getTarget());
    assertTrue(mbDags.size() == 9);
    assertTrue(mbDags.contains(graph));
}
Also used : IndTestDSep(edu.cmu.tetrad.search.IndTestDSep) Mbfs(edu.cmu.tetrad.search.Mbfs) ArrayList(java.util.ArrayList) List(java.util.List) Test(org.junit.Test) IndependenceTest(edu.cmu.tetrad.search.IndependenceTest)

Example 4 with IndTestDSep

use of edu.cmu.tetrad.search.IndTestDSep in project tetrad by cmu-phil.

the class TestMbfs method testRandom.

@Test
public void testRandom() {
    RandomUtil.getInstance().setSeed(8388428832L);
    List<Node> nodes1 = new ArrayList<>();
    for (int i = 0; i < 10; i++) {
        nodes1.add(new ContinuousVariable("X" + (i + 1)));
    }
    Dag dag = new Dag(GraphUtils.randomGraph(nodes1, 0, 10, 5, 5, 5, false));
    IndependenceTest test = new IndTestDSep(dag);
    Mbfs search = new Mbfs(test, -1);
    List<Node> nodes = dag.getNodes();
    for (Node node : nodes) {
        Graph resultMb = search.search(node.getName());
        Graph trueMb = GraphUtils.markovBlanketDag(node, dag);
        List<Node> resultNodes = resultMb.getNodes();
        List<Node> trueNodes = trueMb.getNodes();
        Set<String> resultNames = new HashSet<>();
        for (Node resultNode : resultNodes) {
            resultNames.add(resultNode.getName());
        }
        Set<String> trueNames = new HashSet<>();
        for (Node v : trueNodes) {
            trueNames.add(v.getName());
        }
        assertTrue(resultNames.equals(trueNames));
        Set<Edge> resultEdges = resultMb.getEdges();
        for (Edge resultEdge : resultEdges) {
            if (Edges.isDirectedEdge(resultEdge)) {
                String name1 = resultEdge.getNode1().getName();
                String name2 = resultEdge.getNode2().getName();
                Node node1 = trueMb.getNode(name1);
                Node node2 = trueMb.getNode(name2);
                // possibility that the node is actually a child.
                if (node1 == null) {
                    fail("Node " + name1 + " is not in the true graph.");
                }
                if (node2 == null) {
                    fail("Node " + name2 + " is not in the true graph.");
                }
                Edge trueEdge = trueMb.getEdge(node1, node2);
                if (trueEdge == null) {
                    Node resultNode1 = resultMb.getNode(node1.getName());
                    Node resultNode2 = resultMb.getNode(node2.getName());
                    Node resultTarget = resultMb.getNode(node.getName());
                    Edge a = resultMb.getEdge(resultNode1, resultTarget);
                    Edge b = resultMb.getEdge(resultNode2, resultTarget);
                    if (a == null || b == null) {
                        continue;
                    }
                    if ((Edges.isDirectedEdge(a) && Edges.isUndirectedEdge(b)) || (Edges.isUndirectedEdge(a) && Edges.isDirectedEdge(b))) {
                        continue;
                    }
                    fail("EXTRA EDGE: Edge in result MB but not true MB = " + resultEdge);
                }
                assertEquals(resultEdge.getEndpoint1(), trueEdge.getEndpoint1());
                assertEquals(resultEdge.getEndpoint2(), trueEdge.getEndpoint2());
            }
        }
    }
}
Also used : Mbfs(edu.cmu.tetrad.search.Mbfs) ArrayList(java.util.ArrayList) ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) IndependenceTest(edu.cmu.tetrad.search.IndependenceTest) IndTestDSep(edu.cmu.tetrad.search.IndTestDSep) HashSet(java.util.HashSet) Test(org.junit.Test) IndependenceTest(edu.cmu.tetrad.search.IndependenceTest)

Example 5 with IndTestDSep

use of edu.cmu.tetrad.search.IndTestDSep in project tetrad by cmu-phil.

the class TestPcd method checkSearch.

/**
 * Presents the input graph to FCI and checks to make sure the output of FCI is equivalent to the given output
 * graph.
 */
private void checkSearch(String inputGraph, String outputGraph) {
    // Set up graph and node objects.
    Graph graph = GraphConverter.convert(inputGraph);
    // Set up search.
    IndependenceTest independence = new IndTestDSep(graph);
    Pcd pc = new Pcd(independence);
    // Run search
    Graph resultGraph = pc.search();
    // Build comparison graph.
    Graph trueGraph = GraphConverter.convert(outputGraph);
    resultGraph = GraphUtils.replaceNodes(resultGraph, trueGraph.getNodes());
    // Do test.
    assertTrue(resultGraph.equals(trueGraph));
}
Also used : Pcd(edu.cmu.tetrad.search.Pcd) IndependenceTest(edu.cmu.tetrad.search.IndependenceTest) IndTestDSep(edu.cmu.tetrad.search.IndTestDSep) Graph(edu.cmu.tetrad.graph.Graph)

Aggregations

IndTestDSep (edu.cmu.tetrad.search.IndTestDSep)7 IndependenceTest (edu.cmu.tetrad.search.IndependenceTest)5 Graph (edu.cmu.tetrad.graph.Graph)3 ArrayList (java.util.ArrayList)3 Test (org.junit.Test)3 ContinuousVariable (edu.cmu.tetrad.data.ContinuousVariable)2 Mbfs (edu.cmu.tetrad.search.Mbfs)2 IntTextField (edu.cmu.tetradapp.util.IntTextField)2 EmptyBorder (javax.swing.border.EmptyBorder)2 AbstractTableModel (javax.swing.table.AbstractTableModel)2 JTableHeader (javax.swing.table.JTableHeader)2 DataSet (edu.cmu.tetrad.data.DataSet)1 Discretizer (edu.cmu.tetrad.data.Discretizer)1 Node (edu.cmu.tetrad.graph.Node)1 Pc (edu.cmu.tetrad.search.Pc)1 Pcd (edu.cmu.tetrad.search.Pcd)1 SemIm (edu.cmu.tetrad.sem.SemIm)1 SemPm (edu.cmu.tetrad.sem.SemPm)1 IndependenceResult (edu.cmu.tetradapp.model.IndependenceResult)1 IndTestMultinomialLogisticRegressionWald (edu.pitt.csb.mgm.IndTestMultinomialLogisticRegressionWald)1