use of edu.cmu.tetrad.data.Discretizer in project tetrad by cmu-phil.
the class TestLogisticRegression method test1.
@Test
public void test1() {
List<Node> nodes = new ArrayList<>();
for (int i = 0; i < 5; i++) {
nodes.add(new ContinuousVariable("X" + (i + 1)));
}
Graph graph = new Dag(GraphUtils.randomGraph(nodes, 0, 5, 3, 3, 3, false));
System.out.println(graph);
SemPm pm = new SemPm(graph);
SemIm im = new SemIm(pm);
DataSet data = im.simulateDataRecursive(1000, false);
Node x1 = data.getVariable("X1");
Node x2 = data.getVariable("X2");
Node x3 = data.getVariable("X3");
Node x4 = data.getVariable("X4");
Node x5 = data.getVariable("X5");
Discretizer discretizer = new Discretizer(data);
discretizer.equalCounts(x1, 2);
DataSet d2 = discretizer.discretize();
LogisticRegression regression = new LogisticRegression(d2);
List<Node> regressors = new ArrayList<>();
regressors.add(x2);
regressors.add(x3);
regressors.add(x4);
regressors.add(x5);
DiscreteVariable x1b = (DiscreteVariable) d2.getVariable("X1");
regression.regress(x1b, regressors);
System.out.println(regression);
}
use of edu.cmu.tetrad.data.Discretizer in project tetrad by cmu-phil.
the class TestIndTestWaldLR method testIsIndependent.
@Test
public void testIsIndependent() {
RandomUtil.getInstance().setSeed(1450705713157L);
int numPassed = 0;
for (int i = 0; i < 10; i++) {
List<Node> nodes = new ArrayList<>();
for (int i1 = 0; i1 < 10; i1++) {
nodes.add(new ContinuousVariable("X" + (i1 + 1)));
}
Graph graph = GraphUtils.randomGraph(nodes, 0, 10, 3, 3, 3, false);
SemPm pm = new SemPm(graph);
SemIm im = new SemIm(pm);
DataSet data = im.simulateData(1000, false);
Discretizer discretizer = new Discretizer(data);
discretizer.setVariablesCopied(true);
discretizer.equalCounts(data.getVariable(0), 2);
discretizer.equalCounts(data.getVariable(3), 2);
data = discretizer.discretize();
Node x1 = data.getVariable("X1");
Node x2 = data.getVariable("X2");
Node x3 = data.getVariable("X3");
Node x4 = data.getVariable("X4");
Node x5 = data.getVariable("X5");
List<Node> cond = new ArrayList<>();
cond.add(x3);
cond.add(x4);
cond.add(x5);
Node x1Graph = graph.getNode(x1.getName());
Node x2Graph = graph.getNode(x2.getName());
List<Node> condGraph = new ArrayList<>();
for (Node node : cond) {
condGraph.add(graph.getNode(node.getName()));
}
// Using the Wald LR test since it's most up to date.
IndependenceTest test = new IndTestMultinomialLogisticRegressionWald(data, 0.05, false);
IndTestDSep dsep = new IndTestDSep(graph);
boolean correct = test.isIndependent(x2, x1, cond) == dsep.isIndependent(x2Graph, x1Graph, condGraph);
if (correct) {
numPassed++;
}
}
// System.out.println(RandomUtil.getInstance().getSeed());
// Do not always get all 10.
assertEquals(10, numPassed);
}
use of edu.cmu.tetrad.data.Discretizer in project tetrad by cmu-phil.
the class SemThenDiscretize method simulate.
private DataSet simulate(Graph graph, Parameters parameters) {
SemPm pm = new SemPm(graph);
SemIm im = new SemIm(pm);
DataSet continuousData = im.simulateData(parameters.getInt("sampleSize"), false);
if (this.shuffledOrder == null) {
List<Node> shuffledNodes = new ArrayList<>(continuousData.getVariables());
Collections.shuffle(shuffledNodes);
this.shuffledOrder = shuffledNodes;
}
Discretizer discretizer = new Discretizer(continuousData);
for (int i = 0; i < shuffledOrder.size() * parameters.getDouble("percentDiscrete") * 0.01; i++) {
discretizer.equalIntervals(continuousData.getVariable(shuffledOrder.get(i).getName()), parameters.getInt("numCategories"));
}
return discretizer.discretize();
}
Aggregations