use of dr.evolution.tree.TreeTrait in project beast-mcmc by beast-dev.
the class MarkovJumpsBeagleTreeLikelihood method addRegister.
public void addRegister(Parameter addRegisterParameter, MarkovJumpsType type, boolean scaleByTime) {
if ((type == MarkovJumpsType.COUNTS && addRegisterParameter.getDimension() != stateCount * stateCount) || (type == MarkovJumpsType.REWARDS && addRegisterParameter.getDimension() != stateCount)) {
throw new RuntimeException("Register parameter of wrong dimension");
}
addVariable(addRegisterParameter);
final String tag = addRegisterParameter.getId();
for (int i = 0; i < substitutionModelDelegate.getSubstitutionModelCount(); ++i) {
registerParameter.add(addRegisterParameter);
MarkovJumpsSubstitutionModel mjModel;
SubstitutionModel substitutionModel = substitutionModelDelegate.getSubstitutionModel(i);
if (useUniformization) {
mjModel = new UniformizedSubstitutionModel(substitutionModel, type, nSimulants);
} else {
if (type == MarkovJumpsType.HISTORY) {
throw new RuntimeException("Can only report complete history using uniformization");
}
mjModel = new MarkovJumpsSubstitutionModel(substitutionModel, type);
}
markovjumps.add(mjModel);
branchModelNumber.add(i);
addModel(mjModel);
setupRegistration(numRegisters);
String traitName;
if (substitutionModelDelegate.getSubstitutionModelCount() == 1) {
traitName = tag;
} else {
traitName = tag + i;
}
jumpTag.add(traitName);
expectedJumps.add(new double[treeModel.getNodeCount()][patternCount]);
// storedExpectedJumps.add(new double[treeModel.getNodeCount()][patternCount]);
boolean[] oldScaleByTime = this.scaleByTime;
int oldScaleByTimeLength = (oldScaleByTime == null ? 0 : oldScaleByTime.length);
this.scaleByTime = new boolean[oldScaleByTimeLength + 1];
if (oldScaleByTimeLength > 0) {
System.arraycopy(oldScaleByTime, 0, this.scaleByTime, 0, oldScaleByTimeLength);
}
this.scaleByTime[oldScaleByTimeLength] = scaleByTime;
if (type != MarkovJumpsType.HISTORY) {
TreeTrait.DA da = new TreeTrait.DA() {
final int registerNumber = numRegisters;
public String getTraitName() {
return tag;
}
public Intent getIntent() {
return Intent.BRANCH;
}
public double[] getTrait(Tree tree, NodeRef node) {
return getMarkovJumpsForNodeAndRegister(tree, node, registerNumber);
}
};
treeTraits.addTrait(traitName + "_base", da);
treeTraits.addTrait(addRegisterParameter.getId(), new TreeTrait.SumAcrossArrayD(new TreeTrait.SumOverTreeDA(da)));
} else {
if (histories == null) {
histories = new String[treeModel.getNodeCount()][patternCount];
} else {
throw new RuntimeException("Only one complete history per markovJumpTreeLikelihood is allowed");
}
if (nSimulants > 1) {
throw new RuntimeException("Only one simulant allowed when saving complete history");
}
// Add total number of changes over tree trait
TreeTrait da = new TreeTrait.DA() {
final int registerNumber = numRegisters;
public String getTraitName() {
return tag;
}
public Intent getIntent() {
return Intent.BRANCH;
}
public double[] getTrait(Tree tree, NodeRef node) {
return getMarkovJumpsForNodeAndRegister(tree, node, registerNumber);
}
};
treeTraits.addTrait(addRegisterParameter.getId(), new TreeTrait.SumOverTreeDA(da));
// Record the complete history for this register
historyRegisterNumber = numRegisters;
((UniformizedSubstitutionModel) mjModel).setSaveCompleteHistory(true);
if (useCompactHistory && logHistory) {
treeTraits.addTrait(ALL_HISTORY, new TreeTrait.SA() {
public String getTraitName() {
return ALL_HISTORY;
}
public Intent getIntent() {
return Intent.BRANCH;
}
public boolean getFormatAsArray() {
return true;
}
public String[] getTrait(Tree tree, NodeRef node) {
List<String> events = new ArrayList<String>();
for (int i = 0; i < patternCount; i++) {
String eventString = getHistoryForNode(tree, node, i);
if (eventString != null && eventString.compareTo("{}") != 0) {
eventString = eventString.substring(1, eventString.length() - 1);
if (eventString.contains("},{")) {
// There are multiple events
String[] elements = eventString.split("(?<=\\}),(?=\\{)");
for (String e : elements) {
events.add(e);
}
} else {
events.add(eventString);
}
}
}
String[] array = new String[events.size()];
events.toArray(array);
return array;
}
public boolean getLoggable() {
return true;
}
});
}
for (int site = 0; site < patternCount; ++site) {
final String anonName = (patternCount == 1) ? HISTORY : HISTORY + "_" + (site + 1);
final int anonSite = site;
treeTraits.addTrait(anonName, new TreeTrait.S() {
public String getTraitName() {
return anonName;
}
public Intent getIntent() {
return Intent.BRANCH;
}
public String getTrait(Tree tree, NodeRef node) {
String history = getHistoryForNode(tree, node, anonSite);
// Return null if empty
return (history.compareTo("{}") != 0) ? history : null;
}
public boolean getLoggable() {
return logHistory && !useCompactHistory;
}
});
}
}
numRegisters++;
}
// End of loop over branch models
}
use of dr.evolution.tree.TreeTrait in project beast-mcmc by beast-dev.
the class DiscreteTraitBranchRateModelParser method parseXMLObject.
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
TreeModel treeModel = (TreeModel) xo.getChild(TreeModel.class);
PatternList patternList = (PatternList) xo.getChild(PatternList.class);
TreeTraitProvider traitProvider = (TreeTraitProvider) xo.getChild(TreeTraitProvider.class);
DataType dataType = DataTypeUtils.getDataType(xo);
Parameter rateParameter = null;
Parameter relativeRatesParameter = null;
Parameter indicatorsParameter = null;
if (xo.getChild(RATE) != null) {
rateParameter = (Parameter) xo.getElementFirstChild(RATE);
}
if (xo.getChild(RATES) != null) {
rateParameter = (Parameter) xo.getElementFirstChild(RATES);
}
if (xo.getChild(RELATIVE_RATES) != null) {
relativeRatesParameter = (Parameter) xo.getElementFirstChild(RELATIVE_RATES);
}
if (xo.getChild(INDICATORS) != null) {
indicatorsParameter = (Parameter) xo.getElementFirstChild(INDICATORS);
}
int traitIndex = xo.getAttribute(TRAIT_INDEX, 1) - 1;
String traitName = "states";
Logger.getLogger("dr.evomodel").info("Using discrete trait branch rate model.\n" + "\tIf you use this model, please cite:\n" + "\t\tDrummond and Suchard (in preparation)");
if (traitProvider == null) {
// Use the version that reconstructs the trait using parsimony:
return new DiscreteTraitBranchRateModel(treeModel, patternList, traitIndex, rateParameter);
} else {
if (traitName != null) {
TreeTrait trait = traitProvider.getTreeTrait(traitName);
if (trait == null) {
throw new XMLParseException("A trait called, " + traitName + ", was not available from the TreeTraitProvider supplied to " + getParserName() + ", with ID " + xo.getId());
}
if (relativeRatesParameter != null) {
return new DiscreteTraitBranchRateModel(traitProvider, dataType, treeModel, trait, traitIndex, rateParameter, relativeRatesParameter, indicatorsParameter);
} else {
return new DiscreteTraitBranchRateModel(traitProvider, dataType, treeModel, trait, traitIndex, rateParameter);
}
} else {
TreeTrait[] traits = new TreeTrait[dataType.getStateCount()];
for (int i = 0; i < dataType.getStateCount(); i++) {
traits[i] = traitProvider.getTreeTrait(dataType.getCode(i));
if (traits[i] == null) {
throw new XMLParseException("A trait called, " + dataType.getCode(i) + ", was not available from the TreeTraitProvider supplied to " + getParserName() + ", with ID " + xo.getId());
}
}
return new DiscreteTraitBranchRateModel(traitProvider, traits, treeModel, rateParameter);
}
}
}
use of dr.evolution.tree.TreeTrait in project beast-mcmc by beast-dev.
the class CodonPartitionedRobustCounting method setupTraits.
private void setupTraits() {
TreeTrait baseTrait = new TreeTrait.DA() {
public String getTraitName() {
return BASE_TRAIT_PREFIX + codonLabeling.getText();
}
public Intent getIntent() {
return Intent.BRANCH;
}
public double[] getTrait(Tree tree, NodeRef node) {
return getExpectedCountsForBranch(node);
}
public boolean getLoggable() {
return false;
}
};
if (saveCompleteHistory) {
TreeTrait stringTrait = new TreeTrait.SA() {
public String getTraitName() {
return COMPLETE_HISTORY_PREFIX + codonLabeling.getText();
}
public Intent getIntent() {
return Intent.BRANCH;
}
public boolean getFormatAsArray() {
return true;
}
public String[] getTrait(Tree tree, NodeRef node) {
// Lazy simulation of complete histories
double[] count = getExpectedCountsForBranch(node);
List<String> events = new ArrayList<String>();
for (int i = 0; i < numCodons; i++) {
String eventString = completeHistoryPerNode[node.getNumber()][i];
if (eventString != null) {
if (eventString.contains("},{")) {
// There are multiple events
String[] elements = eventString.split("(?<=\\}),(?=\\{)");
for (String e : elements) {
events.add(e);
}
} else {
events.add(eventString);
}
}
}
if (DEBUG) {
double sum = 0.0;
for (double d : count) {
if (d > 0.0) {
sum += 1;
}
}
System.err.println(events.size() + " " + sum);
if (Math.abs(events.size() - sum) > 0.5) {
System.err.println("Error");
for (int i = 0; i < count.length; ++i) {
if (count[i] != 0.0) {
System.err.println(i + ": " + count[i] + completeHistoryPerNode[node.getNumber()][i]);
}
}
System.err.println("Error");
int c = 0;
for (String s : events) {
c++;
System.err.println(c + ":" + s);
}
System.exit(-1);
}
}
String[] array = new String[events.size()];
events.toArray(array);
return array;
}
public boolean getLoggable() {
return true;
}
};
treeTraits.addTrait(stringTrait);
}
TreeTrait unconditionedSum;
if (!TRIAL) {
unconditionedSum = new TreeTrait.D() {
public String getTraitName() {
return UNCONDITIONED_PREFIX + codonLabeling.getText();
}
public Intent getIntent() {
return Intent.WHOLE_TREE;
}
public Double getTrait(Tree tree, NodeRef node) {
return getUnconditionedTraitValue();
}
public boolean getLoggable() {
return false;
}
};
} else {
unconditionedSum = new TreeTrait.DA() {
public String getTraitName() {
return UNCONDITIONED_PREFIX + codonLabeling.getText();
}
public Intent getIntent() {
return Intent.WHOLE_TREE;
}
public double[] getTrait(Tree tree, NodeRef node) {
return getUnconditionedTraitValues();
}
public boolean getLoggable() {
return false;
}
};
}
TreeTrait sumOverTreeTrait = new TreeTrait.SumOverTreeDA(SITE_SPECIFIC_PREFIX + codonLabeling.getText(), baseTrait, includeExternalBranches, includeInternalBranches) {
@Override
public boolean getLoggable() {
return false;
}
};
// This should be the default output in tree logs
TreeTrait sumOverSitesTrait = new TreeTrait.SumAcrossArrayD(codonLabeling.getText(), baseTrait) {
@Override
public boolean getLoggable() {
return true;
}
};
// This should be the default output in columns logs
String name = prefix != null ? prefix + TOTAL_PREFIX + codonLabeling.getText() : TOTAL_PREFIX + codonLabeling.getText();
TreeTrait sumOverSitesAndTreeTrait = new TreeTrait.SumOverTreeD(name, sumOverSitesTrait, includeExternalBranches, includeInternalBranches) {
@Override
public boolean getLoggable() {
return true;
}
};
treeTraitLogger = new TreeTraitLogger(tree, new TreeTrait[] { sumOverSitesAndTreeTrait });
treeTraits.addTrait(baseTrait);
treeTraits.addTrait(unconditionedSum);
treeTraits.addTrait(sumOverSitesTrait);
treeTraits.addTrait(sumOverTreeTrait);
treeTraits.addTrait(sumOverSitesAndTreeTrait);
if (doUnconditionedPerBranch) {
TreeTrait unconditionedBase = new TreeTrait.DA() {
public String getTraitName() {
return UNCONDITIONED_PER_BRANCH_PREFIX + codonLabeling.getText();
}
public Intent getIntent() {
return Intent.BRANCH;
}
public double[] getTrait(Tree tree, NodeRef node) {
return getUnconditionalCountsForBranch(node);
}
public boolean getLoggable() {
// TODO Should be switched to true to log unconditioned values per branch
return false;
}
};
TreeTrait sumUnconditionedOverSitesTrait = new TreeTrait.SumAcrossArrayD(UNCONDITIONED_PER_BRANCH_PREFIX + codonLabeling.getText(), unconditionedBase) {
@Override
public boolean getLoggable() {
return true;
}
};
String nameU = prefix != null ? prefix + UNCONDITIONED_TOTAL_PREFIX + codonLabeling.getText() : UNCONDITIONED_TOTAL_PREFIX + codonLabeling.getText();
TreeTrait sumUnconditionedOverSitesAndTreeTrait = new TreeTrait.SumOverTreeD(nameU, sumUnconditionedOverSitesTrait, includeExternalBranches, includeInternalBranches) {
public boolean getLoggable() {
return true;
}
};
treeTraitLogger = new TreeTraitLogger(tree, new TreeTrait[] { sumOverSitesAndTreeTrait, sumUnconditionedOverSitesAndTreeTrait });
treeTraits.addTrait(unconditionedBase);
treeTraits.addTrait(sumUnconditionedOverSitesTrait);
}
}
use of dr.evolution.tree.TreeTrait in project beast-mcmc by beast-dev.
the class TreeTraitLogger method getColumns.
public LogColumn[] getColumns() {
if (loggableTreeTraits.size() == 0) {
return null;
}
LogColumn[] columns = new LogColumn[loggableTreeTraits.size()];
for (int i = 0; i < loggableTreeTraits.size(); i++) {
final TreeTrait trait = loggableTreeTraits.get(i);
columns[i] = new LogColumn.Abstract(trait.getTraitName()) {
@Override
protected String getFormattedValue() {
return trait.getTraitString(tree, null);
}
};
}
return columns;
}
use of dr.evolution.tree.TreeTrait in project beast-mcmc by beast-dev.
the class AncestralTraitParser method parseXMLObject.
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
String traitName = xo.getAttribute(TRAIT_NAME, STATES);
String name = xo.getAttribute(NAME, traitName);
Tree tree = (Tree) xo.getChild(Tree.class);
TreeTraitProvider treeTraitProvider = (TreeTraitProvider) xo.getChild(TreeTraitProvider.class);
TaxonList taxa = null;
if (xo.hasChildNamed(MRCA)) {
taxa = (TaxonList) xo.getElementFirstChild(MRCA);
}
TreeTrait trait = treeTraitProvider.getTreeTrait(traitName);
if (trait == null) {
throw new XMLParseException("A trait called, " + traitName + ", was not available from the TreeTraitProvider supplied to " + getParserName() + (xo.hasId() ? ", with ID " + xo.getId() : ""));
}
try {
return new AncestralTrait(name, trait, tree, taxa);
} catch (TreeUtils.MissingTaxonException mte) {
throw new XMLParseException("Taxon, " + mte + ", in " + getParserName() + "was not found in the tree.");
}
}
Aggregations