Search in sources :

Example 1 with GradientWrtParameterProvider

use of dr.inference.hmc.GradientWrtParameterProvider in project beast-mcmc by beast-dev.

the class AppendedPotentialDerivativeParser method parseXMLObject.

@Override
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
    List<GradientWrtParameterProvider> gradList = new ArrayList<GradientWrtParameterProvider>();
    // TODO Remove?
    List<Likelihood> likelihoodList = new ArrayList<Likelihood>();
    for (int i = 0; i < xo.getChildCount(); ++i) {
        Object obj = xo.getChild(i);
        GradientWrtParameterProvider grad;
        Likelihood likelihood;
        if (obj instanceof DistributionLikelihood) {
            DistributionLikelihood dl = (DistributionLikelihood) obj;
            if (!(dl.getDistribution() instanceof GradientProvider)) {
                throw new XMLParseException("Not a gradient provider");
            }
            throw new RuntimeException("Not yet implemented");
        } else if (obj instanceof MultivariateDistributionLikelihood) {
            final MultivariateDistributionLikelihood mdl = (MultivariateDistributionLikelihood) obj;
            if (!(mdl.getDistribution() instanceof GradientProvider)) {
                throw new XMLParseException("Not a gradient provider");
            }
            final GradientProvider provider = (GradientProvider) mdl.getDistribution();
            final Parameter parameter = mdl.getDataParameter();
            likelihood = mdl;
            grad = new GradientWrtParameterProvider.ParameterWrapper(provider, parameter, mdl);
        } else if (obj instanceof GradientWrtParameterProvider) {
            grad = (GradientWrtParameterProvider) obj;
            likelihood = grad.getLikelihood();
        } else {
            throw new XMLParseException("Not a Gaussian process");
        }
        gradList.add(grad);
        likelihoodList.add(likelihood);
    }
    return new CompoundGradient(gradList);
}
Also used : CompoundGradient(dr.inference.hmc.CompoundGradient) MultivariateDistributionLikelihood(dr.inference.distribution.MultivariateDistributionLikelihood) DistributionLikelihood(dr.inference.distribution.DistributionLikelihood) MultivariateDistributionLikelihood(dr.inference.distribution.MultivariateDistributionLikelihood) ArrayList(java.util.ArrayList) GradientWrtParameterProvider(dr.inference.hmc.GradientWrtParameterProvider) DistributionLikelihood(dr.inference.distribution.DistributionLikelihood) MultivariateDistributionLikelihood(dr.inference.distribution.MultivariateDistributionLikelihood)

Example 2 with GradientWrtParameterProvider

use of dr.inference.hmc.GradientWrtParameterProvider in project beast-mcmc by beast-dev.

the class CompoundGradient method getGradientLogDensity.

//    @Override
//    public void getGradientLogDensity(final double[] destination, final int offset) {
//        double[] grad = getGradientLogDensity();
//        System.arraycopy(grad, 0, destination, offset, grad.length);
//    }
@Override
public double[] getGradientLogDensity() {
    double[] result = new double[dimension];
    int offset = 0;
    for (GradientWrtParameterProvider grad : derivativeList) {
        double[] tmp = grad.getGradientLogDensity();
        System.arraycopy(tmp, 0, result, offset, grad.getDimension());
        offset += grad.getDimension();
    }
    return result;
}
Also used : GradientWrtParameterProvider(dr.inference.hmc.GradientWrtParameterProvider)

Example 3 with GradientWrtParameterProvider

use of dr.inference.hmc.GradientWrtParameterProvider in project beast-mcmc by beast-dev.

the class DiffusionParametersGradient method checkAndSetParametersGradients.

private int checkAndSetParametersGradients(CompoundGradient parametersGradients, CompoundParameter parameter) {
    int offset = 0;
    int dim = 0;
    int dimTrait = likelihood.getDataLikelihoodDelegate().getTraitDim();
    for (GradientWrtParameterProvider gradient : parametersGradients.getDerivativeList()) {
        assert gradient instanceof AbstractDiffusionGradient : "Gradients must all be instances of AbstractDiffusionGradient.";
        ((AbstractDiffusionGradient) gradient).setOffset(offset);
        parameter.addParameter(gradient.getParameter());
        offset += ((AbstractDiffusionGradient) gradient).getDerivationParameter().getDimension(dimTrait);
        dim += gradient.getDimension();
    }
    return dim;
}
Also used : GradientWrtParameterProvider(dr.inference.hmc.GradientWrtParameterProvider)

Example 4 with GradientWrtParameterProvider

use of dr.inference.hmc.GradientWrtParameterProvider in project beast-mcmc by beast-dev.

the class DiffusionParametersGradient method getGradientLogDensity.

private double[] getGradientLogDensity(double[] gradient) {
    double[] result = new double[dim];
    int offset = 0;
    for (GradientWrtParameterProvider gradientProvider : parametersGradients.getDerivativeList()) {
        System.arraycopy(((AbstractDiffusionGradient) gradientProvider).getGradientLogDensity(gradient), 0, result, offset, gradientProvider.getDimension());
        offset += gradientProvider.getDimension();
    }
    return result;
}
Also used : GradientWrtParameterProvider(dr.inference.hmc.GradientWrtParameterProvider)

Example 5 with GradientWrtParameterProvider

use of dr.inference.hmc.GradientWrtParameterProvider in project beast-mcmc by beast-dev.

the class LocationScaleGradientParser method parseXMLObject.

public Object parseXMLObject(XMLObject xo) throws XMLParseException {
    String traitName = xo.getAttribute(TRAIT_NAME, DEFAULT_TRAIT_NAME);
    boolean useHessian = xo.getAttribute(USE_HESSIAN, false);
    final Object child = xo.getChild(TreeDataLikelihood.class);
    if (child != null) {
        return parseTreeDataLikelihood(xo, (TreeDataLikelihood) child, traitName, useHessian);
    } else {
        CompoundLikelihood compoundLikelihood = (CompoundLikelihood) xo.getChild(CompoundLikelihood.class);
        List<GradientWrtParameterProvider> providers = new ArrayList<>();
        for (Likelihood likelihood : compoundLikelihood.getLikelihoods()) {
            if (!(likelihood instanceof TreeDataLikelihood)) {
                throw new XMLParseException("Unknown likelihood type");
            }
            GradientWrtParameterProvider provider = parseTreeDataLikelihood(xo, (TreeDataLikelihood) likelihood, traitName, useHessian);
            providers.add(provider);
        }
        checkBranchRateModels(providers);
        return new SumDerivative(providers);
    }
}
Also used : TreeDataLikelihood(dr.evomodel.treedatalikelihood.TreeDataLikelihood) CompoundLikelihood(dr.inference.model.CompoundLikelihood) Likelihood(dr.inference.model.Likelihood) TreeDataLikelihood(dr.evomodel.treedatalikelihood.TreeDataLikelihood) CompoundLikelihood(dr.inference.model.CompoundLikelihood) GradientWrtParameterProvider(dr.inference.hmc.GradientWrtParameterProvider) ArrayList(java.util.ArrayList) SumDerivative(dr.inference.hmc.SumDerivative)

Aggregations

GradientWrtParameterProvider (dr.inference.hmc.GradientWrtParameterProvider)16 ArrayList (java.util.ArrayList)6 Parameter (dr.inference.model.Parameter)4 TreeDataLikelihood (dr.evomodel.treedatalikelihood.TreeDataLikelihood)3 SumDerivative (dr.inference.hmc.SumDerivative)3 Likelihood (dr.inference.model.Likelihood)3 Tree (dr.evolution.tree.Tree)2 DistributionLikelihood (dr.inference.distribution.DistributionLikelihood)2 MultivariateDistributionLikelihood (dr.inference.distribution.MultivariateDistributionLikelihood)2 CompoundDerivative (dr.inference.hmc.CompoundDerivative)2 CompoundGradient (dr.inference.hmc.CompoundGradient)2 CompoundLikelihood (dr.inference.model.CompoundLikelihood)2 Transform (dr.util.Transform)2 AutoCorrelatedGradientWrtIncrements (dr.evomodel.branchratemodel.AutoCorrelatedGradientWrtIncrements)1 BranchRateGradientWrtIncrements (dr.evomodel.branchratemodel.BranchRateGradientWrtIncrements)1 DataLikelihoodDelegate (dr.evomodel.treedatalikelihood.DataLikelihoodDelegate)1 BranchRateGradient (dr.evomodel.treedatalikelihood.continuous.BranchRateGradient)1 BranchSpecificGradient (dr.evomodel.treedatalikelihood.continuous.BranchSpecificGradient)1 ContinuousDataLikelihoodDelegate (dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate)1 ContinuousTraitGradientForBranch (dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch)1