use of edu.cmu.tetrad.bayes.BayesPm in project tetrad by cmu-phil.
the class PerformanceTests method testFgesMb.
private void testFgesMb(int numVars, double edgeFactor, int numCases, int numRuns, boolean continuous) {
double penaltyDiscount = 4.0;
int structurePrior = 10;
int samplePrior = 10;
int maxIndegree = -1;
// boolean faithfulness = false;
List<int[][]> allCounts = new ArrayList<>();
List<double[]> comparisons = new ArrayList<>();
List<Double> degrees = new ArrayList<>();
List<Long> elapsedTimes = new ArrayList<>();
System.out.println("Making dag");
Graph dag = makeDag(numVars, edgeFactor);
System.out.println(new Date());
System.out.println("Calculating pattern for DAG");
Graph pattern = SearchGraphUtils.patternForDag(dag);
int[] tiers = new int[dag.getNumNodes()];
for (int i = 0; i < dag.getNumNodes(); i++) {
tiers[i] = i;
}
System.out.println("Graph done");
long time1 = System.currentTimeMillis();
out.println("Graph done");
System.out.println(new Date());
System.out.println("Starting simulation");
Graph estPattern;
long elapsed;
FgesMb fges;
List<Node> vars;
if (continuous) {
init(new File("FgesMb.comparison.continuous" + numVars + "." + (int) (edgeFactor * numVars) + "." + numCases + "." + numRuns + ".txt"), "Num runs = " + numRuns);
out.println("Num vars = " + numVars);
out.println("Num edges = " + (int) (numVars * edgeFactor));
out.println("Num cases = " + numCases);
out.println("Penalty discount = " + penaltyDiscount);
out.println("Depth = " + maxIndegree);
out.println();
out.println(new Date());
vars = dag.getNodes();
LargeScaleSimulation simulator = new LargeScaleSimulation(dag, vars, tiers);
simulator.setVerbose(false);
simulator.setOut(out);
DataSet data = simulator.simulateDataFisher(numCases);
System.out.println("Finishing simulation");
System.out.println(new Date());
long time2 = System.currentTimeMillis();
out.println("Elapsed (simulating the data): " + (time2 - time1) + " ms");
System.out.println(new Date());
System.out.println("Making covariance matrix");
long time3 = System.currentTimeMillis();
ICovarianceMatrix cov = new CovarianceMatrixOnTheFly(data, true);
System.out.println("Covariance matrix done");
out.println("Elapsed (calculating cov): " + (time3 - time2) + " ms\n");
SemBicScore score = new SemBicScore(cov);
score.setPenaltyDiscount(penaltyDiscount);
System.out.println(new Date());
System.out.println("\nStarting FGES-MB");
fges = new FgesMb(score);
fges.setVerbose(false);
fges.setNumPatternsToStore(0);
fges.setOut(System.out);
// fges.setHeuristicSpeedup(faithfulness);
fges.setMaxIndegree(maxIndegree);
fges.setCycleBound(-1);
} else {
init(new File("FgesMb.comparison.discrete" + numVars + "." + (int) (edgeFactor * numVars) + "." + numCases + "." + numRuns + ".txt"), "Num runs = " + numRuns);
out.println("Num vars = " + numVars);
out.println("Num edges = " + (int) (numVars * edgeFactor));
out.println("Num cases = " + numCases);
out.println("Sample prior = " + samplePrior);
out.println("Structure prior = " + structurePrior);
out.println("Depth = " + maxIndegree);
out.println();
out.println(new Date());
BayesPm pm = new BayesPm(dag, 3, 3);
MlBayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM);
DataSet data = im.simulateData(numCases, false, tiers);
vars = data.getVariables();
pattern = GraphUtils.replaceNodes(pattern, vars);
System.out.println("Finishing simulation");
long time2 = System.currentTimeMillis();
out.println("Elapsed (simulating the data): " + (time2 - time1) + " ms");
long time3 = System.currentTimeMillis();
BDeuScore score = new BDeuScore(data);
score.setStructurePrior(structurePrior);
score.setSamplePrior(samplePrior);
System.out.println(new Date());
System.out.println("\nStarting FGES");
long time4 = System.currentTimeMillis();
fges = new FgesMb(score);
fges.setVerbose(false);
fges.setNumPatternsToStore(0);
fges.setOut(System.out);
// fges.setHeuristicSpeedup(faithfulness);
fges.setMaxIndegree(maxIndegree);
fges.setCycleBound(-1);
long timeb = System.currentTimeMillis();
out.println("Time consructing BDeu score " + (time4 - time3) + " ms");
out.println("Time for FGES-MB constructor " + (timeb - time4) + " ms");
out.println();
}
int numSkipped = 0;
for (int run = 0; run < numRuns; run++) {
out.println("\n\n\n******************************** RUN " + (run + 1) + " ********************************\n\n");
Node target = vars.get(RandomUtil.getInstance().nextInt(vars.size()));
System.out.println("Target = " + target);
long timea = System.currentTimeMillis();
estPattern = fges.search(target);
long timed = System.currentTimeMillis();
elapsed = timed - timea;
Set<Node> mb = new HashSet<>();
mb.add(target);
mb.addAll(pattern.getAdjacentNodes(target));
for (Node child : pattern.getChildren(target)) {
mb.addAll(pattern.getParents(child));
}
Graph trueMbGraph = pattern.subgraph(new ArrayList<>(mb));
long timec = System.currentTimeMillis();
out.println("Time for FGES-MB search " + (timec - timea) + " ms");
out.println();
System.out.println("Done with FGES");
System.out.println(new Date());
double[] comparison = new double[4];
System.out.println("Counting misclassifications.");
int[][] counts = GraphUtils.edgeMisclassificationCounts(trueMbGraph, estPattern, false);
allCounts.add(counts);
System.out.println(new Date());
int sumRow = counts[4][0] + counts[4][3] + counts[4][5];
int sumCol = counts[0][3] + counts[4][3] + counts[5][3] + counts[7][3];
int trueArrow = counts[4][3];
int sumTrueAdjacencies = 0;
for (int i = 0; i < 7; i++) {
for (int j = 0; j < 5; j++) {
sumTrueAdjacencies += counts[i][j];
}
}
int falsePositiveAdjacencies = 0;
for (int j = 0; j < 5; j++) {
falsePositiveAdjacencies += counts[7][j];
}
int falseNegativeAdjacencies = 0;
for (int i = 0; i < 5; i++) {
falseNegativeAdjacencies += counts[i][5];
}
comparison[0] = sumTrueAdjacencies / (double) (sumTrueAdjacencies + falsePositiveAdjacencies);
comparison[1] = sumTrueAdjacencies / (double) (sumTrueAdjacencies + falseNegativeAdjacencies);
comparison[2] = trueArrow / (double) sumCol;
comparison[3] = trueArrow / (double) sumRow;
// if (Double.isNaN(comparison[0]) || Double.isNaN(comparison[1]) || Double.isNaN(comparison[2]) ||
// Double.isNaN(comparison[3])) {
// run--;
// numSkipped++;
// continue;
// }
comparisons.add(comparison);
out.println(GraphUtils.edgeMisclassifications(counts));
out.println(precisionRecall(comparison));
// printAverageConfusion("Average", allCounts);
elapsedTimes.add(elapsed);
out.println("\nElapsed: " + elapsed + " ms");
}
printAverageConfusion("Average", allCounts, new DecimalFormat("0.0"));
printAveragePrecisionRecall(comparisons);
out.println("Number of runs skipped because of undefined accuracies: " + numSkipped);
printAverageStatistics(elapsedTimes, degrees);
out.close();
}
use of edu.cmu.tetrad.bayes.BayesPm in project tetrad by cmu-phil.
the class PerformanceTests method testFges.
private void testFges(int numVars, double edgeFactor, int numCases, int numRuns, boolean continuous) {
out.println(new Date());
// RandomUtil.getInstance().setSeed(4828384343999L);
double penaltyDiscount = 4.0;
int maxIndegree = 5;
boolean faithfulness = true;
// RandomUtil.getInstance().setSeed(50304050454L);
List<int[][]> allCounts = new ArrayList<>();
List<double[]> comparisons = new ArrayList<>();
List<Double> degrees = new ArrayList<>();
List<Long> elapsedTimes = new ArrayList<>();
if (continuous) {
init(new File("fges.comparison.continuous" + numVars + "." + (int) (edgeFactor * numVars) + "." + numCases + "." + numRuns + ".txt"), "Num runs = " + numRuns);
out.println("Num vars = " + numVars);
out.println("Num edges = " + (int) (numVars * edgeFactor));
out.println("Num cases = " + numCases);
out.println("Penalty discount = " + penaltyDiscount);
out.println("Depth = " + maxIndegree);
out.println();
} else {
init(new File("fges.comparison.discrete" + numVars + "." + (int) (edgeFactor * numVars) + "." + numCases + "." + numRuns + ".txt"), "Num runs = " + numRuns);
out.println("Num vars = " + numVars);
out.println("Num edges = " + (int) (numVars * edgeFactor));
out.println("Num cases = " + numCases);
out.println("Sample prior = " + 1);
out.println("Structure prior = " + 1);
out.println("Depth = " + 1);
out.println();
}
for (int run = 0; run < numRuns; run++) {
out.println("\n\n\n******************************** RUN " + (run + 1) + " ********************************\n\n");
System.out.println("Making dag");
out.println(new Date());
Graph dag = makeDag(numVars, edgeFactor);
System.out.println(new Date());
System.out.println("Calculating pattern for DAG");
Graph pattern = SearchGraphUtils.patternForDag(dag);
List<Node> vars = dag.getNodes();
int[] tiers = new int[vars.size()];
for (int i = 0; i < vars.size(); i++) {
tiers[i] = i;
}
System.out.println("Graph done");
long time1 = System.currentTimeMillis();
out.println("Graph done");
System.out.println(new Date());
System.out.println("Starting simulation");
Graph estPattern;
long elapsed;
if (continuous) {
LargeScaleSimulation simulator = new LargeScaleSimulation(dag, vars, tiers);
simulator.setVerbose(false);
simulator.setOut(out);
DataSet data = simulator.simulateDataFisher(numCases);
System.out.println("Finishing simulation");
System.out.println(new Date());
long time2 = System.currentTimeMillis();
out.println("Elapsed (simulating the data): " + (time2 - time1) + " ms");
System.out.println(new Date());
System.out.println("Making covariance matrix");
long time3 = System.currentTimeMillis();
ICovarianceMatrix cov = new CovarianceMatrixOnTheFly(data, true);
System.out.println("Covariance matrix done");
out.println("Elapsed (calculating cov): " + (time3 - time2) + " ms\n");
SemBicScore score = new SemBicScore(cov);
score.setPenaltyDiscount(penaltyDiscount);
System.out.println(new Date());
System.out.println("\nStarting FGES");
long timea = System.currentTimeMillis();
Fges fges = new Fges(score);
// fges.setVerbose(false);
fges.setNumPatternsToStore(0);
fges.setOut(System.out);
fges.setFaithfulnessAssumed(faithfulness);
fges.setCycleBound(-1);
long timeb = System.currentTimeMillis();
estPattern = fges.search();
long timec = System.currentTimeMillis();
out.println("Time for FGES constructor " + (timeb - timea) + " ms");
out.println("Time for FGES search " + (timec - timea) + " ms");
out.println();
out.flush();
elapsed = timec - timea;
} else {
BayesPm pm = new BayesPm(dag, 3, 3);
MlBayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM);
DataSet data = im.simulateData(numCases, false, tiers);
System.out.println("Finishing simulation");
long time2 = System.currentTimeMillis();
out.println("Elapsed (simulating the data): " + (time2 - time1) + " ms");
long time3 = System.currentTimeMillis();
BDeuScore score = new BDeuScore(data);
score.setStructurePrior(1);
score.setSamplePrior(1);
System.out.println(new Date());
System.out.println("\nStarting FGES");
long timea = System.currentTimeMillis();
Fges fges = new Fges(score);
// fges.setVerbose(false);
fges.setNumPatternsToStore(0);
fges.setOut(System.out);
fges.setFaithfulnessAssumed(faithfulness);
fges.setCycleBound(-1);
long timeb = System.currentTimeMillis();
estPattern = fges.search();
long timec = System.currentTimeMillis();
out.println("Time consructing BDeu score " + (timea - time3) + " ms");
out.println("Time for FGES constructor " + (timeb - timea) + " ms");
out.println("Time for FGES search " + (timec - timea) + " ms");
out.println();
elapsed = timec - timea;
}
System.out.println("Done with FGES");
System.out.println(new Date());
// System.out.println("Replacing nodes");d
//
// estPattern = GraphUtils.replaceNodes(estPattern, dag.getNodes());
// System.out.println("Calculating degree");
//
// double degree = GraphUtils.degree(estPattern);
// degrees.add(degree);
//
// out.println("Degree out output graph = " + degree);
double[] comparison = new double[4];
// int adjFn = GraphUtils.countAdjErrors(pattern, estPattern);
// int adjFp = GraphUtils.countAdjErrors(estPattern, pattern);
// int trueAdj = pattern.getNumEdges();
//
// comparison[0] = trueAdj / (double) (trueAdj + adjFp);
// comparison[1] = trueAdj / (double) (trueAdj + adjFn);
System.out.println("Counting misclassifications.");
estPattern = GraphUtils.replaceNodes(estPattern, pattern.getNodes());
int[][] counts = GraphUtils.edgeMisclassificationCounts(pattern, estPattern, false);
allCounts.add(counts);
System.out.println(new Date());
int sumRow = counts[4][0] + counts[4][3] + counts[4][5];
int sumCol = counts[0][3] + counts[4][3] + counts[5][3] + counts[7][3];
int trueArrow = counts[4][3];
int sumTrueAdjacencies = 0;
for (int i = 0; i < 7; i++) {
for (int j = 0; j < 5; j++) {
sumTrueAdjacencies += counts[i][j];
}
}
int falsePositiveAdjacencies = 0;
for (int j = 0; j < 5; j++) {
falsePositiveAdjacencies += counts[7][j];
}
int falseNegativeAdjacencies = 0;
for (int i = 0; i < 5; i++) {
falseNegativeAdjacencies += counts[i][5];
}
comparison[0] = sumTrueAdjacencies / (double) (sumTrueAdjacencies + falsePositiveAdjacencies);
comparison[1] = sumTrueAdjacencies / (double) (sumTrueAdjacencies + falseNegativeAdjacencies);
comparison[2] = trueArrow / (double) sumCol;
comparison[3] = trueArrow / (double) sumRow;
comparisons.add(comparison);
out.println(GraphUtils.edgeMisclassifications(counts));
out.println(precisionRecall(comparison));
elapsedTimes.add(elapsed);
out.println("\nElapsed: " + elapsed + " ms");
}
printAverageConfusion("Average", allCounts);
printAveragePrecisionRecall(comparisons);
printAverageStatistics(elapsedTimes, degrees);
out.close();
}
use of edu.cmu.tetrad.bayes.BayesPm in project tetrad by cmu-phil.
the class BayesPmWrapper method setBayesPm.
private void setBayesPm(Graph graph, int lowerBound, int upperBound) {
BayesPm b = new BayesPm(graph, lowerBound, upperBound);
setBayesPm(b);
}
use of edu.cmu.tetrad.bayes.BayesPm in project tetrad by cmu-phil.
the class XdslXmlParser method buildIM.
private BayesIm buildIM(Element element0, Map<String, String> displayNames) {
Elements elements = element0.getChildElements();
for (int i = 0; i < elements.size(); i++) {
if (!"cpt".equals(elements.get(i).getQualifiedName())) {
throw new IllegalArgumentException("Expecting cpt element.");
}
}
Dag dag = new Dag();
// Get the nodes.
for (int i = 0; i < elements.size(); i++) {
Element cpt = elements.get(i);
String name = cpt.getAttribute(0).getValue();
if (displayNames == null) {
dag.addNode(new GraphNode(name));
} else {
dag.addNode(new GraphNode(displayNames.get(name)));
}
}
// Get the edges.
for (int i = 0; i < elements.size(); i++) {
Element cpt = elements.get(i);
Elements cptElements = cpt.getChildElements();
for (int j = 0; j < cptElements.size(); j++) {
Element cptElement = cptElements.get(j);
if (cptElement.getQualifiedName().equals("parents")) {
String list = cptElement.getValue();
String[] parentNames = list.split(" ");
for (String name : parentNames) {
if (displayNames == null) {
edu.cmu.tetrad.graph.Node parent = dag.getNode(name);
edu.cmu.tetrad.graph.Node child = dag.getNode(cpt.getAttribute(0).getValue());
dag.addDirectedEdge(parent, child);
} else {
edu.cmu.tetrad.graph.Node parent = dag.getNode(displayNames.get(name));
edu.cmu.tetrad.graph.Node child = dag.getNode(displayNames.get(cpt.getAttribute(0).getValue()));
dag.addDirectedEdge(parent, child);
}
}
}
}
String name;
if (displayNames == null) {
name = cpt.getAttribute(0).getValue();
} else {
name = displayNames.get(cpt.getAttribute(0).getValue());
}
dag.addNode(new GraphNode(name));
}
// PM
BayesPm pm = new BayesPm(dag);
for (int i = 0; i < elements.size(); i++) {
Element cpt = elements.get(i);
String varName = cpt.getAttribute(0).getValue();
Node node;
if (displayNames == null) {
node = dag.getNode(varName);
} else {
node = dag.getNode(displayNames.get(varName));
}
Elements cptElements = cpt.getChildElements();
List<String> stateNames = new ArrayList<>();
for (int j = 0; j < cptElements.size(); j++) {
Element cptElement = cptElements.get(j);
if (cptElement.getQualifiedName().equals("state")) {
Attribute attribute = cptElement.getAttribute(0);
String stateName = attribute.getValue();
stateNames.add(stateName);
}
}
pm.setCategories(node, stateNames);
}
// IM
BayesIm im = new MlBayesIm(pm);
for (int nodeIndex = 0; nodeIndex < elements.size(); nodeIndex++) {
Element cpt = elements.get(nodeIndex);
Elements cptElements = cpt.getChildElements();
for (int j = 0; j < cptElements.size(); j++) {
Element cptElement = cptElements.get(j);
if (cptElement.getQualifiedName().equals("probabilities")) {
String list = cptElement.getValue();
String[] probsStrings = list.split(" ");
List<Double> probs = new ArrayList<>();
for (String probString : probsStrings) {
probs.add(Double.parseDouble(probString));
}
int count = -1;
for (int row = 0; row < im.getNumRows(nodeIndex); row++) {
for (int col = 0; col < im.getNumColumns(nodeIndex); col++) {
im.setProbability(nodeIndex, row, col, probs.get(++count));
}
}
}
}
}
return im;
}
use of edu.cmu.tetrad.bayes.BayesPm in project tetrad by cmu-phil.
the class FgesSearchEditor method reportIfDiscrete.
private String reportIfDiscrete(Graph dag, DataSet dataSet) {
List vars = dataSet.getVariables();
Map<String, DiscreteVariable> nodesToVars = new HashMap<>();
for (int i = 0; i < dataSet.getNumColumns(); i++) {
DiscreteVariable var = (DiscreteVariable) vars.get(i);
String name = var.getName();
Node node = new GraphNode(name);
nodesToVars.put(node.getName(), var);
}
BayesPm bayesPm = new BayesPm(new Dag(dag));
List<Node> nodes = bayesPm.getDag().getNodes();
for (Node node : nodes) {
Node var = nodesToVars.get(node.getName());
if (var instanceof DiscreteVariable) {
DiscreteVariable var2 = nodesToVars.get(node.getName());
int numCategories = var2.getNumCategories();
List<String> categories = new ArrayList<>();
for (int j = 0; j < numCategories; j++) {
categories.add(var2.getCategory(j));
}
bayesPm.setCategories(node, categories);
}
}
NumberFormat nf = NumberFormat.getInstance();
nf.setMaximumFractionDigits(4);
StringBuilder buf = new StringBuilder();
BayesProperties properties = new BayesProperties(dataSet);
double p = properties.getLikelihoodRatioP(dag);
double chisq = properties.getChisq();
double bic = properties.getBic();
double dof = properties.getDof();
buf.append("\nP = ").append(p);
buf.append("\nDOF = ").append(dof);
buf.append("\nChiSq = ").append(nf.format(chisq));
buf.append("\nBIC = ").append(nf.format(bic));
buf.append("\n\nH0: Complete DAG.");
return buf.toString();
}
Aggregations