Search in sources :

Example 11 with DataLikelihoodDelegate

use of dr.evomodel.treedatalikelihood.DataLikelihoodDelegate in project beast-mcmc by beast-dev.

the class CheckPointTreeModifier method incorporateAdditionalTaxa.

/**
 * Add the remaining taxa, which can be identified through the TreeDataLikelihood XML elements.
 */
public ArrayList<NodeRef> incorporateAdditionalTaxa(CheckPointUpdaterApp.UpdateChoice choice, BranchRates rateModel, ArrayList<TreeParameterModel> traitModels) {
    // public ArrayList<NodeRef> incorporateAdditionalTaxa(CheckPointUpdaterApp.UpdateChoice choice, BranchRates rateModel) {
    System.out.println("Tree before adding taxa:\n" + treeModel.toString() + "\n");
    ArrayList<NodeRef> newTaxaNodes = new ArrayList<NodeRef>();
    for (String str : newTaxaNames) {
        for (int i = 0; i < treeModel.getExternalNodeCount(); i++) {
            if (treeModel.getNodeTaxon(treeModel.getExternalNode(i)).getId().equals(str)) {
                newTaxaNodes.add(treeModel.getExternalNode(i));
                // always take into account Taxon dates vs. dates set through a TreeModel
                System.out.println(treeModel.getNodeTaxon(treeModel.getExternalNode(i)).getId() + " with height " + treeModel.getNodeHeight(treeModel.getExternalNode(i)) + " or " + treeModel.getNodeTaxon(treeModel.getExternalNode(i)).getHeight());
            }
        }
    }
    System.out.println("newTaxaNodes length = " + newTaxaNodes.size());
    ArrayList<Taxon> currentTaxa = new ArrayList<Taxon>();
    for (int i = 0; i < treeModel.getExternalNodeCount(); i++) {
        boolean taxonFound = false;
        for (String str : newTaxaNames) {
            if (str.equals((treeModel.getNodeTaxon(treeModel.getExternalNode(i))).getId())) {
                taxonFound = true;
            }
        }
        if (!taxonFound) {
            System.out.println("Adding " + treeModel.getNodeTaxon(treeModel.getExternalNode(i)).getId() + " to list of current taxa");
            currentTaxa.add(treeModel.getNodeTaxon(treeModel.getExternalNode(i)));
        }
    }
    System.out.println("Current taxa count = " + currentTaxa.size());
    // iterate over both current taxa and to be added taxa
    boolean originTaxon = true;
    for (Taxon taxon : currentTaxa) {
        if (taxon.getHeight() == 0.0) {
            originTaxon = false;
            System.out.println("Current taxon " + taxon.getId() + " has node height 0.0");
        }
    }
    for (NodeRef newTaxon : newTaxaNodes) {
        if (treeModel.getNodeTaxon(newTaxon).getHeight() == 0.0) {
            originTaxon = false;
            System.out.println("New taxon " + treeModel.getNodeTaxon(newTaxon).getId() + " has node height 0.0");
        }
    }
    // check the Tree(Data)Likelihoods in the connected set of likelihoods
    // focus on TreeDataLikelihood, which has getTree() to get the tree for each likelihood
    // also get the DataLikelihoodDelegate from TreeDataLikelihood
    ArrayList<TreeDataLikelihood> likelihoods = new ArrayList<TreeDataLikelihood>();
    ArrayList<Tree> trees = new ArrayList<Tree>();
    ArrayList<DataLikelihoodDelegate> delegates = new ArrayList<DataLikelihoodDelegate>();
    for (Likelihood likelihood : Likelihood.CONNECTED_LIKELIHOOD_SET) {
        if (likelihood instanceof TreeDataLikelihood) {
            likelihoods.add((TreeDataLikelihood) likelihood);
            trees.add(((TreeDataLikelihood) likelihood).getTree());
            delegates.add(((TreeDataLikelihood) likelihood).getDataLikelihoodDelegate());
        }
    }
    // suggested to go through TreeDataLikelihoodParser and give it an extra option to create a HashMap
    // keyed by the tree; am currently not overly fond of this approach
    ArrayList<PatternList> patternLists = new ArrayList<PatternList>();
    for (DataLikelihoodDelegate del : delegates) {
        if (del instanceof BeagleDataLikelihoodDelegate) {
            patternLists.add(((BeagleDataLikelihoodDelegate) del).getPatternList());
        } else if (del instanceof MultiPartitionDataLikelihoodDelegate) {
            MultiPartitionDataLikelihoodDelegate mpdld = (MultiPartitionDataLikelihoodDelegate) del;
            List<PatternList> list = mpdld.getPatternLists();
            for (PatternList pList : list) {
                patternLists.add(pList);
            }
        }
    }
    if (patternLists.size() == 0) {
        throw new RuntimeException("No patterns detected. Please make sure the XML file is BEAST 1.9 compatible.");
    }
    // aggregate all patterns to create distance matrix
    // TODO What about different trees for different partitions?
    Patterns patterns = new Patterns(patternLists.get(0));
    if (patternLists.size() > 1) {
        for (int i = 1; i < patternLists.size(); i++) {
            patterns.addPatterns(patternLists.get(i));
        }
    }
    // set the patterns for the distance matrix computations
    choice.setPatterns(patterns);
    // add new taxa one at a time
    System.out.println("Adding " + newTaxaNodes.size() + " taxa ...");
    if (NEW_APPROACH) {
        System.out.println("Branch rates are being updated after each new sequence is inserted");
        int numTaxaSoFar = treeModel.getExternalNodeCount() - newTaxaNodes.size();
        for (NodeRef newTaxon : newTaxaNodes) {
            // check for zero-length and negative length branches and internal nodes that don't have 2 child nodes
            if (DEBUG) {
                for (int i = 0; i < treeModel.getExternalNodeCount(); i++) {
                    NodeRef startingTip = treeModel.getExternalNode(i);
                    while (treeModel.getParent(startingTip) != null) {
                        if (treeModel.getChildCount(treeModel.getParent(startingTip)) != 2) {
                            System.out.println(treeModel.getChildCount(treeModel.getParent(startingTip)) + " children for node " + treeModel.getParent(startingTip));
                            System.out.println("Exiting ...");
                            System.exit(0);
                        }
                        double branchLength = treeModel.getNodeHeight(treeModel.getParent(startingTip)) - treeModel.getNodeHeight(startingTip);
                        if (branchLength == 0.0) {
                            System.out.println("Zero-length branch detected:");
                            System.out.println("  parent node: " + treeModel.getParent(startingTip));
                            System.out.println("  child node: " + startingTip);
                            System.out.println("Exiting ...");
                            System.exit(0);
                        } else if (branchLength < 0.0) {
                            System.out.println("Negative branch length detected:");
                            System.out.println("  parent node: " + treeModel.getParent(startingTip));
                            System.out.println("  child node: " + startingTip);
                            System.out.println("Exiting ...");
                            System.exit(0);
                        } else {
                            startingTip = treeModel.getParent(startingTip);
                        }
                    }
                }
            }
            treeModel.setNodeHeight(newTaxon, treeModel.getNodeTaxon(newTaxon).getHeight());
            System.out.println("\nadding Taxon: " + newTaxon + " (height = " + treeModel.getNodeHeight(newTaxon) + ")");
            // check if this taxon has a more recent sampling date than all other nodes in the current TreeModel
            double offset = checkCurrentTreeNodes(newTaxon, treeModel.getRoot());
            System.out.println("Sampling date offset when adding " + newTaxon + " = " + offset);
            // AND set its current node height to 0.0 IF no originTaxon has been found
            if (offset < 0.0) {
                if (!originTaxon) {
                    System.out.println("Updating all node heights with offset " + Math.abs(offset));
                    updateAllTreeNodes(Math.abs(offset), treeModel.getRoot());
                    treeModel.setNodeHeight(newTaxon, 0.0);
                }
            } else if (offset == 0.0) {
                if (!originTaxon) {
                    treeModel.setNodeHeight(newTaxon, 0.0);
                }
            }
            // get the closest Taxon to the Taxon that needs to be added
            // take into account which taxa can currently be chosen
            Taxon closest = choice.getClosestTaxon(treeModel.getNodeTaxon(newTaxon), currentTaxa);
            System.out.println("\nclosest Taxon: " + closest + " with original height: " + closest.getHeight());
            // get the distance between these two taxa
            double distance = choice.getDistance(treeModel.getNodeTaxon(newTaxon), closest);
            if (distance == 0.0) {
                // employ minimum insertion distance but add in a random factor to avoid identical insertion heights
                // this to avoid multifurcations in the case of (many) identical sequences
                distance = MIN_DIST * MathUtils.nextDouble();
                System.out.println("Sequences are identical, setting minimum distance to " + distance);
            }
            System.out.println("at distance: " + distance);
            // find the NodeRef for the closest Taxon (do not rely on node numbering)
            NodeRef closestRef = null;
            // careful with node numbering and subtract number of new taxa
            for (int i = 0; i < treeModel.getExternalNodeCount(); i++) {
                if (treeModel.getNodeTaxon(treeModel.getExternalNode(i)) == closest) {
                    closestRef = treeModel.getExternalNode(i);
                    System.out.println("  closest external nodeRef: " + closestRef);
                }
            }
            System.out.println("closest node : " + closestRef + " with height " + treeModel.getNodeHeight(closestRef));
            System.out.println("parent node: " + treeModel.getParent(closestRef));
            // begin change
            // TODO: only for Sam, revert back to the line below !!!
            // double timeForDistance = distance / rateModel.getBranchRate(treeModel, closestRef);
            double timeForDistance = distance / getBranchRate(traitModels.get(0), closestRef, numTaxaSoFar);
            // end change
            System.out.println("timeForDistance = " + timeForDistance);
            // get parent node of branch that will be split
            NodeRef parent = treeModel.getParent(closestRef);
            // child node of branch that will be split (will be assigned value later)
            NodeRef splitBranchChild;
            /*if ((treeModel.getNodeHeight(parent) - treeModel.getNodeHeight(closestRef)) == 0.0) {
                    System.out.println("Zero-length branch:");
                    System.out.println(parent);
                    System.out.println(closestRef);
                    System.out.println("Exiting ...");
                    System.exit(0);
                }*/
            // determine height of new node
            double insertHeight;
            if (treeModel.getNodeHeight(closestRef) == treeModel.getNodeHeight(newTaxon)) {
                // if the sequences have the same sampling date/time, then simply split the distance/time between the two sequences in half
                // both the closest sequence and the new sequence are equidistant from the newly inserted internal node
                insertHeight = treeModel.getNodeHeight(closestRef) + timeForDistance / 2.0;
                System.out.println("equal sampling times (" + treeModel.getNodeHeight(closestRef) + ") ; insertHeight = " + insertHeight);
                splitBranchChild = closestRef;
                if (insertHeight >= treeModel.getNodeHeight(parent)) {
                    while (insertHeight >= treeModel.getNodeHeight(parent)) {
                        if (treeModel.getParent(parent) == null) {
                            // Use this insertHeight value in case parent doesn't have parent
                            // Otherwise, move up tree
                            insertHeight = treeModel.getNodeHeight(splitBranchChild) + EPSILON * (treeModel.getNodeHeight(parent) - treeModel.getNodeHeight(splitBranchChild));
                            break;
                        } else {
                            splitBranchChild = parent;
                            parent = treeModel.getParent(splitBranchChild);
                        }
                    }
                }
                // now that a suitable branch has been found, perform additional checks
                // parent of branch = parent; child of branch = splitBranchChild
                System.out.printf("parent height: %.25f \n", treeModel.getNodeHeight(parent));
                System.out.printf("child height: %.25f \n", treeModel.getNodeHeight(splitBranchChild));
                if (((treeModel.getNodeHeight(parent) - insertHeight) < MIN_DIST) && ((insertHeight - treeModel.getNodeHeight(splitBranchChild)) < MIN_DIST)) {
                    throw new RuntimeException("No suitable branch found for sequence insertion (all branches < minimum branch length).");
                } else if ((treeModel.getNodeHeight(parent) - insertHeight) < MIN_DIST) {
                    System.out.println("  insertion height too close to parent height");
                    double newInsertHeight = treeModel.getNodeHeight(splitBranchChild) + Math.abs(MathUtils.nextDouble()) * (insertHeight - treeModel.getNodeHeight(splitBranchChild));
                    while (((treeModel.getNodeHeight(parent) - newInsertHeight) < MIN_DIST) && (newInsertHeight - treeModel.getNodeHeight(splitBranchChild) < MIN_DIST)) {
                        newInsertHeight = treeModel.getNodeHeight(splitBranchChild) + Math.abs(MathUtils.nextDouble()) * (insertHeight - treeModel.getNodeHeight(splitBranchChild));
                    }
                    System.out.println("  new insertion height = " + newInsertHeight);
                    insertHeight = newInsertHeight;
                } else if ((insertHeight - treeModel.getNodeHeight(splitBranchChild)) < MIN_DIST) {
                    System.out.println("  insertion height too close to child height");
                    double newInsertHeight = treeModel.getNodeHeight(splitBranchChild) + Math.abs(MathUtils.nextDouble()) * (treeModel.getNodeHeight(parent) - insertHeight);
                    while (((treeModel.getNodeHeight(parent) - newInsertHeight) < MIN_DIST) && (newInsertHeight - treeModel.getNodeHeight(splitBranchChild) < MIN_DIST)) {
                        newInsertHeight = treeModel.getNodeHeight(splitBranchChild) + Math.abs(MathUtils.nextDouble()) * (treeModel.getNodeHeight(parent) - insertHeight);
                    }
                    System.out.println("  new insertion height = " + newInsertHeight);
                    insertHeight = newInsertHeight;
                }
            } else {
                // first calculate if the new internal node is older than both the new sequence and its closest sequence already present in the tree
                double remainder = (timeForDistance - Math.abs(treeModel.getNodeHeight(closestRef) - treeModel.getNodeHeight(newTaxon))) / 2.0;
                if (remainder > 0) {
                    insertHeight = Math.max(treeModel.getNodeHeight(closestRef), treeModel.getNodeHeight(newTaxon)) + remainder;
                    System.out.println("remainder > 0 (" + remainder + "): " + insertHeight);
                    splitBranchChild = closestRef;
                    if (insertHeight >= treeModel.getNodeHeight(parent)) {
                        while (insertHeight >= treeModel.getNodeHeight(parent)) {
                            if (treeModel.getParent(parent) == null) {
                                // use this insertHeight value in case parent doesn't have parent (and we can't move up the tree)
                                insertHeight = treeModel.getNodeHeight(splitBranchChild) + EPSILON * (treeModel.getNodeHeight(parent) - treeModel.getNodeHeight(splitBranchChild));
                                break;
                            } else {
                                // otherwise move up in the tree
                                splitBranchChild = parent;
                                parent = treeModel.getParent(splitBranchChild);
                            }
                        }
                    }
                    // now that a suitable branch has been found, perform additional checks
                    // parent of branch = parent; child of branch = splitBranchChild
                    System.out.printf("parent height: %.25f \n", treeModel.getNodeHeight(parent));
                    System.out.printf("child height: %.25f \n", treeModel.getNodeHeight(splitBranchChild));
                    if (((treeModel.getNodeHeight(parent) - insertHeight) < MIN_DIST) && ((insertHeight - treeModel.getNodeHeight(splitBranchChild)) < MIN_DIST)) {
                        throw new RuntimeException("No suitable branch found for sequence insertion (all branches < minimum branch length).");
                    } else if ((treeModel.getNodeHeight(parent) - insertHeight) < MIN_DIST) {
                        System.out.println("  insertion height too close to parent height");
                        double newInsertHeight = treeModel.getNodeHeight(splitBranchChild) + Math.abs(MathUtils.nextDouble()) * (insertHeight - treeModel.getNodeHeight(splitBranchChild));
                        while (((treeModel.getNodeHeight(parent) - newInsertHeight) < MIN_DIST) && (newInsertHeight - treeModel.getNodeHeight(splitBranchChild) < MIN_DIST)) {
                            newInsertHeight = treeModel.getNodeHeight(splitBranchChild) + Math.abs(MathUtils.nextDouble()) * (insertHeight - treeModel.getNodeHeight(splitBranchChild));
                        }
                        System.out.println("  new insertion height = " + newInsertHeight);
                        insertHeight = newInsertHeight;
                    } else if ((insertHeight - treeModel.getNodeHeight(splitBranchChild)) < MIN_DIST) {
                        System.out.println("  insertion height too close to child height");
                        double newInsertHeight = treeModel.getNodeHeight(splitBranchChild) + Math.abs(MathUtils.nextDouble()) * (treeModel.getNodeHeight(parent) - insertHeight);
                        while (((treeModel.getNodeHeight(parent) - newInsertHeight) < MIN_DIST) && (newInsertHeight - treeModel.getNodeHeight(splitBranchChild) < MIN_DIST)) {
                            newInsertHeight = treeModel.getNodeHeight(splitBranchChild) + Math.abs(MathUtils.nextDouble()) * (treeModel.getNodeHeight(parent) - insertHeight);
                        }
                        System.out.println("  new insertion height = " + newInsertHeight);
                        insertHeight = newInsertHeight;
                    }
                } else {
                    // Come up with better way to handle this?
                    insertHeight = EPSILON * (treeModel.getNodeHeight(parent) - Math.max(treeModel.getNodeHeight(closestRef), treeModel.getNodeHeight(newTaxon)));
                    System.out.println("insertHeight after EPSILON: " + insertHeight);
                    insertHeight += Math.max(treeModel.getNodeHeight(closestRef), treeModel.getNodeHeight(newTaxon));
                    System.out.println("remainder <= 0: " + insertHeight);
                    double insertDifference = treeModel.getNodeHeight(parent) - insertHeight;
                    System.out.println("height difference = " + insertDifference);
                    splitBranchChild = closestRef;
                    if (insertDifference < MIN_DIST) {
                        System.out.println("branch too short ...");
                        boolean suitableBranch = false;
                        // 2. no negative branch length can be introduced so the parent node needs to be older than the new taxon and its closest reference
                        while (!suitableBranch) {
                            double parentBranchLength = treeModel.getNodeHeight(treeModel.getParent(parent)) - treeModel.getNodeHeight(parent);
                            if ((parentBranchLength < MIN_DIST) || ((treeModel.getNodeHeight(parent) - Math.max(treeModel.getNodeHeight(closestRef), treeModel.getNodeHeight(newTaxon))) < 0.0)) {
                                // find another branch by moving upwards in the tree
                                splitBranchChild = parent;
                                parent = treeModel.getParent(splitBranchChild);
                            } else {
                                // node needs to be inserted along the branch from parent to splitBranchChild
                                double positiveRandom = Math.abs(MathUtils.nextDouble());
                                // insertion height needs to be higher than taxon being inserted
                                double minimumHeight = Math.max(Math.max(treeModel.getNodeHeight(splitBranchChild), treeModel.getNodeHeight(newTaxon)), treeModel.getNodeHeight(closestRef));
                                insertHeight = treeModel.getNodeHeight(parent) - (treeModel.getNodeHeight(parent) - minimumHeight) * positiveRandom;
                                suitableBranch = true;
                            }
                        }
                    }
                }
            }
            System.out.println("insert at height: " + insertHeight);
            System.out.printf("parent height: %.25f \n", treeModel.getNodeHeight(parent));
            System.out.printf("height difference = %.25f\n", (insertHeight - treeModel.getNodeHeight(parent)));
            if (treeModel.getParent(parent) != null) {
                System.out.printf("grandparent height: %.25f \n", treeModel.getNodeHeight(treeModel.getParent(parent)));
            } else {
                System.out.println("parent == root");
            }
            System.out.printf("child height: %.25f \n", treeModel.getNodeHeight(splitBranchChild));
            System.out.printf("insert at height: %.25f \n", insertHeight);
            // pass on all the necessary variables to a method that adds the new taxon to the tree
            addTaxonAlongBranch(newTaxon, parent, splitBranchChild, insertHeight);
            // option to print tree after each taxon addition
            System.out.println("\nTree after adding taxon " + newTaxon + ":\n" + treeModel.toString());
            System.out.println(">>" + treeModel.toString());
            // add newly added Taxon to list of current taxa
            currentTaxa.add(treeModel.getNodeTaxon(newTaxon));
            // Update rate categories here
            interpolateTraitValuesOneInsertion(traitModels, newTaxon);
            numTaxaSoFar++;
        }
    }
    return newTaxaNodes;
}
Also used : Likelihood(dr.inference.model.Likelihood) TreeDataLikelihood(dr.evomodel.treedatalikelihood.TreeDataLikelihood) Taxon(dr.evolution.util.Taxon) ArrayList(java.util.ArrayList) PatternList(dr.evolution.alignment.PatternList) BeagleDataLikelihoodDelegate(dr.evomodel.treedatalikelihood.BeagleDataLikelihoodDelegate) NodeRef(dr.evolution.tree.NodeRef) TreeDataLikelihood(dr.evomodel.treedatalikelihood.TreeDataLikelihood) BeagleDataLikelihoodDelegate(dr.evomodel.treedatalikelihood.BeagleDataLikelihoodDelegate) MultiPartitionDataLikelihoodDelegate(dr.evomodel.treedatalikelihood.MultiPartitionDataLikelihoodDelegate) DataLikelihoodDelegate(dr.evomodel.treedatalikelihood.DataLikelihoodDelegate) Tree(dr.evolution.tree.Tree) PatternList(dr.evolution.alignment.PatternList) ArrayList(java.util.ArrayList) List(java.util.List) MultiPartitionDataLikelihoodDelegate(dr.evomodel.treedatalikelihood.MultiPartitionDataLikelihoodDelegate) Patterns(dr.evolution.alignment.Patterns)

Example 12 with DataLikelihoodDelegate

use of dr.evomodel.treedatalikelihood.DataLikelihoodDelegate in project beast-mcmc by beast-dev.

the class BranchRateGradientParser method parseTreeDataLikelihood.

private GradientWrtParameterProvider parseTreeDataLikelihood(TreeDataLikelihood treeDataLikelihood, String traitName, boolean useHessian) throws XMLParseException {
    BranchRateModel branchRateModel = treeDataLikelihood.getBranchRateModel();
    if (branchRateModel instanceof DefaultBranchRateModel || branchRateModel instanceof ArbitraryBranchRates) {
        Parameter branchRates = null;
        if (branchRateModel instanceof ArbitraryBranchRates) {
            branchRates = ((ArbitraryBranchRates) branchRateModel).getRateParameter();
        }
        DataLikelihoodDelegate delegate = treeDataLikelihood.getDataLikelihoodDelegate();
        if (delegate instanceof ContinuousDataLikelihoodDelegate) {
            ContinuousDataLikelihoodDelegate continuousData = (ContinuousDataLikelihoodDelegate) delegate;
            return new BranchRateGradient(traitName, treeDataLikelihood, continuousData, branchRates);
        } else if (delegate instanceof BeagleDataLikelihoodDelegate) {
            BeagleDataLikelihoodDelegate beagleData = (BeagleDataLikelihoodDelegate) delegate;
            if (branchRateModel instanceof LocalBranchRates) {
                return new LocalBranchRateGradientForDiscreteTrait(traitName, treeDataLikelihood, beagleData, branchRates, useHessian);
            } else {
                return new BranchRateGradientForDiscreteTrait(traitName, treeDataLikelihood, beagleData, branchRates, useHessian);
            }
        } else {
            throw new XMLParseException("Unknown likelihood delegate type");
        }
    } else {
        throw new XMLParseException("Only implemented for an arbitrary rates model");
    }
}
Also used : ArbitraryBranchRates(dr.evomodel.branchratemodel.ArbitraryBranchRates) BranchRateModel(dr.evomodel.branchratemodel.BranchRateModel) DefaultBranchRateModel(dr.evomodel.branchratemodel.DefaultBranchRateModel) ContinuousDataLikelihoodDelegate(dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate) BeagleDataLikelihoodDelegate(dr.evomodel.treedatalikelihood.BeagleDataLikelihoodDelegate) DataLikelihoodDelegate(dr.evomodel.treedatalikelihood.DataLikelihoodDelegate) BranchRateGradient(dr.evomodel.treedatalikelihood.continuous.BranchRateGradient) BranchRateGradientForDiscreteTrait(dr.evomodel.treedatalikelihood.discrete.BranchRateGradientForDiscreteTrait) LocalBranchRateGradientForDiscreteTrait(dr.evomodel.treedatalikelihood.discrete.LocalBranchRateGradientForDiscreteTrait) LocalBranchRateGradientForDiscreteTrait(dr.evomodel.treedatalikelihood.discrete.LocalBranchRateGradientForDiscreteTrait) Parameter(dr.inference.model.Parameter) BeagleDataLikelihoodDelegate(dr.evomodel.treedatalikelihood.BeagleDataLikelihoodDelegate) ContinuousDataLikelihoodDelegate(dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate) DefaultBranchRateModel(dr.evomodel.branchratemodel.DefaultBranchRateModel) LocalBranchRates(dr.evomodel.branchratemodel.LocalBranchRates)

Example 13 with DataLikelihoodDelegate

use of dr.evomodel.treedatalikelihood.DataLikelihoodDelegate in project beast-mcmc by beast-dev.

the class DiffusionGradientParser method parseXMLObject.

@Override
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
    String traitName = xo.getAttribute(TRAIT_NAME, DEFAULT_TRAIT_NAME);
    List<ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient.DerivationParameter> derivationParametersList = new ArrayList<ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient.DerivationParameter>();
    CompoundParameter compoundParameter = new CompoundParameter(null);
    List<GradientWrtParameterProvider> derivativeList = new ArrayList<GradientWrtParameterProvider>();
    List<AbstractDiffusionGradient> diffGradients = xo.getAllChildren(AbstractDiffusionGradient.class);
    if (diffGradients != null) {
        for (AbstractDiffusionGradient grad : diffGradients) {
            derivationParametersList.add(grad.getDerivationParameter());
            compoundParameter.addParameter(grad.getRawParameter());
            derivativeList.add(grad);
        }
    }
    CompoundGradient parametersGradients = new CompoundDerivative(derivativeList);
    // testSameModel(precisionGradient, attenuationGradient);
    TreeDataLikelihood treeDataLikelihood = ((TreeDataLikelihood) diffGradients.get(0).getLikelihood());
    DataLikelihoodDelegate delegate = treeDataLikelihood.getDataLikelihoodDelegate();
    int dim = treeDataLikelihood.getDataLikelihoodDelegate().getTraitDim();
    Tree tree = treeDataLikelihood.getTree();
    ContinuousDataLikelihoodDelegate continuousData = (ContinuousDataLikelihoodDelegate) delegate;
    ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient traitGradient = new ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient(dim, tree, continuousData, derivationParametersList);
    BranchSpecificGradient branchSpecificGradient = new BranchSpecificGradient(traitName, treeDataLikelihood, continuousData, traitGradient, compoundParameter);
    return new DiffusionParametersGradient(branchSpecificGradient, parametersGradients);
}
Also used : CompoundGradient(dr.inference.hmc.CompoundGradient) BranchSpecificGradient(dr.evomodel.treedatalikelihood.continuous.BranchSpecificGradient) AbstractDiffusionGradient(dr.evomodel.treedatalikelihood.hmc.AbstractDiffusionGradient) ArrayList(java.util.ArrayList) CompoundParameter(dr.inference.model.CompoundParameter) DiffusionParametersGradient(dr.evomodel.treedatalikelihood.hmc.DiffusionParametersGradient) TreeDataLikelihood(dr.evomodel.treedatalikelihood.TreeDataLikelihood) ContinuousDataLikelihoodDelegate(dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate) DataLikelihoodDelegate(dr.evomodel.treedatalikelihood.DataLikelihoodDelegate) CompoundDerivative(dr.inference.hmc.CompoundDerivative) GradientWrtParameterProvider(dr.inference.hmc.GradientWrtParameterProvider) Tree(dr.evolution.tree.Tree) ContinuousTraitGradientForBranch(dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch) ContinuousDataLikelihoodDelegate(dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate)

Example 14 with DataLikelihoodDelegate

use of dr.evomodel.treedatalikelihood.DataLikelihoodDelegate in project beast-mcmc by beast-dev.

the class TreePrecisionDataProductProviderParser method parseXMLObject.

@Override
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
    String traitName = xo.getAttribute(TRAIT_NAME, DEFAULT_TRAIT_NAME);
    TreeDataLikelihood treeDataLikelihood = (TreeDataLikelihood) xo.getChild(TreeDataLikelihood.class);
    DataLikelihoodDelegate delegate = treeDataLikelihood.getDataLikelihoodDelegate();
    if (!(delegate instanceof ContinuousDataLikelihoodDelegate)) {
        throw new XMLParseException("May not provide a sequence data likelihood to compute tip trait gradient");
    }
    ContinuousDataLikelihoodDelegate continuousData = (ContinuousDataLikelihoodDelegate) delegate;
    return parseComputeMode(xo, treeDataLikelihood, continuousData, traitName);
}
Also used : TreeDataLikelihood(dr.evomodel.treedatalikelihood.TreeDataLikelihood) ContinuousDataLikelihoodDelegate(dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate) DataLikelihoodDelegate(dr.evomodel.treedatalikelihood.DataLikelihoodDelegate) ContinuousDataLikelihoodDelegate(dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate)

Aggregations

DataLikelihoodDelegate (dr.evomodel.treedatalikelihood.DataLikelihoodDelegate)14 TreeDataLikelihood (dr.evomodel.treedatalikelihood.TreeDataLikelihood)12 ContinuousDataLikelihoodDelegate (dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate)8 BeagleDataLikelihoodDelegate (dr.evomodel.treedatalikelihood.BeagleDataLikelihoodDelegate)7 Tree (dr.evolution.tree.Tree)6 ArrayList (java.util.ArrayList)6 Parameter (dr.inference.model.Parameter)5 Taxon (dr.evolution.util.Taxon)4 MultiPartitionDataLikelihoodDelegate (dr.evomodel.treedatalikelihood.MultiPartitionDataLikelihoodDelegate)4 Likelihood (dr.inference.model.Likelihood)4 ArbitraryBranchRates (dr.evomodel.branchratemodel.ArbitraryBranchRates)3 BranchRateModel (dr.evomodel.branchratemodel.BranchRateModel)3 DefaultBranchRateModel (dr.evomodel.branchratemodel.DefaultBranchRateModel)3 PatternList (dr.evolution.alignment.PatternList)2 Patterns (dr.evolution.alignment.Patterns)2 NodeRef (dr.evolution.tree.NodeRef)2 BranchSpecificGradient (dr.evomodel.treedatalikelihood.continuous.BranchSpecificGradient)2 ContinuousTraitGradientForBranch (dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch)2 CompoundLikelihood (dr.inference.model.CompoundLikelihood)2 MatrixParameterInterface (dr.inference.model.MatrixParameterInterface)2