use of dr.util.Transform in project beast-mcmc by beast-dev.
the class OrderedLatentLiabilityTransformParser method parseXMLObject.
@Override
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
OrderedLatentLiabilityLikelihood likelihood = (OrderedLatentLiabilityLikelihood) xo.getChild(OrderedLatentLiabilityLikelihood.class);
CompoundParameter parameter = likelihood.getTipTraitParameter();
DataType dataType = likelihood.getPatternList().getDataType();
if (!(dataType instanceof TwoStates)) {
throw new XMLParseException("Liability transformation is currently only implemented for binary traits");
}
Parameter mask = null;
if (xo.hasChildNamed(MaskedParameterParser.MASKING)) {
mask = (Parameter) xo.getElementFirstChild(MaskedParameterParser.MASKING);
}
List<Transform> transforms = new ArrayList<Transform>();
int index = 0;
for (int tip = 0; tip < parameter.getParameterCount(); ++tip) {
final int[] tipData = likelihood.getData(tip);
for (int trait = 0; trait < tipData.length; ++trait) {
int discreteState = tipData[trait];
boolean valid = true;
Transform transform;
if (discreteState == 0) {
transform = Transform.LOG_NEGATE;
if (parameter.getParameterValue(index) >= 0.0) {
valid = false;
}
} else if (discreteState == 1) {
transform = Transform.LOG;
if (parameter.getParameterValue(index) <= 0.0) {
valid = false;
}
} else {
transform = Transform.NONE;
// transforms.add(Transform.NONE);
}
if (!valid) {
throw new XMLParseException("Incompatible binary trait and latent value in tip '" + parameter.getParameter(tip).getId() + "'");
}
if (mask == null || mask.getParameterValue(index) == 1.0) {
transforms.add(transform);
}
++index;
}
}
return new Transform.Array(transforms, parameter);
}
use of dr.util.Transform in project beast-mcmc by beast-dev.
the class MaximizeWrtParameterParser method parseXMLObject.
@Override
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
GradientWrtParameterProvider gradient = (GradientWrtParameterProvider) xo.getChild(GradientWrtParameterProvider.class);
Parameter parameter;
Likelihood likelihood;
int nIterations = Math.abs(xo.getAttribute(N_ITERATIONS, 0));
boolean initialGuess = xo.getAttribute(INITIAL_GUESS, true);
boolean printScreen = xo.getAttribute(PRINT_SCREEN, false);
if (gradient != null) {
parameter = gradient.getParameter();
likelihood = gradient.getLikelihood();
} else {
XMLObject cxo = xo.getChild(DENSITY);
parameter = (Parameter) cxo.getChild(Parameter.class);
likelihood = (Likelihood) cxo.getChild(Likelihood.class);
}
Transform transform = (Transform) xo.getChild(Transform.class);
MaximizerWrtParameter maximizer = new MaximizerWrtParameter(likelihood, parameter, gradient, transform, new MaximizerWrtParameter.Settings(nIterations, initialGuess, printScreen));
maximizer.maximize();
return maximizer;
}
use of dr.util.Transform in project beast-mcmc by beast-dev.
the class NodeHeightTransformParser method parseXMLObject.
@Override
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
XMLObject cxo = xo.getChild(NODEHEIGHT);
Parameter nodeHeightParameter = (Parameter) cxo.getChild(Parameter.class);
Parameter ratioParameter = null;
if (xo.hasChildNamed(RATIO)) {
ratioParameter = (Parameter) xo.getChild(RATIO).getChild(Parameter.class);
}
if (ratioParameter != null) {
if (ratioParameter.getDimension() == 1) {
ratioParameter.setDimension(nodeHeightParameter.getDimension());
}
ratioParameter.addBounds(new Parameter.DefaultBounds(1.0, 0.0, ratioParameter.getDimension()));
}
Parameter coalescentIntervals = null;
OldGMRFSkyrideLikelihood skyrideLikelihood = null;
if (xo.hasChildNamed(COALESCENT_INTERVAL)) {
cxo = xo.getChild(COALESCENT_INTERVAL);
skyrideLikelihood = (OldGMRFSkyrideLikelihood) cxo.getChild(OldGMRFSkyrideLikelihood.class);
}
TreeModel tree = (TreeModel) xo.getChild(TreeModel.class);
BranchRateModel branchRateModel = (BranchRateModel) xo.getChild(BranchRateModel.class);
Transform nodeHeightTransform;
if (ratioParameter != null) {
NodeHeightTransform transform = new NodeHeightTransform(nodeHeightParameter, ratioParameter, tree, branchRateModel);
if (xo.getChild(RATIO).getAttribute(REAL_LINE, false)) {
List<Transform> transforms = new ArrayList<Transform>();
if (nodeHeightParameter.getDimension() != ratioParameter.getDimension()) {
transforms.add(new Transform.LogTransform());
}
for (int i = 0; i < ratioParameter.getDimension(); i++) {
transforms.add(new Transform.LogitTransform());
}
nodeHeightTransform = new Transform.ComposeMultivariable(new Transform.Array(transforms, nodeHeightParameter), transform);
} else {
nodeHeightTransform = transform;
}
} else {
nodeHeightTransform = new NodeHeightTransform(nodeHeightParameter, tree, skyrideLikelihood);
coalescentIntervals = ((NodeHeightTransform) nodeHeightTransform).getParameter();
cxo = xo.getChild(COALESCENT_INTERVAL);
coalescentIntervals.setId(cxo.getId());
cxo.setNativeObject(coalescentIntervals);
}
return nodeHeightTransform;
}
use of dr.util.Transform in project beast-mcmc by beast-dev.
the class SignTransformParser method parseXMLObject.
@Override
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
boolean hasStartOrEnd = xo.hasAttribute(START) || xo.hasAttribute(END);
Parameter parameter = (Parameter) xo.getChild(Parameter.class);
if (parameter == null) {
// TODO: generalize to multivariate or move out
if (hasStartOrEnd) {
throw new XMLParseException("Cannot provide dimension start/end without a parameter");
}
return new Transform.LogTransform();
}
Bounds<Double> bounds = parameter.getBounds();
List<Transform> transforms = new ArrayList<Transform>();
if (xo.hasAttribute(START) && xo.hasAttribute(END)) {
int start = xo.getIntegerAttribute(START) - 1;
int end = xo.getIntegerAttribute(END);
if (start > parameter.getDimension() || end > parameter.getDimension() || start > end) {
throw new XMLParseException("Invalid start/end values for parameter");
}
for (int i = 0; i < parameter.getDimension(); ++i) {
if (i >= start && i < end) {
if (parameter.getParameterValue(i) < 0) {
transforms.add(Transform.LOG_NEGATE);
} else {
transforms.add(Transform.LOG);
}
} else {
transforms.add(Transform.NONE);
}
}
} else {
for (int i = 0; i < parameter.getDimension(); i++) {
// TODO much better checking is necessary (here we assumed bounds <0 or >0 )
if (bounds.getLowerLimit(i) == 0.0) {
transforms.add(Transform.LOG);
} else if (bounds.getUpperLimit(i) == 0.0) {
transforms.add(Transform.LOG_NEGATE);
} else {
transforms.add(Transform.NONE);
}
}
}
return new Transform.Array(transforms, parameter);
}
use of dr.util.Transform in project beast-mcmc by beast-dev.
the class TransformedRandomWalkOperatorParser method parseXMLObject.
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
AdaptationMode mode = AdaptationMode.parseMode(xo);
double weight = xo.getDoubleAttribute(MCMCOperator.WEIGHT);
double windowSize = xo.getDoubleAttribute(WINDOW_SIZE);
Parameter parameter = (Parameter) xo.getChild(Parameter.class);
int dim = parameter.getDimension();
Transform[] transformations = new Transform[dim];
for (int i = 0; i < dim; i++) {
transformations[i] = Transform.NONE;
}
for (int i = 0; i < xo.getChildCount(); i++) {
Object child = xo.getChild(i);
if (child instanceof Transform.ParsedTransform) {
Transform.ParsedTransform thisObject = (Transform.ParsedTransform) child;
System.err.println("Transformations:");
for (int j = thisObject.start; j < thisObject.end; ++j) {
transformations[j] = thisObject.transform;
System.err.print(transformations[j].getTransformName() + " ");
}
System.err.println();
}
}
Double lower = null;
Double upper = null;
if (xo.hasAttribute(LOWER)) {
lower = xo.getDoubleAttribute(LOWER);
}
if (xo.hasAttribute(UPPER)) {
upper = xo.getDoubleAttribute(UPPER);
}
TransformedRandomWalkOperator.BoundaryCondition condition = TransformedRandomWalkOperator.BoundaryCondition.valueOf(xo.getAttribute(BOUNDARY_CONDITION, TransformedRandomWalkOperator.BoundaryCondition.reflecting.name()));
if (xo.hasChildNamed(UPDATE_INDEX)) {
XMLObject cxo = xo.getChild(UPDATE_INDEX);
Parameter updateIndex = (Parameter) cxo.getChild(Parameter.class);
if (updateIndex.getDimension() != parameter.getDimension())
throw new RuntimeException("Parameter to update and missing indices must have the same dimension");
return new TransformedRandomWalkOperator(parameter, transformations, updateIndex, windowSize, condition, weight, mode, lower, upper);
}
return new TransformedRandomWalkOperator(parameter, transformations, null, windowSize, condition, weight, mode, lower, upper);
}
Aggregations