use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.
the class TestSearchGraph method rtestDSeparation4.
public void rtestDSeparation4() {
List<Node> nodes = new ArrayList<>();
for (int i = 0; i < 100; i++) {
nodes.add(new ContinuousVariable("X" + (i + 1)));
}
Graph graph = new Dag(GraphUtils.randomGraph(nodes, 20, 100, 5, 5, 5, false));
long start, stop;
int depth = -1;
IndependenceTest test = new IndTestDSep(graph);
Rfci fci = new Rfci(test);
Fas fas = new Fas(test);
start = System.currentTimeMillis();
fci.setDepth(depth);
fci.setVerbose(false);
fci.search(fas, fas.getNodes());
stop = System.currentTimeMillis();
System.out.println("DSEP RFCI");
System.out.println("# dsep checks = " + fas.getNumIndependenceTests());
System.out.println("Elapsed " + (stop - start));
System.out.println("Per " + fas.getNumIndependenceTests() / (double) (stop - start));
SemPm pm = new SemPm(graph);
SemIm im = new SemIm(pm);
DataSet data = im.simulateData(1000, false);
IndependenceTest test2 = new IndTestFisherZ(data, 0.001);
Rfci fci3 = new Rfci(test2);
Fas fas2 = new Fas(test2);
start = System.currentTimeMillis();
fci3.setDepth(depth);
fci3.search(fas2, fas2.getNodes());
stop = System.currentTimeMillis();
System.out.println("FISHER Z RFCI");
System.out.println("# indep checks = " + fas.getNumIndependenceTests());
System.out.println("Elapsed " + (stop - start));
System.out.println("Per " + fas.getNumIndependenceTests() / (double) (stop - start));
}
use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.
the class TestSimulatedFmri method testClark2.
// @Test
public void testClark2() {
Node x = new ContinuousVariable("X");
Node y = new ContinuousVariable("Y");
Node z = new ContinuousVariable("Z");
Graph g = new EdgeListGraph();
g.addNode(x);
g.addNode(y);
g.addNode(z);
g.addDirectedEdge(x, y);
g.addDirectedEdge(x, z);
g.addDirectedEdge(y, z);
GeneralizedSemPm pm = new GeneralizedSemPm(g);
try {
pm.setNodeExpression(g.getNode("X"), "E_X");
pm.setNodeExpression(g.getNode("Y"), "0.4 * X + E_Y");
pm.setNodeExpression(g.getNode("Z"), "0.4 * X + 0.4 * Y + E_Z");
String error = "pow(Uniform(0, 1), 1.5)";
pm.setNodeExpression(pm.getErrorNode(g.getNode("X")), error);
pm.setNodeExpression(pm.getErrorNode(g.getNode("Y")), error);
pm.setNodeExpression(pm.getErrorNode(g.getNode("Z")), error);
} catch (ParseException e) {
System.out.println(e);
}
GeneralizedSemIm im = new GeneralizedSemIm(pm);
DataSet data = im.simulateData(1000, false);
edu.cmu.tetrad.search.SemBicScore score = new edu.cmu.tetrad.search.SemBicScore(new CovarianceMatrixOnTheFly(data, false));
Fask fask = new Fask(data, score);
fask.setPenaltyDiscount(1);
fask.setAlpha(0.5);
Graph out = fask.search();
System.out.println(out);
}
use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.
the class TestKnowledge method test2.
@Test
public final void test2() {
List<Node> nodes1 = new ArrayList<>();
for (int i = 0; i < 100; i++) {
nodes1.add(new ContinuousVariable("X" + (i + 1)));
}
Graph g = GraphUtils.randomGraph(nodes1, 0, 100, 3, 3, 3, false);
List<Node> nodes = g.getNodes();
List<String> names = new ArrayList<>();
for (Node node : nodes) names.add(node.getName());
Knowledge2 knowledge = new Knowledge2(names);
knowledge.addToTier(0, "X1*");
knowledge.addToTier(1, "X2*");
knowledge.setRequired("X4*,X6*", "X5*");
knowledge.setRequired("X6*", "X5*");
assertTrue(knowledge.isForbidden("X20", "X10"));
assertTrue(knowledge.isRequired("X6", "X5"));
}
use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.
the class BayesUpdaterClassifierEditor method showClassification.
private void showClassification() {
int tabIndex = -1;
for (int i = 0; i < getTabbedPane().getTabCount(); i++) {
if ("Classification".equals(getTabbedPane().getTitleAt(i))) {
getTabbedPane().remove(i);
tabIndex = i;
}
}
// Put the class information into a DataSet.
int[] classifications = getClassifier().getClassifications();
double[][] marginals = getClassifier().getMarginals();
int maxCategory = 0;
for (int classification : classifications) {
if (classification > maxCategory) {
maxCategory = classification;
}
}
List<Node> variables = new LinkedList<>();
DiscreteVariable targetVariable = classifier.getTargetVariable();
DiscreteVariable classVar = new DiscreteVariable(targetVariable.getName(), maxCategory + 1);
variables.add(classVar);
for (int i = 0; i < marginals.length; i++) {
String name = "P(" + targetVariable + "=" + i + ")";
ContinuousVariable scoreVar = new ContinuousVariable(name);
variables.add(scoreVar);
}
classVar.setName("Result");
DataSet dataSet = new ColtDataSet(classifications.length, variables);
for (int i = 0; i < classifications.length; i++) {
dataSet.setInt(i, 0, classifications[i]);
for (int j = 0; j < marginals.length; j++) {
dataSet.setDouble(i, j + 1, marginals[j][i]);
}
}
DataDisplay jTable = new DataDisplay(dataSet);
JScrollPane scroll = new JScrollPane(jTable);
if (tabIndex == -1) {
getTabbedPane().add("Classification", scroll);
} else {
getTabbedPane().add(scroll, tabIndex);
getTabbedPane().setTitleAt(tabIndex, "Classification");
}
}
use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.
the class ModeInterpolator method filter.
public DataSet filter(DataSet dataSet) {
DataSet newDataSet = dataSet.copy();
for (int j = 0; j < dataSet.getNumColumns(); j++) {
Node var = dataSet.getVariable(j);
if (var instanceof DiscreteVariable) {
DiscreteVariable variable = (DiscreteVariable) var;
int numCategories = variable.getNumCategories();
int[] categoryCounts = new int[numCategories];
for (int i = 0; i < dataSet.getNumRows(); i++) {
if (dataSet.getInt(i, j) == DiscreteVariable.MISSING_VALUE) {
continue;
}
categoryCounts[dataSet.getInt(i, j)]++;
}
int mode = -1;
int max = -1;
for (int i = 0; i < numCategories; i++) {
if (categoryCounts[i] > max) {
max = categoryCounts[i];
mode = i;
}
}
for (int i = 0; i < dataSet.getNumRows(); i++) {
if (dataSet.getInt(i, j) == DiscreteVariable.MISSING_VALUE) {
newDataSet.setInt(i, j, mode);
}
// else {
// newDataSet.setInt(i, j, dataSet.getInt(i, j));
// }
}
} else if (dataSet.getVariable(j) instanceof ContinuousVariable) {
double[] data = new double[dataSet.getNumRows()];
int k = -1;
for (int i = 0; i < dataSet.getNumRows(); i++) {
if (!Double.isNaN(dataSet.getDouble(i, j))) {
data[++k] = dataSet.getDouble(i, j);
}
}
Arrays.sort(data);
double mode = Double.NaN;
if (k >= 0) {
mode = (data[(k + 1) / 2] + data[k / 2]) / 2.d;
}
for (int i = 0; i < dataSet.getNumRows(); i++) {
if (Double.isNaN(dataSet.getDouble(i, j))) {
newDataSet.setDouble(i, j, mode);
}
// else {
// newDataSet.setDouble(i, j, dataSet.getDouble(i, j));
// }
}
}
}
return newDataSet;
}
Aggregations