Search in sources :

Example 1 with Discretizer

use of edu.cmu.tetrad.data.Discretizer in project tetrad by cmu-phil.

the class TestLogisticRegression method test1.

@Test
public void test1() {
    List<Node> nodes = new ArrayList<>();
    for (int i = 0; i < 5; i++) {
        nodes.add(new ContinuousVariable("X" + (i + 1)));
    }
    Graph graph = new Dag(GraphUtils.randomGraph(nodes, 0, 5, 3, 3, 3, false));
    System.out.println(graph);
    SemPm pm = new SemPm(graph);
    SemIm im = new SemIm(pm);
    DataSet data = im.simulateDataRecursive(1000, false);
    Node x1 = data.getVariable("X1");
    Node x2 = data.getVariable("X2");
    Node x3 = data.getVariable("X3");
    Node x4 = data.getVariable("X4");
    Node x5 = data.getVariable("X5");
    Discretizer discretizer = new Discretizer(data);
    discretizer.equalCounts(x1, 2);
    DataSet d2 = discretizer.discretize();
    LogisticRegression regression = new LogisticRegression(d2);
    List<Node> regressors = new ArrayList<>();
    regressors.add(x2);
    regressors.add(x3);
    regressors.add(x4);
    regressors.add(x5);
    DiscreteVariable x1b = (DiscreteVariable) d2.getVariable("X1");
    regression.regress(x1b, regressors);
    System.out.println(regression);
}
Also used : DataSet(edu.cmu.tetrad.data.DataSet) Node(edu.cmu.tetrad.graph.Node) ArrayList(java.util.ArrayList) Dag(edu.cmu.tetrad.graph.Dag) Discretizer(edu.cmu.tetrad.data.Discretizer) ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) DiscreteVariable(edu.cmu.tetrad.data.DiscreteVariable) Graph(edu.cmu.tetrad.graph.Graph) SemPm(edu.cmu.tetrad.sem.SemPm) LogisticRegression(edu.cmu.tetrad.regression.LogisticRegression) SemIm(edu.cmu.tetrad.sem.SemIm) Test(org.junit.Test)

Example 2 with Discretizer

use of edu.cmu.tetrad.data.Discretizer in project tetrad by cmu-phil.

the class TestIndTestWaldLR method testIsIndependent.

@Test
public void testIsIndependent() {
    RandomUtil.getInstance().setSeed(1450705713157L);
    int numPassed = 0;
    for (int i = 0; i < 10; i++) {
        List<Node> nodes = new ArrayList<>();
        for (int i1 = 0; i1 < 10; i1++) {
            nodes.add(new ContinuousVariable("X" + (i1 + 1)));
        }
        Graph graph = GraphUtils.randomGraph(nodes, 0, 10, 3, 3, 3, false);
        SemPm pm = new SemPm(graph);
        SemIm im = new SemIm(pm);
        DataSet data = im.simulateData(1000, false);
        Discretizer discretizer = new Discretizer(data);
        discretizer.setVariablesCopied(true);
        discretizer.equalCounts(data.getVariable(0), 2);
        discretizer.equalCounts(data.getVariable(3), 2);
        data = discretizer.discretize();
        Node x1 = data.getVariable("X1");
        Node x2 = data.getVariable("X2");
        Node x3 = data.getVariable("X3");
        Node x4 = data.getVariable("X4");
        Node x5 = data.getVariable("X5");
        List<Node> cond = new ArrayList<>();
        cond.add(x3);
        cond.add(x4);
        cond.add(x5);
        Node x1Graph = graph.getNode(x1.getName());
        Node x2Graph = graph.getNode(x2.getName());
        List<Node> condGraph = new ArrayList<>();
        for (Node node : cond) {
            condGraph.add(graph.getNode(node.getName()));
        }
        // Using the Wald LR test since it's most up to date.
        IndependenceTest test = new IndTestMultinomialLogisticRegressionWald(data, 0.05, false);
        IndTestDSep dsep = new IndTestDSep(graph);
        boolean correct = test.isIndependent(x2, x1, cond) == dsep.isIndependent(x2Graph, x1Graph, condGraph);
        if (correct) {
            numPassed++;
        }
    }
    // System.out.println(RandomUtil.getInstance().getSeed());
    // Do not always get all 10.
    assertEquals(10, numPassed);
}
Also used : DataSet(edu.cmu.tetrad.data.DataSet) Node(edu.cmu.tetrad.graph.Node) ArrayList(java.util.ArrayList) Discretizer(edu.cmu.tetrad.data.Discretizer) ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) IndependenceTest(edu.cmu.tetrad.search.IndependenceTest) IndTestDSep(edu.cmu.tetrad.search.IndTestDSep) Graph(edu.cmu.tetrad.graph.Graph) SemPm(edu.cmu.tetrad.sem.SemPm) IndTestMultinomialLogisticRegressionWald(edu.pitt.csb.mgm.IndTestMultinomialLogisticRegressionWald) SemIm(edu.cmu.tetrad.sem.SemIm) Test(org.junit.Test) IndependenceTest(edu.cmu.tetrad.search.IndependenceTest)

Example 3 with Discretizer

use of edu.cmu.tetrad.data.Discretizer in project tetrad by cmu-phil.

the class SemThenDiscretize method simulate.

private DataSet simulate(Graph graph, Parameters parameters) {
    SemPm pm = new SemPm(graph);
    SemIm im = new SemIm(pm);
    DataSet continuousData = im.simulateData(parameters.getInt("sampleSize"), false);
    if (this.shuffledOrder == null) {
        List<Node> shuffledNodes = new ArrayList<>(continuousData.getVariables());
        Collections.shuffle(shuffledNodes);
        this.shuffledOrder = shuffledNodes;
    }
    Discretizer discretizer = new Discretizer(continuousData);
    for (int i = 0; i < shuffledOrder.size() * parameters.getDouble("percentDiscrete") * 0.01; i++) {
        discretizer.equalIntervals(continuousData.getVariable(shuffledOrder.get(i).getName()), parameters.getInt("numCategories"));
    }
    return discretizer.discretize();
}
Also used : DataSet(edu.cmu.tetrad.data.DataSet) Node(edu.cmu.tetrad.graph.Node) ArrayList(java.util.ArrayList) SemPm(edu.cmu.tetrad.sem.SemPm) SemIm(edu.cmu.tetrad.sem.SemIm) Discretizer(edu.cmu.tetrad.data.Discretizer)

Aggregations

DataSet (edu.cmu.tetrad.data.DataSet)3 Discretizer (edu.cmu.tetrad.data.Discretizer)3 Node (edu.cmu.tetrad.graph.Node)3 SemIm (edu.cmu.tetrad.sem.SemIm)3 SemPm (edu.cmu.tetrad.sem.SemPm)3 ArrayList (java.util.ArrayList)3 ContinuousVariable (edu.cmu.tetrad.data.ContinuousVariable)2 Graph (edu.cmu.tetrad.graph.Graph)2 Test (org.junit.Test)2 DiscreteVariable (edu.cmu.tetrad.data.DiscreteVariable)1 Dag (edu.cmu.tetrad.graph.Dag)1 LogisticRegression (edu.cmu.tetrad.regression.LogisticRegression)1 IndTestDSep (edu.cmu.tetrad.search.IndTestDSep)1 IndependenceTest (edu.cmu.tetrad.search.IndependenceTest)1 IndTestMultinomialLogisticRegressionWald (edu.pitt.csb.mgm.IndTestMultinomialLogisticRegressionWald)1