Search in sources :

Example 46 with ContinuousVariable

use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.

the class TestDataWrapper method testDataModelList.

@Test
public void testDataModelList() {
    DataModelList modelList = new DataModelList();
    List<Node> variables1 = new ArrayList<>();
    for (int i = 0; i < 10; i++) {
        variables1.add(new ContinuousVariable("X" + i));
    }
    List<Node> variables2 = new ArrayList<>();
    for (int i = 0; i < 10; i++) {
        variables2.add(new ContinuousVariable("X" + i));
    }
    DataSet first = new ColtDataSet(10, variables1);
    first.setName("first");
    DataSet second = new ColtDataSet(10, variables2);
    second.setName("second");
    modelList.add(first);
    modelList.add(second);
    assertTrue(modelList.contains(first));
    assertTrue(modelList.contains(second));
    modelList.setSelectedModel(second);
    try {
        DataModelList modelList2 = new MarshalledObject<>(modelList).get();
        assertEquals("second", modelList2.getSelectedModel().getName());
    } catch (Exception e) {
        e.printStackTrace();
    }
}
Also used : ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) DataModelList(edu.cmu.tetrad.data.DataModelList) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) DataSet(edu.cmu.tetrad.data.DataSet) Node(edu.cmu.tetrad.graph.Node) ArrayList(java.util.ArrayList) Test(org.junit.Test)

Example 47 with ContinuousVariable

use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.

the class MNLRLikelihood method getLik.

public double getLik(int child_index, int[] parents) {
    double lik = 0;
    Node c = variables.get(child_index);
    List<ContinuousVariable> continuous_parents = new ArrayList<>();
    List<DiscreteVariable> discrete_parents = new ArrayList<>();
    for (int p : parents) {
        Node parent = variables.get(p);
        if (parent instanceof ContinuousVariable) {
            continuous_parents.add((ContinuousVariable) parent);
        } else {
            discrete_parents.add((DiscreteVariable) parent);
        }
    }
    int p = continuous_parents.size();
    List<List<Integer>> cells = adTree.getCellLeaves(discrete_parents);
    // List<List<Integer>> cells = partition(discrete_parents);
    int[] continuousCols = new int[p];
    for (int j = 0; j < p; j++) continuousCols[j] = nodesHash.get(continuous_parents.get(j));
    for (List<Integer> cell : cells) {
        int r = cell.size();
        if (r > 1) {
            double[] mean = new double[p];
            double[] var = new double[p];
            for (int i = 0; i < p; i++) {
                for (int j = 0; j < r; j++) {
                    mean[i] += continuousData[continuousCols[i]][cell.get(j)];
                    var[i] += Math.pow(continuousData[continuousCols[i]][cell.get(j)], 2);
                }
                mean[i] /= r;
                var[i] /= r;
                var[i] -= Math.pow(mean[i], 2);
                var[i] = Math.sqrt(var[i]);
                if (Double.isNaN(var[i])) {
                    System.out.println(var[i]);
                }
            }
            int degree = fDegree;
            if (fDegree < 1) {
                degree = (int) Math.floor(Math.log(r));
            }
            TetradMatrix subset = new TetradMatrix(r, p * degree + 1);
            for (int i = 0; i < r; i++) {
                subset.set(i, p * degree, 1);
                for (int j = 0; j < p; j++) {
                    for (int d = 0; d < degree; d++) {
                        subset.set(i, p * d + j, Math.pow((continuousData[continuousCols[j]][cell.get(i)] - mean[j]) / var[j], d + 1));
                    }
                }
            }
            if (c instanceof ContinuousVariable) {
                TetradVector target = new TetradVector(r);
                for (int i = 0; i < r; i++) {
                    target.set(i, continuousData[child_index][cell.get(i)]);
                }
                lik += multipleRegression(target, subset);
            } else {
                ArrayList<Integer> temp = new ArrayList<>();
                TetradMatrix target = new TetradMatrix(r, ((DiscreteVariable) c).getNumCategories());
                for (int i = 0; i < r; i++) {
                    for (int j = 0; j < ((DiscreteVariable) c).getNumCategories(); j++) {
                        target.set(i, j, -1);
                    }
                    target.set(i, discreteData[child_index][cell.get(i)], 1);
                }
                lik += MultinomialLogisticRegression(target, subset);
            }
        }
    }
    return lik;
}
Also used : Node(edu.cmu.tetrad.graph.Node) TetradMatrix(edu.cmu.tetrad.util.TetradMatrix) ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) TetradVector(edu.cmu.tetrad.util.TetradVector) DiscreteVariable(edu.cmu.tetrad.data.DiscreteVariable)

Example 48 with ContinuousVariable

use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.

the class TestHistogram method test1.

public void test1() {
    List<Node> nodes = new LinkedList<>();
    Node x1 = new ContinuousVariable("X1");
    Node x2 = new ContinuousVariable("X2");
    nodes.add(x1);
    nodes.add(x2);
    TetradMatrix dataMatrix = new TetradMatrix(10, 2);
    dataMatrix.set(0, 0, 0);
    dataMatrix.set(1, 0, 0);
    dataMatrix.set(2, 0, 0);
    dataMatrix.set(3, 0, 0);
    dataMatrix.set(4, 0, 0);
    dataMatrix.set(5, 0, 1);
    dataMatrix.set(6, 0, 1);
    dataMatrix.set(7, 0, 1);
    dataMatrix.set(8, 0, 1);
    dataMatrix.set(9, 0, 1);
    dataMatrix.set(0, 1, 0);
    dataMatrix.set(1, 1, 1);
    dataMatrix.set(2, 1, 1);
    dataMatrix.set(3, 1, 1);
    dataMatrix.set(4, 1, 1);
    dataMatrix.set(5, 1, 0);
    dataMatrix.set(6, 1, 0);
    dataMatrix.set(7, 1, 0);
    dataMatrix.set(8, 1, 0);
    dataMatrix.set(9, 1, 1);
    DataSet dataSet = ColtDataSet.makeContinuousData(nodes, dataMatrix);
// Histogram histogram = new Histogram(dataSet, );
}
Also used : ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) DataSet(edu.cmu.tetrad.data.DataSet) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) Node(edu.cmu.tetrad.graph.Node) TetradMatrix(edu.cmu.tetrad.util.TetradMatrix) LinkedList(java.util.LinkedList)

Example 49 with ContinuousVariable

use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.

the class TabularDataTable method addColumnsOutTo.

/**
 * Col index here is JTable index.
 */
private void addColumnsOutTo(int col) {
    for (int i = dataSet.getNumColumns() + getNumLeadingCols(); i <= col; i++) {
        ContinuousVariable var = new ContinuousVariable("");
        dataSet.addVariable(var);
        System.out.println("Adding " + var + " col " + dataSet.getColumn(var));
    }
    pcs.firePropertyChange("modelChanged", null, null);
}
Also used : ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable)

Example 50 with ContinuousVariable

use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.

the class TimeLagGraphEditor method createGraphMenu.

private JMenu createGraphMenu() {
    JMenu graph = new JMenu("Graph");
    graph.add(new GraphPropertiesAction(getWorkbench()));
    graph.add(new PathsAction(getWorkbench()));
    // graph.add(new DirectedPathsAction(getWorkbench()));
    // graph.add(new TreksAction(getWorkbench()));
    // graph.add(new AllPathsAction(getWorkbench()));
    // graph.add(new NeighborhoodsAction(getWorkbench()));
    graph.addSeparator();
    JMenuItem correlateExogenous = new JMenuItem("Correlate Exogenous Variables");
    JMenuItem uncorrelateExogenous = new JMenuItem("Uncorrelate Exogenous Variables");
    graph.add(correlateExogenous);
    graph.add(uncorrelateExogenous);
    graph.addSeparator();
    correlateExogenous.addActionListener(new ActionListener() {

        public void actionPerformed(ActionEvent e) {
            correlateExogenousVariables();
            getWorkbench().invalidate();
            getWorkbench().repaint();
        }
    });
    uncorrelateExogenous.addActionListener(new ActionListener() {

        public void actionPerformed(ActionEvent e) {
            uncorrelationExogenousVariables();
            getWorkbench().invalidate();
            getWorkbench().repaint();
        }
    });
    JMenuItem randomGraph = new JMenuItem("Random Graph");
    graph.add(randomGraph);
    randomGraph.addActionListener(new ActionListener() {

        public void actionPerformed(ActionEvent e) {
            RandomGraphEditor editor = new RandomGraphEditor(workbench.getGraph(), true, parameters);
            int ret = JOptionPane.showConfirmDialog(TimeLagGraphEditor.this, editor, "Edit Random DAG Parameters", JOptionPane.PLAIN_MESSAGE);
            if (ret == JOptionPane.OK_OPTION) {
                Graph graph = null;
                Graph dag = new Dag();
                int numTrials = 0;
                while (graph == null && ++numTrials < 100) {
                    if (editor.isRandomForward()) {
                        dag = GraphUtils.randomGraphRandomForwardEdges(getGraph().getNodes(), editor.getNumLatents(), editor.getMaxEdges(), 30, 15, 15, false, true);
                        GraphUtils.arrangeBySourceGraph(dag, getWorkbench().getGraph());
                        HashMap<String, PointXy> layout = GraphUtils.grabLayout(workbench.getGraph().getNodes());
                        GraphUtils.arrangeByLayout(dag, layout);
                    } else if (editor.isUniformlySelected()) {
                        if (getGraph().getNumNodes() == editor.getNumNodes()) {
                            HashMap<String, PointXy> layout = GraphUtils.grabLayout(workbench.getGraph().getNodes());
                            dag = GraphUtils.randomGraph(getGraph().getNodes(), editor.getNumLatents(), editor.getMaxEdges(), editor.getMaxDegree(), editor.getMaxIndegree(), editor.getMaxOutdegree(), editor.isConnected());
                            GraphUtils.arrangeBySourceGraph(dag, getWorkbench().getGraph());
                            GraphUtils.arrangeByLayout(dag, layout);
                        } else {
                            List<Node> nodes = new ArrayList<>();
                            for (int i = 0; i < editor.getNumNodes(); i++) {
                                nodes.add(new ContinuousVariable("X" + (i + 1)));
                            }
                            dag = GraphUtils.randomGraph(nodes, editor.getNumLatents(), editor.getMaxEdges(), editor.getMaxDegree(), editor.getMaxIndegree(), editor.getMaxOutdegree(), editor.isConnected());
                        }
                    } else {
                        do {
                            if (getGraph().getNumNodes() == editor.getNumNodes()) {
                                HashMap<String, PointXy> layout = GraphUtils.grabLayout(workbench.getGraph().getNodes());
                                dag = GraphUtils.randomDag(getGraph().getNodes(), editor.getNumLatents(), editor.getMaxEdges(), 30, 15, 15, editor.isConnected());
                                GraphUtils.arrangeByLayout(dag, layout);
                            } else {
                                List<Node> nodes = new ArrayList<>();
                                for (int i = 0; i < editor.getNumNodes(); i++) {
                                    nodes.add(new ContinuousVariable("X" + (i + 1)));
                                }
                                dag = GraphUtils.randomGraph(nodes, editor.getNumLatents(), editor.getMaxEdges(), 30, 15, 15, editor.isConnected());
                            }
                        } while (dag.getNumEdges() < editor.getMaxEdges());
                    }
                    boolean addCycles = editor.isAddCycles();
                    if (addCycles) {
                        int minNumCycles = editor.getMinNumCycles();
                        int minCycleLength = editor.getMinCycleLength();
                        // graph = DataGraphUtils.addCycles2(dag, minNumCycles, minCycleLength);
                        graph = GraphUtils.cyclicGraph2(editor.getNumNodes(), editor.getMaxEdges(), 8);
                        GraphUtils.addTwoCycles(graph, editor.getMinNumCycles());
                    } else {
                        graph = new EdgeListGraph(dag);
                    }
                }
                if (graph == null) {
                    JOptionPane.showMessageDialog(TimeLagGraphEditor.this, "Could not find a graph that fits those constrains.");
                    getWorkbench().setGraph(new EdgeListGraph(dag));
                } else {
                    getWorkbench().setGraph(graph);
                }
            // getWorkbench().setGraph(new EdgeListGraph(dag));
            // getWorkbench().setGraph(graph);
            }
        }
    });
    JMenuItem randomIndicatorModel = new JMenuItem("Random Multiple Indicator Model");
    graph.add(randomIndicatorModel);
    randomIndicatorModel.addActionListener(new ActionListener() {

        public void actionPerformed(ActionEvent e) {
            RandomMimParamsEditor editor = new RandomMimParamsEditor(parameters);
            int ret = JOptionPane.showConfirmDialog(JOptionUtils.centeringComp(), editor, "Edit Random MIM Parameters", JOptionPane.OK_CANCEL_OPTION, JOptionPane.PLAIN_MESSAGE);
            if (ret == JOptionPane.OK_OPTION) {
                int numFactors = Preferences.userRoot().getInt("randomMimNumFactors", 1);
                int numStructuralNodes = Preferences.userRoot().getInt("numStructuralNodes", 3);
                int maxStructuralEdges = Preferences.userRoot().getInt("numStructuralEdges", 3);
                int measurementModelDegree = Preferences.userRoot().getInt("measurementModelDegree", 3);
                int numLatentMeasuredImpureParents = Preferences.userRoot().getInt("latentMeasuredImpureParents", 0);
                int numMeasuredMeasuredImpureParents = Preferences.userRoot().getInt("measuredMeasuredImpureParents", 0);
                int numMeasuredMeasuredImpureAssociations = Preferences.userRoot().getInt("measuredMeasuredImpureAssociations", 0);
                Graph graph;
                if (numFactors == 1) {
                    graph = DataGraphUtils.randomSingleFactorModel(numStructuralNodes, maxStructuralEdges, measurementModelDegree, numLatentMeasuredImpureParents, numMeasuredMeasuredImpureParents, numMeasuredMeasuredImpureAssociations);
                } else if (numFactors == 2) {
                    graph = DataGraphUtils.randomBifactorModel(numStructuralNodes, maxStructuralEdges, measurementModelDegree, numLatentMeasuredImpureParents, numMeasuredMeasuredImpureParents, numMeasuredMeasuredImpureAssociations);
                } else {
                    throw new IllegalArgumentException("Can only make random MIMs for 1 or 2 factors, " + "sorry dude.");
                }
                getWorkbench().setGraph(graph);
            }
        }
    });
    graph.addSeparator();
    graph.add(new JMenuItem(new SelectBidirectedAction(getWorkbench())));
    graph.add(new JMenuItem(new SelectUndirectedAction(getWorkbench())));
    // graph.add(action);
    return graph;
}
Also used : ActionEvent(java.awt.event.ActionEvent) PointXy(edu.cmu.tetrad.util.PointXy) ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) ActionListener(java.awt.event.ActionListener)

Aggregations

ContinuousVariable (edu.cmu.tetrad.data.ContinuousVariable)91 DataSet (edu.cmu.tetrad.data.DataSet)48 Node (edu.cmu.tetrad.graph.Node)46 Test (org.junit.Test)42 ArrayList (java.util.ArrayList)28 Graph (edu.cmu.tetrad.graph.Graph)22 ColtDataSet (edu.cmu.tetrad.data.ColtDataSet)19 SemPm (edu.cmu.tetrad.sem.SemPm)18 SemIm (edu.cmu.tetrad.sem.SemIm)16 DiscreteVariable (edu.cmu.tetrad.data.DiscreteVariable)15 LinkedList (java.util.LinkedList)13 EdgeListGraph (edu.cmu.tetrad.graph.EdgeListGraph)12 TetradMatrix (edu.cmu.tetrad.util.TetradMatrix)8 DMSearch (edu.cmu.tetrad.search.DMSearch)7 Dag (edu.cmu.tetrad.graph.Dag)6 CovarianceMatrix (edu.cmu.tetrad.data.CovarianceMatrix)5 RandomUtil (edu.cmu.tetrad.util.RandomUtil)5 ParseException (java.text.ParseException)4 CovarianceMatrixOnTheFly (edu.cmu.tetrad.data.CovarianceMatrixOnTheFly)3 Knowledge2 (edu.cmu.tetrad.data.Knowledge2)3