use of dr.inference.hmc.GradientWrtParameterProvider in project beast-mcmc by beast-dev.
the class HamiltonianMonteCarloOperatorParser method parseXMLObject.
@Override
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
double weight = xo.getDoubleAttribute(MCMCOperator.WEIGHT);
int nSteps = xo.getAttribute(N_STEPS, 10);
double stepSize = xo.getDoubleAttribute(STEP_SIZE);
int runMode = parseRunMode(xo);
MassPreconditioner.Type preconditioningType = parsePreconditioning(xo);
double randomStepFraction = Math.abs(xo.getAttribute(RANDOM_STEP_FRACTION, 0.0));
if (randomStepFraction > 1) {
throw new XMLParseException("Random step count fraction must be < 1.0");
}
int preconditioningUpdateFrequency = xo.getAttribute(PRECONDITIONING_UPDATE_FREQUENCY, 0);
int preconditioningDelay = xo.getAttribute(PRECONDITIONING_DELAY, 0);
int preconditioningMemory = xo.getAttribute(PRECONDITIONING_MEMORY, 0);
AdaptationMode adaptationMode = AdaptationMode.parseMode(xo);
GradientWrtParameterProvider derivative = (GradientWrtParameterProvider) xo.getChild(GradientWrtParameterProvider.class);
if (preconditioningType != MassPreconditioner.Type.NONE && !(derivative instanceof HessianWrtParameterProvider)) {
throw new XMLParseException("Unable precondition without a Hessian provider");
}
Parameter parameter = (Parameter) xo.getChild(Parameter.class);
if (parameter == null) {
parameter = derivative.getParameter();
}
Transform transform = parseTransform(xo);
boolean dimensionMismatch = derivative.getDimension() != parameter.getDimension();
if (transform != null && transform instanceof Transform.MultivariableTransform) {
dimensionMismatch = ((Transform.MultivariableTransform) transform).getDimension() != parameter.getDimension();
}
if (dimensionMismatch) {
throw new XMLParseException("Gradient (" + derivative.getDimension() + ") must be the same dimensions as the parameter (" + parameter.getDimension() + ")");
}
Parameter mask = null;
if (xo.hasChildNamed(MASK)) {
mask = (Parameter) xo.getElementFirstChild(MASK);
if (mask.getDimension() != derivative.getDimension()) {
throw new XMLParseException("Mask (" + mask.getDimension() + ") must be the same dimension as the gradient (" + derivative.getDimension() + ")");
}
}
int gradientCheckCount = xo.getAttribute(GRADIENT_CHECK_COUNT, 0);
double gradientCheckTolerance = xo.getAttribute(GRADIENT_CHECK_TOLERANCE, 1E-3);
int maxIterations = xo.getAttribute(MAX_ITERATIONS, 10);
double reductionFactor = xo.getAttribute(REDUCTION_FACTOR, 0.1);
double targetAcceptanceProbability = xo.getAttribute(TARGET_ACCEPTANCE_PROBABILITY, // Stan default
0.8);
HamiltonianMonteCarloOperator.Options runtimeOptions = new HamiltonianMonteCarloOperator.Options(stepSize, nSteps, randomStepFraction, preconditioningUpdateFrequency, preconditioningDelay, preconditioningMemory, gradientCheckCount, gradientCheckTolerance, maxIterations, reductionFactor, targetAcceptanceProbability);
return factory(adaptationMode, weight, derivative, parameter, transform, mask, runtimeOptions, preconditioningType, runMode);
}
use of dr.inference.hmc.GradientWrtParameterProvider in project beast-mcmc by beast-dev.
the class DiffusionGradientParser method parseXMLObject.
@Override
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
String traitName = xo.getAttribute(TRAIT_NAME, DEFAULT_TRAIT_NAME);
List<ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient.DerivationParameter> derivationParametersList = new ArrayList<ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient.DerivationParameter>();
CompoundParameter compoundParameter = new CompoundParameter(null);
List<GradientWrtParameterProvider> derivativeList = new ArrayList<GradientWrtParameterProvider>();
List<AbstractDiffusionGradient> diffGradients = xo.getAllChildren(AbstractDiffusionGradient.class);
if (diffGradients != null) {
for (AbstractDiffusionGradient grad : diffGradients) {
derivationParametersList.add(grad.getDerivationParameter());
compoundParameter.addParameter(grad.getRawParameter());
derivativeList.add(grad);
}
}
CompoundGradient parametersGradients = new CompoundDerivative(derivativeList);
// testSameModel(precisionGradient, attenuationGradient);
TreeDataLikelihood treeDataLikelihood = ((TreeDataLikelihood) diffGradients.get(0).getLikelihood());
DataLikelihoodDelegate delegate = treeDataLikelihood.getDataLikelihoodDelegate();
int dim = treeDataLikelihood.getDataLikelihoodDelegate().getTraitDim();
Tree tree = treeDataLikelihood.getTree();
ContinuousDataLikelihoodDelegate continuousData = (ContinuousDataLikelihoodDelegate) delegate;
ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient traitGradient = new ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient(dim, tree, continuousData, derivationParametersList);
BranchSpecificGradient branchSpecificGradient = new BranchSpecificGradient(traitName, treeDataLikelihood, continuousData, traitGradient, compoundParameter);
return new DiffusionParametersGradient(branchSpecificGradient, parametersGradients);
}
use of dr.inference.hmc.GradientWrtParameterProvider in project beast-mcmc by beast-dev.
the class MaximizeWrtParameterParser method parseXMLObject.
@Override
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
GradientWrtParameterProvider gradient = (GradientWrtParameterProvider) xo.getChild(GradientWrtParameterProvider.class);
Parameter parameter;
Likelihood likelihood;
int nIterations = Math.abs(xo.getAttribute(N_ITERATIONS, 0));
boolean initialGuess = xo.getAttribute(INITIAL_GUESS, true);
boolean printScreen = xo.getAttribute(PRINT_SCREEN, false);
if (gradient != null) {
parameter = gradient.getParameter();
likelihood = gradient.getLikelihood();
} else {
XMLObject cxo = xo.getChild(DENSITY);
parameter = (Parameter) cxo.getChild(Parameter.class);
likelihood = (Likelihood) cxo.getChild(Likelihood.class);
}
Transform transform = (Transform) xo.getChild(Transform.class);
MaximizerWrtParameter maximizer = new MaximizerWrtParameter(likelihood, parameter, gradient, transform, new MaximizerWrtParameter.Settings(nIterations, initialGuess, printScreen));
maximizer.maximize();
return maximizer;
}
use of dr.inference.hmc.GradientWrtParameterProvider in project beast-mcmc by beast-dev.
the class ZigZagOperatorParser method parseXMLObject.
@Override
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
double weight = xo.getDoubleAttribute(MCMCOperator.WEIGHT);
GradientWrtParameterProvider derivative = (GradientWrtParameterProvider) xo.getChild(GradientWrtParameterProvider.class);
PrecisionMatrixVectorProductProvider productProvider = (PrecisionMatrixVectorProductProvider) xo.getChild(PrecisionMatrixVectorProductProvider.class);
PrecisionColumnProvider columnProvider = (PrecisionColumnProvider) xo.getChild(PrecisionColumnProvider.class);
Parameter mask = parseMask(xo);
AbstractParticleOperator.Options runtimeOptions = parseRuntimeOptions(xo);
int threadCount = xo.getAttribute(THREAD_COUNT, 1);
boolean reversible = xo.getAttribute(REVERSIBLE_FLG, true);
if (reversible) {
return new ReversibleZigZagOperator(derivative, productProvider, columnProvider, weight, runtimeOptions, mask, threadCount);
} else {
return new IrreversibleZigZagOperator(derivative, productProvider, columnProvider, weight, runtimeOptions, mask, threadCount);
}
}
use of dr.inference.hmc.GradientWrtParameterProvider in project beast-mcmc by beast-dev.
the class BranchRateGradientWrtIncrementsParser method parseXMLObject.
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
AutoCorrelatedGradientWrtIncrements priorProvider = (AutoCorrelatedGradientWrtIncrements) xo.getChild(AutoCorrelatedGradientWrtIncrements.class);
GradientWrtParameterProvider rateProvider = (GradientWrtParameterProvider) xo.getChild(GradientWrtParameterProvider.class);
if (!(rateProvider instanceof BranchRateGradient) && !(rateProvider instanceof BranchRateGradientForDiscreteTrait)) {
throw new XMLParseException("Must provide a branch rate gradient");
}
return new BranchRateGradientWrtIncrements(rateProvider, priorProvider);
}
Aggregations