use of dr.evomodel.treedatalikelihood.DataLikelihoodDelegate in project beast-mcmc by beast-dev.
the class NodeHeightGradientParser method parseXMLObject.
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
String traitName = xo.getAttribute(TRAIT_NAME, DEFAULT_TRAIT_NAME);
final TreeDataLikelihood treeDataLikelihood = (TreeDataLikelihood) xo.getChild(TreeDataLikelihood.class);
BranchRateModel branchRateModel = treeDataLikelihood.getBranchRateModel();
if (branchRateModel instanceof DefaultBranchRateModel || branchRateModel instanceof ArbitraryBranchRates) {
Parameter branchRates = null;
if (branchRateModel instanceof ArbitraryBranchRates) {
branchRates = ((ArbitraryBranchRates) branchRateModel).getRateParameter();
}
DataLikelihoodDelegate delegate = treeDataLikelihood.getDataLikelihoodDelegate();
if (delegate instanceof ContinuousDataLikelihoodDelegate) {
throw new XMLParseException("Not yet implemented! ");
} else if (delegate instanceof BeagleDataLikelihoodDelegate) {
BeagleDataLikelihoodDelegate beagleData = (BeagleDataLikelihoodDelegate) delegate;
return new NodeHeightGradientForDiscreteTrait(traitName, treeDataLikelihood, beagleData, branchRates);
} else {
throw new XMLParseException("Unknown likelihood delegate type");
}
} else {
throw new XMLParseException("Only implemented for an arbitrary rates model");
}
}
use of dr.evomodel.treedatalikelihood.DataLikelihoodDelegate in project beast-mcmc by beast-dev.
the class PrecisionGradientParser method parseXMLObject.
@Override
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
String traitName = xo.getAttribute(TRAIT_NAME, DEFAULT_TRAIT_NAME);
MatrixParameterInterface parameter = (MatrixParameterInterface) xo.getChild(MatrixParameterInterface.class);
TreeDataLikelihood treeDataLikelihood = (TreeDataLikelihood) xo.getChild(TreeDataLikelihood.class);
GradientWrtPrecisionProvider gradientWrtPrecisionProvider;
ConjugateWishartStatisticsProvider wishartStatistics = (ConjugateWishartStatisticsProvider) xo.getChild(ConjugateWishartStatisticsProvider.class);
if (wishartStatistics != null) {
gradientWrtPrecisionProvider = new GradientWrtPrecisionProvider.WishartGradientWrtPrecisionProvider(wishartStatistics);
} else {
int dim = treeDataLikelihood.getDataLikelihoodDelegate().getTraitDim();
Tree tree = treeDataLikelihood.getTree();
DataLikelihoodDelegate delegate = treeDataLikelihood.getDataLikelihoodDelegate();
ContinuousDataLikelihoodDelegate continuousData = (ContinuousDataLikelihoodDelegate) delegate;
ModelExtensionProvider.NormalExtensionProvider extensionProvider = (ModelExtensionProvider.NormalExtensionProvider) xo.getChild(ModelExtensionProvider.NormalExtensionProvider.class);
ContinuousTraitGradientForBranch traitGradient;
if (extensionProvider != null) {
traitGradient = new ContinuousTraitGradientForBranch.SamplingVarianceGradient(dim, tree, continuousData, extensionProvider);
} else {
traitGradient = new ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient(dim, tree, continuousData, new ArrayList<>(Arrays.asList(ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient.DerivationParameter.WRT_VARIANCE)));
}
BranchSpecificGradient branchSpecificGradient = new BranchSpecificGradient(traitName, treeDataLikelihood, continuousData, traitGradient, parameter);
gradientWrtPrecisionProvider = new GradientWrtPrecisionProvider.BranchSpecificGradientWrtPrecisionProvider(branchSpecificGradient);
}
ParameterMode parameterMode = parseParameterMode(xo);
return parameterMode.factory(gradientWrtPrecisionProvider, treeDataLikelihood, parameter);
}
use of dr.evomodel.treedatalikelihood.DataLikelihoodDelegate in project beast-mcmc by beast-dev.
the class AttenuationGradientParser method parseXMLObject.
@Override
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
String traitName = xo.getAttribute(TRAIT_NAME, DEFAULT_TRAIT_NAME);
MatrixParameterInterface parameter = (MatrixParameterInterface) xo.getChild(MatrixParameterInterface.class);
TreeDataLikelihood treeDataLikelihood = (TreeDataLikelihood) xo.getChild(TreeDataLikelihood.class);
DataLikelihoodDelegate delegate = treeDataLikelihood.getDataLikelihoodDelegate();
int dim = treeDataLikelihood.getDataLikelihoodDelegate().getTraitDim();
Tree tree = treeDataLikelihood.getTree();
ContinuousDataLikelihoodDelegate continuousData = (ContinuousDataLikelihoodDelegate) delegate;
ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient traitGradient = new ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient(dim, tree, continuousData, new ArrayList<ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient.DerivationParameter>(Arrays.asList(ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient.DerivationParameter.WRT_DIAGONAL_SELECTION_STRENGTH)));
BranchSpecificGradient branchSpecificGradient = new BranchSpecificGradient(traitName, treeDataLikelihood, continuousData, traitGradient, parameter);
ParameterMode parameterMode = parseParameterMode(xo);
return parameterMode.factory(branchSpecificGradient, treeDataLikelihood, parameter);
}
use of dr.evomodel.treedatalikelihood.DataLikelihoodDelegate in project beast-mcmc by beast-dev.
the class FullyConjugateTreeTipsPotentialDerivativeParser method parseXMLObject.
@Override
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
// String name = xo.hasId() ? xo.getId() : FULLY_CONJUGATE_TREE_TIPS_POTENTIAL_DERIVATIVE2;
String traitName = xo.getAttribute(TRAIT_NAME, DEFAULT_TRAIT_NAME);
// Object co = xo.getChild(0);
final FullyConjugateMultivariateTraitLikelihood fcTreeLikelihood = (FullyConjugateMultivariateTraitLikelihood) xo.getChild(FullyConjugateMultivariateTraitLikelihood.class);
final TreeDataLikelihood treeDataLikelihood = (TreeDataLikelihood) xo.getChild(TreeDataLikelihood.class);
Parameter mask = null;
if (xo.hasChildNamed(MASKING)) {
mask = (Parameter) xo.getElementFirstChild(MASKING);
}
if (fcTreeLikelihood != null) {
return new FullyConjugateTreeTipsPotentialDerivative(fcTreeLikelihood, mask);
} else if (treeDataLikelihood != null) {
DataLikelihoodDelegate delegate = treeDataLikelihood.getDataLikelihoodDelegate();
if (!(delegate instanceof ContinuousDataLikelihoodDelegate)) {
throw new XMLParseException("May not provide a sequence data likelihood to compute tip trait gradient");
}
final ContinuousDataLikelihoodDelegate continuousData = (ContinuousDataLikelihoodDelegate) delegate;
return new TreeTipGradient(traitName, treeDataLikelihood, continuousData, mask);
} else {
throw new XMLParseException("Must provide a tree likelihood");
}
}
use of dr.evomodel.treedatalikelihood.DataLikelihoodDelegate in project beast-mcmc by beast-dev.
the class TreeDataLikelihoodParser method createTreeDataLikelihood.
protected Likelihood createTreeDataLikelihood(List<PatternList> patternLists, List<BranchModel> branchModels, List<SiteRateModel> siteRateModels, Tree treeModel, BranchRateModel branchRateModel, TipStatesModel tipStatesModel, boolean useAmbiguities, boolean preferGPU, PartialsRescalingScheme scalingScheme, boolean delayRescalingUntilUnderflow, PreOrderSettings settings) throws XMLParseException {
if (tipStatesModel != null) {
throw new XMLParseException("Tip State Error models are not supported yet with TreeDataLikelihood");
}
List<Taxon> treeTaxa = treeModel.asList();
List<Taxon> patternTaxa = patternLists.get(0).asList();
if (!patternTaxa.containsAll(treeTaxa)) {
throw new XMLParseException("TreeModel " + treeModel.getId() + " contains more taxa (" + treeModel.getExternalNodeCount() + ") than the partition pattern list (" + patternTaxa.size() + ").");
}
if (!treeTaxa.containsAll(patternTaxa)) {
throw new XMLParseException("TreeModel " + treeModel.getId() + " contains fewer taxa (" + treeModel.getExternalNodeCount() + ") than the partition pattern list (" + patternTaxa.size() + ").");
}
boolean useBeagle3MultiPartition = false;
if (patternLists.size() > 1) {
// will currently recommend true if using GPU, CUDA or OpenCL.
useBeagle3MultiPartition = MultiPartitionDataLikelihoodDelegate.IS_MULTI_PARTITION_RECOMMENDED();
if (System.getProperty("USE_BEAGLE3_EXTENSIONS") != null) {
useBeagle3MultiPartition = Boolean.parseBoolean(System.getProperty("USE_BEAGLE3_EXTENSIONS"));
}
if (System.getProperty("beagle.multipartition.extensions") != null && !System.getProperty("beagle.multipartition.extensions").equals("auto")) {
useBeagle3MultiPartition = Boolean.parseBoolean(System.getProperty("beagle.multipartition.extensions"));
}
}
boolean useJava = Boolean.parseBoolean(System.getProperty("java.only", "false"));
int threadCount = -1;
int beagleThreadCount = -1;
if (System.getProperty(BEAGLE_THREAD_COUNT) != null) {
beagleThreadCount = Integer.parseInt(System.getProperty(BEAGLE_THREAD_COUNT));
}
if (beagleThreadCount == -1) {
if (System.getProperty(THREAD_COUNT) != null) {
threadCount = Integer.parseInt(System.getProperty(THREAD_COUNT));
}
}
if (useBeagle3MultiPartition && !useJava) {
if (beagleThreadCount == -1 && threadCount >= 0) {
System.setProperty(BEAGLE_THREAD_COUNT, Integer.toString(threadCount));
}
try {
DataLikelihoodDelegate dataLikelihoodDelegate = new MultiPartitionDataLikelihoodDelegate(treeModel, patternLists, branchModels, siteRateModels, useAmbiguities, scalingScheme, delayRescalingUntilUnderflow);
return new TreeDataLikelihood(dataLikelihoodDelegate, treeModel, branchRateModel);
} catch (DataLikelihoodDelegate.DelegateTypeException dte) {
useBeagle3MultiPartition = false;
}
}
// The multipartition data likelihood isn't available so make a set of single partition data likelihoods
List<Likelihood> treeDataLikelihoods = new ArrayList<Likelihood>();
// Todo: allow for different number of threads per beagle instance according to pattern counts
if (beagleThreadCount == -1 && threadCount >= 0) {
System.setProperty(BEAGLE_THREAD_COUNT, Integer.toString(threadCount / patternLists.size()));
}
for (int i = 0; i < patternLists.size(); i++) {
DataLikelihoodDelegate dataLikelihoodDelegate = new BeagleDataLikelihoodDelegate(treeModel, patternLists.get(i), branchModels.get(i), siteRateModels.get(i), useAmbiguities, preferGPU, scalingScheme, delayRescalingUntilUnderflow, settings);
treeDataLikelihoods.add(new TreeDataLikelihood(dataLikelihoodDelegate, treeModel, branchRateModel));
}
if (treeDataLikelihoods.size() == 1) {
return treeDataLikelihoods.get(0);
}
return new CompoundLikelihood(treeDataLikelihoods);
}
Aggregations