Search in sources :

Example 76 with ContinuousVariable

use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.

the class TestSearchGraph method rtestDSeparation4.

public void rtestDSeparation4() {
    List<Node> nodes = new ArrayList<>();
    for (int i = 0; i < 100; i++) {
        nodes.add(new ContinuousVariable("X" + (i + 1)));
    }
    Graph graph = new Dag(GraphUtils.randomGraph(nodes, 20, 100, 5, 5, 5, false));
    long start, stop;
    int depth = -1;
    IndependenceTest test = new IndTestDSep(graph);
    Rfci fci = new Rfci(test);
    Fas fas = new Fas(test);
    start = System.currentTimeMillis();
    fci.setDepth(depth);
    fci.setVerbose(false);
    fci.search(fas, fas.getNodes());
    stop = System.currentTimeMillis();
    System.out.println("DSEP RFCI");
    System.out.println("# dsep checks = " + fas.getNumIndependenceTests());
    System.out.println("Elapsed " + (stop - start));
    System.out.println("Per " + fas.getNumIndependenceTests() / (double) (stop - start));
    SemPm pm = new SemPm(graph);
    SemIm im = new SemIm(pm);
    DataSet data = im.simulateData(1000, false);
    IndependenceTest test2 = new IndTestFisherZ(data, 0.001);
    Rfci fci3 = new Rfci(test2);
    Fas fas2 = new Fas(test2);
    start = System.currentTimeMillis();
    fci3.setDepth(depth);
    fci3.search(fas2, fas2.getNodes());
    stop = System.currentTimeMillis();
    System.out.println("FISHER Z RFCI");
    System.out.println("# indep checks = " + fas.getNumIndependenceTests());
    System.out.println("Elapsed " + (stop - start));
    System.out.println("Per " + fas.getNumIndependenceTests() / (double) (stop - start));
}
Also used : DataSet(edu.cmu.tetrad.data.DataSet) ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) SemPm(edu.cmu.tetrad.sem.SemPm) SemIm(edu.cmu.tetrad.sem.SemIm)

Example 77 with ContinuousVariable

use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.

the class TestSimulatedFmri method testClark2.

// @Test
public void testClark2() {
    Node x = new ContinuousVariable("X");
    Node y = new ContinuousVariable("Y");
    Node z = new ContinuousVariable("Z");
    Graph g = new EdgeListGraph();
    g.addNode(x);
    g.addNode(y);
    g.addNode(z);
    g.addDirectedEdge(x, y);
    g.addDirectedEdge(x, z);
    g.addDirectedEdge(y, z);
    GeneralizedSemPm pm = new GeneralizedSemPm(g);
    try {
        pm.setNodeExpression(g.getNode("X"), "E_X");
        pm.setNodeExpression(g.getNode("Y"), "0.4 * X + E_Y");
        pm.setNodeExpression(g.getNode("Z"), "0.4 * X + 0.4 * Y + E_Z");
        String error = "pow(Uniform(0, 1), 1.5)";
        pm.setNodeExpression(pm.getErrorNode(g.getNode("X")), error);
        pm.setNodeExpression(pm.getErrorNode(g.getNode("Y")), error);
        pm.setNodeExpression(pm.getErrorNode(g.getNode("Z")), error);
    } catch (ParseException e) {
        System.out.println(e);
    }
    GeneralizedSemIm im = new GeneralizedSemIm(pm);
    DataSet data = im.simulateData(1000, false);
    edu.cmu.tetrad.search.SemBicScore score = new edu.cmu.tetrad.search.SemBicScore(new CovarianceMatrixOnTheFly(data, false));
    Fask fask = new Fask(data, score);
    fask.setPenaltyDiscount(1);
    fask.setAlpha(0.5);
    Graph out = fask.search();
    System.out.println(out);
}
Also used : DataSet(edu.cmu.tetrad.data.DataSet) Node(edu.cmu.tetrad.graph.Node) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) Fask(edu.cmu.tetrad.search.Fask) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) Graph(edu.cmu.tetrad.graph.Graph) GeneralizedSemIm(edu.cmu.tetrad.sem.GeneralizedSemIm) ParseException(java.text.ParseException) CovarianceMatrixOnTheFly(edu.cmu.tetrad.data.CovarianceMatrixOnTheFly) GeneralizedSemPm(edu.cmu.tetrad.sem.GeneralizedSemPm) SemBicScore(edu.cmu.tetrad.algcomparison.score.SemBicScore)

Example 78 with ContinuousVariable

use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.

the class TestKnowledge method test2.

@Test
public final void test2() {
    List<Node> nodes1 = new ArrayList<>();
    for (int i = 0; i < 100; i++) {
        nodes1.add(new ContinuousVariable("X" + (i + 1)));
    }
    Graph g = GraphUtils.randomGraph(nodes1, 0, 100, 3, 3, 3, false);
    List<Node> nodes = g.getNodes();
    List<String> names = new ArrayList<>();
    for (Node node : nodes) names.add(node.getName());
    Knowledge2 knowledge = new Knowledge2(names);
    knowledge.addToTier(0, "X1*");
    knowledge.addToTier(1, "X2*");
    knowledge.setRequired("X4*,X6*", "X5*");
    knowledge.setRequired("X6*", "X5*");
    assertTrue(knowledge.isForbidden("X20", "X10"));
    assertTrue(knowledge.isRequired("X6", "X5"));
}
Also used : ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) Graph(edu.cmu.tetrad.graph.Graph) Node(edu.cmu.tetrad.graph.Node) ArrayList(java.util.ArrayList) Knowledge2(edu.cmu.tetrad.data.Knowledge2) Test(org.junit.Test)

Example 79 with ContinuousVariable

use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.

the class BayesUpdaterClassifierEditor method showClassification.

private void showClassification() {
    int tabIndex = -1;
    for (int i = 0; i < getTabbedPane().getTabCount(); i++) {
        if ("Classification".equals(getTabbedPane().getTitleAt(i))) {
            getTabbedPane().remove(i);
            tabIndex = i;
        }
    }
    // Put the class information into a DataSet.
    int[] classifications = getClassifier().getClassifications();
    double[][] marginals = getClassifier().getMarginals();
    int maxCategory = 0;
    for (int classification : classifications) {
        if (classification > maxCategory) {
            maxCategory = classification;
        }
    }
    List<Node> variables = new LinkedList<>();
    DiscreteVariable targetVariable = classifier.getTargetVariable();
    DiscreteVariable classVar = new DiscreteVariable(targetVariable.getName(), maxCategory + 1);
    variables.add(classVar);
    for (int i = 0; i < marginals.length; i++) {
        String name = "P(" + targetVariable + "=" + i + ")";
        ContinuousVariable scoreVar = new ContinuousVariable(name);
        variables.add(scoreVar);
    }
    classVar.setName("Result");
    DataSet dataSet = new ColtDataSet(classifications.length, variables);
    for (int i = 0; i < classifications.length; i++) {
        dataSet.setInt(i, 0, classifications[i]);
        for (int j = 0; j < marginals.length; j++) {
            dataSet.setDouble(i, j + 1, marginals[j][i]);
        }
    }
    DataDisplay jTable = new DataDisplay(dataSet);
    JScrollPane scroll = new JScrollPane(jTable);
    if (tabIndex == -1) {
        getTabbedPane().add("Classification", scroll);
    } else {
        getTabbedPane().add(scroll, tabIndex);
        getTabbedPane().setTitleAt(tabIndex, "Classification");
    }
}
Also used : ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) DataSet(edu.cmu.tetrad.data.DataSet) Node(edu.cmu.tetrad.graph.Node) LinkedList(java.util.LinkedList) ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) DiscreteVariable(edu.cmu.tetrad.data.DiscreteVariable) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet)

Example 80 with ContinuousVariable

use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.

the class ModeInterpolator method filter.

public DataSet filter(DataSet dataSet) {
    DataSet newDataSet = dataSet.copy();
    for (int j = 0; j < dataSet.getNumColumns(); j++) {
        Node var = dataSet.getVariable(j);
        if (var instanceof DiscreteVariable) {
            DiscreteVariable variable = (DiscreteVariable) var;
            int numCategories = variable.getNumCategories();
            int[] categoryCounts = new int[numCategories];
            for (int i = 0; i < dataSet.getNumRows(); i++) {
                if (dataSet.getInt(i, j) == DiscreteVariable.MISSING_VALUE) {
                    continue;
                }
                categoryCounts[dataSet.getInt(i, j)]++;
            }
            int mode = -1;
            int max = -1;
            for (int i = 0; i < numCategories; i++) {
                if (categoryCounts[i] > max) {
                    max = categoryCounts[i];
                    mode = i;
                }
            }
            for (int i = 0; i < dataSet.getNumRows(); i++) {
                if (dataSet.getInt(i, j) == DiscreteVariable.MISSING_VALUE) {
                    newDataSet.setInt(i, j, mode);
                }
            // else {
            // newDataSet.setInt(i, j, dataSet.getInt(i, j));
            // }
            }
        } else if (dataSet.getVariable(j) instanceof ContinuousVariable) {
            double[] data = new double[dataSet.getNumRows()];
            int k = -1;
            for (int i = 0; i < dataSet.getNumRows(); i++) {
                if (!Double.isNaN(dataSet.getDouble(i, j))) {
                    data[++k] = dataSet.getDouble(i, j);
                }
            }
            Arrays.sort(data);
            double mode = Double.NaN;
            if (k >= 0) {
                mode = (data[(k + 1) / 2] + data[k / 2]) / 2.d;
            }
            for (int i = 0; i < dataSet.getNumRows(); i++) {
                if (Double.isNaN(dataSet.getDouble(i, j))) {
                    newDataSet.setDouble(i, j, mode);
                }
            // else {
            // newDataSet.setDouble(i, j, dataSet.getDouble(i, j));
            // }
            }
        }
    }
    return newDataSet;
}
Also used : ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) DiscreteVariable(edu.cmu.tetrad.data.DiscreteVariable) DataSet(edu.cmu.tetrad.data.DataSet) Node(edu.cmu.tetrad.graph.Node)

Aggregations

ContinuousVariable (edu.cmu.tetrad.data.ContinuousVariable)91 DataSet (edu.cmu.tetrad.data.DataSet)48 Node (edu.cmu.tetrad.graph.Node)46 Test (org.junit.Test)42 ArrayList (java.util.ArrayList)28 Graph (edu.cmu.tetrad.graph.Graph)22 ColtDataSet (edu.cmu.tetrad.data.ColtDataSet)19 SemPm (edu.cmu.tetrad.sem.SemPm)18 SemIm (edu.cmu.tetrad.sem.SemIm)16 DiscreteVariable (edu.cmu.tetrad.data.DiscreteVariable)15 LinkedList (java.util.LinkedList)13 EdgeListGraph (edu.cmu.tetrad.graph.EdgeListGraph)12 TetradMatrix (edu.cmu.tetrad.util.TetradMatrix)8 DMSearch (edu.cmu.tetrad.search.DMSearch)7 Dag (edu.cmu.tetrad.graph.Dag)6 CovarianceMatrix (edu.cmu.tetrad.data.CovarianceMatrix)5 RandomUtil (edu.cmu.tetrad.util.RandomUtil)5 ParseException (java.text.ParseException)4 CovarianceMatrixOnTheFly (edu.cmu.tetrad.data.CovarianceMatrixOnTheFly)3 Knowledge2 (edu.cmu.tetrad.data.Knowledge2)3