Search in sources :

Example 1 with Transform

use of dr.util.Transform in project beast-mcmc by beast-dev.

the class BranchSpecificFixedEffectsParser method parseXMLObject.

public Object parseXMLObject(XMLObject xo) throws XMLParseException {
    // Parameter allocationParameter = (Parameter) xo.getElementFirstChild(ALLOCATION);
    TreeModel treeModel = (TreeModel) xo.getChild(TreeModel.class);
    Parameter coefficients = (Parameter) xo.getChild(Parameter.class);
    boolean includeIntercept = xo.getAttribute(INCLUDE_INTERCEPT, true);
    List<CountableBranchCategoryProvider> categories = new ArrayList<CountableBranchCategoryProvider>();
    for (XMLObject xoc : xo.getAllChildren(CATEGORY)) {
        CountableBranchCategoryProvider.CladeBranchCategoryModel cladeModel = new CountableBranchCategoryProvider.CladeBranchCategoryModel(treeModel, new Parameter.Default(treeModel.getNodeCount() - 1));
        parseCladeCategories(xoc, cladeModel);
        categories.add(cladeModel);
    }
    List<ContinuousBranchValueProvider> values = new ArrayList<ContinuousBranchValueProvider>();
    boolean timeDependentEffect = xo.getAttribute(TIME_DEPENDENT_EFFECT, false);
    if (timeDependentEffect) {
        values.add(new ContinuousBranchValueProvider.MidPoint());
    }
    List<BranchRates> branchRates = new ArrayList<BranchRates>();
    for (int i = 0; i < xo.getChildCount(); ++i) {
        Object obj = xo.getChild(i);
        if (obj instanceof BranchRates) {
            branchRates.add((BranchRates) obj);
        }
    }
    Transform transform = (Transform) xo.getChild(Transform.class);
    BranchSpecificFixedEffects.Default fixedEffects = new BranchSpecificFixedEffects.Default(xo.getId(), categories, values, branchRates, coefficients, includeIntercept);
    double[][] designMatrix = fixedEffects.getDesignMatrix(treeModel);
    Logger.getLogger("dr.evomodel").info("Using a fixed effects model with initial design matrix:\n" + annotateDesignMatrix(designMatrix, treeModel));
    if (transform != null) {
        return new BranchSpecificFixedEffects.Transformed(fixedEffects, transform);
    } else {
        return fixedEffects;
    }
}
Also used : ContinuousBranchValueProvider(dr.evomodel.branchratemodel.ContinuousBranchValueProvider) ArrayList(java.util.ArrayList) TreeModel(dr.evomodel.tree.TreeModel) Parameter(dr.inference.model.Parameter) CountableBranchCategoryProvider(dr.evomodel.branchratemodel.CountableBranchCategoryProvider) Transform(dr.util.Transform) BranchRates(dr.evolution.tree.BranchRates) BranchSpecificFixedEffects(dr.evomodel.branchratemodel.BranchSpecificFixedEffects)

Example 2 with Transform

use of dr.util.Transform in project beast-mcmc by beast-dev.

the class TransformedTreeTraitParser method parseXMLObject.

@Override
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
    TreeTraitProvider treeTraits = (TreeTraitProvider) xo.getChild(TreeTraitProvider.class);
    Transform transform = (Transform) xo.getChild(Transform.class);
    return new TransformedTreeTraitProvider(treeTraits, transform);
}
Also used : TransformedTreeTraitProvider(dr.evolution.tree.TransformedTreeTraitProvider) TreeTraitProvider(dr.evolution.tree.TreeTraitProvider) Transform(dr.util.Transform) TransformedTreeTraitProvider(dr.evolution.tree.TransformedTreeTraitProvider)

Example 3 with Transform

use of dr.util.Transform in project beast-mcmc by beast-dev.

the class RandomWalkGammaPrecisionGibbsOperator method main.

public static void main(String[] args) {
    Parameter logPop = new Parameter.Default(new double[] { 1.0, 2.0, 3.0, 4.0, 5.0 });
    Transform tr = new Transform.LogTransform();
    Parameter effPop = new TransformedParameter(logPop, tr, true);
}
Also used : TransformedParameter(dr.inference.model.TransformedParameter) Parameter(dr.inference.model.Parameter) Transform(dr.util.Transform) TransformedParameter(dr.inference.model.TransformedParameter)

Example 4 with Transform

use of dr.util.Transform in project beast-mcmc by beast-dev.

the class RandomWalkOperatorParser method parseXMLObject.

public Object parseXMLObject(XMLObject xo) throws XMLParseException {
    AdaptationMode mode = AdaptationMode.parseMode(xo);
    double weight = xo.getDoubleAttribute(MCMCOperator.WEIGHT);
    double windowSize = xo.getDoubleAttribute(WINDOW_SIZE);
    Parameter parameter = (Parameter) xo.getChild(Parameter.class);
    if (xo.hasAttribute(LOWER) || xo.hasAttribute(UPPER)) {
        throw new XMLParseException("Do not provide lower/upper bounds on for a RandomWalkOperator; set these values are parameter bounds");
    }
    RandomWalkOperator.BoundaryCondition condition = RandomWalkOperator.BoundaryCondition.valueOf(xo.getAttribute(BOUNDARY_CONDITION, RandomWalkOperator.BoundaryCondition.reflecting.name()));
    final Bounds<Double> bounds = parameter.getBounds();
    final int dim = parameter.getDimension();
    boolean lowerBoundsSet = true;
    boolean upperBoundsSet = true;
    for (int i = 0; i < dim; ++i) {
        if (bounds.getLowerLimit(i) == null || Double.isInfinite(bounds.getLowerLimit(i))) {
            lowerBoundsSet = false;
        }
        if (bounds.getUpperLimit(i) == null || Double.isInfinite(bounds.getUpperLimit(i))) {
            upperBoundsSet = false;
        }
    }
    if (condition == RandomWalkOperator.BoundaryCondition.logit) {
        if (!lowerBoundsSet || !upperBoundsSet) {
            throw new XMLParseException("The logit transformed RandomWalkOperator cannot be used on a parameter without bounds.");
        }
    }
    if (condition == RandomWalkOperator.BoundaryCondition.log) {
        if (!lowerBoundsSet) {
            throw new XMLParseException("The log transformed RandomWalkOperator cannot be used on a parameter without lower bounds.");
        }
    }
    RandomWalkOperator randomWalk;
    if (xo.hasChildNamed(UPDATE_INDEX)) {
        XMLObject cxo = xo.getChild(UPDATE_INDEX);
        Parameter updateIndex = (Parameter) cxo.getChild(Parameter.class);
        if (updateIndex.getDimension() != parameter.getDimension())
            throw new RuntimeException("Parameter to update and missing indices must have the same dimension");
        randomWalk = new RandomWalkOperator(parameter, updateIndex, windowSize, condition, weight, mode);
    } else {
        randomWalk = new RandomWalkOperator(parameter, null, windowSize, condition, weight, mode);
    }
    final Transform transform = parseTransform(xo);
    if (transform == null) {
        return randomWalk;
    } else {
        final boolean inverse = xo.getAttribute(INVERSE, false);
        TransformedParameter transformedParameter;
        if (transform.isMultivariate()) {
            transformedParameter = new TransformedMultivariateParameter(parameter, (Transform.MultivariableTransform) transform, inverse);
        } else {
            transformedParameter = new TransformedParameter(parameter, transform, inverse);
        }
        return new TransformedParameterRandomWalkOperator(transformedParameter, randomWalk);
    }
}
Also used : TransformedMultivariateParameter(dr.inference.model.TransformedMultivariateParameter) TransformedParameter(dr.inference.model.TransformedParameter) TransformedParameter(dr.inference.model.TransformedParameter) Parameter(dr.inference.model.Parameter) TransformedMultivariateParameter(dr.inference.model.TransformedMultivariateParameter) Transform(dr.util.Transform) Util.parseTransform(dr.util.Transform.Util.parseTransform)

Example 5 with Transform

use of dr.util.Transform in project beast-mcmc by beast-dev.

the class HamiltonianMonteCarloOperatorParser method parseXMLObject.

@Override
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
    double weight = xo.getDoubleAttribute(MCMCOperator.WEIGHT);
    int nSteps = xo.getAttribute(N_STEPS, 10);
    double stepSize = xo.getDoubleAttribute(STEP_SIZE);
    int runMode = parseRunMode(xo);
    MassPreconditioner.Type preconditioningType = parsePreconditioning(xo);
    double randomStepFraction = Math.abs(xo.getAttribute(RANDOM_STEP_FRACTION, 0.0));
    if (randomStepFraction > 1) {
        throw new XMLParseException("Random step count fraction must be < 1.0");
    }
    int preconditioningUpdateFrequency = xo.getAttribute(PRECONDITIONING_UPDATE_FREQUENCY, 0);
    int preconditioningDelay = xo.getAttribute(PRECONDITIONING_DELAY, 0);
    int preconditioningMemory = xo.getAttribute(PRECONDITIONING_MEMORY, 0);
    AdaptationMode adaptationMode = AdaptationMode.parseMode(xo);
    GradientWrtParameterProvider derivative = (GradientWrtParameterProvider) xo.getChild(GradientWrtParameterProvider.class);
    if (preconditioningType != MassPreconditioner.Type.NONE && !(derivative instanceof HessianWrtParameterProvider)) {
        throw new XMLParseException("Unable precondition without a Hessian provider");
    }
    Parameter parameter = (Parameter) xo.getChild(Parameter.class);
    if (parameter == null) {
        parameter = derivative.getParameter();
    }
    Transform transform = parseTransform(xo);
    boolean dimensionMismatch = derivative.getDimension() != parameter.getDimension();
    if (transform != null && transform instanceof Transform.MultivariableTransform) {
        dimensionMismatch = ((Transform.MultivariableTransform) transform).getDimension() != parameter.getDimension();
    }
    if (dimensionMismatch) {
        throw new XMLParseException("Gradient (" + derivative.getDimension() + ") must be the same dimensions as the parameter (" + parameter.getDimension() + ")");
    }
    Parameter mask = null;
    if (xo.hasChildNamed(MASK)) {
        mask = (Parameter) xo.getElementFirstChild(MASK);
        if (mask.getDimension() != derivative.getDimension()) {
            throw new XMLParseException("Mask (" + mask.getDimension() + ") must be the same dimension as the gradient (" + derivative.getDimension() + ")");
        }
    }
    int gradientCheckCount = xo.getAttribute(GRADIENT_CHECK_COUNT, 0);
    double gradientCheckTolerance = xo.getAttribute(GRADIENT_CHECK_TOLERANCE, 1E-3);
    int maxIterations = xo.getAttribute(MAX_ITERATIONS, 10);
    double reductionFactor = xo.getAttribute(REDUCTION_FACTOR, 0.1);
    double targetAcceptanceProbability = xo.getAttribute(TARGET_ACCEPTANCE_PROBABILITY, // Stan default
    0.8);
    HamiltonianMonteCarloOperator.Options runtimeOptions = new HamiltonianMonteCarloOperator.Options(stepSize, nSteps, randomStepFraction, preconditioningUpdateFrequency, preconditioningDelay, preconditioningMemory, gradientCheckCount, gradientCheckTolerance, maxIterations, reductionFactor, targetAcceptanceProbability);
    return factory(adaptationMode, weight, derivative, parameter, transform, mask, runtimeOptions, preconditioningType, runMode);
}
Also used : HessianWrtParameterProvider(dr.inference.hmc.HessianWrtParameterProvider) MassPreconditioner(dr.inference.operators.hmc.MassPreconditioner) AdaptationMode(dr.inference.operators.AdaptationMode) GradientWrtParameterProvider(dr.inference.hmc.GradientWrtParameterProvider) Parameter(dr.inference.model.Parameter) Transform(dr.util.Transform) Util.parseTransform(dr.util.Transform.Util.parseTransform) HamiltonianMonteCarloOperator(dr.inference.operators.hmc.HamiltonianMonteCarloOperator)

Aggregations

Transform (dr.util.Transform)13 Parameter (dr.inference.model.Parameter)10 ArrayList (java.util.ArrayList)5 TreeModel (dr.evomodel.tree.TreeModel)2 GradientWrtParameterProvider (dr.inference.hmc.GradientWrtParameterProvider)2 TransformedMultivariateParameter (dr.inference.model.TransformedMultivariateParameter)2 TransformedParameter (dr.inference.model.TransformedParameter)2 AdaptationMode (dr.inference.operators.AdaptationMode)2 Util.parseTransform (dr.util.Transform.Util.parseTransform)2 DataType (dr.evolution.datatype.DataType)1 TwoStates (dr.evolution.datatype.TwoStates)1 BranchRates (dr.evolution.tree.BranchRates)1 TransformedTreeTraitProvider (dr.evolution.tree.TransformedTreeTraitProvider)1 TreeTraitProvider (dr.evolution.tree.TreeTraitProvider)1 BranchRateModel (dr.evomodel.branchratemodel.BranchRateModel)1 BranchSpecificFixedEffects (dr.evomodel.branchratemodel.BranchSpecificFixedEffects)1 ContinuousBranchValueProvider (dr.evomodel.branchratemodel.ContinuousBranchValueProvider)1 CountableBranchCategoryProvider (dr.evomodel.branchratemodel.CountableBranchCategoryProvider)1 OldGMRFSkyrideLikelihood (dr.evomodel.coalescent.OldGMRFSkyrideLikelihood)1 OrderedLatentLiabilityLikelihood (dr.evomodel.continuous.OrderedLatentLiabilityLikelihood)1