use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.
the class BayesUpdaterClassifier method classify.
/**
* Computes and returns the crosstabulation of observed versus estimated
* values of the target variable as described above.
*/
public int[] classify() {
if (targetVariable == null) {
throw new NullPointerException("Target not set.");
}
// Create an updater for the instantiated Bayes net.
BayesUpdater bayesUpdater = new RowSummingExactUpdater(getBayesIm());
// Get the raw data from the dataset to be classified, the number
// of variables and the number of cases.
int nvars = getBayesImVars().size();
int ncases = testData.getNumRows();
int[] varIndices = new int[nvars];
List<Node> dataVars = testData.getVariables();
for (int i = 0; i < nvars; i++) {
DiscreteVariable variable = (DiscreteVariable) getBayesImVars().get(i);
if (variable == targetVariable) {
continue;
}
varIndices[i] = dataVars.indexOf(variable);
if (varIndices[i] == -1) {
throw new IllegalArgumentException("Can't find the (non-target) variable " + variable + " in the data. Either it's not there, or else its " + "categories are in a different order.");
}
}
DataSet selectedData = testData.subsetColumns(varIndices);
this.numCases = ncases;
int[] estimatedValues = new int[ncases];
int numTargetCategories = targetVariable.getNumCategories();
double[][] probOfClassifiedValues = new double[numTargetCategories][ncases];
Arrays.fill(estimatedValues, -1);
// and Bayesian updating.
for (int i = 0; i < ncases; i++) {
// Create an Evidence instance for the instantiated Bayes net
// which will allow that updating.
Evidence evidence = Evidence.tautology(getBayesIm());
// Let the target variable range over all its values.
int itarget = evidence.getNodeIndex(targetVariable.getName());
evidence.getProposition().setVariable(itarget, true);
this.missingValueCaseFound = false;
// this case.
for (int j = 0; j < getBayesImVars().size(); j++) {
if (j == getBayesImVars().indexOf(targetVariable)) {
continue;
}
int observedValue = selectedData.getInt(i, j);
if (observedValue == DiscreteVariable.MISSING_VALUE) {
this.missingValueCaseFound = true;
continue;
}
String jName = getBayesImVars().get(j).getName();
int jIndex = evidence.getNodeIndex(jName);
evidence.getProposition().setCategory(jIndex, observedValue);
}
// Update using those values.
bayesUpdater.setEvidence(evidence);
// for each possible value of target compute its probability in
// the updated Bayes net. Select the value with the highest
// probability as the estimated value.
Node targetNode = getBayesIm().getNode(targetVariable.getName());
int indexTargetBN = getBayesIm().getNodeIndex(targetNode);
// Straw man values--to be replaced.
int estimatedValue = -1;
// if (numTargetCategories == 2) {
// for (int j = 0; j < numTargetCategories; j++) {
// double marginal =
// bayesUpdater.getMarginal(indexTargetBN, j);
// probOfClassifiedValues[j][i] = marginal;
// probOfClassifiedValues[1 - j][i] = 1.0 - marginal;
//
// if (targetCategory == j) {
// if (marginal > binaryCutoff) {
// estimatedValue = j;
// } else {
// estimatedValue = 1 - j;
// }
//
// break;
// }
// }
// } else
{
double highestProb = -0.1;
for (int j = 0; j < numTargetCategories; j++) {
double marginal = bayesUpdater.getMarginal(indexTargetBN, j);
probOfClassifiedValues[j][i] = marginal;
if (marginal >= highestProb) {
highestProb = marginal;
estimatedValue = j;
}
}
}
// training dataset. If that happens skip the case.
if (estimatedValue < 0) {
TetradLogger.getInstance().log("details", "Case " + i + " does not return valid marginal.");
for (int m = 0; m < nvars; m++) {
// System.out.print(getBayesImVars()
// .get(m).getNode());
TetradLogger.getInstance().log("details", " " + selectedData.getDouble(i, m));
}
estimatedValues[i] = DiscreteVariable.MISSING_VALUE;
continue;
}
estimatedValues[i] = estimatedValue;
}
this.classifications = estimatedValues;
this.marginals = probOfClassifiedValues;
return estimatedValues;
}
use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.
the class LogisticRegressionRunner method execute.
// =================PUBLIC METHODS OVERRIDING ABSTRACT=================//
/**
* Executes the algorithm, producing (at least) a result workbench. Must be
* implemented in the extending class.
*/
public void execute() {
outGraph = new EdgeListGraph();
if (regressorNames == null || regressorNames.isEmpty() || targetName == null) {
report = "Response and predictor variables not set.";
return;
}
if (regressorNames.contains(targetName)) {
report = "Response must not be a predictor.";
return;
}
DataSet regressorsDataSet = dataSets.get(getModelIndex()).copy();
Node target = regressorsDataSet.getVariable(targetName);
regressorsDataSet.removeColumn(target);
List<String> names = regressorsDataSet.getVariableNames();
// Get the list of regressors selected by the user
List<Node> regressorNodes = new ArrayList<>();
for (String s : regressorNames) {
regressorNodes.add(dataSets.get(getModelIndex()).getVariable(s));
}
// If the user selected none, use them all
if (regressorNames.size() > 0) {
for (String name1 : names) {
Node regressorVar = regressorsDataSet.getVariable(name1);
if (!regressorNames.contains(regressorVar.getName())) {
regressorsDataSet.removeColumn(regressorVar);
}
}
}
int ncases = regressorsDataSet.getNumRows();
int nvars = regressorsDataSet.getNumColumns();
double[][] regressors = new double[nvars][ncases];
for (int i = 0; i < nvars; i++) {
for (int j = 0; j < ncases; j++) {
regressors[i][j] = regressorsDataSet.getDouble(j, i);
}
}
LogisticRegression logRegression = new LogisticRegression(dataSets.get(getModelIndex()));
logRegression.setAlpha(alpha);
this.result = logRegression.regress((DiscreteVariable) target, regressorNodes);
}
use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.
the class EmBayesEstimator method expectation.
/**
* This method takes an instantiated Bayes net (BayesIm) whose graph include
* all the variables (observed and latent) and computes estimated counts
* using the data in the DataSet mixedData. </p> The counts that are
* estimated correspond to cells in the conditional probability tables of
* the Bayes net. The outermost loop (indexed by j) is over the set of
* variables. If the variable has no parents, each case in the dataset is
* examined and the count for the observed value of the variables is
* increased by 1.0; if the value of the variable is missing the marginal
* probabilities its values given the values of the variables that are
* available for that case are used to increment the corresponding estimated
* counts. </p> If a variable has parents then there is a loop which steps
* through all possible sets of values of its parents. This loop is indexed
* by the variable "row". Each case in the dataset is examined. It the
* variable and all its parents have values in the case the corresponding
* estimated counts are incremented by 1.0. If the variable or any of its
* parents have missing values, the joint marginal is computed for the
* variable and the set of values of its parents corresponding to "row" and
* the corresponding estimated counts are incremented by the appropriate
* probability. </p> The estimated counts are stored in the double[][][]
* array estimatedCounts. The count (possibly fractional) of the number of
* times each combination of parent values occurs is stored in the
* double[][] array estimatedCountsDenom. These two arrays are used to
* compute the estimated conditional probabilities of the output Bayes net.
*/
private void expectation(BayesIm inputBayesIm) {
// System.out.println("Entered method expectation.");
int numCases = mixedData.getNumRows();
// StoredCellEstCounts estCounts = new StoredCellEstCounts(variables);
int numVariables = allVariables.size();
RowSummingExactUpdater rseu = new RowSummingExactUpdater(inputBayesIm);
for (int j = 0; j < numVariables; j++) {
DiscreteVariable var = (DiscreteVariable) allVariables.get(j);
String varName = var.getName();
Node varNode = graph.getNode(varName);
int varIndex = inputBayesIm.getNodeIndex(varNode);
int[] parentVarIndices = inputBayesIm.getParents(varIndex);
// This segment is for variables with no parents:
if (parentVarIndices.length == 0) {
// System.out.println("No parents");
for (int col = 0; col < var.getNumCategories(); col++) {
estimatedCounts[j][0][col] = 0.0;
}
for (int i = 0; i < numCases; i++) {
// If this case has a value for ar
if (mixedData.getInt(i, j) != -99) {
estimatedCounts[j][0][mixedData.getInt(i, j)] += 1.0;
// System.out.println("Adding 1.0 to " + varName +
// " row 0 category " + mixedData[j][i]);
} else {
// find marginal probability, given obs data in this case, p(v=0)
Evidence evidenceThisCase = Evidence.tautology(inputBayesIm);
boolean existsEvidence = false;
// Define evidence for updating by using the values of the other vars.
for (int k = 0; k < numVariables; k++) {
if (k == j) {
continue;
}
Node otherVar = allVariables.get(k);
if (mixedData.getInt(i, k) == -99) {
continue;
}
existsEvidence = true;
String otherVarName = otherVar.getName();
Node otherNode = graph.getNode(otherVarName);
int otherIndex = inputBayesIm.getNodeIndex(otherNode);
evidenceThisCase.getProposition().setCategory(otherIndex, mixedData.getInt(i, k));
}
if (!existsEvidence) {
// No other variable contained useful data
continue;
}
rseu.setEvidence(evidenceThisCase);
for (int m = 0; m < var.getNumCategories(); m++) {
estimatedCounts[j][0][m] += rseu.getMarginal(varIndex, m);
// System.out.println("Adding " + p + " to " + varName +
// " row 0 category " + m);
// find marginal probability, given obs data in this case, p(v=1)
// estimatedCounts[j][0][1] += 0.5;
}
}
}
// Print estimated counts:
// System.out.println("Estimated counts: ");
// Print counts for each value of this variable with no parents.
// for(int m = 0; m < ar.getNumSplits(); m++)
// System.out.print(" " + m + " " + estimatedCounts[j][0][m]);
// System.out.println();
} else {
// For variables with parents:
int numRows = inputBayesIm.getNumRows(varIndex);
for (int row = 0; row < numRows; row++) {
int[] parValues = inputBayesIm.getParentValues(varIndex, row);
estimatedCountsDenom[varIndex][row] = 0.0;
for (int col = 0; col < var.getNumCategories(); col++) {
estimatedCounts[varIndex][row][col] = 0.0;
}
for (int i = 0; i < numCases; i++) {
// for a case where the parent values = parValues increment the estCount
boolean parentMatch = true;
for (int p = 0; p < parentVarIndices.length; p++) {
if (parValues[p] != mixedData.getInt(i, parentVarIndices[p]) && mixedData.getInt(i, parentVarIndices[p]) != -99) {
parentMatch = false;
break;
}
}
if (!parentMatch) {
// Not a matching case; go to next.
continue;
}
boolean parentMissing = false;
for (int parentVarIndice : parentVarIndices) {
if (mixedData.getInt(i, parentVarIndice) == -99) {
parentMissing = true;
break;
}
}
if (mixedData.getInt(i, j) != -99 && !parentMissing) {
estimatedCounts[j][row][mixedData.getInt(i, j)] += 1.0;
estimatedCountsDenom[j][row] += 1.0;
// Next case
continue;
}
// for a case with missing data (either ar or one of its parents)
// compute the joint marginal
// distribution for ar & this combination of values of its parents
// and update the estCounts accordingly
// To compute marginals create the evidence
boolean existsEvidence = false;
Evidence evidenceThisCase = Evidence.tautology(inputBayesIm);
if (!existsEvidence) {
continue;
}
rseu.setEvidence(evidenceThisCase);
estimatedCountsDenom[j][row] += rseu.getJointMarginal(parentVarIndices, parValues);
int[] parPlusChildIndices = new int[parentVarIndices.length + 1];
int[] parPlusChildValues = new int[parentVarIndices.length + 1];
parPlusChildIndices[0] = varIndex;
for (int pc = 1; pc < parPlusChildIndices.length; pc++) {
parPlusChildIndices[pc] = parentVarIndices[pc - 1];
parPlusChildValues[pc] = parValues[pc - 1];
}
for (int m = 0; m < var.getNumCategories(); m++) {
parPlusChildValues[0] = m;
/*
if(varName.equals("X1") && i == 0 ) {
System.out.println("Calling getJointMarginal with parvalues");
for(int k = 0; k < parPlusChildIndices.length; k++) {
int pIndex = parPlusChildIndices[k];
Node pNode = inputBayesIm.getIndex(pIndex);
String pName = pNode.getNode();
System.out.println(pName + " " + parPlusChildValues[k]);
}
}
*/
/*
if(varName.equals("X1") && i == 0 ) {
System.out.println("Evidence = " + evidenceThisCase);
//int[] vars = {l1Index, x1Index};
Node nodex1 = inputBayesIm.getIndex("X1");
int x1Index = inputBayesIm.getNodeIndex(nodex1);
Node nodel1 = inputBayesIm.getIndex("L1");
int l1Index = inputBayesIm.getNodeIndex(nodel1);
int[] vars = {l1Index, x1Index};
int[] vals = {0, 0};
double ptest = rseu.getJointMarginal(vars, vals);
System.out.println("Joint marginal (X1=0, L1 = 0) = " + p);
}
*/
estimatedCounts[j][row][m] += rseu.getJointMarginal(parPlusChildIndices, parPlusChildValues);
// System.out.println("Case " + i + " parent values ");
// for (int pp = 0; pp < parentVarIndices.length; pp++) {
// Variable par = (Variable) allVariables.get(parentVarIndices[pp]);
// System.out.print(" " + par.getNode() + " " + parValues[pp]);
// }
// System.out.println();
// System.out.println("Adding " + p + " to " + varName +
// " row " + row + " category " + m);
}
// }
}
// Print estimated counts:
// System.out.println("Estimated counts: ");
// System.out.println(" Parent values: ");
// for (int i = 0; i < parentVarIndices.length; i++) {
// Variable par = (Variable) allVariables.get(parentVarIndices[i]);
// System.out.print(" " + par.getNode() + " " + parValues[i] + " ");
// }
// System.out.println();
// for(int m = 0; m < ar.getNumSplits(); m++)
// System.out.print(" " + m + " " + estimatedCounts[j][row][m]);
// System.out.println();
}
}
// else
}
// j < numVariables
BayesIm outputBayesIm = new MlBayesIm(bayesPm);
for (int j = 0; j < nodes.length; j++) {
DiscreteVariable var = (DiscreteVariable) allVariables.get(j);
String varName = var.getName();
Node varNode = graph.getNode(varName);
int varIndex = inputBayesIm.getNodeIndex(varNode);
// int[] parentVarIndices = inputBayesIm.getParents(varIndex);
int numRows = inputBayesIm.getNumRows(j);
// System.out.println("Conditional probabilities for variable " + varName);
int numCols = inputBayesIm.getNumColumns(j);
if (numRows == 1) {
double sum = 0.0;
for (int m = 0; m < numCols; m++) {
sum += estimatedCounts[j][0][m];
}
for (int m = 0; m < numCols; m++) {
condProbs[j][0][m] = estimatedCounts[j][0][m] / sum;
// System.out.print(" " + condProbs[j][0][m]);
outputBayesIm.setProbability(varIndex, 0, m, condProbs[j][0][m]);
}
// System.out.println();
} else {
for (int row = 0; row < numRows; row++) {
for (int m = 0; m < numCols; m++) {
if (estimatedCountsDenom[j][row] != 0.0) {
condProbs[j][row][m] = estimatedCounts[j][row][m] / estimatedCountsDenom[j][row];
} else {
condProbs[j][row][m] = Double.NaN;
}
// System.out.print(" " + condProbs[j][row][m]);
outputBayesIm.setProbability(varIndex, row, m, condProbs[j][row][m]);
}
// System.out.println();
}
}
}
}
use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.
the class EmBayesEstimator method initialize.
private void initialize() {
DirichletBayesIm prior = DirichletBayesIm.symmetricDirichletIm(bayesPmObs, 0.5);
observedIm = DirichletEstimator.estimate(prior, dataSet);
// MLBayesEstimator dirichEst = new MLBayesEstimator();
// observedIm = dirichEst.estimate(bayesPmObs, dataSet);
// System.out.println("Estimated Bayes IM for Measured Variables: ");
// System.out.println(observedIm);
// mixedData should be ddsNm with new columns for the latent variables.
// Each such column should contain missing data for each case.
int numFullCases = dataSet.getNumRows();
List<Node> variables = new LinkedList<>();
for (Node node : nodes) {
if (node.getNodeType() == NodeType.LATENT) {
int numCategories = bayesPm.getNumCategories(node);
DiscreteVariable latentVar = new DiscreteVariable(node.getName(), numCategories);
latentVar.setNodeType(NodeType.LATENT);
variables.add(latentVar);
} else {
String name = bayesPm.getVariable(node).getName();
Node variable = dataSet.getVariable(name);
variables.add(variable);
}
}
DataSet dsMixed = new ColtDataSet(numFullCases, variables);
for (int j = 0; j < nodes.length; j++) {
if (nodes[j].getNodeType() == NodeType.LATENT) {
for (int i = 0; i < numFullCases; i++) {
dsMixed.setInt(i, j, -99);
}
} else {
String name = bayesPm.getVariable(nodes[j]).getName();
Node variable = dataSet.getVariable(name);
int index = dataSet.getColumn(variable);
for (int i = 0; i < numFullCases; i++) {
dsMixed.setInt(i, j, dataSet.getInt(i, index));
}
}
}
// System.out.println(dsMixed);
mixedData = dsMixed;
allVariables = mixedData.getVariables();
// Find the bayes net which is parameterized using mixedData or set randomly when that's
// not possible.
estimateIM(bayesPm, mixedData);
// The following DEBUG section tests a case specified by P. Spirtes
// DEBUG TAIL: For use with embayes_l1x1x2x3V3.dat
/*
Node l1Node = graph.getNode("L1");
//int l1Index = bayesImMixed.getNodeIndex(l1Node);
int l1index = estimatedIm.getNodeIndex(l1Node);
Node x1Node = graph.getNode("X1");
//int x1Index = bayesImMixed.getNodeIndex(x1Node);
int x1Index = estimatedIm.getNodeIndex(x1Node);
Node x2Node = graph.getNode("X2");
//int x2Index = bayesImMixed.getNodeIndex(x2Node);
int x2Index = estimatedIm.getNodeIndex(x2Node);
Node x3Node = graph.getNode("X3");
//int x3Index = bayesImMixed.getNodeIndex(x3Node);
int x3Index = estimatedIm.getNodeIndex(x3Node);
estimatedIm.setProbability(l1index, 0, 0, 0.5);
estimatedIm.setProbability(l1index, 0, 1, 0.5);
//bayesImMixed.setProbability(x1Index, 0, 0, 0.33333);
//bayesImMixed.setProbability(x1Index, 0, 1, 0.66667);
estimatedIm.setProbability(x1Index, 0, 0, 0.6); //p(x1 = 0 | l1 = 0)
estimatedIm.setProbability(x1Index, 0, 1, 0.4); //p(x1 = 1 | l1 = 0)
estimatedIm.setProbability(x1Index, 1, 0, 0.4); //p(x1 = 0 | l1 = 1)
estimatedIm.setProbability(x1Index, 1, 1, 0.6); //p(x1 = 1 | l1 = 1)
//bayesImMixed.setProbability(x2Index, 1, 0, 0.66667);
//bayesImMixed.setProbability(x2Index, 1, 1, 0.33333);
estimatedIm.setProbability(x2Index, 1, 0, 0.4); //p(x2 = 0 | l1 = 1)
estimatedIm.setProbability(x2Index, 1, 1, 0.6); //p(x2 = 1 | l1 = 1)
estimatedIm.setProbability(x2Index, 0, 0, 0.6); //p(x2 = 0 | l1 = 0)
estimatedIm.setProbability(x2Index, 0, 1, 0.4); //p(x2 = 1 | l1 = 0)
//bayesImMixed.setProbability(x3Index, 1, 0, 0.66667);
//bayesImMixed.setProbability(x3Index, 1, 1, 0.33333);
estimatedIm.setProbability(x3Index, 1, 0, 0.4); //p(x3 = 0 | l1 = 1)
estimatedIm.setProbability(x3Index, 1, 1, 0.6); //p(x3 = 1 | l1 = 1)
estimatedIm.setProbability(x3Index, 0, 0, 0.6); //p(x3 = 0 | l1 = 0)
estimatedIm.setProbability(x3Index, 0, 1, 0.4); //p(x3 = 1 | l1 = 0)
*/
// END of TAIL
// System.out.println("bayes IM estimated by estimateIM");
// System.out.println(bayesImMixed);
// System.out.println(estimatedIm);
estimatedCounts = new double[nodes.length][][];
estimatedCountsDenom = new double[nodes.length][];
condProbs = new double[nodes.length][][];
for (int i = 0; i < nodes.length; i++) {
// int numRows = bayesImMixed.getNumRows(i);
int numRows = estimatedIm.getNumRows(i);
estimatedCounts[i] = new double[numRows][];
estimatedCountsDenom[i] = new double[numRows];
condProbs[i] = new double[numRows][];
// for(int j = 0; j < bayesImMixed.getNumRows(i); j++) {
for (int j = 0; j < estimatedIm.getNumRows(i); j++) {
// int numCols = bayesImMixed.getNumColumns(i);
int numCols = estimatedIm.getNumColumns(i);
estimatedCounts[i][j] = new double[numCols];
condProbs[i][j] = new double[numCols];
}
}
}
use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.
the class MlBayesImObs method simulateDataHelper.
/**
* Simulates a sample with the given sample size.
*
* @param sampleSize the sample size.
* @return the simulated sample as a DataSet.
*/
private DataSet simulateDataHelper(int sampleSize, boolean latentDataSaved) {
int numMeasured = 0;
int[] map = new int[nodes.length];
List<Node> variables = new LinkedList<>();
for (int j = 0; j < nodes.length; j++) {
if (!latentDataSaved && nodes[j].getNodeType() != NodeType.MEASURED) {
continue;
}
int numCategories = bayesPm.getNumCategories(nodes[j]);
List<String> categories = new LinkedList<>();
for (int k = 0; k < numCategories; k++) {
categories.add(bayesPm.getCategory(nodes[j], k));
}
DiscreteVariable var = new DiscreteVariable(nodes[j].getName(), categories);
variables.add(var);
int index = ++numMeasured - 1;
map[index] = j;
}
DataSet dataSet = new ColtDataSet(sampleSize, variables);
constructSample(sampleSize, numMeasured, dataSet, map);
return dataSet;
}
Aggregations