use of dr.inference.hmc.CompoundGradient in project beast-mcmc by beast-dev.
the class AppendedPotentialDerivativeParser method parseXMLObject.
@Override
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
List<GradientWrtParameterProvider> gradList = new ArrayList<GradientWrtParameterProvider>();
// TODO Remove?
List<Likelihood> likelihoodList = new ArrayList<Likelihood>();
for (int i = 0; i < xo.getChildCount(); ++i) {
Object obj = xo.getChild(i);
GradientWrtParameterProvider grad;
Likelihood likelihood;
if (obj instanceof DistributionLikelihood) {
DistributionLikelihood dl = (DistributionLikelihood) obj;
if (!(dl.getDistribution() instanceof GradientProvider)) {
throw new XMLParseException("Not a gradient provider");
}
throw new RuntimeException("Not yet implemented");
} else if (obj instanceof MultivariateDistributionLikelihood) {
final MultivariateDistributionLikelihood mdl = (MultivariateDistributionLikelihood) obj;
if (!(mdl.getDistribution() instanceof GradientProvider)) {
throw new XMLParseException("Not a gradient provider");
}
final GradientProvider provider = (GradientProvider) mdl.getDistribution();
final Parameter parameter = mdl.getDataParameter();
likelihood = mdl;
grad = new GradientWrtParameterProvider.ParameterWrapper(provider, parameter, mdl);
} else if (obj instanceof GradientWrtParameterProvider) {
grad = (GradientWrtParameterProvider) obj;
likelihood = grad.getLikelihood();
} else {
throw new XMLParseException("Not a Gaussian process");
}
gradList.add(grad);
likelihoodList.add(likelihood);
}
return new CompoundGradient(gradList);
}
use of dr.inference.hmc.CompoundGradient in project beast-mcmc by beast-dev.
the class DiffusionGradientParser method parseXMLObject.
@Override
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
String traitName = xo.getAttribute(TRAIT_NAME, DEFAULT_TRAIT_NAME);
List<ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient.DerivationParameter> derivationParametersList = new ArrayList<ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient.DerivationParameter>();
CompoundParameter compoundParameter = new CompoundParameter(null);
List<GradientWrtParameterProvider> derivativeList = new ArrayList<GradientWrtParameterProvider>();
List<AbstractDiffusionGradient> diffGradients = xo.getAllChildren(AbstractDiffusionGradient.class);
if (diffGradients != null) {
for (AbstractDiffusionGradient grad : diffGradients) {
derivationParametersList.add(grad.getDerivationParameter());
compoundParameter.addParameter(grad.getRawParameter());
derivativeList.add(grad);
}
}
CompoundGradient parametersGradients = new CompoundDerivative(derivativeList);
// testSameModel(precisionGradient, attenuationGradient);
TreeDataLikelihood treeDataLikelihood = ((TreeDataLikelihood) diffGradients.get(0).getLikelihood());
DataLikelihoodDelegate delegate = treeDataLikelihood.getDataLikelihoodDelegate();
int dim = treeDataLikelihood.getDataLikelihoodDelegate().getTraitDim();
Tree tree = treeDataLikelihood.getTree();
ContinuousDataLikelihoodDelegate continuousData = (ContinuousDataLikelihoodDelegate) delegate;
ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient traitGradient = new ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient(dim, tree, continuousData, derivationParametersList);
BranchSpecificGradient branchSpecificGradient = new BranchSpecificGradient(traitName, treeDataLikelihood, continuousData, traitGradient, compoundParameter);
return new DiffusionParametersGradient(branchSpecificGradient, parametersGradients);
}
Aggregations