use of dr.inference.hmc.GradientWrtParameterProvider in project beast-mcmc by beast-dev.
the class AppendedPotentialDerivativeParser method parseXMLObject.
@Override
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
List<GradientWrtParameterProvider> gradList = new ArrayList<GradientWrtParameterProvider>();
// TODO Remove?
List<Likelihood> likelihoodList = new ArrayList<Likelihood>();
for (int i = 0; i < xo.getChildCount(); ++i) {
Object obj = xo.getChild(i);
GradientWrtParameterProvider grad;
Likelihood likelihood;
if (obj instanceof DistributionLikelihood) {
DistributionLikelihood dl = (DistributionLikelihood) obj;
if (!(dl.getDistribution() instanceof GradientProvider)) {
throw new XMLParseException("Not a gradient provider");
}
throw new RuntimeException("Not yet implemented");
} else if (obj instanceof MultivariateDistributionLikelihood) {
final MultivariateDistributionLikelihood mdl = (MultivariateDistributionLikelihood) obj;
if (!(mdl.getDistribution() instanceof GradientProvider)) {
throw new XMLParseException("Not a gradient provider");
}
final GradientProvider provider = (GradientProvider) mdl.getDistribution();
final Parameter parameter = mdl.getDataParameter();
likelihood = mdl;
grad = new GradientWrtParameterProvider.ParameterWrapper(provider, parameter, mdl);
} else if (obj instanceof GradientWrtParameterProvider) {
grad = (GradientWrtParameterProvider) obj;
likelihood = grad.getLikelihood();
} else {
throw new XMLParseException("Not a Gaussian process");
}
gradList.add(grad);
likelihoodList.add(likelihood);
}
return new CompoundGradient(gradList);
}
use of dr.inference.hmc.GradientWrtParameterProvider in project beast-mcmc by beast-dev.
the class CompoundGradient method getGradientLogDensity.
// @Override
// public void getGradientLogDensity(final double[] destination, final int offset) {
// double[] grad = getGradientLogDensity();
// System.arraycopy(grad, 0, destination, offset, grad.length);
// }
@Override
public double[] getGradientLogDensity() {
double[] result = new double[dimension];
int offset = 0;
for (GradientWrtParameterProvider grad : derivativeList) {
double[] tmp = grad.getGradientLogDensity();
System.arraycopy(tmp, 0, result, offset, grad.getDimension());
offset += grad.getDimension();
}
return result;
}
use of dr.inference.hmc.GradientWrtParameterProvider in project beast-mcmc by beast-dev.
the class DiffusionParametersGradient method checkAndSetParametersGradients.
private int checkAndSetParametersGradients(CompoundGradient parametersGradients, CompoundParameter parameter) {
int offset = 0;
int dim = 0;
int dimTrait = likelihood.getDataLikelihoodDelegate().getTraitDim();
for (GradientWrtParameterProvider gradient : parametersGradients.getDerivativeList()) {
assert gradient instanceof AbstractDiffusionGradient : "Gradients must all be instances of AbstractDiffusionGradient.";
((AbstractDiffusionGradient) gradient).setOffset(offset);
parameter.addParameter(gradient.getParameter());
offset += ((AbstractDiffusionGradient) gradient).getDerivationParameter().getDimension(dimTrait);
dim += gradient.getDimension();
}
return dim;
}
use of dr.inference.hmc.GradientWrtParameterProvider in project beast-mcmc by beast-dev.
the class DiffusionParametersGradient method getGradientLogDensity.
private double[] getGradientLogDensity(double[] gradient) {
double[] result = new double[dim];
int offset = 0;
for (GradientWrtParameterProvider gradientProvider : parametersGradients.getDerivativeList()) {
System.arraycopy(((AbstractDiffusionGradient) gradientProvider).getGradientLogDensity(gradient), 0, result, offset, gradientProvider.getDimension());
offset += gradientProvider.getDimension();
}
return result;
}
use of dr.inference.hmc.GradientWrtParameterProvider in project beast-mcmc by beast-dev.
the class LocationScaleGradientParser method parseXMLObject.
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
String traitName = xo.getAttribute(TRAIT_NAME, DEFAULT_TRAIT_NAME);
boolean useHessian = xo.getAttribute(USE_HESSIAN, false);
final Object child = xo.getChild(TreeDataLikelihood.class);
if (child != null) {
return parseTreeDataLikelihood(xo, (TreeDataLikelihood) child, traitName, useHessian);
} else {
CompoundLikelihood compoundLikelihood = (CompoundLikelihood) xo.getChild(CompoundLikelihood.class);
List<GradientWrtParameterProvider> providers = new ArrayList<>();
for (Likelihood likelihood : compoundLikelihood.getLikelihoods()) {
if (!(likelihood instanceof TreeDataLikelihood)) {
throw new XMLParseException("Unknown likelihood type");
}
GradientWrtParameterProvider provider = parseTreeDataLikelihood(xo, (TreeDataLikelihood) likelihood, traitName, useHessian);
providers.add(provider);
}
checkBranchRateModels(providers);
return new SumDerivative(providers);
}
}
Aggregations