Search in sources :

Example 11 with TreeDataLikelihood

use of dr.evomodel.treedatalikelihood.TreeDataLikelihood in project beast-mcmc by beast-dev.

the class FullyConjugateTreeTipsPotentialDerivativeParser method parseXMLObject.

@Override
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
    // String name = xo.hasId() ? xo.getId() : FULLY_CONJUGATE_TREE_TIPS_POTENTIAL_DERIVATIVE2;
    String traitName = xo.getAttribute(TRAIT_NAME, DEFAULT_TRAIT_NAME);
    // Object co = xo.getChild(0);
    final FullyConjugateMultivariateTraitLikelihood fcTreeLikelihood = (FullyConjugateMultivariateTraitLikelihood) xo.getChild(FullyConjugateMultivariateTraitLikelihood.class);
    final TreeDataLikelihood treeDataLikelihood = (TreeDataLikelihood) xo.getChild(TreeDataLikelihood.class);
    Parameter mask = null;
    if (xo.hasChildNamed(MASKING)) {
        mask = (Parameter) xo.getElementFirstChild(MASKING);
    }
    if (fcTreeLikelihood != null) {
        return new FullyConjugateTreeTipsPotentialDerivative(fcTreeLikelihood, mask);
    } else if (treeDataLikelihood != null) {
        DataLikelihoodDelegate delegate = treeDataLikelihood.getDataLikelihoodDelegate();
        if (!(delegate instanceof ContinuousDataLikelihoodDelegate)) {
            throw new XMLParseException("May not provide a sequence data likelihood to compute tip trait gradient");
        }
        final ContinuousDataLikelihoodDelegate continuousData = (ContinuousDataLikelihoodDelegate) delegate;
        return new TreeTipGradient(traitName, treeDataLikelihood, continuousData, mask);
    } else {
        throw new XMLParseException("Must provide a tree likelihood");
    }
}
Also used : TreeDataLikelihood(dr.evomodel.treedatalikelihood.TreeDataLikelihood) ContinuousDataLikelihoodDelegate(dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate) DataLikelihoodDelegate(dr.evomodel.treedatalikelihood.DataLikelihoodDelegate) Parameter(dr.inference.model.Parameter) TreeTipGradient(dr.evomodel.treedatalikelihood.continuous.TreeTipGradient) FullyConjugateTreeTipsPotentialDerivative(dr.evomodel.continuous.hmc.FullyConjugateTreeTipsPotentialDerivative) ContinuousDataLikelihoodDelegate(dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate) FullyConjugateMultivariateTraitLikelihood(dr.evomodel.continuous.FullyConjugateMultivariateTraitLikelihood)

Example 12 with TreeDataLikelihood

use of dr.evomodel.treedatalikelihood.TreeDataLikelihood in project beast-mcmc by beast-dev.

the class ContinuousDataLikelihoodParser method parseXMLObject.

public Object parseXMLObject(XMLObject xo) throws XMLParseException {
    Tree treeModel = (Tree) xo.getChild(Tree.class);
    MultivariateDiffusionModel diffusionModel = (MultivariateDiffusionModel) xo.getChild(MultivariateDiffusionModel.class);
    BranchRateModel rateModel = (BranchRateModel) xo.getChild(BranchRateModel.class);
    boolean useTreeLength = xo.getAttribute(USE_TREE_LENGTH, false);
    boolean scaleByTime = xo.getAttribute(SCALE_BY_TIME, false);
    boolean reciprocalRates = xo.getAttribute(RECIPROCAL_RATES, false);
    if (reciprocalRates) {
        throw new XMLParseException("Reciprocal rates are not yet implemented.");
    }
    if (rateModel == null) {
        rateModel = new DefaultBranchRateModel();
    }
    ContinuousRateTransformation rateTransformation = new ContinuousRateTransformation.Default(treeModel, scaleByTime, useTreeLength);
    final int dim = diffusionModel.getPrecisionmatrix().length;
    String traitName = TreeTraitParserUtilities.DEFAULT_TRAIT_NAME;
    List<Integer> missingIndices;
    // Parameter sampleMissingParameter = null;
    ContinuousTraitPartialsProvider dataModel;
    boolean useMissingIndices = true;
    boolean integratedProcess = xo.getAttribute(INTEGRATED_PROCESS, false);
    if (xo.hasChildNamed(TreeTraitParserUtilities.TRAIT_PARAMETER)) {
        TreeTraitParserUtilities utilities = new TreeTraitParserUtilities();
        TreeTraitParserUtilities.TraitsAndMissingIndices returnValue = utilities.parseTraitsFromTaxonAttributes(xo, traitName, treeModel, true);
        CompoundParameter traitParameter = returnValue.traitParameter;
        missingIndices = returnValue.missingIndices;
        // sampleMissingParameter = returnValue.sampleMissingParameter;
        traitName = returnValue.traitName;
        useMissingIndices = returnValue.useMissingIndices;
        PrecisionType precisionType = PrecisionType.SCALAR;
        if (xo.getAttribute(FORCE_FULL_PRECISION, false) || (useMissingIndices && !xo.getAttribute(FORCE_COMPLETELY_MISSING, false))) {
            precisionType = PrecisionType.FULL;
        }
        if (xo.hasChildNamed(TreeTraitParserUtilities.JITTER)) {
            utilities.jitter(xo, diffusionModel.getPrecisionmatrix().length, missingIndices);
        }
        if (!integratedProcess) {
            dataModel = new ContinuousTraitDataModel(traitName, traitParameter, missingIndices, useMissingIndices, dim, precisionType);
        } else {
            dataModel = new IntegratedProcessTraitDataModel(traitName, traitParameter, missingIndices, useMissingIndices, dim, precisionType);
        }
    } else {
        // Has ContinuousTraitPartialsProvider
        dataModel = (ContinuousTraitPartialsProvider) xo.getChild(ContinuousTraitPartialsProvider.class);
    }
    ConjugateRootTraitPrior rootPrior = ConjugateRootTraitPrior.parseConjugateRootTraitPrior(xo, dataModel.getTraitDimension());
    final boolean allowSingular;
    if (dataModel instanceof IntegratedFactorAnalysisLikelihood) {
        if (traitName == TreeTraitParserUtilities.DEFAULT_TRAIT_NAME) {
            traitName = FACTOR_NAME;
        }
        if (xo.hasAttribute(ALLOW_SINGULAR)) {
            allowSingular = xo.getAttribute(ALLOW_SINGULAR, false);
        } else {
            allowSingular = true;
        }
    } else if (dataModel instanceof RepeatedMeasuresTraitDataModel) {
        traitName = ((RepeatedMeasuresTraitDataModel) dataModel).getTraitName();
        allowSingular = xo.getAttribute(ALLOW_SINGULAR, false);
    } else {
        allowSingular = xo.getAttribute(ALLOW_SINGULAR, false);
    }
    List<BranchRateModel> driftModels = AbstractMultivariateTraitLikelihood.parseDriftModels(xo, diffusionModel);
    List<BranchRateModel> optimalTraitsModels = AbstractMultivariateTraitLikelihood.parseOptimalValuesModels(xo, diffusionModel);
    MultivariateElasticModel elasticModel = null;
    if (xo.hasChildNamed(STRENGTH_OF_SELECTION_MATRIX)) {
        XMLObject cxo = xo.getChild(STRENGTH_OF_SELECTION_MATRIX);
        MatrixParameterInterface strengthOfSelectionMatrixParam;
        strengthOfSelectionMatrixParam = (MatrixParameterInterface) cxo.getChild(MatrixParameterInterface.class);
        if (strengthOfSelectionMatrixParam != null) {
            elasticModel = new MultivariateElasticModel(strengthOfSelectionMatrixParam);
        }
    }
    DiffusionProcessDelegate diffusionProcessDelegate;
    if ((optimalTraitsModels != null && elasticModel != null) || xo.getAttribute(FORCE_OU, false)) {
        if (!integratedProcess) {
            diffusionProcessDelegate = new OUDiffusionModelDelegate(treeModel, diffusionModel, optimalTraitsModels, elasticModel);
        } else {
            diffusionProcessDelegate = new IntegratedOUDiffusionModelDelegate(treeModel, diffusionModel, optimalTraitsModels, elasticModel);
        }
    } else {
        if (driftModels != null || xo.getAttribute(FORCE_DRIFT, false)) {
            diffusionProcessDelegate = new DriftDiffusionModelDelegate(treeModel, diffusionModel, driftModels);
        } else {
            diffusionProcessDelegate = new HomogeneousDiffusionModelDelegate(treeModel, diffusionModel);
        }
    }
    ContinuousDataLikelihoodDelegate delegate = new ContinuousDataLikelihoodDelegate(treeModel, diffusionProcessDelegate, dataModel, rootPrior, rateTransformation, rateModel, allowSingular);
    if (dataModel instanceof IntegratedFactorAnalysisLikelihood) {
        ((IntegratedFactorAnalysisLikelihood) dataModel).setLikelihoodDelegate(delegate);
    }
    TreeDataLikelihood treeDataLikelihood = new TreeDataLikelihood(delegate, treeModel, rateModel);
    boolean reconstructTraits = xo.getAttribute(RECONSTRUCT_TRAITS, true);
    if (reconstructTraits) {
        // if (missingIndices != null && missingIndices.size() == 0) {
        if (!useMissingIndices) {
            ProcessSimulationDelegate simulationDelegate = delegate.getPrecisionType() == PrecisionType.SCALAR ? new ConditionalOnTipsRealizedDelegate(traitName, treeModel, diffusionModel, dataModel, rootPrior, rateTransformation, delegate) : new MultivariateConditionalOnTipsRealizedDelegate(traitName, treeModel, diffusionModel, dataModel, rootPrior, rateTransformation, delegate);
            TreeTraitProvider traitProvider = new ProcessSimulation(treeDataLikelihood, simulationDelegate);
            treeDataLikelihood.addTraits(traitProvider.getTreeTraits());
        } else {
            ProcessSimulationDelegate simulationDelegate = delegate.getPrecisionType() == PrecisionType.SCALAR ? new ConditionalOnTipsRealizedDelegate(traitName, treeModel, diffusionModel, dataModel, rootPrior, rateTransformation, delegate) : new MultivariateConditionalOnTipsRealizedDelegate(traitName, treeModel, diffusionModel, dataModel, rootPrior, rateTransformation, delegate);
            TreeTraitProvider traitProvider = new ProcessSimulation(treeDataLikelihood, simulationDelegate);
            treeDataLikelihood.addTraits(traitProvider.getTreeTraits());
            ProcessSimulationDelegate fullConditionalDelegate = new TipRealizedValuesViaFullConditionalDelegate(traitName, treeModel, diffusionModel, dataModel, rootPrior, rateTransformation, delegate);
            treeDataLikelihood.addTraits(new ProcessSimulation(treeDataLikelihood, fullConditionalDelegate).getTreeTraits());
        // String partialTraitName = getPartiallyMissingTraitName(traitName);
        // 
        // ProcessSimulationDelegate partialSimulationDelegate = new ProcessSimulationDelegate.ConditionalOnPartiallyMissingTipsDelegate(partialTraitName,
        // treeModel, diffusionModel, dataModel, rootPrior, rateTransformation, rateModel, delegate);
        // 
        // TreeTraitProvider partialTraitProvider = new ProcessSimulation(partialTraitName,
        // treeDataLikelihood, partialSimulationDelegate);
        // 
        // treeDataLikelihood.addTraits(partialTraitProvider.getTreeTraits());
        }
    }
    return treeDataLikelihood;
}
Also used : MultivariateConditionalOnTipsRealizedDelegate(dr.evomodel.treedatalikelihood.preorder.MultivariateConditionalOnTipsRealizedDelegate) MultivariateElasticModel(dr.evomodel.continuous.MultivariateElasticModel) PrecisionType(dr.evomodel.treedatalikelihood.continuous.cdi.PrecisionType) DefaultBranchRateModel(dr.evomodel.branchratemodel.DefaultBranchRateModel) CompoundParameter(dr.inference.model.CompoundParameter) MultivariateDiffusionModel(dr.evomodel.continuous.MultivariateDiffusionModel) Tree(dr.evolution.tree.Tree) TreeTraitProvider(dr.evolution.tree.TreeTraitProvider) MatrixParameterInterface(dr.inference.model.MatrixParameterInterface) TreeDataLikelihood(dr.evomodel.treedatalikelihood.TreeDataLikelihood) BranchRateModel(dr.evomodel.branchratemodel.BranchRateModel) DefaultBranchRateModel(dr.evomodel.branchratemodel.DefaultBranchRateModel) ProcessSimulation(dr.evomodel.treedatalikelihood.ProcessSimulation) TipRealizedValuesViaFullConditionalDelegate(dr.evomodel.treedatalikelihood.preorder.TipRealizedValuesViaFullConditionalDelegate) TreeTraitParserUtilities(dr.evomodelxml.treelikelihood.TreeTraitParserUtilities) ConditionalOnTipsRealizedDelegate(dr.evomodel.treedatalikelihood.preorder.ConditionalOnTipsRealizedDelegate) MultivariateConditionalOnTipsRealizedDelegate(dr.evomodel.treedatalikelihood.preorder.MultivariateConditionalOnTipsRealizedDelegate) ProcessSimulationDelegate(dr.evomodel.treedatalikelihood.preorder.ProcessSimulationDelegate)

Example 13 with TreeDataLikelihood

use of dr.evomodel.treedatalikelihood.TreeDataLikelihood in project beast-mcmc by beast-dev.

the class TreeDataLikelihoodParser method createTreeDataLikelihood.

protected Likelihood createTreeDataLikelihood(List<PatternList> patternLists, List<BranchModel> branchModels, List<SiteRateModel> siteRateModels, Tree treeModel, BranchRateModel branchRateModel, TipStatesModel tipStatesModel, boolean useAmbiguities, boolean preferGPU, PartialsRescalingScheme scalingScheme, boolean delayRescalingUntilUnderflow, PreOrderSettings settings) throws XMLParseException {
    if (tipStatesModel != null) {
        throw new XMLParseException("Tip State Error models are not supported yet with TreeDataLikelihood");
    }
    List<Taxon> treeTaxa = treeModel.asList();
    List<Taxon> patternTaxa = patternLists.get(0).asList();
    if (!patternTaxa.containsAll(treeTaxa)) {
        throw new XMLParseException("TreeModel " + treeModel.getId() + " contains more taxa (" + treeModel.getExternalNodeCount() + ") than the partition pattern list (" + patternTaxa.size() + ").");
    }
    if (!treeTaxa.containsAll(patternTaxa)) {
        throw new XMLParseException("TreeModel " + treeModel.getId() + " contains fewer taxa (" + treeModel.getExternalNodeCount() + ") than the partition pattern list (" + patternTaxa.size() + ").");
    }
    boolean useBeagle3MultiPartition = false;
    if (patternLists.size() > 1) {
        // will currently recommend true if using GPU, CUDA or OpenCL.
        useBeagle3MultiPartition = MultiPartitionDataLikelihoodDelegate.IS_MULTI_PARTITION_RECOMMENDED();
        if (System.getProperty("USE_BEAGLE3_EXTENSIONS") != null) {
            useBeagle3MultiPartition = Boolean.parseBoolean(System.getProperty("USE_BEAGLE3_EXTENSIONS"));
        }
        if (System.getProperty("beagle.multipartition.extensions") != null && !System.getProperty("beagle.multipartition.extensions").equals("auto")) {
            useBeagle3MultiPartition = Boolean.parseBoolean(System.getProperty("beagle.multipartition.extensions"));
        }
    }
    boolean useJava = Boolean.parseBoolean(System.getProperty("java.only", "false"));
    int threadCount = -1;
    int beagleThreadCount = -1;
    if (System.getProperty(BEAGLE_THREAD_COUNT) != null) {
        beagleThreadCount = Integer.parseInt(System.getProperty(BEAGLE_THREAD_COUNT));
    }
    if (beagleThreadCount == -1) {
        if (System.getProperty(THREAD_COUNT) != null) {
            threadCount = Integer.parseInt(System.getProperty(THREAD_COUNT));
        }
    }
    if (useBeagle3MultiPartition && !useJava) {
        if (beagleThreadCount == -1 && threadCount >= 0) {
            System.setProperty(BEAGLE_THREAD_COUNT, Integer.toString(threadCount));
        }
        try {
            DataLikelihoodDelegate dataLikelihoodDelegate = new MultiPartitionDataLikelihoodDelegate(treeModel, patternLists, branchModels, siteRateModels, useAmbiguities, scalingScheme, delayRescalingUntilUnderflow);
            return new TreeDataLikelihood(dataLikelihoodDelegate, treeModel, branchRateModel);
        } catch (DataLikelihoodDelegate.DelegateTypeException dte) {
            useBeagle3MultiPartition = false;
        }
    }
    // The multipartition data likelihood isn't available so make a set of single partition data likelihoods
    List<Likelihood> treeDataLikelihoods = new ArrayList<Likelihood>();
    // Todo: allow for different number of threads per beagle instance according to pattern counts
    if (beagleThreadCount == -1 && threadCount >= 0) {
        System.setProperty(BEAGLE_THREAD_COUNT, Integer.toString(threadCount / patternLists.size()));
    }
    for (int i = 0; i < patternLists.size(); i++) {
        DataLikelihoodDelegate dataLikelihoodDelegate = new BeagleDataLikelihoodDelegate(treeModel, patternLists.get(i), branchModels.get(i), siteRateModels.get(i), useAmbiguities, preferGPU, scalingScheme, delayRescalingUntilUnderflow, settings);
        treeDataLikelihoods.add(new TreeDataLikelihood(dataLikelihoodDelegate, treeModel, branchRateModel));
    }
    if (treeDataLikelihoods.size() == 1) {
        return treeDataLikelihoods.get(0);
    }
    return new CompoundLikelihood(treeDataLikelihoods);
}
Also used : CompoundLikelihood(dr.inference.model.CompoundLikelihood) Likelihood(dr.inference.model.Likelihood) TreeDataLikelihood(dr.evomodel.treedatalikelihood.TreeDataLikelihood) Taxon(dr.evolution.util.Taxon) CompoundLikelihood(dr.inference.model.CompoundLikelihood) ArrayList(java.util.ArrayList) BeagleDataLikelihoodDelegate(dr.evomodel.treedatalikelihood.BeagleDataLikelihoodDelegate) TreeDataLikelihood(dr.evomodel.treedatalikelihood.TreeDataLikelihood) BeagleDataLikelihoodDelegate(dr.evomodel.treedatalikelihood.BeagleDataLikelihoodDelegate) MultiPartitionDataLikelihoodDelegate(dr.evomodel.treedatalikelihood.MultiPartitionDataLikelihoodDelegate) DataLikelihoodDelegate(dr.evomodel.treedatalikelihood.DataLikelihoodDelegate) MultiPartitionDataLikelihoodDelegate(dr.evomodel.treedatalikelihood.MultiPartitionDataLikelihoodDelegate)

Example 14 with TreeDataLikelihood

use of dr.evomodel.treedatalikelihood.TreeDataLikelihood in project beast-mcmc by beast-dev.

the class TraitValidationProviderParser method parseTraitValidationProvider.

public static TraitValidationProvider parseTraitValidationProvider(XMLObject xo) throws XMLParseException {
    String trueValuesName = xo.getStringAttribute(TreeTraitParserUtilities.TRAIT_NAME);
    String inferredValuesName = xo.getStringAttribute(INFERRED_NAME);
    TreeDataLikelihood treeLikelihood = (TreeDataLikelihood) xo.getChild(TreeDataLikelihood.class);
    ContinuousDataLikelihoodDelegate delegate = (ContinuousDataLikelihoodDelegate) treeLikelihood.getDataLikelihoodDelegate();
    ContinuousTraitPartialsProvider dataModel = delegate.getDataModel();
    Tree treeModel = treeLikelihood.getTree();
    TreeTraitParserUtilities utilities = new TreeTraitParserUtilities();
    TreeTraitParserUtilities.TraitsAndMissingIndices returnValue = utilities.parseTraitsFromTaxonAttributes(xo, trueValuesName, treeModel, true);
    Parameter trueParameter = returnValue.traitParameter;
    List<Integer> trueMissing = returnValue.missingIndices;
    Parameter missingParameter = null;
    if (xo.hasChildNamed(MASK)) {
        missingParameter = (Parameter) xo.getElementFirstChild(MASK);
    }
    String id = xo.getId();
    TraitValidationProvider provider = new TraitValidationProvider(trueParameter, dataModel, treeModel, id, missingParameter, treeLikelihood, inferredValuesName, trueMissing);
    return provider;
}
Also used : TreeDataLikelihood(dr.evomodel.treedatalikelihood.TreeDataLikelihood) Tree(dr.evolution.tree.Tree) TreeTraitParserUtilities(dr.evomodelxml.treelikelihood.TreeTraitParserUtilities) Parameter(dr.inference.model.Parameter) ContinuousDataLikelihoodDelegate(dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate) ContinuousTraitPartialsProvider(dr.evomodel.treedatalikelihood.continuous.ContinuousTraitPartialsProvider) TraitValidationProvider(dr.inference.model.TraitValidationProvider)

Example 15 with TreeDataLikelihood

use of dr.evomodel.treedatalikelihood.TreeDataLikelihood in project beast-mcmc by beast-dev.

the class CheckPointTreeModifier method incorporateAdditionalTaxa.

/**
 * Add the remaining taxa, which can be identified through the TreeDataLikelihood XML elements.
 */
public ArrayList<NodeRef> incorporateAdditionalTaxa(CheckPointUpdaterApp.UpdateChoice choice, BranchRates rateModel, ArrayList<TreeParameterModel> traitModels) {
    // public ArrayList<NodeRef> incorporateAdditionalTaxa(CheckPointUpdaterApp.UpdateChoice choice, BranchRates rateModel) {
    System.out.println("Tree before adding taxa:\n" + treeModel.toString() + "\n");
    ArrayList<NodeRef> newTaxaNodes = new ArrayList<NodeRef>();
    for (String str : newTaxaNames) {
        for (int i = 0; i < treeModel.getExternalNodeCount(); i++) {
            if (treeModel.getNodeTaxon(treeModel.getExternalNode(i)).getId().equals(str)) {
                newTaxaNodes.add(treeModel.getExternalNode(i));
                // always take into account Taxon dates vs. dates set through a TreeModel
                System.out.println(treeModel.getNodeTaxon(treeModel.getExternalNode(i)).getId() + " with height " + treeModel.getNodeHeight(treeModel.getExternalNode(i)) + " or " + treeModel.getNodeTaxon(treeModel.getExternalNode(i)).getHeight());
            }
        }
    }
    System.out.println("newTaxaNodes length = " + newTaxaNodes.size());
    ArrayList<Taxon> currentTaxa = new ArrayList<Taxon>();
    for (int i = 0; i < treeModel.getExternalNodeCount(); i++) {
        boolean taxonFound = false;
        for (String str : newTaxaNames) {
            if (str.equals((treeModel.getNodeTaxon(treeModel.getExternalNode(i))).getId())) {
                taxonFound = true;
            }
        }
        if (!taxonFound) {
            System.out.println("Adding " + treeModel.getNodeTaxon(treeModel.getExternalNode(i)).getId() + " to list of current taxa");
            currentTaxa.add(treeModel.getNodeTaxon(treeModel.getExternalNode(i)));
        }
    }
    System.out.println("Current taxa count = " + currentTaxa.size());
    // iterate over both current taxa and to be added taxa
    boolean originTaxon = true;
    for (Taxon taxon : currentTaxa) {
        if (taxon.getHeight() == 0.0) {
            originTaxon = false;
            System.out.println("Current taxon " + taxon.getId() + " has node height 0.0");
        }
    }
    for (NodeRef newTaxon : newTaxaNodes) {
        if (treeModel.getNodeTaxon(newTaxon).getHeight() == 0.0) {
            originTaxon = false;
            System.out.println("New taxon " + treeModel.getNodeTaxon(newTaxon).getId() + " has node height 0.0");
        }
    }
    // check the Tree(Data)Likelihoods in the connected set of likelihoods
    // focus on TreeDataLikelihood, which has getTree() to get the tree for each likelihood
    // also get the DataLikelihoodDelegate from TreeDataLikelihood
    ArrayList<TreeDataLikelihood> likelihoods = new ArrayList<TreeDataLikelihood>();
    ArrayList<Tree> trees = new ArrayList<Tree>();
    ArrayList<DataLikelihoodDelegate> delegates = new ArrayList<DataLikelihoodDelegate>();
    for (Likelihood likelihood : Likelihood.CONNECTED_LIKELIHOOD_SET) {
        if (likelihood instanceof TreeDataLikelihood) {
            likelihoods.add((TreeDataLikelihood) likelihood);
            trees.add(((TreeDataLikelihood) likelihood).getTree());
            delegates.add(((TreeDataLikelihood) likelihood).getDataLikelihoodDelegate());
        }
    }
    // suggested to go through TreeDataLikelihoodParser and give it an extra option to create a HashMap
    // keyed by the tree; am currently not overly fond of this approach
    ArrayList<PatternList> patternLists = new ArrayList<PatternList>();
    for (DataLikelihoodDelegate del : delegates) {
        if (del instanceof BeagleDataLikelihoodDelegate) {
            patternLists.add(((BeagleDataLikelihoodDelegate) del).getPatternList());
        } else if (del instanceof MultiPartitionDataLikelihoodDelegate) {
            MultiPartitionDataLikelihoodDelegate mpdld = (MultiPartitionDataLikelihoodDelegate) del;
            List<PatternList> list = mpdld.getPatternLists();
            for (PatternList pList : list) {
                patternLists.add(pList);
            }
        }
    }
    if (patternLists.size() == 0) {
        throw new RuntimeException("No patterns detected. Please make sure the XML file is BEAST 1.9 compatible.");
    }
    // aggregate all patterns to create distance matrix
    // TODO What about different trees for different partitions?
    Patterns patterns = new Patterns(patternLists.get(0));
    if (patternLists.size() > 1) {
        for (int i = 1; i < patternLists.size(); i++) {
            patterns.addPatterns(patternLists.get(i));
        }
    }
    // set the patterns for the distance matrix computations
    choice.setPatterns(patterns);
    // add new taxa one at a time
    System.out.println("Adding " + newTaxaNodes.size() + " taxa ...");
    if (NEW_APPROACH) {
        System.out.println("Branch rates are being updated after each new sequence is inserted");
        int numTaxaSoFar = treeModel.getExternalNodeCount() - newTaxaNodes.size();
        for (NodeRef newTaxon : newTaxaNodes) {
            // check for zero-length and negative length branches and internal nodes that don't have 2 child nodes
            if (DEBUG) {
                for (int i = 0; i < treeModel.getExternalNodeCount(); i++) {
                    NodeRef startingTip = treeModel.getExternalNode(i);
                    while (treeModel.getParent(startingTip) != null) {
                        if (treeModel.getChildCount(treeModel.getParent(startingTip)) != 2) {
                            System.out.println(treeModel.getChildCount(treeModel.getParent(startingTip)) + " children for node " + treeModel.getParent(startingTip));
                            System.out.println("Exiting ...");
                            System.exit(0);
                        }
                        double branchLength = treeModel.getNodeHeight(treeModel.getParent(startingTip)) - treeModel.getNodeHeight(startingTip);
                        if (branchLength == 0.0) {
                            System.out.println("Zero-length branch detected:");
                            System.out.println("  parent node: " + treeModel.getParent(startingTip));
                            System.out.println("  child node: " + startingTip);
                            System.out.println("Exiting ...");
                            System.exit(0);
                        } else if (branchLength < 0.0) {
                            System.out.println("Negative branch length detected:");
                            System.out.println("  parent node: " + treeModel.getParent(startingTip));
                            System.out.println("  child node: " + startingTip);
                            System.out.println("Exiting ...");
                            System.exit(0);
                        } else {
                            startingTip = treeModel.getParent(startingTip);
                        }
                    }
                }
            }
            treeModel.setNodeHeight(newTaxon, treeModel.getNodeTaxon(newTaxon).getHeight());
            System.out.println("\nadding Taxon: " + newTaxon + " (height = " + treeModel.getNodeHeight(newTaxon) + ")");
            // check if this taxon has a more recent sampling date than all other nodes in the current TreeModel
            double offset = checkCurrentTreeNodes(newTaxon, treeModel.getRoot());
            System.out.println("Sampling date offset when adding " + newTaxon + " = " + offset);
            // AND set its current node height to 0.0 IF no originTaxon has been found
            if (offset < 0.0) {
                if (!originTaxon) {
                    System.out.println("Updating all node heights with offset " + Math.abs(offset));
                    updateAllTreeNodes(Math.abs(offset), treeModel.getRoot());
                    treeModel.setNodeHeight(newTaxon, 0.0);
                }
            } else if (offset == 0.0) {
                if (!originTaxon) {
                    treeModel.setNodeHeight(newTaxon, 0.0);
                }
            }
            // get the closest Taxon to the Taxon that needs to be added
            // take into account which taxa can currently be chosen
            Taxon closest = choice.getClosestTaxon(treeModel.getNodeTaxon(newTaxon), currentTaxa);
            System.out.println("\nclosest Taxon: " + closest + " with original height: " + closest.getHeight());
            // get the distance between these two taxa
            double distance = choice.getDistance(treeModel.getNodeTaxon(newTaxon), closest);
            if (distance == 0.0) {
                // employ minimum insertion distance but add in a random factor to avoid identical insertion heights
                // this to avoid multifurcations in the case of (many) identical sequences
                distance = MIN_DIST * MathUtils.nextDouble();
                System.out.println("Sequences are identical, setting minimum distance to " + distance);
            }
            System.out.println("at distance: " + distance);
            // find the NodeRef for the closest Taxon (do not rely on node numbering)
            NodeRef closestRef = null;
            // careful with node numbering and subtract number of new taxa
            for (int i = 0; i < treeModel.getExternalNodeCount(); i++) {
                if (treeModel.getNodeTaxon(treeModel.getExternalNode(i)) == closest) {
                    closestRef = treeModel.getExternalNode(i);
                    System.out.println("  closest external nodeRef: " + closestRef);
                }
            }
            System.out.println("closest node : " + closestRef + " with height " + treeModel.getNodeHeight(closestRef));
            System.out.println("parent node: " + treeModel.getParent(closestRef));
            // begin change
            // TODO: only for Sam, revert back to the line below !!!
            // double timeForDistance = distance / rateModel.getBranchRate(treeModel, closestRef);
            double timeForDistance = distance / getBranchRate(traitModels.get(0), closestRef, numTaxaSoFar);
            // end change
            System.out.println("timeForDistance = " + timeForDistance);
            // get parent node of branch that will be split
            NodeRef parent = treeModel.getParent(closestRef);
            // child node of branch that will be split (will be assigned value later)
            NodeRef splitBranchChild;
            /*if ((treeModel.getNodeHeight(parent) - treeModel.getNodeHeight(closestRef)) == 0.0) {
                    System.out.println("Zero-length branch:");
                    System.out.println(parent);
                    System.out.println(closestRef);
                    System.out.println("Exiting ...");
                    System.exit(0);
                }*/
            // determine height of new node
            double insertHeight;
            if (treeModel.getNodeHeight(closestRef) == treeModel.getNodeHeight(newTaxon)) {
                // if the sequences have the same sampling date/time, then simply split the distance/time between the two sequences in half
                // both the closest sequence and the new sequence are equidistant from the newly inserted internal node
                insertHeight = treeModel.getNodeHeight(closestRef) + timeForDistance / 2.0;
                System.out.println("equal sampling times (" + treeModel.getNodeHeight(closestRef) + ") ; insertHeight = " + insertHeight);
                splitBranchChild = closestRef;
                if (insertHeight >= treeModel.getNodeHeight(parent)) {
                    while (insertHeight >= treeModel.getNodeHeight(parent)) {
                        if (treeModel.getParent(parent) == null) {
                            // Use this insertHeight value in case parent doesn't have parent
                            // Otherwise, move up tree
                            insertHeight = treeModel.getNodeHeight(splitBranchChild) + EPSILON * (treeModel.getNodeHeight(parent) - treeModel.getNodeHeight(splitBranchChild));
                            break;
                        } else {
                            splitBranchChild = parent;
                            parent = treeModel.getParent(splitBranchChild);
                        }
                    }
                }
                // now that a suitable branch has been found, perform additional checks
                // parent of branch = parent; child of branch = splitBranchChild
                System.out.printf("parent height: %.25f \n", treeModel.getNodeHeight(parent));
                System.out.printf("child height: %.25f \n", treeModel.getNodeHeight(splitBranchChild));
                if (((treeModel.getNodeHeight(parent) - insertHeight) < MIN_DIST) && ((insertHeight - treeModel.getNodeHeight(splitBranchChild)) < MIN_DIST)) {
                    throw new RuntimeException("No suitable branch found for sequence insertion (all branches < minimum branch length).");
                } else if ((treeModel.getNodeHeight(parent) - insertHeight) < MIN_DIST) {
                    System.out.println("  insertion height too close to parent height");
                    double newInsertHeight = treeModel.getNodeHeight(splitBranchChild) + Math.abs(MathUtils.nextDouble()) * (insertHeight - treeModel.getNodeHeight(splitBranchChild));
                    while (((treeModel.getNodeHeight(parent) - newInsertHeight) < MIN_DIST) && (newInsertHeight - treeModel.getNodeHeight(splitBranchChild) < MIN_DIST)) {
                        newInsertHeight = treeModel.getNodeHeight(splitBranchChild) + Math.abs(MathUtils.nextDouble()) * (insertHeight - treeModel.getNodeHeight(splitBranchChild));
                    }
                    System.out.println("  new insertion height = " + newInsertHeight);
                    insertHeight = newInsertHeight;
                } else if ((insertHeight - treeModel.getNodeHeight(splitBranchChild)) < MIN_DIST) {
                    System.out.println("  insertion height too close to child height");
                    double newInsertHeight = treeModel.getNodeHeight(splitBranchChild) + Math.abs(MathUtils.nextDouble()) * (treeModel.getNodeHeight(parent) - insertHeight);
                    while (((treeModel.getNodeHeight(parent) - newInsertHeight) < MIN_DIST) && (newInsertHeight - treeModel.getNodeHeight(splitBranchChild) < MIN_DIST)) {
                        newInsertHeight = treeModel.getNodeHeight(splitBranchChild) + Math.abs(MathUtils.nextDouble()) * (treeModel.getNodeHeight(parent) - insertHeight);
                    }
                    System.out.println("  new insertion height = " + newInsertHeight);
                    insertHeight = newInsertHeight;
                }
            } else {
                // first calculate if the new internal node is older than both the new sequence and its closest sequence already present in the tree
                double remainder = (timeForDistance - Math.abs(treeModel.getNodeHeight(closestRef) - treeModel.getNodeHeight(newTaxon))) / 2.0;
                if (remainder > 0) {
                    insertHeight = Math.max(treeModel.getNodeHeight(closestRef), treeModel.getNodeHeight(newTaxon)) + remainder;
                    System.out.println("remainder > 0 (" + remainder + "): " + insertHeight);
                    splitBranchChild = closestRef;
                    if (insertHeight >= treeModel.getNodeHeight(parent)) {
                        while (insertHeight >= treeModel.getNodeHeight(parent)) {
                            if (treeModel.getParent(parent) == null) {
                                // use this insertHeight value in case parent doesn't have parent (and we can't move up the tree)
                                insertHeight = treeModel.getNodeHeight(splitBranchChild) + EPSILON * (treeModel.getNodeHeight(parent) - treeModel.getNodeHeight(splitBranchChild));
                                break;
                            } else {
                                // otherwise move up in the tree
                                splitBranchChild = parent;
                                parent = treeModel.getParent(splitBranchChild);
                            }
                        }
                    }
                    // now that a suitable branch has been found, perform additional checks
                    // parent of branch = parent; child of branch = splitBranchChild
                    System.out.printf("parent height: %.25f \n", treeModel.getNodeHeight(parent));
                    System.out.printf("child height: %.25f \n", treeModel.getNodeHeight(splitBranchChild));
                    if (((treeModel.getNodeHeight(parent) - insertHeight) < MIN_DIST) && ((insertHeight - treeModel.getNodeHeight(splitBranchChild)) < MIN_DIST)) {
                        throw new RuntimeException("No suitable branch found for sequence insertion (all branches < minimum branch length).");
                    } else if ((treeModel.getNodeHeight(parent) - insertHeight) < MIN_DIST) {
                        System.out.println("  insertion height too close to parent height");
                        double newInsertHeight = treeModel.getNodeHeight(splitBranchChild) + Math.abs(MathUtils.nextDouble()) * (insertHeight - treeModel.getNodeHeight(splitBranchChild));
                        while (((treeModel.getNodeHeight(parent) - newInsertHeight) < MIN_DIST) && (newInsertHeight - treeModel.getNodeHeight(splitBranchChild) < MIN_DIST)) {
                            newInsertHeight = treeModel.getNodeHeight(splitBranchChild) + Math.abs(MathUtils.nextDouble()) * (insertHeight - treeModel.getNodeHeight(splitBranchChild));
                        }
                        System.out.println("  new insertion height = " + newInsertHeight);
                        insertHeight = newInsertHeight;
                    } else if ((insertHeight - treeModel.getNodeHeight(splitBranchChild)) < MIN_DIST) {
                        System.out.println("  insertion height too close to child height");
                        double newInsertHeight = treeModel.getNodeHeight(splitBranchChild) + Math.abs(MathUtils.nextDouble()) * (treeModel.getNodeHeight(parent) - insertHeight);
                        while (((treeModel.getNodeHeight(parent) - newInsertHeight) < MIN_DIST) && (newInsertHeight - treeModel.getNodeHeight(splitBranchChild) < MIN_DIST)) {
                            newInsertHeight = treeModel.getNodeHeight(splitBranchChild) + Math.abs(MathUtils.nextDouble()) * (treeModel.getNodeHeight(parent) - insertHeight);
                        }
                        System.out.println("  new insertion height = " + newInsertHeight);
                        insertHeight = newInsertHeight;
                    }
                } else {
                    // Come up with better way to handle this?
                    insertHeight = EPSILON * (treeModel.getNodeHeight(parent) - Math.max(treeModel.getNodeHeight(closestRef), treeModel.getNodeHeight(newTaxon)));
                    System.out.println("insertHeight after EPSILON: " + insertHeight);
                    insertHeight += Math.max(treeModel.getNodeHeight(closestRef), treeModel.getNodeHeight(newTaxon));
                    System.out.println("remainder <= 0: " + insertHeight);
                    double insertDifference = treeModel.getNodeHeight(parent) - insertHeight;
                    System.out.println("height difference = " + insertDifference);
                    splitBranchChild = closestRef;
                    if (insertDifference < MIN_DIST) {
                        System.out.println("branch too short ...");
                        boolean suitableBranch = false;
                        // 2. no negative branch length can be introduced so the parent node needs to be older than the new taxon and its closest reference
                        while (!suitableBranch) {
                            double parentBranchLength = treeModel.getNodeHeight(treeModel.getParent(parent)) - treeModel.getNodeHeight(parent);
                            if ((parentBranchLength < MIN_DIST) || ((treeModel.getNodeHeight(parent) - Math.max(treeModel.getNodeHeight(closestRef), treeModel.getNodeHeight(newTaxon))) < 0.0)) {
                                // find another branch by moving upwards in the tree
                                splitBranchChild = parent;
                                parent = treeModel.getParent(splitBranchChild);
                            } else {
                                // node needs to be inserted along the branch from parent to splitBranchChild
                                double positiveRandom = Math.abs(MathUtils.nextDouble());
                                // insertion height needs to be higher than taxon being inserted
                                double minimumHeight = Math.max(Math.max(treeModel.getNodeHeight(splitBranchChild), treeModel.getNodeHeight(newTaxon)), treeModel.getNodeHeight(closestRef));
                                insertHeight = treeModel.getNodeHeight(parent) - (treeModel.getNodeHeight(parent) - minimumHeight) * positiveRandom;
                                suitableBranch = true;
                            }
                        }
                    }
                }
            }
            System.out.println("insert at height: " + insertHeight);
            System.out.printf("parent height: %.25f \n", treeModel.getNodeHeight(parent));
            System.out.printf("height difference = %.25f\n", (insertHeight - treeModel.getNodeHeight(parent)));
            if (treeModel.getParent(parent) != null) {
                System.out.printf("grandparent height: %.25f \n", treeModel.getNodeHeight(treeModel.getParent(parent)));
            } else {
                System.out.println("parent == root");
            }
            System.out.printf("child height: %.25f \n", treeModel.getNodeHeight(splitBranchChild));
            System.out.printf("insert at height: %.25f \n", insertHeight);
            // pass on all the necessary variables to a method that adds the new taxon to the tree
            addTaxonAlongBranch(newTaxon, parent, splitBranchChild, insertHeight);
            // option to print tree after each taxon addition
            System.out.println("\nTree after adding taxon " + newTaxon + ":\n" + treeModel.toString());
            System.out.println(">>" + treeModel.toString());
            // add newly added Taxon to list of current taxa
            currentTaxa.add(treeModel.getNodeTaxon(newTaxon));
            // Update rate categories here
            interpolateTraitValuesOneInsertion(traitModels, newTaxon);
            numTaxaSoFar++;
        }
    }
    return newTaxaNodes;
}
Also used : Likelihood(dr.inference.model.Likelihood) TreeDataLikelihood(dr.evomodel.treedatalikelihood.TreeDataLikelihood) Taxon(dr.evolution.util.Taxon) ArrayList(java.util.ArrayList) PatternList(dr.evolution.alignment.PatternList) BeagleDataLikelihoodDelegate(dr.evomodel.treedatalikelihood.BeagleDataLikelihoodDelegate) NodeRef(dr.evolution.tree.NodeRef) TreeDataLikelihood(dr.evomodel.treedatalikelihood.TreeDataLikelihood) BeagleDataLikelihoodDelegate(dr.evomodel.treedatalikelihood.BeagleDataLikelihoodDelegate) MultiPartitionDataLikelihoodDelegate(dr.evomodel.treedatalikelihood.MultiPartitionDataLikelihoodDelegate) DataLikelihoodDelegate(dr.evomodel.treedatalikelihood.DataLikelihoodDelegate) Tree(dr.evolution.tree.Tree) PatternList(dr.evolution.alignment.PatternList) ArrayList(java.util.ArrayList) List(java.util.List) MultiPartitionDataLikelihoodDelegate(dr.evomodel.treedatalikelihood.MultiPartitionDataLikelihoodDelegate) Patterns(dr.evolution.alignment.Patterns)

Aggregations

TreeDataLikelihood (dr.evomodel.treedatalikelihood.TreeDataLikelihood)45 ArrayList (java.util.ArrayList)29 BranchRateModel (dr.evomodel.branchratemodel.BranchRateModel)25 DefaultBranchRateModel (dr.evomodel.branchratemodel.DefaultBranchRateModel)23 StrictClockBranchRates (dr.evomodel.branchratemodel.StrictClockBranchRates)21 Parameter (dr.inference.model.Parameter)20 MultivariateElasticModel (dr.evomodel.continuous.MultivariateElasticModel)17 MatrixParameter (dr.inference.model.MatrixParameter)15 DataLikelihoodDelegate (dr.evomodel.treedatalikelihood.DataLikelihoodDelegate)12 ArbitraryBranchRates (dr.evomodel.branchratemodel.ArbitraryBranchRates)11 Tree (dr.evolution.tree.Tree)8 ContinuousDataLikelihoodDelegate (dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate)8 DiagonalMatrix (dr.inference.model.DiagonalMatrix)8 BeagleDataLikelihoodDelegate (dr.evomodel.treedatalikelihood.BeagleDataLikelihoodDelegate)6 Likelihood (dr.inference.model.Likelihood)6 Taxon (dr.evolution.util.Taxon)4 MultiPartitionDataLikelihoodDelegate (dr.evomodel.treedatalikelihood.MultiPartitionDataLikelihoodDelegate)4 CompoundLikelihood (dr.inference.model.CompoundLikelihood)4 MatrixParameterInterface (dr.inference.model.MatrixParameterInterface)4 GradientWrtParameterProvider (dr.inference.hmc.GradientWrtParameterProvider)3