Search in sources :

Example 11 with GradientWrtParameterProvider

use of dr.inference.hmc.GradientWrtParameterProvider in project beast-mcmc by beast-dev.

the class HamiltonianMonteCarloOperatorParser method parseXMLObject.

@Override
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
    double weight = xo.getDoubleAttribute(MCMCOperator.WEIGHT);
    int nSteps = xo.getAttribute(N_STEPS, 10);
    double stepSize = xo.getDoubleAttribute(STEP_SIZE);
    int runMode = parseRunMode(xo);
    MassPreconditioner.Type preconditioningType = parsePreconditioning(xo);
    double randomStepFraction = Math.abs(xo.getAttribute(RANDOM_STEP_FRACTION, 0.0));
    if (randomStepFraction > 1) {
        throw new XMLParseException("Random step count fraction must be < 1.0");
    }
    int preconditioningUpdateFrequency = xo.getAttribute(PRECONDITIONING_UPDATE_FREQUENCY, 0);
    int preconditioningDelay = xo.getAttribute(PRECONDITIONING_DELAY, 0);
    int preconditioningMemory = xo.getAttribute(PRECONDITIONING_MEMORY, 0);
    AdaptationMode adaptationMode = AdaptationMode.parseMode(xo);
    GradientWrtParameterProvider derivative = (GradientWrtParameterProvider) xo.getChild(GradientWrtParameterProvider.class);
    if (preconditioningType != MassPreconditioner.Type.NONE && !(derivative instanceof HessianWrtParameterProvider)) {
        throw new XMLParseException("Unable precondition without a Hessian provider");
    }
    Parameter parameter = (Parameter) xo.getChild(Parameter.class);
    if (parameter == null) {
        parameter = derivative.getParameter();
    }
    Transform transform = parseTransform(xo);
    boolean dimensionMismatch = derivative.getDimension() != parameter.getDimension();
    if (transform != null && transform instanceof Transform.MultivariableTransform) {
        dimensionMismatch = ((Transform.MultivariableTransform) transform).getDimension() != parameter.getDimension();
    }
    if (dimensionMismatch) {
        throw new XMLParseException("Gradient (" + derivative.getDimension() + ") must be the same dimensions as the parameter (" + parameter.getDimension() + ")");
    }
    Parameter mask = null;
    if (xo.hasChildNamed(MASK)) {
        mask = (Parameter) xo.getElementFirstChild(MASK);
        if (mask.getDimension() != derivative.getDimension()) {
            throw new XMLParseException("Mask (" + mask.getDimension() + ") must be the same dimension as the gradient (" + derivative.getDimension() + ")");
        }
    }
    int gradientCheckCount = xo.getAttribute(GRADIENT_CHECK_COUNT, 0);
    double gradientCheckTolerance = xo.getAttribute(GRADIENT_CHECK_TOLERANCE, 1E-3);
    int maxIterations = xo.getAttribute(MAX_ITERATIONS, 10);
    double reductionFactor = xo.getAttribute(REDUCTION_FACTOR, 0.1);
    double targetAcceptanceProbability = xo.getAttribute(TARGET_ACCEPTANCE_PROBABILITY, // Stan default
    0.8);
    HamiltonianMonteCarloOperator.Options runtimeOptions = new HamiltonianMonteCarloOperator.Options(stepSize, nSteps, randomStepFraction, preconditioningUpdateFrequency, preconditioningDelay, preconditioningMemory, gradientCheckCount, gradientCheckTolerance, maxIterations, reductionFactor, targetAcceptanceProbability);
    return factory(adaptationMode, weight, derivative, parameter, transform, mask, runtimeOptions, preconditioningType, runMode);
}
Also used : HessianWrtParameterProvider(dr.inference.hmc.HessianWrtParameterProvider) MassPreconditioner(dr.inference.operators.hmc.MassPreconditioner) AdaptationMode(dr.inference.operators.AdaptationMode) GradientWrtParameterProvider(dr.inference.hmc.GradientWrtParameterProvider) Parameter(dr.inference.model.Parameter) Transform(dr.util.Transform) Util.parseTransform(dr.util.Transform.Util.parseTransform) HamiltonianMonteCarloOperator(dr.inference.operators.hmc.HamiltonianMonteCarloOperator)

Example 12 with GradientWrtParameterProvider

use of dr.inference.hmc.GradientWrtParameterProvider in project beast-mcmc by beast-dev.

the class DiffusionGradientParser method parseXMLObject.

@Override
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
    String traitName = xo.getAttribute(TRAIT_NAME, DEFAULT_TRAIT_NAME);
    List<ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient.DerivationParameter> derivationParametersList = new ArrayList<ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient.DerivationParameter>();
    CompoundParameter compoundParameter = new CompoundParameter(null);
    List<GradientWrtParameterProvider> derivativeList = new ArrayList<GradientWrtParameterProvider>();
    List<AbstractDiffusionGradient> diffGradients = xo.getAllChildren(AbstractDiffusionGradient.class);
    if (diffGradients != null) {
        for (AbstractDiffusionGradient grad : diffGradients) {
            derivationParametersList.add(grad.getDerivationParameter());
            compoundParameter.addParameter(grad.getRawParameter());
            derivativeList.add(grad);
        }
    }
    CompoundGradient parametersGradients = new CompoundDerivative(derivativeList);
    // testSameModel(precisionGradient, attenuationGradient);
    TreeDataLikelihood treeDataLikelihood = ((TreeDataLikelihood) diffGradients.get(0).getLikelihood());
    DataLikelihoodDelegate delegate = treeDataLikelihood.getDataLikelihoodDelegate();
    int dim = treeDataLikelihood.getDataLikelihoodDelegate().getTraitDim();
    Tree tree = treeDataLikelihood.getTree();
    ContinuousDataLikelihoodDelegate continuousData = (ContinuousDataLikelihoodDelegate) delegate;
    ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient traitGradient = new ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient(dim, tree, continuousData, derivationParametersList);
    BranchSpecificGradient branchSpecificGradient = new BranchSpecificGradient(traitName, treeDataLikelihood, continuousData, traitGradient, compoundParameter);
    return new DiffusionParametersGradient(branchSpecificGradient, parametersGradients);
}
Also used : CompoundGradient(dr.inference.hmc.CompoundGradient) BranchSpecificGradient(dr.evomodel.treedatalikelihood.continuous.BranchSpecificGradient) AbstractDiffusionGradient(dr.evomodel.treedatalikelihood.hmc.AbstractDiffusionGradient) ArrayList(java.util.ArrayList) CompoundParameter(dr.inference.model.CompoundParameter) DiffusionParametersGradient(dr.evomodel.treedatalikelihood.hmc.DiffusionParametersGradient) TreeDataLikelihood(dr.evomodel.treedatalikelihood.TreeDataLikelihood) ContinuousDataLikelihoodDelegate(dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate) DataLikelihoodDelegate(dr.evomodel.treedatalikelihood.DataLikelihoodDelegate) CompoundDerivative(dr.inference.hmc.CompoundDerivative) GradientWrtParameterProvider(dr.inference.hmc.GradientWrtParameterProvider) Tree(dr.evolution.tree.Tree) ContinuousTraitGradientForBranch(dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch) ContinuousDataLikelihoodDelegate(dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate)

Example 13 with GradientWrtParameterProvider

use of dr.inference.hmc.GradientWrtParameterProvider in project beast-mcmc by beast-dev.

the class MaximizeWrtParameterParser method parseXMLObject.

@Override
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
    GradientWrtParameterProvider gradient = (GradientWrtParameterProvider) xo.getChild(GradientWrtParameterProvider.class);
    Parameter parameter;
    Likelihood likelihood;
    int nIterations = Math.abs(xo.getAttribute(N_ITERATIONS, 0));
    boolean initialGuess = xo.getAttribute(INITIAL_GUESS, true);
    boolean printScreen = xo.getAttribute(PRINT_SCREEN, false);
    if (gradient != null) {
        parameter = gradient.getParameter();
        likelihood = gradient.getLikelihood();
    } else {
        XMLObject cxo = xo.getChild(DENSITY);
        parameter = (Parameter) cxo.getChild(Parameter.class);
        likelihood = (Likelihood) cxo.getChild(Likelihood.class);
    }
    Transform transform = (Transform) xo.getChild(Transform.class);
    MaximizerWrtParameter maximizer = new MaximizerWrtParameter(likelihood, parameter, gradient, transform, new MaximizerWrtParameter.Settings(nIterations, initialGuess, printScreen));
    maximizer.maximize();
    return maximizer;
}
Also used : Likelihood(dr.inference.model.Likelihood) GradientWrtParameterProvider(dr.inference.hmc.GradientWrtParameterProvider) Parameter(dr.inference.model.Parameter) MaximizerWrtParameter(dr.evomodel.treedatalikelihood.discrete.MaximizerWrtParameter) Transform(dr.util.Transform) MaximizerWrtParameter(dr.evomodel.treedatalikelihood.discrete.MaximizerWrtParameter)

Example 14 with GradientWrtParameterProvider

use of dr.inference.hmc.GradientWrtParameterProvider in project beast-mcmc by beast-dev.

the class ZigZagOperatorParser method parseXMLObject.

@Override
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
    double weight = xo.getDoubleAttribute(MCMCOperator.WEIGHT);
    GradientWrtParameterProvider derivative = (GradientWrtParameterProvider) xo.getChild(GradientWrtParameterProvider.class);
    PrecisionMatrixVectorProductProvider productProvider = (PrecisionMatrixVectorProductProvider) xo.getChild(PrecisionMatrixVectorProductProvider.class);
    PrecisionColumnProvider columnProvider = (PrecisionColumnProvider) xo.getChild(PrecisionColumnProvider.class);
    Parameter mask = parseMask(xo);
    AbstractParticleOperator.Options runtimeOptions = parseRuntimeOptions(xo);
    int threadCount = xo.getAttribute(THREAD_COUNT, 1);
    boolean reversible = xo.getAttribute(REVERSIBLE_FLG, true);
    if (reversible) {
        return new ReversibleZigZagOperator(derivative, productProvider, columnProvider, weight, runtimeOptions, mask, threadCount);
    } else {
        return new IrreversibleZigZagOperator(derivative, productProvider, columnProvider, weight, runtimeOptions, mask, threadCount);
    }
}
Also used : PrecisionColumnProvider(dr.inference.hmc.PrecisionColumnProvider) GradientWrtParameterProvider(dr.inference.hmc.GradientWrtParameterProvider) Parameter(dr.inference.model.Parameter) ReversibleZigZagOperator(dr.inference.operators.hmc.ReversibleZigZagOperator) PrecisionMatrixVectorProductProvider(dr.inference.hmc.PrecisionMatrixVectorProductProvider) AbstractParticleOperator(dr.inference.operators.hmc.AbstractParticleOperator) IrreversibleZigZagOperator(dr.inference.operators.hmc.IrreversibleZigZagOperator)

Example 15 with GradientWrtParameterProvider

use of dr.inference.hmc.GradientWrtParameterProvider in project beast-mcmc by beast-dev.

the class BranchRateGradientWrtIncrementsParser method parseXMLObject.

public Object parseXMLObject(XMLObject xo) throws XMLParseException {
    AutoCorrelatedGradientWrtIncrements priorProvider = (AutoCorrelatedGradientWrtIncrements) xo.getChild(AutoCorrelatedGradientWrtIncrements.class);
    GradientWrtParameterProvider rateProvider = (GradientWrtParameterProvider) xo.getChild(GradientWrtParameterProvider.class);
    if (!(rateProvider instanceof BranchRateGradient) && !(rateProvider instanceof BranchRateGradientForDiscreteTrait)) {
        throw new XMLParseException("Must provide a branch rate gradient");
    }
    return new BranchRateGradientWrtIncrements(rateProvider, priorProvider);
}
Also used : BranchRateGradient(dr.evomodel.treedatalikelihood.continuous.BranchRateGradient) BranchRateGradientForDiscreteTrait(dr.evomodel.treedatalikelihood.discrete.BranchRateGradientForDiscreteTrait) GradientWrtParameterProvider(dr.inference.hmc.GradientWrtParameterProvider) AutoCorrelatedGradientWrtIncrements(dr.evomodel.branchratemodel.AutoCorrelatedGradientWrtIncrements) BranchRateGradientWrtIncrements(dr.evomodel.branchratemodel.BranchRateGradientWrtIncrements)

Aggregations

GradientWrtParameterProvider (dr.inference.hmc.GradientWrtParameterProvider)16 ArrayList (java.util.ArrayList)6 Parameter (dr.inference.model.Parameter)4 TreeDataLikelihood (dr.evomodel.treedatalikelihood.TreeDataLikelihood)3 SumDerivative (dr.inference.hmc.SumDerivative)3 Likelihood (dr.inference.model.Likelihood)3 Tree (dr.evolution.tree.Tree)2 DistributionLikelihood (dr.inference.distribution.DistributionLikelihood)2 MultivariateDistributionLikelihood (dr.inference.distribution.MultivariateDistributionLikelihood)2 CompoundDerivative (dr.inference.hmc.CompoundDerivative)2 CompoundGradient (dr.inference.hmc.CompoundGradient)2 CompoundLikelihood (dr.inference.model.CompoundLikelihood)2 Transform (dr.util.Transform)2 AutoCorrelatedGradientWrtIncrements (dr.evomodel.branchratemodel.AutoCorrelatedGradientWrtIncrements)1 BranchRateGradientWrtIncrements (dr.evomodel.branchratemodel.BranchRateGradientWrtIncrements)1 DataLikelihoodDelegate (dr.evomodel.treedatalikelihood.DataLikelihoodDelegate)1 BranchRateGradient (dr.evomodel.treedatalikelihood.continuous.BranchRateGradient)1 BranchSpecificGradient (dr.evomodel.treedatalikelihood.continuous.BranchSpecificGradient)1 ContinuousDataLikelihoodDelegate (dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate)1 ContinuousTraitGradientForBranch (dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch)1