use of dr.app.gui.tree.SquareTreePainter in project beast-mcmc by beast-dev.
the class BranchRatePlotter method main.
public static void main(String[] args) throws java.io.IOException, Importer.ImportException {
String controlFile = args[0];
//String treeFile1 = args[0];
//String treeFile2 = args[1];
String targetTreeFile = args[1];
int burnin = 0;
if (args.length > 2) {
burnin = Integer.parseInt(args[2]);
}
System.out.println("Ignoring first " + burnin + " trees as burnin.");
BufferedReader readerTarget = new BufferedReader(new FileReader(targetTreeFile));
String lineTarget = readerTarget.readLine();
readerTarget.close();
TreeImporter targetImporter;
if (lineTarget.toUpperCase().startsWith("#NEXUS")) {
targetImporter = new NexusImporter(new FileReader(targetTreeFile));
} else {
targetImporter = new NewickImporter(new FileReader(targetTreeFile));
}
MutableTree targetTree = new FlexibleTree(targetImporter.importNextTree());
targetTree = TreeUtils.rotateTreeByComparator(targetTree, TreeUtils.createNodeDensityComparator(targetTree));
BufferedReader reader = new BufferedReader(new FileReader(controlFile));
String line = reader.readLine();
int totalTrees = 0;
int totalTreesUsed = 0;
while (line != null) {
StringTokenizer tokens = new StringTokenizer(line);
NexusImporter importer1 = new NexusImporter(new FileReader(tokens.nextToken()));
NexusImporter importer2 = new NexusImporter(new FileReader(tokens.nextToken()));
int fileTotalTrees = 0;
while (importer1.hasTree()) {
Tree timeTree = importer1.importNextTree();
Tree mutationTree = importer2.importNextTree();
if (fileTotalTrees >= burnin) {
annotateRates(targetTree, targetTree.getRoot(), timeTree, mutationTree);
totalTreesUsed += 1;
}
totalTrees += 1;
fileTotalTrees += 1;
}
line = reader.readLine();
}
System.out.println("Total trees read: " + totalTrees);
System.out.println("Total trees summarized: " + totalTreesUsed);
// collect all rates
double mutations = 0.0;
double time = 0.0;
double[] rates = new double[targetTree.getNodeCount() - 1];
int index = 0;
for (int i = 0; i < targetTree.getNodeCount(); i++) {
NodeRef node = targetTree.getNode(i);
if (!targetTree.isRoot(node)) {
Integer count = ((Integer) targetTree.getNodeAttribute(node, "count"));
if (count == null) {
throw new RuntimeException("Count missing from node in target tree");
}
if (!targetTree.isExternal(node)) {
double prob = (double) (int) count / (double) (totalTreesUsed);
if (prob >= 0.5) {
String label = "" + (Math.round(prob * 100) / 100.0);
targetTree.setNodeAttribute(node, "label", label);
}
}
Number totalMutations = (Number) targetTree.getNodeAttribute(node, "totalMutations");
Number totalTime = (Number) targetTree.getNodeAttribute(node, "totalTime");
mutations += totalMutations.doubleValue();
time += totalTime.doubleValue();
rates[index] = totalMutations.doubleValue() / totalTime.doubleValue();
System.out.println(totalMutations.doubleValue() + " / " + totalTime.doubleValue() + " = " + rates[index]);
targetTree.setNodeRate(node, rates[index]);
index += 1;
}
}
double minRate = DiscreteStatistics.min(rates);
double maxRate = DiscreteStatistics.max(rates);
double medianRate = DiscreteStatistics.median(rates);
//double topThird = DiscreteStatistics.quantile(2.0/3.0,rates);
//double bottomThird = DiscreteStatistics.quantile(1.0/3.0,rates);
//double unweightedMeanRate = DiscreteStatistics.mean(rates);
double meanRate = mutations / time;
System.out.println(minRate + "\t" + maxRate + "\t" + medianRate + "\t" + meanRate);
for (int i = 0; i < targetTree.getNodeCount(); i++) {
NodeRef node = targetTree.getNode(i);
if (!targetTree.isRoot(node)) {
double rate = targetTree.getNodeRate(node);
//double branchTime = ((Number)targetTree.getNodeAttribute(node, "totalTime")).doubleValue();
//double branchMutations = ((Number)targetTree.getNodeAttribute(node, "totalMutations")).doubleValue();
float relativeRate = (float) (rate / maxRate);
float radius = (float) Math.sqrt(relativeRate * 36.0);
if (rate > meanRate) {
targetTree.setNodeAttribute(node, "color", new Color(1.0f, 0.5f, 0.5f));
} else {
targetTree.setNodeAttribute(node, "color", new Color(0.5f, 0.5f, 1.0f));
}
//targetTree.setNodeAttribute(node, "color", new Color(red, green, blue));
targetTree.setNodeAttribute(node, "line", new BasicStroke(1.0f));
targetTree.setNodeAttribute(node, "shape", new java.awt.geom.Ellipse2D.Double(0, 0, radius * 2.0, radius * 2.0));
}
java.util.List heightList = (java.util.List) targetTree.getNodeAttribute(node, "heightList");
if (heightList != null) {
double[] heights = new double[heightList.size()];
for (int j = 0; j < heights.length; j++) {
heights[j] = (Double) heightList.get(j);
}
targetTree.setNodeHeight(node, DiscreteStatistics.mean(heights));
//if (heights.length >= (totalTreesUsed/2)) {
targetTree.setNodeAttribute(node, "nodeHeight.mean", DiscreteStatistics.mean(heights));
targetTree.setNodeAttribute(node, "nodeHeight.hpdUpper", DiscreteStatistics.quantile(0.975, heights));
targetTree.setNodeAttribute(node, "nodeHeight.hpdLower", DiscreteStatistics.quantile(0.025, heights));
//targetTree.setNodeAttribute(node, "nodeHeight.max", new Double(DiscreteStatistics.max(heights)));
//targetTree.setNodeAttribute(node, "nodeHeight.min", new Double(DiscreteStatistics.min(heights)));
//}
}
}
StringBuffer buffer = new StringBuffer();
writeTree(targetTree, targetTree.getRoot(), buffer, true, false);
buffer.append(";\n");
writeTree(targetTree, targetTree.getRoot(), buffer, false, true);
buffer.append(";\n");
System.out.println(buffer.toString());
SquareTreePainter treePainter = new SquareTreePainter();
treePainter.setColorAttribute("color");
treePainter.setLineAttribute("line");
// treePainter.setShapeAttribute("shape");
// treePainter.setLabelAttribute("label");
JTreeDisplay treeDisplay = new JTreeDisplay(treePainter, targetTree);
JTreePanel treePanel = new JTreePanel(treeDisplay);
JFrame frame = new JFrame();
frame.setSize(800, 600);
frame.getContentPane().setLayout(new BorderLayout());
frame.getContentPane().add(treePanel);
frame.setVisible(true);
PrinterJob printJob = PrinterJob.getPrinterJob();
printJob.setPrintable(treeDisplay);
if (printJob.printDialog()) {
try {
printJob.print();
} catch (Exception ex) {
throw new RuntimeException(ex);
}
}
}
Aggregations