use of dr.inference.hmc.GradientWrtParameterProvider in project beast-mcmc by beast-dev.
the class TaskPoolParser method parseXMLObject.
@Override
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
Tree tree = (Tree) xo.getChild(Tree.class);
GradientWrtParameterProvider gradient = (GradientWrtParameterProvider) xo.getChild(GradientWrtParameterProvider.class);
int taskCount = (tree != null) ? tree.getExternalNodeCount() : gradient.getDimension();
int threadCount = xo.getAttribute(THREAD_COUNT, 1);
return new TaskPool(taskCount, threadCount);
}
use of dr.inference.hmc.GradientWrtParameterProvider in project beast-mcmc by beast-dev.
the class BranchRateGradientParser method parseXMLObject.
@Override
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
String traitName = xo.getAttribute(TRAIT_NAME, DEFAULT_TRAIT_NAME);
boolean useHessian = xo.getAttribute(USE_HESSIAN, false);
final Object child = xo.getChild(TreeDataLikelihood.class);
if (child != null) {
return parseTreeDataLikelihood((TreeDataLikelihood) child, traitName, useHessian);
} else {
CompoundLikelihood compoundLikelihood = (CompoundLikelihood) xo.getChild(CompoundLikelihood.class);
List<GradientWrtParameterProvider> providers = new ArrayList<>();
for (Likelihood likelihood : compoundLikelihood.getLikelihoods()) {
if (!(likelihood instanceof TreeDataLikelihood)) {
throw new XMLParseException("Unknown likelihood type");
}
GradientWrtParameterProvider provider = parseTreeDataLikelihood((TreeDataLikelihood) likelihood, traitName, useHessian);
providers.add(provider);
}
checkBranchRateModels(providers);
return new SumDerivative(providers);
}
}
use of dr.inference.hmc.GradientWrtParameterProvider in project beast-mcmc by beast-dev.
the class CompoundGradientParser method parseXMLObject.
@Override
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
List<GradientWrtParameterProvider> gradList = new ArrayList<GradientWrtParameterProvider>();
// TODO Remove?
List<Likelihood> likelihoodList = new ArrayList<Likelihood>();
for (int i = 0; i < xo.getChildCount(); ++i) {
Object obj = xo.getChild(i);
GradientWrtParameterProvider grad;
Likelihood likelihood;
if (obj instanceof DistributionLikelihood) {
DistributionLikelihood dl = (DistributionLikelihood) obj;
if (!(dl.getDistribution() instanceof GradientProvider)) {
throw new XMLParseException("Not a gradient provider");
}
throw new RuntimeException("Not yet implemented");
} else if (obj instanceof MultivariateDistributionLikelihood) {
final MultivariateDistributionLikelihood mdl = (MultivariateDistributionLikelihood) obj;
if (!(mdl.getDistribution() instanceof GradientProvider)) {
throw new XMLParseException("Not a gradient provider");
}
final GradientProvider provider = (GradientProvider) mdl.getDistribution();
final Parameter parameter = mdl.getDataParameter();
likelihood = mdl;
grad = new GradientWrtParameterProvider.ParameterWrapper(provider, parameter, mdl);
} else if (obj instanceof GradientWrtParameterProvider) {
grad = (GradientWrtParameterProvider) obj;
likelihood = grad.getLikelihood();
} else {
throw new XMLParseException("Not a Gaussian process");
}
gradList.add(grad);
likelihoodList.add(likelihood);
}
return new CompoundDerivative(gradList);
}
use of dr.inference.hmc.GradientWrtParameterProvider in project beast-mcmc by beast-dev.
the class PathGradientParser method parseXMLObject.
@Override
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
GradientWrtParameterProvider source = (GradientWrtParameterProvider) xo.getElementFirstChild(PathLikelihood.SOURCE);
GradientWrtParameterProvider destination = (GradientWrtParameterProvider) xo.getElementFirstChild(PathLikelihood.DESTINATION);
return new PathGradient(source, destination);
}
use of dr.inference.hmc.GradientWrtParameterProvider in project beast-mcmc by beast-dev.
the class SumDerivativeParser method parseXMLObject.
@Override
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
List<GradientWrtParameterProvider> derivativeList = new ArrayList<GradientWrtParameterProvider>();
for (int i = 0; i < xo.getChildCount(); i++) {
GradientWrtParameterProvider grad = (GradientWrtParameterProvider) xo.getChild(i);
derivativeList.add(grad);
}
return new SumDerivative(derivativeList);
}
Aggregations