Search in sources :

Example 6 with GradientWrtParameterProvider

use of dr.inference.hmc.GradientWrtParameterProvider in project beast-mcmc by beast-dev.

the class TaskPoolParser method parseXMLObject.

@Override
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
    Tree tree = (Tree) xo.getChild(Tree.class);
    GradientWrtParameterProvider gradient = (GradientWrtParameterProvider) xo.getChild(GradientWrtParameterProvider.class);
    int taskCount = (tree != null) ? tree.getExternalNodeCount() : gradient.getDimension();
    int threadCount = xo.getAttribute(THREAD_COUNT, 1);
    return new TaskPool(taskCount, threadCount);
}
Also used : TaskPool(dr.util.TaskPool) GradientWrtParameterProvider(dr.inference.hmc.GradientWrtParameterProvider) Tree(dr.evolution.tree.Tree)

Example 7 with GradientWrtParameterProvider

use of dr.inference.hmc.GradientWrtParameterProvider in project beast-mcmc by beast-dev.

the class BranchRateGradientParser method parseXMLObject.

@Override
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
    String traitName = xo.getAttribute(TRAIT_NAME, DEFAULT_TRAIT_NAME);
    boolean useHessian = xo.getAttribute(USE_HESSIAN, false);
    final Object child = xo.getChild(TreeDataLikelihood.class);
    if (child != null) {
        return parseTreeDataLikelihood((TreeDataLikelihood) child, traitName, useHessian);
    } else {
        CompoundLikelihood compoundLikelihood = (CompoundLikelihood) xo.getChild(CompoundLikelihood.class);
        List<GradientWrtParameterProvider> providers = new ArrayList<>();
        for (Likelihood likelihood : compoundLikelihood.getLikelihoods()) {
            if (!(likelihood instanceof TreeDataLikelihood)) {
                throw new XMLParseException("Unknown likelihood type");
            }
            GradientWrtParameterProvider provider = parseTreeDataLikelihood((TreeDataLikelihood) likelihood, traitName, useHessian);
            providers.add(provider);
        }
        checkBranchRateModels(providers);
        return new SumDerivative(providers);
    }
}
Also used : TreeDataLikelihood(dr.evomodel.treedatalikelihood.TreeDataLikelihood) CompoundLikelihood(dr.inference.model.CompoundLikelihood) Likelihood(dr.inference.model.Likelihood) TreeDataLikelihood(dr.evomodel.treedatalikelihood.TreeDataLikelihood) CompoundLikelihood(dr.inference.model.CompoundLikelihood) GradientWrtParameterProvider(dr.inference.hmc.GradientWrtParameterProvider) ArrayList(java.util.ArrayList) SumDerivative(dr.inference.hmc.SumDerivative)

Example 8 with GradientWrtParameterProvider

use of dr.inference.hmc.GradientWrtParameterProvider in project beast-mcmc by beast-dev.

the class CompoundGradientParser method parseXMLObject.

@Override
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
    List<GradientWrtParameterProvider> gradList = new ArrayList<GradientWrtParameterProvider>();
    // TODO Remove?
    List<Likelihood> likelihoodList = new ArrayList<Likelihood>();
    for (int i = 0; i < xo.getChildCount(); ++i) {
        Object obj = xo.getChild(i);
        GradientWrtParameterProvider grad;
        Likelihood likelihood;
        if (obj instanceof DistributionLikelihood) {
            DistributionLikelihood dl = (DistributionLikelihood) obj;
            if (!(dl.getDistribution() instanceof GradientProvider)) {
                throw new XMLParseException("Not a gradient provider");
            }
            throw new RuntimeException("Not yet implemented");
        } else if (obj instanceof MultivariateDistributionLikelihood) {
            final MultivariateDistributionLikelihood mdl = (MultivariateDistributionLikelihood) obj;
            if (!(mdl.getDistribution() instanceof GradientProvider)) {
                throw new XMLParseException("Not a gradient provider");
            }
            final GradientProvider provider = (GradientProvider) mdl.getDistribution();
            final Parameter parameter = mdl.getDataParameter();
            likelihood = mdl;
            grad = new GradientWrtParameterProvider.ParameterWrapper(provider, parameter, mdl);
        } else if (obj instanceof GradientWrtParameterProvider) {
            grad = (GradientWrtParameterProvider) obj;
            likelihood = grad.getLikelihood();
        } else {
            throw new XMLParseException("Not a Gaussian process");
        }
        gradList.add(grad);
        likelihoodList.add(likelihood);
    }
    return new CompoundDerivative(gradList);
}
Also used : MultivariateDistributionLikelihood(dr.inference.distribution.MultivariateDistributionLikelihood) DistributionLikelihood(dr.inference.distribution.DistributionLikelihood) MultivariateDistributionLikelihood(dr.inference.distribution.MultivariateDistributionLikelihood) ArrayList(java.util.ArrayList) CompoundDerivative(dr.inference.hmc.CompoundDerivative) GradientWrtParameterProvider(dr.inference.hmc.GradientWrtParameterProvider) DistributionLikelihood(dr.inference.distribution.DistributionLikelihood) MultivariateDistributionLikelihood(dr.inference.distribution.MultivariateDistributionLikelihood)

Example 9 with GradientWrtParameterProvider

use of dr.inference.hmc.GradientWrtParameterProvider in project beast-mcmc by beast-dev.

the class PathGradientParser method parseXMLObject.

@Override
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
    GradientWrtParameterProvider source = (GradientWrtParameterProvider) xo.getElementFirstChild(PathLikelihood.SOURCE);
    GradientWrtParameterProvider destination = (GradientWrtParameterProvider) xo.getElementFirstChild(PathLikelihood.DESTINATION);
    return new PathGradient(source, destination);
}
Also used : PathGradient(dr.inference.hmc.PathGradient) GradientWrtParameterProvider(dr.inference.hmc.GradientWrtParameterProvider)

Example 10 with GradientWrtParameterProvider

use of dr.inference.hmc.GradientWrtParameterProvider in project beast-mcmc by beast-dev.

the class SumDerivativeParser method parseXMLObject.

@Override
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
    List<GradientWrtParameterProvider> derivativeList = new ArrayList<GradientWrtParameterProvider>();
    for (int i = 0; i < xo.getChildCount(); i++) {
        GradientWrtParameterProvider grad = (GradientWrtParameterProvider) xo.getChild(i);
        derivativeList.add(grad);
    }
    return new SumDerivative(derivativeList);
}
Also used : GradientWrtParameterProvider(dr.inference.hmc.GradientWrtParameterProvider) ArrayList(java.util.ArrayList) SumDerivative(dr.inference.hmc.SumDerivative)

Aggregations

GradientWrtParameterProvider (dr.inference.hmc.GradientWrtParameterProvider)16 ArrayList (java.util.ArrayList)6 Parameter (dr.inference.model.Parameter)4 TreeDataLikelihood (dr.evomodel.treedatalikelihood.TreeDataLikelihood)3 SumDerivative (dr.inference.hmc.SumDerivative)3 Likelihood (dr.inference.model.Likelihood)3 Tree (dr.evolution.tree.Tree)2 DistributionLikelihood (dr.inference.distribution.DistributionLikelihood)2 MultivariateDistributionLikelihood (dr.inference.distribution.MultivariateDistributionLikelihood)2 CompoundDerivative (dr.inference.hmc.CompoundDerivative)2 CompoundGradient (dr.inference.hmc.CompoundGradient)2 CompoundLikelihood (dr.inference.model.CompoundLikelihood)2 Transform (dr.util.Transform)2 AutoCorrelatedGradientWrtIncrements (dr.evomodel.branchratemodel.AutoCorrelatedGradientWrtIncrements)1 BranchRateGradientWrtIncrements (dr.evomodel.branchratemodel.BranchRateGradientWrtIncrements)1 DataLikelihoodDelegate (dr.evomodel.treedatalikelihood.DataLikelihoodDelegate)1 BranchRateGradient (dr.evomodel.treedatalikelihood.continuous.BranchRateGradient)1 BranchSpecificGradient (dr.evomodel.treedatalikelihood.continuous.BranchSpecificGradient)1 ContinuousDataLikelihoodDelegate (dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate)1 ContinuousTraitGradientForBranch (dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch)1