use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.
the class AdLeafTree method getCellLeaves.
/**
* Finds the set of indices into the leaves of the tree for the given variables.
* Counts are the sizes of the index sets.
*
* @param A A list of discrete variables.
* @return The list of index sets of the first variable varied by the second variable,
* and so on, to the last variable.
*/
public List<List<List<Integer>>> getCellLeaves(List<DiscreteVariable> A, DiscreteVariable B) {
Collections.sort(A, new Comparator<DiscreteVariable>() {
@Override
public int compare(DiscreteVariable o1, DiscreteVariable o2) {
return Integer.compare(nodesHash.get(o1), nodesHash.get(o2));
}
});
if (baseCase == null) {
Vary vary = new Vary();
this.baseCase = new ArrayList<>();
baseCase.add(vary);
}
List<Vary> varies = baseCase;
for (DiscreteVariable v : A) {
varies = getVaries(varies, nodesHash.get(v));
}
List<List<List<Integer>>> rows = new ArrayList<>();
for (Vary vary : varies) {
for (int i = 0; i < vary.getNumCategories(); i++) {
Vary subvary = vary.getSubvary(nodesHash.get(B), i);
rows.add(subvary.getRows());
}
}
return rows;
}
use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.
the class BayesUpdaterClassifier method setTarget.
// ==========================PUBLIC METHODS========================//
public void setTarget(String target, int targetCategory) {
// Find the target variable using its name.
DiscreteVariable targetVariable = null;
for (int j = 0; j < getBayesImVars().size(); j++) {
DiscreteVariable dv = (DiscreteVariable) getBayesImVars().get(j);
if (dv.getName().equals(target)) {
targetVariable = dv;
break;
}
}
if (targetVariable == null) {
throw new IllegalArgumentException("Not an available target: " + target);
}
this.targetVariable = targetVariable;
}
use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.
the class IndTestMultinomialLogisticRegressionWald method expandVariable.
private List<Node> expandVariable(DataSet dataSet, Node node) {
if (node instanceof ContinuousVariable) {
return Collections.singletonList(node);
}
if (node instanceof DiscreteVariable && ((DiscreteVariable) node).getNumCategories() < 3) {
return Collections.singletonList(node);
}
if (!(node instanceof DiscreteVariable)) {
throw new IllegalArgumentException();
}
List<String> varCats = new ArrayList<>(((DiscreteVariable) node).getCategories());
varCats.remove(0);
List<Node> variables = new ArrayList<>();
for (String cat : varCats) {
Node newVar;
do {
String newVarName = node.getName() + "MULTINOM" + "." + cat;
newVar = new DiscreteVariable(newVarName, 2);
} while (dataSet.getVariable(newVar.getName()) != null);
variables.add(newVar);
dataSet.addVariable(newVar);
int newVarIndex = dataSet.getColumn(newVar);
int numCases = dataSet.getNumRows();
for (int l = 0; l < numCases; l++) {
Object dataCell = dataSet.getObject(l, dataSet.getColumn(node));
int dataCellIndex = ((DiscreteVariable) node).getIndex(dataCell.toString());
if (dataCellIndex == ((DiscreteVariable) node).getIndex(cat))
dataSet.setInt(l, newVarIndex, 1);
else
dataSet.setInt(l, newVarIndex, 0);
}
}
return variables;
}
use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.
the class MNLRLikelihood method getLik.
public double getLik(int child_index, int[] parents) {
double lik = 0;
Node c = variables.get(child_index);
List<ContinuousVariable> continuous_parents = new ArrayList<>();
List<DiscreteVariable> discrete_parents = new ArrayList<>();
for (int p : parents) {
Node parent = variables.get(p);
if (parent instanceof ContinuousVariable) {
continuous_parents.add((ContinuousVariable) parent);
} else {
discrete_parents.add((DiscreteVariable) parent);
}
}
int p = continuous_parents.size();
List<List<Integer>> cells = adTree.getCellLeaves(discrete_parents);
// List<List<Integer>> cells = partition(discrete_parents);
int[] continuousCols = new int[p];
for (int j = 0; j < p; j++) continuousCols[j] = nodesHash.get(continuous_parents.get(j));
for (List<Integer> cell : cells) {
int r = cell.size();
if (r > 1) {
double[] mean = new double[p];
double[] var = new double[p];
for (int i = 0; i < p; i++) {
for (int j = 0; j < r; j++) {
mean[i] += continuousData[continuousCols[i]][cell.get(j)];
var[i] += Math.pow(continuousData[continuousCols[i]][cell.get(j)], 2);
}
mean[i] /= r;
var[i] /= r;
var[i] -= Math.pow(mean[i], 2);
var[i] = Math.sqrt(var[i]);
if (Double.isNaN(var[i])) {
System.out.println(var[i]);
}
}
int degree = fDegree;
if (fDegree < 1) {
degree = (int) Math.floor(Math.log(r));
}
TetradMatrix subset = new TetradMatrix(r, p * degree + 1);
for (int i = 0; i < r; i++) {
subset.set(i, p * degree, 1);
for (int j = 0; j < p; j++) {
for (int d = 0; d < degree; d++) {
subset.set(i, p * d + j, Math.pow((continuousData[continuousCols[j]][cell.get(i)] - mean[j]) / var[j], d + 1));
}
}
}
if (c instanceof ContinuousVariable) {
TetradVector target = new TetradVector(r);
for (int i = 0; i < r; i++) {
target.set(i, continuousData[child_index][cell.get(i)]);
}
lik += multipleRegression(target, subset);
} else {
ArrayList<Integer> temp = new ArrayList<>();
TetradMatrix target = new TetradMatrix(r, ((DiscreteVariable) c).getNumCategories());
for (int i = 0; i < r; i++) {
for (int j = 0; j < ((DiscreteVariable) c).getNumCategories(); j++) {
target.set(i, j, -1);
}
target.set(i, discreteData[child_index][cell.get(i)], 1);
}
lik += MultinomialLogisticRegression(target, subset);
}
}
}
return lik;
}
use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.
the class MbClassify method classify.
// ============================PUBLIC METHODS=========================//
/**
* Classifies the test data by Bayesian updating. The procedure is as follows. First, MBFS is run on the training
* data to estimate an MB pattern. Bidirected edges are removed; an MB DAG G is selected from the pattern that
* remains. Second, a Bayes model B is estimated using this G and the training data. Third, for each case in the
* test data, the marginal for the target variable in B is calculated conditioning on values of the other varialbes
* in B in the test data; these are reported as classifications. Estimation of B is done using a Dirichlet
* estimator, with a symmetric prior, with the given alpha value. Updating is done using a row-summing exact
* updater.
* <p>
* One consequence of using the row-summing exact updater is that classification will be fast except for cases in
* which there are lots of missing values. The reason for this is that for such cases the number of rows that need
* to be summed over will be exponential in the number of missing values for that case. Hence the parameter for max
* num missing values. A good default for this is like 5. Any test case with more than that number of missing values
* will be skipped.
*
* @return The classifications.
*/
public int[] classify() {
IndependenceTest indTest = new IndTestChiSquare(train, alpha);
Mbfs search = new Mbfs(indTest, depth);
search.setDepth(depth);
// Hiton search = new Hiton(indTest, depth);
// Mmmb search = new Mmmb(indTest, depth);
List<Node> mbPlusTarget = search.findMb(target);
mbPlusTarget.add(train.getVariable(target));
DataSet subset = train.subsetColumns(mbPlusTarget);
System.out.println("subset vars = " + subset.getVariables());
Pc patternSearch = new Pc(new IndTestChiSquare(subset, 0.05));
// patternSearch.setMaxIndegree(depth);
Graph mbPattern = patternSearch.search();
// MbFanSearch search = new MbFanSearch(indTest, depth);
// Graph mbPattern = search.search(target);
TetradLogger.getInstance().log("details", "Pattern = " + mbPattern);
MbUtils.trimToMbNodes(mbPattern, train.getVariable(target), true);
TetradLogger.getInstance().log("details", "Trimmed pattern = " + mbPattern);
// Removing bidirected edges from the pattern before selecting a DAG. 4
for (Edge edge : mbPattern.getEdges()) {
if (Edges.isBidirectedEdge(edge)) {
mbPattern.removeEdge(edge);
}
}
Graph selectedDag = MbUtils.getOneMbDag(mbPattern);
TetradLogger.getInstance().log("details", "Selected DAG = " + selectedDag);
TetradLogger.getInstance().log("details", "Vars = " + selectedDag.getNodes());
TetradLogger.getInstance().log("details", "\nClassification using selected MB DAG:");
NumberFormat nf = NumberFormatUtil.getInstance().getNumberFormat();
List<Node> mbNodes = selectedDag.getNodes();
// The Markov blanket nodes will correspond to a subset of the variables
// in the training dataset. Find the subset dataset.
DataSet trainDataSubset = train.subsetColumns(mbNodes);
// To create a Bayes net for the Markov blanket we need the DAG.
BayesPm bayesPm = new BayesPm(selectedDag);
// To parameterize the Bayes net we need the number of values
// of each variable.
List varsTrain = trainDataSubset.getVariables();
for (int i1 = 0; i1 < varsTrain.size(); i1++) {
DiscreteVariable trainingVar = (DiscreteVariable) varsTrain.get(i1);
bayesPm.setCategories(mbNodes.get(i1), trainingVar.getCategories());
}
// Create an updater for the instantiated Bayes net.
TetradLogger.getInstance().log("info", "Estimating Bayes net; please wait...");
DirichletBayesIm prior = DirichletBayesIm.symmetricDirichletIm(bayesPm, this.prior);
BayesIm bayesIm = DirichletEstimator.estimate(prior, trainDataSubset);
RowSummingExactUpdater updater = new RowSummingExactUpdater(bayesIm);
// The subset dataset of the dataset to be classified containing
// the variables in the Markov blanket.
DataSet testSubset = this.test.subsetColumns(mbNodes);
// Get the raw data from the dataset to be classified, the number
// of variables, and the number of cases.
int numCases = testSubset.getNumRows();
int[] estimatedCategories = new int[numCases];
Arrays.fill(estimatedCategories, -1);
// The variables in the dataset.
List<Node> varsClassify = testSubset.getVariables();
// of the crosstabulation array.
for (int k = 0; k < numCases; k++) {
// Create an Evidence instance for the instantiated Bayes net
// which will allow that updating.
Proposition proposition = Proposition.tautology(bayesIm);
// Restrict all other variables to their observed values in
// this case.
int numMissing = 0;
for (int testIndex = 0; testIndex < varsClassify.size(); testIndex++) {
DiscreteVariable var = (DiscreteVariable) varsClassify.get(testIndex);
// If it's the target, ignore it.
if (var.equals(targetVariable)) {
continue;
}
int trainIndex = proposition.getNodeIndex(var.getName());
// If it's not in the train subset, ignore it.
if (trainIndex == -99) {
continue;
}
int testValue = testSubset.getInt(k, testIndex);
if (testValue == -99) {
numMissing++;
} else {
proposition.setCategory(trainIndex, testValue);
}
}
if (numMissing > this.maxMissing) {
TetradLogger.getInstance().log("details", "classification(" + k + ") = " + "not done since number of missing values too high " + "(" + numMissing + ").");
continue;
}
Evidence evidence = Evidence.tautology(bayesIm);
evidence.getProposition().restrictToProposition(proposition);
updater.setEvidence(evidence);
// for each possible value of target compute its probability in
// the updated Bayes net. Select the value with the highest
// probability as the estimated getValue.
int targetIndex = proposition.getNodeIndex(targetVariable.getName());
// Straw man values--to be replaced.
double highestProb = -0.1;
int _category = -1;
for (int category = 0; category < targetVariable.getNumCategories(); category++) {
double marginal = updater.getMarginal(targetIndex, category);
if (marginal > highestProb) {
highestProb = marginal;
_category = category;
}
}
// training dataset. If that happens skip the case.
if (_category < 0) {
System.out.println("classification(" + k + ") is undefined " + "(undefined marginals).");
continue;
}
String estimatedCategory = targetVariable.getCategories().get(_category);
TetradLogger.getInstance().log("details", "classification(" + k + ") = " + estimatedCategory);
estimatedCategories[k] = _category;
}
// Create a crosstabulation table to store the coefs of observed
// versus estimated occurrences of each value of the target variable.
int targetIndex = varsClassify.indexOf(targetVariable);
int numCategories = targetVariable.getNumCategories();
int[][] crossTabs = new int[numCategories][numCategories];
// Will count the number of cases where the target variable
// is correctly classified.
int numberCorrect = 0;
int numberCounted = 0;
for (int k = 0; k < numCases; k++) {
int estimatedCategory = estimatedCategories[k];
int observedValue = testSubset.getInt(k, targetIndex);
if (estimatedCategory < 0) {
continue;
}
crossTabs[observedValue][estimatedCategory]++;
numberCounted++;
if (observedValue == estimatedCategory) {
numberCorrect++;
}
}
double percentCorrect1 = 100.0 * ((double) numberCorrect) / ((double) numberCounted);
// Print the cross classification.
TetradLogger.getInstance().log("details", "");
TetradLogger.getInstance().log("details", "\t\t\tEstimated\t");
TetradLogger.getInstance().log("details", "Observed\t");
StringBuilder buf0 = new StringBuilder();
buf0.append("\t");
for (int m = 0; m < numCategories; m++) {
buf0.append(targetVariable.getCategory(m)).append("\t");
}
TetradLogger.getInstance().log("details", buf0.toString());
for (int k = 0; k < numCategories; k++) {
StringBuilder buf = new StringBuilder();
buf.append(targetVariable.getCategory(k)).append("\t");
for (int m = 0; m < numCategories; m++) buf.append(crossTabs[k][m]).append("\t");
TetradLogger.getInstance().log("details", buf.toString());
}
TetradLogger.getInstance().log("details", "");
TetradLogger.getInstance().log("details", "Number correct = " + numberCorrect);
TetradLogger.getInstance().log("details", "Number counted = " + numberCounted);
TetradLogger.getInstance().log("details", "Percent correct = " + nf.format(percentCorrect1) + "%");
crossTabulation = crossTabs;
percentCorrect = percentCorrect1;
return estimatedCategories;
}
Aggregations