Search in sources :

Example 11 with TreeTraitProvider

use of dr.evolution.tree.TreeTraitProvider in project beast-mcmc by beast-dev.

the class ContinuousDataLikelihoodDelegate method addBranchConditionalDensityTrait.

void addBranchConditionalDensityTrait(String traitName) {
    ProcessSimulationDelegate gradientDelegate = new BranchConditionalDistributionDelegate(traitName, getCallbackLikelihood().getTree(), getDiffusionModel(), getDataModel(), getRootPrior(), getRateTransformation(), this);
    TreeTraitProvider traitProvider = new ProcessSimulation(getCallbackLikelihood(), gradientDelegate);
    getCallbackLikelihood().addTraits(traitProvider.getTreeTraits());
}
Also used : TreeTraitProvider(dr.evolution.tree.TreeTraitProvider)

Example 12 with TreeTraitProvider

use of dr.evolution.tree.TreeTraitProvider in project beast-mcmc by beast-dev.

the class AncestralTraitParser method parseXMLObject.

public Object parseXMLObject(XMLObject xo) throws XMLParseException {
    String traitName = xo.getAttribute(TRAIT_NAME, STATES);
    String name = xo.getAttribute(NAME, traitName);
    Tree tree = (Tree) xo.getChild(Tree.class);
    TreeTraitProvider treeTraitProvider = (TreeTraitProvider) xo.getChild(TreeTraitProvider.class);
    TaxonList taxa = null;
    if (xo.hasChildNamed(MRCA)) {
        taxa = (TaxonList) xo.getElementFirstChild(MRCA);
    }
    TreeTrait trait = treeTraitProvider.getTreeTrait(traitName);
    if (trait == null) {
        throw new XMLParseException("A trait called, " + traitName + ", was not available from the TreeTraitProvider supplied to " + getParserName() + (xo.hasId() ? ", with ID " + xo.getId() : ""));
    }
    try {
        return new AncestralTrait(name, trait, tree, taxa);
    } catch (TreeUtils.MissingTaxonException mte) {
        throw new XMLParseException("Taxon, " + mte + ", in " + getParserName() + "was not found in the tree.");
    }
}
Also used : TreeTraitProvider(dr.evolution.tree.TreeTraitProvider) TaxonList(dr.evolution.util.TaxonList) Tree(dr.evolution.tree.Tree) AncestralTrait(dr.evomodel.tree.AncestralTrait) TreeTrait(dr.evolution.tree.TreeTrait) TreeUtils(dr.evolution.tree.TreeUtils)

Example 13 with TreeTraitProvider

use of dr.evolution.tree.TreeTraitProvider in project beast-mcmc by beast-dev.

the class DiffusionGradientTest method testGradient.

private void testGradient(MultivariateDiffusionModel diffusionModel, DiffusionProcessDelegate diffusionProcessDelegate, ContinuousTraitPartialsProvider dataModel, ConjugateRootTraitPrior rootPrior, Parameter meanRoot, MatrixParameterInterface precision, Boolean wishart, MatrixParameterInterface attenuation, Parameter drift, MatrixParameterInterface samplingPrecision) {
    int dimLocal = rootPrior.getMean().length;
    // CDL
    ContinuousDataLikelihoodDelegate likelihoodDelegate = new ContinuousDataLikelihoodDelegate(treeModel, diffusionProcessDelegate, dataModel, rootPrior, rateTransformation, rateModel, true);
    // Likelihood Computation
    TreeDataLikelihood dataLikelihood = new TreeDataLikelihood(likelihoodDelegate, treeModel, rateModel);
    ProcessSimulationDelegate simulationDelegate = likelihoodDelegate.getPrecisionType() == PrecisionType.SCALAR ? new ConditionalOnTipsRealizedDelegate("trait", treeModel, diffusionModel, dataModel, rootPrior, rateTransformation, likelihoodDelegate) : new MultivariateConditionalOnTipsRealizedDelegate("trait", treeModel, diffusionModel, dataModel, rootPrior, rateTransformation, likelihoodDelegate);
    TreeTraitProvider traitProvider = new ProcessSimulation(dataLikelihood, simulationDelegate);
    dataLikelihood.addTraits(traitProvider.getTreeTraits());
    ProcessSimulationDelegate fullConditionalDelegate = new TipRealizedValuesViaFullConditionalDelegate("trait", treeModel, diffusionModel, dataModel, rootPrior, rateTransformation, likelihoodDelegate);
    dataLikelihood.addTraits(new ProcessSimulation(dataLikelihood, fullConditionalDelegate).getTreeTraits());
    // Variance
    ContinuousDataLikelihoodDelegate cdld = (ContinuousDataLikelihoodDelegate) dataLikelihood.getDataLikelihoodDelegate();
    if (precision != null) {
        // Branch Specific
        ContinuousProcessParameterGradient traitGradient = new ContinuousProcessParameterGradient(rootPrior.getMean().length, treeModel, cdld, new ArrayList<>(Arrays.asList(DerivationParameter.WRT_VARIANCE)));
        BranchSpecificGradient branchSpecificGradient = new BranchSpecificGradient("trait", dataLikelihood, cdld, traitGradient, precision);
        GradientWrtPrecisionProvider gPPBranchSpecific = new GradientWrtPrecisionProvider.BranchSpecificGradientWrtPrecisionProvider(branchSpecificGradient);
        // Correlation Gradient Branch Specific
        CorrelationPrecisionGradient gradientProviderBranchSpecific = new CorrelationPrecisionGradient(gPPBranchSpecific, dataLikelihood, precision);
        double[] gradientAnalyticalBS = testOneGradient(gradientProviderBranchSpecific);
        // Diagonal Gradient Branch Specific
        DiagonalPrecisionGradient gradientDiagonalProviderBS = new DiagonalPrecisionGradient(gPPBranchSpecific, dataLikelihood, precision);
        double[] gradientDiagonalAnalyticalBS = testOneGradient(gradientDiagonalProviderBS);
        if (wishart) {
            // Wishart Statistic
            WishartStatisticsWrapper wishartStatistics = new WishartStatisticsWrapper("wishart", "trait", dataLikelihood, cdld);
            GradientWrtPrecisionProvider gPPWiwhart = new GradientWrtPrecisionProvider.WishartGradientWrtPrecisionProvider(wishartStatistics);
            // Correlation Gradient
            CorrelationPrecisionGradient gradientProviderWishart = new CorrelationPrecisionGradient(gPPWiwhart, dataLikelihood, precision);
            String sW = gradientProviderWishart.getReport();
            System.err.println(sW);
            double[] gradientAnalyticalW = parseGradient(sW, "analytic");
            assertEquals("Sizes", gradientAnalyticalW.length, gradientAnalyticalBS.length);
            for (int k = 0; k < gradientAnalyticalW.length; k++) {
                assertEquals("gradient correlation k=" + k, gradientAnalyticalW[k], gradientAnalyticalBS[k], delta);
            }
            // Diagonal Gradient
            DiagonalPrecisionGradient gradientDiagonalProviderW = new DiagonalPrecisionGradient(gPPWiwhart, dataLikelihood, precision);
            String sDiagW = gradientDiagonalProviderW.getReport();
            System.err.println(sDiagW);
            double[] gradientDiagonalAnalyticalW = parseGradient(sDiagW, "analytic");
            assertEquals("Sizes", gradientDiagonalAnalyticalW.length, gradientDiagonalAnalyticalBS.length);
            for (int k = 0; k < gradientDiagonalAnalyticalW.length; k++) {
                assertEquals("gradient diagonal k=" + k, gradientDiagonalAnalyticalW[k], gradientDiagonalAnalyticalBS[k], delta);
            }
        }
    }
    // Diagonal Attenuation Gradient Branch Specific
    if (attenuation != null) {
        ContinuousProcessParameterGradient traitGradientAtt = new ContinuousProcessParameterGradient(dimLocal, treeModel, cdld, new ArrayList<>(Arrays.asList(DerivationParameter.WRT_DIAGONAL_SELECTION_STRENGTH)));
        BranchSpecificGradient branchSpecificGradientAtt = new BranchSpecificGradient("trait", dataLikelihood, cdld, traitGradientAtt, attenuation);
        AbstractDiffusionGradient.ParameterDiffusionGradient gABranchSpecific = createDiagonalAttenuationGradient(branchSpecificGradientAtt, dataLikelihood, attenuation);
        testOneGradient(gABranchSpecific);
    }
    // WRT root mean
    boolean sameRoot = (drift == meanRoot);
    ContinuousProcessParameterGradient traitGradientRoot = new ContinuousProcessParameterGradient(dimLocal, treeModel, cdld, new ArrayList<>(Arrays.asList(sameRoot ? DerivationParameter.WRT_CONSTANT_DRIFT_AND_ROOT_MEAN : DerivationParameter.WRT_ROOT_MEAN)));
    BranchSpecificGradient branchSpecificGradientRoot = new BranchSpecificGradient("trait", dataLikelihood, cdld, traitGradientRoot, meanRoot);
    AbstractDiffusionGradient.ParameterDiffusionGradient gRootBranchSpecific = createDriftGradient(branchSpecificGradientRoot, dataLikelihood, meanRoot);
    testOneGradient(gRootBranchSpecific);
    // Drift Gradient Branch Specific
    if (drift != null && !sameRoot) {
        ContinuousProcessParameterGradient traitGradientDrift = new ContinuousProcessParameterGradient(dimLocal, treeModel, cdld, new ArrayList<>(Arrays.asList(DerivationParameter.WRT_CONSTANT_DRIFT)));
        BranchSpecificGradient branchSpecificGradientDrift = new BranchSpecificGradient("trait", dataLikelihood, cdld, traitGradientDrift, drift);
        AbstractDiffusionGradient.ParameterDiffusionGradient gDriftBranchSpecific = createDriftGradient(branchSpecificGradientDrift, dataLikelihood, drift);
        testOneGradient(gDriftBranchSpecific);
    }
    // Sampling Precision
    if (samplingPrecision != null) {
        ContinuousTraitGradientForBranch.SamplingVarianceGradient traitGradientSampling = new ContinuousTraitGradientForBranch.SamplingVarianceGradient(dimLocal, treeModel, likelihoodDelegate, (ModelExtensionProvider.NormalExtensionProvider) dataModel);
        BranchSpecificGradient branchSpecificGradientSampling = new BranchSpecificGradient("trait", dataLikelihood, cdld, traitGradientSampling, samplingPrecision);
        GradientWrtPrecisionProvider gPPBranchSpecificSampling = new GradientWrtPrecisionProvider.BranchSpecificGradientWrtPrecisionProvider(branchSpecificGradientSampling);
        // Correlation Gradient Branch Specific
        // CorrelationPrecisionGradient gradientProviderBranchSpecificSampling = new CorrelationPrecisionGradient(gPPBranchSpecificSampling, dataLikelihood, samplingPrecision);
        // 
        // testOneGradient(gradientProviderBranchSpecificSampling);
        // Diagonal Gradient Branch Specific
        DiagonalPrecisionGradient gradientDiagonalProviderBSSampling = new DiagonalPrecisionGradient(gPPBranchSpecificSampling, dataLikelihood, samplingPrecision);
        testOneGradient(gradientDiagonalProviderBSSampling);
    }
}
Also used : CorrelationPrecisionGradient(dr.evomodel.treedatalikelihood.hmc.CorrelationPrecisionGradient) DiagonalPrecisionGradient(dr.evomodel.treedatalikelihood.hmc.DiagonalPrecisionGradient) ContinuousProcessParameterGradient(dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient) TreeTraitProvider(dr.evolution.tree.TreeTraitProvider) AbstractDiffusionGradient(dr.evomodel.treedatalikelihood.hmc.AbstractDiffusionGradient) TreeDataLikelihood(dr.evomodel.treedatalikelihood.TreeDataLikelihood) GradientWrtPrecisionProvider(dr.evomodel.treedatalikelihood.hmc.GradientWrtPrecisionProvider) ProcessSimulation(dr.evomodel.treedatalikelihood.ProcessSimulation)

Aggregations

TreeTraitProvider (dr.evolution.tree.TreeTraitProvider)13 Tree (dr.evolution.tree.Tree)3 TreeTrait (dr.evolution.tree.TreeTrait)3 DataType (dr.evolution.datatype.DataType)2 ProcessSimulation (dr.evomodel.treedatalikelihood.ProcessSimulation)2 TreeDataLikelihood (dr.evomodel.treedatalikelihood.TreeDataLikelihood)2 NumberFormat (java.text.NumberFormat)2 AncestralSequenceTrait (dr.app.bss.test.AncestralSequenceTrait)1 PatternList (dr.evolution.alignment.PatternList)1 ImportException (dr.evolution.io.Importer.ImportException)1 NewickImporter (dr.evolution.io.NewickImporter)1 NodeRef (dr.evolution.tree.NodeRef)1 TransformedTreeTraitProvider (dr.evolution.tree.TransformedTreeTraitProvider)1 TreeUtils (dr.evolution.tree.TreeUtils)1 TaxonList (dr.evolution.util.TaxonList)1 BranchRateModel (dr.evomodel.branchratemodel.BranchRateModel)1 DefaultBranchRateModel (dr.evomodel.branchratemodel.DefaultBranchRateModel)1 DiscreteTraitBranchRateModel (dr.evomodel.branchratemodel.DiscreteTraitBranchRateModel)1 MultivariateDiffusionModel (dr.evomodel.continuous.MultivariateDiffusionModel)1 MultivariateElasticModel (dr.evomodel.continuous.MultivariateElasticModel)1