use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.
the class TestSemXml method sampleSemIm1.
private static SemIm sampleSemIm1() {
List<Node> nodes = new ArrayList<>();
for (int i = 0; i < 5; i++) {
nodes.add(new ContinuousVariable("X" + (i + 1)));
}
Graph graph = new Dag(GraphUtils.randomGraph(nodes, 0, 5, 30, 15, 15, true));
SemPm pm = new SemPm(graph);
return new SemIm(pm);
}
use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.
the class TestLingamPattern method simulateDataNonNormal.
/**
* This simulates data by picking random values for the exogenous terms and percolating this information down
* through the SEM, assuming it is acyclic. Fast for large simulations but hangs for cyclic models.
*
* @param sampleSize > 0.
* @return the simulated data set.
*/
private DataSet simulateDataNonNormal(SemIm semIm, int sampleSize, List<Distribution> distributions) {
List<Node> variables = new LinkedList<>();
List<Node> variableNodes = semIm.getSemPm().getVariableNodes();
for (Node node : variableNodes) {
ContinuousVariable var = new ContinuousVariable(node.getName());
variables.add(var);
}
DataSet dataSet = new ColtDataSet(sampleSize, variables);
// Create some index arrays to hopefully speed up the simulation.
SemGraph graph = semIm.getSemPm().getGraph();
List<Node> tierOrdering = graph.getCausalOrdering();
int[] tierIndices = new int[variableNodes.size()];
for (int i = 0; i < tierIndices.length; i++) {
tierIndices[i] = variableNodes.indexOf(tierOrdering.get(i));
}
int[][] _parents = new int[variables.size()][];
for (int i = 0; i < variableNodes.size(); i++) {
Node node = variableNodes.get(i);
List<Node> parents = graph.getParents(node);
for (Iterator<Node> j = parents.iterator(); j.hasNext(); ) {
Node _node = j.next();
if (_node.getNodeType() == NodeType.ERROR) {
j.remove();
}
}
_parents[i] = new int[parents.size()];
for (int j = 0; j < parents.size(); j++) {
Node _parent = parents.get(j);
_parents[i][j] = variableNodes.indexOf(_parent);
}
}
// Do the simulation.
for (int row = 0; row < sampleSize; row++) {
for (int i = 0; i < tierOrdering.size(); i++) {
int col = tierIndices[i];
Distribution distribution = distributions.get(col);
// System.out.println(distribution);
double value = distribution.nextRandom();
for (int j = 0; j < _parents[col].length; j++) {
int parent = _parents[col][j];
value += dataSet.getDouble(row, parent) * semIm.getEdgeCoef().get(parent, col);
}
value += semIm.getMeans()[col];
dataSet.setDouble(row, col, value);
}
}
return dataSet;
}
use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.
the class TestHistogram method testHistogram.
@Test
public void testHistogram() {
RandomUtil.getInstance().setSeed(4829384L);
List<Node> nodes = new ArrayList<>();
for (int i = 0; i < 5; i++) {
nodes.add(new ContinuousVariable("X" + (i + 1)));
}
Dag trueGraph = new Dag(GraphUtils.randomGraph(nodes, 0, 5, 30, 15, 15, false));
int sampleSize = 1000;
// Continuous
SemPm semPm = new SemPm(trueGraph);
SemIm semIm = new SemIm(semPm);
DataSet data = semIm.simulateData(sampleSize, false);
Histogram histogram = new Histogram(data);
histogram.setTarget("X1");
histogram.setNumBins(20);
assertEquals(3.76, histogram.getMax(), 0.01);
assertEquals(-3.83, histogram.getMin(), 0.01);
assertEquals(1000, histogram.getN());
histogram.setTarget("X1");
histogram.setNumBins(10);
histogram.addConditioningVariable("X3", 0, 1);
histogram.addConditioningVariable("X4", 0, 1);
histogram.removeConditioningVariable("X3");
assertEquals(3.76, histogram.getMax(), 0.01);
assertEquals(-3.83, histogram.getMin(), 0.01);
assertEquals(188, histogram.getN());
double[] arr = histogram.getContinuousData("X2");
histogram.addConditioningVariable("X2", StatUtils.min(arr), StatUtils.mean(arr));
// Discrete
BayesPm bayesPm = new BayesPm(trueGraph);
BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM);
DataSet data2 = bayesIm.simulateData(sampleSize, false);
// For some reason these are giving different
// values when all of the unit tests are run are
// once. TODO They produce stable values when
// this particular test is run repeatedly.
Histogram histogram2 = new Histogram(data2);
histogram2.setTarget("X1");
int[] frequencies1 = histogram2.getFrequencies();
// assertEquals(928, frequencies1[0]);
// assertEquals(72, frequencies1[1]);
histogram2.setTarget("X1");
histogram2.addConditioningVariable("X2", 0);
histogram2.addConditioningVariable("X3", 1);
int[] frequencies = histogram2.getFrequencies();
// assertEquals(377, frequencies[0]);
// assertEquals(28, frequencies[1]);
}
use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.
the class TestKernelGaussian method testMedianBandwidth.
/**
* Tests the bandwidth setting to the median distance between points in the sample
*/
@Test
public void testMedianBandwidth() {
Node X = new ContinuousVariable("X");
DataSet dataset = new ColtDataSet(5, Arrays.asList(X));
dataset.setDouble(0, 0, 1);
dataset.setDouble(1, 0, 2);
dataset.setDouble(2, 0, 3);
dataset.setDouble(3, 0, 4);
dataset.setDouble(4, 0, 5);
KernelGaussian kernel = new KernelGaussian(dataset, X);
assertTrue(kernel.getBandwidth() == 2);
}
use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.
the class TestSearchGraph method testDSeparation2.
/**
* Tests to see if d separation facts are symmetric.
*/
@Test
public void testDSeparation2() {
List<Node> nodes1 = new ArrayList<>();
for (int i1 = 0; i1 < 7; i1++) {
nodes1.add(new ContinuousVariable("X" + (i1 + 1)));
}
EdgeListGraphSingleConnections graph = new EdgeListGraphSingleConnections(new Dag(GraphUtils.randomGraph(nodes1, 0, 14, 30, 15, 15, true)));
List<Node> nodes = graph.getNodes();
int depth = -1;
for (int i = 0; i < nodes.size(); i++) {
for (int j = i; j < nodes.size(); j++) {
Node x = nodes.get(i);
Node y = nodes.get(j);
List<Node> theRest = new ArrayList<>(nodes);
// theRest.remove(x);
// theRest.remove(y);
DepthChoiceGenerator gen = new DepthChoiceGenerator(theRest.size(), depth);
int[] choice;
while ((choice = gen.next()) != null) {
List<Node> z = new LinkedList<>();
for (int k = 0; k < choice.length; k++) {
z.add(theRest.get(choice[k]));
}
boolean dConnectedTo = graph.isDConnectedTo(x, y, z);
boolean dConnectedTo1 = graph.isDConnectedTo(y, x, z);
if (dConnectedTo != dConnectedTo1) {
System.out.println(x + " d connected to " + y + " given " + z);
System.out.println(graph);
System.out.println("dconnectedto = " + dConnectedTo);
System.out.println("dconnecteto1 = " + dConnectedTo1);
fail();
}
}
}
}
}
Aggregations