use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.
the class BayesXmlParser method makeBayesPm.
private BayesPm makeBayesPm(List<Node> variables, Element element1) {
if (!"parents".equals(element1.getQualifiedName())) {
throw new IllegalArgumentException("Expecting 'parents' element.");
}
Dag graph = new Dag();
for (Node variable : variables) {
graph.addNode(variable);
}
Elements elements = element1.getChildElements();
for (int i = 0; i < elements.size(); i++) {
Element e1 = elements.get(i);
if (!"parentsFor".equals(e1.getQualifiedName())) {
throw new IllegalArgumentException("Expecting 'parentsFor' element.");
}
String varName = e1.getAttributeValue("name");
Node var = namesToVars.get(varName);
Elements elements1 = e1.getChildElements();
for (int j = 0; j < elements1.size(); j++) {
Element e2 = elements1.get(j);
if (!"parent".equals(e2.getQualifiedName())) {
throw new IllegalArgumentException("Expecting 'parent' element.");
}
String parentName = e2.getAttributeValue("name");
Node parent = namesToVars.get(parentName);
graph.addDirectedEdge(parent, var);
}
}
BayesPm bayesPm = new BayesPm(graph);
for (Node variable1 : variables) {
DiscreteVariable graphVariable = (DiscreteVariable) variable1;
List<String> categories = graphVariable.getCategories();
bayesPm.setCategories(graphVariable, categories);
}
return bayesPm;
}
use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.
the class BayesXmlParser method getVariables.
private List<Node> getVariables(Element element0) {
if (!"bnVariables".equals(element0.getQualifiedName())) {
throw new IllegalArgumentException("Expecting 'bnVariables' element.");
}
List<Node> variables = new LinkedList<>();
Elements elements = element0.getChildElements();
for (int i = 0; i < elements.size(); i++) {
Element e1 = elements.get(i);
Elements e2Elements = e1.getChildElements();
if (!"discreteVariable".equals(e1.getQualifiedName())) {
throw new IllegalArgumentException("Expecting 'discreteVariable' " + "element.");
}
String name = e1.getAttributeValue("name");
String isLatentVal = e1.getAttributeValue("latent");
boolean isLatent = (isLatentVal != null) && ((isLatentVal.equals("yes")));
Integer x = new Integer(e1.getAttributeValue("x"));
Integer y = new Integer(e1.getAttributeValue("y"));
int numCategories = e2Elements.size();
List<String> categories = new LinkedList<>();
for (int j = 0; j < numCategories; j++) {
Element e2 = e2Elements.get(j);
if (!"category".equals(e2.getQualifiedName())) {
throw new IllegalArgumentException("Expecting 'category' " + "element.");
}
categories.add(e2.getAttributeValue("name"));
}
DiscreteVariable var = new DiscreteVariable(name, categories);
if (isLatent) {
var.setNodeType(NodeType.LATENT);
}
var.setCenterX(x);
var.setCenterY(y);
variables.add(var);
}
namesToVars = new HashMap<>();
for (Node v : variables) {
String name = v.getName();
namesToVars.put(name, v);
}
return variables;
}
use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.
the class BdeMetric method computeObservedCounts.
// public double scoreLnGam() {
//
// double[][][] priorProbs;
// double[][] priorProbsRowSum;
//
// Graph graph = bayesPm.getDag();
//
// int n = graph.getNumNodes();
//
// observedCounts = new int[n][][];
// priorProbs = new double[n][][];
//
// int[][] observedCountsRowSum = new int[n][];
// priorProbsRowSum = new double[n][];
//
// bayesIm = new MlBayesIm(bayesPm);
//
// for (int i = 0; i < n; i++) {
// //int numRows = bayesImMixed.getNumRows(i);
// int numRows = bayesIm.getNumRows(i);
// observedCounts[i] = new int[numRows][];
// priorProbs[i] = new double[numRows][];
//
// observedCountsRowSum[i] = new int[numRows];
// priorProbsRowSum[i] = new double[numRows];
//
// //for(int j = 0; j < bayesImMixed.getNumRows(i); j++) {
// for (int j = 0; j < numRows; j++) {
//
// observedCountsRowSum[i][j] = 0;
// priorProbsRowSum[i][j] = 0;
//
// //int numCols = bayesImMixed.getNumColumns(i);
// int numCols = bayesIm.getNumColumns(i);
// observedCounts[i][j] = new int[numCols];
// priorProbs[i][j] = new double[numCols];
// }
// }
//
// //At this point set values in both observedCounts and priorProbs
// computeObservedCounts();
// //Set all priorProbs (i.e. estimated counts) to 1.0. Eventually they may be
// //supplied as a parameter of the constructor of this class.
// for (int i = 0; i < n; i++) {
// for (int j = 0; j < bayesIm.getNumRows(i); j++) {
// for (int k = 0; k < bayesIm.getNumColumns(i); k++) {
// priorProbs[i][j][k] = 1.0;
// }
// }
// }
//
//
// for (int i = 0; i < n; i++) {
// for (int j = 0; j < bayesIm.getNumRows(i); j++) {
// for (int k = 0; k < bayesIm.getNumColumns(i); k++) {
// observedCountsRowSum[i][j] += observedCounts[i][j][k];
// priorProbsRowSum[i][j] += priorProbs[i][j][k];
// }
// }
// }
//
// //double outerProduct = 1.0;
// double sum = 0.0;
//
// //Debug print
// //System.out.println("counts and priors");
// //for(int i = 0; i < n; i++)
// // for(int j = 0; j < bayesIm.getNumRows(i); j++) {
// // System.out.println(observedCountsRowSum[i][j] + " " + priorProbsRowSum[i][j]);
// // }
//
// for (int i = 0; i < n; i++) {
//
// int qi = bayesIm.getNumRows(i);
// //double prodj = 1.0;
// double sumj = 0.0;
// for (int j = 0; j < qi; j++) {
//
// try {
// double numerator =
// ProbUtils.lngamma(priorProbsRowSum[i][j]);
// double denom = ProbUtils.lngamma(priorProbsRowSum[i][j] +
// observedCountsRowSum[i][j]);
// //System.out.println("num = " + numerator + " denom = " + denom);
// sumj += (numerator - denom);
// } catch (Exception e) {
// e.printStackTrace();
// }
//
// int ri = bayesIm.getNumColumns(i);
//
// //double prodk = 1.0;
// double sumk = 0.0;
// for (int k = 0; k < ri; k++) {
// try {
// sumk += ProbUtils.lngamma(
// priorProbs[i][j][k] + observedCounts[i][j][k]) -
// ProbUtils.lngamma(priorProbs[i][j][k]);
// } catch (Exception e) {
// e.printStackTrace();
// }
// }
//
// sumj += sumk;
// }
// sum += sumj;
// }
//
// return sum;
// }
private void computeObservedCounts() {
for (int j = 0; j < dataSet.getNumColumns(); j++) {
DiscreteVariable var = (DiscreteVariable) dataSet.getVariables().get(j);
String varName = var.getName();
Node varNode = bayesPm.getDag().getNode(varName);
int varIndex = bayesIm.getNodeIndex(varNode);
int[] parentVarIndices = bayesIm.getParents(varIndex);
// This segment is for variables with no parents:
if (parentVarIndices.length == 0) {
// System.out.println("No parents");
for (int col = 0; col < var.getNumCategories(); col++) {
observedCounts[j][0][col] = 0;
}
for (int i = 0; i < dataSet.getNumRows(); i++) {
// System.out.println("Case " + i);
// If this case has a value for ar
observedCounts[j][0][dataSet.getInt(i, j)] += 1.0;
// System.out.println("Adding 1.0 to " + varName +
// " row 0 category " + mixedData[j][i]);
}
// Print estimated counts:
// System.out.println("Estimated counts: ");
// Print counts for each value of this variable with no parents.
// for(int m = 0; m < ar.getNumSplits(); m++)
// System.out.print(" " + m + " " + observedCounts[j][0][m]);
// System.out.println();
} else {
// For variables with parents:
int numRows = bayesIm.getNumRows(varIndex);
for (int row = 0; row < numRows; row++) {
int[] parValues = bayesIm.getParentValues(varIndex, row);
for (int col = 0; col < var.getNumCategories(); col++) {
observedCounts[varIndex][row][col] = 0;
}
for (int i = 0; i < dataSet.getNumRows(); i++) {
// for a case where the parent values = parValues increment the estCount
boolean parentMatch = true;
for (int p = 0; p < parentVarIndices.length; p++) {
if (parValues[p] != dataSet.getInt(i, parentVarIndices[p])) {
parentMatch = false;
break;
}
}
if (!parentMatch) {
// Not a matching case; go to next.
continue;
}
observedCounts[j][row][dataSet.getInt(i, j)] += 1;
}
// }
// Print estimated counts:
// System.out.println("Estimated counts: ");
// System.out.println(" Parent values: ");
// for (int i = 0; i < parentVarIndices.length; i++) {
// Variable par = (Variable) dataSet.getVariableNames().get(parentVarIndices[i]);
// System.out.print(" " + par.getNode() + " " + parValues[i] + " ");
// }
// System.out.println();
// for(int m = 0; m < ar.getNumSplits(); m++)
// System.out.print(" " + m + " " + observedCounts[j][row][m]);
// System.out.println();
}
}
}
// else
}
use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.
the class EmBayesProperties method setGraph.
public final void setGraph(Graph graph) {
if (graph == null) {
throw new NullPointerException();
}
List<Node> vars = dataSet.getVariables();
Map<String, DiscreteVariable> nodesToVars = new HashMap<>();
for (int i = 0; i < dataSet.getNumColumns(); i++) {
DiscreteVariable var = (DiscreteVariable) vars.get(i);
String name = var.getName();
Node node = new GraphNode(name);
nodesToVars.put(node.getName(), var);
}
Dag dag = new Dag(graph);
BayesPm bayesPm = new BayesPm(dag);
List<Node> nodes = bayesPm.getDag().getNodes();
for (Node node1 : nodes) {
Node var = nodesToVars.get(node1.getName());
if (var != null) {
DiscreteVariable var2 = (DiscreteVariable) var;
List<String> categories = var2.getCategories();
bayesPm.setCategories(node1, categories);
}
}
this.graph = graph;
this.bayesPm = bayesPm;
this.blankBayesIm = new MlBayesIm(bayesPm);
}
use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.
the class DirichletBayesIm method simulateDataHelper.
/**
* Simulates a sample with the given sample size.
*
* @param sampleSize the sample size.
* @param randomUtil optional random number generator to use when
* creating the data
* @param latentDataSaved true iff data for latent variables should be
* saved.
* @return the simulated sample as a DataSet.
*/
private DataSet simulateDataHelper(int sampleSize, RandomUtil randomUtil, boolean latentDataSaved) {
int numMeasured = 0;
int[] map = new int[nodes.length];
List<Node> variables = new LinkedList<>();
for (int j = 0; j < nodes.length; j++) {
if (!latentDataSaved && nodes[j].getNodeType() != NodeType.MEASURED) {
continue;
}
int numCategories = bayesPm.getNumCategories(nodes[j]);
List<String> categories = new LinkedList<>();
for (int k = 0; k < numCategories; k++) {
categories.add(bayesPm.getCategory(nodes[j], k));
}
DiscreteVariable var = new DiscreteVariable(nodes[j].getName(), categories);
variables.add(var);
int index = ++numMeasured - 1;
map[index] = j;
}
DataSet dataSet = new ColtDataSet(sampleSize, variables);
constructSample(sampleSize, randomUtil, numMeasured, dataSet, map);
return dataSet;
}
Aggregations