use of dr.inference.model.Parameter in project beast-mcmc by beast-dev.
the class BeastCheckpointer method readStateFromFile.
private long readStateFromFile(File file, MarkovChain markovChain, double[] lnL) {
OperatorSchedule operatorSchedule = markovChain.getSchedule();
long state = -1;
ArrayList<TreeParameterModel> traitModels = new ArrayList<TreeParameterModel>();
try {
FileReader fileIn = new FileReader(file);
BufferedReader in = new BufferedReader(fileIn);
int[] rngState = null;
String line = in.readLine();
String[] fields = line.split("\t");
if (fields[0].equals("rng")) {
// if there is a random number generator state present then load it...
try {
rngState = new int[fields.length - 1];
for (int i = 0; i < rngState.length; i++) {
rngState[i] = Integer.parseInt(fields[i + 1]);
}
} catch (NumberFormatException nfe) {
throw new RuntimeException("Unable to read state number from state file");
}
line = in.readLine();
fields = line.split("\t");
}
try {
if (!fields[0].equals("state")) {
throw new RuntimeException("Unable to read state number from state file");
}
state = Long.parseLong(fields[1]);
} catch (NumberFormatException nfe) {
throw new RuntimeException("Unable to read state number from state file");
}
line = in.readLine();
fields = line.split("\t");
try {
if (!fields[0].equals("lnL")) {
throw new RuntimeException("Unable to read lnL from state file");
}
if (lnL != null) {
lnL[0] = Double.parseDouble(fields[1]);
}
} catch (NumberFormatException nfe) {
throw new RuntimeException("Unable to read lnL from state file");
}
for (Parameter parameter : Parameter.CONNECTED_PARAMETER_SET) {
line = in.readLine();
fields = line.split("\t");
//if (!fields[0].equals(parameter.getParameterName())) {
// System.err.println("Unable to match state parameter: " + fields[0] + ", expecting " + parameter.getParameterName());
//}
int dimension = Integer.parseInt(fields[2]);
if (dimension != parameter.getDimension()) {
System.err.println("Unable to match state parameter dimension: " + dimension + ", expecting " + parameter.getDimension() + " for parameter: " + parameter.getParameterName());
System.err.print("Read from file: ");
for (int i = 0; i < fields.length; i++) {
System.err.print(fields[i] + "\t");
}
System.err.println();
}
if (fields[1].equals("branchRates.categories.rootNodeNumber")) {
// System.out.println("eek");
double value = Double.parseDouble(fields[3]);
parameter.setParameterValue(0, value);
if (DEBUG) {
System.out.println("restoring " + fields[1] + " with value " + value);
}
} else {
if (DEBUG) {
System.out.print("restoring " + fields[1] + " with values ");
}
for (int dim = 0; dim < parameter.getDimension(); dim++) {
parameter.setParameterValue(dim, Double.parseDouble(fields[dim + 3]));
if (DEBUG) {
System.out.print(Double.parseDouble(fields[dim + 3]) + " ");
}
}
if (DEBUG) {
System.out.println();
}
}
}
for (int i = 0; i < operatorSchedule.getOperatorCount(); i++) {
MCMCOperator operator = operatorSchedule.getOperator(i);
line = in.readLine();
fields = line.split("\t");
if (!fields[1].equals(operator.getOperatorName())) {
throw new RuntimeException("Unable to match operator: " + fields[1]);
}
if (fields.length < 4) {
throw new RuntimeException("Operator missing values: " + fields[1]);
}
operator.setAcceptCount(Integer.parseInt(fields[2]));
operator.setRejectCount(Integer.parseInt(fields[3]));
if (operator instanceof CoercableMCMCOperator) {
if (fields.length != 5) {
throw new RuntimeException("Coercable operator missing parameter: " + fields[1]);
}
((CoercableMCMCOperator) operator).setCoercableParameter(Double.parseDouble(fields[4]));
}
}
// load the tree models last as we get the node heights from the tree (not the parameters which
// which may not be associated with the right node
Set<String> expectedTreeModelNames = new HashSet<String>();
for (Model model : Model.CONNECTED_MODEL_SET) {
if (model instanceof TreeModel) {
if (DEBUG) {
System.out.println("model " + model.getModelName());
}
expectedTreeModelNames.add(model.getModelName());
if (DEBUG) {
for (String s : expectedTreeModelNames) {
System.out.println(s);
}
}
}
if (model instanceof TreeParameterModel) {
traitModels.add((TreeParameterModel) model);
}
}
line = in.readLine();
fields = line.split("\t");
// Read in all (possibly more than one) trees
while (fields[0].equals("tree")) {
if (DEBUG) {
System.out.println("tree: " + fields[1]);
}
for (Model model : Model.CONNECTED_MODEL_SET) {
if (model instanceof TreeModel && fields[1].equals(model.getModelName())) {
line = in.readLine();
line = in.readLine();
fields = line.split("\t");
//read number of nodes
int nodeCount = Integer.parseInt(fields[0]);
double[] nodeHeights = new double[nodeCount];
for (int i = 0; i < nodeCount; i++) {
line = in.readLine();
fields = line.split("\t");
nodeHeights[i] = Double.parseDouble(fields[1]);
}
//on to reading edge information
line = in.readLine();
line = in.readLine();
line = in.readLine();
fields = line.split("\t");
int edgeCount = Integer.parseInt(fields[0]);
//create data matrix of doubles to store information from list of TreeParameterModels
double[][] traitValues = new double[traitModels.size()][edgeCount];
//create array to store whether a node is left or right child of its parent
//can be important for certain tree transition kernels
int[] childOrder = new int[edgeCount];
for (int i = 0; i < childOrder.length; i++) {
childOrder[i] = -1;
}
int[] parents = new int[edgeCount];
for (int i = 0; i < edgeCount; i++) {
parents[i] = -1;
}
for (int i = 0; i < edgeCount; i++) {
line = in.readLine();
if (line != null) {
fields = line.split("\t");
parents[Integer.parseInt(fields[0])] = Integer.parseInt(fields[1]);
childOrder[i] = Integer.parseInt(fields[2]);
for (int j = 0; j < traitModels.size(); j++) {
traitValues[j][i] = Double.parseDouble(fields[3 + j]);
}
}
}
//perform magic with the acquired information
if (DEBUG) {
System.out.println("adopting tree structure");
}
//adopt the loaded tree structure; this does not yet copy the traits on the branches
((TreeModel) model).beginTreeEdit();
((TreeModel) model).adoptTreeStructure(parents, nodeHeights, childOrder);
((TreeModel) model).endTreeEdit();
expectedTreeModelNames.remove(model.getModelName());
}
}
line = in.readLine();
if (line != null) {
fields = line.split("\t");
}
}
if (expectedTreeModelNames.size() > 0) {
StringBuilder sb = new StringBuilder();
for (String notFoundName : expectedTreeModelNames) {
sb.append("Expecting, but unable to match state parameter:" + notFoundName + "\n");
}
throw new RuntimeException(sb.toString());
}
if (DEBUG) {
System.out.println("\nDouble checking:");
for (Parameter parameter : Parameter.CONNECTED_PARAMETER_SET) {
if (parameter.getParameterName().equals("branchRates.categories.rootNodeNumber")) {
System.out.println(parameter.getParameterName() + ": " + parameter.getParameterValue(0));
}
}
}
if (rngState != null) {
MathUtils.setRandomState(rngState);
}
in.close();
fileIn.close();
// This shouldn't be necessary and if it is then it might be hiding a bug...
// for (Likelihood likelihood : Likelihood.CONNECTED_LIKELIHOOD_SET) {
// likelihood.makeDirty();
// }
} catch (IOException ioe) {
throw new RuntimeException("Unable to read file: " + ioe.getMessage());
}
return state;
}
use of dr.inference.model.Parameter in project beast-mcmc by beast-dev.
the class AlloppSpeciesNetworkModel method testExampleNetworkToMulLabTree.
/* *********************** TEST CODE **********************************/
/*
* Test of conversion from network to mullab tree
* * 2011-05-07 It is called from testAlloppSpeciesNetworkModel.java.
* I don't know how to put the code in there without
* making lots public here.
*/
// grjtodo-oneday. should be possible to pass stuff in nmltTEST. Currently
// it just signals that this is indeed a test.
//AR - removing this as it creates a dependency to test.dr.* which is bad...
public String testExampleNetworkToMulLabTree(int testcase) {
int ntaxa = apsp.numberOfSpecies();
Taxon[] spp = new Taxon[ntaxa];
for (int tx = 0; tx < ntaxa; ++tx) {
spp[tx] = new Taxon(apsp.apspeciesName(tx));
}
// 1,2,3 (names b,c,d) are tets, 0,4 are dips (names a,e)
double tetheight0 = 0.0;
double tetheight1 = 0.0;
double tetheight2 = 0.0;
// case 1. one tettree with one foot in each diploid branch
// case 2. one tettree with both feet in one diploid branch
// case 3. one tettree with one joined
// case 4. two tettrees, 2+1, first with one foot in each diploid
// branch, second joined
// case 5. three tettrees, 1+1+1, one of each type of feet, as in cases 1-3
int ntettrees = 0;
switch(testcase) {
case 1:
case 2:
case 3:
ntettrees = 1;
break;
case 4:
ntettrees = 2;
break;
case 5:
ntettrees = 3;
break;
}
tettrees = new ArrayList<AlloppLeggedTree>(ntettrees);
Taxon l0 = new Taxon("L0");
Taxon l1 = new Taxon("L1");
Taxon l2 = new Taxon("L2");
Taxon r0 = new Taxon("R0");
Taxon r1 = new Taxon("R1");
Taxon r2 = new Taxon("R2");
Taxon[] tets123 = { spp[1], spp[2], spp[3] };
Taxon[] tets12 = { spp[1], spp[2] };
Taxon[] tets1 = { spp[1] };
Taxon[] tets2 = { spp[2] };
Taxon[] tets3 = { spp[3] };
Taxon[] dips = new Taxon[0];
switch(testcase) {
case 1:
tettrees.add(new AlloppLeggedTree(tets123));
tetheight0 = tettrees.get(0).getRootHeight();
dips = new Taxon[] { spp[0], l0, r0, spp[4] };
break;
case 2:
tettrees.add(new AlloppLeggedTree(tets123));
tetheight0 = tettrees.get(0).getRootHeight();
dips = new Taxon[] { spp[0], l0, r0, spp[4] };
break;
case 3:
tettrees.add(new AlloppLeggedTree(tets123));
tetheight0 = tettrees.get(0).getRootHeight();
dips = new Taxon[] { spp[0], l0, r0, spp[4] };
break;
case 4:
tettrees.add(new AlloppLeggedTree(tets12));
tettrees.add(new AlloppLeggedTree(tets3));
tetheight0 = tettrees.get(0).getRootHeight();
tetheight1 = tettrees.get(1).getRootHeight();
dips = new Taxon[] { spp[0], l0, r0, l1, r1, spp[4] };
break;
case 5:
tettrees.add(new AlloppLeggedTree(tets1));
tettrees.add(new AlloppLeggedTree(tets2));
tettrees.add(new AlloppLeggedTree(tets3));
tetheight0 = tettrees.get(0).getRootHeight();
tetheight1 = tettrees.get(1).getRootHeight();
tetheight2 = tettrees.get(2).getRootHeight();
dips = new Taxon[] { spp[0], l0, r0, l1, r1, l2, r2, spp[4] };
break;
}
assert dips.length >= 2;
int ndhnodes = 2 * dips.length - 1;
SimpleNode[] dhnodes = new SimpleNode[ndhnodes];
for (int n = 0; n < ndhnodes; n++) {
dhnodes[n] = new SimpleNode();
if (n < dips.length) {
dhnodes[n].setTaxon(dips[n]);
} else {
dhnodes[n].setTaxon(new Taxon(""));
}
}
int dhroot = -1;
switch(testcase) {
case 1:
dhnodes[1].setHeight(tetheight0 + 1.0);
dhnodes[2].setHeight(tetheight0 + 1.0);
addSimpleNodeChildren(dhnodes[4], dhnodes[0], dhnodes[1], 1.0);
addSimpleNodeChildren(dhnodes[5], dhnodes[2], dhnodes[3], 1.0);
addSimpleNodeChildren(dhnodes[6], dhnodes[4], dhnodes[5], 1.0);
dhroot = 6;
break;
case 2:
dhnodes[1].setHeight(tetheight0 + 1.0);
dhnodes[2].setHeight(tetheight0 + 1.0);
addSimpleNodeChildren(dhnodes[4], dhnodes[0], dhnodes[1], 1.0);
addSimpleNodeChildren(dhnodes[5], dhnodes[2], dhnodes[4], 1.0);
addSimpleNodeChildren(dhnodes[6], dhnodes[3], dhnodes[5], 1.0);
dhroot = 6;
break;
case 3:
dhnodes[1].setHeight(tetheight0 + 1.0);
dhnodes[2].setHeight(tetheight0 + 1.0);
addSimpleNodeChildren(dhnodes[4], dhnodes[1], dhnodes[2], 1.0);
addSimpleNodeChildren(dhnodes[5], dhnodes[0], dhnodes[4], 1.0);
addSimpleNodeChildren(dhnodes[6], dhnodes[3], dhnodes[5], 1.0);
dhroot = 6;
break;
case 4:
dhnodes[1].setHeight(tetheight0 + 1.0);
dhnodes[2].setHeight(tetheight0 + 1.0);
dhnodes[3].setHeight(tetheight1 + 1.0);
dhnodes[4].setHeight(tetheight1 + 1.0);
addSimpleNodeChildren(dhnodes[6], dhnodes[0], dhnodes[1], 1.0);
addSimpleNodeChildren(dhnodes[7], dhnodes[3], dhnodes[4], 1.0);
addSimpleNodeChildren(dhnodes[8], dhnodes[6], dhnodes[7], 1.0);
addSimpleNodeChildren(dhnodes[9], dhnodes[2], dhnodes[5], 1.0);
addSimpleNodeChildren(dhnodes[10], dhnodes[8], dhnodes[9], 1.0);
dhroot = 10;
break;
case 5:
dhnodes[1].setHeight(tetheight0 + 1.0);
dhnodes[2].setHeight(tetheight0 + 1.0);
dhnodes[3].setHeight(tetheight1 + 1.0);
dhnodes[4].setHeight(tetheight1 + 1.0);
dhnodes[5].setHeight(tetheight2 + 1.0);
dhnodes[6].setHeight(tetheight2 + 1.0);
addSimpleNodeChildren(dhnodes[8], dhnodes[0], dhnodes[1], 1.0);
addSimpleNodeChildren(dhnodes[9], dhnodes[5], dhnodes[6], 1.0);
addSimpleNodeChildren(dhnodes[10], dhnodes[2], dhnodes[7], 1.0);
addSimpleNodeChildren(dhnodes[11], dhnodes[3], dhnodes[8], 1.0);
addSimpleNodeChildren(dhnodes[12], dhnodes[4], dhnodes[11], 1.0);
addSimpleNodeChildren(dhnodes[13], dhnodes[9], dhnodes[12], 1.0);
addSimpleNodeChildren(dhnodes[14], dhnodes[10], dhnodes[13], 1.0);
dhroot = 14;
break;
}
AlloppDiploidHistory adhist = new AlloppDiploidHistory(dhnodes, dhroot, tettrees, true, apsp);
int ntippopparams = numberOfTipPopParameters();
int nrootpopparams = numberOfRootPopParameters();
int maxnhybpopparams = maxNumberOfHybPopParameters();
Parameter testtippopvalues = new Parameter.Default(ntippopparams);
Parameter testrootpopvalues = new Parameter.Default(nrootpopparams);
double[] testhybpopvalues = new double[maxnhybpopparams];
for (int pp = 0; pp < ntippopparams; pp++) {
testtippopvalues.setParameterValue(pp, 1000 + pp);
}
for (int pp = 0; pp < nrootpopparams; pp++) {
testrootpopvalues.setParameterValue(pp, 2000 + pp);
}
for (int pp = 0; pp < maxnhybpopparams; pp++) {
testhybpopvalues[pp] = 3000 + pp;
}
AlloppMulLabTree testmullabtree = new AlloppMulLabTree(adhist, tettrees, apsp, testtippopvalues, testrootpopvalues, testhybpopvalues);
System.out.println(testmullabtree.asText());
String newick = testmullabtree.mullabTreeAsNewick();
return newick;
}
use of dr.inference.model.Parameter in project beast-mcmc by beast-dev.
the class ClusterSplitMergeOperator method doOperation.
/**
* change the parameter and return the hastings ratio.
*/
public final double doOperation() {
// get a copy of the allocations to work with...
int[] allocations = new int[allocationParameter.getDimension()];
// construct cluster occupancy vector excluding the selected item and count
// the unoccupied clusters.
int[] occupancy = new int[N];
int[] occupiedIndices = new int[N];
// used but not set
for (int i = 0; i < occupiedIndices.length; i++) {
occupiedIndices[i] = -1;
}
// k = number of unoccupied clusters
int K = 0;
for (int i = 0; i < allocations.length; i++) {
allocations[i] = (int) allocationParameter.getParameterValue(i);
occupancy[allocations[i]] += 1;
if (occupancy[allocations[i]] == 1) {
// first item in cluster
occupiedIndices[K] = allocations[i];
K++;
}
}
// Container for split/merge random variable (only 2 draws in 2D)
int paramDim = clusterLocations.getParameter(0).getDimension();
// Need to keep these for computing MHG ratio
double[] splitDraw = new double[paramDim];
// TODO make tunable
double scale = 1.0;
// TODO Make tunable
double newClusterProb = 0.5;
// always split when K = 1, always merge when K = N, otherwise 50:50
boolean doSplit = K == 1 || (K != N && MathUtils.nextBoolean());
if (doSplit) {
// Split operation
int cluster1;
do {
// pick an occupied cluster
cluster1 = occupiedIndices[MathUtils.nextInt(K)];
// For reversibility, merge step requires that both resulting clusters are occupied,
// so we should resample until condition is true
} while (occupancy[cluster1] == 0);
// find the first unoccupied cluster
int cluster2 = 0;
while (occupancy[cluster2] > 0) {
cluster2++;
}
int oldCount = occupancy[cluster1];
do {
occupancy[cluster1] = 0;
occupancy[cluster2] = 0;
for (int i = 0; i < allocations.length; i++) {
if (allocations[i] == cluster1 || allocations[i] == cluster2) {
boolean putInNewCluster = MathUtils.nextDouble() < newClusterProb;
if (putInNewCluster) {
allocations[i] = cluster2;
occupancy[cluster2]++;
} else {
allocations[i] = cluster1;
occupancy[cluster1]++;
}
}
}
} while (occupancy[cluster1] != 0 && occupancy[cluster2] != 0);
K++;
// set both clusters to a location based on the first cluster with some random jitter...
Parameter param1 = clusterLocations.getParameter(cluster1);
Parameter param2 = clusterLocations.getParameter(cluster2);
double[] loc = param1.getParameterValues();
for (int dim = 0; dim < param1.getDimension(); dim++) {
splitDraw[dim] = MathUtils.nextGaussian();
param1.setParameterValue(dim, loc[dim] + (splitDraw[dim] * scale));
// Move in opposite direction
param2.setParameterValue(dim, loc[dim] - (splitDraw[dim] * scale));
}
if (DEBUG) {
System.err.println("Split: " + (oldCount - occupancy[cluster1]) + " items from cluster " + cluster1 + " to create cluster " + cluster2);
}
} else {
// Merge operation
// pick 2 occupied clusters
int cluster1 = occupiedIndices[MathUtils.nextInt(K)];
int cluster2;
do {
cluster2 = occupiedIndices[MathUtils.nextInt(K)];
// resample until cluster1 != cluster2 to maintain reversibility, because split assumes they are different
} while (cluster1 == cluster2);
for (int i = 0; i < allocations.length; i++) {
if (allocations[i] == cluster2) {
allocations[i] = cluster1;
// keep occupancy up to date (remove if not need)
occupancy[cluster1]++;
occupancy[cluster2]--;
}
}
K--;
// set the merged cluster to the mean location of the two original clusters
Parameter loc1 = clusterLocations.getParameter(cluster1);
Parameter loc2 = clusterLocations.getParameter(cluster2);
for (int dim = 0; dim < loc1.getDimension(); dim++) {
double average = (loc1.getParameterValue(dim) + loc2.getParameterValue(dim)) / 2.0;
// Record that the reverse step would need to draw
splitDraw[dim] = (loc1.getParameterValue(dim) - average) / scale;
loc1.setParameterValue(dim, average);
// Consider loc2 as the extra dimensions for dimension-matching
// On second thought, maybe not a good idea
// loc2.setParameterValue(dim, splitDraw[dim]);
}
if (DEBUG) {
System.err.println("Merge: " + occupancy[cluster1] + "items into cluster " + cluster1 + " from " + cluster2);
}
}
// set the final allocations (only for those that have changed)
for (int i = 0; i < allocations.length; i++) {
int k = (int) allocationParameter.getParameterValue(i);
if (allocations[i] != k) {
allocationParameter.setParameterValue(i, allocations[i]);
}
}
// todo the Hastings ratio
return 0.0;
}
use of dr.inference.model.Parameter in project beast-mcmc by beast-dev.
the class BirthDeathCollapseModelParser method parseXMLObject.
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
final Units.Type units = XMLUnits.Utils.getUnitsAttr(xo);
final double collH = xo.getDoubleAttribute(COLLAPSE_HEIGHT);
XMLObject cxo = xo.getChild(TREE);
final Tree tree = (Tree) cxo.getChild(Tree.class);
final Parameter birthMinusDeath = (Parameter) xo.getElementFirstChild(BIRTHDIFF_RATE);
final Parameter relativeDeathRate = (Parameter) xo.getElementFirstChild(RELATIVE_DEATH_RATE);
final Parameter originHeight = (Parameter) xo.getElementFirstChild(ORIGIN_HEIGHT);
final Parameter collapseWeight = (Parameter) xo.getElementFirstChild(COLLAPSE_WEIGHT);
final String modelName = xo.getId();
return new BirthDeathCollapseModel(modelName, tree, units, birthMinusDeath, relativeDeathRate, originHeight, collapseWeight, collH);
}
use of dr.inference.model.Parameter in project beast-mcmc by beast-dev.
the class MulTreeNodeSlide method operateOneNode.
public void operateOneNode(final double factor) {
// #print "operate: tree", ut.treerep(t)
// if( verbose) System.out.println(" Mau at start: " + tree.getSimpleTree());
final int count = multree.getExternalNodeCount();
assert count == species.nSpSeqs();
NodeRef[] order = new NodeRef[2 * count - 1];
boolean[] swapped = new boolean[count - 1];
mauCanonical(multree, order, swapped);
// internal node to change
// count-1 - number of internal nodes
int which = MathUtils.nextInt(count - 1);
FixedBitSet left = new FixedBitSet(count);
FixedBitSet right = new FixedBitSet(count);
for (int k = 0; k < 2 * which + 1; k += 2) {
left.set(multree.speciesIndex(order[k]));
}
for (int k = 2 * (which + 1); k < 2 * count; k += 2) {
right.set(multree.speciesIndex(order[k]));
}
double newHeight;
if (factor > 0) {
newHeight = multree.getNodeHeight(order[2 * which + 1]) * factor;
} else {
final double limit = species.speciationUpperBound(left, right);
newHeight = MathUtils.nextDouble() * limit;
}
multree.beginTreeEdit();
multree.setPreorderIndices(preOrderIndexBefore);
final NodeRef node = order[2 * which + 1];
multree.setNodeHeight(node, newHeight);
mauReconstruct(multree, order, swapped);
// restore pre-order of pops -
{
multree.setPreorderIndices(preOrderIndexAfter);
double[] splitPopValues = null;
for (int k = 0; k < preOrderIndexBefore.length; ++k) {
final int b = preOrderIndexBefore[k];
if (b >= 0) {
final int a = preOrderIndexAfter[k];
if (a != b) {
//if( verbose) System.out.println("pops: " + a + " <- " + b);
final Parameter p1 = multree.sppSplitPopulations;
if (splitPopValues == null) {
splitPopValues = p1.getParameterValues();
}
if (multree.constPopulation()) {
p1.setParameterValue(count + a, splitPopValues[count + b]);
} else {
for (int i = 0; i < 2; ++i) {
p1.setParameterValue(count + 2 * a + i, splitPopValues[count + 2 * b + i]);
}
}
}
}
}
}
multree.endTreeEdit();
}
Aggregations