Search in sources :

Example 6 with Transform

use of dr.util.Transform in project beast-mcmc by beast-dev.

the class OrderedLatentLiabilityTransformParser method parseXMLObject.

@Override
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
    OrderedLatentLiabilityLikelihood likelihood = (OrderedLatentLiabilityLikelihood) xo.getChild(OrderedLatentLiabilityLikelihood.class);
    CompoundParameter parameter = likelihood.getTipTraitParameter();
    DataType dataType = likelihood.getPatternList().getDataType();
    if (!(dataType instanceof TwoStates)) {
        throw new XMLParseException("Liability transformation is currently only implemented for binary traits");
    }
    Parameter mask = null;
    if (xo.hasChildNamed(MaskedParameterParser.MASKING)) {
        mask = (Parameter) xo.getElementFirstChild(MaskedParameterParser.MASKING);
    }
    List<Transform> transforms = new ArrayList<Transform>();
    int index = 0;
    for (int tip = 0; tip < parameter.getParameterCount(); ++tip) {
        final int[] tipData = likelihood.getData(tip);
        for (int trait = 0; trait < tipData.length; ++trait) {
            int discreteState = tipData[trait];
            boolean valid = true;
            Transform transform;
            if (discreteState == 0) {
                transform = Transform.LOG_NEGATE;
                if (parameter.getParameterValue(index) >= 0.0) {
                    valid = false;
                }
            } else if (discreteState == 1) {
                transform = Transform.LOG;
                if (parameter.getParameterValue(index) <= 0.0) {
                    valid = false;
                }
            } else {
                transform = Transform.NONE;
            // transforms.add(Transform.NONE);
            }
            if (!valid) {
                throw new XMLParseException("Incompatible binary trait and latent value in tip '" + parameter.getParameter(tip).getId() + "'");
            }
            if (mask == null || mask.getParameterValue(index) == 1.0) {
                transforms.add(transform);
            }
            ++index;
        }
    }
    return new Transform.Array(transforms, parameter);
}
Also used : OrderedLatentLiabilityLikelihood(dr.evomodel.continuous.OrderedLatentLiabilityLikelihood) TwoStates(dr.evolution.datatype.TwoStates) ArrayList(java.util.ArrayList) DataType(dr.evolution.datatype.DataType) Transform(dr.util.Transform)

Example 7 with Transform

use of dr.util.Transform in project beast-mcmc by beast-dev.

the class MaximizeWrtParameterParser method parseXMLObject.

@Override
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
    GradientWrtParameterProvider gradient = (GradientWrtParameterProvider) xo.getChild(GradientWrtParameterProvider.class);
    Parameter parameter;
    Likelihood likelihood;
    int nIterations = Math.abs(xo.getAttribute(N_ITERATIONS, 0));
    boolean initialGuess = xo.getAttribute(INITIAL_GUESS, true);
    boolean printScreen = xo.getAttribute(PRINT_SCREEN, false);
    if (gradient != null) {
        parameter = gradient.getParameter();
        likelihood = gradient.getLikelihood();
    } else {
        XMLObject cxo = xo.getChild(DENSITY);
        parameter = (Parameter) cxo.getChild(Parameter.class);
        likelihood = (Likelihood) cxo.getChild(Likelihood.class);
    }
    Transform transform = (Transform) xo.getChild(Transform.class);
    MaximizerWrtParameter maximizer = new MaximizerWrtParameter(likelihood, parameter, gradient, transform, new MaximizerWrtParameter.Settings(nIterations, initialGuess, printScreen));
    maximizer.maximize();
    return maximizer;
}
Also used : Likelihood(dr.inference.model.Likelihood) GradientWrtParameterProvider(dr.inference.hmc.GradientWrtParameterProvider) Parameter(dr.inference.model.Parameter) MaximizerWrtParameter(dr.evomodel.treedatalikelihood.discrete.MaximizerWrtParameter) Transform(dr.util.Transform) MaximizerWrtParameter(dr.evomodel.treedatalikelihood.discrete.MaximizerWrtParameter)

Example 8 with Transform

use of dr.util.Transform in project beast-mcmc by beast-dev.

the class NodeHeightTransformParser method parseXMLObject.

@Override
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
    XMLObject cxo = xo.getChild(NODEHEIGHT);
    Parameter nodeHeightParameter = (Parameter) cxo.getChild(Parameter.class);
    Parameter ratioParameter = null;
    if (xo.hasChildNamed(RATIO)) {
        ratioParameter = (Parameter) xo.getChild(RATIO).getChild(Parameter.class);
    }
    if (ratioParameter != null) {
        if (ratioParameter.getDimension() == 1) {
            ratioParameter.setDimension(nodeHeightParameter.getDimension());
        }
        ratioParameter.addBounds(new Parameter.DefaultBounds(1.0, 0.0, ratioParameter.getDimension()));
    }
    Parameter coalescentIntervals = null;
    OldGMRFSkyrideLikelihood skyrideLikelihood = null;
    if (xo.hasChildNamed(COALESCENT_INTERVAL)) {
        cxo = xo.getChild(COALESCENT_INTERVAL);
        skyrideLikelihood = (OldGMRFSkyrideLikelihood) cxo.getChild(OldGMRFSkyrideLikelihood.class);
    }
    TreeModel tree = (TreeModel) xo.getChild(TreeModel.class);
    BranchRateModel branchRateModel = (BranchRateModel) xo.getChild(BranchRateModel.class);
    Transform nodeHeightTransform;
    if (ratioParameter != null) {
        NodeHeightTransform transform = new NodeHeightTransform(nodeHeightParameter, ratioParameter, tree, branchRateModel);
        if (xo.getChild(RATIO).getAttribute(REAL_LINE, false)) {
            List<Transform> transforms = new ArrayList<Transform>();
            if (nodeHeightParameter.getDimension() != ratioParameter.getDimension()) {
                transforms.add(new Transform.LogTransform());
            }
            for (int i = 0; i < ratioParameter.getDimension(); i++) {
                transforms.add(new Transform.LogitTransform());
            }
            nodeHeightTransform = new Transform.ComposeMultivariable(new Transform.Array(transforms, nodeHeightParameter), transform);
        } else {
            nodeHeightTransform = transform;
        }
    } else {
        nodeHeightTransform = new NodeHeightTransform(nodeHeightParameter, tree, skyrideLikelihood);
        coalescentIntervals = ((NodeHeightTransform) nodeHeightTransform).getParameter();
        cxo = xo.getChild(COALESCENT_INTERVAL);
        coalescentIntervals.setId(cxo.getId());
        cxo.setNativeObject(coalescentIntervals);
    }
    return nodeHeightTransform;
}
Also used : NodeHeightTransform(dr.evomodel.treedatalikelihood.discrete.NodeHeightTransform) OldGMRFSkyrideLikelihood(dr.evomodel.coalescent.OldGMRFSkyrideLikelihood) ArrayList(java.util.ArrayList) TreeModel(dr.evomodel.tree.TreeModel) BranchRateModel(dr.evomodel.branchratemodel.BranchRateModel) Parameter(dr.inference.model.Parameter) NodeHeightTransform(dr.evomodel.treedatalikelihood.discrete.NodeHeightTransform) Transform(dr.util.Transform)

Example 9 with Transform

use of dr.util.Transform in project beast-mcmc by beast-dev.

the class SignTransformParser method parseXMLObject.

@Override
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
    boolean hasStartOrEnd = xo.hasAttribute(START) || xo.hasAttribute(END);
    Parameter parameter = (Parameter) xo.getChild(Parameter.class);
    if (parameter == null) {
        // TODO: generalize to multivariate or move out
        if (hasStartOrEnd) {
            throw new XMLParseException("Cannot provide dimension start/end without a parameter");
        }
        return new Transform.LogTransform();
    }
    Bounds<Double> bounds = parameter.getBounds();
    List<Transform> transforms = new ArrayList<Transform>();
    if (xo.hasAttribute(START) && xo.hasAttribute(END)) {
        int start = xo.getIntegerAttribute(START) - 1;
        int end = xo.getIntegerAttribute(END);
        if (start > parameter.getDimension() || end > parameter.getDimension() || start > end) {
            throw new XMLParseException("Invalid start/end values for parameter");
        }
        for (int i = 0; i < parameter.getDimension(); ++i) {
            if (i >= start && i < end) {
                if (parameter.getParameterValue(i) < 0) {
                    transforms.add(Transform.LOG_NEGATE);
                } else {
                    transforms.add(Transform.LOG);
                }
            } else {
                transforms.add(Transform.NONE);
            }
        }
    } else {
        for (int i = 0; i < parameter.getDimension(); i++) {
            // TODO much better checking is necessary (here we assumed bounds <0 or >0 )
            if (bounds.getLowerLimit(i) == 0.0) {
                transforms.add(Transform.LOG);
            } else if (bounds.getUpperLimit(i) == 0.0) {
                transforms.add(Transform.LOG_NEGATE);
            } else {
                transforms.add(Transform.NONE);
            }
        }
    }
    return new Transform.Array(transforms, parameter);
}
Also used : ArrayList(java.util.ArrayList) Parameter(dr.inference.model.Parameter) Transform(dr.util.Transform)

Example 10 with Transform

use of dr.util.Transform in project beast-mcmc by beast-dev.

the class TransformedRandomWalkOperatorParser method parseXMLObject.

public Object parseXMLObject(XMLObject xo) throws XMLParseException {
    AdaptationMode mode = AdaptationMode.parseMode(xo);
    double weight = xo.getDoubleAttribute(MCMCOperator.WEIGHT);
    double windowSize = xo.getDoubleAttribute(WINDOW_SIZE);
    Parameter parameter = (Parameter) xo.getChild(Parameter.class);
    int dim = parameter.getDimension();
    Transform[] transformations = new Transform[dim];
    for (int i = 0; i < dim; i++) {
        transformations[i] = Transform.NONE;
    }
    for (int i = 0; i < xo.getChildCount(); i++) {
        Object child = xo.getChild(i);
        if (child instanceof Transform.ParsedTransform) {
            Transform.ParsedTransform thisObject = (Transform.ParsedTransform) child;
            System.err.println("Transformations:");
            for (int j = thisObject.start; j < thisObject.end; ++j) {
                transformations[j] = thisObject.transform;
                System.err.print(transformations[j].getTransformName() + " ");
            }
            System.err.println();
        }
    }
    Double lower = null;
    Double upper = null;
    if (xo.hasAttribute(LOWER)) {
        lower = xo.getDoubleAttribute(LOWER);
    }
    if (xo.hasAttribute(UPPER)) {
        upper = xo.getDoubleAttribute(UPPER);
    }
    TransformedRandomWalkOperator.BoundaryCondition condition = TransformedRandomWalkOperator.BoundaryCondition.valueOf(xo.getAttribute(BOUNDARY_CONDITION, TransformedRandomWalkOperator.BoundaryCondition.reflecting.name()));
    if (xo.hasChildNamed(UPDATE_INDEX)) {
        XMLObject cxo = xo.getChild(UPDATE_INDEX);
        Parameter updateIndex = (Parameter) cxo.getChild(Parameter.class);
        if (updateIndex.getDimension() != parameter.getDimension())
            throw new RuntimeException("Parameter to update and missing indices must have the same dimension");
        return new TransformedRandomWalkOperator(parameter, transformations, updateIndex, windowSize, condition, weight, mode, lower, upper);
    }
    return new TransformedRandomWalkOperator(parameter, transformations, null, windowSize, condition, weight, mode, lower, upper);
}
Also used : TransformedRandomWalkOperator(dr.inference.operators.TransformedRandomWalkOperator) XMLObject(dr.xml.XMLObject) AdaptationMode(dr.inference.operators.AdaptationMode) Parameter(dr.inference.model.Parameter) XMLObject(dr.xml.XMLObject) Transform(dr.util.Transform)

Aggregations

Transform (dr.util.Transform)13 Parameter (dr.inference.model.Parameter)10 ArrayList (java.util.ArrayList)5 TreeModel (dr.evomodel.tree.TreeModel)2 GradientWrtParameterProvider (dr.inference.hmc.GradientWrtParameterProvider)2 TransformedMultivariateParameter (dr.inference.model.TransformedMultivariateParameter)2 TransformedParameter (dr.inference.model.TransformedParameter)2 AdaptationMode (dr.inference.operators.AdaptationMode)2 Util.parseTransform (dr.util.Transform.Util.parseTransform)2 DataType (dr.evolution.datatype.DataType)1 TwoStates (dr.evolution.datatype.TwoStates)1 BranchRates (dr.evolution.tree.BranchRates)1 TransformedTreeTraitProvider (dr.evolution.tree.TransformedTreeTraitProvider)1 TreeTraitProvider (dr.evolution.tree.TreeTraitProvider)1 BranchRateModel (dr.evomodel.branchratemodel.BranchRateModel)1 BranchSpecificFixedEffects (dr.evomodel.branchratemodel.BranchSpecificFixedEffects)1 ContinuousBranchValueProvider (dr.evomodel.branchratemodel.ContinuousBranchValueProvider)1 CountableBranchCategoryProvider (dr.evomodel.branchratemodel.CountableBranchCategoryProvider)1 OldGMRFSkyrideLikelihood (dr.evomodel.coalescent.OldGMRFSkyrideLikelihood)1 OrderedLatentLiabilityLikelihood (dr.evomodel.continuous.OrderedLatentLiabilityLikelihood)1