use of dr.evomodel.treedatalikelihood.hmc.GradientWrtPrecisionProvider in project beast-mcmc by beast-dev.
the class DiffusionGradientTest method testGradient.
private void testGradient(MultivariateDiffusionModel diffusionModel, DiffusionProcessDelegate diffusionProcessDelegate, ContinuousTraitPartialsProvider dataModel, ConjugateRootTraitPrior rootPrior, Parameter meanRoot, MatrixParameterInterface precision, Boolean wishart, MatrixParameterInterface attenuation, Parameter drift, MatrixParameterInterface samplingPrecision) {
int dimLocal = rootPrior.getMean().length;
// CDL
ContinuousDataLikelihoodDelegate likelihoodDelegate = new ContinuousDataLikelihoodDelegate(treeModel, diffusionProcessDelegate, dataModel, rootPrior, rateTransformation, rateModel, true);
// Likelihood Computation
TreeDataLikelihood dataLikelihood = new TreeDataLikelihood(likelihoodDelegate, treeModel, rateModel);
ProcessSimulationDelegate simulationDelegate = likelihoodDelegate.getPrecisionType() == PrecisionType.SCALAR ? new ConditionalOnTipsRealizedDelegate("trait", treeModel, diffusionModel, dataModel, rootPrior, rateTransformation, likelihoodDelegate) : new MultivariateConditionalOnTipsRealizedDelegate("trait", treeModel, diffusionModel, dataModel, rootPrior, rateTransformation, likelihoodDelegate);
TreeTraitProvider traitProvider = new ProcessSimulation(dataLikelihood, simulationDelegate);
dataLikelihood.addTraits(traitProvider.getTreeTraits());
ProcessSimulationDelegate fullConditionalDelegate = new TipRealizedValuesViaFullConditionalDelegate("trait", treeModel, diffusionModel, dataModel, rootPrior, rateTransformation, likelihoodDelegate);
dataLikelihood.addTraits(new ProcessSimulation(dataLikelihood, fullConditionalDelegate).getTreeTraits());
// Variance
ContinuousDataLikelihoodDelegate cdld = (ContinuousDataLikelihoodDelegate) dataLikelihood.getDataLikelihoodDelegate();
if (precision != null) {
// Branch Specific
ContinuousProcessParameterGradient traitGradient = new ContinuousProcessParameterGradient(rootPrior.getMean().length, treeModel, cdld, new ArrayList<>(Arrays.asList(DerivationParameter.WRT_VARIANCE)));
BranchSpecificGradient branchSpecificGradient = new BranchSpecificGradient("trait", dataLikelihood, cdld, traitGradient, precision);
GradientWrtPrecisionProvider gPPBranchSpecific = new GradientWrtPrecisionProvider.BranchSpecificGradientWrtPrecisionProvider(branchSpecificGradient);
// Correlation Gradient Branch Specific
CorrelationPrecisionGradient gradientProviderBranchSpecific = new CorrelationPrecisionGradient(gPPBranchSpecific, dataLikelihood, precision);
double[] gradientAnalyticalBS = testOneGradient(gradientProviderBranchSpecific);
// Diagonal Gradient Branch Specific
DiagonalPrecisionGradient gradientDiagonalProviderBS = new DiagonalPrecisionGradient(gPPBranchSpecific, dataLikelihood, precision);
double[] gradientDiagonalAnalyticalBS = testOneGradient(gradientDiagonalProviderBS);
if (wishart) {
// Wishart Statistic
WishartStatisticsWrapper wishartStatistics = new WishartStatisticsWrapper("wishart", "trait", dataLikelihood, cdld);
GradientWrtPrecisionProvider gPPWiwhart = new GradientWrtPrecisionProvider.WishartGradientWrtPrecisionProvider(wishartStatistics);
// Correlation Gradient
CorrelationPrecisionGradient gradientProviderWishart = new CorrelationPrecisionGradient(gPPWiwhart, dataLikelihood, precision);
String sW = gradientProviderWishart.getReport();
System.err.println(sW);
double[] gradientAnalyticalW = parseGradient(sW, "analytic");
assertEquals("Sizes", gradientAnalyticalW.length, gradientAnalyticalBS.length);
for (int k = 0; k < gradientAnalyticalW.length; k++) {
assertEquals("gradient correlation k=" + k, gradientAnalyticalW[k], gradientAnalyticalBS[k], delta);
}
// Diagonal Gradient
DiagonalPrecisionGradient gradientDiagonalProviderW = new DiagonalPrecisionGradient(gPPWiwhart, dataLikelihood, precision);
String sDiagW = gradientDiagonalProviderW.getReport();
System.err.println(sDiagW);
double[] gradientDiagonalAnalyticalW = parseGradient(sDiagW, "analytic");
assertEquals("Sizes", gradientDiagonalAnalyticalW.length, gradientDiagonalAnalyticalBS.length);
for (int k = 0; k < gradientDiagonalAnalyticalW.length; k++) {
assertEquals("gradient diagonal k=" + k, gradientDiagonalAnalyticalW[k], gradientDiagonalAnalyticalBS[k], delta);
}
}
}
// Diagonal Attenuation Gradient Branch Specific
if (attenuation != null) {
ContinuousProcessParameterGradient traitGradientAtt = new ContinuousProcessParameterGradient(dimLocal, treeModel, cdld, new ArrayList<>(Arrays.asList(DerivationParameter.WRT_DIAGONAL_SELECTION_STRENGTH)));
BranchSpecificGradient branchSpecificGradientAtt = new BranchSpecificGradient("trait", dataLikelihood, cdld, traitGradientAtt, attenuation);
AbstractDiffusionGradient.ParameterDiffusionGradient gABranchSpecific = createDiagonalAttenuationGradient(branchSpecificGradientAtt, dataLikelihood, attenuation);
testOneGradient(gABranchSpecific);
}
// WRT root mean
boolean sameRoot = (drift == meanRoot);
ContinuousProcessParameterGradient traitGradientRoot = new ContinuousProcessParameterGradient(dimLocal, treeModel, cdld, new ArrayList<>(Arrays.asList(sameRoot ? DerivationParameter.WRT_CONSTANT_DRIFT_AND_ROOT_MEAN : DerivationParameter.WRT_ROOT_MEAN)));
BranchSpecificGradient branchSpecificGradientRoot = new BranchSpecificGradient("trait", dataLikelihood, cdld, traitGradientRoot, meanRoot);
AbstractDiffusionGradient.ParameterDiffusionGradient gRootBranchSpecific = createDriftGradient(branchSpecificGradientRoot, dataLikelihood, meanRoot);
testOneGradient(gRootBranchSpecific);
// Drift Gradient Branch Specific
if (drift != null && !sameRoot) {
ContinuousProcessParameterGradient traitGradientDrift = new ContinuousProcessParameterGradient(dimLocal, treeModel, cdld, new ArrayList<>(Arrays.asList(DerivationParameter.WRT_CONSTANT_DRIFT)));
BranchSpecificGradient branchSpecificGradientDrift = new BranchSpecificGradient("trait", dataLikelihood, cdld, traitGradientDrift, drift);
AbstractDiffusionGradient.ParameterDiffusionGradient gDriftBranchSpecific = createDriftGradient(branchSpecificGradientDrift, dataLikelihood, drift);
testOneGradient(gDriftBranchSpecific);
}
// Sampling Precision
if (samplingPrecision != null) {
ContinuousTraitGradientForBranch.SamplingVarianceGradient traitGradientSampling = new ContinuousTraitGradientForBranch.SamplingVarianceGradient(dimLocal, treeModel, likelihoodDelegate, (ModelExtensionProvider.NormalExtensionProvider) dataModel);
BranchSpecificGradient branchSpecificGradientSampling = new BranchSpecificGradient("trait", dataLikelihood, cdld, traitGradientSampling, samplingPrecision);
GradientWrtPrecisionProvider gPPBranchSpecificSampling = new GradientWrtPrecisionProvider.BranchSpecificGradientWrtPrecisionProvider(branchSpecificGradientSampling);
// Correlation Gradient Branch Specific
// CorrelationPrecisionGradient gradientProviderBranchSpecificSampling = new CorrelationPrecisionGradient(gPPBranchSpecificSampling, dataLikelihood, samplingPrecision);
//
// testOneGradient(gradientProviderBranchSpecificSampling);
// Diagonal Gradient Branch Specific
DiagonalPrecisionGradient gradientDiagonalProviderBSSampling = new DiagonalPrecisionGradient(gPPBranchSpecificSampling, dataLikelihood, samplingPrecision);
testOneGradient(gradientDiagonalProviderBSSampling);
}
}
Aggregations