use of edu.cmu.tetrad.bayes.BayesPm in project tetrad by cmu-phil.
the class EvidenceWizardMultiple method appendJoint.
private void appendJoint(List<Node> selectedNodes, JTextArea marginalsArea, BayesIm manipulatedIm, NumberFormat nf) {
if (!getUpdaterWrapper().getBayesUpdater().isJointMarginalSupported()) {
marginalsArea.append("\n\n(Calculation of joint not supported " + "for this updater.)");
return;
}
BayesPm bayesPm = manipulatedIm.getBayesPm();
int numNodes = selectedNodes.size();
int[] dims = new int[numNodes];
int[] variables = new int[numNodes];
int numRows = 1;
for (int i = 0; i < numNodes; i++) {
Node node = selectedNodes.get(i);
int numCategories = bayesPm.getNumCategories(node);
variables[i] = manipulatedIm.getNodeIndex(node);
dims[i] = numCategories;
numRows *= numCategories;
}
marginalsArea.append("\n\nJOINT OVER SELECTED VARIABLES:\n\n");
for (int i = 0; i < numNodes; i++) {
marginalsArea.append(selectedNodes.get(i) + "\t");
}
marginalsArea.append("Joint\tLog odds\n");
for (int row = 0; row < numRows; row++) {
int[] values = getCategories(row, dims);
double prob = getUpdaterWrapper().getBayesUpdater().getJointMarginal(variables, values);
double logOdds = Math.log(prob / (1. - prob));
marginalsArea.append("\n");
for (int j = 0; j < numNodes; j++) {
Node node = selectedNodes.get(j);
marginalsArea.append(bayesPm.getCategory(node, values[j]));
marginalsArea.append("\t");
}
marginalsArea.append(nf.format(prob) + "\t");
marginalsArea.append(nf.format(logOdds));
}
}
use of edu.cmu.tetrad.bayes.BayesPm in project tetrad by cmu-phil.
the class EvidenceWizardMultiple method appendMarginals.
private void appendMarginals(List<Node> selectedNodes, JTextArea marginalsArea, BayesIm manipulatedIm, NumberFormat nf) {
BayesPm bayesPm = manipulatedIm.getBayesPm();
marginalsArea.append("MARGINALS FOR SELECTED VARIABLES:\n");
for (Node selectedNode : selectedNodes) {
marginalsArea.append("\nVariable " + selectedNode.getName() + ":\n");
int nodeIndex = manipulatedIm.getNodeIndex(selectedNode);
for (int j = 0; j < bayesPm.getNumCategories(selectedNode); j++) {
double prob = getUpdaterWrapper().getBayesUpdater().getMarginal(nodeIndex, j);
double logOdds = Math.log(prob / (1. - prob));
marginalsArea.append("Category " + bayesPm.getCategory(selectedNode, j) + ": p = " + nf.format(prob) + ", log odds = " + nf.format(logOdds) + "\n");
}
}
}
use of edu.cmu.tetrad.bayes.BayesPm in project tetrad by cmu-phil.
the class EvidenceWizardMultipleObs method appendMarginals.
private void appendMarginals(List<Node> selectedNodes, JTextArea marginalsArea, BayesIm manipulatedIm, NumberFormat nf) {
BayesPm bayesPm = manipulatedIm.getBayesPm();
marginalsArea.append("MARGINALS FOR SELECTED VARIABLES:\n");
for (Node selectedNode : selectedNodes) {
marginalsArea.append("\nVariable " + selectedNode.getName() + ":\n");
int nodeIndex = manipulatedIm.getNodeIndex(selectedNode);
for (int j = 0; j < bayesPm.getNumCategories(selectedNode); j++) {
double prob = getUpdaterWrapper().getBayesUpdater().getMarginal(nodeIndex, j);
// identifiability returns -1 if the requested prob is unidentifiable
if (prob < 0.0) {
marginalsArea.append("Category " + bayesPm.getCategory(selectedNode, j) + ": p = " + "Unidentifiable" + ", log odds = " + "*" + "\n");
} else {
double logOdds = Math.log(prob / (1. - prob));
marginalsArea.append("Category " + bayesPm.getCategory(selectedNode, j) + ": p = " + nf.format(prob) + ", log odds = " + nf.format(logOdds) + "\n");
}
}
}
}
use of edu.cmu.tetrad.bayes.BayesPm in project tetrad by cmu-phil.
the class TestEvidence method sampleBayesIm2.
private static BayesIm sampleBayesIm2() {
Node a = new GraphNode("a");
Node b = new GraphNode("b");
Node c = new GraphNode("c");
Dag graph;
graph = new Dag();
graph.addNode(a);
graph.addNode(b);
graph.addNode(c);
graph.addDirectedEdge(a, b);
graph.addDirectedEdge(a, c);
graph.addDirectedEdge(b, c);
BayesPm bayesPm = new BayesPm(graph);
bayesPm.setNumCategories(b, 3);
BayesIm bayesIm1 = new MlBayesIm(bayesPm);
bayesIm1.setProbability(0, 0, 0, .3);
bayesIm1.setProbability(0, 0, 1, .7);
bayesIm1.setProbability(1, 0, 0, .3);
bayesIm1.setProbability(1, 0, 1, .4);
bayesIm1.setProbability(1, 0, 2, .3);
bayesIm1.setProbability(1, 1, 0, .6);
bayesIm1.setProbability(1, 1, 1, .1);
bayesIm1.setProbability(1, 1, 2, .3);
bayesIm1.setProbability(2, 0, 0, .9);
bayesIm1.setProbability(2, 0, 1, .1);
bayesIm1.setProbability(2, 1, 0, .1);
bayesIm1.setProbability(2, 1, 1, .9);
bayesIm1.setProbability(2, 2, 0, .5);
bayesIm1.setProbability(2, 2, 1, .5);
bayesIm1.setProbability(2, 3, 0, .2);
bayesIm1.setProbability(2, 3, 1, .8);
bayesIm1.setProbability(2, 4, 0, .6);
bayesIm1.setProbability(2, 4, 1, .4);
bayesIm1.setProbability(2, 5, 0, .7);
bayesIm1.setProbability(2, 5, 1, .3);
return bayesIm1;
}
use of edu.cmu.tetrad.bayes.BayesPm in project tetrad by cmu-phil.
the class TestFges method explore2.
@Test
public void explore2() {
RandomUtil.getInstance().setSeed(1457220623122L);
int numVars = 20;
double edgeFactor = 1.0;
int numCases = 1000;
double structurePrior = 1;
double samplePrior = 1;
List<Node> vars = new ArrayList<>();
for (int i = 0; i < numVars; i++) {
vars.add(new ContinuousVariable("X" + i));
}
Graph dag = GraphUtils.randomGraphRandomForwardEdges(vars, 0, (int) (numVars * edgeFactor), 30, 15, 15, false, true);
// printDegreeDistribution(dag, out);
BayesPm pm = new BayesPm(dag, 2, 3);
BayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM);
DataSet data = im.simulateData(numCases, false);
// out.println("Finishing simulation");
BDeScore score = new BDeScore(data);
score.setSamplePrior(samplePrior);
score.setStructurePrior(structurePrior);
Fges ges = new Fges(score);
ges.setVerbose(false);
ges.setNumPatternsToStore(0);
ges.setFaithfulnessAssumed(false);
Graph estPattern = ges.search();
final Graph truePattern = SearchGraphUtils.patternForDag(dag);
int[][] counts = SearchGraphUtils.graphComparison(estPattern, truePattern, null);
int[][] expectedCounts = { { 2, 0, 0, 0, 0, 1 }, { 0, 0, 0, 0, 0, 0 }, { 0, 0, 0, 0, 0, 0 }, { 0, 0, 0, 0, 0, 0 }, { 2, 0, 0, 13, 0, 3 }, { 0, 0, 0, 0, 0, 0 }, { 0, 0, 0, 0, 0, 0 }, { 0, 0, 0, 0, 0, 0 } };
// for (int i = 0; i < counts.length; i++) {
// assertTrue(Arrays.equals(counts[i], expectedCounts[i]));
// }
// System.out.println(MatrixUtils.toString(expectedCounts));
// System.out.println(MatrixUtils.toString(counts));
// System.out.println(RandomUtil.getInstance().getSeed());
}
Aggregations