Search in sources :

Example 1 with CoercableMCMCOperator

use of dr.inference.operators.CoercableMCMCOperator in project beast-mcmc by beast-dev.

the class BeastCheckpointer method readStateFromFile.

private long readStateFromFile(File file, MarkovChain markovChain, double[] lnL) {
    OperatorSchedule operatorSchedule = markovChain.getSchedule();
    long state = -1;
    ArrayList<TreeParameterModel> traitModels = new ArrayList<TreeParameterModel>();
    try {
        FileReader fileIn = new FileReader(file);
        BufferedReader in = new BufferedReader(fileIn);
        int[] rngState = null;
        String line = in.readLine();
        String[] fields = line.split("\t");
        if (fields[0].equals("rng")) {
            // if there is a random number generator state present then load it...
            try {
                rngState = new int[fields.length - 1];
                for (int i = 0; i < rngState.length; i++) {
                    rngState[i] = Integer.parseInt(fields[i + 1]);
                }
            } catch (NumberFormatException nfe) {
                throw new RuntimeException("Unable to read state number from state file");
            }
            line = in.readLine();
            fields = line.split("\t");
        }
        try {
            if (!fields[0].equals("state")) {
                throw new RuntimeException("Unable to read state number from state file");
            }
            state = Long.parseLong(fields[1]);
        } catch (NumberFormatException nfe) {
            throw new RuntimeException("Unable to read state number from state file");
        }
        line = in.readLine();
        fields = line.split("\t");
        try {
            if (!fields[0].equals("lnL")) {
                throw new RuntimeException("Unable to read lnL from state file");
            }
            if (lnL != null) {
                lnL[0] = Double.parseDouble(fields[1]);
            }
        } catch (NumberFormatException nfe) {
            throw new RuntimeException("Unable to read lnL from state file");
        }
        for (Parameter parameter : Parameter.CONNECTED_PARAMETER_SET) {
            line = in.readLine();
            fields = line.split("\t");
            //if (!fields[0].equals(parameter.getParameterName())) {
            //  System.err.println("Unable to match state parameter: " + fields[0] + ", expecting " + parameter.getParameterName());
            //}
            int dimension = Integer.parseInt(fields[2]);
            if (dimension != parameter.getDimension()) {
                System.err.println("Unable to match state parameter dimension: " + dimension + ", expecting " + parameter.getDimension() + " for parameter: " + parameter.getParameterName());
                System.err.print("Read from file: ");
                for (int i = 0; i < fields.length; i++) {
                    System.err.print(fields[i] + "\t");
                }
                System.err.println();
            }
            if (fields[1].equals("branchRates.categories.rootNodeNumber")) {
                // System.out.println("eek");
                double value = Double.parseDouble(fields[3]);
                parameter.setParameterValue(0, value);
                if (DEBUG) {
                    System.out.println("restoring " + fields[1] + " with value " + value);
                }
            } else {
                if (DEBUG) {
                    System.out.print("restoring " + fields[1] + " with values ");
                }
                for (int dim = 0; dim < parameter.getDimension(); dim++) {
                    parameter.setParameterValue(dim, Double.parseDouble(fields[dim + 3]));
                    if (DEBUG) {
                        System.out.print(Double.parseDouble(fields[dim + 3]) + " ");
                    }
                }
                if (DEBUG) {
                    System.out.println();
                }
            }
        }
        for (int i = 0; i < operatorSchedule.getOperatorCount(); i++) {
            MCMCOperator operator = operatorSchedule.getOperator(i);
            line = in.readLine();
            fields = line.split("\t");
            if (!fields[1].equals(operator.getOperatorName())) {
                throw new RuntimeException("Unable to match operator: " + fields[1]);
            }
            if (fields.length < 4) {
                throw new RuntimeException("Operator missing values: " + fields[1]);
            }
            operator.setAcceptCount(Integer.parseInt(fields[2]));
            operator.setRejectCount(Integer.parseInt(fields[3]));
            if (operator instanceof CoercableMCMCOperator) {
                if (fields.length != 5) {
                    throw new RuntimeException("Coercable operator missing parameter: " + fields[1]);
                }
                ((CoercableMCMCOperator) operator).setCoercableParameter(Double.parseDouble(fields[4]));
            }
        }
        // load the tree models last as we get the node heights from the tree (not the parameters which
        // which may not be associated with the right node
        Set<String> expectedTreeModelNames = new HashSet<String>();
        for (Model model : Model.CONNECTED_MODEL_SET) {
            if (model instanceof TreeModel) {
                if (DEBUG) {
                    System.out.println("model " + model.getModelName());
                }
                expectedTreeModelNames.add(model.getModelName());
                if (DEBUG) {
                    for (String s : expectedTreeModelNames) {
                        System.out.println(s);
                    }
                }
            }
            if (model instanceof TreeParameterModel) {
                traitModels.add((TreeParameterModel) model);
            }
        }
        line = in.readLine();
        fields = line.split("\t");
        // Read in all (possibly more than one) trees
        while (fields[0].equals("tree")) {
            if (DEBUG) {
                System.out.println("tree: " + fields[1]);
            }
            for (Model model : Model.CONNECTED_MODEL_SET) {
                if (model instanceof TreeModel && fields[1].equals(model.getModelName())) {
                    line = in.readLine();
                    line = in.readLine();
                    fields = line.split("\t");
                    //read number of nodes
                    int nodeCount = Integer.parseInt(fields[0]);
                    double[] nodeHeights = new double[nodeCount];
                    for (int i = 0; i < nodeCount; i++) {
                        line = in.readLine();
                        fields = line.split("\t");
                        nodeHeights[i] = Double.parseDouble(fields[1]);
                    }
                    //on to reading edge information
                    line = in.readLine();
                    line = in.readLine();
                    line = in.readLine();
                    fields = line.split("\t");
                    int edgeCount = Integer.parseInt(fields[0]);
                    //create data matrix of doubles to store information from list of TreeParameterModels
                    double[][] traitValues = new double[traitModels.size()][edgeCount];
                    //create array to store whether a node is left or right child of its parent
                    //can be important for certain tree transition kernels
                    int[] childOrder = new int[edgeCount];
                    for (int i = 0; i < childOrder.length; i++) {
                        childOrder[i] = -1;
                    }
                    int[] parents = new int[edgeCount];
                    for (int i = 0; i < edgeCount; i++) {
                        parents[i] = -1;
                    }
                    for (int i = 0; i < edgeCount; i++) {
                        line = in.readLine();
                        if (line != null) {
                            fields = line.split("\t");
                            parents[Integer.parseInt(fields[0])] = Integer.parseInt(fields[1]);
                            childOrder[i] = Integer.parseInt(fields[2]);
                            for (int j = 0; j < traitModels.size(); j++) {
                                traitValues[j][i] = Double.parseDouble(fields[3 + j]);
                            }
                        }
                    }
                    //perform magic with the acquired information
                    if (DEBUG) {
                        System.out.println("adopting tree structure");
                    }
                    //adopt the loaded tree structure; this does not yet copy the traits on the branches
                    ((TreeModel) model).beginTreeEdit();
                    ((TreeModel) model).adoptTreeStructure(parents, nodeHeights, childOrder);
                    ((TreeModel) model).endTreeEdit();
                    expectedTreeModelNames.remove(model.getModelName());
                }
            }
            line = in.readLine();
            if (line != null) {
                fields = line.split("\t");
            }
        }
        if (expectedTreeModelNames.size() > 0) {
            StringBuilder sb = new StringBuilder();
            for (String notFoundName : expectedTreeModelNames) {
                sb.append("Expecting, but unable to match state parameter:" + notFoundName + "\n");
            }
            throw new RuntimeException(sb.toString());
        }
        if (DEBUG) {
            System.out.println("\nDouble checking:");
            for (Parameter parameter : Parameter.CONNECTED_PARAMETER_SET) {
                if (parameter.getParameterName().equals("branchRates.categories.rootNodeNumber")) {
                    System.out.println(parameter.getParameterName() + ": " + parameter.getParameterValue(0));
                }
            }
        }
        if (rngState != null) {
            MathUtils.setRandomState(rngState);
        }
        in.close();
        fileIn.close();
    // This shouldn't be necessary and if it is then it might be hiding a bug...
    //            for (Likelihood likelihood : Likelihood.CONNECTED_LIKELIHOOD_SET) {
    //                likelihood.makeDirty();
    //            }
    } catch (IOException ioe) {
        throw new RuntimeException("Unable to read file: " + ioe.getMessage());
    }
    return state;
}
Also used : TreeModel(dr.evomodel.tree.TreeModel) OperatorSchedule(dr.inference.operators.OperatorSchedule) TreeParameterModel(dr.evomodel.tree.TreeParameterModel) TreeParameterModel(dr.evomodel.tree.TreeParameterModel) Model(dr.inference.model.Model) TreeModel(dr.evomodel.tree.TreeModel) Parameter(dr.inference.model.Parameter) CoercableMCMCOperator(dr.inference.operators.CoercableMCMCOperator) MCMCOperator(dr.inference.operators.MCMCOperator) CoercableMCMCOperator(dr.inference.operators.CoercableMCMCOperator)

Example 2 with CoercableMCMCOperator

use of dr.inference.operators.CoercableMCMCOperator in project beast-mcmc by beast-dev.

the class BeastCheckpointer method writeStateToFile.

private boolean writeStateToFile(File file, long state, double lnL, MarkovChain markovChain) {
    OperatorSchedule operatorSchedule = markovChain.getSchedule();
    OutputStream fileOut = null;
    try {
        fileOut = new FileOutputStream(file);
        PrintStream out = new PrintStream(fileOut);
        ArrayList<TreeParameterModel> traitModels = new ArrayList<TreeParameterModel>();
        int[] rngState = MathUtils.getRandomState();
        out.print("rng");
        for (int i = 0; i < rngState.length; i++) {
            out.print("\t");
            out.print(rngState[i]);
        }
        out.println();
        out.print("state\t");
        out.println(state);
        out.print("lnL\t");
        out.println(lnL);
        for (Parameter parameter : Parameter.CONNECTED_PARAMETER_SET) {
            out.print("parameter");
            out.print("\t");
            out.print(parameter.getParameterName());
            out.print("\t");
            out.print(parameter.getDimension());
            for (int dim = 0; dim < parameter.getDimension(); dim++) {
                out.print("\t");
                out.print(parameter.getParameterValue(dim));
            }
            out.println();
        }
        for (int i = 0; i < operatorSchedule.getOperatorCount(); i++) {
            MCMCOperator operator = operatorSchedule.getOperator(i);
            out.print("operator");
            out.print("\t");
            out.print(operator.getOperatorName());
            out.print("\t");
            out.print(operator.getAcceptCount());
            out.print("\t");
            out.print(operator.getRejectCount());
            if (operator instanceof CoercableMCMCOperator) {
                out.print("\t");
                out.print(((CoercableMCMCOperator) operator).getCoercableParameter());
            }
            out.println();
        }
        //check up front if there are any TreeParameterModel objects
        for (Model model : Model.CONNECTED_MODEL_SET) {
            if (model instanceof TreeParameterModel) {
                //System.out.println("\nDetected TreeParameterModel: " + ((TreeParameterModel) model).toString());
                traitModels.add((TreeParameterModel) model);
            }
        }
        for (Model model : Model.CONNECTED_MODEL_SET) {
            if (model instanceof TreeModel) {
                out.print("tree");
                out.print("\t");
                out.println(model.getModelName());
                //replace Newick format by printing general graph structure
                //out.println(((TreeModel) model).getNewick());
                out.println("#node height taxon");
                int nodeCount = ((TreeModel) model).getNodeCount();
                out.println(nodeCount);
                for (int i = 0; i < nodeCount; i++) {
                    out.print(((TreeModel) model).getNode(i).getNumber());
                    out.print("\t");
                    out.print(((TreeModel) model).getNodeHeight(((TreeModel) model).getNode(i)));
                    if (((TreeModel) model).isExternal(((TreeModel) model).getNode(i))) {
                        out.print("\t");
                        out.print(((TreeModel) model).getNodeTaxon(((TreeModel) model).getNode(i)).getId());
                    }
                    out.println();
                }
                out.println("#edges");
                out.println("#child-node parent-node L/R-child traits");
                out.println(nodeCount);
                for (int i = 0; i < nodeCount; i++) {
                    NodeRef parent = ((TreeModel) model).getParent(((TreeModel) model).getNode(i));
                    if (parent != null) {
                        out.print(((TreeModel) model).getNode(i).getNumber());
                        out.print("\t");
                        out.print(((TreeModel) model).getParent(((TreeModel) model).getNode(i)).getNumber());
                        out.print("\t");
                        if ((((TreeModel) model).getChild(parent, 0) == ((TreeModel) model).getNode(i))) {
                            //left child
                            out.print(0);
                        } else if ((((TreeModel) model).getChild(parent, 1) == ((TreeModel) model).getNode(i))) {
                            //right child
                            out.print(1);
                        } else {
                            throw new RuntimeException("Operation currently only supported for nodes with 2 children.");
                        }
                        for (TreeParameterModel tpm : traitModels) {
                            out.print("\t");
                            out.print(tpm.getNodeValue((TreeModel) model, ((TreeModel) model).getNode(i)));
                        }
                        out.println();
                    }
                }
            }
        }
        out.close();
        fileOut.close();
    } catch (IOException ioe) {
        System.err.println("Unable to write file: " + ioe.getMessage());
        return false;
    }
    if (DEBUG) {
        for (Likelihood likelihood : Likelihood.CONNECTED_LIKELIHOOD_SET) {
            System.err.println(likelihood.getId() + ": " + likelihood.getLogLikelihood());
        }
    }
    return true;
}
Also used : OperatorSchedule(dr.inference.operators.OperatorSchedule) Likelihood(dr.inference.model.Likelihood) TreeParameterModel(dr.evomodel.tree.TreeParameterModel) TreeModel(dr.evomodel.tree.TreeModel) NodeRef(dr.evolution.tree.NodeRef) TreeParameterModel(dr.evomodel.tree.TreeParameterModel) Model(dr.inference.model.Model) TreeModel(dr.evomodel.tree.TreeModel) Parameter(dr.inference.model.Parameter) CoercableMCMCOperator(dr.inference.operators.CoercableMCMCOperator) MCMCOperator(dr.inference.operators.MCMCOperator) CoercableMCMCOperator(dr.inference.operators.CoercableMCMCOperator)

Example 3 with CoercableMCMCOperator

use of dr.inference.operators.CoercableMCMCOperator in project beast-mcmc by beast-dev.

the class MCMCMC method finish.

/**
     * cleans up when the chain finishes (possibly early).
     */
private void finish() {
    NumberFormatter formatter = new NumberFormatter(8);
    MCLogger[] loggers = mcLoggers[coldChain];
    for (MCLogger logger : loggers) {
        logger.log(currentState);
        logger.stopLogging();
    }
    System.out.println();
    System.out.println("Time taken: " + timer.toString());
    if (showOperatorAnalysis) {
        System.out.println();
        System.out.println("Operator analysis");
        System.out.println(formatter.formatToFieldWidth("Operator", 30) + formatter.formatToFieldWidth("", 8) + formatter.formatToFieldWidth("Pr(accept)", 11) + " Performance suggestion");
        for (int i = 0; i < schedules[coldChain].getOperatorCount(); i++) {
            MCMCOperator op = schedules[coldChain].getOperator(i);
            double acceptanceProb = MCMCOperator.Utils.getAcceptanceProbability(op);
            String message = "good";
            if (acceptanceProb < op.getMinimumGoodAcceptanceLevel()) {
                if (acceptanceProb < (op.getMinimumAcceptanceLevel() / 10.0)) {
                    message = "very low";
                } else if (acceptanceProb < op.getMinimumAcceptanceLevel()) {
                    message = "low";
                } else
                    message = "slightly low";
            } else if (acceptanceProb > op.getMaximumGoodAcceptanceLevel()) {
                double reallyHigh = 1.0 - ((1.0 - op.getMaximumAcceptanceLevel()) / 10.0);
                if (acceptanceProb > reallyHigh) {
                    message = "very high";
                } else if (acceptanceProb > op.getMaximumAcceptanceLevel()) {
                    message = "high";
                } else
                    message = "slightly high";
            }
            String suggestion = op.getPerformanceSuggestion();
            String pString = "        ";
            if (op instanceof CoercableMCMCOperator) {
                pString = formatter.formatToFieldWidth(formatter.formatDecimal(((CoercableMCMCOperator) op).getRawParameter(), 3), 8);
            }
            System.out.println(formatter.formatToFieldWidth(op.getOperatorName(), 30) + pString + formatter.formatToFieldWidth(formatter.formatDecimal(acceptanceProb, 4), 11) + " " + message + "\t" + suggestion);
        }
        System.out.println();
    }
}
Also used : CoercableMCMCOperator(dr.inference.operators.CoercableMCMCOperator) MCLogger(dr.inference.loggers.MCLogger) MCMCOperator(dr.inference.operators.MCMCOperator) CoercableMCMCOperator(dr.inference.operators.CoercableMCMCOperator) NumberFormatter(dr.util.NumberFormatter)

Example 4 with CoercableMCMCOperator

use of dr.inference.operators.CoercableMCMCOperator in project beast-mcmc by beast-dev.

the class MCMCMC method swapChainTemperatures.

private int swapChainTemperatures() {
    if (DEBUG) {
        System.out.print("Current scores: ");
        for (int i = 0; i < chains.length; i++) {
            System.out.print("\t");
            if (i == coldChain) {
                System.out.print("[");
            }
            System.out.print(chains[i].getCurrentScore());
            if (i == coldChain) {
                System.out.print("]");
            }
        }
        System.out.println();
    }
    int newColdChain = coldChain;
    int index1 = MathUtils.nextInt(chains.length);
    int index2 = MathUtils.nextInt(chains.length);
    while (index1 == index2) {
        index2 = MathUtils.nextInt(chains.length);
    }
    double score1 = chains[index1].getCurrentScore();
    MCMCCriterion acceptor1 = ((MCMCCriterion) chains[index1].getAcceptor());
    double temperature1 = acceptor1.getTemperature();
    double score2 = chains[index2].getCurrentScore();
    MCMCCriterion acceptor2 = ((MCMCCriterion) chains[index2].getAcceptor());
    double temperature2 = acceptor2.getTemperature();
    double logRatio = ((score2 - score1) * temperature1) + ((score1 - score2) * temperature2);
    boolean swap = (Math.log(MathUtils.nextDouble()) < logRatio);
    if (swap) {
        if (DEBUG) {
            System.out.println("Swapping chain " + index1 + " and chain " + index2);
        }
        acceptor1.setTemperature(temperature2);
        acceptor2.setTemperature(temperature1);
        OperatorSchedule schedule1 = schedules[index1];
        OperatorSchedule schedule2 = schedules[index2];
        for (int i = 0; i < schedule1.getOperatorCount(); i++) {
            MCMCOperator operator1 = schedule1.getOperator(i);
            MCMCOperator operator2 = schedule2.getOperator(i);
            long tmp = operator1.getAcceptCount();
            operator1.setAcceptCount(operator2.getAcceptCount());
            operator2.setAcceptCount(tmp);
            tmp = operator1.getRejectCount();
            operator1.setRejectCount(operator2.getRejectCount());
            operator2.setRejectCount(tmp);
            double tmp2 = operator1.getSumDeviation();
            operator1.setSumDeviation(operator2.getSumDeviation());
            operator2.setSumDeviation(tmp2);
            if (operator1 instanceof CoercableMCMCOperator) {
                tmp2 = ((CoercableMCMCOperator) operator1).getCoercableParameter();
                ((CoercableMCMCOperator) operator1).setCoercableParameter(((CoercableMCMCOperator) operator2).getCoercableParameter());
                ((CoercableMCMCOperator) operator2).setCoercableParameter(tmp2);
            }
        }
        if (index1 == coldChain) {
            newColdChain = index2;
        } else if (index2 == coldChain) {
            newColdChain = index1;
        }
    }
    return newColdChain;
}
Also used : OperatorSchedule(dr.inference.operators.OperatorSchedule) MCMCCriterion(dr.inference.mcmc.MCMCCriterion) CoercableMCMCOperator(dr.inference.operators.CoercableMCMCOperator) MCMCOperator(dr.inference.operators.MCMCOperator) CoercableMCMCOperator(dr.inference.operators.CoercableMCMCOperator)

Example 5 with CoercableMCMCOperator

use of dr.inference.operators.CoercableMCMCOperator in project beast-mcmc by beast-dev.

the class CheckPointModifier method readStateFromFile.

private long readStateFromFile(File file, MarkovChain markovChain, double[] lnL) {
    OperatorSchedule operatorSchedule = markovChain.getSchedule();
    long state = -1;
    this.traitModels = new ArrayList<TreeParameterModel>();
    try {
        FileReader fileIn = new FileReader(file);
        BufferedReader in = new BufferedReader(fileIn);
        int[] rngState = null;
        String line = in.readLine();
        String[] fields = line.split("\t");
        if (fields[0].equals("rng")) {
            // if there is a random number generator state present then load it...
            try {
                rngState = new int[fields.length - 1];
                for (int i = 0; i < rngState.length; i++) {
                    rngState[i] = Integer.parseInt(fields[i + 1]);
                }
            } catch (NumberFormatException nfe) {
                throw new RuntimeException("Unable to read state number from state file");
            }
            line = in.readLine();
            fields = line.split("\t");
        }
        try {
            if (!fields[0].equals("state")) {
                throw new RuntimeException("Unable to read state number from state file");
            }
            state = Long.parseLong(fields[1]);
        } catch (NumberFormatException nfe) {
            throw new RuntimeException("Unable to read state number from state file");
        }
        line = in.readLine();
        fields = line.split("\t");
        try {
            if (!fields[0].equals("lnL")) {
                throw new RuntimeException("Unable to read lnL from state file");
            }
            if (lnL != null) {
                lnL[0] = Double.parseDouble(fields[1]);
            }
        } catch (NumberFormatException nfe) {
            throw new RuntimeException("Unable to read lnL from state file");
        }
        line = in.readLine();
        //System.out.println(line);
        fields = line.split("\t");
        //Tree nodes have numbers as parameter ids
        for (Parameter parameter : Parameter.CONNECTED_PARAMETER_SET) {
            //numbers should be positive but can include zero
            if (isTreeNode(parameter.getId()) && isTreeNode(fields[1]) || parameter.getId().equals(fields[1])) {
                int dimension = Integer.parseInt(fields[2]);
                if (dimension != parameter.getDimension() && !fields[1].equals("branchRates.categories")) {
                    System.err.println("Unable to match state parameter dimension: " + dimension + ", expecting " + parameter.getDimension() + " for parameter: " + parameter.getParameterName());
                    System.err.print("Read from file: ");
                    for (int i = 0; i < fields.length; i++) {
                        System.err.print(fields[i] + "\t");
                    }
                    System.err.println();
                }
                if (fields[1].equals("branchRates.categories.rootNodeNumber")) {
                    // System.out.println("eek");
                    double value = Double.parseDouble(fields[3]);
                    parameter.setParameterValue(0, value);
                    if (DEBUG) {
                        System.out.println("restoring " + fields[1] + " with value " + value);
                    }
                } else {
                    if (DEBUG) {
                        System.out.print("restoring " + fields[1] + " with values ");
                    }
                    if (fields[1].equals("branchRates.categories")) {
                        for (int dim = 0; dim < (fields.length - 3); dim++) {
                            //System.out.println("dim " + dim);
                            parameter.setParameterValue(dim, Double.parseDouble(fields[dim + 3]));
                            if (DEBUG) {
                                System.out.print(Double.parseDouble(fields[dim + 3]) + " ");
                            }
                        }
                    } else {
                        for (int dim = 0; dim < parameter.getDimension(); dim++) {
                            parameter.setParameterValue(dim, Double.parseDouble(fields[dim + 3]));
                            if (DEBUG) {
                                System.out.print(Double.parseDouble(fields[dim + 3]) + " ");
                            }
                        }
                    }
                    if (DEBUG) {
                        System.out.println();
                    }
                }
                line = in.readLine();
                //System.out.println(line);
                fields = line.split("\t");
            } else {
            //there will be more parameters in the connected set than there are lines in the checkpoint file
            //do nothing and just keep iterating over the parameters in the connected set
            }
        }
        //No changes needed for loading in operators
        for (int i = 0; i < operatorSchedule.getOperatorCount(); i++) {
            MCMCOperator operator = operatorSchedule.getOperator(i);
            if (!fields[1].equals(operator.getOperatorName())) {
                throw new RuntimeException("Unable to match operator: " + fields[1]);
            }
            if (fields.length < 4) {
                throw new RuntimeException("Operator missing values: " + fields[1]);
            }
            operator.setAcceptCount(Integer.parseInt(fields[2]));
            operator.setRejectCount(Integer.parseInt(fields[3]));
            if (operator instanceof CoercableMCMCOperator) {
                if (fields.length != 5) {
                    throw new RuntimeException("Coercable operator missing parameter: " + fields[1]);
                }
                ((CoercableMCMCOperator) operator).setCoercableParameter(Double.parseDouble(fields[4]));
            }
            line = in.readLine();
            fields = line.split("\t");
        }
        // load the tree models last as we get the node heights from the tree (not the parameters which
        // which may not be associated with the right node
        Set<String> expectedTreeModelNames = new HashSet<String>();
        for (Model model : Model.CONNECTED_MODEL_SET) {
            if (model instanceof TreeModel) {
                expectedTreeModelNames.add(model.getModelName());
            }
            if (model instanceof TreeParameterModel) {
                this.traitModels.add((TreeParameterModel) model);
            }
            if (model instanceof BranchRates) {
                this.rateModel = (BranchRates) model;
            }
        }
        while (fields[0].equals("tree")) {
            for (Model model : Model.CONNECTED_MODEL_SET) {
                if (model instanceof TreeModel && fields[1].equals(model.getModelName())) {
                    //AR: Can we not just add them to a Flexible tree and then make a new TreeModel
                    //taking that in the constructor?
                    //internally, we have a tree with all the taxa
                    //externally, i.e. in the checkpoint file, we have a tree representation comprising
                    //a subset of the full taxa set
                    //write method that adjusts the internal representation, i.e. the one in the connected
                    //set, according to the checkpoint file and a distance-based approach to position
                    //the additional taxa
                    //first read in all the data from the checkpoint file
                    line = in.readLine();
                    line = in.readLine();
                    fields = line.split("\t");
                    //read number of nodes
                    int nodeCount = Integer.parseInt(fields[0]);
                    double[] nodeHeights = new double[nodeCount];
                    String[] taxaNames = new String[(nodeCount + 1) / 2];
                    for (int i = 0; i < nodeCount; i++) {
                        line = in.readLine();
                        fields = line.split("\t");
                        nodeHeights[i] = Double.parseDouble(fields[1]);
                        if (i < taxaNames.length) {
                            taxaNames[i] = fields[2];
                        }
                    }
                    //on to reading edge information
                    line = in.readLine();
                    line = in.readLine();
                    line = in.readLine();
                    fields = line.split("\t");
                    int edgeCount = Integer.parseInt(fields[0]);
                    //create data matrix of doubles to store information from list of TreeParameterModels
                    double[][] traitValues = new double[traitModels.size()][edgeCount];
                    //create array to store whether a node is left or right child of its parent
                    //can be important for certain tree transition kernels
                    int[] childOrder = new int[edgeCount];
                    for (int i = 0; i < childOrder.length; i++) {
                        childOrder[i] = -1;
                    }
                    int[] parents = new int[edgeCount];
                    for (int i = 0; i < edgeCount; i++) {
                        parents[i] = -1;
                    }
                    for (int i = 0; i < edgeCount; i++) {
                        line = in.readLine();
                        if (line != null) {
                            fields = line.split("\t");
                            parents[Integer.parseInt(fields[0])] = Integer.parseInt(fields[1]);
                            childOrder[i] = Integer.parseInt(fields[2]);
                            for (int j = 0; j < traitModels.size(); j++) {
                                traitValues[j][i] = Double.parseDouble(fields[3 + j]);
                            }
                        }
                    }
                    //perform magic with the acquired information
                    //CheckPointTreeModifier modifyTree = new CheckPointTreeModifier((TreeModel) model);
                    this.modifyTree = new CheckPointTreeModifier((TreeModel) model);
                    modifyTree.adoptTreeStructure(parents, nodeHeights, childOrder, taxaNames);
                    if (traitModels.size() > 0) {
                        modifyTree.adoptTraitData(parents, this.traitModels, traitValues);
                    }
                    //adopt the loaded tree structure; this does not yet copy the traits on the branches
                    //((TreeModel) model).beginTreeEdit();
                    //((TreeModel) model).adoptTreeStructure(parents, nodeHeights, childOrder);
                    //((TreeModel) model).endTreeEdit();
                    expectedTreeModelNames.remove(model.getModelName());
                }
            }
            line = in.readLine();
            if (line != null) {
                fields = line.split("\t");
            }
        }
        if (expectedTreeModelNames.size() > 0) {
            StringBuilder sb = new StringBuilder();
            for (String notFoundName : expectedTreeModelNames) {
                sb.append("Expecting, but unable to match state parameter:" + notFoundName + "\n");
            }
            throw new RuntimeException(sb.toString());
        }
        in.close();
        fileIn.close();
    } catch (IOException ioe) {
        throw new RuntimeException("Unable to read file: " + ioe.getMessage());
    }
    return state;
}
Also used : TreeModel(dr.evomodel.tree.TreeModel) FileReader(java.io.FileReader) HashSet(java.util.HashSet) OperatorSchedule(dr.inference.operators.OperatorSchedule) TreeParameterModel(dr.evomodel.tree.TreeParameterModel) IOException(java.io.IOException) BufferedReader(java.io.BufferedReader) TreeParameterModel(dr.evomodel.tree.TreeParameterModel) Model(dr.inference.model.Model) TreeModel(dr.evomodel.tree.TreeModel) Parameter(dr.inference.model.Parameter) CoercableMCMCOperator(dr.inference.operators.CoercableMCMCOperator) BranchRates(dr.evolution.tree.BranchRates) MCMCOperator(dr.inference.operators.MCMCOperator) CoercableMCMCOperator(dr.inference.operators.CoercableMCMCOperator)

Aggregations

CoercableMCMCOperator (dr.inference.operators.CoercableMCMCOperator)5 MCMCOperator (dr.inference.operators.MCMCOperator)5 OperatorSchedule (dr.inference.operators.OperatorSchedule)4 TreeModel (dr.evomodel.tree.TreeModel)3 TreeParameterModel (dr.evomodel.tree.TreeParameterModel)3 Model (dr.inference.model.Model)3 Parameter (dr.inference.model.Parameter)3 BranchRates (dr.evolution.tree.BranchRates)1 NodeRef (dr.evolution.tree.NodeRef)1 MCLogger (dr.inference.loggers.MCLogger)1 MCMCCriterion (dr.inference.mcmc.MCMCCriterion)1 Likelihood (dr.inference.model.Likelihood)1 NumberFormatter (dr.util.NumberFormatter)1 BufferedReader (java.io.BufferedReader)1 FileReader (java.io.FileReader)1 IOException (java.io.IOException)1 HashSet (java.util.HashSet)1