use of dr.inference.operators.MCMCOperator in project beast-mcmc by beast-dev.
the class Tutorial1 method main.
public static void main(String[] arg) throws IOException, TraceException {
// constructing random variable representing mean of normal distribution
Variable.D mean = new Variable.D("mean", 1.0);
// give mean a uniform prior [-1000, 1000]
mean.addBounds(new Parameter.DefaultBounds(1000, -1000, 1));
// constructing random variable representing stdev of normal distribution
Variable.D stdev = new Variable.D("stdev", 1.0);
// give stdev a uniform prior [0, 1000]
stdev.addBounds(new Parameter.DefaultBounds(1000, 0, 1));
// construct normal distribution model
NormalDistributionModel normal = new NormalDistributionModel(mean, stdev);
// construct a likelihood for normal distribution
DistributionLikelihood likelihood = new DistributionLikelihood(normal);
// construct data
Attribute.Default<double[]> d = new Attribute.Default<double[]>("x", new double[] { 1, 2, 3, 4, 5, 6, 7, 8, 9 });
// add data (representing 9 independent observations) to likelihood
likelihood.addData(d);
// construct two "operators" to be used as the proposal distribution
MCMCOperator meanMove = new ScaleOperator(mean, 0.75);
MCMCOperator stdevMove = new ScaleOperator(stdev, 0.75);
// construct a logger to log progress of MCMC run to stdout (screen)
MCLogger logger1 = new MCLogger(100);
logger1.add(mean);
logger1.add(stdev);
// construct a logger to log to a log file for later analysis
MCLogger logger2 = new MCLogger("tutorial1.log", 100, false, 0);
logger2.add(mean);
logger2.add(stdev);
// construct MCMC object
MCMC mcmc = new MCMC("tutorial1:normal");
// initialize MCMC with chain length, likelihood, operators and loggers
mcmc.init(100000, likelihood, new MCMCOperator[] { meanMove, stdevMove }, new Logger[] { logger1, logger2 });
// run the mcmc
mcmc.chain();
// perform post-analysis
TraceAnalysis.report("tutorial1.log");
}
use of dr.inference.operators.MCMCOperator in project beast-mcmc by beast-dev.
the class ValuesPoolSwapOperatorParser method parseXMLObject.
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
final ValuesPool parameter = (ValuesPool) xo.getChild(ValuesPool.class);
final double weight = xo.getDoubleAttribute(MCMCOperator.WEIGHT);
final MCMCOperator op = new ValuesPoolSwapOperator(parameter);
op.setWeight(weight);
return op;
}
use of dr.inference.operators.MCMCOperator in project beast-mcmc by beast-dev.
the class BeastCheckpointer method readStateFromFile.
protected long readStateFromFile(File file, MarkovChain markovChain, double[] lnL) {
DoubleParser parser = useFullPrecision ? DoubleParser.HEX : DoubleParser.TEXT;
OperatorSchedule operatorSchedule = markovChain.getSchedule();
long state = -1;
ArrayList<TreeParameterModel> traitModels = new ArrayList<TreeParameterModel>();
try {
FileReader fileIn = new FileReader(file);
BufferedReader in = new BufferedReader(fileIn);
int[] rngState = null;
String line = in.readLine();
String[] fields = line.split("\t");
if (fields[0].equals("rng")) {
// if there is a random number generator state present then load it...
try {
rngState = new int[fields.length - 1];
for (int i = 0; i < rngState.length; i++) {
rngState[i] = Integer.parseInt(fields[i + 1]);
}
} catch (NumberFormatException nfe) {
throw new RuntimeException("Unable to read state number from state file");
}
line = in.readLine();
fields = line.split("\t");
}
try {
if (!fields[0].equals("state")) {
throw new RuntimeException("Unable to read state number from state file");
}
state = Long.parseLong(fields[1]);
} catch (NumberFormatException nfe) {
throw new RuntimeException("Unable to read state number from state file");
}
line = in.readLine();
fields = line.split("\t");
try {
if (!fields[0].equals("lnL")) {
throw new RuntimeException("Unable to read lnL from state file");
}
if (lnL != null) {
lnL[0] = parser.parseDouble(fields[1]);
}
} catch (NumberFormatException nfe) {
throw new RuntimeException("Unable to read lnL from state file");
}
for (Parameter parameter : Parameter.CONNECTED_PARAMETER_SET) {
if (!parameter.isImmutable()) {
line = in.readLine();
fields = line.split("\t");
// if (!fields[0].equals(parameter.getParameterName())) {
// System.err.println("Unable to match state parameter: " + fields[0] + ", expecting " + parameter.getParameterName());
// }
int dimension = Integer.parseInt(fields[2]);
if (dimension != parameter.getDimension()) {
System.err.println("Unable to match state parameter dimension: " + dimension + ", expecting " + parameter.getDimension() + " for parameter: " + parameter.getParameterName());
System.err.print("Read from file: ");
for (int i = 0; i < fields.length; i++) {
System.err.print(fields[i] + "\t");
}
System.err.println();
}
if (fields[1].equals("branchRates.categories.rootNodeNumber")) {
// System.out.println("eek");
double value = parser.parseDouble(fields[3]);
parameter.setParameterValue(0, value);
if (DEBUG) {
System.out.println("restoring " + fields[1] + " with value " + value);
}
} else {
if (DEBUG) {
System.out.print("restoring " + fields[1] + " with values ");
}
for (int dim = 0; dim < parameter.getDimension(); dim++) {
try {
parameter.setParameterUntransformedValue(dim, parser.parseDouble(fields[dim + 3]));
} catch (RuntimeException rte) {
System.err.println(rte);
continue;
}
if (DEBUG) {
System.out.print(parser.parseDouble(fields[dim + 3]) + " ");
}
}
if (DEBUG) {
System.out.println();
}
}
}
}
for (int i = 0; i < operatorSchedule.getOperatorCount(); i++) {
// TODO we can no longer assume these are in the right order
// TODO best parse all the "operator" lines and store them so we can mix and match within this for loop
// TODO does not only apply to the operators but also to the parameters
// TODO test using additional tip-date sampling compared to previous run
MCMCOperator operator = operatorSchedule.getOperator(i);
line = in.readLine();
fields = line.split("\t");
if (!fields[1].equals(operator.getOperatorName())) {
throw new RuntimeException("Unable to match " + operator.getOperatorName() + " operator: " + fields[1]);
}
if (fields.length < 4) {
throw new RuntimeException("Operator missing values: " + fields[1]);
}
operator.setAcceptCount(Integer.parseInt(fields[2]));
operator.setRejectCount(Integer.parseInt(fields[3]));
if (operator instanceof AdaptableMCMCOperator) {
if (fields.length != 6) {
throw new RuntimeException("Coercable operator missing parameter: " + fields[1]);
}
((AdaptableMCMCOperator) operator).setAdaptableParameter(parser.parseDouble(fields[4]));
((AdaptableMCMCOperator) operator).setAdaptationCount(Long.parseLong(fields[5]));
}
}
// load the tree models last as we get the node heights from the tree (not the parameters which
// which may not be associated with the right node
Set<String> expectedTreeModelNames = new HashSet<String>();
// store list of TreeModels for debugging purposes
ArrayList<TreeModel> treeModelList = new ArrayList<TreeModel>();
for (Model model : Model.CONNECTED_MODEL_SET) {
if (model instanceof TreeModel) {
if (DEBUG) {
System.out.println("model " + model.getModelName());
}
treeModelList.add((TreeModel) model);
expectedTreeModelNames.add(model.getModelName());
if (DEBUG) {
System.out.println("\nexpectedTreeModelNames:");
for (String s : expectedTreeModelNames) {
System.out.println(s);
}
System.out.println();
}
}
// first add all TreeParameterModels to a list
if (model instanceof TreeParameterModel) {
traitModels.add((TreeParameterModel) model);
}
}
// explicitly link TreeModel (using its unique ID) to a list of TreeParameterModels
// this information is currently not yet used
HashMap<String, ArrayList<TreeParameterModel>> linkedModels = new HashMap<String, ArrayList<TreeParameterModel>>();
for (String name : expectedTreeModelNames) {
ArrayList<TreeParameterModel> tpmList = new ArrayList<TreeParameterModel>();
for (TreeParameterModel tpm : traitModels) {
if (tpm.getTreeModel().getId().equals(name)) {
tpmList.add(tpm);
if (DEBUG) {
System.out.println("TreeModel: " + name + " has been assigned TreeParameterModel: " + tpm.toString());
}
}
}
linkedModels.put(name, tpmList);
}
line = in.readLine();
fields = line.split("\t");
// Read in all (possibly more than one) trees
while (fields[0].equals("tree")) {
if (DEBUG) {
System.out.println("\ntree: " + fields[1]);
}
for (Model model : Model.CONNECTED_MODEL_SET) {
if (model instanceof TreeModel && fields[1].equals(model.getModelName())) {
line = in.readLine();
line = in.readLine();
fields = line.split("\t");
// read number of nodes
int nodeCount = Integer.parseInt(fields[0]);
double[] nodeHeights = new double[nodeCount];
String[] taxaNames = new String[(nodeCount + 1) / 2];
for (int i = 0; i < nodeCount; i++) {
line = in.readLine();
fields = line.split("\t");
nodeHeights[i] = parser.parseDouble(fields[1]);
if (i < taxaNames.length) {
taxaNames[i] = fields[2];
}
}
// on to reading edge information
line = in.readLine();
line = in.readLine();
line = in.readLine();
fields = line.split("\t");
int edgeCount = Integer.parseInt(fields[0]);
if (DEBUG) {
System.out.println("edge count = " + edgeCount);
}
// create data matrix of doubles to store information from list of TreeParameterModels
// size of matrix depends on the number of TreeParameterModels assigned to a TreeModel
double[][] traitValues = new double[linkedModels.get(model.getId()).size()][edgeCount];
// create array to store whether a node is left or right child of its parent
// can be important for certain tree transition kernels
int[] childOrder = new int[edgeCount];
for (int i = 0; i < childOrder.length; i++) {
childOrder[i] = -1;
}
int[] parents = new int[edgeCount];
for (int i = 0; i < edgeCount; i++) {
parents[i] = -1;
}
for (int i = 0; i < edgeCount - 1; i++) {
line = in.readLine();
if (line != null) {
if (DEBUG) {
System.out.println("DEBUG: " + line);
}
fields = line.split("\t");
parents[Integer.parseInt(fields[0])] = Integer.parseInt(fields[1]);
// childOrder[i] = Integer.parseInt(fields[2]);
childOrder[Integer.parseInt(fields[0])] = Integer.parseInt(fields[2]);
for (int j = 0; j < linkedModels.get(model.getId()).size(); j++) {
// traitValues[j][i] = parser.parseDouble(fields[3+j]);
traitValues[j][Integer.parseInt(fields[0])] = parser.parseDouble(fields[3 + j]);
}
}
}
// perform magic with the acquired information
if (DEBUG) {
System.out.println("adopting tree structure");
}
// adopt the loaded tree structure;
((TreeModel) model).beginTreeEdit();
((TreeModel) model).adoptTreeStructure(parents, nodeHeights, childOrder, taxaNames);
if (traitModels.size() > 0) {
System.out.println("adopting " + traitModels.size() + " trait models to treeModel " + ((TreeModel) model).getId());
((TreeModel) model).adoptTraitData(parents, traitModels, traitValues, taxaNames);
}
((TreeModel) model).endTreeEdit();
expectedTreeModelNames.remove(model.getModelName());
}
}
line = in.readLine();
if (line != null) {
fields = line.split("\t");
}
}
if (expectedTreeModelNames.size() > 0) {
StringBuilder sb = new StringBuilder();
for (String notFoundName : expectedTreeModelNames) {
sb.append("Expecting, but unable to match state parameter:" + notFoundName + "\n");
}
throw new RuntimeException("\n" + sb.toString());
}
if (DEBUG) {
System.out.println("\nDouble checking:");
for (Parameter parameter : Parameter.CONNECTED_PARAMETER_SET) {
if (parameter.getParameterName().equals("branchRates.categories.rootNodeNumber")) {
System.out.println(parameter.getParameterName() + ": " + parameter.getParameterValue(0));
}
}
System.out.println("\nPrinting trees:");
for (TreeModel tm : treeModelList) {
System.out.println(tm.getId() + ": ");
System.out.println(tm.getNewick());
}
}
if (System.getProperty(BeastCheckpointer.CHECKPOINT_SEED) != null) {
MathUtils.setSeed(Long.parseLong(System.getProperty(BeastCheckpointer.CHECKPOINT_SEED)));
} else if (rngState != null) {
MathUtils.setRandomState(rngState);
}
in.close();
fileIn.close();
// This shouldn't be necessary and if it is then it might be hiding a bug...
/*for (Likelihood likelihood : Likelihood.CONNECTED_LIKELIHOOD_SET) {
likelihood.makeDirty();
}*/
} catch (IOException ioe) {
throw new RuntimeException("Unable to read file: " + ioe.getMessage());
}
return state;
}
use of dr.inference.operators.MCMCOperator in project beast-mcmc by beast-dev.
the class LognormalPriorTest method testLognormalPrior.
public void testLognormalPrior() {
// ConstantPopulation constant = new ConstantPopulation(Units.Type.YEARS);
// constant.setN0(popSize); // popSize
Parameter popSize = new Parameter.Default(6.0);
popSize.setId(ConstantPopulationModelParser.POPULATION_SIZE);
ConstantPopulationModel demo = new ConstantPopulationModel(popSize, Units.Type.YEARS);
// Likelihood
Likelihood dummyLikelihood = new DummyLikelihood(demo);
// Operators
OperatorSchedule schedule = new SimpleOperatorSchedule();
MCMCOperator operator = new ScaleOperator(popSize, 0.75);
operator.setWeight(1.0);
schedule.addOperator(operator);
// Log
ArrayLogFormatter formatter = new ArrayLogFormatter(false);
MCLogger[] loggers = new MCLogger[2];
loggers[0] = new MCLogger(formatter, 1000, false);
// loggers[0].add(treeLikelihood);
loggers[0].add(popSize);
loggers[1] = new MCLogger(new TabDelimitedFormatter(System.out), 100000, false);
// loggers[1].add(treeLikelihood);
loggers[1].add(popSize);
// MCMC
MCMC mcmc = new MCMC("mcmc1");
MCMCOptions options = new MCMCOptions(1000000);
// meanInRealSpace="false"
DistributionLikelihood logNormalLikelihood = new DistributionLikelihood(new LogNormalDistribution(1.0, 1.0), 0);
logNormalLikelihood.addData(popSize);
List<Likelihood> likelihoods = new ArrayList<Likelihood>();
likelihoods.add(logNormalLikelihood);
Likelihood prior = new CompoundLikelihood(0, likelihoods);
likelihoods.clear();
likelihoods.add(dummyLikelihood);
Likelihood likelihood = new CompoundLikelihood(-1, likelihoods);
likelihoods.clear();
likelihoods.add(prior);
likelihoods.add(likelihood);
Likelihood posterior = new CompoundLikelihood(0, likelihoods);
mcmc.setShowOperatorAnalysis(true);
mcmc.init(options, posterior, schedule, loggers);
mcmc.run();
// time
System.out.println(mcmc.getTimer().toString());
// Tracer
List<Trace> traces = formatter.getTraces();
ArrayTraceList traceList = new ArrayTraceList("LognormalPriorTest", traces, 0);
for (int i = 1; i < traces.size(); i++) {
traceList.analyseTrace(i);
}
// <expectation name="param" value="4.48168907"/>
TraceCorrelation popSizeStats = traceList.getCorrelationStatistics(traceList.getTraceIndex(ConstantPopulationModelParser.POPULATION_SIZE));
System.out.println("Expectation of Log-Normal(1,1) is e^(M+S^2/2) = e^(1.5) = " + Math.exp(1.5));
assertExpectation(ConstantPopulationModelParser.POPULATION_SIZE, popSizeStats, Math.exp(1.5));
}
use of dr.inference.operators.MCMCOperator in project beast-mcmc by beast-dev.
the class BeastCheckpointer method writeStateToFile.
protected boolean writeStateToFile(File file, long state, double lnL, MarkovChain markovChain) {
OperatorSchedule operatorSchedule = markovChain.getSchedule();
OutputStream fileOut = null;
try {
fileOut = new FileOutputStream(file);
PrintStream out = useFullPrecision ? new CheckpointPrintStream(fileOut) : new PrintStream(fileOut);
ArrayList<TreeParameterModel> traitModels = new ArrayList<TreeParameterModel>();
int[] rngState = MathUtils.getRandomState();
out.print("rng");
for (int i = 0; i < rngState.length; i++) {
out.print("\t");
out.print(rngState[i]);
}
out.println();
out.print("state\t");
out.println(state);
out.print("lnL\t");
out.println(lnL);
for (Parameter parameter : Parameter.CONNECTED_PARAMETER_SET) {
if (!parameter.isImmutable()) {
out.print("parameter");
out.print("\t");
out.print(parameter.getParameterName());
out.print("\t");
out.print(parameter.getDimension());
for (int dim = 0; dim < parameter.getDimension(); dim++) {
out.print("\t");
out.print(parameter.getParameterUntransformedValue(dim));
}
out.print("\n");
}
}
for (int i = 0; i < operatorSchedule.getOperatorCount(); i++) {
MCMCOperator operator = operatorSchedule.getOperator(i);
out.print("operator");
out.print("\t");
out.print(operator.getOperatorName());
out.print("\t");
out.print(operator.getAcceptCount());
out.print("\t");
out.print(operator.getRejectCount());
if (operator instanceof AdaptableMCMCOperator) {
out.print("\t");
out.print(((AdaptableMCMCOperator) operator).getAdaptableParameter());
out.print("\t");
out.print(((AdaptableMCMCOperator) operator).getAdaptationCount());
}
out.println();
}
// check up front if there are any TreeParameterModel objects
for (Model model : Model.CONNECTED_MODEL_SET) {
if (model instanceof TreeParameterModel) {
// System.out.println("\nDetected TreeParameterModel: " + ((TreeParameterModel) model).toString());
traitModels.add((TreeParameterModel) model);
}
}
for (Model model : Model.CONNECTED_MODEL_SET) {
if (model instanceof TreeModel) {
out.print("tree");
out.print("\t");
out.println(model.getModelName());
// replace Newick format by printing general graph structure
// out.println(((TreeModel) model).getNewick());
out.println("#node height taxon");
int nodeCount = ((TreeModel) model).getNodeCount();
out.println(nodeCount);
for (int i = 0; i < nodeCount; i++) {
out.print(((TreeModel) model).getNode(i).getNumber());
out.print("\t");
out.print(((TreeModel) model).getNodeHeight(((TreeModel) model).getNode(i)));
if (((TreeModel) model).isExternal(((TreeModel) model).getNode(i))) {
out.print("\t");
out.print(((TreeModel) model).getNodeTaxon(((TreeModel) model).getNode(i)).getId());
}
out.println();
}
out.println("#edges");
out.println("#child-node parent-node L/R-child traits");
out.println(nodeCount);
for (int i = 0; i < nodeCount; i++) {
NodeRef parent = ((TreeModel) model).getParent(((TreeModel) model).getNode(i));
if (parent != null) {
out.print(((TreeModel) model).getNode(i).getNumber());
out.print("\t");
out.print(((TreeModel) model).getParent(((TreeModel) model).getNode(i)).getNumber());
out.print("\t");
if ((((TreeModel) model).getChild(parent, 0) == ((TreeModel) model).getNode(i))) {
// left child
out.print(0);
} else if ((((TreeModel) model).getChild(parent, 1) == ((TreeModel) model).getNode(i))) {
// right child
out.print(1);
} else {
throw new RuntimeException("Operation currently only supported for nodes with 2 children.");
}
// only print the TreeParameterModel that matches the TreeModel currently being written
for (TreeParameterModel tpm : traitModels) {
if (model == tpm.getTreeModel()) {
out.print("\t");
out.print(tpm.getNodeValue((TreeModel) model, ((TreeModel) model).getNode(i)));
}
}
out.println();
} else {
if (DEBUG) {
System.out.println(((TreeModel) model).getNode(i) + " has no parent.");
}
}
}
}
}
out.close();
fileOut.close();
} catch (IOException ioe) {
System.err.println("Unable to write file: " + ioe.getMessage());
return false;
}
if (DEBUG) {
for (Likelihood likelihood : Likelihood.CONNECTED_LIKELIHOOD_SET) {
System.err.println(likelihood.getId() + ": " + likelihood.getLogLikelihood());
}
}
return true;
}
Aggregations