Search in sources :

Example 1 with Regression

use of dr.stats.Regression in project beast-mcmc by beast-dev.

the class TemporalRooting method findLocalRoot.

private double findLocalRoot(final FlexibleTree tree, final double[] dates, final RootingFunction rootingFunction, final boolean forcePositiveRate) {
    if (rootingFunction == RootingFunction.RESIDUAL_MEAN_SQUARED) {
        return findAnalyticalLocalRoot(tree, dates, rootingFunction);
    }
    NodeRef node1 = tree.getChild(tree.getRoot(), 0);
    NodeRef node2 = tree.getChild(tree.getRoot(), 1);
    final double length1 = tree.getBranchLength(node1);
    final double length2 = tree.getBranchLength(node2);
    final double sumLength = length1 + length2;
    final Set<NodeRef> tipSet1 = TreeUtils.getExternalNodes(tree, node1);
    final Set<NodeRef> tipSet2 = TreeUtils.getExternalNodes(tree, node2);
    final double[] y = new double[tree.getExternalNodeCount()];
    UnivariateFunction f = new UnivariateFunction() {

        //        MultivariateFunction f = new MultivariateFunction() {
        public double evaluate(final double argument) {
            double l1 = argument * sumLength;
            for (NodeRef tip : tipSet1) {
                y[tip.getNumber()] = getRootToTipDistance(tree, tip) - length1 + l1;
            }
            double l2 = (1.0 - argument) * sumLength;
            for (NodeRef tip : tipSet2) {
                y[tip.getNumber()] = getRootToTipDistance(tree, tip) - length2 + l2;
            }
            double score;
            if (!contemporaneous) {
                Regression r = new Regression(dates, y);
                switch(rootingFunction) {
                    case CORRELATION:
                        score = -r.getCorrelationCoefficient();
                        break;
                    case R_SQUARED:
                        score = -r.getRSquared();
                        break;
                    case HEURISTIC_RESIDUAL_MEAN_SQUARED:
                    case RESIDUAL_MEAN_SQUARED:
                        score = r.getResidualMeanSquared();
                        break;
                    default:
                        throw new RuntimeException("Unknown enum value");
                }
                if (forcePositiveRate) {
                    score = (r.getGradient() < 0.0 ? -score : score);
                }
            } else {
                score = DiscreteStatistics.variance(y);
            }
            return score;
        }

        public int getNumArguments() {
            return 1;
        }

        public double getLowerBound() {
            return 0;
        }

        public double getUpperBound() {
            return 1.0;
        }
    };
    //        DifferentialEvolution minimum = new DifferentialEvolution(1);
    //        ConjugateDirectionSearch minimum = new ConjugateDirectionSearch();
    //        double[] minx = new double[] { 0.5 };
    //
    //        double fminx = minimum.findMinimum(f, minx);
    //        double x = minx[0];
    UnivariateMinimum minimum = new UnivariateMinimum();
    double x = minimum.findMinimum(f);
    double fminx = minimum.fminx;
    double l1 = x * sumLength;
    double l2 = (1.0 - x) * sumLength;
    tree.setBranchLength(node1, l1);
    tree.setBranchLength(node2, l2);
    return fminx;
}
Also used : Regression(dr.stats.Regression)

Example 2 with Regression

use of dr.stats.Regression in project beast-mcmc by beast-dev.

the class TemporalRooting method findAnalyticalLocalRoot.

private double findAnalyticalLocalRoot(final FlexibleTree tree, final double[] t, final RootingFunction rootingFunction) {
    if (rootingFunction != RootingFunction.RESIDUAL_MEAN_SQUARED) {
        throw new UnsupportedOperationException("Analytical local root solution only for residual mean squared");
    }
    NodeRef node1 = tree.getChild(tree.getRoot(), 0);
    NodeRef node2 = tree.getChild(tree.getRoot(), 1);
    final double length1 = tree.getBranchLength(node1);
    final double length2 = tree.getBranchLength(node2);
    final double sumLength = length1 + length2;
    final Set<NodeRef> tipSet1 = TreeUtils.getExternalNodes(tree, node1);
    final Set<NodeRef> tipSet2 = TreeUtils.getExternalNodes(tree, node2);
    int N = tipSet1.size() + tipSet2.size();
    int n = tipSet2.size();
    final double[] c = new double[N];
    for (NodeRef tip : tipSet2) {
        int i = tip.getNumber();
        c[i] = 1;
    }
    final double[] y = getRootToTipDistances(tree);
    for (int j = 0; j < y.length; j++) {
        // little fiddling with the root-to-tip divergences to get the right input vector
        y[j] = y[j] + (1 - c[j]) * (sumLength - length1) - c[j] * (sumLength - length1);
    }
    double sum_tt = 0.0;
    double sum_t = 0.0;
    double sum_y = 0.0;
    double sum_ty = 0.0;
    double sum_tc = 0.0;
    double Nd = N;
    // need to set these naughty guys to doubles
    double nd = n;
    for (int i = 0; i < N; i++) {
        sum_tt += t[i] * t[i];
        sum_t += t[i];
        sum_y += y[i];
        sum_ty += t[i] * y[i];
        sum_tc += t[i] * c[i];
    }
    double y_bar = sum_y / Nd;
    double t_bar = sum_t / Nd;
    double C = sum_tt - (sum_t * sum_t / Nd);
    double sumAB = 0.0;
    double sumAA = 0.0;
    for (int i = 0; i < N; i++) {
        double Ai = 2 * c[i] - ((2 * nd - Nd) / Nd) + (2 * (t_bar - t[i]) / (C * Nd)) * (Nd * sum_tc - nd * sum_t) - 1;
        double Bi = (y[i] - y_bar) + ((t_bar - t[i]) / (C * Nd)) * ((Nd * sum_ty) - (sum_t * sum_y));
        sumAB += Ai * Bi;
        sumAA += Ai * Ai;
    }
    double x = -sumAB / (sumLength * sumAA);
    x = Math.min(Math.max(x, 0.0), 1.0);
    double l1 = (1.0 - x) * sumLength;
    double l2 = x * sumLength;
    tree.setBranchLength(node1, l1);
    tree.setBranchLength(node2, l2);
    Regression r = new Regression(t, getRootToTipDistances(tree));
    return r.getResidualMeanSquared();
}
Also used : Regression(dr.stats.Regression)

Example 3 with Regression

use of dr.stats.Regression in project beast-mcmc by beast-dev.

the class TempestPanel method selectMRCA.

private void selectMRCA() {
    if (mrcaPlot == null)
        return;
    if (selectedPoints != null && selectedPoints.size() > 0) {
        Set<String> selectedTaxa = new HashSet<String>();
        for (Integer i : selectedPoints) {
            selectedTaxa.add(tree.getTaxon(i).toString());
        }
        Regression r = temporalRooting.getRootToTipRegression(currentTree);
        NodeRef mrca = TreeUtils.getCommonAncestorNode(currentTree, selectedTaxa);
        double mrcaDistance1 = temporalRooting.getRootToTipDistance(currentTree, mrca);
        double mrcaTime1 = r.getX(mrcaDistance1);
        if (tree.isExternal(mrca)) {
            mrca = tree.getParent(mrca);
        }
        double mrcaDistance = temporalRooting.getRootToTipDistance(currentTree, mrca);
        double mrcaTime = r.getX(mrcaDistance);
        mrcaPlot.setSelectedPoints(selectedPoints, mrcaTime, mrcaDistance);
    } else {
        mrcaPlot.clearSelection();
    }
    repaint();
}
Also used : Regression(dr.stats.Regression)

Example 4 with Regression

use of dr.stats.Regression in project beast-mcmc by beast-dev.

the class TempestPanel method setupPanel.

public void setupPanel() {
    StringBuilder sb = new StringBuilder();
    NumberFormatter nf = new NumberFormatter(6);
    if (tree != null) {
        temporalRooting = new TemporalRooting(tree);
        currentTree = this.tree;
        if (bestFittingRoot && bestFittingRootTree != null) {
            currentTree = bestFittingRootTree;
            sb.append("Best-fitting root");
        } else {
            sb.append("User root");
        }
        if (temporalRooting.isContemporaneous()) {
            if (tabbedPane.getSelectedIndex() == 2) {
                tabbedPane.setSelectedIndex(1);
            }
            tabbedPane.setEnabledAt(2, false);
        } else {
            tabbedPane.setEnabledAt(2, true);
        }
        RootedTree jtree = dr.evolution.tree.TreeUtils.asJeblTree(currentTree);
        List<Color> colours = new ArrayList<Color>();
        for (Node tip : jtree.getExternalNodes()) {
            Taxon taxon = jtree.getTaxon(tip);
            colours.add((Color) taxon.getAttribute("!color"));
        }
        if (temporalRooting.isContemporaneous()) {
            double[] dv = temporalRooting.getRootToTipDistances(currentTree);
            List<Double> values = new ArrayList<Double>();
            for (double d : dv) {
                values.add(d);
            }
            rootToTipChart.removeAllPlots();
            NumericalDensityPlot dp = new NumericalDensityPlot(values, 20, null);
            dp.setLineColor(new Color(9, 70, 15));
            double yOffset = (Double) dp.getYData().getMax() / 2;
            List<Double> dummyValues = new ArrayList<Double>();
            for (int i = 0; i < values.size(); i++) {
                // add a random y offset to give some visual spread
                double y = MathUtils.nextGaussian() * ((Double) dp.getYData().getMax() * 0.05);
                dummyValues.add(yOffset + y);
            }
            rootToTipPlot = new ScatterPlot(values, dummyValues);
            rootToTipPlot.setColours(colours);
            rootToTipPlot.setMarkStyle(Plot.CIRCLE_MARK, 8, new BasicStroke(0.0F), new Color(44, 44, 44), new Color(129, 149, 149));
            rootToTipPlot.setHilightedMarkStyle(new BasicStroke(0.5F), new Color(44, 44, 44), UIManager.getColor("List.selectionBackground"));
            rootToTipPlot.addListener(new Plot.Adaptor() {

                @Override
                public void markClicked(int index, double x, double y, boolean isShiftDown) {
                    rootToTipPlot.selectPoint(index, isShiftDown);
                }

                public void selectionChanged(final Set<Integer> selectedPoints) {
                    plotSelectionChanged(selectedPoints);
                }
            });
            rootToTipChart.addPlot(rootToTipPlot);
            rootToTipChart.addPlot(dp);
            rootToTipPanel.setXAxisTitle("root-to-tip divergence");
            rootToTipPanel.setYAxisTitle("proportion");
            residualChart.removeAllPlots();
            sb.append(", contemporaneous tips");
            sb.append(", mean root-tip distance: " + nf.format(DiscreteStatistics.mean(dv)));
            sb.append(", coefficient of variation: " + nf.format(DiscreteStatistics.stdev(dv) / DiscreteStatistics.mean(dv)));
            sb.append(", stdev: " + nf.format(DiscreteStatistics.stdev(dv)));
            sb.append(", variance: " + nf.format(DiscreteStatistics.variance(dv)));
            showMRCACheck.setVisible(false);
        } else {
            Regression r = temporalRooting.getRootToTipRegression(currentTree);
            double[] residuals = temporalRooting.getRootToTipResiduals(currentTree, r);
            pointMap.clear();
            for (int i = 0; i < currentTree.getExternalNodeCount(); i++) {
                NodeRef tip = currentTree.getExternalNode(i);
                Node node = jtree.getNode(Taxon.getTaxon(currentTree.getNodeTaxon(tip).getId()));
                node.setAttribute("residual", residuals[i]);
                pointMap.put(node, i);
            }
            rootToTipChart.removeAllPlots();
            if (showMRCACheck.isSelected()) {
                double[] dv = temporalRooting.getParentRootToTipDistances(currentTree);
                List<Double> parentDistances = new ArrayList<Double>();
                for (int i = 0; i < dv.length; i++) {
                    parentDistances.add(i, dv[i]);
                }
                List<Double> parentTimes = new ArrayList<Double>();
                for (int i = 0; i < parentDistances.size(); i++) {
                    parentTimes.add(i, r.getX(parentDistances.get(i)));
                }
                mrcaPlot = new ParentPlot(r.getXData(), r.getYData(), parentTimes, parentDistances);
                mrcaPlot.setLineColor(new Color(105, 202, 105));
                mrcaPlot.setLineStroke(new BasicStroke(0.5F));
                rootToTipChart.addPlot(mrcaPlot);
            }
            if (true) {
                double[] datePrecisions = temporalRooting.getTipDatePrecisions(currentTree);
                Variate.D ed = new Variate.D();
                for (int i = 0; i < datePrecisions.length; i++) {
                    ed.add(datePrecisions[i]);
                }
                errorBarPlot = new ErrorBarPlot(ErrorBarPlot.Orientation.HORIZONTAL, r.getXData(), r.getYData(), ed);
                errorBarPlot.setLineColor(new Color(44, 44, 44));
                errorBarPlot.setLineStroke(new BasicStroke(1.0F));
                rootToTipChart.addPlot(errorBarPlot);
            }
            rootToTipPlot = new ScatterPlot(r.getXData(), r.getYData());
            rootToTipPlot.addListener(new Plot.Adaptor() {

                public void selectionChanged(final Set<Integer> selectedPoints) {
                    plotSelectionChanged(selectedPoints);
                }
            });
            rootToTipPlot.setColours(colours);
            rootToTipPlot.setMarkStyle(Plot.CIRCLE_MARK, 8, new BasicStroke(0.0F), new Color(44, 44, 44), new Color(129, 149, 149));
            rootToTipPlot.setHilightedMarkStyle(new BasicStroke(0.5F), new Color(44, 44, 44), UIManager.getColor("List.selectionBackground"));
            rootToTipChart.addPlot(rootToTipPlot);
            rootToTipChart.addPlot(new RegressionPlot(r));
            rootToTipChart.getXAxis().addRange(r.getXIntercept(), (Double) r.getXData().getMax());
            rootToTipPanel.setXAxisTitle("time");
            rootToTipPanel.setYAxisTitle("root-to-tip divergence");
            residualChart.removeAllPlots();
            Variate.D values = (Variate.D) r.getYResidualData();
            NumericalDensityPlot dp = new NumericalDensityPlot(values, 20);
            dp.setLineColor(new Color(103, 128, 144));
            double yOffset = (Double) dp.getYData().getMax() / 2;
            Double[] dummyValues = new Double[values.getCount()];
            for (int i = 0; i < dummyValues.length; i++) {
                // add a random y offset to give some visual spread
                double y = MathUtils.nextGaussian() * ((Double) dp.getYData().getMax() * 0.05);
                dummyValues[i] = yOffset + y;
            }
            Variate.D yOffsetValues = new Variate.D(dummyValues);
            residualPlot = new ScatterPlot(values, yOffsetValues);
            residualPlot.addListener(new Plot.Adaptor() {

                @Override
                public void markClicked(int index, double x, double y, boolean isShiftDown) {
                    rootToTipPlot.selectPoint(index, isShiftDown);
                }

                @Override
                public void selectionChanged(final Set<Integer> selectedPoints) {
                    plotSelectionChanged(selectedPoints);
                }
            });
            residualPlot.setColours(colours);
            residualPlot.setMarkStyle(Plot.CIRCLE_MARK, 8, new BasicStroke(0.0F), new Color(44, 44, 44), new Color(129, 149, 149));
            residualPlot.setHilightedMarkStyle(new BasicStroke(0.5F), new Color(44, 44, 44), UIManager.getColor("List.selectionBackground"));
            residualChart.addPlot(residualPlot);
            residualChart.addPlot(dp);
            residualPanel.setXAxisTitle("residual");
            residualPanel.setYAxisTitle("proportion");
            if (SHOW_NODE_DENSITY) {
                Regression r2 = temporalRooting.getNodeDensityRegression(currentTree);
                nodeDensityChart.removeAllPlots();
                nodeDensityPlot = new ScatterPlot(r2.getXData(), r2.getYData());
                nodeDensityPlot.addListener(new Plot.Adaptor() {

                    public void selectionChanged(final Set<Integer> selectedPoints) {
                        plotSelectionChanged(selectedPoints);
                    }
                });
                nodeDensityPlot.setColours(colours);
                nodeDensityPlot.setMarkStyle(Plot.CIRCLE_MARK, 8, new BasicStroke(0.0F), new Color(44, 44, 44), new Color(129, 149, 149));
                nodeDensityPlot.setHilightedMarkStyle(new BasicStroke(0.5F), new Color(44, 44, 44), UIManager.getColor("List.selectionBackground"));
                nodeDensityChart.addPlot(nodeDensityPlot);
                nodeDensityChart.addPlot(new RegressionPlot(r2));
                nodeDensityChart.getXAxis().addRange(r2.getXIntercept(), (Double) r2.getXData().getMax());
                nodeDensityPanel.setXAxisTitle("time");
                nodeDensityPanel.setYAxisTitle("node density");
            }
            sb.append(", dated tips");
            sb.append(", date range: " + nf.format(temporalRooting.getDateRange()));
            sb.append(", slope (rate): " + nf.format(r.getGradient()));
            sb.append(", x-intercept (TMRCA): " + nf.format(r.getXIntercept()));
            sb.append(", corr. coeff: " + nf.format(r.getCorrelationCoefficient()));
            sb.append(", R^2: " + nf.format(r.getRSquared()));
            showMRCACheck.setVisible(true);
        }
        treePanel.setTree(jtree);
        treePanel.setColourBy("residual");
    } else {
        treePanel.setTree(null);
        rootToTipChart.removeAllPlots();
        sb.append("No trees loaded");
    }
    textArea.setText(sb.toString());
    statisticsModel.fireTableStructureChanged();
    repaint();
}
Also used : Variate(dr.stats.Variate) Node(jebl.evolution.graphs.Node) Taxon(jebl.evolution.taxa.Taxon) Regression(dr.stats.Regression) RootedTree(jebl.evolution.trees.RootedTree) NumberFormatter(dr.util.NumberFormatter)

Example 5 with Regression

use of dr.stats.Regression in project beast-mcmc by beast-dev.

the class TemporalRooting method findAnalyticalLocalRoot.

private double findAnalyticalLocalRoot(final FlexibleTree tree, final double[] t, final RootingFunction rootingFunction) {
    if (rootingFunction != RootingFunction.RESIDUAL_MEAN_SQUARED) {
        throw new UnsupportedOperationException("Analytical local root solution only for residual mean squared");
    }
    NodeRef node1 = tree.getChild(tree.getRoot(), 0);
    NodeRef node2 = tree.getChild(tree.getRoot(), 1);
    final double length1 = tree.getBranchLength(node1);
    final double length2 = tree.getBranchLength(node2);
    final double sumLength = length1 + length2;
    final Set<NodeRef> tipSet1 = TreeUtils.getExternalNodes(tree, node1);
    final Set<NodeRef> tipSet2 = TreeUtils.getExternalNodes(tree, node2);
    int N = tipSet1.size() + tipSet2.size();
    int n = tipSet2.size();
    final double[] c = new double[N];
    for (NodeRef tip : tipSet2) {
        int i = tip.getNumber();
        c[i] = 1;
    }
    final double[] y = getRootToTipDistances(tree);
    for (int j = 0; j < y.length; j++) {
        // little fiddling with the root-to-tip divergences to get the right input vector
        y[j] = y[j] + (1 - c[j]) * (sumLength - length1) - c[j] * (sumLength - length1);
    }
    double sum_tt = 0.0;
    double sum_t = 0.0;
    double sum_y = 0.0;
    double sum_ty = 0.0;
    double sum_tc = 0.0;
    double Nd = N;
    // need to set these naughty guys to doubles
    double nd = n;
    for (int i = 0; i < N; i++) {
        sum_tt += t[i] * t[i];
        sum_t += t[i];
        sum_y += y[i];
        sum_ty += t[i] * y[i];
        sum_tc += t[i] * c[i];
    }
    double y_bar = sum_y / Nd;
    double t_bar = sum_t / Nd;
    double C = sum_tt - (sum_t * sum_t / Nd);
    double sumAB = 0.0;
    double sumAA = 0.0;
    for (int i = 0; i < N; i++) {
        double Ai = 2 * c[i] - ((2 * nd - Nd) / Nd) + (2 * (t_bar - t[i]) / (C * Nd)) * (Nd * sum_tc - nd * sum_t) - 1;
        double Bi = (y[i] - y_bar) + ((t_bar - t[i]) / (C * Nd)) * ((Nd * sum_ty) - (sum_t * sum_y));
        sumAB += Ai * Bi;
        sumAA += Ai * Ai;
    }
    double x = -sumAB / (sumLength * sumAA);
    x = Math.min(Math.max(x, 0.0), 1.0);
    double l1 = (1.0 - x) * sumLength;
    double l2 = x * sumLength;
    tree.setBranchLength(node1, l1);
    tree.setBranchLength(node2, l2);
    Regression r = new Regression(t, getRootToTipDistances(tree));
    return r.getResidualMeanSquared();
}
Also used : Regression(dr.stats.Regression)

Aggregations

Regression (dr.stats.Regression)17 NodeRef (dr.evolution.tree.NodeRef)3 NexusExporter (dr.app.tools.NexusExporter)2 FlexibleTree (dr.evolution.tree.FlexibleTree)2 Variate (dr.stats.Variate)2 NumberFormatter (dr.util.NumberFormatter)2 PrintWriter (java.io.PrintWriter)2 Node (jebl.evolution.graphs.Node)2 Taxon (jebl.evolution.taxa.Taxon)2 RootedTree (jebl.evolution.trees.RootedTree)2 MultivariateTraitTree (dr.evolution.tree.MultivariateTraitTree)1 TreeUtils (dr.evolution.tree.TreeUtils)1 BranchRateModel (dr.evomodel.branchratemodel.BranchRateModel)1 SphericalPolarCoordinates (dr.geo.math.SphericalPolarCoordinates)1