use of dr.util.Transform in project beast-mcmc by beast-dev.
the class BranchSpecificFixedEffectsParser method parseXMLObject.
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
// Parameter allocationParameter = (Parameter) xo.getElementFirstChild(ALLOCATION);
TreeModel treeModel = (TreeModel) xo.getChild(TreeModel.class);
Parameter coefficients = (Parameter) xo.getChild(Parameter.class);
boolean includeIntercept = xo.getAttribute(INCLUDE_INTERCEPT, true);
List<CountableBranchCategoryProvider> categories = new ArrayList<CountableBranchCategoryProvider>();
for (XMLObject xoc : xo.getAllChildren(CATEGORY)) {
CountableBranchCategoryProvider.CladeBranchCategoryModel cladeModel = new CountableBranchCategoryProvider.CladeBranchCategoryModel(treeModel, new Parameter.Default(treeModel.getNodeCount() - 1));
parseCladeCategories(xoc, cladeModel);
categories.add(cladeModel);
}
List<ContinuousBranchValueProvider> values = new ArrayList<ContinuousBranchValueProvider>();
boolean timeDependentEffect = xo.getAttribute(TIME_DEPENDENT_EFFECT, false);
if (timeDependentEffect) {
values.add(new ContinuousBranchValueProvider.MidPoint());
}
List<BranchRates> branchRates = new ArrayList<BranchRates>();
for (int i = 0; i < xo.getChildCount(); ++i) {
Object obj = xo.getChild(i);
if (obj instanceof BranchRates) {
branchRates.add((BranchRates) obj);
}
}
Transform transform = (Transform) xo.getChild(Transform.class);
BranchSpecificFixedEffects.Default fixedEffects = new BranchSpecificFixedEffects.Default(xo.getId(), categories, values, branchRates, coefficients, includeIntercept);
double[][] designMatrix = fixedEffects.getDesignMatrix(treeModel);
Logger.getLogger("dr.evomodel").info("Using a fixed effects model with initial design matrix:\n" + annotateDesignMatrix(designMatrix, treeModel));
if (transform != null) {
return new BranchSpecificFixedEffects.Transformed(fixedEffects, transform);
} else {
return fixedEffects;
}
}
use of dr.util.Transform in project beast-mcmc by beast-dev.
the class TransformedTreeTraitParser method parseXMLObject.
@Override
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
TreeTraitProvider treeTraits = (TreeTraitProvider) xo.getChild(TreeTraitProvider.class);
Transform transform = (Transform) xo.getChild(Transform.class);
return new TransformedTreeTraitProvider(treeTraits, transform);
}
use of dr.util.Transform in project beast-mcmc by beast-dev.
the class RandomWalkGammaPrecisionGibbsOperator method main.
public static void main(String[] args) {
Parameter logPop = new Parameter.Default(new double[] { 1.0, 2.0, 3.0, 4.0, 5.0 });
Transform tr = new Transform.LogTransform();
Parameter effPop = new TransformedParameter(logPop, tr, true);
}
use of dr.util.Transform in project beast-mcmc by beast-dev.
the class RandomWalkOperatorParser method parseXMLObject.
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
AdaptationMode mode = AdaptationMode.parseMode(xo);
double weight = xo.getDoubleAttribute(MCMCOperator.WEIGHT);
double windowSize = xo.getDoubleAttribute(WINDOW_SIZE);
Parameter parameter = (Parameter) xo.getChild(Parameter.class);
if (xo.hasAttribute(LOWER) || xo.hasAttribute(UPPER)) {
throw new XMLParseException("Do not provide lower/upper bounds on for a RandomWalkOperator; set these values are parameter bounds");
}
RandomWalkOperator.BoundaryCondition condition = RandomWalkOperator.BoundaryCondition.valueOf(xo.getAttribute(BOUNDARY_CONDITION, RandomWalkOperator.BoundaryCondition.reflecting.name()));
final Bounds<Double> bounds = parameter.getBounds();
final int dim = parameter.getDimension();
boolean lowerBoundsSet = true;
boolean upperBoundsSet = true;
for (int i = 0; i < dim; ++i) {
if (bounds.getLowerLimit(i) == null || Double.isInfinite(bounds.getLowerLimit(i))) {
lowerBoundsSet = false;
}
if (bounds.getUpperLimit(i) == null || Double.isInfinite(bounds.getUpperLimit(i))) {
upperBoundsSet = false;
}
}
if (condition == RandomWalkOperator.BoundaryCondition.logit) {
if (!lowerBoundsSet || !upperBoundsSet) {
throw new XMLParseException("The logit transformed RandomWalkOperator cannot be used on a parameter without bounds.");
}
}
if (condition == RandomWalkOperator.BoundaryCondition.log) {
if (!lowerBoundsSet) {
throw new XMLParseException("The log transformed RandomWalkOperator cannot be used on a parameter without lower bounds.");
}
}
RandomWalkOperator randomWalk;
if (xo.hasChildNamed(UPDATE_INDEX)) {
XMLObject cxo = xo.getChild(UPDATE_INDEX);
Parameter updateIndex = (Parameter) cxo.getChild(Parameter.class);
if (updateIndex.getDimension() != parameter.getDimension())
throw new RuntimeException("Parameter to update and missing indices must have the same dimension");
randomWalk = new RandomWalkOperator(parameter, updateIndex, windowSize, condition, weight, mode);
} else {
randomWalk = new RandomWalkOperator(parameter, null, windowSize, condition, weight, mode);
}
final Transform transform = parseTransform(xo);
if (transform == null) {
return randomWalk;
} else {
final boolean inverse = xo.getAttribute(INVERSE, false);
TransformedParameter transformedParameter;
if (transform.isMultivariate()) {
transformedParameter = new TransformedMultivariateParameter(parameter, (Transform.MultivariableTransform) transform, inverse);
} else {
transformedParameter = new TransformedParameter(parameter, transform, inverse);
}
return new TransformedParameterRandomWalkOperator(transformedParameter, randomWalk);
}
}
use of dr.util.Transform in project beast-mcmc by beast-dev.
the class HamiltonianMonteCarloOperatorParser method parseXMLObject.
@Override
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
double weight = xo.getDoubleAttribute(MCMCOperator.WEIGHT);
int nSteps = xo.getAttribute(N_STEPS, 10);
double stepSize = xo.getDoubleAttribute(STEP_SIZE);
int runMode = parseRunMode(xo);
MassPreconditioner.Type preconditioningType = parsePreconditioning(xo);
double randomStepFraction = Math.abs(xo.getAttribute(RANDOM_STEP_FRACTION, 0.0));
if (randomStepFraction > 1) {
throw new XMLParseException("Random step count fraction must be < 1.0");
}
int preconditioningUpdateFrequency = xo.getAttribute(PRECONDITIONING_UPDATE_FREQUENCY, 0);
int preconditioningDelay = xo.getAttribute(PRECONDITIONING_DELAY, 0);
int preconditioningMemory = xo.getAttribute(PRECONDITIONING_MEMORY, 0);
AdaptationMode adaptationMode = AdaptationMode.parseMode(xo);
GradientWrtParameterProvider derivative = (GradientWrtParameterProvider) xo.getChild(GradientWrtParameterProvider.class);
if (preconditioningType != MassPreconditioner.Type.NONE && !(derivative instanceof HessianWrtParameterProvider)) {
throw new XMLParseException("Unable precondition without a Hessian provider");
}
Parameter parameter = (Parameter) xo.getChild(Parameter.class);
if (parameter == null) {
parameter = derivative.getParameter();
}
Transform transform = parseTransform(xo);
boolean dimensionMismatch = derivative.getDimension() != parameter.getDimension();
if (transform != null && transform instanceof Transform.MultivariableTransform) {
dimensionMismatch = ((Transform.MultivariableTransform) transform).getDimension() != parameter.getDimension();
}
if (dimensionMismatch) {
throw new XMLParseException("Gradient (" + derivative.getDimension() + ") must be the same dimensions as the parameter (" + parameter.getDimension() + ")");
}
Parameter mask = null;
if (xo.hasChildNamed(MASK)) {
mask = (Parameter) xo.getElementFirstChild(MASK);
if (mask.getDimension() != derivative.getDimension()) {
throw new XMLParseException("Mask (" + mask.getDimension() + ") must be the same dimension as the gradient (" + derivative.getDimension() + ")");
}
}
int gradientCheckCount = xo.getAttribute(GRADIENT_CHECK_COUNT, 0);
double gradientCheckTolerance = xo.getAttribute(GRADIENT_CHECK_TOLERANCE, 1E-3);
int maxIterations = xo.getAttribute(MAX_ITERATIONS, 10);
double reductionFactor = xo.getAttribute(REDUCTION_FACTOR, 0.1);
double targetAcceptanceProbability = xo.getAttribute(TARGET_ACCEPTANCE_PROBABILITY, // Stan default
0.8);
HamiltonianMonteCarloOperator.Options runtimeOptions = new HamiltonianMonteCarloOperator.Options(stepSize, nSteps, randomStepFraction, preconditioningUpdateFrequency, preconditioningDelay, preconditioningMemory, gradientCheckCount, gradientCheckTolerance, maxIterations, reductionFactor, targetAcceptanceProbability);
return factory(adaptationMode, weight, derivative, parameter, transform, mask, runtimeOptions, preconditioningType, runMode);
}
Aggregations