Search in sources :

Example 1 with Model

use of dr.inference.model.Model in project beast-mcmc by beast-dev.

the class BeastCheckpointer method readStateFromFile.

private long readStateFromFile(File file, MarkovChain markovChain, double[] lnL) {
    OperatorSchedule operatorSchedule = markovChain.getSchedule();
    long state = -1;
    ArrayList<TreeParameterModel> traitModels = new ArrayList<TreeParameterModel>();
    try {
        FileReader fileIn = new FileReader(file);
        BufferedReader in = new BufferedReader(fileIn);
        int[] rngState = null;
        String line = in.readLine();
        String[] fields = line.split("\t");
        if (fields[0].equals("rng")) {
            // if there is a random number generator state present then load it...
            try {
                rngState = new int[fields.length - 1];
                for (int i = 0; i < rngState.length; i++) {
                    rngState[i] = Integer.parseInt(fields[i + 1]);
                }
            } catch (NumberFormatException nfe) {
                throw new RuntimeException("Unable to read state number from state file");
            }
            line = in.readLine();
            fields = line.split("\t");
        }
        try {
            if (!fields[0].equals("state")) {
                throw new RuntimeException("Unable to read state number from state file");
            }
            state = Long.parseLong(fields[1]);
        } catch (NumberFormatException nfe) {
            throw new RuntimeException("Unable to read state number from state file");
        }
        line = in.readLine();
        fields = line.split("\t");
        try {
            if (!fields[0].equals("lnL")) {
                throw new RuntimeException("Unable to read lnL from state file");
            }
            if (lnL != null) {
                lnL[0] = Double.parseDouble(fields[1]);
            }
        } catch (NumberFormatException nfe) {
            throw new RuntimeException("Unable to read lnL from state file");
        }
        for (Parameter parameter : Parameter.CONNECTED_PARAMETER_SET) {
            line = in.readLine();
            fields = line.split("\t");
            //if (!fields[0].equals(parameter.getParameterName())) {
            //  System.err.println("Unable to match state parameter: " + fields[0] + ", expecting " + parameter.getParameterName());
            //}
            int dimension = Integer.parseInt(fields[2]);
            if (dimension != parameter.getDimension()) {
                System.err.println("Unable to match state parameter dimension: " + dimension + ", expecting " + parameter.getDimension() + " for parameter: " + parameter.getParameterName());
                System.err.print("Read from file: ");
                for (int i = 0; i < fields.length; i++) {
                    System.err.print(fields[i] + "\t");
                }
                System.err.println();
            }
            if (fields[1].equals("branchRates.categories.rootNodeNumber")) {
                // System.out.println("eek");
                double value = Double.parseDouble(fields[3]);
                parameter.setParameterValue(0, value);
                if (DEBUG) {
                    System.out.println("restoring " + fields[1] + " with value " + value);
                }
            } else {
                if (DEBUG) {
                    System.out.print("restoring " + fields[1] + " with values ");
                }
                for (int dim = 0; dim < parameter.getDimension(); dim++) {
                    parameter.setParameterValue(dim, Double.parseDouble(fields[dim + 3]));
                    if (DEBUG) {
                        System.out.print(Double.parseDouble(fields[dim + 3]) + " ");
                    }
                }
                if (DEBUG) {
                    System.out.println();
                }
            }
        }
        for (int i = 0; i < operatorSchedule.getOperatorCount(); i++) {
            MCMCOperator operator = operatorSchedule.getOperator(i);
            line = in.readLine();
            fields = line.split("\t");
            if (!fields[1].equals(operator.getOperatorName())) {
                throw new RuntimeException("Unable to match operator: " + fields[1]);
            }
            if (fields.length < 4) {
                throw new RuntimeException("Operator missing values: " + fields[1]);
            }
            operator.setAcceptCount(Integer.parseInt(fields[2]));
            operator.setRejectCount(Integer.parseInt(fields[3]));
            if (operator instanceof CoercableMCMCOperator) {
                if (fields.length != 5) {
                    throw new RuntimeException("Coercable operator missing parameter: " + fields[1]);
                }
                ((CoercableMCMCOperator) operator).setCoercableParameter(Double.parseDouble(fields[4]));
            }
        }
        // load the tree models last as we get the node heights from the tree (not the parameters which
        // which may not be associated with the right node
        Set<String> expectedTreeModelNames = new HashSet<String>();
        for (Model model : Model.CONNECTED_MODEL_SET) {
            if (model instanceof TreeModel) {
                if (DEBUG) {
                    System.out.println("model " + model.getModelName());
                }
                expectedTreeModelNames.add(model.getModelName());
                if (DEBUG) {
                    for (String s : expectedTreeModelNames) {
                        System.out.println(s);
                    }
                }
            }
            if (model instanceof TreeParameterModel) {
                traitModels.add((TreeParameterModel) model);
            }
        }
        line = in.readLine();
        fields = line.split("\t");
        // Read in all (possibly more than one) trees
        while (fields[0].equals("tree")) {
            if (DEBUG) {
                System.out.println("tree: " + fields[1]);
            }
            for (Model model : Model.CONNECTED_MODEL_SET) {
                if (model instanceof TreeModel && fields[1].equals(model.getModelName())) {
                    line = in.readLine();
                    line = in.readLine();
                    fields = line.split("\t");
                    //read number of nodes
                    int nodeCount = Integer.parseInt(fields[0]);
                    double[] nodeHeights = new double[nodeCount];
                    for (int i = 0; i < nodeCount; i++) {
                        line = in.readLine();
                        fields = line.split("\t");
                        nodeHeights[i] = Double.parseDouble(fields[1]);
                    }
                    //on to reading edge information
                    line = in.readLine();
                    line = in.readLine();
                    line = in.readLine();
                    fields = line.split("\t");
                    int edgeCount = Integer.parseInt(fields[0]);
                    //create data matrix of doubles to store information from list of TreeParameterModels
                    double[][] traitValues = new double[traitModels.size()][edgeCount];
                    //create array to store whether a node is left or right child of its parent
                    //can be important for certain tree transition kernels
                    int[] childOrder = new int[edgeCount];
                    for (int i = 0; i < childOrder.length; i++) {
                        childOrder[i] = -1;
                    }
                    int[] parents = new int[edgeCount];
                    for (int i = 0; i < edgeCount; i++) {
                        parents[i] = -1;
                    }
                    for (int i = 0; i < edgeCount; i++) {
                        line = in.readLine();
                        if (line != null) {
                            fields = line.split("\t");
                            parents[Integer.parseInt(fields[0])] = Integer.parseInt(fields[1]);
                            childOrder[i] = Integer.parseInt(fields[2]);
                            for (int j = 0; j < traitModels.size(); j++) {
                                traitValues[j][i] = Double.parseDouble(fields[3 + j]);
                            }
                        }
                    }
                    //perform magic with the acquired information
                    if (DEBUG) {
                        System.out.println("adopting tree structure");
                    }
                    //adopt the loaded tree structure; this does not yet copy the traits on the branches
                    ((TreeModel) model).beginTreeEdit();
                    ((TreeModel) model).adoptTreeStructure(parents, nodeHeights, childOrder);
                    ((TreeModel) model).endTreeEdit();
                    expectedTreeModelNames.remove(model.getModelName());
                }
            }
            line = in.readLine();
            if (line != null) {
                fields = line.split("\t");
            }
        }
        if (expectedTreeModelNames.size() > 0) {
            StringBuilder sb = new StringBuilder();
            for (String notFoundName : expectedTreeModelNames) {
                sb.append("Expecting, but unable to match state parameter:" + notFoundName + "\n");
            }
            throw new RuntimeException(sb.toString());
        }
        if (DEBUG) {
            System.out.println("\nDouble checking:");
            for (Parameter parameter : Parameter.CONNECTED_PARAMETER_SET) {
                if (parameter.getParameterName().equals("branchRates.categories.rootNodeNumber")) {
                    System.out.println(parameter.getParameterName() + ": " + parameter.getParameterValue(0));
                }
            }
        }
        if (rngState != null) {
            MathUtils.setRandomState(rngState);
        }
        in.close();
        fileIn.close();
    // This shouldn't be necessary and if it is then it might be hiding a bug...
    //            for (Likelihood likelihood : Likelihood.CONNECTED_LIKELIHOOD_SET) {
    //                likelihood.makeDirty();
    //            }
    } catch (IOException ioe) {
        throw new RuntimeException("Unable to read file: " + ioe.getMessage());
    }
    return state;
}
Also used : TreeModel(dr.evomodel.tree.TreeModel) OperatorSchedule(dr.inference.operators.OperatorSchedule) TreeParameterModel(dr.evomodel.tree.TreeParameterModel) TreeParameterModel(dr.evomodel.tree.TreeParameterModel) Model(dr.inference.model.Model) TreeModel(dr.evomodel.tree.TreeModel) Parameter(dr.inference.model.Parameter) CoercableMCMCOperator(dr.inference.operators.CoercableMCMCOperator) MCMCOperator(dr.inference.operators.MCMCOperator) CoercableMCMCOperator(dr.inference.operators.CoercableMCMCOperator)

Example 2 with Model

use of dr.inference.model.Model in project beast-mcmc by beast-dev.

the class TreeLoggerParser method parseXMLParameters.

protected void parseXMLParameters(XMLObject xo) throws XMLParseException {
    // reset this every time...
    branchRates = null;
    tree = (Tree) xo.getChild(Tree.class);
    title = xo.getAttribute(TITLE, "");
    nexusFormat = xo.getAttribute(NEXUS_FORMAT, false);
    sortTranslationTable = xo.getAttribute(SORT_TRANSLATION_TABLE, true);
    boolean substitutions = xo.getAttribute(BRANCH_LENGTHS, "").equals(SUBSTITUTIONS);
    List<TreeAttributeProvider> taps = new ArrayList<TreeAttributeProvider>();
    List<TreeTraitProvider> ttps = new ArrayList<TreeTraitProvider>();
    // ttps2 are for TTPs that are not specified within a Trait element. These are only
    // included if not already added through a trait element to avoid duplication of
    // (in particular) the BranchRates which is required for substitution trees.
    List<TreeTraitProvider> ttps2 = new ArrayList<TreeTraitProvider>();
    for (int i = 0; i < xo.getChildCount(); i++) {
        Object cxo = xo.getChild(i);
        if (cxo instanceof Likelihood) {
            final Likelihood likelihood = (Likelihood) cxo;
            taps.add(new TreeAttributeProvider() {

                public String[] getTreeAttributeLabel() {
                    return new String[] { "lnP" };
                }

                public String[] getAttributeForTree(Tree tree) {
                    return new String[] { Double.toString(likelihood.getLogLikelihood()) };
                }
            });
        }
        if (cxo instanceof TreeAttributeProvider) {
            taps.add((TreeAttributeProvider) cxo);
        }
        if (cxo instanceof TreeTraitProvider) {
            if (xo.hasAttribute(FILTER_TRAITS)) {
                String[] matches = ((String) xo.getAttribute(FILTER_TRAITS)).split("[\\s,]+");
                TreeTraitProvider ttp = (TreeTraitProvider) cxo;
                TreeTrait[] traits = ttp.getTreeTraits();
                List<TreeTrait> filteredTraits = new ArrayList<TreeTrait>();
                for (String match : matches) {
                    for (TreeTrait trait : traits) {
                        if (trait.getTraitName().startsWith(match)) {
                            filteredTraits.add(trait);
                        }
                    }
                }
                if (filteredTraits.size() > 0) {
                    ttps2.add(new TreeTraitProvider.Helper(filteredTraits));
                }
            } else {
                // Add all of them
                ttps2.add((TreeTraitProvider) cxo);
            }
        }
        if (cxo instanceof XMLObject) {
            XMLObject xco = (XMLObject) cxo;
            if (xco.getName().equals(TREE_TRAIT)) {
                TreeTraitProvider ttp = (TreeTraitProvider) xco.getChild(TreeTraitProvider.class);
                if (xco.hasAttribute(NAME)) {
                    // a specific named trait is required (optionally with a tag to name it in the tree file)
                    String name = xco.getStringAttribute(NAME);
                    final TreeTrait trait = ttp.getTreeTrait(name);
                    if (trait == null) {
                        String childName = "TreeTraitProvider";
                        if (ttp instanceof Likelihood) {
                            childName = ((Likelihood) ttp).prettyName();
                        } else if (ttp instanceof Model) {
                            childName = ((Model) ttp).getModelName();
                        }
                        throw new XMLParseException("Trait named, " + name + ", not found for " + childName);
                    }
                    final String tag;
                    if (xco.hasAttribute(TAG)) {
                        tag = xco.getStringAttribute(TAG);
                    } else {
                        tag = name;
                    }
                    ttps.add(new TreeTraitProvider.Helper(tag, new TreeTrait() {

                        public String getTraitName() {
                            return tag;
                        }

                        public Intent getIntent() {
                            return trait.getIntent();
                        }

                        public Class getTraitClass() {
                            return trait.getTraitClass();
                        }

                        public Object getTrait(Tree tree, NodeRef node) {
                            return trait.getTrait(tree, node);
                        }

                        public String getTraitString(Tree tree, NodeRef node) {
                            return trait.getTraitString(tree, node);
                        }

                        public boolean getLoggable() {
                            return trait.getLoggable();
                        }
                    }));
                } else if (xo.hasAttribute(FILTER_TRAITS)) {
                    // else a filter attribute is given to ask for all traits that starts with a specific
                    // string
                    String[] matches = ((String) xo.getAttribute(FILTER_TRAITS)).split("[\\s,]+");
                    TreeTrait[] traits = ttp.getTreeTraits();
                    List<TreeTrait> filteredTraits = new ArrayList<TreeTrait>();
                    for (String match : matches) {
                        for (TreeTrait trait : traits) {
                            if (trait.getTraitName().startsWith(match)) {
                                filteredTraits.add(trait);
                            }
                        }
                    }
                    if (filteredTraits.size() > 0) {
                        ttps.add(new TreeTraitProvider.Helper(filteredTraits));
                    }
                } else {
                    // neither named or filtered traits so just add them all
                    ttps.add(ttp);
                }
            }
        }
        // be able to put arbitrary statistics in as tree attributes
        if (cxo instanceof Loggable) {
            final Loggable loggable = (Loggable) cxo;
            taps.add(new TreeAttributeProvider() {

                public String[] getTreeAttributeLabel() {
                    String[] labels = new String[loggable.getColumns().length];
                    for (int i = 0; i < loggable.getColumns().length; i++) {
                        labels[i] = loggable.getColumns()[i].getLabel();
                    }
                    return labels;
                }

                public String[] getAttributeForTree(Tree tree) {
                    String[] values = new String[loggable.getColumns().length];
                    for (int i = 0; i < loggable.getColumns().length; i++) {
                        values[i] = loggable.getColumns()[i].getFormatted();
                    }
                    return values;
                }
            });
        }
    }
    // inclusion of the codon partitioned robust counting TTP...
    if (ttps2.size() > 0) {
        ttps.addAll(ttps2);
    }
    if (substitutions) {
        branchRates = (BranchRates) xo.getChild(BranchRates.class);
    }
    if (substitutions && branchRates == null) {
        throw new XMLParseException("To log trees in units of substitutions a BranchRateModel must be provided");
    }
    // logEvery of zero only displays at the end
    logEvery = xo.getAttribute(LOG_EVERY, 0);
    //        double normaliseMeanRateTo = xo.getAttribute(NORMALISE_MEAN_RATE_TO, Double.NaN);
    // decimal places
    final int dp = xo.getAttribute(DECIMAL_PLACES, -1);
    if (dp != -1) {
        format = NumberFormat.getNumberInstance(Locale.ENGLISH);
        format.setMaximumFractionDigits(dp);
    }
    final PrintWriter pw = getLogFile(xo, getParserName());
    formatter = new TabDelimitedFormatter(pw);
    treeAttributeProviders = new TreeAttributeProvider[taps.size()];
    taps.toArray(treeAttributeProviders);
    treeTraitProviders = new TreeTraitProvider[ttps.size()];
    ttps.toArray(treeTraitProviders);
    // I think the default should be to have names rather than numbers, thus the false default - AJD
    // I think the default should be numbers - using names results in larger files and end user never
    // sees the numbers anyway as any software loading the nexus files does the translation - JH
    mapNames = xo.getAttribute(MAP_NAMES, true);
    condition = logEvery == 0 ? (TreeLogger.LogUpon) xo.getChild(TreeLogger.LogUpon.class) : null;
}
Also used : Likelihood(dr.inference.model.Likelihood) ArrayList(java.util.ArrayList) Loggable(dr.inference.loggers.Loggable) ArrayList(java.util.ArrayList) List(java.util.List) PrintWriter(java.io.PrintWriter) TabDelimitedFormatter(dr.inference.loggers.TabDelimitedFormatter) Model(dr.inference.model.Model)

Example 3 with Model

use of dr.inference.model.Model in project beast-mcmc by beast-dev.

the class DummyLikelihoodParser method parseXMLObject.

public Object parseXMLObject(XMLObject xo) {
    Model model = (Model) xo.getChild(Model.class);
    Parameter parameter = (Parameter) xo.getChild(Parameter.class);
    if (model == null) {
        model = new DefaultModel();
    }
    final DummyLikelihood likelihood = new DummyLikelihood(model);
    ((DefaultModel) model).addVariable(parameter);
    return likelihood;
}
Also used : DefaultModel(dr.inference.model.DefaultModel) DefaultModel(dr.inference.model.DefaultModel) Model(dr.inference.model.Model) Parameter(dr.inference.model.Parameter) DummyLikelihood(dr.inference.model.DummyLikelihood)

Example 4 with Model

use of dr.inference.model.Model in project beast-mcmc by beast-dev.

the class BeastCheckpointer method writeStateToFile.

private boolean writeStateToFile(File file, long state, double lnL, MarkovChain markovChain) {
    OperatorSchedule operatorSchedule = markovChain.getSchedule();
    OutputStream fileOut = null;
    try {
        fileOut = new FileOutputStream(file);
        PrintStream out = new PrintStream(fileOut);
        ArrayList<TreeParameterModel> traitModels = new ArrayList<TreeParameterModel>();
        int[] rngState = MathUtils.getRandomState();
        out.print("rng");
        for (int i = 0; i < rngState.length; i++) {
            out.print("\t");
            out.print(rngState[i]);
        }
        out.println();
        out.print("state\t");
        out.println(state);
        out.print("lnL\t");
        out.println(lnL);
        for (Parameter parameter : Parameter.CONNECTED_PARAMETER_SET) {
            out.print("parameter");
            out.print("\t");
            out.print(parameter.getParameterName());
            out.print("\t");
            out.print(parameter.getDimension());
            for (int dim = 0; dim < parameter.getDimension(); dim++) {
                out.print("\t");
                out.print(parameter.getParameterValue(dim));
            }
            out.println();
        }
        for (int i = 0; i < operatorSchedule.getOperatorCount(); i++) {
            MCMCOperator operator = operatorSchedule.getOperator(i);
            out.print("operator");
            out.print("\t");
            out.print(operator.getOperatorName());
            out.print("\t");
            out.print(operator.getAcceptCount());
            out.print("\t");
            out.print(operator.getRejectCount());
            if (operator instanceof CoercableMCMCOperator) {
                out.print("\t");
                out.print(((CoercableMCMCOperator) operator).getCoercableParameter());
            }
            out.println();
        }
        //check up front if there are any TreeParameterModel objects
        for (Model model : Model.CONNECTED_MODEL_SET) {
            if (model instanceof TreeParameterModel) {
                //System.out.println("\nDetected TreeParameterModel: " + ((TreeParameterModel) model).toString());
                traitModels.add((TreeParameterModel) model);
            }
        }
        for (Model model : Model.CONNECTED_MODEL_SET) {
            if (model instanceof TreeModel) {
                out.print("tree");
                out.print("\t");
                out.println(model.getModelName());
                //replace Newick format by printing general graph structure
                //out.println(((TreeModel) model).getNewick());
                out.println("#node height taxon");
                int nodeCount = ((TreeModel) model).getNodeCount();
                out.println(nodeCount);
                for (int i = 0; i < nodeCount; i++) {
                    out.print(((TreeModel) model).getNode(i).getNumber());
                    out.print("\t");
                    out.print(((TreeModel) model).getNodeHeight(((TreeModel) model).getNode(i)));
                    if (((TreeModel) model).isExternal(((TreeModel) model).getNode(i))) {
                        out.print("\t");
                        out.print(((TreeModel) model).getNodeTaxon(((TreeModel) model).getNode(i)).getId());
                    }
                    out.println();
                }
                out.println("#edges");
                out.println("#child-node parent-node L/R-child traits");
                out.println(nodeCount);
                for (int i = 0; i < nodeCount; i++) {
                    NodeRef parent = ((TreeModel) model).getParent(((TreeModel) model).getNode(i));
                    if (parent != null) {
                        out.print(((TreeModel) model).getNode(i).getNumber());
                        out.print("\t");
                        out.print(((TreeModel) model).getParent(((TreeModel) model).getNode(i)).getNumber());
                        out.print("\t");
                        if ((((TreeModel) model).getChild(parent, 0) == ((TreeModel) model).getNode(i))) {
                            //left child
                            out.print(0);
                        } else if ((((TreeModel) model).getChild(parent, 1) == ((TreeModel) model).getNode(i))) {
                            //right child
                            out.print(1);
                        } else {
                            throw new RuntimeException("Operation currently only supported for nodes with 2 children.");
                        }
                        for (TreeParameterModel tpm : traitModels) {
                            out.print("\t");
                            out.print(tpm.getNodeValue((TreeModel) model, ((TreeModel) model).getNode(i)));
                        }
                        out.println();
                    }
                }
            }
        }
        out.close();
        fileOut.close();
    } catch (IOException ioe) {
        System.err.println("Unable to write file: " + ioe.getMessage());
        return false;
    }
    if (DEBUG) {
        for (Likelihood likelihood : Likelihood.CONNECTED_LIKELIHOOD_SET) {
            System.err.println(likelihood.getId() + ": " + likelihood.getLogLikelihood());
        }
    }
    return true;
}
Also used : OperatorSchedule(dr.inference.operators.OperatorSchedule) Likelihood(dr.inference.model.Likelihood) TreeParameterModel(dr.evomodel.tree.TreeParameterModel) TreeModel(dr.evomodel.tree.TreeModel) NodeRef(dr.evolution.tree.NodeRef) TreeParameterModel(dr.evomodel.tree.TreeParameterModel) Model(dr.inference.model.Model) TreeModel(dr.evomodel.tree.TreeModel) Parameter(dr.inference.model.Parameter) CoercableMCMCOperator(dr.inference.operators.CoercableMCMCOperator) MCMCOperator(dr.inference.operators.MCMCOperator) CoercableMCMCOperator(dr.inference.operators.CoercableMCMCOperator)

Example 5 with Model

use of dr.inference.model.Model in project beast-mcmc by beast-dev.

the class CheckPointModifier method readStateFromFile.

private long readStateFromFile(File file, MarkovChain markovChain, double[] lnL) {
    OperatorSchedule operatorSchedule = markovChain.getSchedule();
    long state = -1;
    this.traitModels = new ArrayList<TreeParameterModel>();
    try {
        FileReader fileIn = new FileReader(file);
        BufferedReader in = new BufferedReader(fileIn);
        int[] rngState = null;
        String line = in.readLine();
        String[] fields = line.split("\t");
        if (fields[0].equals("rng")) {
            // if there is a random number generator state present then load it...
            try {
                rngState = new int[fields.length - 1];
                for (int i = 0; i < rngState.length; i++) {
                    rngState[i] = Integer.parseInt(fields[i + 1]);
                }
            } catch (NumberFormatException nfe) {
                throw new RuntimeException("Unable to read state number from state file");
            }
            line = in.readLine();
            fields = line.split("\t");
        }
        try {
            if (!fields[0].equals("state")) {
                throw new RuntimeException("Unable to read state number from state file");
            }
            state = Long.parseLong(fields[1]);
        } catch (NumberFormatException nfe) {
            throw new RuntimeException("Unable to read state number from state file");
        }
        line = in.readLine();
        fields = line.split("\t");
        try {
            if (!fields[0].equals("lnL")) {
                throw new RuntimeException("Unable to read lnL from state file");
            }
            if (lnL != null) {
                lnL[0] = Double.parseDouble(fields[1]);
            }
        } catch (NumberFormatException nfe) {
            throw new RuntimeException("Unable to read lnL from state file");
        }
        line = in.readLine();
        //System.out.println(line);
        fields = line.split("\t");
        //Tree nodes have numbers as parameter ids
        for (Parameter parameter : Parameter.CONNECTED_PARAMETER_SET) {
            //numbers should be positive but can include zero
            if (isTreeNode(parameter.getId()) && isTreeNode(fields[1]) || parameter.getId().equals(fields[1])) {
                int dimension = Integer.parseInt(fields[2]);
                if (dimension != parameter.getDimension() && !fields[1].equals("branchRates.categories")) {
                    System.err.println("Unable to match state parameter dimension: " + dimension + ", expecting " + parameter.getDimension() + " for parameter: " + parameter.getParameterName());
                    System.err.print("Read from file: ");
                    for (int i = 0; i < fields.length; i++) {
                        System.err.print(fields[i] + "\t");
                    }
                    System.err.println();
                }
                if (fields[1].equals("branchRates.categories.rootNodeNumber")) {
                    // System.out.println("eek");
                    double value = Double.parseDouble(fields[3]);
                    parameter.setParameterValue(0, value);
                    if (DEBUG) {
                        System.out.println("restoring " + fields[1] + " with value " + value);
                    }
                } else {
                    if (DEBUG) {
                        System.out.print("restoring " + fields[1] + " with values ");
                    }
                    if (fields[1].equals("branchRates.categories")) {
                        for (int dim = 0; dim < (fields.length - 3); dim++) {
                            //System.out.println("dim " + dim);
                            parameter.setParameterValue(dim, Double.parseDouble(fields[dim + 3]));
                            if (DEBUG) {
                                System.out.print(Double.parseDouble(fields[dim + 3]) + " ");
                            }
                        }
                    } else {
                        for (int dim = 0; dim < parameter.getDimension(); dim++) {
                            parameter.setParameterValue(dim, Double.parseDouble(fields[dim + 3]));
                            if (DEBUG) {
                                System.out.print(Double.parseDouble(fields[dim + 3]) + " ");
                            }
                        }
                    }
                    if (DEBUG) {
                        System.out.println();
                    }
                }
                line = in.readLine();
                //System.out.println(line);
                fields = line.split("\t");
            } else {
            //there will be more parameters in the connected set than there are lines in the checkpoint file
            //do nothing and just keep iterating over the parameters in the connected set
            }
        }
        //No changes needed for loading in operators
        for (int i = 0; i < operatorSchedule.getOperatorCount(); i++) {
            MCMCOperator operator = operatorSchedule.getOperator(i);
            if (!fields[1].equals(operator.getOperatorName())) {
                throw new RuntimeException("Unable to match operator: " + fields[1]);
            }
            if (fields.length < 4) {
                throw new RuntimeException("Operator missing values: " + fields[1]);
            }
            operator.setAcceptCount(Integer.parseInt(fields[2]));
            operator.setRejectCount(Integer.parseInt(fields[3]));
            if (operator instanceof CoercableMCMCOperator) {
                if (fields.length != 5) {
                    throw new RuntimeException("Coercable operator missing parameter: " + fields[1]);
                }
                ((CoercableMCMCOperator) operator).setCoercableParameter(Double.parseDouble(fields[4]));
            }
            line = in.readLine();
            fields = line.split("\t");
        }
        // load the tree models last as we get the node heights from the tree (not the parameters which
        // which may not be associated with the right node
        Set<String> expectedTreeModelNames = new HashSet<String>();
        for (Model model : Model.CONNECTED_MODEL_SET) {
            if (model instanceof TreeModel) {
                expectedTreeModelNames.add(model.getModelName());
            }
            if (model instanceof TreeParameterModel) {
                this.traitModels.add((TreeParameterModel) model);
            }
            if (model instanceof BranchRates) {
                this.rateModel = (BranchRates) model;
            }
        }
        while (fields[0].equals("tree")) {
            for (Model model : Model.CONNECTED_MODEL_SET) {
                if (model instanceof TreeModel && fields[1].equals(model.getModelName())) {
                    //AR: Can we not just add them to a Flexible tree and then make a new TreeModel
                    //taking that in the constructor?
                    //internally, we have a tree with all the taxa
                    //externally, i.e. in the checkpoint file, we have a tree representation comprising
                    //a subset of the full taxa set
                    //write method that adjusts the internal representation, i.e. the one in the connected
                    //set, according to the checkpoint file and a distance-based approach to position
                    //the additional taxa
                    //first read in all the data from the checkpoint file
                    line = in.readLine();
                    line = in.readLine();
                    fields = line.split("\t");
                    //read number of nodes
                    int nodeCount = Integer.parseInt(fields[0]);
                    double[] nodeHeights = new double[nodeCount];
                    String[] taxaNames = new String[(nodeCount + 1) / 2];
                    for (int i = 0; i < nodeCount; i++) {
                        line = in.readLine();
                        fields = line.split("\t");
                        nodeHeights[i] = Double.parseDouble(fields[1]);
                        if (i < taxaNames.length) {
                            taxaNames[i] = fields[2];
                        }
                    }
                    //on to reading edge information
                    line = in.readLine();
                    line = in.readLine();
                    line = in.readLine();
                    fields = line.split("\t");
                    int edgeCount = Integer.parseInt(fields[0]);
                    //create data matrix of doubles to store information from list of TreeParameterModels
                    double[][] traitValues = new double[traitModels.size()][edgeCount];
                    //create array to store whether a node is left or right child of its parent
                    //can be important for certain tree transition kernels
                    int[] childOrder = new int[edgeCount];
                    for (int i = 0; i < childOrder.length; i++) {
                        childOrder[i] = -1;
                    }
                    int[] parents = new int[edgeCount];
                    for (int i = 0; i < edgeCount; i++) {
                        parents[i] = -1;
                    }
                    for (int i = 0; i < edgeCount; i++) {
                        line = in.readLine();
                        if (line != null) {
                            fields = line.split("\t");
                            parents[Integer.parseInt(fields[0])] = Integer.parseInt(fields[1]);
                            childOrder[i] = Integer.parseInt(fields[2]);
                            for (int j = 0; j < traitModels.size(); j++) {
                                traitValues[j][i] = Double.parseDouble(fields[3 + j]);
                            }
                        }
                    }
                    //perform magic with the acquired information
                    //CheckPointTreeModifier modifyTree = new CheckPointTreeModifier((TreeModel) model);
                    this.modifyTree = new CheckPointTreeModifier((TreeModel) model);
                    modifyTree.adoptTreeStructure(parents, nodeHeights, childOrder, taxaNames);
                    if (traitModels.size() > 0) {
                        modifyTree.adoptTraitData(parents, this.traitModels, traitValues);
                    }
                    //adopt the loaded tree structure; this does not yet copy the traits on the branches
                    //((TreeModel) model).beginTreeEdit();
                    //((TreeModel) model).adoptTreeStructure(parents, nodeHeights, childOrder);
                    //((TreeModel) model).endTreeEdit();
                    expectedTreeModelNames.remove(model.getModelName());
                }
            }
            line = in.readLine();
            if (line != null) {
                fields = line.split("\t");
            }
        }
        if (expectedTreeModelNames.size() > 0) {
            StringBuilder sb = new StringBuilder();
            for (String notFoundName : expectedTreeModelNames) {
                sb.append("Expecting, but unable to match state parameter:" + notFoundName + "\n");
            }
            throw new RuntimeException(sb.toString());
        }
        in.close();
        fileIn.close();
    } catch (IOException ioe) {
        throw new RuntimeException("Unable to read file: " + ioe.getMessage());
    }
    return state;
}
Also used : TreeModel(dr.evomodel.tree.TreeModel) FileReader(java.io.FileReader) HashSet(java.util.HashSet) OperatorSchedule(dr.inference.operators.OperatorSchedule) TreeParameterModel(dr.evomodel.tree.TreeParameterModel) IOException(java.io.IOException) BufferedReader(java.io.BufferedReader) TreeParameterModel(dr.evomodel.tree.TreeParameterModel) Model(dr.inference.model.Model) TreeModel(dr.evomodel.tree.TreeModel) Parameter(dr.inference.model.Parameter) CoercableMCMCOperator(dr.inference.operators.CoercableMCMCOperator) BranchRates(dr.evolution.tree.BranchRates) MCMCOperator(dr.inference.operators.MCMCOperator) CoercableMCMCOperator(dr.inference.operators.CoercableMCMCOperator)

Aggregations

Model (dr.inference.model.Model)7 Parameter (dr.inference.model.Parameter)5 OperatorSchedule (dr.inference.operators.OperatorSchedule)4 TreeModel (dr.evomodel.tree.TreeModel)3 TreeParameterModel (dr.evomodel.tree.TreeParameterModel)3 Likelihood (dr.inference.model.Likelihood)3 CoercableMCMCOperator (dr.inference.operators.CoercableMCMCOperator)3 MCMCOperator (dr.inference.operators.MCMCOperator)3 CompoundLikelihood (dr.inference.model.CompoundLikelihood)2 ArrayList (java.util.ArrayList)2 BranchRates (dr.evolution.tree.BranchRates)1 NodeRef (dr.evolution.tree.NodeRef)1 GibbsIndependentCoalescentOperator (dr.evomodel.continuous.GibbsIndependentCoalescentOperator)1 Loggable (dr.inference.loggers.Loggable)1 Logger (dr.inference.loggers.Logger)1 TabDelimitedFormatter (dr.inference.loggers.TabDelimitedFormatter)1 MarkovChain (dr.inference.markovchain.MarkovChain)1 MCMC (dr.inference.mcmc.MCMC)1 MCMCOptions (dr.inference.mcmc.MCMCOptions)1 DefaultModel (dr.inference.model.DefaultModel)1