use of dr.evomodel.treedatalikelihood.BeagleDataLikelihoodDelegate in project beast-mcmc by beast-dev.
the class CheckPointTreeModifier method incorporateAdditionalTaxa.
/**
* Add the remaining taxa, which can be identified through the TreeDataLikelihood XML elements.
*/
public ArrayList<NodeRef> incorporateAdditionalTaxa(CheckPointUpdaterApp.UpdateChoice choice, BranchRates rateModel, ArrayList<TreeParameterModel> traitModels) {
// public ArrayList<NodeRef> incorporateAdditionalTaxa(CheckPointUpdaterApp.UpdateChoice choice, BranchRates rateModel) {
System.out.println("Tree before adding taxa:\n" + treeModel.toString() + "\n");
ArrayList<NodeRef> newTaxaNodes = new ArrayList<NodeRef>();
for (String str : newTaxaNames) {
for (int i = 0; i < treeModel.getExternalNodeCount(); i++) {
if (treeModel.getNodeTaxon(treeModel.getExternalNode(i)).getId().equals(str)) {
newTaxaNodes.add(treeModel.getExternalNode(i));
// always take into account Taxon dates vs. dates set through a TreeModel
System.out.println(treeModel.getNodeTaxon(treeModel.getExternalNode(i)).getId() + " with height " + treeModel.getNodeHeight(treeModel.getExternalNode(i)) + " or " + treeModel.getNodeTaxon(treeModel.getExternalNode(i)).getHeight());
}
}
}
System.out.println("newTaxaNodes length = " + newTaxaNodes.size());
ArrayList<Taxon> currentTaxa = new ArrayList<Taxon>();
for (int i = 0; i < treeModel.getExternalNodeCount(); i++) {
boolean taxonFound = false;
for (String str : newTaxaNames) {
if (str.equals((treeModel.getNodeTaxon(treeModel.getExternalNode(i))).getId())) {
taxonFound = true;
}
}
if (!taxonFound) {
System.out.println("Adding " + treeModel.getNodeTaxon(treeModel.getExternalNode(i)).getId() + " to list of current taxa");
currentTaxa.add(treeModel.getNodeTaxon(treeModel.getExternalNode(i)));
}
}
System.out.println("Current taxa count = " + currentTaxa.size());
// iterate over both current taxa and to be added taxa
boolean originTaxon = true;
for (Taxon taxon : currentTaxa) {
if (taxon.getHeight() == 0.0) {
originTaxon = false;
System.out.println("Current taxon " + taxon.getId() + " has node height 0.0");
}
}
for (NodeRef newTaxon : newTaxaNodes) {
if (treeModel.getNodeTaxon(newTaxon).getHeight() == 0.0) {
originTaxon = false;
System.out.println("New taxon " + treeModel.getNodeTaxon(newTaxon).getId() + " has node height 0.0");
}
}
// check the Tree(Data)Likelihoods in the connected set of likelihoods
// focus on TreeDataLikelihood, which has getTree() to get the tree for each likelihood
// also get the DataLikelihoodDelegate from TreeDataLikelihood
ArrayList<TreeDataLikelihood> likelihoods = new ArrayList<TreeDataLikelihood>();
ArrayList<Tree> trees = new ArrayList<Tree>();
ArrayList<DataLikelihoodDelegate> delegates = new ArrayList<DataLikelihoodDelegate>();
for (Likelihood likelihood : Likelihood.CONNECTED_LIKELIHOOD_SET) {
if (likelihood instanceof TreeDataLikelihood) {
likelihoods.add((TreeDataLikelihood) likelihood);
trees.add(((TreeDataLikelihood) likelihood).getTree());
delegates.add(((TreeDataLikelihood) likelihood).getDataLikelihoodDelegate());
}
}
// suggested to go through TreeDataLikelihoodParser and give it an extra option to create a HashMap
// keyed by the tree; am currently not overly fond of this approach
ArrayList<PatternList> patternLists = new ArrayList<PatternList>();
for (DataLikelihoodDelegate del : delegates) {
if (del instanceof BeagleDataLikelihoodDelegate) {
patternLists.add(((BeagleDataLikelihoodDelegate) del).getPatternList());
} else if (del instanceof MultiPartitionDataLikelihoodDelegate) {
MultiPartitionDataLikelihoodDelegate mpdld = (MultiPartitionDataLikelihoodDelegate) del;
List<PatternList> list = mpdld.getPatternLists();
for (PatternList pList : list) {
patternLists.add(pList);
}
}
}
if (patternLists.size() == 0) {
throw new RuntimeException("No patterns detected. Please make sure the XML file is BEAST 1.9 compatible.");
}
// aggregate all patterns to create distance matrix
// TODO What about different trees for different partitions?
Patterns patterns = new Patterns(patternLists.get(0));
if (patternLists.size() > 1) {
for (int i = 1; i < patternLists.size(); i++) {
patterns.addPatterns(patternLists.get(i));
}
}
// set the patterns for the distance matrix computations
choice.setPatterns(patterns);
// add new taxa one at a time
System.out.println("Adding " + newTaxaNodes.size() + " taxa ...");
if (NEW_APPROACH) {
System.out.println("Branch rates are being updated after each new sequence is inserted");
int numTaxaSoFar = treeModel.getExternalNodeCount() - newTaxaNodes.size();
for (NodeRef newTaxon : newTaxaNodes) {
// check for zero-length and negative length branches and internal nodes that don't have 2 child nodes
if (DEBUG) {
for (int i = 0; i < treeModel.getExternalNodeCount(); i++) {
NodeRef startingTip = treeModel.getExternalNode(i);
while (treeModel.getParent(startingTip) != null) {
if (treeModel.getChildCount(treeModel.getParent(startingTip)) != 2) {
System.out.println(treeModel.getChildCount(treeModel.getParent(startingTip)) + " children for node " + treeModel.getParent(startingTip));
System.out.println("Exiting ...");
System.exit(0);
}
double branchLength = treeModel.getNodeHeight(treeModel.getParent(startingTip)) - treeModel.getNodeHeight(startingTip);
if (branchLength == 0.0) {
System.out.println("Zero-length branch detected:");
System.out.println(" parent node: " + treeModel.getParent(startingTip));
System.out.println(" child node: " + startingTip);
System.out.println("Exiting ...");
System.exit(0);
} else if (branchLength < 0.0) {
System.out.println("Negative branch length detected:");
System.out.println(" parent node: " + treeModel.getParent(startingTip));
System.out.println(" child node: " + startingTip);
System.out.println("Exiting ...");
System.exit(0);
} else {
startingTip = treeModel.getParent(startingTip);
}
}
}
}
treeModel.setNodeHeight(newTaxon, treeModel.getNodeTaxon(newTaxon).getHeight());
System.out.println("\nadding Taxon: " + newTaxon + " (height = " + treeModel.getNodeHeight(newTaxon) + ")");
// check if this taxon has a more recent sampling date than all other nodes in the current TreeModel
double offset = checkCurrentTreeNodes(newTaxon, treeModel.getRoot());
System.out.println("Sampling date offset when adding " + newTaxon + " = " + offset);
// AND set its current node height to 0.0 IF no originTaxon has been found
if (offset < 0.0) {
if (!originTaxon) {
System.out.println("Updating all node heights with offset " + Math.abs(offset));
updateAllTreeNodes(Math.abs(offset), treeModel.getRoot());
treeModel.setNodeHeight(newTaxon, 0.0);
}
} else if (offset == 0.0) {
if (!originTaxon) {
treeModel.setNodeHeight(newTaxon, 0.0);
}
}
// get the closest Taxon to the Taxon that needs to be added
// take into account which taxa can currently be chosen
Taxon closest = choice.getClosestTaxon(treeModel.getNodeTaxon(newTaxon), currentTaxa);
System.out.println("\nclosest Taxon: " + closest + " with original height: " + closest.getHeight());
// get the distance between these two taxa
double distance = choice.getDistance(treeModel.getNodeTaxon(newTaxon), closest);
if (distance == 0.0) {
// employ minimum insertion distance but add in a random factor to avoid identical insertion heights
// this to avoid multifurcations in the case of (many) identical sequences
distance = MIN_DIST * MathUtils.nextDouble();
System.out.println("Sequences are identical, setting minimum distance to " + distance);
}
System.out.println("at distance: " + distance);
// find the NodeRef for the closest Taxon (do not rely on node numbering)
NodeRef closestRef = null;
// careful with node numbering and subtract number of new taxa
for (int i = 0; i < treeModel.getExternalNodeCount(); i++) {
if (treeModel.getNodeTaxon(treeModel.getExternalNode(i)) == closest) {
closestRef = treeModel.getExternalNode(i);
System.out.println(" closest external nodeRef: " + closestRef);
}
}
System.out.println("closest node : " + closestRef + " with height " + treeModel.getNodeHeight(closestRef));
System.out.println("parent node: " + treeModel.getParent(closestRef));
// begin change
// TODO: only for Sam, revert back to the line below !!!
// double timeForDistance = distance / rateModel.getBranchRate(treeModel, closestRef);
double timeForDistance = distance / getBranchRate(traitModels.get(0), closestRef, numTaxaSoFar);
// end change
System.out.println("timeForDistance = " + timeForDistance);
// get parent node of branch that will be split
NodeRef parent = treeModel.getParent(closestRef);
// child node of branch that will be split (will be assigned value later)
NodeRef splitBranchChild;
/*if ((treeModel.getNodeHeight(parent) - treeModel.getNodeHeight(closestRef)) == 0.0) {
System.out.println("Zero-length branch:");
System.out.println(parent);
System.out.println(closestRef);
System.out.println("Exiting ...");
System.exit(0);
}*/
// determine height of new node
double insertHeight;
if (treeModel.getNodeHeight(closestRef) == treeModel.getNodeHeight(newTaxon)) {
// if the sequences have the same sampling date/time, then simply split the distance/time between the two sequences in half
// both the closest sequence and the new sequence are equidistant from the newly inserted internal node
insertHeight = treeModel.getNodeHeight(closestRef) + timeForDistance / 2.0;
System.out.println("equal sampling times (" + treeModel.getNodeHeight(closestRef) + ") ; insertHeight = " + insertHeight);
splitBranchChild = closestRef;
if (insertHeight >= treeModel.getNodeHeight(parent)) {
while (insertHeight >= treeModel.getNodeHeight(parent)) {
if (treeModel.getParent(parent) == null) {
// Use this insertHeight value in case parent doesn't have parent
// Otherwise, move up tree
insertHeight = treeModel.getNodeHeight(splitBranchChild) + EPSILON * (treeModel.getNodeHeight(parent) - treeModel.getNodeHeight(splitBranchChild));
break;
} else {
splitBranchChild = parent;
parent = treeModel.getParent(splitBranchChild);
}
}
}
// now that a suitable branch has been found, perform additional checks
// parent of branch = parent; child of branch = splitBranchChild
System.out.printf("parent height: %.25f \n", treeModel.getNodeHeight(parent));
System.out.printf("child height: %.25f \n", treeModel.getNodeHeight(splitBranchChild));
if (((treeModel.getNodeHeight(parent) - insertHeight) < MIN_DIST) && ((insertHeight - treeModel.getNodeHeight(splitBranchChild)) < MIN_DIST)) {
throw new RuntimeException("No suitable branch found for sequence insertion (all branches < minimum branch length).");
} else if ((treeModel.getNodeHeight(parent) - insertHeight) < MIN_DIST) {
System.out.println(" insertion height too close to parent height");
double newInsertHeight = treeModel.getNodeHeight(splitBranchChild) + Math.abs(MathUtils.nextDouble()) * (insertHeight - treeModel.getNodeHeight(splitBranchChild));
while (((treeModel.getNodeHeight(parent) - newInsertHeight) < MIN_DIST) && (newInsertHeight - treeModel.getNodeHeight(splitBranchChild) < MIN_DIST)) {
newInsertHeight = treeModel.getNodeHeight(splitBranchChild) + Math.abs(MathUtils.nextDouble()) * (insertHeight - treeModel.getNodeHeight(splitBranchChild));
}
System.out.println(" new insertion height = " + newInsertHeight);
insertHeight = newInsertHeight;
} else if ((insertHeight - treeModel.getNodeHeight(splitBranchChild)) < MIN_DIST) {
System.out.println(" insertion height too close to child height");
double newInsertHeight = treeModel.getNodeHeight(splitBranchChild) + Math.abs(MathUtils.nextDouble()) * (treeModel.getNodeHeight(parent) - insertHeight);
while (((treeModel.getNodeHeight(parent) - newInsertHeight) < MIN_DIST) && (newInsertHeight - treeModel.getNodeHeight(splitBranchChild) < MIN_DIST)) {
newInsertHeight = treeModel.getNodeHeight(splitBranchChild) + Math.abs(MathUtils.nextDouble()) * (treeModel.getNodeHeight(parent) - insertHeight);
}
System.out.println(" new insertion height = " + newInsertHeight);
insertHeight = newInsertHeight;
}
} else {
// first calculate if the new internal node is older than both the new sequence and its closest sequence already present in the tree
double remainder = (timeForDistance - Math.abs(treeModel.getNodeHeight(closestRef) - treeModel.getNodeHeight(newTaxon))) / 2.0;
if (remainder > 0) {
insertHeight = Math.max(treeModel.getNodeHeight(closestRef), treeModel.getNodeHeight(newTaxon)) + remainder;
System.out.println("remainder > 0 (" + remainder + "): " + insertHeight);
splitBranchChild = closestRef;
if (insertHeight >= treeModel.getNodeHeight(parent)) {
while (insertHeight >= treeModel.getNodeHeight(parent)) {
if (treeModel.getParent(parent) == null) {
// use this insertHeight value in case parent doesn't have parent (and we can't move up the tree)
insertHeight = treeModel.getNodeHeight(splitBranchChild) + EPSILON * (treeModel.getNodeHeight(parent) - treeModel.getNodeHeight(splitBranchChild));
break;
} else {
// otherwise move up in the tree
splitBranchChild = parent;
parent = treeModel.getParent(splitBranchChild);
}
}
}
// now that a suitable branch has been found, perform additional checks
// parent of branch = parent; child of branch = splitBranchChild
System.out.printf("parent height: %.25f \n", treeModel.getNodeHeight(parent));
System.out.printf("child height: %.25f \n", treeModel.getNodeHeight(splitBranchChild));
if (((treeModel.getNodeHeight(parent) - insertHeight) < MIN_DIST) && ((insertHeight - treeModel.getNodeHeight(splitBranchChild)) < MIN_DIST)) {
throw new RuntimeException("No suitable branch found for sequence insertion (all branches < minimum branch length).");
} else if ((treeModel.getNodeHeight(parent) - insertHeight) < MIN_DIST) {
System.out.println(" insertion height too close to parent height");
double newInsertHeight = treeModel.getNodeHeight(splitBranchChild) + Math.abs(MathUtils.nextDouble()) * (insertHeight - treeModel.getNodeHeight(splitBranchChild));
while (((treeModel.getNodeHeight(parent) - newInsertHeight) < MIN_DIST) && (newInsertHeight - treeModel.getNodeHeight(splitBranchChild) < MIN_DIST)) {
newInsertHeight = treeModel.getNodeHeight(splitBranchChild) + Math.abs(MathUtils.nextDouble()) * (insertHeight - treeModel.getNodeHeight(splitBranchChild));
}
System.out.println(" new insertion height = " + newInsertHeight);
insertHeight = newInsertHeight;
} else if ((insertHeight - treeModel.getNodeHeight(splitBranchChild)) < MIN_DIST) {
System.out.println(" insertion height too close to child height");
double newInsertHeight = treeModel.getNodeHeight(splitBranchChild) + Math.abs(MathUtils.nextDouble()) * (treeModel.getNodeHeight(parent) - insertHeight);
while (((treeModel.getNodeHeight(parent) - newInsertHeight) < MIN_DIST) && (newInsertHeight - treeModel.getNodeHeight(splitBranchChild) < MIN_DIST)) {
newInsertHeight = treeModel.getNodeHeight(splitBranchChild) + Math.abs(MathUtils.nextDouble()) * (treeModel.getNodeHeight(parent) - insertHeight);
}
System.out.println(" new insertion height = " + newInsertHeight);
insertHeight = newInsertHeight;
}
} else {
// Come up with better way to handle this?
insertHeight = EPSILON * (treeModel.getNodeHeight(parent) - Math.max(treeModel.getNodeHeight(closestRef), treeModel.getNodeHeight(newTaxon)));
System.out.println("insertHeight after EPSILON: " + insertHeight);
insertHeight += Math.max(treeModel.getNodeHeight(closestRef), treeModel.getNodeHeight(newTaxon));
System.out.println("remainder <= 0: " + insertHeight);
double insertDifference = treeModel.getNodeHeight(parent) - insertHeight;
System.out.println("height difference = " + insertDifference);
splitBranchChild = closestRef;
if (insertDifference < MIN_DIST) {
System.out.println("branch too short ...");
boolean suitableBranch = false;
// 2. no negative branch length can be introduced so the parent node needs to be older than the new taxon and its closest reference
while (!suitableBranch) {
double parentBranchLength = treeModel.getNodeHeight(treeModel.getParent(parent)) - treeModel.getNodeHeight(parent);
if ((parentBranchLength < MIN_DIST) || ((treeModel.getNodeHeight(parent) - Math.max(treeModel.getNodeHeight(closestRef), treeModel.getNodeHeight(newTaxon))) < 0.0)) {
// find another branch by moving upwards in the tree
splitBranchChild = parent;
parent = treeModel.getParent(splitBranchChild);
} else {
// node needs to be inserted along the branch from parent to splitBranchChild
double positiveRandom = Math.abs(MathUtils.nextDouble());
// insertion height needs to be higher than taxon being inserted
double minimumHeight = Math.max(Math.max(treeModel.getNodeHeight(splitBranchChild), treeModel.getNodeHeight(newTaxon)), treeModel.getNodeHeight(closestRef));
insertHeight = treeModel.getNodeHeight(parent) - (treeModel.getNodeHeight(parent) - minimumHeight) * positiveRandom;
suitableBranch = true;
}
}
}
}
}
System.out.println("insert at height: " + insertHeight);
System.out.printf("parent height: %.25f \n", treeModel.getNodeHeight(parent));
System.out.printf("height difference = %.25f\n", (insertHeight - treeModel.getNodeHeight(parent)));
if (treeModel.getParent(parent) != null) {
System.out.printf("grandparent height: %.25f \n", treeModel.getNodeHeight(treeModel.getParent(parent)));
} else {
System.out.println("parent == root");
}
System.out.printf("child height: %.25f \n", treeModel.getNodeHeight(splitBranchChild));
System.out.printf("insert at height: %.25f \n", insertHeight);
// pass on all the necessary variables to a method that adds the new taxon to the tree
addTaxonAlongBranch(newTaxon, parent, splitBranchChild, insertHeight);
// option to print tree after each taxon addition
System.out.println("\nTree after adding taxon " + newTaxon + ":\n" + treeModel.toString());
System.out.println(">>" + treeModel.toString());
// add newly added Taxon to list of current taxa
currentTaxa.add(treeModel.getNodeTaxon(newTaxon));
// Update rate categories here
interpolateTraitValuesOneInsertion(traitModels, newTaxon);
numTaxaSoFar++;
}
}
return newTaxaNodes;
}
use of dr.evomodel.treedatalikelihood.BeagleDataLikelihoodDelegate in project beast-mcmc by beast-dev.
the class BranchSubstitutionParameterGradientParser method parseXMLObject.
@Override
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
String traitName = xo.getAttribute(TRAIT_NAME, DEFAULT_TRAIT_NAME);
boolean useHessian = xo.getAttribute(USE_HESSIAN, false);
final TreeDataLikelihood treeDataLikelihood = (TreeDataLikelihood) xo.getChild(TreeDataLikelihood.class);
BranchSpecificSubstitutionParameterBranchModel branchModel = (BranchSpecificSubstitutionParameterBranchModel) xo.getChild(BranchModel.class);
BeagleDataLikelihoodDelegate beagleData = (BeagleDataLikelihoodDelegate) treeDataLikelihood.getDataLikelihoodDelegate();
BranchRateModel branchRateModel = (BranchRateModel) xo.getChild(BranchRateModel.class);
CompoundParameter branchParameter = branchModel.getBranchSpecificParameters(branchRateModel);
return new BranchSubstitutionParameterGradient(traitName, treeDataLikelihood, beagleData, branchParameter, branchRateModel, useHessian);
}
use of dr.evomodel.treedatalikelihood.BeagleDataLikelihoodDelegate in project beast-mcmc by beast-dev.
the class BranchRateGradientParser method parseTreeDataLikelihood.
private GradientWrtParameterProvider parseTreeDataLikelihood(TreeDataLikelihood treeDataLikelihood, String traitName, boolean useHessian) throws XMLParseException {
BranchRateModel branchRateModel = treeDataLikelihood.getBranchRateModel();
if (branchRateModel instanceof DefaultBranchRateModel || branchRateModel instanceof ArbitraryBranchRates) {
Parameter branchRates = null;
if (branchRateModel instanceof ArbitraryBranchRates) {
branchRates = ((ArbitraryBranchRates) branchRateModel).getRateParameter();
}
DataLikelihoodDelegate delegate = treeDataLikelihood.getDataLikelihoodDelegate();
if (delegate instanceof ContinuousDataLikelihoodDelegate) {
ContinuousDataLikelihoodDelegate continuousData = (ContinuousDataLikelihoodDelegate) delegate;
return new BranchRateGradient(traitName, treeDataLikelihood, continuousData, branchRates);
} else if (delegate instanceof BeagleDataLikelihoodDelegate) {
BeagleDataLikelihoodDelegate beagleData = (BeagleDataLikelihoodDelegate) delegate;
if (branchRateModel instanceof LocalBranchRates) {
return new LocalBranchRateGradientForDiscreteTrait(traitName, treeDataLikelihood, beagleData, branchRates, useHessian);
} else {
return new BranchRateGradientForDiscreteTrait(traitName, treeDataLikelihood, beagleData, branchRates, useHessian);
}
} else {
throw new XMLParseException("Unknown likelihood delegate type");
}
} else {
throw new XMLParseException("Only implemented for an arbitrary rates model");
}
}
Aggregations