use of edu.cmu.tetrad.data.ColtDataSet in project tetrad by cmu-phil.
the class AbstractMBSearchRunner method setSearchResults.
/**
* Sets the results of the search.
*/
void setSearchResults(List<Node> nodes) {
if (nodes == null) {
throw new NullPointerException("nodes were null.");
}
this.variables = new ArrayList<>(nodes);
if (nodes.isEmpty()) {
this.dataModel = new ColtDataSet(source.getNumRows(), nodes);
} else {
this.dataModel = this.source.subsetColumns(nodes);
}
this.setDataModel(this.dataModel);
}
use of edu.cmu.tetrad.data.ColtDataSet in project tetrad by cmu-phil.
the class TabularComparison method newExecution.
private void newExecution() {
statistics = new ArrayList<>();
statistics.add(new AdjacencyPrecision());
statistics.add(new AdjacencyRecall());
statistics.add(new ArrowheadPrecision());
statistics.add(new ArrowheadRecall());
statistics.add(new TwoCyclePrecision());
statistics.add(new TwoCycleRecall());
statistics.add(new TwoCycleFalsePositive());
// statistics.add(new ElapsedTime());
// statistics.add(new F1Adj());
// statistics.add(new F1Arrow());
// statistics.add(new MathewsCorrAdj());
// statistics.add(new MathewsCorrArrow());
// statistics.add(new SHD());
List<Node> variables = new ArrayList<>();
for (Statistic statistic : statistics) {
variables.add(new ContinuousVariable(statistic.getAbbreviation()));
}
dataSet = new ColtDataSet(0, variables);
dataSet.setNumberFormat(new DecimalFormat("0.00"));
}
use of edu.cmu.tetrad.data.ColtDataSet in project tetrad by cmu-phil.
the class RemoveMissingCasesDataFilter method filter.
public DataSet filter(DataSet data) {
List<Node> variables = data.getVariables();
int numRows = 0;
ROWS: for (int row = 0; row < data.getNumRows(); row++) {
for (int col = 0; col < data.getNumColumns(); col++) {
Node variable = data.getVariable(col);
if (((Variable) variable).isMissingValue(data.getObject(row, col))) {
continue ROWS;
}
}
numRows++;
}
DataSet newDataSet = new ColtDataSet(numRows, variables);
int newRow = 0;
ROWS: for (int row = 0; row < data.getNumRows(); row++) {
for (int col = 0; col < data.getNumColumns(); col++) {
Node variable = data.getVariable(col);
if (((Variable) variable).isMissingValue(data.getObject(row, col))) {
continue ROWS;
}
}
for (int col = 0; col < data.getNumColumns(); col++) {
newDataSet.setObject(newRow, col, data.getObject(row, col));
}
newRow++;
}
return newDataSet;
}
use of edu.cmu.tetrad.data.ColtDataSet in project tetrad by cmu-phil.
the class EmBayesEstimator method initialize.
private void initialize() {
DirichletBayesIm prior = DirichletBayesIm.symmetricDirichletIm(bayesPmObs, 0.5);
observedIm = DirichletEstimator.estimate(prior, dataSet);
// MLBayesEstimator dirichEst = new MLBayesEstimator();
// observedIm = dirichEst.estimate(bayesPmObs, dataSet);
// System.out.println("Estimated Bayes IM for Measured Variables: ");
// System.out.println(observedIm);
// mixedData should be ddsNm with new columns for the latent variables.
// Each such column should contain missing data for each case.
int numFullCases = dataSet.getNumRows();
List<Node> variables = new LinkedList<>();
for (Node node : nodes) {
if (node.getNodeType() == NodeType.LATENT) {
int numCategories = bayesPm.getNumCategories(node);
DiscreteVariable latentVar = new DiscreteVariable(node.getName(), numCategories);
latentVar.setNodeType(NodeType.LATENT);
variables.add(latentVar);
} else {
String name = bayesPm.getVariable(node).getName();
Node variable = dataSet.getVariable(name);
variables.add(variable);
}
}
DataSet dsMixed = new ColtDataSet(numFullCases, variables);
for (int j = 0; j < nodes.length; j++) {
if (nodes[j].getNodeType() == NodeType.LATENT) {
for (int i = 0; i < numFullCases; i++) {
dsMixed.setInt(i, j, -99);
}
} else {
String name = bayesPm.getVariable(nodes[j]).getName();
Node variable = dataSet.getVariable(name);
int index = dataSet.getColumn(variable);
for (int i = 0; i < numFullCases; i++) {
dsMixed.setInt(i, j, dataSet.getInt(i, index));
}
}
}
// System.out.println(dsMixed);
mixedData = dsMixed;
allVariables = mixedData.getVariables();
// Find the bayes net which is parameterized using mixedData or set randomly when that's
// not possible.
estimateIM(bayesPm, mixedData);
// The following DEBUG section tests a case specified by P. Spirtes
// DEBUG TAIL: For use with embayes_l1x1x2x3V3.dat
/*
Node l1Node = graph.getNode("L1");
//int l1Index = bayesImMixed.getNodeIndex(l1Node);
int l1index = estimatedIm.getNodeIndex(l1Node);
Node x1Node = graph.getNode("X1");
//int x1Index = bayesImMixed.getNodeIndex(x1Node);
int x1Index = estimatedIm.getNodeIndex(x1Node);
Node x2Node = graph.getNode("X2");
//int x2Index = bayesImMixed.getNodeIndex(x2Node);
int x2Index = estimatedIm.getNodeIndex(x2Node);
Node x3Node = graph.getNode("X3");
//int x3Index = bayesImMixed.getNodeIndex(x3Node);
int x3Index = estimatedIm.getNodeIndex(x3Node);
estimatedIm.setProbability(l1index, 0, 0, 0.5);
estimatedIm.setProbability(l1index, 0, 1, 0.5);
//bayesImMixed.setProbability(x1Index, 0, 0, 0.33333);
//bayesImMixed.setProbability(x1Index, 0, 1, 0.66667);
estimatedIm.setProbability(x1Index, 0, 0, 0.6); //p(x1 = 0 | l1 = 0)
estimatedIm.setProbability(x1Index, 0, 1, 0.4); //p(x1 = 1 | l1 = 0)
estimatedIm.setProbability(x1Index, 1, 0, 0.4); //p(x1 = 0 | l1 = 1)
estimatedIm.setProbability(x1Index, 1, 1, 0.6); //p(x1 = 1 | l1 = 1)
//bayesImMixed.setProbability(x2Index, 1, 0, 0.66667);
//bayesImMixed.setProbability(x2Index, 1, 1, 0.33333);
estimatedIm.setProbability(x2Index, 1, 0, 0.4); //p(x2 = 0 | l1 = 1)
estimatedIm.setProbability(x2Index, 1, 1, 0.6); //p(x2 = 1 | l1 = 1)
estimatedIm.setProbability(x2Index, 0, 0, 0.6); //p(x2 = 0 | l1 = 0)
estimatedIm.setProbability(x2Index, 0, 1, 0.4); //p(x2 = 1 | l1 = 0)
//bayesImMixed.setProbability(x3Index, 1, 0, 0.66667);
//bayesImMixed.setProbability(x3Index, 1, 1, 0.33333);
estimatedIm.setProbability(x3Index, 1, 0, 0.4); //p(x3 = 0 | l1 = 1)
estimatedIm.setProbability(x3Index, 1, 1, 0.6); //p(x3 = 1 | l1 = 1)
estimatedIm.setProbability(x3Index, 0, 0, 0.6); //p(x3 = 0 | l1 = 0)
estimatedIm.setProbability(x3Index, 0, 1, 0.4); //p(x3 = 1 | l1 = 0)
*/
// END of TAIL
// System.out.println("bayes IM estimated by estimateIM");
// System.out.println(bayesImMixed);
// System.out.println(estimatedIm);
estimatedCounts = new double[nodes.length][][];
estimatedCountsDenom = new double[nodes.length][];
condProbs = new double[nodes.length][][];
for (int i = 0; i < nodes.length; i++) {
// int numRows = bayesImMixed.getNumRows(i);
int numRows = estimatedIm.getNumRows(i);
estimatedCounts[i] = new double[numRows][];
estimatedCountsDenom[i] = new double[numRows];
condProbs[i] = new double[numRows][];
// for(int j = 0; j < bayesImMixed.getNumRows(i); j++) {
for (int j = 0; j < estimatedIm.getNumRows(i); j++) {
// int numCols = bayesImMixed.getNumColumns(i);
int numCols = estimatedIm.getNumColumns(i);
estimatedCounts[i][j] = new double[numCols];
condProbs[i][j] = new double[numCols];
}
}
}
use of edu.cmu.tetrad.data.ColtDataSet in project tetrad by cmu-phil.
the class MlBayesImObs method simulateDataHelper.
/**
* Simulates a sample with the given sample size.
*
* @param sampleSize the sample size.
* @return the simulated sample as a DataSet.
*/
private DataSet simulateDataHelper(int sampleSize, boolean latentDataSaved) {
int numMeasured = 0;
int[] map = new int[nodes.length];
List<Node> variables = new LinkedList<>();
for (int j = 0; j < nodes.length; j++) {
if (!latentDataSaved && nodes[j].getNodeType() != NodeType.MEASURED) {
continue;
}
int numCategories = bayesPm.getNumCategories(nodes[j]);
List<String> categories = new LinkedList<>();
for (int k = 0; k < numCategories; k++) {
categories.add(bayesPm.getCategory(nodes[j], k));
}
DiscreteVariable var = new DiscreteVariable(nodes[j].getName(), categories);
variables.add(var);
int index = ++numMeasured - 1;
map[index] = j;
}
DataSet dataSet = new ColtDataSet(sampleSize, variables);
constructSample(sampleSize, numMeasured, dataSet, map);
return dataSet;
}
Aggregations