Search in sources :

Example 1 with CorrelationPrecisionGradient

use of dr.evomodel.treedatalikelihood.hmc.CorrelationPrecisionGradient in project beast-mcmc by beast-dev.

the class DiffusionGradientTest method testGradient.

private void testGradient(MultivariateDiffusionModel diffusionModel, DiffusionProcessDelegate diffusionProcessDelegate, ContinuousTraitPartialsProvider dataModel, ConjugateRootTraitPrior rootPrior, Parameter meanRoot, MatrixParameterInterface precision, Boolean wishart, MatrixParameterInterface attenuation, Parameter drift, MatrixParameterInterface samplingPrecision) {
    int dimLocal = rootPrior.getMean().length;
    // CDL
    ContinuousDataLikelihoodDelegate likelihoodDelegate = new ContinuousDataLikelihoodDelegate(treeModel, diffusionProcessDelegate, dataModel, rootPrior, rateTransformation, rateModel, true);
    // Likelihood Computation
    TreeDataLikelihood dataLikelihood = new TreeDataLikelihood(likelihoodDelegate, treeModel, rateModel);
    ProcessSimulationDelegate simulationDelegate = likelihoodDelegate.getPrecisionType() == PrecisionType.SCALAR ? new ConditionalOnTipsRealizedDelegate("trait", treeModel, diffusionModel, dataModel, rootPrior, rateTransformation, likelihoodDelegate) : new MultivariateConditionalOnTipsRealizedDelegate("trait", treeModel, diffusionModel, dataModel, rootPrior, rateTransformation, likelihoodDelegate);
    TreeTraitProvider traitProvider = new ProcessSimulation(dataLikelihood, simulationDelegate);
    dataLikelihood.addTraits(traitProvider.getTreeTraits());
    ProcessSimulationDelegate fullConditionalDelegate = new TipRealizedValuesViaFullConditionalDelegate("trait", treeModel, diffusionModel, dataModel, rootPrior, rateTransformation, likelihoodDelegate);
    dataLikelihood.addTraits(new ProcessSimulation(dataLikelihood, fullConditionalDelegate).getTreeTraits());
    // Variance
    ContinuousDataLikelihoodDelegate cdld = (ContinuousDataLikelihoodDelegate) dataLikelihood.getDataLikelihoodDelegate();
    if (precision != null) {
        // Branch Specific
        ContinuousProcessParameterGradient traitGradient = new ContinuousProcessParameterGradient(rootPrior.getMean().length, treeModel, cdld, new ArrayList<>(Arrays.asList(DerivationParameter.WRT_VARIANCE)));
        BranchSpecificGradient branchSpecificGradient = new BranchSpecificGradient("trait", dataLikelihood, cdld, traitGradient, precision);
        GradientWrtPrecisionProvider gPPBranchSpecific = new GradientWrtPrecisionProvider.BranchSpecificGradientWrtPrecisionProvider(branchSpecificGradient);
        // Correlation Gradient Branch Specific
        CorrelationPrecisionGradient gradientProviderBranchSpecific = new CorrelationPrecisionGradient(gPPBranchSpecific, dataLikelihood, precision);
        double[] gradientAnalyticalBS = testOneGradient(gradientProviderBranchSpecific);
        // Diagonal Gradient Branch Specific
        DiagonalPrecisionGradient gradientDiagonalProviderBS = new DiagonalPrecisionGradient(gPPBranchSpecific, dataLikelihood, precision);
        double[] gradientDiagonalAnalyticalBS = testOneGradient(gradientDiagonalProviderBS);
        if (wishart) {
            // Wishart Statistic
            WishartStatisticsWrapper wishartStatistics = new WishartStatisticsWrapper("wishart", "trait", dataLikelihood, cdld);
            GradientWrtPrecisionProvider gPPWiwhart = new GradientWrtPrecisionProvider.WishartGradientWrtPrecisionProvider(wishartStatistics);
            // Correlation Gradient
            CorrelationPrecisionGradient gradientProviderWishart = new CorrelationPrecisionGradient(gPPWiwhart, dataLikelihood, precision);
            String sW = gradientProviderWishart.getReport();
            System.err.println(sW);
            double[] gradientAnalyticalW = parseGradient(sW, "analytic");
            assertEquals("Sizes", gradientAnalyticalW.length, gradientAnalyticalBS.length);
            for (int k = 0; k < gradientAnalyticalW.length; k++) {
                assertEquals("gradient correlation k=" + k, gradientAnalyticalW[k], gradientAnalyticalBS[k], delta);
            }
            // Diagonal Gradient
            DiagonalPrecisionGradient gradientDiagonalProviderW = new DiagonalPrecisionGradient(gPPWiwhart, dataLikelihood, precision);
            String sDiagW = gradientDiagonalProviderW.getReport();
            System.err.println(sDiagW);
            double[] gradientDiagonalAnalyticalW = parseGradient(sDiagW, "analytic");
            assertEquals("Sizes", gradientDiagonalAnalyticalW.length, gradientDiagonalAnalyticalBS.length);
            for (int k = 0; k < gradientDiagonalAnalyticalW.length; k++) {
                assertEquals("gradient diagonal k=" + k, gradientDiagonalAnalyticalW[k], gradientDiagonalAnalyticalBS[k], delta);
            }
        }
    }
    // Diagonal Attenuation Gradient Branch Specific
    if (attenuation != null) {
        ContinuousProcessParameterGradient traitGradientAtt = new ContinuousProcessParameterGradient(dimLocal, treeModel, cdld, new ArrayList<>(Arrays.asList(DerivationParameter.WRT_DIAGONAL_SELECTION_STRENGTH)));
        BranchSpecificGradient branchSpecificGradientAtt = new BranchSpecificGradient("trait", dataLikelihood, cdld, traitGradientAtt, attenuation);
        AbstractDiffusionGradient.ParameterDiffusionGradient gABranchSpecific = createDiagonalAttenuationGradient(branchSpecificGradientAtt, dataLikelihood, attenuation);
        testOneGradient(gABranchSpecific);
    }
    // WRT root mean
    boolean sameRoot = (drift == meanRoot);
    ContinuousProcessParameterGradient traitGradientRoot = new ContinuousProcessParameterGradient(dimLocal, treeModel, cdld, new ArrayList<>(Arrays.asList(sameRoot ? DerivationParameter.WRT_CONSTANT_DRIFT_AND_ROOT_MEAN : DerivationParameter.WRT_ROOT_MEAN)));
    BranchSpecificGradient branchSpecificGradientRoot = new BranchSpecificGradient("trait", dataLikelihood, cdld, traitGradientRoot, meanRoot);
    AbstractDiffusionGradient.ParameterDiffusionGradient gRootBranchSpecific = createDriftGradient(branchSpecificGradientRoot, dataLikelihood, meanRoot);
    testOneGradient(gRootBranchSpecific);
    // Drift Gradient Branch Specific
    if (drift != null && !sameRoot) {
        ContinuousProcessParameterGradient traitGradientDrift = new ContinuousProcessParameterGradient(dimLocal, treeModel, cdld, new ArrayList<>(Arrays.asList(DerivationParameter.WRT_CONSTANT_DRIFT)));
        BranchSpecificGradient branchSpecificGradientDrift = new BranchSpecificGradient("trait", dataLikelihood, cdld, traitGradientDrift, drift);
        AbstractDiffusionGradient.ParameterDiffusionGradient gDriftBranchSpecific = createDriftGradient(branchSpecificGradientDrift, dataLikelihood, drift);
        testOneGradient(gDriftBranchSpecific);
    }
    // Sampling Precision
    if (samplingPrecision != null) {
        ContinuousTraitGradientForBranch.SamplingVarianceGradient traitGradientSampling = new ContinuousTraitGradientForBranch.SamplingVarianceGradient(dimLocal, treeModel, likelihoodDelegate, (ModelExtensionProvider.NormalExtensionProvider) dataModel);
        BranchSpecificGradient branchSpecificGradientSampling = new BranchSpecificGradient("trait", dataLikelihood, cdld, traitGradientSampling, samplingPrecision);
        GradientWrtPrecisionProvider gPPBranchSpecificSampling = new GradientWrtPrecisionProvider.BranchSpecificGradientWrtPrecisionProvider(branchSpecificGradientSampling);
        // Correlation Gradient Branch Specific
        // CorrelationPrecisionGradient gradientProviderBranchSpecificSampling = new CorrelationPrecisionGradient(gPPBranchSpecificSampling, dataLikelihood, samplingPrecision);
        // 
        // testOneGradient(gradientProviderBranchSpecificSampling);
        // Diagonal Gradient Branch Specific
        DiagonalPrecisionGradient gradientDiagonalProviderBSSampling = new DiagonalPrecisionGradient(gPPBranchSpecificSampling, dataLikelihood, samplingPrecision);
        testOneGradient(gradientDiagonalProviderBSSampling);
    }
}
Also used : CorrelationPrecisionGradient(dr.evomodel.treedatalikelihood.hmc.CorrelationPrecisionGradient) DiagonalPrecisionGradient(dr.evomodel.treedatalikelihood.hmc.DiagonalPrecisionGradient) ContinuousProcessParameterGradient(dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient) TreeTraitProvider(dr.evolution.tree.TreeTraitProvider) AbstractDiffusionGradient(dr.evomodel.treedatalikelihood.hmc.AbstractDiffusionGradient) TreeDataLikelihood(dr.evomodel.treedatalikelihood.TreeDataLikelihood) GradientWrtPrecisionProvider(dr.evomodel.treedatalikelihood.hmc.GradientWrtPrecisionProvider) ProcessSimulation(dr.evomodel.treedatalikelihood.ProcessSimulation)

Aggregations

TreeTraitProvider (dr.evolution.tree.TreeTraitProvider)1 ProcessSimulation (dr.evomodel.treedatalikelihood.ProcessSimulation)1 TreeDataLikelihood (dr.evomodel.treedatalikelihood.TreeDataLikelihood)1 ContinuousProcessParameterGradient (dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient)1 AbstractDiffusionGradient (dr.evomodel.treedatalikelihood.hmc.AbstractDiffusionGradient)1 CorrelationPrecisionGradient (dr.evomodel.treedatalikelihood.hmc.CorrelationPrecisionGradient)1 DiagonalPrecisionGradient (dr.evomodel.treedatalikelihood.hmc.DiagonalPrecisionGradient)1 GradientWrtPrecisionProvider (dr.evomodel.treedatalikelihood.hmc.GradientWrtPrecisionProvider)1