use of edu.cmu.tetrad.bayes.BayesIm in project tetrad by cmu-phil.
the class SaveBayesImXmlAction method actionPerformed.
public void actionPerformed(ActionEvent e) {
try {
File outfile = EditorUtils.getSaveFile("bayesim", "xml", this.bayesImEditor, false, "Save Bayes IM as XML...");
BayesIm bayesIm = bayesImEditor.getWizard().getBayesIm();
FileOutputStream out = new FileOutputStream(outfile);
Element element = BayesXmlRenderer.getElement(bayesIm);
Document document = new Document(element);
Serializer serializer = new Serializer(out);
serializer.setLineSeparator("\n");
serializer.setIndent(2);
serializer.write(document);
out.close();
} catch (IOException e1) {
throw new RuntimeException(e1);
}
}
use of edu.cmu.tetrad.bayes.BayesIm in project tetrad by cmu-phil.
the class ConditionalGaussianSimulation method simulate.
private DataSet simulate(Graph G, Parameters parameters) {
HashMap<String, Integer> nd = new HashMap<>();
List<Node> nodes = G.getNodes();
Collections.shuffle(nodes);
if (this.shuffledOrder == null) {
List<Node> shuffledNodes = new ArrayList<>(nodes);
Collections.shuffle(shuffledNodes);
this.shuffledOrder = shuffledNodes;
}
for (int i = 0; i < nodes.size(); i++) {
if (i < nodes.size() * parameters.getDouble("percentDiscrete") * 0.01) {
final int minNumCategories = parameters.getInt("minCategories");
final int maxNumCategories = parameters.getInt("maxCategories");
final int value = pickNumCategories(minNumCategories, maxNumCategories);
nd.put(shuffledOrder.get(i).getName(), value);
} else {
nd.put(shuffledOrder.get(i).getName(), 0);
}
}
G = makeMixedGraph(G, nd);
nodes = G.getNodes();
DataSet mixedData = new BoxDataSet(new MixedDataBox(nodes, parameters.getInt("sampleSize")), nodes);
List<Node> X = new ArrayList<>();
List<Node> A = new ArrayList<>();
for (Node node : G.getNodes()) {
if (node instanceof ContinuousVariable) {
X.add(node);
} else {
A.add(node);
}
}
Graph AG = G.subgraph(A);
Graph XG = G.subgraph(X);
Map<ContinuousVariable, DiscreteVariable> erstatzNodes = new HashMap<>();
Map<String, ContinuousVariable> erstatzNodesReverse = new HashMap<>();
for (Node y : A) {
for (Node x : G.getParents(y)) {
if (x instanceof ContinuousVariable) {
DiscreteVariable ersatz = erstatzNodes.get(x);
if (ersatz == null) {
ersatz = new DiscreteVariable("Ersatz_" + x.getName(), RandomUtil.getInstance().nextInt(3) + 2);
erstatzNodes.put((ContinuousVariable) x, ersatz);
erstatzNodesReverse.put(ersatz.getName(), (ContinuousVariable) x);
AG.addNode(ersatz);
}
AG.addDirectedEdge(ersatz, y);
}
}
}
BayesPm bayesPm = new BayesPm(AG);
BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM);
SemPm semPm = new SemPm(XG);
Map<Combination, Double> paramValues = new HashMap<>();
List<Node> tierOrdering = G.getCausalOrdering();
int[] tiers = new int[tierOrdering.size()];
for (int t = 0; t < tierOrdering.size(); t++) {
tiers[t] = nodes.indexOf(tierOrdering.get(t));
}
Map<Integer, double[]> breakpointsMap = new HashMap<>();
for (int mixedIndex : tiers) {
for (int i = 0; i < parameters.getInt("sampleSize"); i++) {
if (nodes.get(mixedIndex) instanceof DiscreteVariable) {
int bayesIndex = bayesIm.getNodeIndex(nodes.get(mixedIndex));
int[] bayesParents = bayesIm.getParents(bayesIndex);
int[] parentValues = new int[bayesParents.length];
for (int k = 0; k < parentValues.length; k++) {
int bayesParentColumn = bayesParents[k];
Node bayesParent = bayesIm.getVariables().get(bayesParentColumn);
DiscreteVariable _parent = (DiscreteVariable) bayesParent;
int value;
ContinuousVariable orig = erstatzNodesReverse.get(_parent.getName());
if (orig != null) {
int mixedParentColumn = mixedData.getColumn(orig);
double d = mixedData.getDouble(i, mixedParentColumn);
double[] breakpoints = breakpointsMap.get(mixedParentColumn);
if (breakpoints == null) {
breakpoints = getBreakpoints(mixedData, _parent, mixedParentColumn);
breakpointsMap.put(mixedParentColumn, breakpoints);
}
value = breakpoints.length;
for (int j = 0; j < breakpoints.length; j++) {
if (d < breakpoints[j]) {
value = j;
break;
}
}
} else {
int mixedColumn = mixedData.getColumn(bayesParent);
value = mixedData.getInt(i, mixedColumn);
}
parentValues[k] = value;
}
int rowIndex = bayesIm.getRowIndex(bayesIndex, parentValues);
double sum = 0.0;
double r = RandomUtil.getInstance().nextDouble();
mixedData.setInt(i, mixedIndex, 0);
for (int k = 0; k < bayesIm.getNumColumns(bayesIndex); k++) {
double probability = bayesIm.getProbability(bayesIndex, rowIndex, k);
sum += probability;
if (sum >= r) {
mixedData.setInt(i, mixedIndex, k);
break;
}
}
} else {
Node y = nodes.get(mixedIndex);
Set<DiscreteVariable> discreteParents = new HashSet<>();
Set<ContinuousVariable> continuousParents = new HashSet<>();
for (Node node : G.getParents(y)) {
if (node instanceof DiscreteVariable) {
discreteParents.add((DiscreteVariable) node);
} else {
continuousParents.add((ContinuousVariable) node);
}
}
Parameter varParam = semPm.getParameter(y, y);
Parameter muParam = semPm.getMeanParameter(y);
Combination varComb = new Combination(varParam);
Combination muComb = new Combination(muParam);
for (DiscreteVariable v : discreteParents) {
varComb.addParamValue(v, mixedData.getInt(i, mixedData.getColumn(v)));
muComb.addParamValue(v, mixedData.getInt(i, mixedData.getColumn(v)));
}
double value = RandomUtil.getInstance().nextNormal(0, getParamValue(varComb, paramValues));
for (Node x : continuousParents) {
Parameter coefParam = semPm.getParameter(x, y);
Combination coefComb = new Combination(coefParam);
for (DiscreteVariable v : discreteParents) {
coefComb.addParamValue(v, mixedData.getInt(i, mixedData.getColumn(v)));
}
int parent = nodes.indexOf(x);
double parentValue = mixedData.getDouble(i, parent);
double parentCoef = getParamValue(coefComb, paramValues);
value += parentValue * parentCoef;
}
value += getParamValue(muComb, paramValues);
mixedData.setDouble(i, mixedIndex, value);
}
}
}
boolean saveLatentVars = parameters.getBoolean("saveLatentVars");
return saveLatentVars ? mixedData : DataUtils.restrictToMeasured(mixedData);
}
use of edu.cmu.tetrad.bayes.BayesIm in project tetrad by cmu-phil.
the class XdslXmlParser method buildIM.
private BayesIm buildIM(Element element0, Map<String, String> displayNames) {
Elements elements = element0.getChildElements();
for (int i = 0; i < elements.size(); i++) {
if (!"cpt".equals(elements.get(i).getQualifiedName())) {
throw new IllegalArgumentException("Expecting cpt element.");
}
}
Dag dag = new Dag();
// Get the nodes.
for (int i = 0; i < elements.size(); i++) {
Element cpt = elements.get(i);
String name = cpt.getAttribute(0).getValue();
if (displayNames == null) {
dag.addNode(new GraphNode(name));
} else {
dag.addNode(new GraphNode(displayNames.get(name)));
}
}
// Get the edges.
for (int i = 0; i < elements.size(); i++) {
Element cpt = elements.get(i);
Elements cptElements = cpt.getChildElements();
for (int j = 0; j < cptElements.size(); j++) {
Element cptElement = cptElements.get(j);
if (cptElement.getQualifiedName().equals("parents")) {
String list = cptElement.getValue();
String[] parentNames = list.split(" ");
for (String name : parentNames) {
if (displayNames == null) {
edu.cmu.tetrad.graph.Node parent = dag.getNode(name);
edu.cmu.tetrad.graph.Node child = dag.getNode(cpt.getAttribute(0).getValue());
dag.addDirectedEdge(parent, child);
} else {
edu.cmu.tetrad.graph.Node parent = dag.getNode(displayNames.get(name));
edu.cmu.tetrad.graph.Node child = dag.getNode(displayNames.get(cpt.getAttribute(0).getValue()));
dag.addDirectedEdge(parent, child);
}
}
}
}
String name;
if (displayNames == null) {
name = cpt.getAttribute(0).getValue();
} else {
name = displayNames.get(cpt.getAttribute(0).getValue());
}
dag.addNode(new GraphNode(name));
}
// PM
BayesPm pm = new BayesPm(dag);
for (int i = 0; i < elements.size(); i++) {
Element cpt = elements.get(i);
String varName = cpt.getAttribute(0).getValue();
Node node;
if (displayNames == null) {
node = dag.getNode(varName);
} else {
node = dag.getNode(displayNames.get(varName));
}
Elements cptElements = cpt.getChildElements();
List<String> stateNames = new ArrayList<>();
for (int j = 0; j < cptElements.size(); j++) {
Element cptElement = cptElements.get(j);
if (cptElement.getQualifiedName().equals("state")) {
Attribute attribute = cptElement.getAttribute(0);
String stateName = attribute.getValue();
stateNames.add(stateName);
}
}
pm.setCategories(node, stateNames);
}
// IM
BayesIm im = new MlBayesIm(pm);
for (int nodeIndex = 0; nodeIndex < elements.size(); nodeIndex++) {
Element cpt = elements.get(nodeIndex);
Elements cptElements = cpt.getChildElements();
for (int j = 0; j < cptElements.size(); j++) {
Element cptElement = cptElements.get(j);
if (cptElement.getQualifiedName().equals("probabilities")) {
String list = cptElement.getValue();
String[] probsStrings = list.split(" ");
List<Double> probs = new ArrayList<>();
for (String probString : probsStrings) {
probs.add(Double.parseDouble(probString));
}
int count = -1;
for (int row = 0; row < im.getNumRows(nodeIndex); row++) {
for (int col = 0; col < im.getNumColumns(nodeIndex); col++) {
im.setProbability(nodeIndex, row, col, probs.get(++count));
}
}
}
}
}
return im;
}
use of edu.cmu.tetrad.bayes.BayesIm in project tetrad by cmu-phil.
the class XdslXmlParser method getBayesIm.
/**
* Takes an xml representation of a Bayes IM and reinstantiates the IM
*
* @param element the xml of the IM
* @return the BayesIM
*/
public BayesIm getBayesIm(Element element) {
if (!"smile".equals(element.getQualifiedName())) {
throw new IllegalArgumentException("Expecting " + "smile" + " element.");
}
Elements elements = element.getChildElements();
Element element0 = null, element1 = null;
for (int i = 0; i < elements.size(); i++) {
Element _element = elements.get(i);
if ("nodes".equals(_element.getQualifiedName())) {
element0 = _element;
}
if ("extensions".equals(_element.getQualifiedName())) {
element1 = _element.getFirstChildElement("genie");
}
}
Map<String, String> displayNames = mapDisplayNames(element1, useDisplayNames);
BayesIm bayesIm = buildIM(element0, displayNames);
return bayesIm;
}
use of edu.cmu.tetrad.bayes.BayesIm in project tetrad by cmu-phil.
the class TestEvidence method sampleBayesIm2.
private static BayesIm sampleBayesIm2() {
Node a = new GraphNode("a");
Node b = new GraphNode("b");
Node c = new GraphNode("c");
Dag graph;
graph = new Dag();
graph.addNode(a);
graph.addNode(b);
graph.addNode(c);
graph.addDirectedEdge(a, b);
graph.addDirectedEdge(a, c);
graph.addDirectedEdge(b, c);
BayesPm bayesPm = new BayesPm(graph);
bayesPm.setNumCategories(b, 3);
BayesIm bayesIm1 = new MlBayesIm(bayesPm);
bayesIm1.setProbability(0, 0, 0, .3);
bayesIm1.setProbability(0, 0, 1, .7);
bayesIm1.setProbability(1, 0, 0, .3);
bayesIm1.setProbability(1, 0, 1, .4);
bayesIm1.setProbability(1, 0, 2, .3);
bayesIm1.setProbability(1, 1, 0, .6);
bayesIm1.setProbability(1, 1, 1, .1);
bayesIm1.setProbability(1, 1, 2, .3);
bayesIm1.setProbability(2, 0, 0, .9);
bayesIm1.setProbability(2, 0, 1, .1);
bayesIm1.setProbability(2, 1, 0, .1);
bayesIm1.setProbability(2, 1, 1, .9);
bayesIm1.setProbability(2, 2, 0, .5);
bayesIm1.setProbability(2, 2, 1, .5);
bayesIm1.setProbability(2, 3, 0, .2);
bayesIm1.setProbability(2, 3, 1, .8);
bayesIm1.setProbability(2, 4, 0, .6);
bayesIm1.setProbability(2, 4, 1, .4);
bayesIm1.setProbability(2, 5, 0, .7);
bayesIm1.setProbability(2, 5, 1, .3);
return bayesIm1;
}
Aggregations