Search in sources :

Example 1 with TreeTrait

use of dr.evolution.tree.TreeTrait in project beast-mcmc by beast-dev.

the class MarkovJumpsBeagleTreeLikelihood method addRegister.

public void addRegister(Parameter addRegisterParameter, MarkovJumpsType type, boolean scaleByTime) {
    if ((type == MarkovJumpsType.COUNTS && addRegisterParameter.getDimension() != stateCount * stateCount) || (type == MarkovJumpsType.REWARDS && addRegisterParameter.getDimension() != stateCount)) {
        throw new RuntimeException("Register parameter of wrong dimension");
    }
    addVariable(addRegisterParameter);
    final String tag = addRegisterParameter.getId();
    for (int i = 0; i < substitutionModelDelegate.getSubstitutionModelCount(); ++i) {
        registerParameter.add(addRegisterParameter);
        MarkovJumpsSubstitutionModel mjModel;
        SubstitutionModel substitutionModel = substitutionModelDelegate.getSubstitutionModel(i);
        if (useUniformization) {
            mjModel = new UniformizedSubstitutionModel(substitutionModel, type, nSimulants);
        } else {
            if (type == MarkovJumpsType.HISTORY) {
                throw new RuntimeException("Can only report complete history using uniformization");
            }
            mjModel = new MarkovJumpsSubstitutionModel(substitutionModel, type);
        }
        markovjumps.add(mjModel);
        branchModelNumber.add(i);
        addModel(mjModel);
        setupRegistration(numRegisters);
        String traitName;
        if (substitutionModelDelegate.getSubstitutionModelCount() == 1) {
            traitName = tag;
        } else {
            traitName = tag + i;
        }
        jumpTag.add(traitName);
        expectedJumps.add(new double[treeModel.getNodeCount()][patternCount]);
        //        storedExpectedJumps.add(new double[treeModel.getNodeCount()][patternCount]);
        boolean[] oldScaleByTime = this.scaleByTime;
        int oldScaleByTimeLength = (oldScaleByTime == null ? 0 : oldScaleByTime.length);
        this.scaleByTime = new boolean[oldScaleByTimeLength + 1];
        if (oldScaleByTimeLength > 0) {
            System.arraycopy(oldScaleByTime, 0, this.scaleByTime, 0, oldScaleByTimeLength);
        }
        this.scaleByTime[oldScaleByTimeLength] = scaleByTime;
        if (type != MarkovJumpsType.HISTORY) {
            TreeTrait.DA da = new TreeTrait.DA() {

                final int registerNumber = numRegisters;

                public String getTraitName() {
                    return tag;
                }

                public Intent getIntent() {
                    return Intent.BRANCH;
                }

                public double[] getTrait(Tree tree, NodeRef node) {
                    return getMarkovJumpsForNodeAndRegister(tree, node, registerNumber);
                }
            };
            treeTraits.addTrait(traitName + "_base", da);
            treeTraits.addTrait(addRegisterParameter.getId(), new TreeTrait.SumAcrossArrayD(new TreeTrait.SumOverTreeDA(da)));
        } else {
            if (histories == null) {
                histories = new String[treeModel.getNodeCount()][patternCount];
            } else {
                throw new RuntimeException("Only one complete history per markovJumpTreeLikelihood is allowed");
            }
            if (nSimulants > 1) {
                throw new RuntimeException("Only one simulant allowed when saving complete history");
            }
            // Add total number of changes over tree trait
            TreeTrait da = new TreeTrait.DA() {

                final int registerNumber = numRegisters;

                public String getTraitName() {
                    return tag;
                }

                public Intent getIntent() {
                    return Intent.BRANCH;
                }

                public double[] getTrait(Tree tree, NodeRef node) {
                    return getMarkovJumpsForNodeAndRegister(tree, node, registerNumber);
                }
            };
            treeTraits.addTrait(addRegisterParameter.getId(), new TreeTrait.SumOverTreeDA(da));
            // Record the complete history for this register
            historyRegisterNumber = numRegisters;
            ((UniformizedSubstitutionModel) mjModel).setSaveCompleteHistory(true);
            if (useCompactHistory && logHistory) {
                treeTraits.addTrait(ALL_HISTORY, new TreeTrait.SA() {

                    public String getTraitName() {
                        return ALL_HISTORY;
                    }

                    public Intent getIntent() {
                        return Intent.BRANCH;
                    }

                    public boolean getFormatAsArray() {
                        return true;
                    }

                    public String[] getTrait(Tree tree, NodeRef node) {
                        List<String> events = new ArrayList<String>();
                        for (int i = 0; i < patternCount; i++) {
                            String eventString = getHistoryForNode(tree, node, i);
                            if (eventString != null && eventString.compareTo("{}") != 0) {
                                eventString = eventString.substring(1, eventString.length() - 1);
                                if (eventString.contains("},{")) {
                                    // There are multiple events
                                    String[] elements = eventString.split("(?<=\\}),(?=\\{)");
                                    for (String e : elements) {
                                        events.add(e);
                                    }
                                } else {
                                    events.add(eventString);
                                }
                            }
                        }
                        String[] array = new String[events.size()];
                        events.toArray(array);
                        return array;
                    }

                    public boolean getLoggable() {
                        return true;
                    }
                });
            }
            for (int site = 0; site < patternCount; ++site) {
                final String anonName = (patternCount == 1) ? HISTORY : HISTORY + "_" + (site + 1);
                final int anonSite = site;
                treeTraits.addTrait(anonName, new TreeTrait.S() {

                    public String getTraitName() {
                        return anonName;
                    }

                    public Intent getIntent() {
                        return Intent.BRANCH;
                    }

                    public String getTrait(Tree tree, NodeRef node) {
                        String history = getHistoryForNode(tree, node, anonSite);
                        // Return null if empty
                        return (history.compareTo("{}") != 0) ? history : null;
                    }

                    public boolean getLoggable() {
                        return logHistory && !useCompactHistory;
                    }
                });
            }
        }
        numRegisters++;
    }
// End of loop over branch models
}
Also used : MarkovJumpsSubstitutionModel(dr.evomodel.substmodel.MarkovJumpsSubstitutionModel) UniformizedSubstitutionModel(dr.evomodel.substmodel.UniformizedSubstitutionModel) SubstitutionModel(dr.evomodel.substmodel.SubstitutionModel) UniformizedSubstitutionModel(dr.evomodel.substmodel.UniformizedSubstitutionModel) TreeTrait(dr.evolution.tree.TreeTrait) NodeRef(dr.evolution.tree.NodeRef) Tree(dr.evolution.tree.Tree) PatternList(dr.evolution.alignment.PatternList) MarkovJumpsSubstitutionModel(dr.evomodel.substmodel.MarkovJumpsSubstitutionModel)

Example 2 with TreeTrait

use of dr.evolution.tree.TreeTrait in project beast-mcmc by beast-dev.

the class DiscreteTraitBranchRateModelParser method parseXMLObject.

public Object parseXMLObject(XMLObject xo) throws XMLParseException {
    TreeModel treeModel = (TreeModel) xo.getChild(TreeModel.class);
    PatternList patternList = (PatternList) xo.getChild(PatternList.class);
    TreeTraitProvider traitProvider = (TreeTraitProvider) xo.getChild(TreeTraitProvider.class);
    DataType dataType = DataTypeUtils.getDataType(xo);
    Parameter rateParameter = null;
    Parameter relativeRatesParameter = null;
    Parameter indicatorsParameter = null;
    if (xo.getChild(RATE) != null) {
        rateParameter = (Parameter) xo.getElementFirstChild(RATE);
    }
    if (xo.getChild(RATES) != null) {
        rateParameter = (Parameter) xo.getElementFirstChild(RATES);
    }
    if (xo.getChild(RELATIVE_RATES) != null) {
        relativeRatesParameter = (Parameter) xo.getElementFirstChild(RELATIVE_RATES);
    }
    if (xo.getChild(INDICATORS) != null) {
        indicatorsParameter = (Parameter) xo.getElementFirstChild(INDICATORS);
    }
    int traitIndex = xo.getAttribute(TRAIT_INDEX, 1) - 1;
    String traitName = "states";
    Logger.getLogger("dr.evomodel").info("Using discrete trait branch rate model.\n" + "\tIf you use this model, please cite:\n" + "\t\tDrummond and Suchard (in preparation)");
    if (traitProvider == null) {
        // Use the version that reconstructs the trait using parsimony:
        return new DiscreteTraitBranchRateModel(treeModel, patternList, traitIndex, rateParameter);
    } else {
        if (traitName != null) {
            TreeTrait trait = traitProvider.getTreeTrait(traitName);
            if (trait == null) {
                throw new XMLParseException("A trait called, " + traitName + ", was not available from the TreeTraitProvider supplied to " + getParserName() + ", with ID " + xo.getId());
            }
            if (relativeRatesParameter != null) {
                return new DiscreteTraitBranchRateModel(traitProvider, dataType, treeModel, trait, traitIndex, rateParameter, relativeRatesParameter, indicatorsParameter);
            } else {
                return new DiscreteTraitBranchRateModel(traitProvider, dataType, treeModel, trait, traitIndex, rateParameter);
            }
        } else {
            TreeTrait[] traits = new TreeTrait[dataType.getStateCount()];
            for (int i = 0; i < dataType.getStateCount(); i++) {
                traits[i] = traitProvider.getTreeTrait(dataType.getCode(i));
                if (traits[i] == null) {
                    throw new XMLParseException("A trait called, " + dataType.getCode(i) + ", was not available from the TreeTraitProvider supplied to " + getParserName() + ", with ID " + xo.getId());
                }
            }
            return new DiscreteTraitBranchRateModel(traitProvider, traits, treeModel, rateParameter);
        }
    }
}
Also used : TreeModel(dr.evomodel.tree.TreeModel) TreeTraitProvider(dr.evolution.tree.TreeTraitProvider) PatternList(dr.evolution.alignment.PatternList) DataType(dr.evolution.datatype.DataType) Parameter(dr.inference.model.Parameter) DiscreteTraitBranchRateModel(dr.evomodel.branchratemodel.DiscreteTraitBranchRateModel) TreeTrait(dr.evolution.tree.TreeTrait)

Example 3 with TreeTrait

use of dr.evolution.tree.TreeTrait in project beast-mcmc by beast-dev.

the class CodonPartitionedRobustCounting method setupTraits.

private void setupTraits() {
    TreeTrait baseTrait = new TreeTrait.DA() {

        public String getTraitName() {
            return BASE_TRAIT_PREFIX + codonLabeling.getText();
        }

        public Intent getIntent() {
            return Intent.BRANCH;
        }

        public double[] getTrait(Tree tree, NodeRef node) {
            return getExpectedCountsForBranch(node);
        }

        public boolean getLoggable() {
            return false;
        }
    };
    if (saveCompleteHistory) {
        TreeTrait stringTrait = new TreeTrait.SA() {

            public String getTraitName() {
                return COMPLETE_HISTORY_PREFIX + codonLabeling.getText();
            }

            public Intent getIntent() {
                return Intent.BRANCH;
            }

            public boolean getFormatAsArray() {
                return true;
            }

            public String[] getTrait(Tree tree, NodeRef node) {
                // Lazy simulation of complete histories
                double[] count = getExpectedCountsForBranch(node);
                List<String> events = new ArrayList<String>();
                for (int i = 0; i < numCodons; i++) {
                    String eventString = completeHistoryPerNode[node.getNumber()][i];
                    if (eventString != null) {
                        if (eventString.contains("},{")) {
                            // There are multiple events
                            String[] elements = eventString.split("(?<=\\}),(?=\\{)");
                            for (String e : elements) {
                                events.add(e);
                            }
                        } else {
                            events.add(eventString);
                        }
                    }
                }
                if (DEBUG) {
                    double sum = 0.0;
                    for (double d : count) {
                        if (d > 0.0) {
                            sum += 1;
                        }
                    }
                    System.err.println(events.size() + " " + sum);
                    if (Math.abs(events.size() - sum) > 0.5) {
                        System.err.println("Error");
                        for (int i = 0; i < count.length; ++i) {
                            if (count[i] != 0.0) {
                                System.err.println(i + ": " + count[i] + completeHistoryPerNode[node.getNumber()][i]);
                            }
                        }
                        System.err.println("Error");
                        int c = 0;
                        for (String s : events) {
                            c++;
                            System.err.println(c + ":" + s);
                        }
                        System.exit(-1);
                    }
                }
                String[] array = new String[events.size()];
                events.toArray(array);
                return array;
            }

            public boolean getLoggable() {
                return true;
            }
        };
        treeTraits.addTrait(stringTrait);
    }
    TreeTrait unconditionedSum;
    if (!TRIAL) {
        unconditionedSum = new TreeTrait.D() {

            public String getTraitName() {
                return UNCONDITIONED_PREFIX + codonLabeling.getText();
            }

            public Intent getIntent() {
                return Intent.WHOLE_TREE;
            }

            public Double getTrait(Tree tree, NodeRef node) {
                return getUnconditionedTraitValue();
            }

            public boolean getLoggable() {
                return false;
            }
        };
    } else {
        unconditionedSum = new TreeTrait.DA() {

            public String getTraitName() {
                return UNCONDITIONED_PREFIX + codonLabeling.getText();
            }

            public Intent getIntent() {
                return Intent.WHOLE_TREE;
            }

            public double[] getTrait(Tree tree, NodeRef node) {
                return getUnconditionedTraitValues();
            }

            public boolean getLoggable() {
                return false;
            }
        };
    }
    TreeTrait sumOverTreeTrait = new TreeTrait.SumOverTreeDA(SITE_SPECIFIC_PREFIX + codonLabeling.getText(), baseTrait, includeExternalBranches, includeInternalBranches) {

        @Override
        public boolean getLoggable() {
            return false;
        }
    };
    // This should be the default output in tree logs
    TreeTrait sumOverSitesTrait = new TreeTrait.SumAcrossArrayD(codonLabeling.getText(), baseTrait) {

        @Override
        public boolean getLoggable() {
            return true;
        }
    };
    // This should be the default output in columns logs
    String name = prefix != null ? prefix + TOTAL_PREFIX + codonLabeling.getText() : TOTAL_PREFIX + codonLabeling.getText();
    TreeTrait sumOverSitesAndTreeTrait = new TreeTrait.SumOverTreeD(name, sumOverSitesTrait, includeExternalBranches, includeInternalBranches) {

        @Override
        public boolean getLoggable() {
            return true;
        }
    };
    treeTraitLogger = new TreeTraitLogger(tree, new TreeTrait[] { sumOverSitesAndTreeTrait });
    treeTraits.addTrait(baseTrait);
    treeTraits.addTrait(unconditionedSum);
    treeTraits.addTrait(sumOverSitesTrait);
    treeTraits.addTrait(sumOverTreeTrait);
    treeTraits.addTrait(sumOverSitesAndTreeTrait);
    if (doUnconditionedPerBranch) {
        TreeTrait unconditionedBase = new TreeTrait.DA() {

            public String getTraitName() {
                return UNCONDITIONED_PER_BRANCH_PREFIX + codonLabeling.getText();
            }

            public Intent getIntent() {
                return Intent.BRANCH;
            }

            public double[] getTrait(Tree tree, NodeRef node) {
                return getUnconditionalCountsForBranch(node);
            }

            public boolean getLoggable() {
                // TODO Should be switched to true to log unconditioned values per branch
                return false;
            }
        };
        TreeTrait sumUnconditionedOverSitesTrait = new TreeTrait.SumAcrossArrayD(UNCONDITIONED_PER_BRANCH_PREFIX + codonLabeling.getText(), unconditionedBase) {

            @Override
            public boolean getLoggable() {
                return true;
            }
        };
        String nameU = prefix != null ? prefix + UNCONDITIONED_TOTAL_PREFIX + codonLabeling.getText() : UNCONDITIONED_TOTAL_PREFIX + codonLabeling.getText();
        TreeTrait sumUnconditionedOverSitesAndTreeTrait = new TreeTrait.SumOverTreeD(nameU, sumUnconditionedOverSitesTrait, includeExternalBranches, includeInternalBranches) {

            public boolean getLoggable() {
                return true;
            }
        };
        treeTraitLogger = new TreeTraitLogger(tree, new TreeTrait[] { sumOverSitesAndTreeTrait, sumUnconditionedOverSitesAndTreeTrait });
        treeTraits.addTrait(unconditionedBase);
        treeTraits.addTrait(sumUnconditionedOverSitesTrait);
    }
}
Also used : TreeTraitLogger(dr.evomodel.treelikelihood.utilities.TreeTraitLogger) ArrayList(java.util.ArrayList) TreeTrait(dr.evolution.tree.TreeTrait) NodeRef(dr.evolution.tree.NodeRef) Tree(dr.evolution.tree.Tree)

Example 4 with TreeTrait

use of dr.evolution.tree.TreeTrait in project beast-mcmc by beast-dev.

the class TreeTraitLogger method getColumns.

public LogColumn[] getColumns() {
    if (loggableTreeTraits.size() == 0) {
        return null;
    }
    LogColumn[] columns = new LogColumn[loggableTreeTraits.size()];
    for (int i = 0; i < loggableTreeTraits.size(); i++) {
        final TreeTrait trait = loggableTreeTraits.get(i);
        columns[i] = new LogColumn.Abstract(trait.getTraitName()) {

            @Override
            protected String getFormattedValue() {
                return trait.getTraitString(tree, null);
            }
        };
    }
    return columns;
}
Also used : LogColumn(dr.inference.loggers.LogColumn) TreeTrait(dr.evolution.tree.TreeTrait)

Example 5 with TreeTrait

use of dr.evolution.tree.TreeTrait in project beast-mcmc by beast-dev.

the class AncestralTraitParser method parseXMLObject.

public Object parseXMLObject(XMLObject xo) throws XMLParseException {
    String traitName = xo.getAttribute(TRAIT_NAME, STATES);
    String name = xo.getAttribute(NAME, traitName);
    Tree tree = (Tree) xo.getChild(Tree.class);
    TreeTraitProvider treeTraitProvider = (TreeTraitProvider) xo.getChild(TreeTraitProvider.class);
    TaxonList taxa = null;
    if (xo.hasChildNamed(MRCA)) {
        taxa = (TaxonList) xo.getElementFirstChild(MRCA);
    }
    TreeTrait trait = treeTraitProvider.getTreeTrait(traitName);
    if (trait == null) {
        throw new XMLParseException("A trait called, " + traitName + ", was not available from the TreeTraitProvider supplied to " + getParserName() + (xo.hasId() ? ", with ID " + xo.getId() : ""));
    }
    try {
        return new AncestralTrait(name, trait, tree, taxa);
    } catch (TreeUtils.MissingTaxonException mte) {
        throw new XMLParseException("Taxon, " + mte + ", in " + getParserName() + "was not found in the tree.");
    }
}
Also used : TreeTraitProvider(dr.evolution.tree.TreeTraitProvider) TaxonList(dr.evolution.util.TaxonList) Tree(dr.evolution.tree.Tree) AncestralTrait(dr.evomodel.tree.AncestralTrait) TreeTrait(dr.evolution.tree.TreeTrait) TreeUtils(dr.evolution.tree.TreeUtils)

Aggregations

TreeTrait (dr.evolution.tree.TreeTrait)8 Tree (dr.evolution.tree.Tree)5 NodeRef (dr.evolution.tree.NodeRef)3 TreeTraitProvider (dr.evolution.tree.TreeTraitProvider)3 PatternList (dr.evolution.alignment.PatternList)2 TreeModel (dr.evomodel.tree.TreeModel)2 TreeTraitLogger (dr.evomodel.treelikelihood.utilities.TreeTraitLogger)2 DataType (dr.evolution.datatype.DataType)1 TreeUtils (dr.evolution.tree.TreeUtils)1 TaxonList (dr.evolution.util.TaxonList)1 DiscreteTraitBranchRateModel (dr.evomodel.branchratemodel.DiscreteTraitBranchRateModel)1 CodonPartitionedRobustCounting (dr.evomodel.substmodel.CodonPartitionedRobustCounting)1 MarkovJumpsSubstitutionModel (dr.evomodel.substmodel.MarkovJumpsSubstitutionModel)1 SubstitutionModel (dr.evomodel.substmodel.SubstitutionModel)1 UniformizedSubstitutionModel (dr.evomodel.substmodel.UniformizedSubstitutionModel)1 AncestralTrait (dr.evomodel.tree.AncestralTrait)1 DnDsLogger (dr.evomodel.treelikelihood.utilities.DnDsLogger)1 LogColumn (dr.inference.loggers.LogColumn)1 Parameter (dr.inference.model.Parameter)1 ArrayList (java.util.ArrayList)1