Search in sources :

Example 11 with MutableTreeModel

use of dr.evolution.tree.MutableTreeModel in project beast-mcmc by beast-dev.

the class OptimizedBeagleTreeLikelihoodParser method parseXMLObject.

public Object parseXMLObject(XMLObject xo) throws XMLParseException {
    // Default of 100 likelihood calculations for calibration
    int calibrate = 100;
    if (xo.hasAttribute(CALIBRATE)) {
        calibrate = xo.getIntegerAttribute(CALIBRATE);
    }
    // Default: only try the next split up, unless a RETRY value is specified in the XML
    int retry = 0;
    if (xo.hasAttribute(RETRY)) {
        retry = xo.getIntegerAttribute(RETRY);
    }
    int childCount = xo.getChildCount();
    List<Likelihood> likelihoods = new ArrayList<Likelihood>();
    // TEST
    List<Likelihood> originalLikelihoods = new ArrayList<Likelihood>();
    for (int i = 0; i < childCount; i++) {
        likelihoods.add((Likelihood) xo.getChild(i));
        originalLikelihoods.add((Likelihood) xo.getChild(i));
    }
    if (DEBUG) {
        System.err.println("-----");
        System.err.println(childCount + " BeagleTreeLikelihoods added.");
    }
    int[] instanceCounts = new int[childCount];
    for (int i = 0; i < childCount; i++) {
        instanceCounts[i] = 1;
    }
    int[] currentLocation = new int[childCount];
    for (int i = 0; i < childCount; i++) {
        currentLocation[i] = i;
    }
    int[] siteCounts = new int[childCount];
    // store everything for later use
    SitePatterns[] patterns = new SitePatterns[childCount];
    MutableTreeModel[] treeModels = new TreeModel[childCount];
    BranchModel[] branchModels = new BranchModel[childCount];
    GammaSiteRateModel[] siteRateModels = new GammaSiteRateModel[childCount];
    BranchRateModel[] branchRateModels = new BranchRateModel[childCount];
    boolean[] ambiguities = new boolean[childCount];
    PartialsRescalingScheme[] rescalingSchemes = new PartialsRescalingScheme[childCount];
    boolean[] isDelayRescalingUntilUnderflow = new boolean[childCount];
    List<Map<Set<String>, Parameter>> partialsRestrictions = new ArrayList<Map<Set<String>, Parameter>>();
    for (int i = 0; i < likelihoods.size(); i++) {
        patterns[i] = (SitePatterns) ((BeagleTreeLikelihood) likelihoods.get(i)).getPatternsList();
        siteCounts[i] = patterns[i].getPatternCount();
        treeModels[i] = ((BeagleTreeLikelihood) likelihoods.get(i)).getTreeModel();
        branchModels[i] = ((BeagleTreeLikelihood) likelihoods.get(i)).getBranchModel();
        siteRateModels[i] = (GammaSiteRateModel) ((BeagleTreeLikelihood) likelihoods.get(i)).getSiteRateModel();
        branchRateModels[i] = ((BeagleTreeLikelihood) likelihoods.get(i)).getBranchRateModel();
        ambiguities[i] = ((BeagleTreeLikelihood) likelihoods.get(i)).useAmbiguities();
        rescalingSchemes[i] = ((BeagleTreeLikelihood) likelihoods.get(i)).getRescalingScheme();
        isDelayRescalingUntilUnderflow[i] = ((BeagleTreeLikelihood) likelihoods.get(i)).isDelayRescalingUntilUnderflow();
        partialsRestrictions.add(i, ((BeagleTreeLikelihood) likelihoods.get(i)).getPartialsRestrictions());
    }
    if (DEBUG) {
        System.err.println("Pattern counts: ");
        for (int i = 0; i < siteCounts.length; i++) {
            System.err.println(siteCounts[i] + "   vs.    " + patterns[i].getPatternCount());
        }
        System.err.println();
        System.err.println("Instance counts: ");
        for (int i = 0; i < instanceCounts.length; i++) {
            System.err.print(instanceCounts[i] + " ");
        }
        System.err.println();
        System.err.println("Current locations: ");
        for (int i = 0; i < currentLocation.length; i++) {
            System.err.print(currentLocation[i] + " ");
        }
        System.err.println();
    }
    TestThreadedCompoundLikelihood compound = new TestThreadedCompoundLikelihood(likelihoods);
    if (DEBUG) {
        System.err.println("Timing estimates for each of the " + calibrate + " likelihood calculations:");
    }
    double start = System.nanoTime();
    for (int i = 0; i < calibrate; i++) {
        if (DEBUG) {
            // double debugStart = System.nanoTime();
            compound.makeDirty();
            compound.getLogLikelihood();
        // double debugEnd = System.nanoTime();
        // System.err.println(debugEnd - debugStart);
        } else {
            compound.makeDirty();
            compound.getLogLikelihood();
        }
    }
    double end = System.nanoTime();
    double baseResult = end - start;
    if (DEBUG) {
        System.err.println("Starting evaluation took: " + baseResult);
    }
    int longestIndex = 0;
    int longestSize = siteCounts[0];
    // START TEST CODE
    /*System.err.println("Detailed evaluation times: ");
        long[] evaluationTimes = compound.getEvaluationTimes();
        int[] evaluationCounts = compound.getEvaluationCounts();
        long longest = evaluationTimes[0];
        for (int i = 0; i < evaluationTimes.length; i++) {
        	System.err.println(i + ": time=" + evaluationTimes[i] + "   count=" + evaluationCounts[i]);
            if (evaluationTimes[i] > longest) {
            	longest = evaluationTimes[i];
        	}
        }*/
    // END TEST CODE
    /*if (SPLIT_BY_PATTERN_COUNT) {

        	boolean notFinished = true;

        	while (notFinished) {

        		for (int i = 0; i < siteCounts.length; i++) {
        			if (siteCounts[i] > longestSize) {
        				longestIndex = i;
        				longestSize = siteCounts[longestIndex];
        			}
        		}
        		System.err.println("Split likelihood " + longestIndex + " with pattern count " + longestSize);

        		//split it in 2
        		int instanceCount = ++instanceCounts[longestIndex];

        		List<Likelihood> newList = new ArrayList<Likelihood>();
        		for (int i = 0; i < instanceCount; i++) {
        			Patterns subPatterns = new Patterns(patterns[longestIndex], 0, 0, 1, i, instanceCount);

        			BeagleTreeLikelihood treeLikelihood = createTreeLikelihood(
        					subPatterns, treeModels[longestIndex], branchModels[longestIndex], siteRateModels[longestIndex], branchRateModels[longestIndex],
        					null, 
        					ambiguities[longestIndex], rescalingSchemes[longestIndex], partialsRestrictions.get(longestIndex),
        					xo);

        			treeLikelihood.setId(xo.getId() + "_" + instanceCount);
        			newList.add(treeLikelihood);
        		}
        		for (int i = 0; i < newList.size()-1; i++) {
        			likelihoods.remove(currentLocation[longestIndex]);
        		}
        		//likelihoods.remove(longestIndex);
        		//likelihoods.add(longestIndex, new CompoundLikelihood(newList));
        		for (int i = 0; i < newList.size(); i++) {
        			likelihoods.add(currentLocation[longestIndex], newList.get(i));
        		}
        		for (int i = longestIndex+1; i < currentLocation.length; i++) {
        			currentLocation[i]++;
        		}
        		//compound = new ThreadedCompoundLikelihood(likelihoods);
        		compound = new CompoundLikelihood(likelihoods);
        		siteCounts[longestIndex] = (instanceCount-1)*siteCounts[longestIndex]/instanceCount;
        		longestSize = (instanceCount-1)*longestSize/instanceCount;

        		//check number of likelihoods
        		System.err.println("Number of BeagleTreeLikelihoods: " + compound.getLikelihoodCount());
        		System.err.println("Pattern counts: ");
        		for (int i = 0;i < siteCounts.length; i++) {
        			System.err.print(siteCounts[i] + " ");
        		}
        		System.err.println();
        		System.err.println("Instance counts: ");
        		for (int i = 0;i < instanceCounts.length; i++) {
        			System.err.print(instanceCounts[i] + " ");
        		}
        		System.err.println();
        		System.err.println("Current locations: ");
        		for (int i = 0;i < currentLocation.length; i++) {
        			System.err.print(currentLocation[i] + " ");
        		}
        		System.err.println();

        		//evaluate speed
        		start = System.nanoTime();
        		for (int i = 0; i < TEST_RUNS; i++) {
        			compound.makeDirty();
        			compound.getLogLikelihood();
        		}
        		end = System.nanoTime();
        		double newResult = end - start;
        		System.err.println("New evaluation took: " + newResult + " vs. old evaluation: " + baseResult);

        		if (newResult < baseResult) {
            		baseResult = newResult;
            	} else {
            		notFinished = false;

            		//remove 1 instanceCount
            		System.err.print("Removing 1 instance count: " + instanceCount);
            		instanceCount = --instanceCounts[longestIndex];
            		System.err.println(" -> " + instanceCount + " for likelihood " + longestIndex);
            		newList = new ArrayList<Likelihood>();
                	for (int i = 0; i < instanceCount; i++) {
                		Patterns subPatterns = new Patterns(patterns[longestIndex], 0, 0, 1, i, instanceCount);

                		BeagleTreeLikelihood treeLikelihood = createTreeLikelihood(
                                subPatterns, treeModels[longestIndex], branchModels[longestIndex], siteRateModels[longestIndex], branchRateModels[longestIndex],
                                null, 
                                ambiguities[longestIndex], rescalingSchemes[longestIndex], partialsRestrictions.get(longestIndex),
                                xo);

                        treeLikelihood.setId(xo.getId() + "_" + instanceCount);
                        newList.add(treeLikelihood);
                	}
                	for (int i = 0; i < newList.size()+1; i++) {
            			likelihoods.remove(currentLocation[longestIndex]);
            		}
                	for (int i = 0; i < newList.size(); i++) {
                		likelihoods.add(currentLocation[longestIndex], newList.get(i));
                	}
                	for (int i = longestIndex+1; i < currentLocation.length; i++) {
            			currentLocation[i]--;
            		}
                	//likelihoods.remove(longestIndex);
                	//likelihoods.add(longestIndex, new CompoundLikelihood(newList));

                	//compound = new ThreadedCompoundLikelihood(likelihoods);
                	compound = new CompoundLikelihood(likelihoods);
                	siteCounts[longestIndex] = (instanceCount+1)*siteCounts[longestIndex]/instanceCount;
                	longestSize = (instanceCount+1)*longestSize/instanceCount;

                	System.err.println("Pattern counts: ");
                	for (int i = 0;i < siteCounts.length; i++) {
                		System.err.print(siteCounts[i] + " ");
                	}
                	System.err.println();
                	System.err.println("Instance counts: ");
                	for (int i = 0;i < instanceCounts.length; i++) {
                		System.err.print(instanceCounts[i] + " ");
                	}
                	System.err.println();
                	System.err.println("Current locations: ");
            		for (int i = 0;i < currentLocation.length; i++) {
            			System.err.print(currentLocation[i] + " ");
            		}
            		System.err.println();

            	}

        	}

        } else {*/
    // Try splitting the same likelihood until no further improvement, then move on towards the next one
    boolean notFinished = true;
    // construct list with likelihoods to split up
    List<Integer> splitList = new ArrayList<Integer>();
    for (int i = 0; i < siteCounts.length; i++) {
        int top = 0;
        for (int j = 0; j < siteCounts.length; j++) {
            if (siteCounts[j] > siteCounts[top]) {
                top = j;
            }
        }
        siteCounts[top] = 0;
        splitList.add(top);
    }
    for (int i = 0; i < likelihoods.size(); i++) {
        siteCounts[i] = patterns[i].getPatternCount();
        if (DEBUG) {
            System.err.println("Site count " + i + " = " + siteCounts[i]);
        }
    }
    if (DEBUG) {
        // print list
        System.err.print("Ordered list of likelihoods to be evaluated: ");
        for (int i = 0; i < splitList.size(); i++) {
            System.err.print(splitList.get(i) + " ");
        }
        System.err.println();
    }
    int timesRetried = 0;
    while (notFinished) {
        // split it in 1 more piece
        longestIndex = splitList.get(0);
        int instanceCount = ++instanceCounts[longestIndex];
        List<Likelihood> newList = new ArrayList<Likelihood>();
        for (int i = 0; i < instanceCount; i++) {
            Patterns subPatterns = new Patterns(patterns[longestIndex], 0, 0, 1, i, instanceCount);
            BeagleTreeLikelihood treeLikelihood = createTreeLikelihood(subPatterns, treeModels[longestIndex], branchModels[longestIndex], siteRateModels[longestIndex], branchRateModels[longestIndex], null, ambiguities[longestIndex], rescalingSchemes[longestIndex], isDelayRescalingUntilUnderflow[longestIndex], partialsRestrictions.get(longestIndex), xo);
            treeLikelihood.setId(xo.getId() + "_" + longestIndex + "_" + i);
            System.err.println(treeLikelihood.getId() + " created.");
            newList.add(treeLikelihood);
        }
        for (int i = 0; i < newList.size() - 1; i++) {
            likelihoods.remove(currentLocation[longestIndex]);
        }
        // likelihoods.add(longestIndex, new CompoundLikelihood(newList));
        for (int i = 0; i < newList.size(); i++) {
            likelihoods.add(currentLocation[longestIndex], newList.get(i));
        }
        for (int i = longestIndex + 1; i < currentLocation.length; i++) {
            currentLocation[i]++;
        }
        compound = new TestThreadedCompoundLikelihood(likelihoods);
        // compound = new CompoundLikelihood(likelihoods);
        // compound = new ThreadedCompoundLikelihood(likelihoods);
        siteCounts[longestIndex] = (instanceCount - 1) * siteCounts[longestIndex] / instanceCount;
        longestSize = (instanceCount - 1) * longestSize / instanceCount;
        if (DEBUG) {
            // check number of likelihoods
            System.err.println("Number of BeagleTreeLikelihoods: " + compound.getLikelihoodCount());
            System.err.println("Pattern counts: ");
            for (int i = 0; i < siteCounts.length; i++) {
                System.err.print(siteCounts[i] + " ");
            }
            System.err.println();
            System.err.println("Instance counts: ");
            for (int i = 0; i < instanceCounts.length; i++) {
                System.err.print(instanceCounts[i] + " ");
            }
            System.err.println();
            System.err.println("Current locations: ");
            for (int i = 0; i < currentLocation.length; i++) {
                System.err.print(currentLocation[i] + " ");
            }
            System.err.println();
        }
        // evaluate speed
        if (DEBUG) {
            System.err.println("Timing estimates for each of the " + calibrate + " likelihood calculations:");
        }
        start = System.nanoTime();
        for (int i = 0; i < calibrate; i++) {
            if (DEBUG) {
                // double debugStart = System.nanoTime();
                compound.makeDirty();
                compound.getLogLikelihood();
            // double debugEnd = System.nanoTime();
            // System.err.println(debugEnd - debugStart);
            } else {
                compound.makeDirty();
                compound.getLogLikelihood();
            }
        }
        end = System.nanoTime();
        double newResult = end - start;
        if (DEBUG) {
            System.err.println("New evaluation took: " + newResult + " vs. old evaluation: " + baseResult);
        }
        if (newResult < baseResult) {
            // new partitioning is faster, so partition further
            baseResult = newResult;
            // reorder split list
            if (DEBUG) {
                System.err.print("Current split list: ");
                for (int i = 0; i < splitList.size(); i++) {
                    System.err.print(splitList.get(i) + "  ");
                }
                System.err.println();
                System.err.print("Current pattern counts: ");
                for (int i = 0; i < splitList.size(); i++) {
                    System.err.print(siteCounts[splitList.get(i)] + "  ");
                }
                System.err.println();
            }
            int currentPatternCount = siteCounts[longestIndex];
            int findIndex = 0;
            for (int i = 0; i < splitList.size(); i++) {
                if (siteCounts[splitList.get(i)] > currentPatternCount) {
                    findIndex = i;
                }
            }
            if (DEBUG) {
                System.err.println("Current pattern count: " + currentPatternCount);
                System.err.println("Index found: " + findIndex + " with pattern count: " + siteCounts[findIndex]);
                System.err.println("Moving 0 to " + findIndex);
            }
            for (int i = 0; i < findIndex; i++) {
                int temp = splitList.get(i);
                splitList.set(i, splitList.get(i + 1));
                splitList.set(i + 1, temp);
            }
            if (DEBUG) {
                System.err.print("New split list: ");
                for (int i = 0; i < splitList.size(); i++) {
                    System.err.print(splitList.get(i) + "  ");
                }
                System.err.println();
                System.err.print("New pattern counts: ");
                for (int i = 0; i < splitList.size(); i++) {
                    System.err.print(siteCounts[splitList.get(i)] + "  ");
                }
                System.err.println();
            }
            timesRetried = 0;
        } else {
            if (DEBUG) {
                System.err.println("timesRetried = " + timesRetried + " vs. retry = " + retry);
            }
            // new partitioning is slower, so reinstate previous state unless RETRY is specified
            if (timesRetried < retry) {
                // try splitting further any way
                // do not set baseResult
                timesRetried++;
                if (DEBUG) {
                    System.err.println("RETRY number " + timesRetried);
                }
            } else {
                splitList.remove(0);
                if (splitList.size() == 0) {
                    notFinished = false;
                }
                // remove timesTried instanceCount(s)
                if (DEBUG) {
                    System.err.print("Removing " + (timesRetried + 1) + " instance count(s): " + instanceCount);
                }
                // instanceCount = --instanceCounts[longestIndex];
                instanceCounts[longestIndex] = instanceCounts[longestIndex] - (timesRetried + 1);
                instanceCount = instanceCounts[longestIndex];
                if (DEBUG) {
                    System.err.println(" -> " + instanceCount + " for likelihood " + longestIndex);
                }
                newList = new ArrayList<Likelihood>();
                for (int i = 0; i < instanceCount; i++) {
                    Patterns subPatterns = new Patterns(patterns[longestIndex], 0, 0, 1, i, instanceCount);
                    BeagleTreeLikelihood treeLikelihood = createTreeLikelihood(subPatterns, treeModels[longestIndex], branchModels[longestIndex], siteRateModels[longestIndex], branchRateModels[longestIndex], null, ambiguities[longestIndex], rescalingSchemes[longestIndex], isDelayRescalingUntilUnderflow[longestIndex], partialsRestrictions.get(longestIndex), xo);
                    treeLikelihood.setId(xo.getId() + "_" + longestIndex + "_" + i);
                    System.err.println(treeLikelihood.getId() + " created.");
                    newList.add(treeLikelihood);
                }
                /*for (int i = 0; i < newList.size()+1; i++) {
                        likelihoods.remove(currentLocation[longestIndex]);
                    }*/
                for (int i = 0; i < newList.size() + timesRetried + 1; i++) {
                    // TEST CODE START
                    unregisterAllModels((BeagleTreeLikelihood) likelihoods.get(currentLocation[longestIndex]));
                    // TEST CODE END
                    likelihoods.remove(currentLocation[longestIndex]);
                }
                for (int i = 0; i < newList.size(); i++) {
                    likelihoods.add(currentLocation[longestIndex], newList.get(i));
                }
                for (int i = longestIndex + 1; i < currentLocation.length; i++) {
                    currentLocation[i] -= (timesRetried + 1);
                }
                // likelihoods.remove(longestIndex);
                // likelihoods.add(longestIndex, new CompoundLikelihood(newList));
                compound = new TestThreadedCompoundLikelihood(likelihoods);
                // compound = new CompoundLikelihood(likelihoods);
                // compound = new ThreadedCompoundLikelihood(likelihoods);
                siteCounts[longestIndex] = (instanceCount + timesRetried + 1) * siteCounts[longestIndex] / instanceCount;
                longestSize = (instanceCount + timesRetried + 1) * longestSize / instanceCount;
                if (DEBUG) {
                    System.err.println("Pattern counts: ");
                    for (int i = 0; i < siteCounts.length; i++) {
                        System.err.print(siteCounts[i] + " ");
                    }
                    System.err.println();
                    System.err.println("Instance counts: ");
                    for (int i = 0; i < instanceCounts.length; i++) {
                        System.err.print(instanceCounts[i] + " ");
                    }
                    System.err.println();
                    System.err.println("Current locations: ");
                    for (int i = 0; i < currentLocation.length; i++) {
                        System.err.print(currentLocation[i] + " ");
                    }
                    System.err.println();
                }
                timesRetried = 0;
            }
        }
    }
    for (int i = 0; i < originalLikelihoods.size(); i++) {
        unregisterAllModels((BeagleTreeLikelihood) originalLikelihoods.get(i));
    }
    return compound;
}
Also used : Set(java.util.Set) BeagleTreeLikelihood(dr.evomodel.treelikelihood.BeagleTreeLikelihood) ArrayList(java.util.ArrayList) PartialsRescalingScheme(dr.evomodel.treelikelihood.PartialsRescalingScheme) BranchModel(dr.evomodel.branchmodel.BranchModel) MutableTreeModel(dr.evolution.tree.MutableTreeModel) TreeModel(dr.evomodel.tree.TreeModel) MutableTreeModel(dr.evolution.tree.MutableTreeModel) Patterns(dr.evolution.alignment.Patterns) SitePatterns(dr.evolution.alignment.SitePatterns) SitePatterns(dr.evolution.alignment.SitePatterns) BeagleTreeLikelihood(dr.evomodel.treelikelihood.BeagleTreeLikelihood) GammaSiteRateModel(dr.evomodel.siteratemodel.GammaSiteRateModel) BranchRateModel(dr.evomodel.branchratemodel.BranchRateModel) Map(java.util.Map)

Example 12 with MutableTreeModel

use of dr.evolution.tree.MutableTreeModel in project beast-mcmc by beast-dev.

the class AncestralTraitTreeModelParser method parseXMLObject.

/**
 * @return a tree object based on the XML element it was passed.
 */
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
    MutableTreeModel tree = (MutableTreeModel) xo.getChild(MutableTreeModel.class);
    List<AncestralTaxonInTree> ancestors = parseAllAncestors(tree, xo);
    int index = tree.getExternalNodeCount();
    for (AncestralTaxonInTree ancestor : ancestors) {
        ancestor.setIndex(index);
        ancestor.setNode(new NodeRef() {

            @Override
            public int getNumber() {
                return 0;
            }

            @Override
            public void setNumber(int n) {
            }
        });
    }
    if (xo.hasChildNamed(NODE_TRAITS)) {
        for (XMLObject cxo : xo.getAllChildren(NODE_TRAITS)) {
            parseNodeTraits(cxo, tree, ancestors);
        }
    }
    return new AncestralTraitTreeModel(xo.getId(), tree, ancestors);
}
Also used : NodeRef(dr.evolution.tree.NodeRef) AncestralTaxonInTree(dr.evomodel.continuous.AncestralTaxonInTree) MutableTreeModel(dr.evolution.tree.MutableTreeModel)

Example 13 with MutableTreeModel

use of dr.evolution.tree.MutableTreeModel in project beast-mcmc by beast-dev.

the class TransformedTreeModelParser method parseXMLObject.

/**
 * @return a tree object based on the XML element it was passed.
 */
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
    MutableTreeModel tree = (MutableTreeModel) xo.getChild(MutableTreeModel.class);
    Parameter scale = (Parameter) xo.getChild(Parameter.class);
    String id = tree.getId();
    if (!xo.hasId()) {
        // System.err.println("No check!");
        id = "transformed." + id;
    } else {
        // System.err.println("Why am I here?");
        id = xo.getId();
    }
    Logger.getLogger("dr.evomodel").info("Creating a transformed tree model, '" + id + "'");
    TreeTransform transform;
    String version = xo.getAttribute(VERSION, "generic");
    if (version.compareTo("new") == 0) {
        transform = new ProgressiveScalarTreeTransform(scale);
    } else if (version.compareTo("branch") == 0) {
        transform = new ProgressiveScalarTreeTransform(tree, scale);
    } else if (version.compareTo("ou") == 0) {
        transform = new OuScalarTreeTransform(scale);
    } else {
        transform = new SingleScalarTreeTransform(scale);
    }
    return new TransformedTreeModel(id, tree, transform);
}
Also used : MutableTreeModel(dr.evolution.tree.MutableTreeModel) Parameter(dr.inference.model.Parameter)

Example 14 with MutableTreeModel

use of dr.evolution.tree.MutableTreeModel in project beast-mcmc by beast-dev.

the class ContinuousDiffusionStatistic method getStatisticValue.

public double getStatisticValue(int dim) {
    double treeLength = 0;
    double treeDistance = 0;
    double totalMaxDistanceFromRoot = 0;
    // can only be used when cumulative and not associated with discrete state (not based on the distances on the branches from the root up that point)
    double maxDistanceFromRootCumulative = 0;
    double maxBranchDistanceFromRoot = 0;
    // can only be used when cumulative and not associated with discrete state (not based on the distances on the branches from the root up that point)
    double maxDistanceOverTimeFromRootWA = 0;
    double maxBranchDistanceOverTimeFromRootWA = 0;
    List<Double> rates = new ArrayList<Double>();
    List<Double> distances = new ArrayList<Double>();
    List<Double> times = new ArrayList<Double>();
    List<Double> traits = new ArrayList<Double>();
    List<double[]> traits2D = new ArrayList<double[]>();
    // double[] diffusionCoefficients =  null;
    List<Double> diffusionCoefficients = new ArrayList<Double>();
    double waDiffusionCoefficient = 0;
    double lowerHeight = heightLowers[dim];
    double upperHeight = Double.MAX_VALUE;
    if (heightLowers.length == 1) {
        upperHeight = heightUpper;
    } else {
        if (dim > 0) {
            if (!cumulative) {
                upperHeight = heightLowers[dim - 1];
            }
        }
    }
    for (AbstractMultivariateTraitLikelihood traitLikelihood : traitLikelihoods) {
        MutableTreeModel tree = traitLikelihood.getTreeModel();
        BranchRateModel branchRates = traitLikelihood.getBranchRateModel();
        String traitName = traitLikelihood.getTraitName();
        for (int i = 0; i < tree.getNodeCount(); i++) {
            NodeRef node = tree.getNode(i);
            if (node != tree.getRoot()) {
                NodeRef parentNode = tree.getParent(node);
                boolean testNode = true;
                if (branchset.equals(BranchSet.CLADE)) {
                    try {
                        testNode = inClade(tree, node, taxonList);
                    } catch (TreeUtils.MissingTaxonException mte) {
                        throw new RuntimeException(mte.toString());
                    }
                } else if (branchset.equals(BranchSet.BACKBONE)) {
                    if (backboneTime > 0) {
                        testNode = onAncestralPathTime(tree, node, backboneTime);
                    } else {
                        try {
                            testNode = onAncestralPathTaxa(tree, node, taxonList);
                        } catch (TreeUtils.MissingTaxonException mte) {
                            throw new RuntimeException(mte.toString());
                        }
                    }
                }
                if (testNode) {
                    if ((tree.getNodeHeight(parentNode) > lowerHeight) && (tree.getNodeHeight(node) < upperHeight)) {
                        double[] trait = traitLikelihood.getTraitForNode(tree, node, traitName);
                        double[] parentTrait = traitLikelihood.getTraitForNode(tree, parentNode, traitName);
                        double[] traitUp = parentTrait;
                        double[] traitLow = trait;
                        double timeUp = tree.getNodeHeight(parentNode);
                        double timeLow = tree.getNodeHeight(node);
                        double rate = (branchRates != null ? branchRates.getBranchRate(tree, node) : 1.0);
                        // System.out.println(rate);
                        MultivariateDiffusionModel diffModel = traitLikelihood.diffusionModel;
                        double[] precision = diffModel.getPrecisionParameter().getParameterValues();
                        History history = null;
                        if (stateString != null) {
                            history = setUpHistory(markovJumpLikelihood.getHistoryForNode(tree, node, SITE), markovJumpLikelihood.getStatesForNode(tree, node)[SITE], markovJumpLikelihood.getStatesForNode(tree, parentNode)[SITE], timeLow, timeUp);
                        }
                        if (tree.getNodeHeight(parentNode) > upperHeight) {
                            timeUp = upperHeight;
                            traitUp = imputeValue(trait, parentTrait, upperHeight, tree.getNodeHeight(node), tree.getNodeHeight(parentNode), precision, rate, trueNoise);
                            if (stateString != null) {
                                history.truncateUpper(timeUp);
                            }
                        }
                        if (tree.getNodeHeight(node) < lowerHeight) {
                            timeLow = lowerHeight;
                            traitLow = imputeValue(trait, parentTrait, lowerHeight, tree.getNodeHeight(node), tree.getNodeHeight(parentNode), precision, rate, trueNoise);
                            if (stateString != null) {
                                history.truncateLower(timeLow);
                            }
                        }
                        if (dimension > traitLow.length) {
                            System.err.println("specified trait dimension for continuous trait summary, " + dimension + ", is > dimensionality of trait, " + traitLow.length + ". No trait summarized.");
                        } else {
                            traits.add(traitLow[(dimension - 1)]);
                        }
                        if (traitLow.length == 2) {
                            traits2D.add(traitLow);
                        }
                        double time;
                        if (stateString != null) {
                            if (!history.returnMismatch()) {
                                time = history.getStateTime(stateString);
                            } else {
                                time = NaN;
                            }
                        // System.out.println("time before = "+(timeUp - timeLow)+", time after= "+time);
                        } else {
                            time = timeUp - timeLow;
                        }
                        treeLength += time;
                        times.add(time);
                        // setting up continuous trait values for heights in discrete trait history
                        if (stateString != null) {
                            history.setTraitsforHeights(traitUp, traitLow, precision, rate, trueNoise);
                        }
                        double[] rootTrait = traitLikelihood.getTraitForNode(tree, tree.getRoot(), traitName);
                        double timeFromRoot = (tree.getNodeHeight(tree.getRoot()) - timeLow);
                        if (useGreatCircleDistances && (trait.length == 2)) {
                            // Great Circle distance
                            double distance;
                            if (stateString != null) {
                                if (!history.returnMismatch()) {
                                    distance = history.getStateGreatCircleDistance(stateString);
                                } else {
                                    distance = NaN;
                                }
                            } else {
                                distance = getGreatCircleDistance(traitLow, traitUp);
                            }
                            distances.add(distance);
                            if (time > 0) {
                                treeDistance += distance;
                                double dc = Math.pow(distance, 2) / (4 * time);
                                diffusionCoefficients.add(dc);
                                waDiffusionCoefficient += (dc * time);
                                rates.add(distance / time);
                            }
                            SphericalPolarCoordinates rootCoord = new SphericalPolarCoordinates(rootTrait[0], rootTrait[1]);
                            double tempDistanceFromRootLow = rootCoord.distance(new SphericalPolarCoordinates(traitUp[0], traitUp[1]));
                            if (tempDistanceFromRootLow > totalMaxDistanceFromRoot) {
                                totalMaxDistanceFromRoot = tempDistanceFromRootLow;
                                if (stateString != null) {
                                    double[] stateTimeDistance = getStateTimeAndDistanceFromRoot(tree, node, timeLow, traitLikelihood, traitName, traitLow, precision, branchRates, true);
                                    if (stateTimeDistance[0] > 0) {
                                        if (!history.returnMismatch()) {
                                            maxDistanceFromRootCumulative = tempDistanceFromRootLow * (stateTimeDistance[0] / timeFromRoot);
                                            maxDistanceOverTimeFromRootWA = maxDistanceFromRootCumulative / stateTimeDistance[0];
                                            maxBranchDistanceFromRoot = stateTimeDistance[1];
                                            maxBranchDistanceOverTimeFromRootWA = stateTimeDistance[1] / stateTimeDistance[0];
                                        } else {
                                            maxDistanceFromRootCumulative = NaN;
                                            maxDistanceOverTimeFromRootWA = NaN;
                                            maxBranchDistanceFromRoot = NaN;
                                            maxBranchDistanceOverTimeFromRootWA = NaN;
                                        }
                                    }
                                } else {
                                    maxDistanceFromRootCumulative = tempDistanceFromRootLow;
                                    maxDistanceOverTimeFromRootWA = tempDistanceFromRootLow / timeFromRoot;
                                    double[] timeDistance = getTimeAndDistanceFromRoot(tree, node, timeLow, traitLikelihood, traitName, traitLow, true);
                                    maxBranchDistanceFromRoot = timeDistance[1];
                                    maxBranchDistanceOverTimeFromRootWA = timeDistance[1] / timeDistance[0];
                                }
                                // distance between traitLow and traitUp for maxDistanceFromRootCumulative
                                if (timeUp == upperHeight) {
                                    if (time > 0) {
                                        maxDistanceFromRootCumulative = distance;
                                        maxDistanceOverTimeFromRootWA = distance / time;
                                        maxBranchDistanceFromRoot = distance;
                                        maxBranchDistanceOverTimeFromRootWA = distance / time;
                                    }
                                }
                            }
                        } else {
                            double distance;
                            if (stateString != null) {
                                if (!history.returnMismatch()) {
                                    distance = history.getStateNativeDistance(stateString);
                                } else {
                                    distance = NaN;
                                }
                            } else {
                                distance = getNativeDistance(traitLow, traitUp);
                            }
                            distances.add(distance);
                            if (time > 0) {
                                treeDistance += distance;
                                double dc = Math.pow(distance, 2) / (4 * time);
                                diffusionCoefficients.add(dc);
                                waDiffusionCoefficient += dc * time;
                                rates.add(distance / time);
                            }
                            double tempDistanceFromRoot = getNativeDistance(traitLow, rootTrait);
                            if (tempDistanceFromRoot > totalMaxDistanceFromRoot) {
                                totalMaxDistanceFromRoot = tempDistanceFromRoot;
                                if (stateString != null) {
                                    double[] stateTimeDistance = getStateTimeAndDistanceFromRoot(tree, node, timeLow, traitLikelihood, traitName, traitLow, precision, branchRates, false);
                                    if (stateTimeDistance[0] > 0) {
                                        if (!history.returnMismatch()) {
                                            maxDistanceFromRootCumulative = tempDistanceFromRoot * (stateTimeDistance[0] / timeFromRoot);
                                            maxDistanceOverTimeFromRootWA = maxDistanceFromRootCumulative / stateTimeDistance[0];
                                            maxBranchDistanceFromRoot = stateTimeDistance[1];
                                            maxBranchDistanceOverTimeFromRootWA = stateTimeDistance[1] / stateTimeDistance[0];
                                        } else {
                                            maxDistanceFromRootCumulative = NaN;
                                            maxDistanceOverTimeFromRootWA = NaN;
                                            maxBranchDistanceFromRoot = NaN;
                                            maxBranchDistanceOverTimeFromRootWA = NaN;
                                        }
                                    }
                                } else {
                                    maxDistanceFromRootCumulative = tempDistanceFromRoot;
                                    maxDistanceOverTimeFromRootWA = tempDistanceFromRoot / timeFromRoot;
                                    double[] timeDistance = getTimeAndDistanceFromRoot(tree, node, timeLow, traitLikelihood, traitName, traitLow, false);
                                    maxBranchDistanceFromRoot = timeDistance[1];
                                    maxBranchDistanceOverTimeFromRootWA = timeDistance[1] / timeDistance[0];
                                }
                                // distance between traitLow and traitUp for maxDistanceFromRootCumulative
                                if (timeUp == upperHeight) {
                                    if (time > 0) {
                                        maxDistanceFromRootCumulative = distance;
                                        maxDistanceOverTimeFromRootWA = distance / time;
                                        maxBranchDistanceFromRoot = distance;
                                        maxBranchDistanceOverTimeFromRootWA = distance / time;
                                    }
                                }
                            }
                        }
                    }
                }
            }
        }
    }
    if (summaryStat == summaryStatistic.DIFFUSION_RATE) {
        if (summaryMode == Mode.AVERAGE) {
            return DiscreteStatistics.mean(toArray(rates));
        } else if (summaryMode == Mode.MEDIAN) {
            return DiscreteStatistics.median(toArray(rates));
        } else if (summaryMode == Mode.COEFFICIENT_OF_VARIATION) {
            final double mean = DiscreteStatistics.mean(toArray(rates));
            return Math.sqrt(DiscreteStatistics.variance(toArray(rates), mean)) / mean;
        // weighted average
        } else {
            return treeDistance / treeLength;
        }
    } else if (summaryStat == summaryStatistic.TRAIT) {
        if (summaryMode == Mode.MEDIAN) {
            return DiscreteStatistics.median(toArray(traits));
        } else if (summaryMode == Mode.COEFFICIENT_OF_VARIATION) {
            // don't compute mean twice
            final double mean = DiscreteStatistics.mean(toArray(traits));
            return Math.sqrt(DiscreteStatistics.variance(toArray(traits), mean)) / mean;
        // default is average. A warning is thrown by the parser when trying to use WEIGHTED_AVERAGE
        } else {
            return DiscreteStatistics.mean(toArray(traits));
        }
    } else if (summaryStat == summaryStatistic.TRAIT2DAREA) {
        double area = getAreaFrom2Dtraits(traits2D, 0.99);
        return area;
    } else if (summaryStat == summaryStatistic.DIFFUSION_COEFFICIENT) {
        if (summaryMode == Mode.AVERAGE) {
            return DiscreteStatistics.mean(toArray(diffusionCoefficients));
        } else if (summaryMode == Mode.MEDIAN) {
            return DiscreteStatistics.median(toArray(diffusionCoefficients));
        } else if (summaryMode == Mode.COEFFICIENT_OF_VARIATION) {
            // don't compute mean twice
            final double mean = DiscreteStatistics.mean(toArray(diffusionCoefficients));
            return Math.sqrt(DiscreteStatistics.variance(toArray(diffusionCoefficients), mean)) / mean;
        } else {
            return waDiffusionCoefficient / treeLength;
        }
    // wavefront distance
    // TODO: restrict to non state-specific wavefrontDistance/rate
    } else if (summaryStat == summaryStatistic.WAVEFRONT_DISTANCE) {
        return maxDistanceFromRootCumulative;
    // return maxBranchDistanceFromRoot;
    } else if (summaryStat == summaryStatistic.WAVEFRONT_DISTANCE_PHYLO) {
        return maxBranchDistanceFromRoot;
    // wavefront rate, only weighted average TODO: extend for average, median, COEFFICIENT_OF_VARIATION?
    } else if (summaryStat == summaryStatistic.WAVEFRONT_RATE) {
        return maxDistanceOverTimeFromRootWA;
    // return maxBranchDistanceOverTimeFromRootWA;
    } else if (summaryStat == summaryStatistic.DIFFUSION_DISTANCE) {
        return treeDistance;
    // DIFFUSION_TIME
    } else if (summaryStat == summaryStatistic.DISTANCE_TIME_CORRELATION) {
        if (summaryMode == Mode.SPEARMAN) {
            return getSpearmanRho(convertDoubles(times), convertDoubles(distances));
        } else if (summaryMode == Mode.R_SQUARED) {
            Regression r = new Regression(convertDoubles(times), convertDoubles(distances));
            return r.getRSquared();
        } else {
            Regression r = new Regression(convertDoubles(times), convertDoubles(distances));
            return r.getCorrelationCoefficient();
        }
    } else {
        return treeLength;
    }
}
Also used : SphericalPolarCoordinates(dr.geo.math.SphericalPolarCoordinates) Regression(dr.stats.Regression) NodeRef(dr.evolution.tree.NodeRef) BranchRateModel(dr.evomodel.branchratemodel.BranchRateModel) MutableTreeModel(dr.evolution.tree.MutableTreeModel) TreeUtils(dr.evolution.tree.TreeUtils)

Example 15 with MutableTreeModel

use of dr.evolution.tree.MutableTreeModel in project beast-mcmc by beast-dev.

the class MultivariateTraitUtils method computeTreeTraitMean.

public static double[] computeTreeTraitMean(FullyConjugateMultivariateTraitLikelihood trait, double[] rootValue, boolean conditionOnRoot) {
    double[] root = trait.getPriorMean();
    if (conditionOnRoot) {
        System.err.println("WARNING: Not yet fully implemented (conditioning on root in simulator)");
        // root = new double[root.length];
        root = rootValue;
    }
    final int nTaxa = trait.getTreeModel().getExternalNodeCount();
    double[] mean = new double[root.length * nTaxa];
    for (int i = 0; i < nTaxa; ++i) {
        System.arraycopy(root, 0, mean, i * root.length, root.length);
    }
    if (trait.driftModels != null) {
        MutableTreeModel myTreeModel = trait.getTreeModel();
        for (int i = 0; i < nTaxa; ++i) {
            double[] shiftContribution = getShiftContributionToMean(myTreeModel.getExternalNode(i), trait);
            for (int j = 0; j < trait.dimTrait; ++j) {
                mean[i * trait.dimTrait + j] = mean[i * trait.dimTrait + j] + shiftContribution[j];
            }
        }
    }
    return mean;
}
Also used : MutableTreeModel(dr.evolution.tree.MutableTreeModel)

Aggregations

MutableTreeModel (dr.evolution.tree.MutableTreeModel)17 NodeRef (dr.evolution.tree.NodeRef)9 BranchRateModel (dr.evomodel.branchratemodel.BranchRateModel)4 Parameter (dr.inference.model.Parameter)4 SphericalPolarCoordinates (dr.geo.math.SphericalPolarCoordinates)3 ArrayList (java.util.ArrayList)3 Patterns (dr.evolution.alignment.Patterns)2 TreeUtils (dr.evolution.tree.TreeUtils)2 BranchModel (dr.evomodel.branchmodel.BranchModel)2 EpochBranchModel (dr.evomodel.branchmodel.EpochBranchModel)2 GammaSiteRateModel (dr.evomodel.siteratemodel.GammaSiteRateModel)2 SubstitutionModel (dr.evomodel.substmodel.SubstitutionModel)2 BeagleTreeLikelihood (dr.evomodel.treelikelihood.BeagleTreeLikelihood)2 PartialsRescalingScheme (dr.evomodel.treelikelihood.PartialsRescalingScheme)2 Set (java.util.Set)2 PatternList (dr.evolution.alignment.PatternList)1 SitePatterns (dr.evolution.alignment.SitePatterns)1 TaxonList (dr.evolution.util.TaxonList)1 HomogeneousBranchModel (dr.evomodel.branchmodel.HomogeneousBranchModel)1 AncestralTaxonInTree (dr.evomodel.continuous.AncestralTaxonInTree)1