use of dr.inference.operators.OperatorSchedule in project beast-mcmc by beast-dev.
the class YuleModelTest method testYuleWithSubtreeSlide.
public void testYuleWithSubtreeSlide() {
TreeModel treeModel = new TreeModel("treeModel", tree);
OperatorSchedule schedule = new SimpleOperatorSchedule();
MCMCOperator operator = new SubtreeSlideOperator(treeModel, 1, 1, true, false, false, false, CoercionMode.COERCION_ON);
schedule.addOperator(operator);
yuleTester(treeModel, schedule);
}
use of dr.inference.operators.OperatorSchedule in project beast-mcmc by beast-dev.
the class OperatorAssert method irreducibilityTester.
private void irreducibilityTester(Tree tree, int numLabelledTopologies, int chainLength, int sampleTreeEvery) throws IOException, Importer.ImportException {
MCMC mcmc = new MCMC("mcmc1");
MCMCOptions options = new MCMCOptions(chainLength);
TreeModel treeModel = new TreeModel("treeModel", tree);
TreeLengthStatistic tls = new TreeLengthStatistic(TL, treeModel);
TreeHeightStatistic rootHeight = new TreeHeightStatistic(TREE_HEIGHT, treeModel);
OperatorSchedule schedule = getOperatorSchedule(treeModel);
Parameter b = new Parameter.Default("b", 2.0, 0.0, Double.MAX_VALUE);
Parameter d = new Parameter.Default("d", 0.0, 0.0, Double.MAX_VALUE);
SpeciationModel speciationModel = new BirthDeathGernhard08Model(b, d, null, BirthDeathGernhard08Model.TreeType.UNSCALED, Units.Type.YEARS);
Likelihood likelihood = new SpeciationLikelihood(treeModel, speciationModel, "yule.like");
MCLogger[] loggers = new MCLogger[2];
// loggers[0] = new MCLogger(new ArrayLogFormatter(false), 100, false);
// loggers[0].add(likelihood);
// loggers[0].add(rootHeight);
// loggers[0].add(tls);
loggers[0] = new MCLogger(new TabDelimitedFormatter(System.out), 10000, false);
loggers[0].add(likelihood);
loggers[0].add(rootHeight);
loggers[0].add(tls);
File file = new File("yule.trees");
file.deleteOnExit();
FileOutputStream out = new FileOutputStream(file);
loggers[1] = new TreeLogger(treeModel, new TabDelimitedFormatter(out), sampleTreeEvery, true, true, false);
mcmc.setShowOperatorAnalysis(true);
mcmc.init(options, likelihood, schedule, loggers);
mcmc.run();
out.flush();
out.close();
Set<String> uniqueTrees = new HashSet<String>();
HashMap<String, Integer> topologies = new HashMap<String, Integer>();
HashMap<String, HashMap<String, Integer>> treeCounts = new HashMap<String, HashMap<String, Integer>>();
NexusImporter importer = new NexusImporter(new FileReader(file));
int sampleSize = 0;
while (importer.hasTree()) {
sampleSize++;
Tree t = importer.importNextTree();
String uniqueNewick = TreeUtils.uniqueNewick(t, t.getRoot());
String topology = uniqueNewick.replaceAll("\\w+", "X");
if (!uniqueTrees.contains(uniqueNewick)) {
uniqueTrees.add(uniqueNewick);
}
HashMap<String, Integer> counts;
if (topologies.containsKey(topology)) {
topologies.put(topology, topologies.get(topology) + 1);
counts = treeCounts.get(topology);
} else {
topologies.put(topology, 1);
counts = new HashMap<String, Integer>();
treeCounts.put(topology, counts);
}
if (counts.containsKey(uniqueNewick)) {
counts.put(uniqueNewick, counts.get(uniqueNewick) + 1);
} else {
counts.put(uniqueNewick, 1);
}
}
TestCase.assertEquals(numLabelledTopologies, uniqueTrees.size());
TestCase.assertEquals(sampleSize, chainLength / sampleTreeEvery + 1);
Set<String> keys = topologies.keySet();
double ep = 1.0 / topologies.size();
for (String topology : keys) {
double ap = ((double) topologies.get(topology)) / (sampleSize);
// assertExpectation(ep, ap, sampleSize);
HashMap<String, Integer> counts = treeCounts.get(topology);
Set<String> trees = counts.keySet();
double MSE = 0;
double ep1 = 1.0 / counts.size();
for (String t : trees) {
double ap1 = ((double) counts.get(t)) / (topologies.get(topology));
// assertExpectation(ep1, ap1, topologies.get(topology));
MSE += (ep1 - ap1) * (ep1 - ap1);
}
MSE /= counts.size();
System.out.println("The Mean Square Error for the topolgy " + topology + " is " + MSE);
}
}
use of dr.inference.operators.OperatorSchedule in project beast-mcmc by beast-dev.
the class ARGAddRemoveOperatorTest method getSchedule.
public static OperatorSchedule getSchedule(ARGModel arg) {
CompoundParameter rootHeight = (CompoundParameter) arg.createNodeHeightsParameter(true, false, false);
CompoundParameter internalHeights = (CompoundParameter) arg.createNodeHeightsParameter(false, true, false);
//CompoundParameter allInternalNodeHeights = (CompoundParameter) arg.createNodeHeightsParameter(true, true, false);
// CompoundParameter rates = (CompoundParameter) arg.createNodeRatesParameter(false, true, true);
// ARGAddRemoveEventOperator operator1 = new ARGAddRemoveEventOperator(arg, 5, 0.5,
// CoercionMode.COERCION_ON, internalHeights, allInternalNodeHeights, rates, 0.9, null,-1);
ScaleOperator operator2 = new ScaleOperator(rootHeight, 0.75, CoercionMode.COERCION_ON, 5);
ScaleOperator operator3 = new ScaleOperator(internalHeights, 0.75, CoercionMode.COERCION_ON, 10);
OperatorSchedule schedule = new SimpleOperatorSchedule();
// schedule.addOperator(operator1);
schedule.addOperator(operator2);
schedule.addOperator(operator3);
return schedule;
// <scaleOperator id="rootOperator" scaleFactor="0.5"
// weight="10">
// <parameter idref="argModel.rootHeight" />
// </scaleOperator>
//
// <scaleOperator scaleFactor="0.95" weight="10">
// <parameter idref="argModel.internalNodeHeights" />
// </scaleOperator>
// <ARGEventOperator weight="5" addProbability="0.5"
// autoOptimize="false">
// <argTreeModel idref="argModel" />
// <internalNodes>
// <parameter idref="argModel.internalNodeHeights" />
// </internalNodes>
// <internalNodesPlusRoot>
// <parameter idref="argModel.allInternalNodeHeights" />
// </internalNodesPlusRoot>
// <nodeRates>
// <parameter idref="argModel.rates" />
// </nodeRates>
// </ARGEventOperator>
}
use of dr.inference.operators.OperatorSchedule in project beast-mcmc by beast-dev.
the class GibbsSubtreeSwapTestProblem method getOperatorSchedule.
public OperatorSchedule getOperatorSchedule(TreeModel treeModel) {
Parameter rootParameter = treeModel.createNodeHeightsParameter(true, false, false);
Parameter internalHeights = treeModel.createNodeHeightsParameter(false, true, false);
GibbsSubtreeSwap operator = new GibbsSubtreeSwap(treeModel, false, 1.0);
ScaleOperator scaleOperator = new ScaleOperator(rootParameter, 0.75, CoercionMode.COERCION_ON, 1.0);
UniformOperator uniformOperator = new UniformOperator(internalHeights, 1.0);
OperatorSchedule schedule = new SimpleOperatorSchedule();
schedule.addOperator(operator);
schedule.addOperator(scaleOperator);
schedule.addOperator(uniformOperator);
return schedule;
}
use of dr.inference.operators.OperatorSchedule in project beast-mcmc by beast-dev.
the class BeastCheckpointer method readStateFromFile.
private long readStateFromFile(File file, MarkovChain markovChain, double[] lnL) {
OperatorSchedule operatorSchedule = markovChain.getSchedule();
long state = -1;
ArrayList<TreeParameterModel> traitModels = new ArrayList<TreeParameterModel>();
try {
FileReader fileIn = new FileReader(file);
BufferedReader in = new BufferedReader(fileIn);
int[] rngState = null;
String line = in.readLine();
String[] fields = line.split("\t");
if (fields[0].equals("rng")) {
// if there is a random number generator state present then load it...
try {
rngState = new int[fields.length - 1];
for (int i = 0; i < rngState.length; i++) {
rngState[i] = Integer.parseInt(fields[i + 1]);
}
} catch (NumberFormatException nfe) {
throw new RuntimeException("Unable to read state number from state file");
}
line = in.readLine();
fields = line.split("\t");
}
try {
if (!fields[0].equals("state")) {
throw new RuntimeException("Unable to read state number from state file");
}
state = Long.parseLong(fields[1]);
} catch (NumberFormatException nfe) {
throw new RuntimeException("Unable to read state number from state file");
}
line = in.readLine();
fields = line.split("\t");
try {
if (!fields[0].equals("lnL")) {
throw new RuntimeException("Unable to read lnL from state file");
}
if (lnL != null) {
lnL[0] = Double.parseDouble(fields[1]);
}
} catch (NumberFormatException nfe) {
throw new RuntimeException("Unable to read lnL from state file");
}
for (Parameter parameter : Parameter.CONNECTED_PARAMETER_SET) {
line = in.readLine();
fields = line.split("\t");
//if (!fields[0].equals(parameter.getParameterName())) {
// System.err.println("Unable to match state parameter: " + fields[0] + ", expecting " + parameter.getParameterName());
//}
int dimension = Integer.parseInt(fields[2]);
if (dimension != parameter.getDimension()) {
System.err.println("Unable to match state parameter dimension: " + dimension + ", expecting " + parameter.getDimension() + " for parameter: " + parameter.getParameterName());
System.err.print("Read from file: ");
for (int i = 0; i < fields.length; i++) {
System.err.print(fields[i] + "\t");
}
System.err.println();
}
if (fields[1].equals("branchRates.categories.rootNodeNumber")) {
// System.out.println("eek");
double value = Double.parseDouble(fields[3]);
parameter.setParameterValue(0, value);
if (DEBUG) {
System.out.println("restoring " + fields[1] + " with value " + value);
}
} else {
if (DEBUG) {
System.out.print("restoring " + fields[1] + " with values ");
}
for (int dim = 0; dim < parameter.getDimension(); dim++) {
parameter.setParameterValue(dim, Double.parseDouble(fields[dim + 3]));
if (DEBUG) {
System.out.print(Double.parseDouble(fields[dim + 3]) + " ");
}
}
if (DEBUG) {
System.out.println();
}
}
}
for (int i = 0; i < operatorSchedule.getOperatorCount(); i++) {
MCMCOperator operator = operatorSchedule.getOperator(i);
line = in.readLine();
fields = line.split("\t");
if (!fields[1].equals(operator.getOperatorName())) {
throw new RuntimeException("Unable to match operator: " + fields[1]);
}
if (fields.length < 4) {
throw new RuntimeException("Operator missing values: " + fields[1]);
}
operator.setAcceptCount(Integer.parseInt(fields[2]));
operator.setRejectCount(Integer.parseInt(fields[3]));
if (operator instanceof CoercableMCMCOperator) {
if (fields.length != 5) {
throw new RuntimeException("Coercable operator missing parameter: " + fields[1]);
}
((CoercableMCMCOperator) operator).setCoercableParameter(Double.parseDouble(fields[4]));
}
}
// load the tree models last as we get the node heights from the tree (not the parameters which
// which may not be associated with the right node
Set<String> expectedTreeModelNames = new HashSet<String>();
for (Model model : Model.CONNECTED_MODEL_SET) {
if (model instanceof TreeModel) {
if (DEBUG) {
System.out.println("model " + model.getModelName());
}
expectedTreeModelNames.add(model.getModelName());
if (DEBUG) {
for (String s : expectedTreeModelNames) {
System.out.println(s);
}
}
}
if (model instanceof TreeParameterModel) {
traitModels.add((TreeParameterModel) model);
}
}
line = in.readLine();
fields = line.split("\t");
// Read in all (possibly more than one) trees
while (fields[0].equals("tree")) {
if (DEBUG) {
System.out.println("tree: " + fields[1]);
}
for (Model model : Model.CONNECTED_MODEL_SET) {
if (model instanceof TreeModel && fields[1].equals(model.getModelName())) {
line = in.readLine();
line = in.readLine();
fields = line.split("\t");
//read number of nodes
int nodeCount = Integer.parseInt(fields[0]);
double[] nodeHeights = new double[nodeCount];
for (int i = 0; i < nodeCount; i++) {
line = in.readLine();
fields = line.split("\t");
nodeHeights[i] = Double.parseDouble(fields[1]);
}
//on to reading edge information
line = in.readLine();
line = in.readLine();
line = in.readLine();
fields = line.split("\t");
int edgeCount = Integer.parseInt(fields[0]);
//create data matrix of doubles to store information from list of TreeParameterModels
double[][] traitValues = new double[traitModels.size()][edgeCount];
//create array to store whether a node is left or right child of its parent
//can be important for certain tree transition kernels
int[] childOrder = new int[edgeCount];
for (int i = 0; i < childOrder.length; i++) {
childOrder[i] = -1;
}
int[] parents = new int[edgeCount];
for (int i = 0; i < edgeCount; i++) {
parents[i] = -1;
}
for (int i = 0; i < edgeCount; i++) {
line = in.readLine();
if (line != null) {
fields = line.split("\t");
parents[Integer.parseInt(fields[0])] = Integer.parseInt(fields[1]);
childOrder[i] = Integer.parseInt(fields[2]);
for (int j = 0; j < traitModels.size(); j++) {
traitValues[j][i] = Double.parseDouble(fields[3 + j]);
}
}
}
//perform magic with the acquired information
if (DEBUG) {
System.out.println("adopting tree structure");
}
//adopt the loaded tree structure; this does not yet copy the traits on the branches
((TreeModel) model).beginTreeEdit();
((TreeModel) model).adoptTreeStructure(parents, nodeHeights, childOrder);
((TreeModel) model).endTreeEdit();
expectedTreeModelNames.remove(model.getModelName());
}
}
line = in.readLine();
if (line != null) {
fields = line.split("\t");
}
}
if (expectedTreeModelNames.size() > 0) {
StringBuilder sb = new StringBuilder();
for (String notFoundName : expectedTreeModelNames) {
sb.append("Expecting, but unable to match state parameter:" + notFoundName + "\n");
}
throw new RuntimeException(sb.toString());
}
if (DEBUG) {
System.out.println("\nDouble checking:");
for (Parameter parameter : Parameter.CONNECTED_PARAMETER_SET) {
if (parameter.getParameterName().equals("branchRates.categories.rootNodeNumber")) {
System.out.println(parameter.getParameterName() + ": " + parameter.getParameterValue(0));
}
}
}
if (rngState != null) {
MathUtils.setRandomState(rngState);
}
in.close();
fileIn.close();
// This shouldn't be necessary and if it is then it might be hiding a bug...
// for (Likelihood likelihood : Likelihood.CONNECTED_LIKELIHOOD_SET) {
// likelihood.makeDirty();
// }
} catch (IOException ioe) {
throw new RuntimeException("Unable to read file: " + ioe.getMessage());
}
return state;
}
Aggregations