Search in sources :

Example 6 with OperatorSchedule

use of dr.inference.operators.OperatorSchedule in project beast-mcmc by beast-dev.

the class MLOptimizerParser method parseXMLObject.

public Object parseXMLObject(XMLObject xo) throws XMLParseException {
    int chainLength = xo.getIntegerAttribute(CHAIN_LENGTH);
    OperatorSchedule opsched = null;
    dr.inference.model.Likelihood likelihood = null;
    ArrayList<Logger> loggers = new ArrayList<Logger>();
    for (int i = 0; i < xo.getChildCount(); i++) {
        Object child = xo.getChild(i);
        if (child instanceof dr.inference.model.Likelihood) {
            likelihood = (dr.inference.model.Likelihood) child;
        } else if (child instanceof OperatorSchedule) {
            opsched = (OperatorSchedule) child;
        } else if (child instanceof Logger) {
            loggers.add((Logger) child);
        } else {
            throw new XMLParseException("Unrecognized element found in optimizer element:" + child);
        }
    }
    Logger[] loggerArray = new Logger[loggers.size()];
    loggers.toArray(loggerArray);
    return new MLOptimizer("optimizer1", chainLength, likelihood, opsched, loggerArray);
}
Also used : Likelihood(dr.inference.model.Likelihood) OperatorSchedule(dr.inference.operators.OperatorSchedule) Likelihood(dr.inference.model.Likelihood) ArrayList(java.util.ArrayList) Logger(dr.inference.loggers.Logger) MLOptimizer(dr.inference.ml.MLOptimizer)

Example 7 with OperatorSchedule

use of dr.inference.operators.OperatorSchedule in project beast-mcmc by beast-dev.

the class LognormalPriorTest method testLognormalPrior.

public void testLognormalPrior() {
    //        ConstantPopulation constant = new ConstantPopulation(Units.Type.YEARS);
    //        constant.setN0(popSize); // popSize
    Parameter popSize = new Parameter.Default(6.0);
    popSize.setId(ConstantPopulationModelParser.POPULATION_SIZE);
    ConstantPopulationModel demo = new ConstantPopulationModel(popSize, Units.Type.YEARS);
    //Likelihood
    Likelihood dummyLikelihood = new DummyLikelihood(demo);
    // Operators
    OperatorSchedule schedule = new SimpleOperatorSchedule();
    MCMCOperator operator = new ScaleOperator(popSize, 0.75);
    operator.setWeight(1.0);
    schedule.addOperator(operator);
    // Log
    ArrayLogFormatter formatter = new ArrayLogFormatter(false);
    MCLogger[] loggers = new MCLogger[2];
    loggers[0] = new MCLogger(formatter, 1000, false);
    //        loggers[0].add(treeLikelihood);
    loggers[0].add(popSize);
    loggers[1] = new MCLogger(new TabDelimitedFormatter(System.out), 100000, false);
    //        loggers[1].add(treeLikelihood);
    loggers[1].add(popSize);
    // MCMC
    MCMC mcmc = new MCMC("mcmc1");
    MCMCOptions options = new MCMCOptions(1000000);
    // meanInRealSpace="false"
    DistributionLikelihood logNormalLikelihood = new DistributionLikelihood(new LogNormalDistribution(1.0, 1.0), 0);
    logNormalLikelihood.addData(popSize);
    List<Likelihood> likelihoods = new ArrayList<Likelihood>();
    likelihoods.add(logNormalLikelihood);
    Likelihood prior = new CompoundLikelihood(0, likelihoods);
    likelihoods.clear();
    likelihoods.add(dummyLikelihood);
    Likelihood likelihood = new CompoundLikelihood(-1, likelihoods);
    likelihoods.clear();
    likelihoods.add(prior);
    likelihoods.add(likelihood);
    Likelihood posterior = new CompoundLikelihood(0, likelihoods);
    mcmc.setShowOperatorAnalysis(true);
    mcmc.init(options, posterior, schedule, loggers);
    mcmc.run();
    // time
    System.out.println(mcmc.getTimer().toString());
    // Tracer
    List<Trace> traces = formatter.getTraces();
    ArrayTraceList traceList = new ArrayTraceList("LognormalPriorTest", traces, 0);
    for (int i = 1; i < traces.size(); i++) {
        traceList.analyseTrace(i);
    }
    //      <expectation name="param" value="4.48168907"/>
    TraceCorrelation popSizeStats = traceList.getCorrelationStatistics(traceList.getTraceIndex(ConstantPopulationModelParser.POPULATION_SIZE));
    System.out.println("Expectation of Log-Normal(1,1) is e^(M+S^2/2) = e^(1.5) = " + Math.exp(1.5));
    assertExpectation(ConstantPopulationModelParser.POPULATION_SIZE, popSizeStats, Math.exp(1.5));
}
Also used : CompoundLikelihood(dr.inference.model.CompoundLikelihood) Likelihood(dr.inference.model.Likelihood) DistributionLikelihood(dr.inference.distribution.DistributionLikelihood) DummyLikelihood(dr.inference.model.DummyLikelihood) MCMC(dr.inference.mcmc.MCMC) ArrayList(java.util.ArrayList) LogNormalDistribution(dr.math.distributions.LogNormalDistribution) MCMCOptions(dr.inference.mcmc.MCMCOptions) DummyLikelihood(dr.inference.model.DummyLikelihood) ArrayLogFormatter(dr.inference.loggers.ArrayLogFormatter) TraceCorrelation(dr.inference.trace.TraceCorrelation) ConstantPopulationModel(dr.evomodel.coalescent.ConstantPopulationModel) OperatorSchedule(dr.inference.operators.OperatorSchedule) SimpleOperatorSchedule(dr.inference.operators.SimpleOperatorSchedule) CompoundLikelihood(dr.inference.model.CompoundLikelihood) TabDelimitedFormatter(dr.inference.loggers.TabDelimitedFormatter) Trace(dr.inference.trace.Trace) SimpleOperatorSchedule(dr.inference.operators.SimpleOperatorSchedule) ArrayTraceList(dr.inference.trace.ArrayTraceList) Parameter(dr.inference.model.Parameter) ScaleOperator(dr.inference.operators.ScaleOperator) DistributionLikelihood(dr.inference.distribution.DistributionLikelihood) MCMCOperator(dr.inference.operators.MCMCOperator) MCLogger(dr.inference.loggers.MCLogger)

Example 8 with OperatorSchedule

use of dr.inference.operators.OperatorSchedule in project beast-mcmc by beast-dev.

the class BeastCheckpointer method writeStateToFile.

private boolean writeStateToFile(File file, long state, double lnL, MarkovChain markovChain) {
    OperatorSchedule operatorSchedule = markovChain.getSchedule();
    OutputStream fileOut = null;
    try {
        fileOut = new FileOutputStream(file);
        PrintStream out = new PrintStream(fileOut);
        ArrayList<TreeParameterModel> traitModels = new ArrayList<TreeParameterModel>();
        int[] rngState = MathUtils.getRandomState();
        out.print("rng");
        for (int i = 0; i < rngState.length; i++) {
            out.print("\t");
            out.print(rngState[i]);
        }
        out.println();
        out.print("state\t");
        out.println(state);
        out.print("lnL\t");
        out.println(lnL);
        for (Parameter parameter : Parameter.CONNECTED_PARAMETER_SET) {
            out.print("parameter");
            out.print("\t");
            out.print(parameter.getParameterName());
            out.print("\t");
            out.print(parameter.getDimension());
            for (int dim = 0; dim < parameter.getDimension(); dim++) {
                out.print("\t");
                out.print(parameter.getParameterValue(dim));
            }
            out.println();
        }
        for (int i = 0; i < operatorSchedule.getOperatorCount(); i++) {
            MCMCOperator operator = operatorSchedule.getOperator(i);
            out.print("operator");
            out.print("\t");
            out.print(operator.getOperatorName());
            out.print("\t");
            out.print(operator.getAcceptCount());
            out.print("\t");
            out.print(operator.getRejectCount());
            if (operator instanceof CoercableMCMCOperator) {
                out.print("\t");
                out.print(((CoercableMCMCOperator) operator).getCoercableParameter());
            }
            out.println();
        }
        //check up front if there are any TreeParameterModel objects
        for (Model model : Model.CONNECTED_MODEL_SET) {
            if (model instanceof TreeParameterModel) {
                //System.out.println("\nDetected TreeParameterModel: " + ((TreeParameterModel) model).toString());
                traitModels.add((TreeParameterModel) model);
            }
        }
        for (Model model : Model.CONNECTED_MODEL_SET) {
            if (model instanceof TreeModel) {
                out.print("tree");
                out.print("\t");
                out.println(model.getModelName());
                //replace Newick format by printing general graph structure
                //out.println(((TreeModel) model).getNewick());
                out.println("#node height taxon");
                int nodeCount = ((TreeModel) model).getNodeCount();
                out.println(nodeCount);
                for (int i = 0; i < nodeCount; i++) {
                    out.print(((TreeModel) model).getNode(i).getNumber());
                    out.print("\t");
                    out.print(((TreeModel) model).getNodeHeight(((TreeModel) model).getNode(i)));
                    if (((TreeModel) model).isExternal(((TreeModel) model).getNode(i))) {
                        out.print("\t");
                        out.print(((TreeModel) model).getNodeTaxon(((TreeModel) model).getNode(i)).getId());
                    }
                    out.println();
                }
                out.println("#edges");
                out.println("#child-node parent-node L/R-child traits");
                out.println(nodeCount);
                for (int i = 0; i < nodeCount; i++) {
                    NodeRef parent = ((TreeModel) model).getParent(((TreeModel) model).getNode(i));
                    if (parent != null) {
                        out.print(((TreeModel) model).getNode(i).getNumber());
                        out.print("\t");
                        out.print(((TreeModel) model).getParent(((TreeModel) model).getNode(i)).getNumber());
                        out.print("\t");
                        if ((((TreeModel) model).getChild(parent, 0) == ((TreeModel) model).getNode(i))) {
                            //left child
                            out.print(0);
                        } else if ((((TreeModel) model).getChild(parent, 1) == ((TreeModel) model).getNode(i))) {
                            //right child
                            out.print(1);
                        } else {
                            throw new RuntimeException("Operation currently only supported for nodes with 2 children.");
                        }
                        for (TreeParameterModel tpm : traitModels) {
                            out.print("\t");
                            out.print(tpm.getNodeValue((TreeModel) model, ((TreeModel) model).getNode(i)));
                        }
                        out.println();
                    }
                }
            }
        }
        out.close();
        fileOut.close();
    } catch (IOException ioe) {
        System.err.println("Unable to write file: " + ioe.getMessage());
        return false;
    }
    if (DEBUG) {
        for (Likelihood likelihood : Likelihood.CONNECTED_LIKELIHOOD_SET) {
            System.err.println(likelihood.getId() + ": " + likelihood.getLogLikelihood());
        }
    }
    return true;
}
Also used : OperatorSchedule(dr.inference.operators.OperatorSchedule) Likelihood(dr.inference.model.Likelihood) TreeParameterModel(dr.evomodel.tree.TreeParameterModel) TreeModel(dr.evomodel.tree.TreeModel) NodeRef(dr.evolution.tree.NodeRef) TreeParameterModel(dr.evomodel.tree.TreeParameterModel) Model(dr.inference.model.Model) TreeModel(dr.evomodel.tree.TreeModel) Parameter(dr.inference.model.Parameter) CoercableMCMCOperator(dr.inference.operators.CoercableMCMCOperator) MCMCOperator(dr.inference.operators.MCMCOperator) CoercableMCMCOperator(dr.inference.operators.CoercableMCMCOperator)

Example 9 with OperatorSchedule

use of dr.inference.operators.OperatorSchedule in project beast-mcmc by beast-dev.

the class ARGAddRemoveOperatorTest method flatPriorTester.

private void flatPriorTester(ARGModel arg, int chainLength, int sampleTreeEvery, double nodeCountSetting, double rootHeightAlpha, double rootHeightBeta, int maxCount) throws IOException, Importer.ImportException {
    MCMC mcmc = new MCMC("mcmc1");
    MCMCOptions options = new MCMCOptions(chainLength);
    //        double nodeCountSetting = 2.0;
    //        double rootHeightAlpha = 100;
    //        double rootHeightBeta = 0.5;
    OperatorSchedule schedule = getSchedule(arg);
    ARGUniformPrior uniformPrior = new ARGUniformPrior(arg, maxCount, arg.getExternalNodeCount());
    PoissonDistribution poisson = new PoissonDistribution(nodeCountSetting);
    DistributionLikelihood nodeCountPrior = new DistributionLikelihood(poisson, 0.0);
    ARGReassortmentNodeCountStatistic nodeCountStatistic = new ARGReassortmentNodeCountStatistic("nodeCount", arg);
    nodeCountPrior.addData(nodeCountStatistic);
    DistributionLikelihood rootPrior = new DistributionLikelihood(new GammaDistribution(rootHeightAlpha, rootHeightBeta), 0.0);
    CompoundParameter rootHeight = (CompoundParameter) arg.createNodeHeightsParameter(true, false, false);
    rootPrior.addData(rootHeight);
    List<Likelihood> likelihoods = new ArrayList<Likelihood>();
    likelihoods.add(uniformPrior);
    likelihoods.add(rootPrior);
    likelihoods.add(nodeCountPrior);
    CompoundLikelihood compoundLikelihood = new CompoundLikelihood(1, likelihoods);
    compoundLikelihood.setId("likelihood1");
    MCLogger[] loggers = new MCLogger[3];
    loggers[0] = new MCLogger(new TabDelimitedFormatter(System.out), 10000, false);
    loggers[0].add(compoundLikelihood);
    loggers[0].add(arg);
    File file = new File("test.args");
    file.deleteOnExit();
    FileOutputStream out = new FileOutputStream(file);
    loggers[1] = new ARGLogger(arg, new TabDelimitedFormatter(out), sampleTreeEvery, "test");
    ArrayLogFormatter formatter = new ArrayLogFormatter(false);
    loggers[2] = new MCLogger(formatter, sampleTreeEvery, false);
    loggers[2].add(arg);
    arg.getRootHeightParameter().setId("root");
    loggers[2].add(arg.getRootHeightParameter());
    mcmc.setShowOperatorAnalysis(true);
    mcmc.init(options, compoundLikelihood, schedule, loggers);
    mcmc.run();
    out.flush();
    out.close();
    List<Trace> traces = formatter.getTraces();
    //        Set<String> uniqueTrees = new HashSet<String>();
    //
    //        NexusImporter importer = new NexusImporter(new FileReader(file));
    //        while (importer.hasTree()) {
    //            Tree t = importer.importNextTree();
    //            uniqueTrees.add(Tree.Utils.uniqueNewick(t, t.getRoot()));
    //        }
    //
    //        TestCase.assertEquals(numTopologies, uniqueTrees.size());            List<Trace> traces = formatter.getTraces();
    ArrayTraceList traceList = new ArrayTraceList("ARGTest", traces, 0);
    for (int i = 1; i < traces.size(); i++) {
        traceList.analyseTrace(i);
    }
    TraceCorrelation nodeCountStats = traceList.getCorrelationStatistics(1);
    TraceCorrelation rootHeightStats = traceList.getCorrelationStatistics(4);
    assertExpectation("nodeCount", nodeCountStats, poisson.truncatedMean(maxCount));
    assertExpectation(TreeModelParser.ROOT_HEIGHT, rootHeightStats, rootHeightAlpha * rootHeightBeta);
}
Also used : PoissonDistribution(dr.math.distributions.PoissonDistribution) CompoundLikelihood(dr.inference.model.CompoundLikelihood) Likelihood(dr.inference.model.Likelihood) DistributionLikelihood(dr.inference.distribution.DistributionLikelihood) MCMC(dr.inference.mcmc.MCMC) ArrayList(java.util.ArrayList) ARGUniformPrior(dr.evomodel.arg.coalescent.ARGUniformPrior) CompoundParameter(dr.inference.model.CompoundParameter) MCMCOptions(dr.inference.mcmc.MCMCOptions) ArrayLogFormatter(dr.inference.loggers.ArrayLogFormatter) ARGReassortmentNodeCountStatistic(dr.evomodel.arg.ARGReassortmentNodeCountStatistic) GammaDistribution(dr.math.distributions.GammaDistribution) TraceCorrelation(dr.inference.trace.TraceCorrelation) OperatorSchedule(dr.inference.operators.OperatorSchedule) SimpleOperatorSchedule(dr.inference.operators.SimpleOperatorSchedule) CompoundLikelihood(dr.inference.model.CompoundLikelihood) TabDelimitedFormatter(dr.inference.loggers.TabDelimitedFormatter) Trace(dr.inference.trace.Trace) ARGLogger(dr.evomodel.arg.ARGLogger) ArrayTraceList(dr.inference.trace.ArrayTraceList) FileOutputStream(java.io.FileOutputStream) DistributionLikelihood(dr.inference.distribution.DistributionLikelihood) File(java.io.File) MCLogger(dr.inference.loggers.MCLogger)

Example 10 with OperatorSchedule

use of dr.inference.operators.OperatorSchedule in project beast-mcmc by beast-dev.

the class MCMCMC method swapChainTemperatures.

private int swapChainTemperatures() {
    if (DEBUG) {
        System.out.print("Current scores: ");
        for (int i = 0; i < chains.length; i++) {
            System.out.print("\t");
            if (i == coldChain) {
                System.out.print("[");
            }
            System.out.print(chains[i].getCurrentScore());
            if (i == coldChain) {
                System.out.print("]");
            }
        }
        System.out.println();
    }
    int newColdChain = coldChain;
    int index1 = MathUtils.nextInt(chains.length);
    int index2 = MathUtils.nextInt(chains.length);
    while (index1 == index2) {
        index2 = MathUtils.nextInt(chains.length);
    }
    double score1 = chains[index1].getCurrentScore();
    MCMCCriterion acceptor1 = ((MCMCCriterion) chains[index1].getAcceptor());
    double temperature1 = acceptor1.getTemperature();
    double score2 = chains[index2].getCurrentScore();
    MCMCCriterion acceptor2 = ((MCMCCriterion) chains[index2].getAcceptor());
    double temperature2 = acceptor2.getTemperature();
    double logRatio = ((score2 - score1) * temperature1) + ((score1 - score2) * temperature2);
    boolean swap = (Math.log(MathUtils.nextDouble()) < logRatio);
    if (swap) {
        if (DEBUG) {
            System.out.println("Swapping chain " + index1 + " and chain " + index2);
        }
        acceptor1.setTemperature(temperature2);
        acceptor2.setTemperature(temperature1);
        OperatorSchedule schedule1 = schedules[index1];
        OperatorSchedule schedule2 = schedules[index2];
        for (int i = 0; i < schedule1.getOperatorCount(); i++) {
            MCMCOperator operator1 = schedule1.getOperator(i);
            MCMCOperator operator2 = schedule2.getOperator(i);
            long tmp = operator1.getAcceptCount();
            operator1.setAcceptCount(operator2.getAcceptCount());
            operator2.setAcceptCount(tmp);
            tmp = operator1.getRejectCount();
            operator1.setRejectCount(operator2.getRejectCount());
            operator2.setRejectCount(tmp);
            double tmp2 = operator1.getSumDeviation();
            operator1.setSumDeviation(operator2.getSumDeviation());
            operator2.setSumDeviation(tmp2);
            if (operator1 instanceof CoercableMCMCOperator) {
                tmp2 = ((CoercableMCMCOperator) operator1).getCoercableParameter();
                ((CoercableMCMCOperator) operator1).setCoercableParameter(((CoercableMCMCOperator) operator2).getCoercableParameter());
                ((CoercableMCMCOperator) operator2).setCoercableParameter(tmp2);
            }
        }
        if (index1 == coldChain) {
            newColdChain = index2;
        } else if (index2 == coldChain) {
            newColdChain = index1;
        }
    }
    return newColdChain;
}
Also used : OperatorSchedule(dr.inference.operators.OperatorSchedule) MCMCCriterion(dr.inference.mcmc.MCMCCriterion) CoercableMCMCOperator(dr.inference.operators.CoercableMCMCOperator) MCMCOperator(dr.inference.operators.MCMCOperator) CoercableMCMCOperator(dr.inference.operators.CoercableMCMCOperator)

Aggregations

OperatorSchedule (dr.inference.operators.OperatorSchedule)14 Parameter (dr.inference.model.Parameter)9 SimpleOperatorSchedule (dr.inference.operators.SimpleOperatorSchedule)7 TreeModel (dr.evomodel.tree.TreeModel)6 Likelihood (dr.inference.model.Likelihood)6 MCMCOperator (dr.inference.operators.MCMCOperator)6 MCMC (dr.inference.mcmc.MCMC)4 MCMCOptions (dr.inference.mcmc.MCMCOptions)4 Model (dr.inference.model.Model)4 CoercableMCMCOperator (dr.inference.operators.CoercableMCMCOperator)4 ScaleOperator (dr.inference.operators.ScaleOperator)4 ArrayList (java.util.ArrayList)4 TreeParameterModel (dr.evomodel.tree.TreeParameterModel)3 MCLogger (dr.inference.loggers.MCLogger)3 TabDelimitedFormatter (dr.inference.loggers.TabDelimitedFormatter)3 CompoundLikelihood (dr.inference.model.CompoundLikelihood)3 DistributionLikelihood (dr.inference.distribution.DistributionLikelihood)2 ArrayLogFormatter (dr.inference.loggers.ArrayLogFormatter)2 Logger (dr.inference.loggers.Logger)2 CompoundParameter (dr.inference.model.CompoundParameter)2