Search in sources :

Example 11 with OperatorSchedule

use of dr.inference.operators.OperatorSchedule in project beast-mcmc by beast-dev.

the class CheckPointModifier method readStateFromFile.

private long readStateFromFile(File file, MarkovChain markovChain, double[] lnL) {
    OperatorSchedule operatorSchedule = markovChain.getSchedule();
    long state = -1;
    this.traitModels = new ArrayList<TreeParameterModel>();
    try {
        FileReader fileIn = new FileReader(file);
        BufferedReader in = new BufferedReader(fileIn);
        int[] rngState = null;
        String line = in.readLine();
        String[] fields = line.split("\t");
        if (fields[0].equals("rng")) {
            // if there is a random number generator state present then load it...
            try {
                rngState = new int[fields.length - 1];
                for (int i = 0; i < rngState.length; i++) {
                    rngState[i] = Integer.parseInt(fields[i + 1]);
                }
            } catch (NumberFormatException nfe) {
                throw new RuntimeException("Unable to read state number from state file");
            }
            line = in.readLine();
            fields = line.split("\t");
        }
        try {
            if (!fields[0].equals("state")) {
                throw new RuntimeException("Unable to read state number from state file");
            }
            state = Long.parseLong(fields[1]);
        } catch (NumberFormatException nfe) {
            throw new RuntimeException("Unable to read state number from state file");
        }
        line = in.readLine();
        fields = line.split("\t");
        try {
            if (!fields[0].equals("lnL")) {
                throw new RuntimeException("Unable to read lnL from state file");
            }
            if (lnL != null) {
                lnL[0] = Double.parseDouble(fields[1]);
            }
        } catch (NumberFormatException nfe) {
            throw new RuntimeException("Unable to read lnL from state file");
        }
        line = in.readLine();
        //System.out.println(line);
        fields = line.split("\t");
        //Tree nodes have numbers as parameter ids
        for (Parameter parameter : Parameter.CONNECTED_PARAMETER_SET) {
            //numbers should be positive but can include zero
            if (isTreeNode(parameter.getId()) && isTreeNode(fields[1]) || parameter.getId().equals(fields[1])) {
                int dimension = Integer.parseInt(fields[2]);
                if (dimension != parameter.getDimension() && !fields[1].equals("branchRates.categories")) {
                    System.err.println("Unable to match state parameter dimension: " + dimension + ", expecting " + parameter.getDimension() + " for parameter: " + parameter.getParameterName());
                    System.err.print("Read from file: ");
                    for (int i = 0; i < fields.length; i++) {
                        System.err.print(fields[i] + "\t");
                    }
                    System.err.println();
                }
                if (fields[1].equals("branchRates.categories.rootNodeNumber")) {
                    // System.out.println("eek");
                    double value = Double.parseDouble(fields[3]);
                    parameter.setParameterValue(0, value);
                    if (DEBUG) {
                        System.out.println("restoring " + fields[1] + " with value " + value);
                    }
                } else {
                    if (DEBUG) {
                        System.out.print("restoring " + fields[1] + " with values ");
                    }
                    if (fields[1].equals("branchRates.categories")) {
                        for (int dim = 0; dim < (fields.length - 3); dim++) {
                            //System.out.println("dim " + dim);
                            parameter.setParameterValue(dim, Double.parseDouble(fields[dim + 3]));
                            if (DEBUG) {
                                System.out.print(Double.parseDouble(fields[dim + 3]) + " ");
                            }
                        }
                    } else {
                        for (int dim = 0; dim < parameter.getDimension(); dim++) {
                            parameter.setParameterValue(dim, Double.parseDouble(fields[dim + 3]));
                            if (DEBUG) {
                                System.out.print(Double.parseDouble(fields[dim + 3]) + " ");
                            }
                        }
                    }
                    if (DEBUG) {
                        System.out.println();
                    }
                }
                line = in.readLine();
                //System.out.println(line);
                fields = line.split("\t");
            } else {
            //there will be more parameters in the connected set than there are lines in the checkpoint file
            //do nothing and just keep iterating over the parameters in the connected set
            }
        }
        //No changes needed for loading in operators
        for (int i = 0; i < operatorSchedule.getOperatorCount(); i++) {
            MCMCOperator operator = operatorSchedule.getOperator(i);
            if (!fields[1].equals(operator.getOperatorName())) {
                throw new RuntimeException("Unable to match operator: " + fields[1]);
            }
            if (fields.length < 4) {
                throw new RuntimeException("Operator missing values: " + fields[1]);
            }
            operator.setAcceptCount(Integer.parseInt(fields[2]));
            operator.setRejectCount(Integer.parseInt(fields[3]));
            if (operator instanceof CoercableMCMCOperator) {
                if (fields.length != 5) {
                    throw new RuntimeException("Coercable operator missing parameter: " + fields[1]);
                }
                ((CoercableMCMCOperator) operator).setCoercableParameter(Double.parseDouble(fields[4]));
            }
            line = in.readLine();
            fields = line.split("\t");
        }
        // load the tree models last as we get the node heights from the tree (not the parameters which
        // which may not be associated with the right node
        Set<String> expectedTreeModelNames = new HashSet<String>();
        for (Model model : Model.CONNECTED_MODEL_SET) {
            if (model instanceof TreeModel) {
                expectedTreeModelNames.add(model.getModelName());
            }
            if (model instanceof TreeParameterModel) {
                this.traitModels.add((TreeParameterModel) model);
            }
            if (model instanceof BranchRates) {
                this.rateModel = (BranchRates) model;
            }
        }
        while (fields[0].equals("tree")) {
            for (Model model : Model.CONNECTED_MODEL_SET) {
                if (model instanceof TreeModel && fields[1].equals(model.getModelName())) {
                    //AR: Can we not just add them to a Flexible tree and then make a new TreeModel
                    //taking that in the constructor?
                    //internally, we have a tree with all the taxa
                    //externally, i.e. in the checkpoint file, we have a tree representation comprising
                    //a subset of the full taxa set
                    //write method that adjusts the internal representation, i.e. the one in the connected
                    //set, according to the checkpoint file and a distance-based approach to position
                    //the additional taxa
                    //first read in all the data from the checkpoint file
                    line = in.readLine();
                    line = in.readLine();
                    fields = line.split("\t");
                    //read number of nodes
                    int nodeCount = Integer.parseInt(fields[0]);
                    double[] nodeHeights = new double[nodeCount];
                    String[] taxaNames = new String[(nodeCount + 1) / 2];
                    for (int i = 0; i < nodeCount; i++) {
                        line = in.readLine();
                        fields = line.split("\t");
                        nodeHeights[i] = Double.parseDouble(fields[1]);
                        if (i < taxaNames.length) {
                            taxaNames[i] = fields[2];
                        }
                    }
                    //on to reading edge information
                    line = in.readLine();
                    line = in.readLine();
                    line = in.readLine();
                    fields = line.split("\t");
                    int edgeCount = Integer.parseInt(fields[0]);
                    //create data matrix of doubles to store information from list of TreeParameterModels
                    double[][] traitValues = new double[traitModels.size()][edgeCount];
                    //create array to store whether a node is left or right child of its parent
                    //can be important for certain tree transition kernels
                    int[] childOrder = new int[edgeCount];
                    for (int i = 0; i < childOrder.length; i++) {
                        childOrder[i] = -1;
                    }
                    int[] parents = new int[edgeCount];
                    for (int i = 0; i < edgeCount; i++) {
                        parents[i] = -1;
                    }
                    for (int i = 0; i < edgeCount; i++) {
                        line = in.readLine();
                        if (line != null) {
                            fields = line.split("\t");
                            parents[Integer.parseInt(fields[0])] = Integer.parseInt(fields[1]);
                            childOrder[i] = Integer.parseInt(fields[2]);
                            for (int j = 0; j < traitModels.size(); j++) {
                                traitValues[j][i] = Double.parseDouble(fields[3 + j]);
                            }
                        }
                    }
                    //perform magic with the acquired information
                    //CheckPointTreeModifier modifyTree = new CheckPointTreeModifier((TreeModel) model);
                    this.modifyTree = new CheckPointTreeModifier((TreeModel) model);
                    modifyTree.adoptTreeStructure(parents, nodeHeights, childOrder, taxaNames);
                    if (traitModels.size() > 0) {
                        modifyTree.adoptTraitData(parents, this.traitModels, traitValues);
                    }
                    //adopt the loaded tree structure; this does not yet copy the traits on the branches
                    //((TreeModel) model).beginTreeEdit();
                    //((TreeModel) model).adoptTreeStructure(parents, nodeHeights, childOrder);
                    //((TreeModel) model).endTreeEdit();
                    expectedTreeModelNames.remove(model.getModelName());
                }
            }
            line = in.readLine();
            if (line != null) {
                fields = line.split("\t");
            }
        }
        if (expectedTreeModelNames.size() > 0) {
            StringBuilder sb = new StringBuilder();
            for (String notFoundName : expectedTreeModelNames) {
                sb.append("Expecting, but unable to match state parameter:" + notFoundName + "\n");
            }
            throw new RuntimeException(sb.toString());
        }
        in.close();
        fileIn.close();
    } catch (IOException ioe) {
        throw new RuntimeException("Unable to read file: " + ioe.getMessage());
    }
    return state;
}
Also used : TreeModel(dr.evomodel.tree.TreeModel) FileReader(java.io.FileReader) HashSet(java.util.HashSet) OperatorSchedule(dr.inference.operators.OperatorSchedule) TreeParameterModel(dr.evomodel.tree.TreeParameterModel) IOException(java.io.IOException) BufferedReader(java.io.BufferedReader) TreeParameterModel(dr.evomodel.tree.TreeParameterModel) Model(dr.inference.model.Model) TreeModel(dr.evomodel.tree.TreeModel) Parameter(dr.inference.model.Parameter) CoercableMCMCOperator(dr.inference.operators.CoercableMCMCOperator) BranchRates(dr.evolution.tree.BranchRates) MCMCOperator(dr.inference.operators.MCMCOperator) CoercableMCMCOperator(dr.inference.operators.CoercableMCMCOperator)

Example 12 with OperatorSchedule

use of dr.inference.operators.OperatorSchedule in project beast-mcmc by beast-dev.

the class MCMCParser method parseXMLObject.

/**
     * @return an mcmc object based on the XML element it was passed.
     */
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
    MCMC mcmc = new MCMC(xo.getAttribute(NAME, "mcmc1"));
    long chainLength = xo.getLongIntegerAttribute(CHAIN_LENGTH);
    boolean useCoercion = xo.getAttribute(COERCION, true);
    long coercionDelay = chainLength / 100;
    if (xo.hasAttribute(PRE_BURNIN)) {
        coercionDelay = xo.getIntegerAttribute(PRE_BURNIN);
    }
    coercionDelay = xo.getAttribute(COERCION_DELAY, coercionDelay);
    double temperature = xo.getAttribute(TEMPERATURE, 1.0);
    long fullEvaluationCount = xo.getAttribute(FULL_EVALUATION, 2000);
    double evaluationTestThreshold = MarkovChain.EVALUATION_TEST_THRESHOLD;
    if (System.getProperty("mcmc.evaluation.threshold") != null) {
        evaluationTestThreshold = Double.parseDouble(System.getProperty("mcmc.evaluation.threshold"));
    }
    evaluationTestThreshold = xo.getAttribute(EVALUATION_THRESHOLD, evaluationTestThreshold);
    int minOperatorCountForFullEvaluation = xo.getAttribute(MIN_OPS_EVALUATIONS, 1);
    MCMCOptions options = new MCMCOptions(chainLength, fullEvaluationCount, minOperatorCountForFullEvaluation, evaluationTestThreshold, useCoercion, coercionDelay, temperature);
    OperatorSchedule opsched = (OperatorSchedule) xo.getChild(OperatorSchedule.class);
    Likelihood likelihood = (Likelihood) xo.getChild(Likelihood.class);
    likelihood.setUsed();
    if (Boolean.valueOf(System.getProperty("show_warnings", "false"))) {
        // check that all models, parameters and likelihoods are being used
        for (Likelihood l : Likelihood.FULL_LIKELIHOOD_SET) {
            if (!l.isUsed()) {
                java.util.logging.Logger.getLogger("dr.inference").warning("Likelihood, " + l.getId() + ", of class " + l.getClass().getName() + " is not being handled by the MCMC.");
            }
        }
        for (Model m : Model.FULL_MODEL_SET) {
            if (!m.isUsed()) {
                java.util.logging.Logger.getLogger("dr.inference").warning("Model, " + m.getId() + ", of class " + m.getClass().getName() + " is not being handled by the MCMC.");
            }
        }
        for (Parameter p : Parameter.FULL_PARAMETER_SET) {
            if (!p.isUsed()) {
                java.util.logging.Logger.getLogger("dr.inference").warning("Parameter, " + p.getId() + ", of class " + p.getClass().getName() + " is not being handled by the MCMC.");
            }
        }
    }
    ArrayList<Logger> loggers = new ArrayList<Logger>();
    for (int i = 0; i < xo.getChildCount(); i++) {
        Object child = xo.getChild(i);
        if (child instanceof Logger) {
            loggers.add((Logger) child);
        }
    }
    mcmc.setShowOperatorAnalysis(true);
    if (xo.hasAttribute(OPERATOR_ANALYSIS)) {
        mcmc.setOperatorAnalysisFile(XMLParser.getLogFile(xo, OPERATOR_ANALYSIS));
    }
    Logger[] loggerArray = new Logger[loggers.size()];
    loggers.toArray(loggerArray);
    java.util.logging.Logger.getLogger("dr.inference").info("\nCreating the MCMC chain:" + "\n  chainLength=" + options.getChainLength() + "\n  autoOptimize=" + options.useCoercion() + (options.useCoercion() ? "\n  autoOptimize delayed for " + options.getCoercionDelay() + " steps" : "") + (options.getFullEvaluationCount() == 0 ? "\n  full evaluation test off" : ""));
    mcmc.init(options, likelihood, opsched, loggerArray);
    MarkovChain mc = mcmc.getMarkovChain();
    double initialScore = mc.getCurrentScore();
    if (initialScore == Double.NEGATIVE_INFINITY) {
        String message = "The initial posterior is zero";
        if (likelihood instanceof CompoundLikelihood) {
            message += ": " + ((CompoundLikelihood) likelihood).getDiagnosis(2);
        } else {
            message += "!";
        }
        throw new IllegalArgumentException(message);
    }
    if (!xo.getAttribute(SPAWN, true))
        mcmc.setSpawnable(false);
    return mcmc;
}
Also used : OperatorSchedule(dr.inference.operators.OperatorSchedule) CompoundLikelihood(dr.inference.model.CompoundLikelihood) Likelihood(dr.inference.model.Likelihood) CompoundLikelihood(dr.inference.model.CompoundLikelihood) MCMC(dr.inference.mcmc.MCMC) ArrayList(java.util.ArrayList) MarkovChain(dr.inference.markovchain.MarkovChain) Logger(dr.inference.loggers.Logger) MCMCOptions(dr.inference.mcmc.MCMCOptions) Model(dr.inference.model.Model) Parameter(dr.inference.model.Parameter)

Example 13 with OperatorSchedule

use of dr.inference.operators.OperatorSchedule in project beast-mcmc by beast-dev.

the class NarrowExchangeTest method getOperatorSchedule.

public OperatorSchedule getOperatorSchedule(TreeModel treeModel) {
    Parameter rootParameter = treeModel.createNodeHeightsParameter(true, false, false);
    Parameter internalHeights = treeModel.createNodeHeightsParameter(false, true, false);
    ExchangeOperator operator = new ExchangeOperator(ExchangeOperator.NARROW, treeModel, 1.0);
    ScaleOperator scaleOperator = new ScaleOperator(rootParameter, 0.75, CoercionMode.COERCION_ON, 1.0);
    UniformOperator uniformOperator = new UniformOperator(internalHeights, 1.0);
    OperatorSchedule schedule = new SimpleOperatorSchedule();
    schedule.addOperator(operator);
    schedule.addOperator(scaleOperator);
    schedule.addOperator(uniformOperator);
    return schedule;
}
Also used : SimpleOperatorSchedule(dr.inference.operators.SimpleOperatorSchedule) OperatorSchedule(dr.inference.operators.OperatorSchedule) SimpleOperatorSchedule(dr.inference.operators.SimpleOperatorSchedule) ExchangeOperator(dr.evomodel.operators.ExchangeOperator) Parameter(dr.inference.model.Parameter) UniformOperator(dr.inference.operators.UniformOperator) ScaleOperator(dr.inference.operators.ScaleOperator)

Example 14 with OperatorSchedule

use of dr.inference.operators.OperatorSchedule in project beast-mcmc by beast-dev.

the class RLYModelTest method testTreeBitRandomWalk.

public void testTreeBitRandomWalk() {
    TreeModel treeModel = new TreeModel("treeModel", tree);
    Parameter I = treeModel.createNodeTraitsParameter(birthRateIndicator, new double[] { 1 });
    Parameter b = treeModel.createNodeTraitsParameter(birthRate, new double[] { 1 });
    OperatorSchedule schedule = new SimpleOperatorSchedule();
    TreeBitRandomWalkOperator tbrw = new TreeBitRandomWalkOperator(treeModel, birthRateIndicator, birthRate, 1.0, 4, true);
    BitFlipOperator bfo = new BitFlipOperator(I, 1.0, true);
    schedule.addOperator(tbrw);
    schedule.addOperator(bfo);
    randomLocalYuleTester(treeModel, I, b, schedule);
}
Also used : TreeModel(dr.evomodel.tree.TreeModel) SimpleOperatorSchedule(dr.inference.operators.SimpleOperatorSchedule) OperatorSchedule(dr.inference.operators.OperatorSchedule) SimpleOperatorSchedule(dr.inference.operators.SimpleOperatorSchedule) BitFlipOperator(dr.inference.operators.BitFlipOperator) Parameter(dr.inference.model.Parameter) TreeBitRandomWalkOperator(dr.evomodel.operators.TreeBitRandomWalkOperator)

Aggregations

OperatorSchedule (dr.inference.operators.OperatorSchedule)14 Parameter (dr.inference.model.Parameter)9 SimpleOperatorSchedule (dr.inference.operators.SimpleOperatorSchedule)7 TreeModel (dr.evomodel.tree.TreeModel)6 Likelihood (dr.inference.model.Likelihood)6 MCMCOperator (dr.inference.operators.MCMCOperator)6 MCMC (dr.inference.mcmc.MCMC)4 MCMCOptions (dr.inference.mcmc.MCMCOptions)4 Model (dr.inference.model.Model)4 CoercableMCMCOperator (dr.inference.operators.CoercableMCMCOperator)4 ScaleOperator (dr.inference.operators.ScaleOperator)4 ArrayList (java.util.ArrayList)4 TreeParameterModel (dr.evomodel.tree.TreeParameterModel)3 MCLogger (dr.inference.loggers.MCLogger)3 TabDelimitedFormatter (dr.inference.loggers.TabDelimitedFormatter)3 CompoundLikelihood (dr.inference.model.CompoundLikelihood)3 DistributionLikelihood (dr.inference.distribution.DistributionLikelihood)2 ArrayLogFormatter (dr.inference.loggers.ArrayLogFormatter)2 Logger (dr.inference.loggers.Logger)2 CompoundParameter (dr.inference.model.CompoundParameter)2