use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.
the class TestLingamPattern method test1.
@Test
public void test1() {
RandomUtil.getInstance().setSeed(4938492L);
int sampleSize = 1000;
List<Node> nodes = new ArrayList<>();
for (int i = 0; i < 6; i++) {
nodes.add(new ContinuousVariable("X" + (i + 1)));
}
Graph graph = new Dag(GraphUtils.randomGraph(nodes, 0, 6, 4, 4, 4, false));
List<Distribution> variableDistributions = new ArrayList<>();
variableDistributions.add(new Normal(0, 1));
variableDistributions.add(new Normal(0, 1));
variableDistributions.add(new Normal(0, 1));
variableDistributions.add(new Uniform(-1, 1));
variableDistributions.add(new Normal(0, 1));
variableDistributions.add(new Normal(0, 1));
SemPm semPm = new SemPm(graph);
SemIm semIm = new SemIm(semPm);
DataSet dataSet = simulateDataNonNormal(semIm, sampleSize, variableDistributions);
Score score = new SemBicScore(new CovarianceMatrixOnTheFly(dataSet));
Graph estPattern = new Fges(score).search();
LingamPattern lingam = new LingamPattern(estPattern, dataSet);
lingam.search();
double[] pvals = lingam.getPValues();
double[] expectedPVals = { 0.18, 0.29, 0.88, 0.00, 0.01, 0.58 };
for (int i = 0; i < pvals.length; i++) {
assertEquals(expectedPVals[i], pvals[i], 0.01);
}
}
use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.
the class TestMbfs method testRandom.
@Test
public void testRandom() {
RandomUtil.getInstance().setSeed(8388428832L);
List<Node> nodes1 = new ArrayList<>();
for (int i = 0; i < 10; i++) {
nodes1.add(new ContinuousVariable("X" + (i + 1)));
}
Dag dag = new Dag(GraphUtils.randomGraph(nodes1, 0, 10, 5, 5, 5, false));
IndependenceTest test = new IndTestDSep(dag);
Mbfs search = new Mbfs(test, -1);
List<Node> nodes = dag.getNodes();
for (Node node : nodes) {
Graph resultMb = search.search(node.getName());
Graph trueMb = GraphUtils.markovBlanketDag(node, dag);
List<Node> resultNodes = resultMb.getNodes();
List<Node> trueNodes = trueMb.getNodes();
Set<String> resultNames = new HashSet<>();
for (Node resultNode : resultNodes) {
resultNames.add(resultNode.getName());
}
Set<String> trueNames = new HashSet<>();
for (Node v : trueNodes) {
trueNames.add(v.getName());
}
assertTrue(resultNames.equals(trueNames));
Set<Edge> resultEdges = resultMb.getEdges();
for (Edge resultEdge : resultEdges) {
if (Edges.isDirectedEdge(resultEdge)) {
String name1 = resultEdge.getNode1().getName();
String name2 = resultEdge.getNode2().getName();
Node node1 = trueMb.getNode(name1);
Node node2 = trueMb.getNode(name2);
// possibility that the node is actually a child.
if (node1 == null) {
fail("Node " + name1 + " is not in the true graph.");
}
if (node2 == null) {
fail("Node " + name2 + " is not in the true graph.");
}
Edge trueEdge = trueMb.getEdge(node1, node2);
if (trueEdge == null) {
Node resultNode1 = resultMb.getNode(node1.getName());
Node resultNode2 = resultMb.getNode(node2.getName());
Node resultTarget = resultMb.getNode(node.getName());
Edge a = resultMb.getEdge(resultNode1, resultTarget);
Edge b = resultMb.getEdge(resultNode2, resultTarget);
if (a == null || b == null) {
continue;
}
if ((Edges.isDirectedEdge(a) && Edges.isUndirectedEdge(b)) || (Edges.isUndirectedEdge(a) && Edges.isDirectedEdge(b))) {
continue;
}
fail("EXTRA EDGE: Edge in result MB but not true MB = " + resultEdge);
}
assertEquals(resultEdge.getEndpoint1(), trueEdge.getEndpoint1());
assertEquals(resultEdge.getEndpoint2(), trueEdge.getEndpoint2());
}
}
}
}
use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.
the class TestSemIm method testIntercepts.
@Test
public void testIntercepts() {
List<Node> nodes = new ArrayList<>();
for (int i = 0; i < 5; i++) {
nodes.add(new ContinuousVariable("X" + (i + 1)));
}
Graph randomGraph = new Dag(GraphUtils.randomGraph(nodes, 0, 8, 30, 15, 15, false));
SemPm semPm = new SemPm(randomGraph);
SemIm semIm = new SemIm(semPm);
semIm.setIntercept(semIm.getVariableNodes().get(0), 1.0);
semIm.setIntercept(semIm.getVariableNodes().get(1), 3.0);
semIm.setIntercept(semIm.getVariableNodes().get(2), -1.0);
semIm.setIntercept(semIm.getVariableNodes().get(3), 6.0);
assertEquals(1.0, semIm.getIntercept(semIm.getVariableNodes().get(0)), 0.1);
assertEquals(3.0, semIm.getIntercept(semIm.getVariableNodes().get(1)), 0.1);
assertEquals(-1.0, semIm.getIntercept(semIm.getVariableNodes().get(2)), 0.1);
assertEquals(6.0, semIm.getIntercept(semIm.getVariableNodes().get(3)), 0.1);
assertEquals(0.0, semIm.getIntercept(semIm.getVariableNodes().get(4)), 0.1);
}
use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.
the class TestSemIm method testCovariancesOfSimulated.
@Test
public void testCovariancesOfSimulated() {
List<Node> nodes = new ArrayList<>();
for (int i = 0; i < 5; i++) {
nodes.add(new ContinuousVariable("X" + (i + 1)));
}
Graph randomGraph = new Dag(GraphUtils.randomGraph(nodes, 0, 8, 30, 15, 15, false));
SemPm semPm1 = new SemPm(randomGraph);
SemIm semIm1 = new SemIm(semPm1);
TetradMatrix implCovarC = semIm1.getImplCovar(true);
implCovarC.toArray();
DataSet dataSet = semIm1.simulateDataRecursive(1000, false);
new CovarianceMatrix(dataSet);
}
use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.
the class TestMarkovBlanketSearches method testRandom.
@Test
public void testRandom() {
List<Node> nodes1 = new ArrayList<>();
for (int i = 0; i < 10; i++) {
nodes1.add(new ContinuousVariable("X" + (i + 1)));
}
Dag dag = new Dag(GraphUtils.randomGraph(nodes1, 0, 10, 5, 5, 5, false));
IndependenceTest test = new IndTestDSep(dag);
Mbfs search = new Mbfs(test, -1);
List<Node> nodes = dag.getNodes();
for (Node node : nodes) {
List<Node> resultNodes = search.findMb(node.getName());
Graph trueMb = GraphUtils.markovBlanketDag(node, dag);
List<Node> trueNodes = trueMb.getNodes();
trueNodes.remove(node);
Collections.sort(trueNodes, new Comparator<Node>() {
public int compare(Node n1, Node n2) {
return n1.getName().compareTo(n2.getName());
}
});
Collections.sort(resultNodes, new Comparator<Node>() {
public int compare(Node n1, Node n2) {
return n1.getName().compareTo(n2.getName());
}
});
assertEquals(trueNodes, resultNodes);
}
}
Aggregations