use of dr.inference.operators.OperatorSchedule in project beast-mcmc by beast-dev.
the class MLOptimizerParser method parseXMLObject.
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
int chainLength = xo.getIntegerAttribute(CHAIN_LENGTH);
OperatorSchedule opsched = null;
dr.inference.model.Likelihood likelihood = null;
ArrayList<Logger> loggers = new ArrayList<Logger>();
for (int i = 0; i < xo.getChildCount(); i++) {
Object child = xo.getChild(i);
if (child instanceof dr.inference.model.Likelihood) {
likelihood = (dr.inference.model.Likelihood) child;
} else if (child instanceof OperatorSchedule) {
opsched = (OperatorSchedule) child;
} else if (child instanceof Logger) {
loggers.add((Logger) child);
} else {
throw new XMLParseException("Unrecognized element found in optimizer element:" + child);
}
}
Logger[] loggerArray = new Logger[loggers.size()];
loggers.toArray(loggerArray);
return new MLOptimizer("optimizer1", chainLength, likelihood, opsched, loggerArray);
}
use of dr.inference.operators.OperatorSchedule in project beast-mcmc by beast-dev.
the class LognormalPriorTest method testLognormalPrior.
public void testLognormalPrior() {
// ConstantPopulation constant = new ConstantPopulation(Units.Type.YEARS);
// constant.setN0(popSize); // popSize
Parameter popSize = new Parameter.Default(6.0);
popSize.setId(ConstantPopulationModelParser.POPULATION_SIZE);
ConstantPopulationModel demo = new ConstantPopulationModel(popSize, Units.Type.YEARS);
//Likelihood
Likelihood dummyLikelihood = new DummyLikelihood(demo);
// Operators
OperatorSchedule schedule = new SimpleOperatorSchedule();
MCMCOperator operator = new ScaleOperator(popSize, 0.75);
operator.setWeight(1.0);
schedule.addOperator(operator);
// Log
ArrayLogFormatter formatter = new ArrayLogFormatter(false);
MCLogger[] loggers = new MCLogger[2];
loggers[0] = new MCLogger(formatter, 1000, false);
// loggers[0].add(treeLikelihood);
loggers[0].add(popSize);
loggers[1] = new MCLogger(new TabDelimitedFormatter(System.out), 100000, false);
// loggers[1].add(treeLikelihood);
loggers[1].add(popSize);
// MCMC
MCMC mcmc = new MCMC("mcmc1");
MCMCOptions options = new MCMCOptions(1000000);
// meanInRealSpace="false"
DistributionLikelihood logNormalLikelihood = new DistributionLikelihood(new LogNormalDistribution(1.0, 1.0), 0);
logNormalLikelihood.addData(popSize);
List<Likelihood> likelihoods = new ArrayList<Likelihood>();
likelihoods.add(logNormalLikelihood);
Likelihood prior = new CompoundLikelihood(0, likelihoods);
likelihoods.clear();
likelihoods.add(dummyLikelihood);
Likelihood likelihood = new CompoundLikelihood(-1, likelihoods);
likelihoods.clear();
likelihoods.add(prior);
likelihoods.add(likelihood);
Likelihood posterior = new CompoundLikelihood(0, likelihoods);
mcmc.setShowOperatorAnalysis(true);
mcmc.init(options, posterior, schedule, loggers);
mcmc.run();
// time
System.out.println(mcmc.getTimer().toString());
// Tracer
List<Trace> traces = formatter.getTraces();
ArrayTraceList traceList = new ArrayTraceList("LognormalPriorTest", traces, 0);
for (int i = 1; i < traces.size(); i++) {
traceList.analyseTrace(i);
}
// <expectation name="param" value="4.48168907"/>
TraceCorrelation popSizeStats = traceList.getCorrelationStatistics(traceList.getTraceIndex(ConstantPopulationModelParser.POPULATION_SIZE));
System.out.println("Expectation of Log-Normal(1,1) is e^(M+S^2/2) = e^(1.5) = " + Math.exp(1.5));
assertExpectation(ConstantPopulationModelParser.POPULATION_SIZE, popSizeStats, Math.exp(1.5));
}
use of dr.inference.operators.OperatorSchedule in project beast-mcmc by beast-dev.
the class BeastCheckpointer method writeStateToFile.
private boolean writeStateToFile(File file, long state, double lnL, MarkovChain markovChain) {
OperatorSchedule operatorSchedule = markovChain.getSchedule();
OutputStream fileOut = null;
try {
fileOut = new FileOutputStream(file);
PrintStream out = new PrintStream(fileOut);
ArrayList<TreeParameterModel> traitModels = new ArrayList<TreeParameterModel>();
int[] rngState = MathUtils.getRandomState();
out.print("rng");
for (int i = 0; i < rngState.length; i++) {
out.print("\t");
out.print(rngState[i]);
}
out.println();
out.print("state\t");
out.println(state);
out.print("lnL\t");
out.println(lnL);
for (Parameter parameter : Parameter.CONNECTED_PARAMETER_SET) {
out.print("parameter");
out.print("\t");
out.print(parameter.getParameterName());
out.print("\t");
out.print(parameter.getDimension());
for (int dim = 0; dim < parameter.getDimension(); dim++) {
out.print("\t");
out.print(parameter.getParameterValue(dim));
}
out.println();
}
for (int i = 0; i < operatorSchedule.getOperatorCount(); i++) {
MCMCOperator operator = operatorSchedule.getOperator(i);
out.print("operator");
out.print("\t");
out.print(operator.getOperatorName());
out.print("\t");
out.print(operator.getAcceptCount());
out.print("\t");
out.print(operator.getRejectCount());
if (operator instanceof CoercableMCMCOperator) {
out.print("\t");
out.print(((CoercableMCMCOperator) operator).getCoercableParameter());
}
out.println();
}
//check up front if there are any TreeParameterModel objects
for (Model model : Model.CONNECTED_MODEL_SET) {
if (model instanceof TreeParameterModel) {
//System.out.println("\nDetected TreeParameterModel: " + ((TreeParameterModel) model).toString());
traitModels.add((TreeParameterModel) model);
}
}
for (Model model : Model.CONNECTED_MODEL_SET) {
if (model instanceof TreeModel) {
out.print("tree");
out.print("\t");
out.println(model.getModelName());
//replace Newick format by printing general graph structure
//out.println(((TreeModel) model).getNewick());
out.println("#node height taxon");
int nodeCount = ((TreeModel) model).getNodeCount();
out.println(nodeCount);
for (int i = 0; i < nodeCount; i++) {
out.print(((TreeModel) model).getNode(i).getNumber());
out.print("\t");
out.print(((TreeModel) model).getNodeHeight(((TreeModel) model).getNode(i)));
if (((TreeModel) model).isExternal(((TreeModel) model).getNode(i))) {
out.print("\t");
out.print(((TreeModel) model).getNodeTaxon(((TreeModel) model).getNode(i)).getId());
}
out.println();
}
out.println("#edges");
out.println("#child-node parent-node L/R-child traits");
out.println(nodeCount);
for (int i = 0; i < nodeCount; i++) {
NodeRef parent = ((TreeModel) model).getParent(((TreeModel) model).getNode(i));
if (parent != null) {
out.print(((TreeModel) model).getNode(i).getNumber());
out.print("\t");
out.print(((TreeModel) model).getParent(((TreeModel) model).getNode(i)).getNumber());
out.print("\t");
if ((((TreeModel) model).getChild(parent, 0) == ((TreeModel) model).getNode(i))) {
//left child
out.print(0);
} else if ((((TreeModel) model).getChild(parent, 1) == ((TreeModel) model).getNode(i))) {
//right child
out.print(1);
} else {
throw new RuntimeException("Operation currently only supported for nodes with 2 children.");
}
for (TreeParameterModel tpm : traitModels) {
out.print("\t");
out.print(tpm.getNodeValue((TreeModel) model, ((TreeModel) model).getNode(i)));
}
out.println();
}
}
}
}
out.close();
fileOut.close();
} catch (IOException ioe) {
System.err.println("Unable to write file: " + ioe.getMessage());
return false;
}
if (DEBUG) {
for (Likelihood likelihood : Likelihood.CONNECTED_LIKELIHOOD_SET) {
System.err.println(likelihood.getId() + ": " + likelihood.getLogLikelihood());
}
}
return true;
}
use of dr.inference.operators.OperatorSchedule in project beast-mcmc by beast-dev.
the class ARGAddRemoveOperatorTest method flatPriorTester.
private void flatPriorTester(ARGModel arg, int chainLength, int sampleTreeEvery, double nodeCountSetting, double rootHeightAlpha, double rootHeightBeta, int maxCount) throws IOException, Importer.ImportException {
MCMC mcmc = new MCMC("mcmc1");
MCMCOptions options = new MCMCOptions(chainLength);
// double nodeCountSetting = 2.0;
// double rootHeightAlpha = 100;
// double rootHeightBeta = 0.5;
OperatorSchedule schedule = getSchedule(arg);
ARGUniformPrior uniformPrior = new ARGUniformPrior(arg, maxCount, arg.getExternalNodeCount());
PoissonDistribution poisson = new PoissonDistribution(nodeCountSetting);
DistributionLikelihood nodeCountPrior = new DistributionLikelihood(poisson, 0.0);
ARGReassortmentNodeCountStatistic nodeCountStatistic = new ARGReassortmentNodeCountStatistic("nodeCount", arg);
nodeCountPrior.addData(nodeCountStatistic);
DistributionLikelihood rootPrior = new DistributionLikelihood(new GammaDistribution(rootHeightAlpha, rootHeightBeta), 0.0);
CompoundParameter rootHeight = (CompoundParameter) arg.createNodeHeightsParameter(true, false, false);
rootPrior.addData(rootHeight);
List<Likelihood> likelihoods = new ArrayList<Likelihood>();
likelihoods.add(uniformPrior);
likelihoods.add(rootPrior);
likelihoods.add(nodeCountPrior);
CompoundLikelihood compoundLikelihood = new CompoundLikelihood(1, likelihoods);
compoundLikelihood.setId("likelihood1");
MCLogger[] loggers = new MCLogger[3];
loggers[0] = new MCLogger(new TabDelimitedFormatter(System.out), 10000, false);
loggers[0].add(compoundLikelihood);
loggers[0].add(arg);
File file = new File("test.args");
file.deleteOnExit();
FileOutputStream out = new FileOutputStream(file);
loggers[1] = new ARGLogger(arg, new TabDelimitedFormatter(out), sampleTreeEvery, "test");
ArrayLogFormatter formatter = new ArrayLogFormatter(false);
loggers[2] = new MCLogger(formatter, sampleTreeEvery, false);
loggers[2].add(arg);
arg.getRootHeightParameter().setId("root");
loggers[2].add(arg.getRootHeightParameter());
mcmc.setShowOperatorAnalysis(true);
mcmc.init(options, compoundLikelihood, schedule, loggers);
mcmc.run();
out.flush();
out.close();
List<Trace> traces = formatter.getTraces();
// Set<String> uniqueTrees = new HashSet<String>();
//
// NexusImporter importer = new NexusImporter(new FileReader(file));
// while (importer.hasTree()) {
// Tree t = importer.importNextTree();
// uniqueTrees.add(Tree.Utils.uniqueNewick(t, t.getRoot()));
// }
//
// TestCase.assertEquals(numTopologies, uniqueTrees.size()); List<Trace> traces = formatter.getTraces();
ArrayTraceList traceList = new ArrayTraceList("ARGTest", traces, 0);
for (int i = 1; i < traces.size(); i++) {
traceList.analyseTrace(i);
}
TraceCorrelation nodeCountStats = traceList.getCorrelationStatistics(1);
TraceCorrelation rootHeightStats = traceList.getCorrelationStatistics(4);
assertExpectation("nodeCount", nodeCountStats, poisson.truncatedMean(maxCount));
assertExpectation(TreeModelParser.ROOT_HEIGHT, rootHeightStats, rootHeightAlpha * rootHeightBeta);
}
use of dr.inference.operators.OperatorSchedule in project beast-mcmc by beast-dev.
the class MCMCMC method swapChainTemperatures.
private int swapChainTemperatures() {
if (DEBUG) {
System.out.print("Current scores: ");
for (int i = 0; i < chains.length; i++) {
System.out.print("\t");
if (i == coldChain) {
System.out.print("[");
}
System.out.print(chains[i].getCurrentScore());
if (i == coldChain) {
System.out.print("]");
}
}
System.out.println();
}
int newColdChain = coldChain;
int index1 = MathUtils.nextInt(chains.length);
int index2 = MathUtils.nextInt(chains.length);
while (index1 == index2) {
index2 = MathUtils.nextInt(chains.length);
}
double score1 = chains[index1].getCurrentScore();
MCMCCriterion acceptor1 = ((MCMCCriterion) chains[index1].getAcceptor());
double temperature1 = acceptor1.getTemperature();
double score2 = chains[index2].getCurrentScore();
MCMCCriterion acceptor2 = ((MCMCCriterion) chains[index2].getAcceptor());
double temperature2 = acceptor2.getTemperature();
double logRatio = ((score2 - score1) * temperature1) + ((score1 - score2) * temperature2);
boolean swap = (Math.log(MathUtils.nextDouble()) < logRatio);
if (swap) {
if (DEBUG) {
System.out.println("Swapping chain " + index1 + " and chain " + index2);
}
acceptor1.setTemperature(temperature2);
acceptor2.setTemperature(temperature1);
OperatorSchedule schedule1 = schedules[index1];
OperatorSchedule schedule2 = schedules[index2];
for (int i = 0; i < schedule1.getOperatorCount(); i++) {
MCMCOperator operator1 = schedule1.getOperator(i);
MCMCOperator operator2 = schedule2.getOperator(i);
long tmp = operator1.getAcceptCount();
operator1.setAcceptCount(operator2.getAcceptCount());
operator2.setAcceptCount(tmp);
tmp = operator1.getRejectCount();
operator1.setRejectCount(operator2.getRejectCount());
operator2.setRejectCount(tmp);
double tmp2 = operator1.getSumDeviation();
operator1.setSumDeviation(operator2.getSumDeviation());
operator2.setSumDeviation(tmp2);
if (operator1 instanceof CoercableMCMCOperator) {
tmp2 = ((CoercableMCMCOperator) operator1).getCoercableParameter();
((CoercableMCMCOperator) operator1).setCoercableParameter(((CoercableMCMCOperator) operator2).getCoercableParameter());
((CoercableMCMCOperator) operator2).setCoercableParameter(tmp2);
}
}
if (index1 == coldChain) {
newColdChain = index2;
} else if (index2 == coldChain) {
newColdChain = index1;
}
}
return newColdChain;
}
Aggregations