Search in sources :

Example 36 with BranchRateModel

use of dr.evomodel.branchratemodel.BranchRateModel in project beast-mcmc by beast-dev.

the class OptimizedBeagleTreeLikelihoodParser method parseXMLObject.

public Object parseXMLObject(XMLObject xo) throws XMLParseException {
    //Default of 100 likelihood calculations for calibration
    int calibrate = 100;
    if (xo.hasAttribute(CALIBRATE)) {
        calibrate = xo.getIntegerAttribute(CALIBRATE);
    }
    //Default: only try the next split up, unless a RETRY value is specified in the XML
    int retry = 0;
    if (xo.hasAttribute(RETRY)) {
        retry = xo.getIntegerAttribute(RETRY);
    }
    int childCount = xo.getChildCount();
    List<Likelihood> likelihoods = new ArrayList<Likelihood>();
    //TEST
    List<Likelihood> originalLikelihoods = new ArrayList<Likelihood>();
    for (int i = 0; i < childCount; i++) {
        likelihoods.add((Likelihood) xo.getChild(i));
        originalLikelihoods.add((Likelihood) xo.getChild(i));
    }
    if (DEBUG) {
        System.err.println("-----");
        System.err.println(childCount + " BeagleTreeLikelihoods added.");
    }
    int[] instanceCounts = new int[childCount];
    for (int i = 0; i < childCount; i++) {
        instanceCounts[i] = 1;
    }
    int[] currentLocation = new int[childCount];
    for (int i = 0; i < childCount; i++) {
        currentLocation[i] = i;
    }
    int[] siteCounts = new int[childCount];
    //store everything for later use
    SitePatterns[] patterns = new SitePatterns[childCount];
    TreeModel[] treeModels = new TreeModel[childCount];
    BranchModel[] branchModels = new BranchModel[childCount];
    GammaSiteRateModel[] siteRateModels = new GammaSiteRateModel[childCount];
    BranchRateModel[] branchRateModels = new BranchRateModel[childCount];
    boolean[] ambiguities = new boolean[childCount];
    PartialsRescalingScheme[] rescalingSchemes = new PartialsRescalingScheme[childCount];
    boolean[] isDelayRescalingUntilUnderflow = new boolean[childCount];
    List<Map<Set<String>, Parameter>> partialsRestrictions = new ArrayList<Map<Set<String>, Parameter>>();
    for (int i = 0; i < likelihoods.size(); i++) {
        patterns[i] = (SitePatterns) ((BeagleTreeLikelihood) likelihoods.get(i)).getPatternsList();
        siteCounts[i] = patterns[i].getPatternCount();
        treeModels[i] = ((BeagleTreeLikelihood) likelihoods.get(i)).getTreeModel();
        branchModels[i] = ((BeagleTreeLikelihood) likelihoods.get(i)).getBranchModel();
        siteRateModels[i] = (GammaSiteRateModel) ((BeagleTreeLikelihood) likelihoods.get(i)).getSiteRateModel();
        branchRateModels[i] = ((BeagleTreeLikelihood) likelihoods.get(i)).getBranchRateModel();
        ambiguities[i] = ((BeagleTreeLikelihood) likelihoods.get(i)).useAmbiguities();
        rescalingSchemes[i] = ((BeagleTreeLikelihood) likelihoods.get(i)).getRescalingScheme();
        isDelayRescalingUntilUnderflow[i] = ((BeagleTreeLikelihood) likelihoods.get(i)).isDelayRescalingUntilUnderflow();
        partialsRestrictions.add(i, ((BeagleTreeLikelihood) likelihoods.get(i)).getPartialsRestrictions());
    }
    if (DEBUG) {
        System.err.println("Pattern counts: ");
        for (int i = 0; i < siteCounts.length; i++) {
            System.err.println(siteCounts[i] + "   vs.    " + patterns[i].getPatternCount());
        }
        System.err.println();
        System.err.println("Instance counts: ");
        for (int i = 0; i < instanceCounts.length; i++) {
            System.err.print(instanceCounts[i] + " ");
        }
        System.err.println();
        System.err.println("Current locations: ");
        for (int i = 0; i < currentLocation.length; i++) {
            System.err.print(currentLocation[i] + " ");
        }
        System.err.println();
    }
    TestThreadedCompoundLikelihood compound = new TestThreadedCompoundLikelihood(likelihoods);
    if (DEBUG) {
        System.err.println("Timing estimates for each of the " + calibrate + " likelihood calculations:");
    }
    double start = System.nanoTime();
    for (int i = 0; i < calibrate; i++) {
        if (DEBUG) {
            //double debugStart = System.nanoTime();
            compound.makeDirty();
            compound.getLogLikelihood();
        //double debugEnd = System.nanoTime();
        //System.err.println(debugEnd - debugStart);
        } else {
            compound.makeDirty();
            compound.getLogLikelihood();
        }
    }
    double end = System.nanoTime();
    double baseResult = end - start;
    if (DEBUG) {
        System.err.println("Starting evaluation took: " + baseResult);
    }
    int longestIndex = 0;
    int longestSize = siteCounts[0];
    //START TEST CODE
    /*System.err.println("Detailed evaluation times: ");
        long[] evaluationTimes = compound.getEvaluationTimes();
        int[] evaluationCounts = compound.getEvaluationCounts();
        long longest = evaluationTimes[0];
        for (int i = 0; i < evaluationTimes.length; i++) {
        	System.err.println(i + ": time=" + evaluationTimes[i] + "   count=" + evaluationCounts[i]);
            if (evaluationTimes[i] > longest) {
            	longest = evaluationTimes[i];
        	}
        }*/
    //END TEST CODE
    /*if (SPLIT_BY_PATTERN_COUNT) {

        	boolean notFinished = true;

        	while (notFinished) {

        		for (int i = 0; i < siteCounts.length; i++) {
        			if (siteCounts[i] > longestSize) {
        				longestIndex = i;
        				longestSize = siteCounts[longestIndex];
        			}
        		}
        		System.err.println("Split likelihood " + longestIndex + " with pattern count " + longestSize);

        		//split it in 2
        		int instanceCount = ++instanceCounts[longestIndex];

        		List<Likelihood> newList = new ArrayList<Likelihood>();
        		for (int i = 0; i < instanceCount; i++) {
        			Patterns subPatterns = new Patterns(patterns[longestIndex], 0, 0, 1, i, instanceCount);

        			BeagleTreeLikelihood treeLikelihood = createTreeLikelihood(
        					subPatterns, treeModels[longestIndex], branchModels[longestIndex], siteRateModels[longestIndex], branchRateModels[longestIndex],
        					null, 
        					ambiguities[longestIndex], rescalingSchemes[longestIndex], partialsRestrictions.get(longestIndex),
        					xo);

        			treeLikelihood.setId(xo.getId() + "_" + instanceCount);
        			newList.add(treeLikelihood);
        		}
        		for (int i = 0; i < newList.size()-1; i++) {
        			likelihoods.remove(currentLocation[longestIndex]);
        		}
        		//likelihoods.remove(longestIndex);
        		//likelihoods.add(longestIndex, new CompoundLikelihood(newList));
        		for (int i = 0; i < newList.size(); i++) {
        			likelihoods.add(currentLocation[longestIndex], newList.get(i));
        		}
        		for (int i = longestIndex+1; i < currentLocation.length; i++) {
        			currentLocation[i]++;
        		}
        		//compound = new ThreadedCompoundLikelihood(likelihoods);
        		compound = new CompoundLikelihood(likelihoods);
        		siteCounts[longestIndex] = (instanceCount-1)*siteCounts[longestIndex]/instanceCount;
        		longestSize = (instanceCount-1)*longestSize/instanceCount;

        		//check number of likelihoods
        		System.err.println("Number of BeagleTreeLikelihoods: " + compound.getLikelihoodCount());
        		System.err.println("Pattern counts: ");
        		for (int i = 0;i < siteCounts.length; i++) {
        			System.err.print(siteCounts[i] + " ");
        		}
        		System.err.println();
        		System.err.println("Instance counts: ");
        		for (int i = 0;i < instanceCounts.length; i++) {
        			System.err.print(instanceCounts[i] + " ");
        		}
        		System.err.println();
        		System.err.println("Current locations: ");
        		for (int i = 0;i < currentLocation.length; i++) {
        			System.err.print(currentLocation[i] + " ");
        		}
        		System.err.println();

        		//evaluate speed
        		start = System.nanoTime();
        		for (int i = 0; i < TEST_RUNS; i++) {
        			compound.makeDirty();
        			compound.getLogLikelihood();
        		}
        		end = System.nanoTime();
        		double newResult = end - start;
        		System.err.println("New evaluation took: " + newResult + " vs. old evaluation: " + baseResult);

        		if (newResult < baseResult) {
            		baseResult = newResult;
            	} else {
            		notFinished = false;

            		//remove 1 instanceCount
            		System.err.print("Removing 1 instance count: " + instanceCount);
            		instanceCount = --instanceCounts[longestIndex];
            		System.err.println(" -> " + instanceCount + " for likelihood " + longestIndex);
            		newList = new ArrayList<Likelihood>();
                	for (int i = 0; i < instanceCount; i++) {
                		Patterns subPatterns = new Patterns(patterns[longestIndex], 0, 0, 1, i, instanceCount);

                		BeagleTreeLikelihood treeLikelihood = createTreeLikelihood(
                                subPatterns, treeModels[longestIndex], branchModels[longestIndex], siteRateModels[longestIndex], branchRateModels[longestIndex],
                                null, 
                                ambiguities[longestIndex], rescalingSchemes[longestIndex], partialsRestrictions.get(longestIndex),
                                xo);

                        treeLikelihood.setId(xo.getId() + "_" + instanceCount);
                        newList.add(treeLikelihood);
                	}
                	for (int i = 0; i < newList.size()+1; i++) {
            			likelihoods.remove(currentLocation[longestIndex]);
            		}
                	for (int i = 0; i < newList.size(); i++) {
                		likelihoods.add(currentLocation[longestIndex], newList.get(i));
                	}
                	for (int i = longestIndex+1; i < currentLocation.length; i++) {
            			currentLocation[i]--;
            		}
                	//likelihoods.remove(longestIndex);
                	//likelihoods.add(longestIndex, new CompoundLikelihood(newList));

                	//compound = new ThreadedCompoundLikelihood(likelihoods);
                	compound = new CompoundLikelihood(likelihoods);
                	siteCounts[longestIndex] = (instanceCount+1)*siteCounts[longestIndex]/instanceCount;
                	longestSize = (instanceCount+1)*longestSize/instanceCount;

                	System.err.println("Pattern counts: ");
                	for (int i = 0;i < siteCounts.length; i++) {
                		System.err.print(siteCounts[i] + " ");
                	}
                	System.err.println();
                	System.err.println("Instance counts: ");
                	for (int i = 0;i < instanceCounts.length; i++) {
                		System.err.print(instanceCounts[i] + " ");
                	}
                	System.err.println();
                	System.err.println("Current locations: ");
            		for (int i = 0;i < currentLocation.length; i++) {
            			System.err.print(currentLocation[i] + " ");
            		}
            		System.err.println();

            	}

        	}

        } else {*/
    //Try splitting the same likelihood until no further improvement, then move on towards the next one
    boolean notFinished = true;
    //construct list with likelihoods to split up
    List<Integer> splitList = new ArrayList<Integer>();
    for (int i = 0; i < siteCounts.length; i++) {
        int top = 0;
        for (int j = 0; j < siteCounts.length; j++) {
            if (siteCounts[j] > siteCounts[top]) {
                top = j;
            }
        }
        siteCounts[top] = 0;
        splitList.add(top);
    }
    for (int i = 0; i < likelihoods.size(); i++) {
        siteCounts[i] = patterns[i].getPatternCount();
        if (DEBUG) {
            System.err.println("Site count " + i + " = " + siteCounts[i]);
        }
    }
    if (DEBUG) {
        //print list
        System.err.print("Ordered list of likelihoods to be evaluated: ");
        for (int i = 0; i < splitList.size(); i++) {
            System.err.print(splitList.get(i) + " ");
        }
        System.err.println();
    }
    int timesRetried = 0;
    while (notFinished) {
        //split it in 1 more piece
        longestIndex = splitList.get(0);
        int instanceCount = ++instanceCounts[longestIndex];
        List<Likelihood> newList = new ArrayList<Likelihood>();
        for (int i = 0; i < instanceCount; i++) {
            Patterns subPatterns = new Patterns(patterns[longestIndex], 0, 0, 1, i, instanceCount);
            BeagleTreeLikelihood treeLikelihood = createTreeLikelihood(subPatterns, treeModels[longestIndex], branchModels[longestIndex], siteRateModels[longestIndex], branchRateModels[longestIndex], null, ambiguities[longestIndex], rescalingSchemes[longestIndex], isDelayRescalingUntilUnderflow[longestIndex], partialsRestrictions.get(longestIndex), xo);
            treeLikelihood.setId(xo.getId() + "_" + longestIndex + "_" + i);
            System.err.println(treeLikelihood.getId() + " created.");
            newList.add(treeLikelihood);
        }
        for (int i = 0; i < newList.size() - 1; i++) {
            likelihoods.remove(currentLocation[longestIndex]);
        }
        //likelihoods.add(longestIndex, new CompoundLikelihood(newList));
        for (int i = 0; i < newList.size(); i++) {
            likelihoods.add(currentLocation[longestIndex], newList.get(i));
        }
        for (int i = longestIndex + 1; i < currentLocation.length; i++) {
            currentLocation[i]++;
        }
        compound = new TestThreadedCompoundLikelihood(likelihoods);
        //compound = new CompoundLikelihood(likelihoods);
        //compound = new ThreadedCompoundLikelihood(likelihoods);
        siteCounts[longestIndex] = (instanceCount - 1) * siteCounts[longestIndex] / instanceCount;
        longestSize = (instanceCount - 1) * longestSize / instanceCount;
        if (DEBUG) {
            //check number of likelihoods
            System.err.println("Number of BeagleTreeLikelihoods: " + compound.getLikelihoodCount());
            System.err.println("Pattern counts: ");
            for (int i = 0; i < siteCounts.length; i++) {
                System.err.print(siteCounts[i] + " ");
            }
            System.err.println();
            System.err.println("Instance counts: ");
            for (int i = 0; i < instanceCounts.length; i++) {
                System.err.print(instanceCounts[i] + " ");
            }
            System.err.println();
            System.err.println("Current locations: ");
            for (int i = 0; i < currentLocation.length; i++) {
                System.err.print(currentLocation[i] + " ");
            }
            System.err.println();
        }
        //evaluate speed
        if (DEBUG) {
            System.err.println("Timing estimates for each of the " + calibrate + " likelihood calculations:");
        }
        start = System.nanoTime();
        for (int i = 0; i < calibrate; i++) {
            if (DEBUG) {
                //double debugStart = System.nanoTime();
                compound.makeDirty();
                compound.getLogLikelihood();
            //double debugEnd = System.nanoTime();
            //System.err.println(debugEnd - debugStart);
            } else {
                compound.makeDirty();
                compound.getLogLikelihood();
            }
        }
        end = System.nanoTime();
        double newResult = end - start;
        if (DEBUG) {
            System.err.println("New evaluation took: " + newResult + " vs. old evaluation: " + baseResult);
        }
        if (newResult < baseResult) {
            //new partitioning is faster, so partition further
            baseResult = newResult;
            //reorder split list
            if (DEBUG) {
                System.err.print("Current split list: ");
                for (int i = 0; i < splitList.size(); i++) {
                    System.err.print(splitList.get(i) + "  ");
                }
                System.err.println();
                System.err.print("Current pattern counts: ");
                for (int i = 0; i < splitList.size(); i++) {
                    System.err.print(siteCounts[splitList.get(i)] + "  ");
                }
                System.err.println();
            }
            int currentPatternCount = siteCounts[longestIndex];
            int findIndex = 0;
            for (int i = 0; i < splitList.size(); i++) {
                if (siteCounts[splitList.get(i)] > currentPatternCount) {
                    findIndex = i;
                }
            }
            if (DEBUG) {
                System.err.println("Current pattern count: " + currentPatternCount);
                System.err.println("Index found: " + findIndex + " with pattern count: " + siteCounts[findIndex]);
                System.err.println("Moving 0 to " + findIndex);
            }
            for (int i = 0; i < findIndex; i++) {
                int temp = splitList.get(i);
                splitList.set(i, splitList.get(i + 1));
                splitList.set(i + 1, temp);
            }
            if (DEBUG) {
                System.err.print("New split list: ");
                for (int i = 0; i < splitList.size(); i++) {
                    System.err.print(splitList.get(i) + "  ");
                }
                System.err.println();
                System.err.print("New pattern counts: ");
                for (int i = 0; i < splitList.size(); i++) {
                    System.err.print(siteCounts[splitList.get(i)] + "  ");
                }
                System.err.println();
            }
            timesRetried = 0;
        } else {
            if (DEBUG) {
                System.err.println("timesRetried = " + timesRetried + " vs. retry = " + retry);
            }
            //new partitioning is slower, so reinstate previous state unless RETRY is specified
            if (timesRetried < retry) {
                //try splitting further any way
                //do not set baseResult
                timesRetried++;
                if (DEBUG) {
                    System.err.println("RETRY number " + timesRetried);
                }
            } else {
                splitList.remove(0);
                if (splitList.size() == 0) {
                    notFinished = false;
                }
                //remove timesTried instanceCount(s)
                if (DEBUG) {
                    System.err.print("Removing " + (timesRetried + 1) + " instance count(s): " + instanceCount);
                }
                //instanceCount = --instanceCounts[longestIndex];
                instanceCounts[longestIndex] = instanceCounts[longestIndex] - (timesRetried + 1);
                instanceCount = instanceCounts[longestIndex];
                if (DEBUG) {
                    System.err.println(" -> " + instanceCount + " for likelihood " + longestIndex);
                }
                newList = new ArrayList<Likelihood>();
                for (int i = 0; i < instanceCount; i++) {
                    Patterns subPatterns = new Patterns(patterns[longestIndex], 0, 0, 1, i, instanceCount);
                    BeagleTreeLikelihood treeLikelihood = createTreeLikelihood(subPatterns, treeModels[longestIndex], branchModels[longestIndex], siteRateModels[longestIndex], branchRateModels[longestIndex], null, ambiguities[longestIndex], rescalingSchemes[longestIndex], isDelayRescalingUntilUnderflow[longestIndex], partialsRestrictions.get(longestIndex), xo);
                    treeLikelihood.setId(xo.getId() + "_" + longestIndex + "_" + i);
                    System.err.println(treeLikelihood.getId() + " created.");
                    newList.add(treeLikelihood);
                }
                /*for (int i = 0; i < newList.size()+1; i++) {
                        likelihoods.remove(currentLocation[longestIndex]);
                    }*/
                for (int i = 0; i < newList.size() + timesRetried + 1; i++) {
                    //TEST CODE START
                    unregisterAllModels((BeagleTreeLikelihood) likelihoods.get(currentLocation[longestIndex]));
                    //TEST CODE END
                    likelihoods.remove(currentLocation[longestIndex]);
                }
                for (int i = 0; i < newList.size(); i++) {
                    likelihoods.add(currentLocation[longestIndex], newList.get(i));
                }
                for (int i = longestIndex + 1; i < currentLocation.length; i++) {
                    currentLocation[i] -= (timesRetried + 1);
                }
                //likelihoods.remove(longestIndex);
                //likelihoods.add(longestIndex, new CompoundLikelihood(newList));
                compound = new TestThreadedCompoundLikelihood(likelihoods);
                //compound = new CompoundLikelihood(likelihoods);
                //compound = new ThreadedCompoundLikelihood(likelihoods);
                siteCounts[longestIndex] = (instanceCount + timesRetried + 1) * siteCounts[longestIndex] / instanceCount;
                longestSize = (instanceCount + timesRetried + 1) * longestSize / instanceCount;
                if (DEBUG) {
                    System.err.println("Pattern counts: ");
                    for (int i = 0; i < siteCounts.length; i++) {
                        System.err.print(siteCounts[i] + " ");
                    }
                    System.err.println();
                    System.err.println("Instance counts: ");
                    for (int i = 0; i < instanceCounts.length; i++) {
                        System.err.print(instanceCounts[i] + " ");
                    }
                    System.err.println();
                    System.err.println("Current locations: ");
                    for (int i = 0; i < currentLocation.length; i++) {
                        System.err.print(currentLocation[i] + " ");
                    }
                    System.err.println();
                }
                timesRetried = 0;
            }
        }
    }
    for (int i = 0; i < originalLikelihoods.size(); i++) {
        unregisterAllModels((BeagleTreeLikelihood) originalLikelihoods.get(i));
    }
    return compound;
}
Also used : Set(java.util.Set) BeagleTreeLikelihood(dr.evomodel.treelikelihood.BeagleTreeLikelihood) ArrayList(java.util.ArrayList) PartialsRescalingScheme(dr.evomodel.treelikelihood.PartialsRescalingScheme) BranchModel(dr.evomodel.branchmodel.BranchModel) TreeModel(dr.evomodel.tree.TreeModel) Patterns(dr.evolution.alignment.Patterns) SitePatterns(dr.evolution.alignment.SitePatterns) SitePatterns(dr.evolution.alignment.SitePatterns) BeagleTreeLikelihood(dr.evomodel.treelikelihood.BeagleTreeLikelihood) GammaSiteRateModel(dr.evomodel.siteratemodel.GammaSiteRateModel) BranchRateModel(dr.evomodel.branchratemodel.BranchRateModel) Map(java.util.Map)

Example 37 with BranchRateModel

use of dr.evomodel.branchratemodel.BranchRateModel in project beast-mcmc by beast-dev.

the class ALSTreeLikelihoodParser method parseXMLObject.

public Object parseXMLObject(XMLObject xo) throws XMLParseException {
    boolean useAmbiguities = false;
    boolean storePartials = true;
    if (xo.hasAttribute(TreeLikelihoodParser.USE_AMBIGUITIES)) {
        useAmbiguities = xo.getBooleanAttribute(TreeLikelihoodParser.USE_AMBIGUITIES);
    }
    if (xo.hasAttribute(TreeLikelihoodParser.STORE_PARTIALS)) {
        storePartials = xo.getBooleanAttribute(TreeLikelihoodParser.STORE_PARTIALS);
    }
    boolean integrateGainRate = xo.getBooleanAttribute(INTEGRATE_GAIN_RATE);
    //AbstractObservationProcess observationProcess = (AbstractObservationProcess) xo.getChild(AbstractObservationProcess.class);
    PatternList patternList = (PatternList) xo.getChild(PatternList.class);
    TreeModel treeModel = (TreeModel) xo.getChild(TreeModel.class);
    SiteModel siteModel = (SiteModel) xo.getChild(SiteModel.class);
    BranchRateModel branchRateModel = (BranchRateModel) xo.getChild(BranchRateModel.class);
    Parameter mu = ((MutationDeathModel) siteModel.getSubstitutionModel()).getDeathParameter();
    Parameter lam;
    if (!integrateGainRate) {
        lam = (Parameter) xo.getElementFirstChild(IMMIGRATION_RATE);
    } else {
        lam = new Parameter.Default("gainRate", 1.0, 0.001, 1.999);
    }
    AbstractObservationProcess observationProcess = null;
    Logger.getLogger("dr.evolution").info("\n ---------------------------------\nCreating ALSTreeLikelihood model.");
    for (int i = 0; i < xo.getChildCount(); ++i) {
        Object cxo = xo.getChild(i);
        if (cxo instanceof XMLObject && ((XMLObject) cxo).getName().equals(OBSERVATION_PROCESS)) {
            if (((XMLObject) cxo).getStringAttribute(OBSERVATION_TYPE).equals("singleTip")) {
                String taxonName = ((XMLObject) cxo).getStringAttribute(OBSERVATION_TAXON);
                Taxon taxon = treeModel.getTaxon(treeModel.getTaxonIndex(taxonName));
                observationProcess = new SingleTipObservationProcess(treeModel, patternList, siteModel, branchRateModel, mu, lam, taxon);
                Logger.getLogger("dr.evolution").info("All traits are assumed extant in " + taxonName);
            } else {
                // "anyTip" observation process
                observationProcess = new AnyTipObservationProcess(ANY_TIP, treeModel, patternList, siteModel, branchRateModel, mu, lam);
                Logger.getLogger("dr.evolution").info("Observed traits are assumed to be extant in at least one tip node.");
            }
            observationProcess.setIntegrateGainRate(integrateGainRate);
        }
    }
    Logger.getLogger("dr.evolution").info("\tIf you publish results using Acquisition-Loss-Mutation (ALS) Model likelihood, please reference Alekseyenko, Lee and Suchard (2008) Syst. Biol 57: 772-784.\n---------------------------------\n");
    boolean forceRescaling = xo.getAttribute(FORCE_RESCALING, false);
    return new ALSTreeLikelihood(observationProcess, patternList, treeModel, siteModel, branchRateModel, useAmbiguities, storePartials, forceRescaling);
}
Also used : Taxon(dr.evolution.util.Taxon) AnyTipObservationProcess(dr.oldevomodel.MSSD.AnyTipObservationProcess) PatternList(dr.evolution.alignment.PatternList) MutationDeathModel(dr.oldevomodel.substmodel.MutationDeathModel) SiteModel(dr.oldevomodel.sitemodel.SiteModel) ALSTreeLikelihood(dr.oldevomodel.MSSD.ALSTreeLikelihood) TreeModel(dr.evomodel.tree.TreeModel) SingleTipObservationProcess(dr.oldevomodel.MSSD.SingleTipObservationProcess) BranchRateModel(dr.evomodel.branchratemodel.BranchRateModel) AbstractObservationProcess(dr.oldevomodel.MSSD.AbstractObservationProcess) Parameter(dr.inference.model.Parameter)

Example 38 with BranchRateModel

use of dr.evomodel.branchratemodel.BranchRateModel in project beast-mcmc by beast-dev.

the class ALSTreeLikelihoodParser method parseXMLObject.

public Object parseXMLObject(XMLObject xo) throws XMLParseException {
    boolean useAmbiguities = false;
    boolean storePartials = true;
    if (xo.hasAttribute(TreeLikelihoodParser.USE_AMBIGUITIES)) {
        useAmbiguities = xo.getBooleanAttribute(TreeLikelihoodParser.USE_AMBIGUITIES);
    }
    if (xo.hasAttribute(TreeLikelihoodParser.STORE_PARTIALS)) {
        storePartials = xo.getBooleanAttribute(TreeLikelihoodParser.STORE_PARTIALS);
    }
    boolean integrateGainRate = xo.getBooleanAttribute(INTEGRATE_GAIN_RATE);
    //AbstractObservationProcess observationProcess = (AbstractObservationProcess) xo.getChild(AbstractObservationProcess.class);
    PatternList patternList = (PatternList) xo.getChild(PatternList.class);
    TreeModel treeModel = (TreeModel) xo.getChild(TreeModel.class);
    SiteModel siteModel = (SiteModel) xo.getChild(SiteModel.class);
    BranchRateModel branchRateModel = (BranchRateModel) xo.getChild(BranchRateModel.class);
    Parameter mu = ((MutationDeathModel) siteModel.getSubstitutionModel()).getDeathParameter();
    Parameter lam;
    if (!integrateGainRate) {
        lam = (Parameter) xo.getElementFirstChild(IMMIGRATION_RATE);
    } else {
        lam = new Parameter.Default("gainRate", 1.0, 0.001, 1.999);
    }
    AbstractObservationProcess observationProcess = null;
    Logger.getLogger("dr.evolution").info("\n ---------------------------------\nCreating ALSTreeLikelihood model.");
    for (int i = 0; i < xo.getChildCount(); ++i) {
        Object cxo = xo.getChild(i);
        if (cxo instanceof XMLObject && ((XMLObject) cxo).getName().equals(OBSERVATION_PROCESS)) {
            if (((XMLObject) cxo).getStringAttribute(OBSERVATION_TYPE).equals("singleTip")) {
                String taxonName = ((XMLObject) cxo).getStringAttribute(OBSERVATION_TAXON);
                Taxon taxon = treeModel.getTaxon(treeModel.getTaxonIndex(taxonName));
                observationProcess = new SingleTipObservationProcess(treeModel, patternList, siteModel, branchRateModel, mu, lam, taxon);
                Logger.getLogger("dr.evolution").info("All traits are assumed extant in " + taxonName);
            } else {
                // "anyTip" observation process
                observationProcess = new AnyTipObservationProcess(ANY_TIP, treeModel, patternList, siteModel, branchRateModel, mu, lam);
                Logger.getLogger("dr.evolution").info("Observed traits are assumed to be extant in at least one tip node.");
            }
            observationProcess.setIntegrateGainRate(integrateGainRate);
        }
    }
    Logger.getLogger("dr.evolution").info("\tIf you publish results using Acquisition-Loss-Mutation (ALS) Model likelihood, please reference Alekseyenko, Lee and Suchard (2008) Syst. Biol 57: 772-784.\n---------------------------------\n");
    boolean forceRescaling = xo.getAttribute(FORCE_RESCALING, false);
    return new ALSTreeLikelihood(observationProcess, patternList, treeModel, siteModel, branchRateModel, useAmbiguities, storePartials, forceRescaling);
}
Also used : Taxon(dr.evolution.util.Taxon) AnyTipObservationProcess(dr.oldevomodel.MSSD.AnyTipObservationProcess) PatternList(dr.evolution.alignment.PatternList) MutationDeathModel(dr.oldevomodel.substmodel.MutationDeathModel) SiteModel(dr.oldevomodel.sitemodel.SiteModel) ALSTreeLikelihood(dr.oldevomodel.MSSD.ALSTreeLikelihood) TreeModel(dr.evomodel.tree.TreeModel) SingleTipObservationProcess(dr.oldevomodel.MSSD.SingleTipObservationProcess) BranchRateModel(dr.evomodel.branchratemodel.BranchRateModel) AbstractObservationProcess(dr.oldevomodel.MSSD.AbstractObservationProcess) Parameter(dr.inference.model.Parameter)

Example 39 with BranchRateModel

use of dr.evomodel.branchratemodel.BranchRateModel in project beast-mcmc by beast-dev.

the class AnyTipObservationProcessParser method parseXMLObject.

public Object parseXMLObject(XMLObject xo) throws XMLParseException {
    Parameter mu = (Parameter) xo.getElementFirstChild(DEATH_RATE);
    Parameter lam = (Parameter) xo.getElementFirstChild(IMMIGRATION_RATE);
    TreeModel treeModel = (TreeModel) xo.getChild(TreeModel.class);
    PatternList patterns = (PatternList) xo.getChild(PatternList.class);
    SiteModel siteModel = (SiteModel) xo.getChild(SiteModel.class);
    BranchRateModel branchRateModel = (BranchRateModel) xo.getChild(BranchRateModel.class);
    Logger.getLogger("dr.evomodel.MSSD").info("Creating AnyTipObservationProcess model. Observed traits are assumed to be extant in at least one tip node. Initial mu = " + mu.getParameterValue(0) + " initial lam = " + lam.getParameterValue(0));
    return new AnyTipObservationProcess(MODEL_NAME, treeModel, patterns, siteModel, branchRateModel, mu, lam);
}
Also used : TreeModel(dr.evomodel.tree.TreeModel) BranchRateModel(dr.evomodel.branchratemodel.BranchRateModel) AnyTipObservationProcess(dr.oldevomodel.MSSD.AnyTipObservationProcess) PatternList(dr.evolution.alignment.PatternList) Parameter(dr.inference.model.Parameter) SiteModel(dr.oldevomodel.sitemodel.SiteModel)

Example 40 with BranchRateModel

use of dr.evomodel.branchratemodel.BranchRateModel in project beast-mcmc by beast-dev.

the class PartitionData method createClockRateModel.

public BranchRateModel createClockRateModel() {
    BranchRateModel branchRateModel = null;
    if (this.clockModelIndex == 0) {
        // Strict Clock
        Parameter rateParameter = new Parameter.Default(1, clockParameterValues[0]);
        branchRateModel = new StrictClockBranchRates(rateParameter);
    } else if (this.clockModelIndex == LRC_INDEX) {
        // Lognormal relaxed clock
        double numberOfBranches = 2 * (createTreeModel().getTaxonCount() - 1);
        Parameter rateCategoryParameter = new Parameter.Default(numberOfBranches);
        Parameter mean = new Parameter.Default(LogNormalDistributionModelParser.MEAN, 1, clockParameterValues[1]);
        Parameter stdev = new Parameter.Default(LogNormalDistributionModelParser.STDEV, 1, clockParameterValues[2]);
        //TODO: choose between log scale / real scale
        ParametricDistributionModel distributionModel = new LogNormalDistributionModel(mean, stdev, clockParameterValues[3], lrcParametersInRealSpace, lrcParametersInRealSpace);
        branchRateModel = new //
        DiscretizedBranchRates(//
        createTreeModel(), //
        rateCategoryParameter, //
        distributionModel, // 
        1, // 
        false, //
        Double.NaN, //randomizeRates
        true, // keepRates
        false, // cacheRates
        false);
    } else if (this.clockModelIndex == 2) {
        // Exponential relaxed clock
        double numberOfBranches = 2 * (createTreeModel().getTaxonCount() - 1);
        Parameter rateCategoryParameter = new Parameter.Default(numberOfBranches);
        Parameter mean = new Parameter.Default(DistributionModelParser.MEAN, 1, clockParameterValues[4]);
        ParametricDistributionModel distributionModel = new ExponentialDistributionModel(mean, clockParameterValues[5]);
        //	        branchRateModel = new DiscretizedBranchRates(createTreeModel(), rateCategoryParameter, 
        //	                distributionModel, 1, false, Double.NaN);
        branchRateModel = new //
        DiscretizedBranchRates(//
        createTreeModel(), //
        rateCategoryParameter, //
        distributionModel, // 
        1, // 
        false, //
        Double.NaN, //randomizeRates
        true, // keepRates
        false, // cacheRates
        false);
    } else if (this.clockModelIndex == 3) {
        // Inverse Gaussian
        double numberOfBranches = 2 * (createTreeModel().getTaxonCount() - 1);
        Parameter rateCategoryParameter = new Parameter.Default(numberOfBranches);
        Parameter mean = new Parameter.Default(InverseGaussianDistributionModelParser.MEAN, 1, clockParameterValues[6]);
        Parameter stdev = new Parameter.Default(InverseGaussianDistributionModelParser.STDEV, 1, clockParameterValues[7]);
        ParametricDistributionModel distributionModel = new InverseGaussianDistributionModel(mean, stdev, clockParameterValues[8], false);
        branchRateModel = new //
        DiscretizedBranchRates(//
        createTreeModel(), //
        rateCategoryParameter, //
        distributionModel, // 
        1, // 
        false, //
        Double.NaN, //randomizeRates
        true, // keepRates
        false, // cacheRates
        false);
    } else {
        System.out.println("Not yet implemented");
    }
    return branchRateModel;
}
Also used : DiscretizedBranchRates(dr.evomodel.branchratemodel.DiscretizedBranchRates) InverseGaussianDistributionModel(dr.inference.distribution.InverseGaussianDistributionModel) BranchRateModel(dr.evomodel.branchratemodel.BranchRateModel) LogNormalDistributionModel(dr.inference.distribution.LogNormalDistributionModel) ParametricDistributionModel(dr.inference.distribution.ParametricDistributionModel) ExponentialDistributionModel(dr.inference.distribution.ExponentialDistributionModel) Parameter(dr.inference.model.Parameter) StrictClockBranchRates(dr.evomodel.branchratemodel.StrictClockBranchRates)

Aggregations

BranchRateModel (dr.evomodel.branchratemodel.BranchRateModel)44 Parameter (dr.inference.model.Parameter)31 TreeModel (dr.evomodel.tree.TreeModel)28 GammaSiteRateModel (dr.evomodel.siteratemodel.GammaSiteRateModel)26 FrequencyModel (dr.evomodel.substmodel.FrequencyModel)22 DefaultBranchRateModel (dr.evomodel.branchratemodel.DefaultBranchRateModel)21 HomogeneousBranchModel (dr.evomodel.branchmodel.HomogeneousBranchModel)18 Tree (dr.evolution.tree.Tree)15 ArrayList (java.util.ArrayList)14 PatternList (dr.evolution.alignment.PatternList)12 BranchModel (dr.evomodel.branchmodel.BranchModel)12 HKY (dr.evomodel.substmodel.nucleotide.HKY)12 Partition (dr.app.beagle.tools.Partition)11 NewickImporter (dr.evolution.io.NewickImporter)11 BeagleSequenceSimulator (dr.app.beagle.tools.BeagleSequenceSimulator)10 SubstitutionModel (dr.evomodel.substmodel.SubstitutionModel)8 BeagleTreeLikelihood (dr.evomodel.treelikelihood.BeagleTreeLikelihood)8 SimpleAlignment (dr.evolution.alignment.SimpleAlignment)7 ImportException (dr.evolution.io.Importer.ImportException)7 Taxon (dr.evolution.util.Taxon)7