use of edu.cmu.tetrad.bayes.MlBayesIm in project tetrad by cmu-phil.
the class BayesImWrapper method setBayesIm.
private void setBayesIm(BayesPm bayesPm, BayesIm oldBayesIm, int manual) {
bayesIms = new ArrayList<>();
bayesIms.add(new MlBayesIm(bayesPm, oldBayesIm, manual));
}
use of edu.cmu.tetrad.bayes.MlBayesIm in project tetrad by cmu-phil.
the class XdslXmlParser method buildIM.
private BayesIm buildIM(Element element0, Map<String, String> displayNames) {
Elements elements = element0.getChildElements();
for (int i = 0; i < elements.size(); i++) {
if (!"cpt".equals(elements.get(i).getQualifiedName())) {
throw new IllegalArgumentException("Expecting cpt element.");
}
}
Dag dag = new Dag();
// Get the nodes.
for (int i = 0; i < elements.size(); i++) {
Element cpt = elements.get(i);
String name = cpt.getAttribute(0).getValue();
if (displayNames == null) {
dag.addNode(new GraphNode(name));
} else {
dag.addNode(new GraphNode(displayNames.get(name)));
}
}
// Get the edges.
for (int i = 0; i < elements.size(); i++) {
Element cpt = elements.get(i);
Elements cptElements = cpt.getChildElements();
for (int j = 0; j < cptElements.size(); j++) {
Element cptElement = cptElements.get(j);
if (cptElement.getQualifiedName().equals("parents")) {
String list = cptElement.getValue();
String[] parentNames = list.split(" ");
for (String name : parentNames) {
if (displayNames == null) {
edu.cmu.tetrad.graph.Node parent = dag.getNode(name);
edu.cmu.tetrad.graph.Node child = dag.getNode(cpt.getAttribute(0).getValue());
dag.addDirectedEdge(parent, child);
} else {
edu.cmu.tetrad.graph.Node parent = dag.getNode(displayNames.get(name));
edu.cmu.tetrad.graph.Node child = dag.getNode(displayNames.get(cpt.getAttribute(0).getValue()));
dag.addDirectedEdge(parent, child);
}
}
}
}
String name;
if (displayNames == null) {
name = cpt.getAttribute(0).getValue();
} else {
name = displayNames.get(cpt.getAttribute(0).getValue());
}
dag.addNode(new GraphNode(name));
}
// PM
BayesPm pm = new BayesPm(dag);
for (int i = 0; i < elements.size(); i++) {
Element cpt = elements.get(i);
String varName = cpt.getAttribute(0).getValue();
Node node;
if (displayNames == null) {
node = dag.getNode(varName);
} else {
node = dag.getNode(displayNames.get(varName));
}
Elements cptElements = cpt.getChildElements();
List<String> stateNames = new ArrayList<>();
for (int j = 0; j < cptElements.size(); j++) {
Element cptElement = cptElements.get(j);
if (cptElement.getQualifiedName().equals("state")) {
Attribute attribute = cptElement.getAttribute(0);
String stateName = attribute.getValue();
stateNames.add(stateName);
}
}
pm.setCategories(node, stateNames);
}
// IM
BayesIm im = new MlBayesIm(pm);
for (int nodeIndex = 0; nodeIndex < elements.size(); nodeIndex++) {
Element cpt = elements.get(nodeIndex);
Elements cptElements = cpt.getChildElements();
for (int j = 0; j < cptElements.size(); j++) {
Element cptElement = cptElements.get(j);
if (cptElement.getQualifiedName().equals("probabilities")) {
String list = cptElement.getValue();
String[] probsStrings = list.split(" ");
List<Double> probs = new ArrayList<>();
for (String probString : probsStrings) {
probs.add(Double.parseDouble(probString));
}
int count = -1;
for (int row = 0; row < im.getNumRows(nodeIndex); row++) {
for (int col = 0; col < im.getNumColumns(nodeIndex); col++) {
im.setProbability(nodeIndex, row, col, probs.get(++count));
}
}
}
}
}
return im;
}
use of edu.cmu.tetrad.bayes.MlBayesIm in project tetrad by cmu-phil.
the class TestEvidence method sampleBayesIm2.
private static BayesIm sampleBayesIm2() {
Node a = new GraphNode("a");
Node b = new GraphNode("b");
Node c = new GraphNode("c");
Dag graph;
graph = new Dag();
graph.addNode(a);
graph.addNode(b);
graph.addNode(c);
graph.addDirectedEdge(a, b);
graph.addDirectedEdge(a, c);
graph.addDirectedEdge(b, c);
BayesPm bayesPm = new BayesPm(graph);
bayesPm.setNumCategories(b, 3);
BayesIm bayesIm1 = new MlBayesIm(bayesPm);
bayesIm1.setProbability(0, 0, 0, .3);
bayesIm1.setProbability(0, 0, 1, .7);
bayesIm1.setProbability(1, 0, 0, .3);
bayesIm1.setProbability(1, 0, 1, .4);
bayesIm1.setProbability(1, 0, 2, .3);
bayesIm1.setProbability(1, 1, 0, .6);
bayesIm1.setProbability(1, 1, 1, .1);
bayesIm1.setProbability(1, 1, 2, .3);
bayesIm1.setProbability(2, 0, 0, .9);
bayesIm1.setProbability(2, 0, 1, .1);
bayesIm1.setProbability(2, 1, 0, .1);
bayesIm1.setProbability(2, 1, 1, .9);
bayesIm1.setProbability(2, 2, 0, .5);
bayesIm1.setProbability(2, 2, 1, .5);
bayesIm1.setProbability(2, 3, 0, .2);
bayesIm1.setProbability(2, 3, 1, .8);
bayesIm1.setProbability(2, 4, 0, .6);
bayesIm1.setProbability(2, 4, 1, .4);
bayesIm1.setProbability(2, 5, 0, .7);
bayesIm1.setProbability(2, 5, 1, .3);
return bayesIm1;
}
use of edu.cmu.tetrad.bayes.MlBayesIm in project tetrad by cmu-phil.
the class TestFges method explore2.
@Test
public void explore2() {
RandomUtil.getInstance().setSeed(1457220623122L);
int numVars = 20;
double edgeFactor = 1.0;
int numCases = 1000;
double structurePrior = 1;
double samplePrior = 1;
List<Node> vars = new ArrayList<>();
for (int i = 0; i < numVars; i++) {
vars.add(new ContinuousVariable("X" + i));
}
Graph dag = GraphUtils.randomGraphRandomForwardEdges(vars, 0, (int) (numVars * edgeFactor), 30, 15, 15, false, true);
// printDegreeDistribution(dag, out);
BayesPm pm = new BayesPm(dag, 2, 3);
BayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM);
DataSet data = im.simulateData(numCases, false);
// out.println("Finishing simulation");
BDeScore score = new BDeScore(data);
score.setSamplePrior(samplePrior);
score.setStructurePrior(structurePrior);
Fges ges = new Fges(score);
ges.setVerbose(false);
ges.setNumPatternsToStore(0);
ges.setFaithfulnessAssumed(false);
Graph estPattern = ges.search();
final Graph truePattern = SearchGraphUtils.patternForDag(dag);
int[][] counts = SearchGraphUtils.graphComparison(estPattern, truePattern, null);
int[][] expectedCounts = { { 2, 0, 0, 0, 0, 1 }, { 0, 0, 0, 0, 0, 0 }, { 0, 0, 0, 0, 0, 0 }, { 0, 0, 0, 0, 0, 0 }, { 2, 0, 0, 13, 0, 3 }, { 0, 0, 0, 0, 0, 0 }, { 0, 0, 0, 0, 0, 0 }, { 0, 0, 0, 0, 0, 0 } };
// for (int i = 0; i < counts.length; i++) {
// assertTrue(Arrays.equals(counts[i], expectedCounts[i]));
// }
// System.out.println(MatrixUtils.toString(expectedCounts));
// System.out.println(MatrixUtils.toString(counts));
// System.out.println(RandomUtil.getInstance().getSeed());
}
use of edu.cmu.tetrad.bayes.MlBayesIm in project tetrad by cmu-phil.
the class TestGeneralBootstrapTest method testFCId.
@Test
public void testFCId() {
double structurePrior = 1, samplePrior = 1;
int depth = -1;
int maxPathLength = -1;
int numVars = 20;
int edgesPerNode = 2;
int numLatentConfounders = 4;
int numCases = 50;
int numBootstrapSamples = 5;
boolean verbose = true;
long seed = 123;
Graph dag = makeDiscreteDAG(numVars, numLatentConfounders, edgesPerNode);
DagToPag dagToPag = new DagToPag(dag);
Graph truePag = dagToPag.convert();
System.out.println("Truth PAG_of_the_true_DAG Graph:");
System.out.println(truePag.toString());
BayesPm pm = new BayesPm(dag, 2, 3);
BayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM);
DataSet data = im.simulateData(numCases, seed, false);
Parameters parameters = new Parameters();
parameters.set("structurePrior", structurePrior);
parameters.set("samplePrior", samplePrior);
parameters.set("depth", depth);
parameters.set("maxPathLength", maxPathLength);
parameters.set("numPatternsToStore", 0);
parameters.set("verbose", verbose);
IndependenceWrapper test = new ChiSquare();
Algorithm algorithm = new Fci(test);
GeneralBootstrapTest bootstrapTest = new GeneralBootstrapTest(data, algorithm, numBootstrapSamples);
bootstrapTest.setVerbose(verbose);
bootstrapTest.setParameters(parameters);
bootstrapTest.setEdgeEnsemble(BootstrapEdgeEnsemble.Highest);
Graph resultGraph = bootstrapTest.search();
System.out.println("Estimated Bootstrapped PAG_of_the_true_DAG Graph:");
System.out.println(resultGraph.toString());
// Adjacency Confusion Matrix
int[][] adjAr = GeneralBootstrapTest.getAdjConfusionMatrix(truePag, resultGraph);
printAdjConfusionMatrix(adjAr);
// Edge Type Confusion Matrix
int[][] edgeAr = GeneralBootstrapTest.getEdgeTypeConfusionMatrix(truePag, resultGraph);
printEdgeTypeConfusionMatrix(edgeAr);
}
Aggregations