use of dr.inference.model.Likelihood in project beast-mcmc by beast-dev.
the class TreeLoggerParser method parseXMLParameters.
protected void parseXMLParameters(XMLObject xo) throws XMLParseException {
// reset this every time...
branchRates = null;
tree = (Tree) xo.getChild(Tree.class);
title = xo.getAttribute(TITLE, "");
nexusFormat = xo.getAttribute(NEXUS_FORMAT, false);
sortTranslationTable = xo.getAttribute(SORT_TRANSLATION_TABLE, true);
boolean substitutions = xo.getAttribute(BRANCH_LENGTHS, "").equals(SUBSTITUTIONS);
List<TreeAttributeProvider> taps = new ArrayList<TreeAttributeProvider>();
List<TreeTraitProvider> ttps = new ArrayList<TreeTraitProvider>();
// ttps2 are for TTPs that are not specified within a Trait element. These are only
// included if not already added through a trait element to avoid duplication of
// (in particular) the BranchRates which is required for substitution trees.
List<TreeTraitProvider> ttps2 = new ArrayList<TreeTraitProvider>();
for (int i = 0; i < xo.getChildCount(); i++) {
Object cxo = xo.getChild(i);
if (cxo instanceof Likelihood) {
final Likelihood likelihood = (Likelihood) cxo;
taps.add(new TreeAttributeProvider() {
public String[] getTreeAttributeLabel() {
return new String[] { "lnP" };
}
public String[] getAttributeForTree(Tree tree) {
return new String[] { Double.toString(likelihood.getLogLikelihood()) };
}
});
}
if (cxo instanceof TreeAttributeProvider) {
taps.add((TreeAttributeProvider) cxo);
}
if (cxo instanceof TreeTraitProvider) {
if (xo.hasAttribute(FILTER_TRAITS)) {
String[] matches = ((String) xo.getAttribute(FILTER_TRAITS)).split("[\\s,]+");
TreeTraitProvider ttp = (TreeTraitProvider) cxo;
TreeTrait[] traits = ttp.getTreeTraits();
List<TreeTrait> filteredTraits = new ArrayList<TreeTrait>();
for (String match : matches) {
for (TreeTrait trait : traits) {
if (trait.getTraitName().startsWith(match)) {
filteredTraits.add(trait);
}
}
}
if (filteredTraits.size() > 0) {
ttps2.add(new TreeTraitProvider.Helper(filteredTraits));
}
} else {
// Add all of them
ttps2.add((TreeTraitProvider) cxo);
}
}
if (cxo instanceof XMLObject) {
XMLObject xco = (XMLObject) cxo;
if (xco.getName().equals(TREE_TRAIT)) {
TreeTraitProvider ttp = (TreeTraitProvider) xco.getChild(TreeTraitProvider.class);
if (xco.hasAttribute(NAME)) {
// a specific named trait is required (optionally with a tag to name it in the tree file)
String name = xco.getStringAttribute(NAME);
final TreeTrait trait = ttp.getTreeTrait(name);
if (trait == null) {
String childName = "TreeTraitProvider";
if (ttp instanceof Likelihood) {
childName = ((Likelihood) ttp).prettyName();
} else if (ttp instanceof Model) {
childName = ((Model) ttp).getModelName();
}
throw new XMLParseException("Trait named, " + name + ", not found for " + childName);
}
final String tag;
if (xco.hasAttribute(TAG)) {
tag = xco.getStringAttribute(TAG);
} else {
tag = name;
}
ttps.add(new TreeTraitProvider.Helper(tag, new TreeTrait() {
public String getTraitName() {
return tag;
}
public Intent getIntent() {
return trait.getIntent();
}
public Class getTraitClass() {
return trait.getTraitClass();
}
public Object getTrait(Tree tree, NodeRef node) {
return trait.getTrait(tree, node);
}
public String getTraitString(Tree tree, NodeRef node) {
return trait.getTraitString(tree, node);
}
public boolean getLoggable() {
return trait.getLoggable();
}
}));
} else if (xo.hasAttribute(FILTER_TRAITS)) {
// else a filter attribute is given to ask for all traits that starts with a specific
// string
String[] matches = ((String) xo.getAttribute(FILTER_TRAITS)).split("[\\s,]+");
TreeTrait[] traits = ttp.getTreeTraits();
List<TreeTrait> filteredTraits = new ArrayList<TreeTrait>();
for (String match : matches) {
for (TreeTrait trait : traits) {
if (trait.getTraitName().startsWith(match)) {
filteredTraits.add(trait);
}
}
}
if (filteredTraits.size() > 0) {
ttps.add(new TreeTraitProvider.Helper(filteredTraits));
}
} else {
// neither named or filtered traits so just add them all
ttps.add(ttp);
}
}
}
// be able to put arbitrary statistics in as tree attributes
if (cxo instanceof Loggable) {
final Loggable loggable = (Loggable) cxo;
taps.add(new TreeAttributeProvider() {
public String[] getTreeAttributeLabel() {
String[] labels = new String[loggable.getColumns().length];
for (int i = 0; i < loggable.getColumns().length; i++) {
labels[i] = loggable.getColumns()[i].getLabel();
}
return labels;
}
public String[] getAttributeForTree(Tree tree) {
String[] values = new String[loggable.getColumns().length];
for (int i = 0; i < loggable.getColumns().length; i++) {
values[i] = loggable.getColumns()[i].getFormatted();
}
return values;
}
});
}
}
// inclusion of the codon partitioned robust counting TTP...
if (ttps2.size() > 0) {
ttps.addAll(ttps2);
}
if (substitutions) {
branchRates = (BranchRates) xo.getChild(BranchRates.class);
}
if (substitutions && branchRates == null) {
throw new XMLParseException("To log trees in units of substitutions a BranchRateModel must be provided");
}
// logEvery of zero only displays at the end
logEvery = xo.getAttribute(LOG_EVERY, 0);
// double normaliseMeanRateTo = xo.getAttribute(NORMALISE_MEAN_RATE_TO, Double.NaN);
// decimal places
final int dp = xo.getAttribute(DECIMAL_PLACES, -1);
if (dp != -1) {
format = NumberFormat.getNumberInstance(Locale.ENGLISH);
format.setMaximumFractionDigits(dp);
}
final PrintWriter pw = getLogFile(xo, getParserName());
formatter = new TabDelimitedFormatter(pw);
treeAttributeProviders = new TreeAttributeProvider[taps.size()];
taps.toArray(treeAttributeProviders);
treeTraitProviders = new TreeTraitProvider[ttps.size()];
ttps.toArray(treeTraitProviders);
// I think the default should be to have names rather than numbers, thus the false default - AJD
// I think the default should be numbers - using names results in larger files and end user never
// sees the numbers anyway as any software loading the nexus files does the translation - JH
mapNames = xo.getAttribute(MAP_NAMES, true);
condition = logEvery == 0 ? (TreeLogger.LogUpon) xo.getChild(TreeLogger.LogUpon.class) : null;
}
use of dr.inference.model.Likelihood in project beast-mcmc by beast-dev.
the class CheckPointTreeModifier method incorporateAdditionalTaxa.
/**
* Add the remaining taxa, which can be identified through the TreeDataLikelihood XML elements.
*/
public ArrayList<NodeRef> incorporateAdditionalTaxa(CheckPointUpdaterApp.UpdateChoice choice, BranchRates rateModel) {
System.out.println("Tree before adding taxa:\n" + treeModel.toString() + "\n");
ArrayList<NodeRef> newTaxaNodes = new ArrayList<NodeRef>();
for (String str : newTaxaNames) {
for (int i = 0; i < treeModel.getExternalNodeCount(); i++) {
if (treeModel.getNodeTaxon(treeModel.getExternalNode(i)).getId().equals(str)) {
newTaxaNodes.add(treeModel.getExternalNode(i));
//always take into account Taxon dates vs. dates set through a TreeModel
System.out.println(treeModel.getNodeTaxon(treeModel.getExternalNode(i)).getId() + " with height " + treeModel.getNodeHeight(treeModel.getExternalNode(i)) + " or " + treeModel.getNodeTaxon(treeModel.getExternalNode(i)).getHeight());
}
}
}
System.out.println("newTaxaNodes length = " + newTaxaNodes.size());
ArrayList<Taxon> currentTaxa = new ArrayList<Taxon>();
for (int i = 0; i < treeModel.getExternalNodeCount(); i++) {
boolean taxonFound = false;
for (String str : newTaxaNames) {
if (str.equals((treeModel.getNodeTaxon(treeModel.getExternalNode(i))).getId())) {
taxonFound = true;
}
}
if (!taxonFound) {
System.out.println("Adding " + treeModel.getNodeTaxon(treeModel.getExternalNode(i)).getId() + " to list of current taxa");
currentTaxa.add(treeModel.getNodeTaxon(treeModel.getExternalNode(i)));
}
}
System.out.println("Current taxa count = " + currentTaxa.size());
//iterate over both current taxa and to be added taxa
boolean originTaxon = true;
for (Taxon taxon : currentTaxa) {
if (taxon.getHeight() == 0.0) {
originTaxon = false;
System.out.println("Current taxon " + taxon.getId() + " has node height 0.0");
}
}
for (NodeRef newTaxon : newTaxaNodes) {
if (treeModel.getNodeTaxon(newTaxon).getHeight() == 0.0) {
originTaxon = false;
System.out.println("New taxon " + treeModel.getNodeTaxon(newTaxon).getId() + " has node height 0.0");
}
}
//check the Tree(Data)Likelihoods in the connected set of likelihoods
//focus on TreeDataLikelihood, which has getTree() to get the tree for each likelihood
//also get the DataLikelihoodDelegate from TreeDataLikelihood
ArrayList<TreeDataLikelihood> likelihoods = new ArrayList<TreeDataLikelihood>();
ArrayList<Tree> trees = new ArrayList<Tree>();
ArrayList<DataLikelihoodDelegate> delegates = new ArrayList<DataLikelihoodDelegate>();
for (Likelihood likelihood : Likelihood.CONNECTED_LIKELIHOOD_SET) {
if (likelihood instanceof TreeDataLikelihood) {
likelihoods.add((TreeDataLikelihood) likelihood);
trees.add(((TreeDataLikelihood) likelihood).getTree());
delegates.add(((TreeDataLikelihood) likelihood).getDataLikelihoodDelegate());
}
}
//suggested to go through TreeDataLikelihoodParser and give it an extra option to create a HashMap
//keyed by the tree; am currently not overly fond of this approach
ArrayList<PatternList> patternLists = new ArrayList<PatternList>();
for (DataLikelihoodDelegate del : delegates) {
if (del instanceof BeagleDataLikelihoodDelegate) {
patternLists.add(((BeagleDataLikelihoodDelegate) del).getPatternList());
} else if (del instanceof MultiPartitionDataLikelihoodDelegate) {
MultiPartitionDataLikelihoodDelegate mpdld = (MultiPartitionDataLikelihoodDelegate) del;
List<PatternList> list = mpdld.getPatternLists();
for (PatternList pList : list) {
patternLists.add(pList);
}
}
}
if (patternLists.size() == 0) {
throw new RuntimeException("No patterns detected. Please make sure the XML file is BEAST 1.9 compatible.");
}
//aggregate all patterns to create distance matrix
//TODO What about different trees for different partitions?
Patterns patterns = new Patterns(patternLists.get(0));
if (patternLists.size() > 1) {
for (int i = 1; i < patternLists.size(); i++) {
patterns.addPatterns(patternLists.get(i));
}
}
//set the patterns for the distance matrix computations
choice.setPatterns(patterns);
//add new taxa one at a time
System.out.println("Adding " + newTaxaNodes.size() + " taxa ...");
for (NodeRef newTaxon : newTaxaNodes) {
treeModel.setNodeHeight(newTaxon, treeModel.getNodeTaxon(newTaxon).getHeight());
System.out.println("\nadding Taxon: " + newTaxon + " (height = " + treeModel.getNodeHeight(newTaxon) + ")");
//check if this taxon has a more recent sampling date than all other nodes in the current TreeModel
double offset = checkCurrentTreeNodes(newTaxon, treeModel.getRoot());
System.out.println("Sampling date offset when adding " + newTaxon + " = " + offset);
//AND set its current node height to 0.0 IF no originTaxon has been found
if (offset < 0.0) {
if (!originTaxon) {
System.out.println("Updating all node heights with offset " + Math.abs(offset));
updateAllTreeNodes(Math.abs(offset), treeModel.getRoot());
treeModel.setNodeHeight(newTaxon, 0.0);
}
} else if (offset == 0.0) {
if (!originTaxon) {
treeModel.setNodeHeight(newTaxon, 0.0);
}
}
//get the closest Taxon to the Taxon that needs to be added
//take into account which taxa can currently be chosen
Taxon closest = choice.getClosestTaxon(treeModel.getNodeTaxon(newTaxon), currentTaxa);
System.out.println("\nclosest Taxon: " + closest + " with original height: " + closest.getHeight());
//get the distance between these two taxa
double distance = choice.getDistance(treeModel.getNodeTaxon(newTaxon), closest);
System.out.println("at distance: " + distance);
//TODO what if distance == 0.0 ??? how to choose closest taxon then (in absence of geo info)?
//find the NodeRef for the closest Taxon (do not rely on node numbering)
NodeRef closestRef = null;
//careful with node numbering and subtract number of new taxa
for (int i = 0; i < treeModel.getExternalNodeCount(); i++) {
if (treeModel.getNodeTaxon(treeModel.getExternalNode(i)) == closest) {
closestRef = treeModel.getExternalNode(i);
}
}
System.out.println(closestRef + " with height " + treeModel.getNodeHeight(closestRef));
//System.out.println("trying to set node height: " + closestRef + " from " + treeModel.getNodeHeight(closestRef) + " to " + closest.getHeight());
//treeModel.setNodeHeight(closestRef, closest.getHeight());
double timeForDistance = distance / rateModel.getBranchRate(treeModel, closestRef);
System.out.println("timeForDistance = " + timeForDistance);
//get parent node of branch that will be split
NodeRef parent = treeModel.getParent(closestRef);
//determine height of new node
double insertHeight;
if (treeModel.getNodeHeight(closestRef) == treeModel.getNodeHeight(newTaxon)) {
insertHeight = treeModel.getNodeHeight(closestRef) + timeForDistance / 2.0;
System.out.println("treeModel.getNodeHeight(closestRef) == treeModel.getNodeHeight(newTaxon): " + insertHeight);
if (insertHeight >= treeModel.getNodeHeight(parent)) {
insertHeight = treeModel.getNodeHeight(closestRef) + EPSILON * (treeModel.getNodeHeight(parent) - treeModel.getNodeHeight(closestRef));
}
} else {
double remainder = (timeForDistance - Math.abs(treeModel.getNodeHeight(closestRef) - treeModel.getNodeHeight(newTaxon))) / 2.0;
if (remainder > 0) {
insertHeight = Math.max(treeModel.getNodeHeight(closestRef), treeModel.getNodeHeight(newTaxon)) + remainder;
System.out.println("remainder > 0: " + insertHeight);
if (insertHeight >= treeModel.getNodeHeight(parent)) {
insertHeight = treeModel.getNodeHeight(closestRef) + EPSILON * (treeModel.getNodeHeight(parent) - treeModel.getNodeHeight(closestRef));
}
} else {
insertHeight = EPSILON * (treeModel.getNodeHeight(parent) - Math.max(treeModel.getNodeHeight(closestRef), treeModel.getNodeHeight(newTaxon)));
insertHeight += Math.max(treeModel.getNodeHeight(closestRef), treeModel.getNodeHeight(newTaxon));
System.out.println("remainder <= 0: " + insertHeight);
}
}
System.out.println("insert at height: " + insertHeight);
//pass on all the necessary variables to a method that adds the new taxon to the tree
addTaxonAlongBranch(newTaxon, parent, closestRef, insertHeight);
//option to print tree after each taxon addition
System.out.println("\nTree after adding taxon " + newTaxon + ":\n" + treeModel.toString());
//add newly added Taxon to list of current taxa
currentTaxa.add(treeModel.getNodeTaxon(newTaxon));
}
return newTaxaNodes;
}
use of dr.inference.model.Likelihood in project beast-mcmc by beast-dev.
the class FactorRJMCMCOperatorParser method parseXMLObject.
@Override
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
//attributes
double weight = xo.getDoubleAttribute(WEIGHT);
int chainLength = xo.getIntegerAttribute(CHAIN_LENGTH);
double sizeParameter = xo.getDoubleAttribute(SIZE_PARAMETER);
//declaration
AdaptableSizeFastMatrixParameter factors, loadings, cutoffs, loadingsSparcity;
//parameters
if (xo.hasChildNamed(FACTORS))
factors = (AdaptableSizeFastMatrixParameter) xo.getChild(FACTORS).getChild(AdaptableSizeFastMatrixParameter.class);
else
factors = null;
loadings = (AdaptableSizeFastMatrixParameter) xo.getChild(LOADINGS).getChild(AdaptableSizeFastMatrixParameter.class);
if (xo.hasChildNamed(CUTOFFS)) {
cutoffs = (AdaptableSizeFastMatrixParameter) xo.getChild(CUTOFFS).getChild(AdaptableSizeFastMatrixParameter.class);
} else
cutoffs = null;
if (xo.hasChildNamed(LOADINGS_SPARSITY))
loadingsSparcity = (AdaptableSizeFastMatrixParameter) xo.getChild(LOADINGS_SPARSITY).getChild(AdaptableSizeFastMatrixParameter.class);
else
loadingsSparcity = null;
//models
DeterminentalPointProcessPrior DPP = null;
if (xo.getChild(SPARSITY_PRIOR) != null)
DPP = (DeterminentalPointProcessPrior) xo.getChild(SPARSITY_PRIOR).getChild(DeterminentalPointProcessPrior.class);
SimpleMCMCOperator NOp = null;
if (xo.getChild(NEGATION_OPERATOR) != null) {
NOp = (SimpleMCMCOperator) xo.getChild(NEGATION_OPERATOR).getChild(SimpleMCMCOperator.class);
}
AbstractModelLikelihood LFM = (AbstractModelLikelihood) xo.getChild(AbstractModelLikelihood.class);
RowDimensionPoissonPrior rowPrior = (RowDimensionPoissonPrior) xo.getChild(ROW_PRIOR).getChild(RowDimensionPoissonPrior.class);
Likelihood loadingsPrior = null;
if (xo.hasChildNamed(LOADINGS_PRIOR)) {
loadingsPrior = (Likelihood) xo.getChild(LOADINGS_PRIOR).getChild(Likelihood.class);
}
//operators
BitFlipOperator sparsityOperator = null;
if (xo.getChild(BitFlipOperator.class) != null)
sparsityOperator = (BitFlipOperator) xo.getChild(BitFlipOperator.class);
SimpleMCMCOperator loadingsOperator = (SimpleMCMCOperator) xo.getChild(LOADINGS_OPERATOR).getChild(SimpleMCMCOperator.class);
SimpleMCMCOperator factorOperator = null;
if (xo.getChild(FACTOR_OPERATOR) != null)
factorOperator = (SimpleMCMCOperator) xo.getChild(FACTOR_OPERATOR).getChild(FactorTreeGibbsOperator.class);
return new FactorRJMCMCOperator(weight, sizeParameter, chainLength, factors, loadings, cutoffs, loadingsSparcity, LFM, DPP, loadingsPrior, loadingsOperator, factorOperator, sparsityOperator, NOp, rowPrior);
}
use of dr.inference.model.Likelihood in project beast-mcmc by beast-dev.
the class MLOptimizerParser method parseXMLObject.
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
int chainLength = xo.getIntegerAttribute(CHAIN_LENGTH);
OperatorSchedule opsched = null;
dr.inference.model.Likelihood likelihood = null;
ArrayList<Logger> loggers = new ArrayList<Logger>();
for (int i = 0; i < xo.getChildCount(); i++) {
Object child = xo.getChild(i);
if (child instanceof dr.inference.model.Likelihood) {
likelihood = (dr.inference.model.Likelihood) child;
} else if (child instanceof OperatorSchedule) {
opsched = (OperatorSchedule) child;
} else if (child instanceof Logger) {
loggers.add((Logger) child);
} else {
throw new XMLParseException("Unrecognized element found in optimizer element:" + child);
}
}
Logger[] loggerArray = new Logger[loggers.size()];
loggers.toArray(loggerArray);
return new MLOptimizer("optimizer1", chainLength, likelihood, opsched, loggerArray);
}
use of dr.inference.model.Likelihood in project beast-mcmc by beast-dev.
the class LognormalPriorTest method testLognormalPrior.
public void testLognormalPrior() {
// ConstantPopulation constant = new ConstantPopulation(Units.Type.YEARS);
// constant.setN0(popSize); // popSize
Parameter popSize = new Parameter.Default(6.0);
popSize.setId(ConstantPopulationModelParser.POPULATION_SIZE);
ConstantPopulationModel demo = new ConstantPopulationModel(popSize, Units.Type.YEARS);
//Likelihood
Likelihood dummyLikelihood = new DummyLikelihood(demo);
// Operators
OperatorSchedule schedule = new SimpleOperatorSchedule();
MCMCOperator operator = new ScaleOperator(popSize, 0.75);
operator.setWeight(1.0);
schedule.addOperator(operator);
// Log
ArrayLogFormatter formatter = new ArrayLogFormatter(false);
MCLogger[] loggers = new MCLogger[2];
loggers[0] = new MCLogger(formatter, 1000, false);
// loggers[0].add(treeLikelihood);
loggers[0].add(popSize);
loggers[1] = new MCLogger(new TabDelimitedFormatter(System.out), 100000, false);
// loggers[1].add(treeLikelihood);
loggers[1].add(popSize);
// MCMC
MCMC mcmc = new MCMC("mcmc1");
MCMCOptions options = new MCMCOptions(1000000);
// meanInRealSpace="false"
DistributionLikelihood logNormalLikelihood = new DistributionLikelihood(new LogNormalDistribution(1.0, 1.0), 0);
logNormalLikelihood.addData(popSize);
List<Likelihood> likelihoods = new ArrayList<Likelihood>();
likelihoods.add(logNormalLikelihood);
Likelihood prior = new CompoundLikelihood(0, likelihoods);
likelihoods.clear();
likelihoods.add(dummyLikelihood);
Likelihood likelihood = new CompoundLikelihood(-1, likelihoods);
likelihoods.clear();
likelihoods.add(prior);
likelihoods.add(likelihood);
Likelihood posterior = new CompoundLikelihood(0, likelihoods);
mcmc.setShowOperatorAnalysis(true);
mcmc.init(options, posterior, schedule, loggers);
mcmc.run();
// time
System.out.println(mcmc.getTimer().toString());
// Tracer
List<Trace> traces = formatter.getTraces();
ArrayTraceList traceList = new ArrayTraceList("LognormalPriorTest", traces, 0);
for (int i = 1; i < traces.size(); i++) {
traceList.analyseTrace(i);
}
// <expectation name="param" value="4.48168907"/>
TraceCorrelation popSizeStats = traceList.getCorrelationStatistics(traceList.getTraceIndex(ConstantPopulationModelParser.POPULATION_SIZE));
System.out.println("Expectation of Log-Normal(1,1) is e^(M+S^2/2) = e^(1.5) = " + Math.exp(1.5));
assertExpectation(ConstantPopulationModelParser.POPULATION_SIZE, popSizeStats, Math.exp(1.5));
}
Aggregations