Search in sources :

Example 16 with SitePatterns

use of dr.evolution.alignment.SitePatterns in project beast-mcmc by beast-dev.

the class AlignmentScore method main.

public static void main(String[] args) throws java.io.IOException, Importer.ImportException {
    NexusImporter importer = new NexusImporter(new FileReader(args[0]));
    Alignment alignment = importer.importAlignment();
    ExtractPairs pairs = new ExtractPairs(alignment);
    Parameter muParam = new Parameter.Default(1.0);
    Parameter kappaParam = new Parameter.Default(1.0);
    kappaParam.addBounds(new Parameter.DefaultBounds(100.0, 0.0, 1));
    muParam.addBounds(new Parameter.DefaultBounds(1.0, 1.0, 1));
    Parameter freqParam = new Parameter.Default(alignment.getStateFrequencies());
    FrequencyModel freqModel = new FrequencyModel(Nucleotides.INSTANCE, freqParam);
    SubstitutionModel substModel = new HKY(kappaParam, freqModel);
    SiteModel siteModel = new GammaSiteModel(substModel, muParam, null, 1, null);
    ScoreMatrix scoreMatrix = new ScoreMatrix(siteModel, 0.1);
    double threshold = 0.1;
    List<PairDistance> pairDistances = new ArrayList<PairDistance>();
    Set<Integer> sequencesUsed = new HashSet<Integer>();
    List<Integer> allGaps = new ArrayList<Integer>();
    for (int i = 0; i < alignment.getSequenceCount(); i++) {
        for (int j = i + 1; j < alignment.getSequenceCount(); j++) {
            Alignment pairAlignment = pairs.getPairAlignment(i, j);
            if (pairAlignment != null) {
                SitePatterns patterns = new SitePatterns(pairAlignment);
                double distance = getGeneticDistance(scoreMatrix, patterns);
                if (distance < threshold) {
                    List gaps = new ArrayList();
                    GapUtils.getGapSizes(pairAlignment, gaps);
                    pairDistances.add(new PairDistance(i, j, distance, gaps, pairAlignment.getSiteCount()));
                    System.out.print(".");
                } else {
                    System.out.print("*");
                }
            } else {
                System.out.print("x");
            }
        }
        System.out.println();
    }
    Collections.sort(pairDistances);
    int totalLength = 0;
    for (PairDistance pairDistance : pairDistances) {
        Integer x = pairDistance.x;
        Integer y = pairDistance.y;
        if (!sequencesUsed.contains(x) && !sequencesUsed.contains(y)) {
            allGaps.addAll(pairDistance.gaps);
            sequencesUsed.add(x);
            sequencesUsed.add(y);
            System.out.println("Added pair (" + x + "," + y + ") d=" + pairDistance.distance + " L=" + pairDistance.alignmentLength);
            totalLength += pairDistance.alignmentLength;
        }
    }
    printFrequencyTable(allGaps);
    System.out.println("total length=" + totalLength);
}
Also used : ExtractPairs(dr.evolution.alignment.ExtractPairs) FrequencyModel(dr.oldevomodel.substmodel.FrequencyModel) SitePatterns(dr.evolution.alignment.SitePatterns) NexusImporter(dr.evolution.io.NexusImporter) SubstitutionModel(dr.oldevomodel.substmodel.SubstitutionModel) Alignment(dr.evolution.alignment.Alignment) HKY(dr.oldevomodel.substmodel.HKY) Parameter(dr.inference.model.Parameter) FileReader(java.io.FileReader)

Example 17 with SitePatterns

use of dr.evolution.alignment.SitePatterns in project beast-mcmc by beast-dev.

the class RandomLocalClockTestProblem method testRandomLocalClock.

public void testRandomLocalClock() throws Exception {
    Parameter popSize = new Parameter.Default(ConstantPopulationModelParser.POPULATION_SIZE, 0.077, 0, Double.POSITIVE_INFINITY);
    ConstantPopulationModel constantModel = createRandomInitialTree(popSize);
    CoalescentLikelihood coalescent = new CoalescentLikelihood(treeModel, null, new ArrayList<TaxonList>(), constantModel);
    coalescent.setId("coalescent");
    // clock model
    Parameter ratesParameter = new Parameter.Default(RandomLocalClockModelParser.RATES, 10, 1);
    Parameter rateIndicatorParameter = new Parameter.Default(RandomLocalClockModelParser.RATE_INDICATORS, 10, 1);
    Parameter meanRateParameter = new Parameter.Default(RandomLocalClockModelParser.CLOCK_RATE, 1, 1.0);
    RandomLocalClockModel branchRateModel = new RandomLocalClockModel(treeModel, meanRateParameter, rateIndicatorParameter, ratesParameter, false, 0.5);
    SumStatistic rateChanges = new SumStatistic("rateChangeCount", true);
    rateChanges.addStatistic(rateIndicatorParameter);
    RateStatistic meanRate = new RateStatistic("meanRate", treeModel, branchRateModel, true, true, RateStatisticParser.MEAN);
    RateStatistic coefficientOfVariation = new RateStatistic(RateStatisticParser.COEFFICIENT_OF_VARIATION, treeModel, branchRateModel, true, true, RateStatisticParser.COEFFICIENT_OF_VARIATION);
    RateCovarianceStatistic covariance = new RateCovarianceStatistic("covariance", treeModel, branchRateModel);
    // Sub model
    Parameter freqs = new Parameter.Default(alignment.getStateFrequencies());
    Parameter kappa = new Parameter.Default(HKYParser.KAPPA, 1.0, 0, Double.POSITIVE_INFINITY);
    FrequencyModel f = new FrequencyModel(Nucleotides.INSTANCE, freqs);
    HKY hky = new HKY(kappa, f);
    //siteModel
    GammaSiteModel siteModel = new GammaSiteModel(hky);
    Parameter mu = new Parameter.Default(GammaSiteModelParser.MUTATION_RATE, 1.0, 0, Double.POSITIVE_INFINITY);
    siteModel.setMutationRateParameter(mu);
    //treeLikelihood
    SitePatterns patterns = new SitePatterns(alignment, null, 0, -1, 1, true);
    TreeLikelihood treeLikelihood = new TreeLikelihood(patterns, treeModel, siteModel, branchRateModel, null, false, false, true, false, false);
    treeLikelihood.setId(TreeLikelihoodParser.TREE_LIKELIHOOD);
    // Operators
    OperatorSchedule schedule = new SimpleOperatorSchedule();
    MCMCOperator operator = new ScaleOperator(kappa, 0.75);
    operator.setWeight(1.0);
    schedule.addOperator(operator);
    operator = new ScaleOperator(ratesParameter, 0.75);
    operator.setWeight(10.0);
    schedule.addOperator(operator);
    operator = new BitFlipOperator(rateIndicatorParameter, 15.0, true);
    schedule.addOperator(operator);
    operator = new ScaleOperator(popSize, 0.75);
    operator.setWeight(3.0);
    schedule.addOperator(operator);
    Parameter rootHeight = treeModel.getRootHeightParameter();
    rootHeight.setId(TREE_HEIGHT);
    operator = new ScaleOperator(rootHeight, 0.75);
    operator.setWeight(3.0);
    schedule.addOperator(operator);
    Parameter internalHeights = treeModel.createNodeHeightsParameter(false, true, false);
    operator = new UniformOperator(internalHeights, 30.0);
    schedule.addOperator(operator);
    operator = new SubtreeSlideOperator(treeModel, 15.0, 0.0077, true, false, false, false, CoercionMode.COERCION_ON);
    schedule.addOperator(operator);
    operator = new ExchangeOperator(ExchangeOperator.NARROW, treeModel, 15.0);
    //        operator.doOperation();
    schedule.addOperator(operator);
    operator = new ExchangeOperator(ExchangeOperator.WIDE, treeModel, 3.0);
    //        operator.doOperation();
    schedule.addOperator(operator);
    operator = new WilsonBalding(treeModel, 3.0);
    //        operator.doOperation();
    schedule.addOperator(operator);
    //CompoundLikelihood
    OneOnXPrior likelihood1 = new OneOnXPrior();
    likelihood1.addData(popSize);
    OneOnXPrior likelihood2 = new OneOnXPrior();
    likelihood2.addData(kappa);
    DistributionLikelihood likelihood3 = new DistributionLikelihood(new GammaDistribution(0.5, 2.0), 0.0);
    likelihood3.addData(ratesParameter);
    DistributionLikelihood likelihood4 = new DistributionLikelihood(new PoissonDistribution(1.0), 0.0);
    likelihood4.addData(rateChanges);
    List<Likelihood> likelihoods = new ArrayList<Likelihood>();
    likelihoods.add(likelihood1);
    likelihoods.add(likelihood2);
    likelihoods.add(likelihood3);
    likelihoods.add(likelihood4);
    likelihoods.add(coalescent);
    Likelihood prior = new CompoundLikelihood(0, likelihoods);
    prior.setId(CompoundLikelihoodParser.PRIOR);
    likelihoods.clear();
    likelihoods.add(treeLikelihood);
    Likelihood likelihood = new CompoundLikelihood(-1, likelihoods);
    likelihoods.clear();
    likelihoods.add(prior);
    likelihoods.add(likelihood);
    Likelihood posterior = new CompoundLikelihood(0, likelihoods);
    posterior.setId(CompoundLikelihoodParser.POSTERIOR);
    // Log
    ArrayLogFormatter formatter = new ArrayLogFormatter(false);
    MCLogger[] loggers = new MCLogger[2];
    loggers[0] = new MCLogger(formatter, 1000, false);
    loggers[0].add(posterior);
    loggers[0].add(prior);
    loggers[0].add(treeLikelihood);
    loggers[0].add(rootHeight);
    loggers[0].add(kappa);
    //        loggers[0].add(meanRate);
    loggers[0].add(rateChanges);
    loggers[0].add(coefficientOfVariation);
    loggers[0].add(covariance);
    loggers[0].add(popSize);
    loggers[0].add(coalescent);
    loggers[1] = new MCLogger(new TabDelimitedFormatter(System.out), 10000, false);
    loggers[1].add(posterior);
    loggers[1].add(treeLikelihood);
    loggers[1].add(rootHeight);
    loggers[1].add(meanRate);
    loggers[1].add(rateChanges);
    // MCMC
    MCMC mcmc = new MCMC("mcmc1");
    MCMCOptions options = new MCMCOptions(1000000);
    mcmc.setShowOperatorAnalysis(true);
    mcmc.init(options, posterior, schedule, loggers);
    mcmc.run();
    // time
    System.out.println(mcmc.getTimer().toString());
    // Tracer
    List<Trace> traces = formatter.getTraces();
    ArrayTraceList traceList = new ArrayTraceList("RandomLocalClockTest", traces, 0);
    for (int i = 1; i < traces.size(); i++) {
        traceList.analyseTrace(i);
    }
    //        <expectation name="posterior" value="-1818.26"/>
    //        <expectation name="prior" value="-2.70143"/>
    //        <expectation name="likelihood" value="-1815.56"/>
    //        <expectation name="treeModel.rootHeight" value="6.363E-2"/>
    //        <expectation name="constant.popSize" value="9.67405E-2"/>
    //        <expectation name="hky.kappa" value="30.0394"/>
    //        <expectation name="coefficientOfVariation" value="7.02408E-2"/>
    //        covariance 0.47952
    //        <expectation name="rateChangeCount" value="0.40786"/>
    //        <expectation name="coalescent" value="7.29521"/>
    TraceCorrelation likelihoodStats = traceList.getCorrelationStatistics(traceList.getTraceIndex(CompoundLikelihoodParser.POSTERIOR));
    assertExpectation(CompoundLikelihoodParser.POSTERIOR, likelihoodStats, -1818.26);
    likelihoodStats = traceList.getCorrelationStatistics(traceList.getTraceIndex(CompoundLikelihoodParser.PRIOR));
    assertExpectation(CompoundLikelihoodParser.PRIOR, likelihoodStats, -2.70143);
    likelihoodStats = traceList.getCorrelationStatistics(traceList.getTraceIndex(TreeLikelihoodParser.TREE_LIKELIHOOD));
    assertExpectation(TreeLikelihoodParser.TREE_LIKELIHOOD, likelihoodStats, -1815.56);
    TraceCorrelation treeHeightStats = traceList.getCorrelationStatistics(traceList.getTraceIndex(TREE_HEIGHT));
    assertExpectation(TREE_HEIGHT, treeHeightStats, 6.363E-2);
    TraceCorrelation kappaStats = traceList.getCorrelationStatistics(traceList.getTraceIndex(HKYParser.KAPPA));
    assertExpectation(HKYParser.KAPPA, kappaStats, 30.0394);
    TraceCorrelation rateChangeStats = traceList.getCorrelationStatistics(traceList.getTraceIndex("rateChangeCount"));
    assertExpectation("rateChangeCount", rateChangeStats, 0.40786);
    TraceCorrelation coefficientOfVariationStats = traceList.getCorrelationStatistics(traceList.getTraceIndex(RateStatisticParser.COEFFICIENT_OF_VARIATION));
    assertExpectation(RateStatisticParser.COEFFICIENT_OF_VARIATION, coefficientOfVariationStats, 7.02408E-2);
    TraceCorrelation covarianceStats = traceList.getCorrelationStatistics(traceList.getTraceIndex("covariance"));
    assertExpectation("covariance", covarianceStats, 0.47952);
    TraceCorrelation popStats = traceList.getCorrelationStatistics(traceList.getTraceIndex(ConstantPopulationModelParser.POPULATION_SIZE));
    assertExpectation(ConstantPopulationModelParser.POPULATION_SIZE, popStats, 9.67405E-2);
    TraceCorrelation coalescentStats = traceList.getCorrelationStatistics(traceList.getTraceIndex("coalescent"));
    assertExpectation("coalescent", coalescentStats, 7.29521);
}
Also used : FrequencyModel(dr.oldevomodel.substmodel.FrequencyModel) PoissonDistribution(dr.math.distributions.PoissonDistribution) TreeLikelihood(dr.oldevomodel.treelikelihood.TreeLikelihood) DistributionLikelihood(dr.inference.distribution.DistributionLikelihood) CoalescentLikelihood(dr.evomodel.coalescent.CoalescentLikelihood) TreeLikelihood(dr.oldevomodel.treelikelihood.TreeLikelihood) ExchangeOperator(dr.evomodel.operators.ExchangeOperator) ArrayList(java.util.ArrayList) MCMC(dr.inference.mcmc.MCMC) SubtreeSlideOperator(dr.evomodel.operators.SubtreeSlideOperator) RateCovarianceStatistic(dr.evomodel.tree.RateCovarianceStatistic) CoalescentLikelihood(dr.evomodel.coalescent.CoalescentLikelihood) MCMCOptions(dr.inference.mcmc.MCMCOptions) ArrayLogFormatter(dr.inference.loggers.ArrayLogFormatter) GammaDistribution(dr.math.distributions.GammaDistribution) WilsonBalding(dr.evomodel.operators.WilsonBalding) TraceCorrelation(dr.inference.trace.TraceCorrelation) SitePatterns(dr.evolution.alignment.SitePatterns) ConstantPopulationModel(dr.evomodel.coalescent.ConstantPopulationModel) TaxonList(dr.evolution.util.TaxonList) TabDelimitedFormatter(dr.inference.loggers.TabDelimitedFormatter) Trace(dr.inference.trace.Trace) GammaSiteModel(dr.oldevomodel.sitemodel.GammaSiteModel) RateStatistic(dr.evomodel.tree.RateStatistic) ArrayTraceList(dr.inference.trace.ArrayTraceList) RandomLocalClockModel(dr.evomodel.branchratemodel.RandomLocalClockModel) HKY(dr.oldevomodel.substmodel.HKY) DistributionLikelihood(dr.inference.distribution.DistributionLikelihood) MCLogger(dr.inference.loggers.MCLogger)

Example 18 with SitePatterns

use of dr.evolution.alignment.SitePatterns in project beast-mcmc by beast-dev.

the class MarkovJumpsTest method testMarkovJumps.

public void testMarkovJumps() {
    MathUtils.setSeed(666);
    createAlignment(sequencesTwo, Nucleotides.INSTANCE);
    try {
        //            createSpecifiedTree("((human:1,chimp:1):1,gorilla:2)");
        createSpecifiedTree("(human:1,chimp:1)");
    } catch (Exception e) {
        throw new RuntimeException("Unable to parse Newick tree");
    }
    //substitutionModel
    Parameter freqs = new Parameter.Default(new double[] { 0.40, 0.25, 0.25, 0.10 });
    Parameter kappa = new Parameter.Default(HKYParser.KAPPA, 10.0, 0, 100);
    FrequencyModel f = new FrequencyModel(Nucleotides.INSTANCE, freqs);
    HKY hky = new HKY(kappa, f);
    //siteModel
    //        double alpha = 0.5;
    Parameter mu = new Parameter.Default(GammaSiteModelParser.MUTATION_RATE, 0.5, 0, Double.POSITIVE_INFINITY);
    //        Parameter pInv = new Parameter.Default("pInv", 0.5, 0, 1);
    Parameter pInv = null;
    GammaSiteRateModel siteRateModel = new GammaSiteRateModel("gammaModel", mu, null, -1, pInv);
    siteRateModel.setSubstitutionModel(hky);
    //treeLikelihood
    SitePatterns patterns = new SitePatterns(alignment, null, 0, -1, 1, true);
    BranchModel branchModel = new HomogeneousBranchModel(siteRateModel.getSubstitutionModel());
    BranchRateModel branchRateModel = null;
    MarkovJumpsBeagleTreeLikelihood mjTreeLikelihood = new MarkovJumpsBeagleTreeLikelihood(patterns, treeModel, branchModel, siteRateModel, branchRateModel, null, false, PartialsRescalingScheme.AUTO, true, null, hky.getDataType(), "stateTag", // use MAP
    false, // return ML
    true, // use uniformization
    false, false, 1000);
    int nRegisters = registerValues.length;
    int nSim = 10000;
    for (int i = 0; i < nRegisters; i++) {
        Parameter registerParameter = new Parameter.Default(registerValues[i]);
        registerParameter.setId(registerTages[i]);
        mjTreeLikelihood.addRegister(registerParameter, registerTypes[i], registerScales[i]);
    }
    double logLikelihood = mjTreeLikelihood.getLogLikelihood();
    System.out.println("logLikelihood = " + logLikelihood);
    double[] averages = new double[nRegisters];
    for (int i = 0; i < nSim; i++) {
        for (int r = 0; r < nRegisters; r++) {
            double[][] values = mjTreeLikelihood.getMarkovJumpsForRegister(treeModel, r);
            for (double[] value : values) {
                averages[r] += value[0];
            }
        }
        mjTreeLikelihood.makeDirty();
    }
    for (int r = 0; r < nRegisters; r++) {
        averages[r] /= (double) nSim;
        System.out.print(" " + averages[r]);
    }
    System.out.println("");
    assertEquals(valuesFromR, averages, 1E-2);
}
Also used : FrequencyModel(dr.evomodel.substmodel.FrequencyModel) SitePatterns(dr.evolution.alignment.SitePatterns) MarkovJumpsBeagleTreeLikelihood(dr.evomodel.treelikelihood.MarkovJumpsBeagleTreeLikelihood) HomogeneousBranchModel(dr.evomodel.branchmodel.HomogeneousBranchModel) GammaSiteRateModel(dr.evomodel.siteratemodel.GammaSiteRateModel) BranchModel(dr.evomodel.branchmodel.BranchModel) HomogeneousBranchModel(dr.evomodel.branchmodel.HomogeneousBranchModel) BranchRateModel(dr.evomodel.branchratemodel.BranchRateModel) HKY(dr.evomodel.substmodel.nucleotide.HKY) Parameter(dr.inference.model.Parameter)

Example 19 with SitePatterns

use of dr.evolution.alignment.SitePatterns in project beast-mcmc by beast-dev.

the class DataLikelihoodTester method main.

public static void main(String[] args) {
    // turn off logging to avoid screen noise...
    Logger logger = Logger.getLogger("dr");
    logger.setUseParentHandlers(false);
    SimpleAlignment alignment = createAlignment(sequences, Nucleotides.INSTANCE);
    TreeModel treeModel;
    try {
        treeModel = createSpecifiedTree("((human:0.1,chimp:0.1):0.1,gorilla:0.2)");
    } catch (Exception e) {
        throw new RuntimeException("Unable to parse Newick tree");
    }
    System.out.print("\nTest BeagleTreeLikelihood (kappa = 1): ");
    //substitutionModel
    Parameter freqs = new Parameter.Default(new double[] { 0.25, 0.25, 0.25, 0.25 });
    Parameter kappa = new Parameter.Default(HKYParser.KAPPA, 1.0, 0, 100);
    FrequencyModel f = new FrequencyModel(Nucleotides.INSTANCE, freqs);
    HKY hky = new HKY(kappa, f);
    //siteModel
    double alpha = 0.5;
    GammaSiteRateModel siteRateModel = new GammaSiteRateModel("gammaModel", alpha, 4);
    //        GammaSiteRateModel siteRateModel = new GammaSiteRateModel("siteRateModel");
    siteRateModel.setSubstitutionModel(hky);
    Parameter mu = new Parameter.Default(GammaSiteModelParser.SUBSTITUTION_RATE, 1.0, 0, Double.POSITIVE_INFINITY);
    siteRateModel.setRelativeRateParameter(mu);
    FrequencyModel f2 = new FrequencyModel(Nucleotides.INSTANCE, freqs);
    Parameter kappa2 = new Parameter.Default(HKYParser.KAPPA, 10.0, 0, 100);
    HKY hky2 = new HKY(kappa2, f2);
    GammaSiteRateModel siteRateModel2 = new GammaSiteRateModel("gammaModel", alpha, 4);
    siteRateModel2.setSubstitutionModel(hky2);
    siteRateModel2.setRelativeRateParameter(mu);
    //treeLikelihood
    SitePatterns patterns = new SitePatterns(alignment, null, 0, -1, 1, true);
    BranchModel branchModel = new HomogeneousBranchModel(siteRateModel.getSubstitutionModel(), siteRateModel.getSubstitutionModel().getFrequencyModel());
    BranchModel branchModel2 = new HomogeneousBranchModel(siteRateModel2.getSubstitutionModel(), siteRateModel2.getSubstitutionModel().getFrequencyModel());
    BranchRateModel branchRateModel = new DefaultBranchRateModel();
    BeagleTreeLikelihood treeLikelihood = new BeagleTreeLikelihood(patterns, treeModel, branchModel, siteRateModel, branchRateModel, null, false, PartialsRescalingScheme.AUTO, true);
    double logLikelihood = treeLikelihood.getLogLikelihood();
    System.out.println("logLikelihood = " + logLikelihood);
    System.out.print("\nTest BeagleDataLikelihoodDelegate (kappa = 1): ");
    BeagleDataLikelihoodDelegate dataLikelihoodDelegate = new BeagleDataLikelihoodDelegate(treeModel, patterns, branchModel, siteRateModel, false, PartialsRescalingScheme.NONE, false);
    TreeDataLikelihood treeDataLikelihood = new TreeDataLikelihood(dataLikelihoodDelegate, treeModel, branchRateModel);
    logLikelihood = treeDataLikelihood.getLogLikelihood();
    System.out.println("logLikelihood = " + logLikelihood);
    hky.setKappa(5.0);
    System.out.print("\nTest BeagleDataLikelihoodDelegate (kappa = 5): ");
    logLikelihood = treeDataLikelihood.getLogLikelihood();
    System.out.println("logLikelihood = " + logLikelihood);
    System.out.print("\nTest BeagleDataLikelihoodDelegate (kappa = 10): ");
    dataLikelihoodDelegate = new BeagleDataLikelihoodDelegate(treeModel, patterns, branchModel2, siteRateModel2, false, PartialsRescalingScheme.NONE, false);
    treeDataLikelihood = new TreeDataLikelihood(dataLikelihoodDelegate, treeModel, branchRateModel);
    logLikelihood = treeDataLikelihood.getLogLikelihood();
    System.out.println("logLikelihood = " + logLikelihood);
    hky2.setKappa(11.0);
    System.out.print("\nTest BeagleDataLikelihoodDelegate (kappa = 11): ");
    logLikelihood = treeDataLikelihood.getLogLikelihood();
    System.out.println("logLikelihood = " + logLikelihood);
    hky.setKappa(1.0);
    hky2.setKappa(10.0);
    MultiPartitionDataLikelihoodDelegate multiPartitionDataLikelihoodDelegate;
    System.out.print("\nTest MultiPartitionDataLikelihoodDelegate 1 partition (kappa = 1):");
    multiPartitionDataLikelihoodDelegate = new MultiPartitionDataLikelihoodDelegate(treeModel, Collections.singletonList((PatternList) patterns), Collections.singletonList((BranchModel) branchModel), Collections.singletonList((SiteRateModel) siteRateModel), true, PartialsRescalingScheme.NONE, false);
    treeDataLikelihood = new TreeDataLikelihood(multiPartitionDataLikelihoodDelegate, treeModel, branchRateModel);
    logLikelihood = treeDataLikelihood.getLogLikelihood();
    System.out.println("logLikelihood = " + logLikelihood);
    hky.setKappa(5.0);
    System.out.print("\nTest MultiPartitionDataLikelihoodDelegate 1 partition (kappa = 5):");
    logLikelihood = treeDataLikelihood.getLogLikelihood();
    System.out.println("logLikelihood = " + logLikelihood);
    hky.setKappa(1.0);
    System.out.print("\nTest MultiPartitionDataLikelihoodDelegate 1 partition (kappa = 10):");
    multiPartitionDataLikelihoodDelegate = new MultiPartitionDataLikelihoodDelegate(treeModel, Collections.singletonList((PatternList) patterns), Collections.singletonList((BranchModel) branchModel2), Collections.singletonList((SiteRateModel) siteRateModel2), true, PartialsRescalingScheme.NONE, false);
    treeDataLikelihood = new TreeDataLikelihood(multiPartitionDataLikelihoodDelegate, treeModel, branchRateModel);
    logLikelihood = treeDataLikelihood.getLogLikelihood();
    System.out.println("logLikelihood = " + logLikelihood);
    System.out.print("\nTest MultiPartitionDataLikelihoodDelegate 2 partitions (kappa = 1, 10): ");
    List<PatternList> patternLists = new ArrayList<PatternList>();
    patternLists.add(patterns);
    patternLists.add(patterns);
    List<SiteRateModel> siteRateModels = new ArrayList<SiteRateModel>();
    siteRateModels.add(siteRateModel);
    siteRateModels.add(siteRateModel2);
    List<BranchModel> branchModels = new ArrayList<BranchModel>();
    branchModels.add(branchModel);
    branchModels.add(branchModel2);
    multiPartitionDataLikelihoodDelegate = new MultiPartitionDataLikelihoodDelegate(treeModel, patternLists, branchModels, siteRateModels, true, PartialsRescalingScheme.NONE, false);
    treeDataLikelihood = new TreeDataLikelihood(multiPartitionDataLikelihoodDelegate, treeModel, branchRateModel);
    logLikelihood = treeDataLikelihood.getLogLikelihood();
    System.out.println("logLikelihood = " + logLikelihood + " (NOT OK: this is 2x the logLikelihood of the 2nd partition)\n\n");
    System.exit(0);
    //START ADDITIONAL TEST #1 - Guy Baele
    System.out.println("-- Test #1 SiteRateModels -- ");
    //alpha in partition 1 reject followed by alpha in partition 2 reject
    System.out.print("Adjust alpha in partition 1: ");
    siteRateModel.setAlpha(0.4);
    logLikelihood = treeDataLikelihood.getLogLikelihood();
    System.out.println("logLikelihood = " + logLikelihood);
    System.out.print("Return alpha in partition 1 to original value: ");
    siteRateModel.setAlpha(0.5);
    logLikelihood = treeDataLikelihood.getLogLikelihood();
    System.out.println("logLikelihood = " + logLikelihood + " (i.e. reject: OK)\n");
    System.out.print("Adjust alpha in partition 2: ");
    siteRateModel2.setAlpha(0.35);
    logLikelihood = treeDataLikelihood.getLogLikelihood();
    System.out.println("logLikelihood = " + logLikelihood);
    System.out.print("Return alpha in partition 2 to original value: ");
    siteRateModel2.setAlpha(0.5);
    logLikelihood = treeDataLikelihood.getLogLikelihood();
    System.out.println("logLikelihood = " + logLikelihood + " (i.e. reject: OK)\n");
    //alpha in partition 1 accept followed by alpha in partition 2 accept
    System.out.print("Adjust alpha in partition 1: ");
    siteRateModel.setAlpha(0.4);
    logLikelihood = treeDataLikelihood.getLogLikelihood();
    System.out.println("logLikelihood = " + logLikelihood);
    System.out.print("Adjust alpha in partition 2: ");
    siteRateModel2.setAlpha(0.35);
    logLikelihood = treeDataLikelihood.getLogLikelihood();
    System.out.println("logLikelihood = " + logLikelihood + " (NOT OK: same logLikelihood as only setting alpha in partition 2)");
    System.out.print("Return alpha in partition 1 to original value: ");
    siteRateModel.setAlpha(0.5);
    logLikelihood = treeDataLikelihood.getLogLikelihood();
    System.out.println("logLikelihood = " + logLikelihood + " (NOT OK: alpha in partition 2 has not been returned to original value yet)");
    System.out.print("Return alpha in partition 2 to original value: ");
    siteRateModel2.setAlpha(0.5);
    logLikelihood = treeDataLikelihood.getLogLikelihood();
    System.out.println("logLikelihood = " + logLikelihood + "\n");
    //adjusting alphas in both partitions without explicitly calling getLogLikelihood() in between
    System.out.print("Adjust both alphas in partitions 1 and 2: ");
    siteRateModel.setAlpha(0.4);
    siteRateModel2.setAlpha(0.35);
    logLikelihood = treeDataLikelihood.getLogLikelihood();
    System.out.println("logLikelihood = " + logLikelihood);
    System.out.print("Return alpha in partition 2 to original value: ");
    siteRateModel2.setAlpha(0.5);
    logLikelihood = treeDataLikelihood.getLogLikelihood();
    System.out.println("logLikelihood = " + logLikelihood + " (NOT OK: alpha in partition 1 has not been returned to original value yet)");
    System.out.print("Return alpha in partition 1 to original value: ");
    siteRateModel.setAlpha(0.5);
    logLikelihood = treeDataLikelihood.getLogLikelihood();
    System.out.println("logLikelihood = " + logLikelihood + "\n\n");
    //END ADDITIONAL TEST - Guy Baele
    //START ADDITIONAL TEST #2 - Guy Baele
    System.out.println("-- Test #2 SiteRateModels -- ");
    logLikelihood = treeDataLikelihood.getLogLikelihood();
    System.out.println("logLikelihood = " + logLikelihood);
    //1 siteRateModel shared across 2 partitions
    siteRateModels = new ArrayList<SiteRateModel>();
    siteRateModels.add(siteRateModel);
    multiPartitionDataLikelihoodDelegate = new MultiPartitionDataLikelihoodDelegate(treeModel, patternLists, branchModels, siteRateModels, true, PartialsRescalingScheme.NONE, false);
    treeDataLikelihood = new TreeDataLikelihood(multiPartitionDataLikelihoodDelegate, treeModel, branchRateModel);
    logLikelihood = treeDataLikelihood.getLogLikelihood();
    System.out.println("logLikelihood = " + logLikelihood + "\n");
    System.out.print("Adjust alpha in shared siteRateModel: ");
    siteRateModel.setAlpha(0.4);
    logLikelihood = treeDataLikelihood.getLogLikelihood();
    System.out.println("logLikelihood = " + logLikelihood + " (NOT OK: same logLikelihood as only adjusted alpha for partition 1)");
    siteRateModel.setAlpha(0.5);
    logLikelihood = treeDataLikelihood.getLogLikelihood();
    System.out.println("logLikelihood = " + logLikelihood + "\n\n");
    //END ADDITIONAL TEST - Guy Baele
    //START ADDITIONAL TEST #3 - Guy Baele
    System.out.println("-- Test #3 SiteRateModels -- ");
    siteRateModel = new GammaSiteRateModel("gammaModel");
    siteRateModel.setSubstitutionModel(hky);
    siteRateModel.setRelativeRateParameter(mu);
    siteRateModel2 = new GammaSiteRateModel("gammaModel2");
    siteRateModel2.setSubstitutionModel(hky2);
    siteRateModel2.setRelativeRateParameter(mu);
    siteRateModels = new ArrayList<SiteRateModel>();
    siteRateModels.add(siteRateModel);
    siteRateModels.add(siteRateModel2);
    multiPartitionDataLikelihoodDelegate = new MultiPartitionDataLikelihoodDelegate(treeModel, patternLists, branchModels, siteRateModels, true, PartialsRescalingScheme.NONE, false);
    treeDataLikelihood = new TreeDataLikelihood(multiPartitionDataLikelihoodDelegate, treeModel, branchRateModel);
    logLikelihood = treeDataLikelihood.getLogLikelihood();
    System.out.println("logLikelihood = " + logLikelihood + "\n");
    System.out.print("Adjust kappa in partition 1: ");
    hky.setKappa(5.0);
    logLikelihood = treeDataLikelihood.getLogLikelihood();
    System.out.println("logLikelihood = " + logLikelihood + " (NOT OK: logLikelihood has not changed?)");
    System.out.print("Return kappa in partition 1 to original value: ");
    hky.setKappa(1.0);
    logLikelihood = treeDataLikelihood.getLogLikelihood();
    System.out.println("logLikelihood = " + logLikelihood + "\n");
    System.out.print("Adjust kappa in partition 2: ");
    hky2.setKappa(11.0);
    logLikelihood = treeDataLikelihood.getLogLikelihood();
    System.out.println("logLikelihood = " + logLikelihood);
    System.out.print("Return kappa in partition 2 to original value: ");
    hky2.setKappa(10.0);
    logLikelihood = treeDataLikelihood.getLogLikelihood();
    System.out.println("logLikelihood = " + logLikelihood + " (i.e. reject: OK)\n\n");
    //END ADDITIONAL TEST - Guy Baele
    //START ADDITIONAL TEST #4 - Guy Baele
    System.out.println("-- Test #4 SiteRateModels -- ");
    SimpleAlignment secondAlignment = createAlignment(moreSequences, Nucleotides.INSTANCE);
    SitePatterns morePatterns = new SitePatterns(secondAlignment, null, 0, -1, 1, true);
    BeagleDataLikelihoodDelegate dataLikelihoodDelegateOne = new BeagleDataLikelihoodDelegate(treeModel, patterns, branchModel, siteRateModel, false, PartialsRescalingScheme.NONE, false);
    TreeDataLikelihood treeDataLikelihoodOne = new TreeDataLikelihood(dataLikelihoodDelegateOne, treeModel, branchRateModel);
    logLikelihood = treeDataLikelihoodOne.getLogLikelihood();
    System.out.println("\nBeagleDataLikelihoodDelegate logLikelihood partition 1 (kappa = 1) = " + logLikelihood);
    hky.setKappa(10.0);
    logLikelihood = treeDataLikelihoodOne.getLogLikelihood();
    System.out.println("BeagleDataLikelihoodDelegate logLikelihood partition 1 (kappa = 10) = " + logLikelihood);
    hky.setKappa(1.0);
    BeagleDataLikelihoodDelegate dataLikelihoodDelegateTwo = new BeagleDataLikelihoodDelegate(treeModel, morePatterns, branchModel2, siteRateModel2, false, PartialsRescalingScheme.NONE, false);
    TreeDataLikelihood treeDataLikelihoodTwo = new TreeDataLikelihood(dataLikelihoodDelegateTwo, treeModel, branchRateModel);
    logLikelihood = treeDataLikelihoodTwo.getLogLikelihood();
    System.out.println("BeagleDataLikelihoodDelegate logLikelihood partition 2 (kappa = 10) = " + logLikelihood + "\n");
    multiPartitionDataLikelihoodDelegate = new MultiPartitionDataLikelihoodDelegate(treeModel, Collections.singletonList((PatternList) patterns), Collections.singletonList((BranchModel) branchModel), Collections.singletonList((SiteRateModel) siteRateModel), true, PartialsRescalingScheme.NONE, false);
    treeDataLikelihood = new TreeDataLikelihood(multiPartitionDataLikelihoodDelegate, treeModel, branchRateModel);
    logLikelihood = treeDataLikelihood.getLogLikelihood();
    System.out.print("Test MultiPartitionDataLikelihoodDelegate 1st partition (kappa = 1):");
    System.out.println("logLikelihood = " + logLikelihood);
    hky.setKappa(10.0);
    logLikelihood = treeDataLikelihood.getLogLikelihood();
    System.out.print("Test MultiPartitionDataLikelihoodDelegate 1st partition (kappa = 10):");
    System.out.println("logLikelihood = " + logLikelihood);
    hky.setKappa(1.0);
    multiPartitionDataLikelihoodDelegate = new MultiPartitionDataLikelihoodDelegate(treeModel, Collections.singletonList((PatternList) morePatterns), Collections.singletonList((BranchModel) branchModel2), Collections.singletonList((SiteRateModel) siteRateModel2), true, PartialsRescalingScheme.NONE, false);
    treeDataLikelihood = new TreeDataLikelihood(multiPartitionDataLikelihoodDelegate, treeModel, branchRateModel);
    logLikelihood = treeDataLikelihood.getLogLikelihood();
    System.out.print("Test MultiPartitionDataLikelihoodDelegate 2nd partition (kappa = 10):");
    System.out.println("logLikelihood = " + logLikelihood + "\n");
    patternLists = new ArrayList<PatternList>();
    patternLists.add(patterns);
    patternLists.add(morePatterns);
    multiPartitionDataLikelihoodDelegate = new MultiPartitionDataLikelihoodDelegate(treeModel, patternLists, branchModels, siteRateModels, true, PartialsRescalingScheme.NONE, false);
    treeDataLikelihood = new TreeDataLikelihood(multiPartitionDataLikelihoodDelegate, treeModel, branchRateModel);
    logLikelihood = treeDataLikelihood.getLogLikelihood();
    System.out.print("Test MultiPartitionDataLikelihoodDelegate 2 partitions (kappa = 1, 10): ");
    System.out.println("logLikelihood = " + logLikelihood + " (NOT OK: should be the sum of both separate logLikelihoods)\nKappa value of partition 2 is used to compute logLikelihood for both partitions?");
//END ADDITIONAL TEST - Guy Baele
}
Also used : FrequencyModel(dr.evomodel.substmodel.FrequencyModel) ArrayList(java.util.ArrayList) PatternList(dr.evolution.alignment.PatternList) HomogeneousBranchModel(dr.evomodel.branchmodel.HomogeneousBranchModel) BranchModel(dr.evomodel.branchmodel.BranchModel) DefaultBranchRateModel(dr.evomodel.branchratemodel.DefaultBranchRateModel) SiteRateModel(dr.evomodel.siteratemodel.SiteRateModel) GammaSiteRateModel(dr.evomodel.siteratemodel.GammaSiteRateModel) TreeModel(dr.evomodel.tree.TreeModel) SimpleAlignment(dr.evolution.alignment.SimpleAlignment) SitePatterns(dr.evolution.alignment.SitePatterns) BeagleTreeLikelihood(dr.evomodel.treelikelihood.BeagleTreeLikelihood) HomogeneousBranchModel(dr.evomodel.branchmodel.HomogeneousBranchModel) GammaSiteRateModel(dr.evomodel.siteratemodel.GammaSiteRateModel) DefaultBranchRateModel(dr.evomodel.branchratemodel.DefaultBranchRateModel) BranchRateModel(dr.evomodel.branchratemodel.BranchRateModel) HKY(dr.evomodel.substmodel.nucleotide.HKY) Parameter(dr.inference.model.Parameter)

Example 20 with SitePatterns

use of dr.evolution.alignment.SitePatterns in project beast-mcmc by beast-dev.

the class OptimizedBeagleTreeLikelihoodParser method parseXMLObject.

public Object parseXMLObject(XMLObject xo) throws XMLParseException {
    //Default of 100 likelihood calculations for calibration
    int calibrate = 100;
    if (xo.hasAttribute(CALIBRATE)) {
        calibrate = xo.getIntegerAttribute(CALIBRATE);
    }
    //Default: only try the next split up, unless a RETRY value is specified in the XML
    int retry = 0;
    if (xo.hasAttribute(RETRY)) {
        retry = xo.getIntegerAttribute(RETRY);
    }
    int childCount = xo.getChildCount();
    List<Likelihood> likelihoods = new ArrayList<Likelihood>();
    //TEST
    List<Likelihood> originalLikelihoods = new ArrayList<Likelihood>();
    for (int i = 0; i < childCount; i++) {
        likelihoods.add((Likelihood) xo.getChild(i));
        originalLikelihoods.add((Likelihood) xo.getChild(i));
    }
    if (DEBUG) {
        System.err.println("-----");
        System.err.println(childCount + " BeagleTreeLikelihoods added.");
    }
    int[] instanceCounts = new int[childCount];
    for (int i = 0; i < childCount; i++) {
        instanceCounts[i] = 1;
    }
    int[] currentLocation = new int[childCount];
    for (int i = 0; i < childCount; i++) {
        currentLocation[i] = i;
    }
    int[] siteCounts = new int[childCount];
    //store everything for later use
    SitePatterns[] patterns = new SitePatterns[childCount];
    TreeModel[] treeModels = new TreeModel[childCount];
    BranchModel[] branchModels = new BranchModel[childCount];
    GammaSiteRateModel[] siteRateModels = new GammaSiteRateModel[childCount];
    BranchRateModel[] branchRateModels = new BranchRateModel[childCount];
    boolean[] ambiguities = new boolean[childCount];
    PartialsRescalingScheme[] rescalingSchemes = new PartialsRescalingScheme[childCount];
    boolean[] isDelayRescalingUntilUnderflow = new boolean[childCount];
    List<Map<Set<String>, Parameter>> partialsRestrictions = new ArrayList<Map<Set<String>, Parameter>>();
    for (int i = 0; i < likelihoods.size(); i++) {
        patterns[i] = (SitePatterns) ((BeagleTreeLikelihood) likelihoods.get(i)).getPatternsList();
        siteCounts[i] = patterns[i].getPatternCount();
        treeModels[i] = ((BeagleTreeLikelihood) likelihoods.get(i)).getTreeModel();
        branchModels[i] = ((BeagleTreeLikelihood) likelihoods.get(i)).getBranchModel();
        siteRateModels[i] = (GammaSiteRateModel) ((BeagleTreeLikelihood) likelihoods.get(i)).getSiteRateModel();
        branchRateModels[i] = ((BeagleTreeLikelihood) likelihoods.get(i)).getBranchRateModel();
        ambiguities[i] = ((BeagleTreeLikelihood) likelihoods.get(i)).useAmbiguities();
        rescalingSchemes[i] = ((BeagleTreeLikelihood) likelihoods.get(i)).getRescalingScheme();
        isDelayRescalingUntilUnderflow[i] = ((BeagleTreeLikelihood) likelihoods.get(i)).isDelayRescalingUntilUnderflow();
        partialsRestrictions.add(i, ((BeagleTreeLikelihood) likelihoods.get(i)).getPartialsRestrictions());
    }
    if (DEBUG) {
        System.err.println("Pattern counts: ");
        for (int i = 0; i < siteCounts.length; i++) {
            System.err.println(siteCounts[i] + "   vs.    " + patterns[i].getPatternCount());
        }
        System.err.println();
        System.err.println("Instance counts: ");
        for (int i = 0; i < instanceCounts.length; i++) {
            System.err.print(instanceCounts[i] + " ");
        }
        System.err.println();
        System.err.println("Current locations: ");
        for (int i = 0; i < currentLocation.length; i++) {
            System.err.print(currentLocation[i] + " ");
        }
        System.err.println();
    }
    TestThreadedCompoundLikelihood compound = new TestThreadedCompoundLikelihood(likelihoods);
    if (DEBUG) {
        System.err.println("Timing estimates for each of the " + calibrate + " likelihood calculations:");
    }
    double start = System.nanoTime();
    for (int i = 0; i < calibrate; i++) {
        if (DEBUG) {
            //double debugStart = System.nanoTime();
            compound.makeDirty();
            compound.getLogLikelihood();
        //double debugEnd = System.nanoTime();
        //System.err.println(debugEnd - debugStart);
        } else {
            compound.makeDirty();
            compound.getLogLikelihood();
        }
    }
    double end = System.nanoTime();
    double baseResult = end - start;
    if (DEBUG) {
        System.err.println("Starting evaluation took: " + baseResult);
    }
    int longestIndex = 0;
    int longestSize = siteCounts[0];
    //START TEST CODE
    /*System.err.println("Detailed evaluation times: ");
        long[] evaluationTimes = compound.getEvaluationTimes();
        int[] evaluationCounts = compound.getEvaluationCounts();
        long longest = evaluationTimes[0];
        for (int i = 0; i < evaluationTimes.length; i++) {
        	System.err.println(i + ": time=" + evaluationTimes[i] + "   count=" + evaluationCounts[i]);
            if (evaluationTimes[i] > longest) {
            	longest = evaluationTimes[i];
        	}
        }*/
    //END TEST CODE
    /*if (SPLIT_BY_PATTERN_COUNT) {

        	boolean notFinished = true;

        	while (notFinished) {

        		for (int i = 0; i < siteCounts.length; i++) {
        			if (siteCounts[i] > longestSize) {
        				longestIndex = i;
        				longestSize = siteCounts[longestIndex];
        			}
        		}
        		System.err.println("Split likelihood " + longestIndex + " with pattern count " + longestSize);

        		//split it in 2
        		int instanceCount = ++instanceCounts[longestIndex];

        		List<Likelihood> newList = new ArrayList<Likelihood>();
        		for (int i = 0; i < instanceCount; i++) {
        			Patterns subPatterns = new Patterns(patterns[longestIndex], 0, 0, 1, i, instanceCount);

        			BeagleTreeLikelihood treeLikelihood = createTreeLikelihood(
        					subPatterns, treeModels[longestIndex], branchModels[longestIndex], siteRateModels[longestIndex], branchRateModels[longestIndex],
        					null, 
        					ambiguities[longestIndex], rescalingSchemes[longestIndex], partialsRestrictions.get(longestIndex),
        					xo);

        			treeLikelihood.setId(xo.getId() + "_" + instanceCount);
        			newList.add(treeLikelihood);
        		}
        		for (int i = 0; i < newList.size()-1; i++) {
        			likelihoods.remove(currentLocation[longestIndex]);
        		}
        		//likelihoods.remove(longestIndex);
        		//likelihoods.add(longestIndex, new CompoundLikelihood(newList));
        		for (int i = 0; i < newList.size(); i++) {
        			likelihoods.add(currentLocation[longestIndex], newList.get(i));
        		}
        		for (int i = longestIndex+1; i < currentLocation.length; i++) {
        			currentLocation[i]++;
        		}
        		//compound = new ThreadedCompoundLikelihood(likelihoods);
        		compound = new CompoundLikelihood(likelihoods);
        		siteCounts[longestIndex] = (instanceCount-1)*siteCounts[longestIndex]/instanceCount;
        		longestSize = (instanceCount-1)*longestSize/instanceCount;

        		//check number of likelihoods
        		System.err.println("Number of BeagleTreeLikelihoods: " + compound.getLikelihoodCount());
        		System.err.println("Pattern counts: ");
        		for (int i = 0;i < siteCounts.length; i++) {
        			System.err.print(siteCounts[i] + " ");
        		}
        		System.err.println();
        		System.err.println("Instance counts: ");
        		for (int i = 0;i < instanceCounts.length; i++) {
        			System.err.print(instanceCounts[i] + " ");
        		}
        		System.err.println();
        		System.err.println("Current locations: ");
        		for (int i = 0;i < currentLocation.length; i++) {
        			System.err.print(currentLocation[i] + " ");
        		}
        		System.err.println();

        		//evaluate speed
        		start = System.nanoTime();
        		for (int i = 0; i < TEST_RUNS; i++) {
        			compound.makeDirty();
        			compound.getLogLikelihood();
        		}
        		end = System.nanoTime();
        		double newResult = end - start;
        		System.err.println("New evaluation took: " + newResult + " vs. old evaluation: " + baseResult);

        		if (newResult < baseResult) {
            		baseResult = newResult;
            	} else {
            		notFinished = false;

            		//remove 1 instanceCount
            		System.err.print("Removing 1 instance count: " + instanceCount);
            		instanceCount = --instanceCounts[longestIndex];
            		System.err.println(" -> " + instanceCount + " for likelihood " + longestIndex);
            		newList = new ArrayList<Likelihood>();
                	for (int i = 0; i < instanceCount; i++) {
                		Patterns subPatterns = new Patterns(patterns[longestIndex], 0, 0, 1, i, instanceCount);

                		BeagleTreeLikelihood treeLikelihood = createTreeLikelihood(
                                subPatterns, treeModels[longestIndex], branchModels[longestIndex], siteRateModels[longestIndex], branchRateModels[longestIndex],
                                null, 
                                ambiguities[longestIndex], rescalingSchemes[longestIndex], partialsRestrictions.get(longestIndex),
                                xo);

                        treeLikelihood.setId(xo.getId() + "_" + instanceCount);
                        newList.add(treeLikelihood);
                	}
                	for (int i = 0; i < newList.size()+1; i++) {
            			likelihoods.remove(currentLocation[longestIndex]);
            		}
                	for (int i = 0; i < newList.size(); i++) {
                		likelihoods.add(currentLocation[longestIndex], newList.get(i));
                	}
                	for (int i = longestIndex+1; i < currentLocation.length; i++) {
            			currentLocation[i]--;
            		}
                	//likelihoods.remove(longestIndex);
                	//likelihoods.add(longestIndex, new CompoundLikelihood(newList));

                	//compound = new ThreadedCompoundLikelihood(likelihoods);
                	compound = new CompoundLikelihood(likelihoods);
                	siteCounts[longestIndex] = (instanceCount+1)*siteCounts[longestIndex]/instanceCount;
                	longestSize = (instanceCount+1)*longestSize/instanceCount;

                	System.err.println("Pattern counts: ");
                	for (int i = 0;i < siteCounts.length; i++) {
                		System.err.print(siteCounts[i] + " ");
                	}
                	System.err.println();
                	System.err.println("Instance counts: ");
                	for (int i = 0;i < instanceCounts.length; i++) {
                		System.err.print(instanceCounts[i] + " ");
                	}
                	System.err.println();
                	System.err.println("Current locations: ");
            		for (int i = 0;i < currentLocation.length; i++) {
            			System.err.print(currentLocation[i] + " ");
            		}
            		System.err.println();

            	}

        	}

        } else {*/
    //Try splitting the same likelihood until no further improvement, then move on towards the next one
    boolean notFinished = true;
    //construct list with likelihoods to split up
    List<Integer> splitList = new ArrayList<Integer>();
    for (int i = 0; i < siteCounts.length; i++) {
        int top = 0;
        for (int j = 0; j < siteCounts.length; j++) {
            if (siteCounts[j] > siteCounts[top]) {
                top = j;
            }
        }
        siteCounts[top] = 0;
        splitList.add(top);
    }
    for (int i = 0; i < likelihoods.size(); i++) {
        siteCounts[i] = patterns[i].getPatternCount();
        if (DEBUG) {
            System.err.println("Site count " + i + " = " + siteCounts[i]);
        }
    }
    if (DEBUG) {
        //print list
        System.err.print("Ordered list of likelihoods to be evaluated: ");
        for (int i = 0; i < splitList.size(); i++) {
            System.err.print(splitList.get(i) + " ");
        }
        System.err.println();
    }
    int timesRetried = 0;
    while (notFinished) {
        //split it in 1 more piece
        longestIndex = splitList.get(0);
        int instanceCount = ++instanceCounts[longestIndex];
        List<Likelihood> newList = new ArrayList<Likelihood>();
        for (int i = 0; i < instanceCount; i++) {
            Patterns subPatterns = new Patterns(patterns[longestIndex], 0, 0, 1, i, instanceCount);
            BeagleTreeLikelihood treeLikelihood = createTreeLikelihood(subPatterns, treeModels[longestIndex], branchModels[longestIndex], siteRateModels[longestIndex], branchRateModels[longestIndex], null, ambiguities[longestIndex], rescalingSchemes[longestIndex], isDelayRescalingUntilUnderflow[longestIndex], partialsRestrictions.get(longestIndex), xo);
            treeLikelihood.setId(xo.getId() + "_" + longestIndex + "_" + i);
            System.err.println(treeLikelihood.getId() + " created.");
            newList.add(treeLikelihood);
        }
        for (int i = 0; i < newList.size() - 1; i++) {
            likelihoods.remove(currentLocation[longestIndex]);
        }
        //likelihoods.add(longestIndex, new CompoundLikelihood(newList));
        for (int i = 0; i < newList.size(); i++) {
            likelihoods.add(currentLocation[longestIndex], newList.get(i));
        }
        for (int i = longestIndex + 1; i < currentLocation.length; i++) {
            currentLocation[i]++;
        }
        compound = new TestThreadedCompoundLikelihood(likelihoods);
        //compound = new CompoundLikelihood(likelihoods);
        //compound = new ThreadedCompoundLikelihood(likelihoods);
        siteCounts[longestIndex] = (instanceCount - 1) * siteCounts[longestIndex] / instanceCount;
        longestSize = (instanceCount - 1) * longestSize / instanceCount;
        if (DEBUG) {
            //check number of likelihoods
            System.err.println("Number of BeagleTreeLikelihoods: " + compound.getLikelihoodCount());
            System.err.println("Pattern counts: ");
            for (int i = 0; i < siteCounts.length; i++) {
                System.err.print(siteCounts[i] + " ");
            }
            System.err.println();
            System.err.println("Instance counts: ");
            for (int i = 0; i < instanceCounts.length; i++) {
                System.err.print(instanceCounts[i] + " ");
            }
            System.err.println();
            System.err.println("Current locations: ");
            for (int i = 0; i < currentLocation.length; i++) {
                System.err.print(currentLocation[i] + " ");
            }
            System.err.println();
        }
        //evaluate speed
        if (DEBUG) {
            System.err.println("Timing estimates for each of the " + calibrate + " likelihood calculations:");
        }
        start = System.nanoTime();
        for (int i = 0; i < calibrate; i++) {
            if (DEBUG) {
                //double debugStart = System.nanoTime();
                compound.makeDirty();
                compound.getLogLikelihood();
            //double debugEnd = System.nanoTime();
            //System.err.println(debugEnd - debugStart);
            } else {
                compound.makeDirty();
                compound.getLogLikelihood();
            }
        }
        end = System.nanoTime();
        double newResult = end - start;
        if (DEBUG) {
            System.err.println("New evaluation took: " + newResult + " vs. old evaluation: " + baseResult);
        }
        if (newResult < baseResult) {
            //new partitioning is faster, so partition further
            baseResult = newResult;
            //reorder split list
            if (DEBUG) {
                System.err.print("Current split list: ");
                for (int i = 0; i < splitList.size(); i++) {
                    System.err.print(splitList.get(i) + "  ");
                }
                System.err.println();
                System.err.print("Current pattern counts: ");
                for (int i = 0; i < splitList.size(); i++) {
                    System.err.print(siteCounts[splitList.get(i)] + "  ");
                }
                System.err.println();
            }
            int currentPatternCount = siteCounts[longestIndex];
            int findIndex = 0;
            for (int i = 0; i < splitList.size(); i++) {
                if (siteCounts[splitList.get(i)] > currentPatternCount) {
                    findIndex = i;
                }
            }
            if (DEBUG) {
                System.err.println("Current pattern count: " + currentPatternCount);
                System.err.println("Index found: " + findIndex + " with pattern count: " + siteCounts[findIndex]);
                System.err.println("Moving 0 to " + findIndex);
            }
            for (int i = 0; i < findIndex; i++) {
                int temp = splitList.get(i);
                splitList.set(i, splitList.get(i + 1));
                splitList.set(i + 1, temp);
            }
            if (DEBUG) {
                System.err.print("New split list: ");
                for (int i = 0; i < splitList.size(); i++) {
                    System.err.print(splitList.get(i) + "  ");
                }
                System.err.println();
                System.err.print("New pattern counts: ");
                for (int i = 0; i < splitList.size(); i++) {
                    System.err.print(siteCounts[splitList.get(i)] + "  ");
                }
                System.err.println();
            }
            timesRetried = 0;
        } else {
            if (DEBUG) {
                System.err.println("timesRetried = " + timesRetried + " vs. retry = " + retry);
            }
            //new partitioning is slower, so reinstate previous state unless RETRY is specified
            if (timesRetried < retry) {
                //try splitting further any way
                //do not set baseResult
                timesRetried++;
                if (DEBUG) {
                    System.err.println("RETRY number " + timesRetried);
                }
            } else {
                splitList.remove(0);
                if (splitList.size() == 0) {
                    notFinished = false;
                }
                //remove timesTried instanceCount(s)
                if (DEBUG) {
                    System.err.print("Removing " + (timesRetried + 1) + " instance count(s): " + instanceCount);
                }
                //instanceCount = --instanceCounts[longestIndex];
                instanceCounts[longestIndex] = instanceCounts[longestIndex] - (timesRetried + 1);
                instanceCount = instanceCounts[longestIndex];
                if (DEBUG) {
                    System.err.println(" -> " + instanceCount + " for likelihood " + longestIndex);
                }
                newList = new ArrayList<Likelihood>();
                for (int i = 0; i < instanceCount; i++) {
                    Patterns subPatterns = new Patterns(patterns[longestIndex], 0, 0, 1, i, instanceCount);
                    BeagleTreeLikelihood treeLikelihood = createTreeLikelihood(subPatterns, treeModels[longestIndex], branchModels[longestIndex], siteRateModels[longestIndex], branchRateModels[longestIndex], null, ambiguities[longestIndex], rescalingSchemes[longestIndex], isDelayRescalingUntilUnderflow[longestIndex], partialsRestrictions.get(longestIndex), xo);
                    treeLikelihood.setId(xo.getId() + "_" + longestIndex + "_" + i);
                    System.err.println(treeLikelihood.getId() + " created.");
                    newList.add(treeLikelihood);
                }
                /*for (int i = 0; i < newList.size()+1; i++) {
                        likelihoods.remove(currentLocation[longestIndex]);
                    }*/
                for (int i = 0; i < newList.size() + timesRetried + 1; i++) {
                    //TEST CODE START
                    unregisterAllModels((BeagleTreeLikelihood) likelihoods.get(currentLocation[longestIndex]));
                    //TEST CODE END
                    likelihoods.remove(currentLocation[longestIndex]);
                }
                for (int i = 0; i < newList.size(); i++) {
                    likelihoods.add(currentLocation[longestIndex], newList.get(i));
                }
                for (int i = longestIndex + 1; i < currentLocation.length; i++) {
                    currentLocation[i] -= (timesRetried + 1);
                }
                //likelihoods.remove(longestIndex);
                //likelihoods.add(longestIndex, new CompoundLikelihood(newList));
                compound = new TestThreadedCompoundLikelihood(likelihoods);
                //compound = new CompoundLikelihood(likelihoods);
                //compound = new ThreadedCompoundLikelihood(likelihoods);
                siteCounts[longestIndex] = (instanceCount + timesRetried + 1) * siteCounts[longestIndex] / instanceCount;
                longestSize = (instanceCount + timesRetried + 1) * longestSize / instanceCount;
                if (DEBUG) {
                    System.err.println("Pattern counts: ");
                    for (int i = 0; i < siteCounts.length; i++) {
                        System.err.print(siteCounts[i] + " ");
                    }
                    System.err.println();
                    System.err.println("Instance counts: ");
                    for (int i = 0; i < instanceCounts.length; i++) {
                        System.err.print(instanceCounts[i] + " ");
                    }
                    System.err.println();
                    System.err.println("Current locations: ");
                    for (int i = 0; i < currentLocation.length; i++) {
                        System.err.print(currentLocation[i] + " ");
                    }
                    System.err.println();
                }
                timesRetried = 0;
            }
        }
    }
    for (int i = 0; i < originalLikelihoods.size(); i++) {
        unregisterAllModels((BeagleTreeLikelihood) originalLikelihoods.get(i));
    }
    return compound;
}
Also used : Set(java.util.Set) BeagleTreeLikelihood(dr.evomodel.treelikelihood.BeagleTreeLikelihood) ArrayList(java.util.ArrayList) PartialsRescalingScheme(dr.evomodel.treelikelihood.PartialsRescalingScheme) BranchModel(dr.evomodel.branchmodel.BranchModel) TreeModel(dr.evomodel.tree.TreeModel) Patterns(dr.evolution.alignment.Patterns) SitePatterns(dr.evolution.alignment.SitePatterns) SitePatterns(dr.evolution.alignment.SitePatterns) BeagleTreeLikelihood(dr.evomodel.treelikelihood.BeagleTreeLikelihood) GammaSiteRateModel(dr.evomodel.siteratemodel.GammaSiteRateModel) BranchRateModel(dr.evomodel.branchratemodel.BranchRateModel) Map(java.util.Map)

Aggregations

SitePatterns (dr.evolution.alignment.SitePatterns)30 Parameter (dr.inference.model.Parameter)20 FrequencyModel (dr.oldevomodel.substmodel.FrequencyModel)16 GammaSiteModel (dr.oldevomodel.sitemodel.GammaSiteModel)15 TreeLikelihood (dr.oldevomodel.treelikelihood.TreeLikelihood)15 HKY (dr.oldevomodel.substmodel.HKY)11 ArrayList (java.util.ArrayList)7 BranchRateModel (dr.evomodel.branchratemodel.BranchRateModel)6 GammaSiteRateModel (dr.evomodel.siteratemodel.GammaSiteRateModel)6 TaxonList (dr.evolution.util.TaxonList)5 BranchModel (dr.evomodel.branchmodel.BranchModel)5 ExchangeOperator (dr.evomodel.operators.ExchangeOperator)5 SubtreeSlideOperator (dr.evomodel.operators.SubtreeSlideOperator)5 WilsonBalding (dr.evomodel.operators.WilsonBalding)5 TreeModel (dr.evomodel.tree.TreeModel)5 ArrayLogFormatter (dr.inference.loggers.ArrayLogFormatter)5 MCLogger (dr.inference.loggers.MCLogger)5 TabDelimitedFormatter (dr.inference.loggers.TabDelimitedFormatter)5 MCMC (dr.inference.mcmc.MCMC)5 MCMCOptions (dr.inference.mcmc.MCMCOptions)5