Search in sources :

Example 96 with JavaSparkContext

use of org.apache.spark.api.java.JavaSparkContext in project gatk-protected by broadinstitute.

the class AlleleFractionModellerUnitTest method testMCMC.

private void testMCMC(final double meanBiasSimulated, final double biasVarianceSimulated, final double meanBiasExpected, final double biasVarianceExpected, final AllelicPanelOfNormals allelicPoN) {
    LoggingUtils.setLoggingLevel(Log.LogLevel.INFO);
    final JavaSparkContext ctx = SparkContextFactory.getTestSparkContext();
    final int numSamples = 150;
    final int numBurnIn = 50;
    final double averageHetsPerSegment = 50;
    final int numSegments = 100;
    final int averageDepth = 50;
    final double outlierProbability = 0.02;
    // note: the following tolerances could actually be made much smaller if we used more segments and/or
    // more hets -- most of the error is the sampling error of a finite simulated data set, not numerical error of MCMC
    final double minorFractionTolerance = 0.02;
    final double meanBiasTolerance = 0.02;
    final double biasVarianceTolerance = 0.01;
    final double outlierProbabilityTolerance = 0.02;
    final AlleleFractionSimulatedData simulatedData = new AlleleFractionSimulatedData(averageHetsPerSegment, numSegments, averageDepth, meanBiasSimulated, biasVarianceSimulated, outlierProbability);
    final AlleleFractionModeller modeller = new AlleleFractionModeller(simulatedData.getSegmentedGenome(), allelicPoN);
    modeller.fitMCMC(numSamples, numBurnIn);
    final List<Double> meanBiasSamples = modeller.getmeanBiasSamples();
    Assert.assertEquals(meanBiasSamples.size(), numSamples - numBurnIn);
    final List<Double> biasVarianceSamples = modeller.getBiasVarianceSamples();
    Assert.assertEquals(biasVarianceSamples.size(), numSamples - numBurnIn);
    final List<Double> outlierProbabilitySamples = modeller.getOutlierProbabilitySamples();
    Assert.assertEquals(outlierProbabilitySamples.size(), numSamples - numBurnIn);
    final List<AlleleFractionState.MinorFractions> minorFractionsSamples = modeller.getMinorFractionsSamples();
    Assert.assertEquals(minorFractionsSamples.size(), numSamples - numBurnIn);
    for (final AlleleFractionState.MinorFractions sample : minorFractionsSamples) {
        Assert.assertEquals(sample.size(), numSegments);
    }
    final List<List<Double>> minorFractionsSamplesBySegment = modeller.getMinorFractionSamplesBySegment();
    final double mcmcMeanBias = meanBiasSamples.stream().mapToDouble(x -> x).average().getAsDouble();
    final double mcmcBiasVariance = biasVarianceSamples.stream().mapToDouble(x -> x).average().getAsDouble();
    final double mcmcOutlierProbability = outlierProbabilitySamples.stream().mapToDouble(x -> x).average().getAsDouble();
    final List<Double> mcmcMinorFractions = minorFractionsSamplesBySegment.stream().map(list -> list.stream().mapToDouble(x -> x).average().getAsDouble()).collect(Collectors.toList());
    double totalSegmentError = 0.0;
    for (int segment = 0; segment < numSegments; segment++) {
        totalSegmentError += Math.abs(mcmcMinorFractions.get(segment) - simulatedData.getTrueState().segmentMinorFraction(segment));
    }
    Assert.assertEquals(mcmcMeanBias, meanBiasExpected, meanBiasTolerance);
    Assert.assertEquals(mcmcBiasVariance, biasVarianceExpected, biasVarianceTolerance);
    Assert.assertEquals(mcmcOutlierProbability, outlierProbability, outlierProbabilityTolerance);
    Assert.assertEquals(totalSegmentError / numSegments, 0.0, minorFractionTolerance);
    //test posterior summaries
    final Map<AlleleFractionParameter, PosteriorSummary> globalParameterPosteriorSummaries = modeller.getGlobalParameterPosteriorSummaries(CREDIBLE_INTERVAL_ALPHA, ctx);
    final PosteriorSummary meanBiasPosteriorSummary = globalParameterPosteriorSummaries.get(AlleleFractionParameter.MEAN_BIAS);
    final double meanBiasPosteriorCenter = meanBiasPosteriorSummary.getCenter();
    Assert.assertEquals(meanBiasPosteriorCenter, meanBiasExpected, meanBiasTolerance);
    final PosteriorSummary biasVariancePosteriorSummary = globalParameterPosteriorSummaries.get(AlleleFractionParameter.BIAS_VARIANCE);
    final double biasVariancePosteriorCenter = biasVariancePosteriorSummary.getCenter();
    Assert.assertEquals(biasVariancePosteriorCenter, biasVarianceExpected, biasVarianceTolerance);
    final PosteriorSummary outlierProbabilityPosteriorSummary = globalParameterPosteriorSummaries.get(AlleleFractionParameter.OUTLIER_PROBABILITY);
    final double outlierProbabilityPosteriorCenter = outlierProbabilityPosteriorSummary.getCenter();
    Assert.assertEquals(outlierProbabilityPosteriorCenter, outlierProbability, outlierProbabilityTolerance);
    final List<PosteriorSummary> minorAlleleFractionPosteriorSummaries = modeller.getMinorAlleleFractionsPosteriorSummaries(CREDIBLE_INTERVAL_ALPHA, ctx);
    final List<Double> minorFractionsPosteriorCenters = minorAlleleFractionPosteriorSummaries.stream().map(PosteriorSummary::getCenter).collect(Collectors.toList());
    double totalPosteriorCentersSegmentError = 0.0;
    for (int segment = 0; segment < numSegments; segment++) {
        totalPosteriorCentersSegmentError += Math.abs(minorFractionsPosteriorCenters.get(segment) - simulatedData.getTrueState().segmentMinorFraction(segment));
    }
    Assert.assertEquals(totalPosteriorCentersSegmentError / numSegments, 0.0, minorFractionTolerance);
}
Also used : Arrays(java.util.Arrays) DataProvider(org.testng.annotations.DataProvider) BaseTest(org.broadinstitute.hellbender.utils.test.BaseTest) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) Genome(org.broadinstitute.hellbender.tools.exome.Genome) Test(org.testng.annotations.Test) SegmentUtils(org.broadinstitute.hellbender.tools.exome.SegmentUtils) SimpleInterval(org.broadinstitute.hellbender.utils.SimpleInterval) Collectors(java.util.stream.Collectors) File(java.io.File) List(java.util.List) Log(htsjdk.samtools.util.Log) Assert(org.testng.Assert) PosteriorSummary(org.broadinstitute.hellbender.utils.mcmc.PosteriorSummary) Map(java.util.Map) AllelicPanelOfNormals(org.broadinstitute.hellbender.tools.pon.allelic.AllelicPanelOfNormals) SparkContextFactory(org.broadinstitute.hellbender.engine.spark.SparkContextFactory) SegmentedGenome(org.broadinstitute.hellbender.tools.exome.SegmentedGenome) AllelicCountCollection(org.broadinstitute.hellbender.tools.exome.alleliccount.AllelicCountCollection) LoggingUtils(org.broadinstitute.hellbender.utils.LoggingUtils) PosteriorSummary(org.broadinstitute.hellbender.utils.mcmc.PosteriorSummary) List(java.util.List) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext)

Example 97 with JavaSparkContext

use of org.apache.spark.api.java.JavaSparkContext in project gatk-protected by broadinstitute.

the class AllelicSplitCallerUnitTest method testMakeCalls.

@Test
public void testMakeCalls() {
    // This mostly just checks that the calling does not crash and does produce results.
    final CNLOHCaller cnlohCaller = new CNLOHCaller();
    final JavaSparkContext ctx = SparkContextFactory.getTestSparkContext();
    final List<ACNVModeledSegment> segs = SegmentUtils.readACNVModeledSegmentFile(ACNV_SEG_FILE);
    SparkTestUtils.roundTripInKryo(segs.get(0), ACNVModeledSegment.class, ctx.getConf());
    // Make sure the CNLOH Caller is serializable before making calls.
    SparkTestUtils.roundTripInKryo(cnlohCaller, CNLOHCaller.class, ctx.getConf());
    final List<AllelicCalls> calls = cnlohCaller.makeCalls(segs, 2, ctx);
    Assert.assertNotNull(calls);
    Assert.assertTrue(calls.size() > 0);
    Assert.assertTrue(calls.stream().allMatch(c -> c.getBalancedCall() != null));
    Assert.assertTrue(calls.stream().allMatch(c -> c.getCnlohCall() != null));
    Assert.assertTrue(calls.stream().allMatch(c -> c.getAcnvSegment() != null));
    // Make sure the CNLOH Caller is serializable after making calls.
    SparkTestUtils.roundTripInKryo(cnlohCaller, CNLOHCaller.class, ctx.getConf());
    SparkTestUtils.roundTripInKryo(calls.get(0), AllelicCalls.class, ctx.getConf());
}
Also used : Arrays(java.util.Arrays) Pair(org.apache.commons.math3.util.Pair) DataProvider(org.testng.annotations.DataProvider) BaseTest(org.broadinstitute.hellbender.utils.test.BaseTest) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) Test(org.testng.annotations.Test) SegmentUtils(org.broadinstitute.hellbender.tools.exome.SegmentUtils) File(java.io.File) List(java.util.List) ACNVModeledSegment(org.broadinstitute.hellbender.tools.exome.ACNVModeledSegment) Assert(org.testng.Assert) SparkContextFactory(org.broadinstitute.hellbender.engine.spark.SparkContextFactory) HomoSapiensConstants(org.broadinstitute.hellbender.utils.variant.HomoSapiensConstants) SparkTestUtils(org.broadinstitute.hellbender.utils.test.SparkTestUtils) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) ACNVModeledSegment(org.broadinstitute.hellbender.tools.exome.ACNVModeledSegment) BaseTest(org.broadinstitute.hellbender.utils.test.BaseTest) Test(org.testng.annotations.Test)

Example 98 with JavaSparkContext

use of org.apache.spark.api.java.JavaSparkContext in project gatk by broadinstitute.

the class CNLOHCaller method calcNewRhos.

private double[] calcNewRhos(final List<ACNVModeledSegment> segments, final List<double[][][]> responsibilitiesBySeg, final double lambda, final double[] rhos, final int[] mVals, final int[] nVals, final JavaSparkContext ctx) {
    // Since, we pass in the entire responsibilities matrix, we need the correct index for each rho.  That, and the
    //  fact that this is a univariate objective function, means we need to create an instance for each rho.  And
    //  then we blast across Spark.
    final List<Pair<? extends Function<Double, Double>, SearchInterval>> objectives = IntStream.range(0, rhos.length).mapToObj(i -> new Pair<>(new Function<Double, Double>() {

        @Override
        public Double apply(Double rho) {
            return calculateESmnObjective(rho, segments, responsibilitiesBySeg, mVals, nVals, lambda, i);
        }
    }, new SearchInterval(0.0, 1.0, rhos[i]))).collect(Collectors.toList());
    final JavaRDD<Pair<? extends Function<Double, Double>, SearchInterval>> objectivesRDD = ctx.parallelize(objectives);
    final List<Double> resultsAsDouble = objectivesRDD.map(objective -> optimizeIt(objective.getFirst(), objective.getSecond())).collect();
    return resultsAsDouble.stream().mapToDouble(Double::doubleValue).toArray();
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) SearchInterval(org.apache.commons.math3.optim.univariate.SearchInterval) RealVector(org.apache.commons.math3.linear.RealVector) Function(java.util.function.Function) ParamUtils(org.broadinstitute.hellbender.utils.param.ParamUtils) GammaDistribution(org.apache.commons.math3.distribution.GammaDistribution) Gamma(org.apache.commons.math3.special.Gamma) ACNVModeledSegment(org.broadinstitute.hellbender.tools.exome.ACNVModeledSegment) BaseAbstractUnivariateIntegrator(org.apache.commons.math3.analysis.integration.BaseAbstractUnivariateIntegrator) GoalType(org.apache.commons.math3.optim.nonlinear.scalar.GoalType) MatrixUtils(org.apache.commons.math3.linear.MatrixUtils) UnivariateObjectiveFunction(org.apache.commons.math3.optim.univariate.UnivariateObjectiveFunction) ArrayRealVector(org.apache.commons.math3.linear.ArrayRealVector) SimpsonIntegrator(org.apache.commons.math3.analysis.integration.SimpsonIntegrator) JavaRDD(org.apache.spark.api.java.JavaRDD) GATKProtectedMathUtils(org.broadinstitute.hellbender.utils.GATKProtectedMathUtils) Pair(org.apache.commons.math3.util.Pair) Collectors(java.util.stream.Collectors) BrentOptimizer(org.apache.commons.math3.optim.univariate.BrentOptimizer) Serializable(java.io.Serializable) List(java.util.List) Percentile(org.apache.commons.math3.stat.descriptive.rank.Percentile) Logger(org.apache.logging.log4j.Logger) MathUtils(org.broadinstitute.hellbender.utils.MathUtils) NormalDistribution(org.apache.commons.math3.distribution.NormalDistribution) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) Variance(org.apache.commons.math3.stat.descriptive.moment.Variance) Utils(org.broadinstitute.hellbender.utils.Utils) RealMatrix(org.apache.commons.math3.linear.RealMatrix) VisibleForTesting(com.google.common.annotations.VisibleForTesting) HomoSapiensConstants(org.broadinstitute.hellbender.utils.variant.HomoSapiensConstants) MaxEval(org.apache.commons.math3.optim.MaxEval) LogManager(org.apache.logging.log4j.LogManager) Function(java.util.function.Function) UnivariateObjectiveFunction(org.apache.commons.math3.optim.univariate.UnivariateObjectiveFunction) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) SearchInterval(org.apache.commons.math3.optim.univariate.SearchInterval) Pair(org.apache.commons.math3.util.Pair)

Example 99 with JavaSparkContext

use of org.apache.spark.api.java.JavaSparkContext in project gatk by broadinstitute.

the class CountBasesSpark method runTool.

@Override
protected void runTool(final JavaSparkContext ctx) {
    final JavaRDD<GATKRead> reads = getReads();
    final long count = reads.map(r -> (long) r.getLength()).reduce(Long::sum);
    System.out.println(count);
    if (out != null) {
        try (final PrintStream ps = new PrintStream(BucketUtils.createFile(out))) {
            ps.print(count);
        }
    }
}
Also used : GATKRead(org.broadinstitute.hellbender.utils.read.GATKRead) DocumentedFeature(org.broadinstitute.barclay.help.DocumentedFeature) PrintStream(java.io.PrintStream) CommandLineProgramProperties(org.broadinstitute.barclay.argparser.CommandLineProgramProperties) BucketUtils(org.broadinstitute.hellbender.utils.gcs.BucketUtils) SparkProgramGroup(org.broadinstitute.hellbender.cmdline.programgroups.SparkProgramGroup) Argument(org.broadinstitute.barclay.argparser.Argument) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) GATKSparkTool(org.broadinstitute.hellbender.engine.spark.GATKSparkTool) StandardArgumentDefinitions(org.broadinstitute.hellbender.cmdline.StandardArgumentDefinitions) GATKRead(org.broadinstitute.hellbender.utils.read.GATKRead) JavaRDD(org.apache.spark.api.java.JavaRDD) PrintStream(java.io.PrintStream)

Example 100 with JavaSparkContext

use of org.apache.spark.api.java.JavaSparkContext in project gatk by broadinstitute.

the class CollectMultipleMetricsSpark method runTool.

@Override
protected void runTool(final JavaSparkContext ctx) {
    final JavaRDD<GATKRead> unFilteredReads = getUnfilteredReads();
    List<SparkCollectorProvider> collectorsToRun = getCollectorsToRun();
    if (collectorsToRun.size() > 1) {
        // if there is more than one collector to run, cache the
        // unfiltered RDD so we don't recompute it
        unFilteredReads.cache();
    }
    for (final SparkCollectorProvider provider : collectorsToRun) {
        MetricsCollectorSpark<? extends MetricsArgumentCollection> metricsCollector = provider.createCollector(outputBaseName, metricAccumulationLevel.accumulationLevels, getDefaultHeaders(), getHeaderForReads());
        validateCollector(metricsCollector, collectorsToRun.get(collectorsToRun.indexOf(provider)).getClass().getName());
        // Execute the collector's lifecycle
        //Bypass the framework merging of command line filters and just apply the default
        //ones specified by the collector
        ReadFilter readFilter = ReadFilter.fromList(metricsCollector.getDefaultReadFilters(), getHeaderForReads());
        metricsCollector.collectMetrics(unFilteredReads.filter(r -> readFilter.test(r)), getHeaderForReads());
        metricsCollector.saveMetrics(getReadSourceName(), getAuthHolder());
    }
}
Also used : GATKRead(org.broadinstitute.hellbender.utils.read.GATKRead) DocumentedFeature(org.broadinstitute.barclay.help.DocumentedFeature) CommandLineProgramProperties(org.broadinstitute.barclay.argparser.CommandLineProgramProperties) java.util(java.util) SparkProgramGroup(org.broadinstitute.hellbender.cmdline.programgroups.SparkProgramGroup) Header(htsjdk.samtools.metrics.Header) Argument(org.broadinstitute.barclay.argparser.Argument) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) GATKSparkTool(org.broadinstitute.hellbender.engine.spark.GATKSparkTool) StandardArgumentDefinitions(org.broadinstitute.hellbender.cmdline.StandardArgumentDefinitions) ArgumentCollection(org.broadinstitute.barclay.argparser.ArgumentCollection) ReadFilter(org.broadinstitute.hellbender.engine.filters.ReadFilter) GATKRead(org.broadinstitute.hellbender.utils.read.GATKRead) SAMFileHeader(htsjdk.samtools.SAMFileHeader) ReadUtils(org.broadinstitute.hellbender.utils.read.ReadUtils) org.broadinstitute.hellbender.metrics(org.broadinstitute.hellbender.metrics) MetricAccumulationLevelArgumentCollection(org.broadinstitute.hellbender.cmdline.argumentcollections.MetricAccumulationLevelArgumentCollection) JavaRDD(org.apache.spark.api.java.JavaRDD) ReadFilter(org.broadinstitute.hellbender.engine.filters.ReadFilter)

Aggregations

JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)251 Test (org.testng.annotations.Test)65 BaseTest (org.broadinstitute.hellbender.utils.test.BaseTest)64 Tuple2 (scala.Tuple2)48 SparkConf (org.apache.spark.SparkConf)46 Test (org.junit.Test)43 ArrayList (java.util.ArrayList)41 GATKRead (org.broadinstitute.hellbender.utils.read.GATKRead)32 List (java.util.List)26 Configuration (org.apache.hadoop.conf.Configuration)23 JavaRDD (org.apache.spark.api.java.JavaRDD)23 File (java.io.File)22 SimpleInterval (org.broadinstitute.hellbender.utils.SimpleInterval)20 Collectors (java.util.stream.Collectors)16 TextPipeline (org.deeplearning4j.spark.text.functions.TextPipeline)15 DataSet (org.nd4j.linalg.dataset.DataSet)15 IOException (java.io.IOException)13 SAMFileHeader (htsjdk.samtools.SAMFileHeader)12 RealMatrix (org.apache.commons.math3.linear.RealMatrix)12 SAMSequenceDictionary (htsjdk.samtools.SAMSequenceDictionary)11