use of edu.cmu.tetrad.data.DataSet in project tetrad by cmu-phil.
the class TestColtDataSet method testContinuous.
@Test
public final void testContinuous() {
int rows = 10;
int cols = 5;
List<Node> _variables = new LinkedList<>();
for (int i = 0; i < cols; i++) {
_variables.add(new ContinuousVariable("X" + i));
}
DataSet dataSet = new ColtDataSet(rows, _variables);
RandomUtil randomUtil = RandomUtil.getInstance();
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
dataSet.setDouble(i, j, randomUtil.nextDouble());
}
}
List<Node> variables = dataSet.getVariables();
List<Node> newVars = new LinkedList<>();
newVars.add(variables.get(2));
newVars.add(variables.get(4));
DataSet _dataSet = dataSet.subsetColumns(newVars);
assertEquals(dataSet.getDoubleData().getColumn(2).get(0), _dataSet.getDoubleData().getColumn(0).get(0), .001);
assertEquals(dataSet.getDoubleData().getColumn(4).get(0), _dataSet.getDoubleData().getColumn(1).get(0), .001);
}
use of edu.cmu.tetrad.data.DataSet in project tetrad by cmu-phil.
the class TestColtDataSet method testDiscreteFromScratch.
@Test
public void testDiscreteFromScratch() {
DataSet dataSet = new ColtDataSet(0, Collections.EMPTY_LIST);
DiscreteVariable x1 = new DiscreteVariable("X1");
dataSet.addVariable(x1);
dataSet.setInt(0, 0, 0);
dataSet.setInt(1, 0, 2);
dataSet.setInt(2, 0, 1);
DiscreteVariable x2 = new DiscreteVariable("X2");
dataSet.addVariable(x2);
dataSet.setInt(0, 1, 0);
dataSet.setInt(1, 1, 2);
dataSet.setInt(2, 1, 1);
ColtDataSet _dataSet = new ColtDataSet((ColtDataSet) dataSet);
assertEquals(dataSet, _dataSet);
assertEquals(dataSet.getInt(1, 1), 2);
}
use of edu.cmu.tetrad.data.DataSet in project tetrad by cmu-phil.
the class TestColtDataSet method testRemoveRows.
@Test
public void testRemoveRows() {
int rows = 10;
int cols = 5;
List<Node> variables = new LinkedList<>();
for (int i = 0; i < cols; i++) {
variables.add(new ContinuousVariable("X" + i));
}
DataSet dataSet = new ColtDataSet(rows, variables);
RandomUtil randomUtil = RandomUtil.getInstance();
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
dataSet.setDouble(i, j, randomUtil.nextDouble());
}
}
int numRows = dataSet.getNumRows();
double d = dataSet.getDouble(3, 0);
int[] _rows = new int[2];
_rows[0] = 1;
_rows[1] = 2;
dataSet.removeRows(_rows);
assertEquals(numRows - 2, dataSet.getNumRows());
assertEquals(d, dataSet.getDouble(1, 0), 0.001);
}
use of edu.cmu.tetrad.data.DataSet in project tetrad by cmu-phil.
the class TestCpc method test7.
@Test
public void test7() {
int numVars = 6;
int numEdges = 6;
List<Node> nodes = new ArrayList<>();
for (int i = 0; i < numVars; i++) {
nodes.add(new ContinuousVariable("X" + (i + 1)));
}
Dag trueGraph = new Dag(GraphUtils.randomGraph(nodes, 0, numEdges, 7, 5, 5, false));
SemPm semPm = new SemPm(trueGraph);
SemIm semIm = new SemIm(semPm);
DataSet _dataSet = semIm.simulateData(1000, false);
IndependenceTest test = new IndTestFisherZ(_dataSet, 0.05);
Cpc search = new Cpc(test);
Graph resultGraph = search.search();
}
use of edu.cmu.tetrad.data.DataSet in project tetrad by cmu-phil.
the class TestGeneralizedSem method test9.
@Test
public void test9() {
RandomUtil.getInstance().setSeed(29999483L);
try {
Node x1 = new GraphNode("X1");
Node x2 = new GraphNode("X2");
Node x3 = new GraphNode("X3");
Node x4 = new GraphNode("X4");
Graph g = new EdgeListGraphSingleConnections();
g.addNode(x1);
g.addNode(x2);
g.addNode(x3);
g.addNode(x4);
g.addDirectedEdge(x1, x2);
g.addDirectedEdge(x2, x3);
g.addDirectedEdge(x3, x4);
g.addDirectedEdge(x1, x4);
GeneralizedSemPm pm = new GeneralizedSemPm(g);
pm.setNodeExpression(x1, "E_X1");
pm.setNodeExpression(x2, "a1 * tan(X1) + E_X2");
pm.setNodeExpression(x3, "a2 * tan(X2) + E_X3");
pm.setNodeExpression(x4, "a3 * tan(X1) + a4 * tan(X3) ^ 2 + E_X4");
// pm.setNodeExpression(x1, "E_X1");
// pm.setNodeExpression(x2, "a1 * X1^2 + E_X2");
// pm.setNodeExpression(x3, "a2 * X2^2 + E_X3");
// pm.setNodeExpression(x4, "a3 * X1^2 + a4 * X3 ^ 2 + E_X4");
//
pm.setNodeExpression(pm.getErrorNode(x1), "Beta(5, 2)");
pm.setNodeExpression(pm.getErrorNode(x2), "Beta(2, 5)");
pm.setNodeExpression(pm.getErrorNode(x3), "Beta(1, 3)");
pm.setNodeExpression(pm.getErrorNode(x4), "Beta(1, 7)");
pm.setParameterEstimationInitializationExpression("c1", "U(1, 3)");
pm.setParameterEstimationInitializationExpression("c2", "U(1, 3)");
pm.setParameterEstimationInitializationExpression("c3", "U(1, 3)");
pm.setParameterEstimationInitializationExpression("c4", "U(1, 3)");
pm.setParameterEstimationInitializationExpression("c5", "U(1, 3)");
pm.setParameterEstimationInitializationExpression("c6", "U(1, 3)");
pm.setParameterEstimationInitializationExpression("c7", "U(1, 3)");
pm.setParameterEstimationInitializationExpression("c8", "U(1, 3)");
GeneralizedSemIm im = new GeneralizedSemIm(pm);
print("True model: ");
print(im);
DataSet data = im.simulateDataRecursive(1000, false);
pm.setNodeExpression(pm.getErrorNode(x1), "Beta(c1, c2)");
pm.setNodeExpression(pm.getErrorNode(x2), "Beta(c3, c4)");
pm.setNodeExpression(pm.getErrorNode(x3), "Beta(c5, c6)");
pm.setNodeExpression(pm.getErrorNode(x4), "Beta(c7, c8)");
GeneralizedSemEstimator estimator = new GeneralizedSemEstimator();
GeneralizedSemIm estIm = estimator.estimate(pm, data);
print("\n\n\nEstimated model: ");
print(estIm);
print(estimator.getReport());
double aSquaredStar = estimator.getaSquaredStar();
assertEquals(0.62, aSquaredStar, 0.01);
} catch (ParseException e) {
e.printStackTrace();
}
}
Aggregations