Search in sources :

Example 1 with JTreePanel

use of dr.app.gui.tree.JTreePanel in project beast-mcmc by beast-dev.

the class BranchRatePlotter method main.

public static void main(String[] args) throws java.io.IOException, Importer.ImportException {
    String controlFile = args[0];
    // String treeFile1 = args[0];
    // String treeFile2 = args[1];
    String targetTreeFile = args[1];
    int burnin = 0;
    if (args.length > 2) {
        burnin = Integer.parseInt(args[2]);
    }
    System.out.println("Ignoring first " + burnin + " trees as burnin.");
    BufferedReader readerTarget = new BufferedReader(new FileReader(targetTreeFile));
    String lineTarget = readerTarget.readLine();
    readerTarget.close();
    TreeImporter targetImporter;
    if (lineTarget.toUpperCase().startsWith("#NEXUS")) {
        targetImporter = new NexusImporter(new FileReader(targetTreeFile));
    } else {
        targetImporter = new NewickImporter(new FileReader(targetTreeFile));
    }
    MutableTree targetTree = new FlexibleTree(targetImporter.importNextTree());
    targetTree = TreeUtils.rotateTreeByComparator(targetTree, TreeUtils.createNodeDensityComparator(targetTree));
    BufferedReader reader = new BufferedReader(new FileReader(controlFile));
    String line = reader.readLine();
    int totalTrees = 0;
    int totalTreesUsed = 0;
    while (line != null) {
        StringTokenizer tokens = new StringTokenizer(line);
        NexusImporter importer1 = new NexusImporter(new FileReader(tokens.nextToken()));
        NexusImporter importer2 = new NexusImporter(new FileReader(tokens.nextToken()));
        int fileTotalTrees = 0;
        while (importer1.hasTree()) {
            Tree timeTree = importer1.importNextTree();
            Tree mutationTree = importer2.importNextTree();
            if (fileTotalTrees >= burnin) {
                annotateRates(targetTree, targetTree.getRoot(), timeTree, mutationTree);
                totalTreesUsed += 1;
            }
            totalTrees += 1;
            fileTotalTrees += 1;
        }
        line = reader.readLine();
    }
    System.out.println("Total trees read: " + totalTrees);
    System.out.println("Total trees summarized: " + totalTreesUsed);
    // collect all rates
    double mutations = 0.0;
    double time = 0.0;
    double[] rates = new double[targetTree.getNodeCount() - 1];
    int index = 0;
    for (int i = 0; i < targetTree.getNodeCount(); i++) {
        NodeRef node = targetTree.getNode(i);
        if (!targetTree.isRoot(node)) {
            Integer count = ((Integer) targetTree.getNodeAttribute(node, "count"));
            if (count == null) {
                throw new RuntimeException("Count missing from node in target tree");
            }
            if (!targetTree.isExternal(node)) {
                double prob = (double) (int) count / (double) (totalTreesUsed);
                if (prob >= 0.5) {
                    String label = "" + (Math.round(prob * 100) / 100.0);
                    targetTree.setNodeAttribute(node, "label", label);
                }
            }
            Number totalMutations = (Number) targetTree.getNodeAttribute(node, "totalMutations");
            Number totalTime = (Number) targetTree.getNodeAttribute(node, "totalTime");
            mutations += totalMutations.doubleValue();
            time += totalTime.doubleValue();
            rates[index] = totalMutations.doubleValue() / totalTime.doubleValue();
            System.out.println(totalMutations.doubleValue() + " / " + totalTime.doubleValue() + " = " + rates[index]);
            targetTree.setNodeRate(node, rates[index]);
            index += 1;
        }
    }
    double minRate = DiscreteStatistics.min(rates);
    double maxRate = DiscreteStatistics.max(rates);
    double medianRate = DiscreteStatistics.median(rates);
    // double topThird = DiscreteStatistics.quantile(2.0/3.0,rates);
    // double bottomThird = DiscreteStatistics.quantile(1.0/3.0,rates);
    // double unweightedMeanRate = DiscreteStatistics.mean(rates);
    double meanRate = mutations / time;
    System.out.println(minRate + "\t" + maxRate + "\t" + medianRate + "\t" + meanRate);
    for (int i = 0; i < targetTree.getNodeCount(); i++) {
        NodeRef node = targetTree.getNode(i);
        if (!targetTree.isRoot(node)) {
            double rate = targetTree.getNodeRate(node);
            // double branchTime = ((Number)targetTree.getNodeAttribute(node, "totalTime")).doubleValue();
            // double branchMutations = ((Number)targetTree.getNodeAttribute(node, "totalMutations")).doubleValue();
            float relativeRate = (float) (rate / maxRate);
            float radius = (float) Math.sqrt(relativeRate * 36.0);
            if (rate > meanRate) {
                targetTree.setNodeAttribute(node, "color", new Color(1.0f, 0.5f, 0.5f));
            } else {
                targetTree.setNodeAttribute(node, "color", new Color(0.5f, 0.5f, 1.0f));
            }
            // targetTree.setNodeAttribute(node, "color", new Color(red, green, blue));
            targetTree.setNodeAttribute(node, "line", new BasicStroke(1.0f));
            targetTree.setNodeAttribute(node, "shape", new java.awt.geom.Ellipse2D.Double(0, 0, radius * 2.0, radius * 2.0));
        }
        java.util.List heightList = (java.util.List) targetTree.getNodeAttribute(node, "heightList");
        if (heightList != null) {
            double[] heights = new double[heightList.size()];
            for (int j = 0; j < heights.length; j++) {
                heights[j] = (Double) heightList.get(j);
            }
            targetTree.setNodeHeight(node, DiscreteStatistics.mean(heights));
            // if (heights.length >= (totalTreesUsed/2)) {
            targetTree.setNodeAttribute(node, "nodeHeight.mean", DiscreteStatistics.mean(heights));
            targetTree.setNodeAttribute(node, "nodeHeight.hpdUpper", DiscreteStatistics.quantile(0.975, heights));
            targetTree.setNodeAttribute(node, "nodeHeight.hpdLower", DiscreteStatistics.quantile(0.025, heights));
        // targetTree.setNodeAttribute(node, "nodeHeight.max", new Double(DiscreteStatistics.max(heights)));
        // targetTree.setNodeAttribute(node, "nodeHeight.min", new Double(DiscreteStatistics.min(heights)));
        // }
        }
    }
    StringBuffer buffer = new StringBuffer();
    writeTree(targetTree, targetTree.getRoot(), buffer, true, false);
    buffer.append(";\n");
    writeTree(targetTree, targetTree.getRoot(), buffer, false, true);
    buffer.append(";\n");
    System.out.println(buffer.toString());
    SquareTreePainter treePainter = new SquareTreePainter();
    treePainter.setColorAttribute("color");
    treePainter.setLineAttribute("line");
    // treePainter.setShapeAttribute("shape");
    // treePainter.setLabelAttribute("label");
    JTreeDisplay treeDisplay = new JTreeDisplay(treePainter, targetTree);
    JTreePanel treePanel = new JTreePanel(treeDisplay);
    JFrame frame = new JFrame();
    frame.setSize(800, 600);
    frame.getContentPane().setLayout(new BorderLayout());
    frame.getContentPane().add(treePanel);
    frame.setVisible(true);
    PrinterJob printJob = PrinterJob.getPrinterJob();
    printJob.setPrintable(treeDisplay);
    if (printJob.printDialog()) {
        try {
            printJob.print();
        } catch (Exception ex) {
            throw new RuntimeException(ex);
        }
    }
}
Also used : PrinterJob(java.awt.print.PrinterJob) NewickImporter(dr.evolution.io.NewickImporter) JTreeDisplay(dr.app.gui.tree.JTreeDisplay) FileReader(java.io.FileReader) ArrayList(java.util.ArrayList) NexusImporter(dr.evolution.io.NexusImporter) JTreePanel(dr.app.gui.tree.JTreePanel) SquareTreePainter(dr.app.gui.tree.SquareTreePainter) StringTokenizer(java.util.StringTokenizer) BufferedReader(java.io.BufferedReader) TreeImporter(dr.evolution.io.TreeImporter) java.awt(java.awt)

Aggregations

JTreeDisplay (dr.app.gui.tree.JTreeDisplay)1 JTreePanel (dr.app.gui.tree.JTreePanel)1 SquareTreePainter (dr.app.gui.tree.SquareTreePainter)1 NewickImporter (dr.evolution.io.NewickImporter)1 NexusImporter (dr.evolution.io.NexusImporter)1 TreeImporter (dr.evolution.io.TreeImporter)1 java.awt (java.awt)1 PrinterJob (java.awt.print.PrinterJob)1 BufferedReader (java.io.BufferedReader)1 FileReader (java.io.FileReader)1 ArrayList (java.util.ArrayList)1 StringTokenizer (java.util.StringTokenizer)1