use of org.apache.spark.api.java.JavaSparkContext in project gatk-protected by broadinstitute.
the class AlleleFractionModellerUnitTest method testMCMC.
private void testMCMC(final double meanBiasSimulated, final double biasVarianceSimulated, final double meanBiasExpected, final double biasVarianceExpected, final AllelicPanelOfNormals allelicPoN) {
LoggingUtils.setLoggingLevel(Log.LogLevel.INFO);
final JavaSparkContext ctx = SparkContextFactory.getTestSparkContext();
final int numSamples = 150;
final int numBurnIn = 50;
final double averageHetsPerSegment = 50;
final int numSegments = 100;
final int averageDepth = 50;
final double outlierProbability = 0.02;
// note: the following tolerances could actually be made much smaller if we used more segments and/or
// more hets -- most of the error is the sampling error of a finite simulated data set, not numerical error of MCMC
final double minorFractionTolerance = 0.02;
final double meanBiasTolerance = 0.02;
final double biasVarianceTolerance = 0.01;
final double outlierProbabilityTolerance = 0.02;
final AlleleFractionSimulatedData simulatedData = new AlleleFractionSimulatedData(averageHetsPerSegment, numSegments, averageDepth, meanBiasSimulated, biasVarianceSimulated, outlierProbability);
final AlleleFractionModeller modeller = new AlleleFractionModeller(simulatedData.getSegmentedGenome(), allelicPoN);
modeller.fitMCMC(numSamples, numBurnIn);
final List<Double> meanBiasSamples = modeller.getmeanBiasSamples();
Assert.assertEquals(meanBiasSamples.size(), numSamples - numBurnIn);
final List<Double> biasVarianceSamples = modeller.getBiasVarianceSamples();
Assert.assertEquals(biasVarianceSamples.size(), numSamples - numBurnIn);
final List<Double> outlierProbabilitySamples = modeller.getOutlierProbabilitySamples();
Assert.assertEquals(outlierProbabilitySamples.size(), numSamples - numBurnIn);
final List<AlleleFractionState.MinorFractions> minorFractionsSamples = modeller.getMinorFractionsSamples();
Assert.assertEquals(minorFractionsSamples.size(), numSamples - numBurnIn);
for (final AlleleFractionState.MinorFractions sample : minorFractionsSamples) {
Assert.assertEquals(sample.size(), numSegments);
}
final List<List<Double>> minorFractionsSamplesBySegment = modeller.getMinorFractionSamplesBySegment();
final double mcmcMeanBias = meanBiasSamples.stream().mapToDouble(x -> x).average().getAsDouble();
final double mcmcBiasVariance = biasVarianceSamples.stream().mapToDouble(x -> x).average().getAsDouble();
final double mcmcOutlierProbability = outlierProbabilitySamples.stream().mapToDouble(x -> x).average().getAsDouble();
final List<Double> mcmcMinorFractions = minorFractionsSamplesBySegment.stream().map(list -> list.stream().mapToDouble(x -> x).average().getAsDouble()).collect(Collectors.toList());
double totalSegmentError = 0.0;
for (int segment = 0; segment < numSegments; segment++) {
totalSegmentError += Math.abs(mcmcMinorFractions.get(segment) - simulatedData.getTrueState().segmentMinorFraction(segment));
}
Assert.assertEquals(mcmcMeanBias, meanBiasExpected, meanBiasTolerance);
Assert.assertEquals(mcmcBiasVariance, biasVarianceExpected, biasVarianceTolerance);
Assert.assertEquals(mcmcOutlierProbability, outlierProbability, outlierProbabilityTolerance);
Assert.assertEquals(totalSegmentError / numSegments, 0.0, minorFractionTolerance);
//test posterior summaries
final Map<AlleleFractionParameter, PosteriorSummary> globalParameterPosteriorSummaries = modeller.getGlobalParameterPosteriorSummaries(CREDIBLE_INTERVAL_ALPHA, ctx);
final PosteriorSummary meanBiasPosteriorSummary = globalParameterPosteriorSummaries.get(AlleleFractionParameter.MEAN_BIAS);
final double meanBiasPosteriorCenter = meanBiasPosteriorSummary.getCenter();
Assert.assertEquals(meanBiasPosteriorCenter, meanBiasExpected, meanBiasTolerance);
final PosteriorSummary biasVariancePosteriorSummary = globalParameterPosteriorSummaries.get(AlleleFractionParameter.BIAS_VARIANCE);
final double biasVariancePosteriorCenter = biasVariancePosteriorSummary.getCenter();
Assert.assertEquals(biasVariancePosteriorCenter, biasVarianceExpected, biasVarianceTolerance);
final PosteriorSummary outlierProbabilityPosteriorSummary = globalParameterPosteriorSummaries.get(AlleleFractionParameter.OUTLIER_PROBABILITY);
final double outlierProbabilityPosteriorCenter = outlierProbabilityPosteriorSummary.getCenter();
Assert.assertEquals(outlierProbabilityPosteriorCenter, outlierProbability, outlierProbabilityTolerance);
final List<PosteriorSummary> minorAlleleFractionPosteriorSummaries = modeller.getMinorAlleleFractionsPosteriorSummaries(CREDIBLE_INTERVAL_ALPHA, ctx);
final List<Double> minorFractionsPosteriorCenters = minorAlleleFractionPosteriorSummaries.stream().map(PosteriorSummary::getCenter).collect(Collectors.toList());
double totalPosteriorCentersSegmentError = 0.0;
for (int segment = 0; segment < numSegments; segment++) {
totalPosteriorCentersSegmentError += Math.abs(minorFractionsPosteriorCenters.get(segment) - simulatedData.getTrueState().segmentMinorFraction(segment));
}
Assert.assertEquals(totalPosteriorCentersSegmentError / numSegments, 0.0, minorFractionTolerance);
}
use of org.apache.spark.api.java.JavaSparkContext in project gatk-protected by broadinstitute.
the class AllelicSplitCallerUnitTest method testMakeCalls.
@Test
public void testMakeCalls() {
// This mostly just checks that the calling does not crash and does produce results.
final CNLOHCaller cnlohCaller = new CNLOHCaller();
final JavaSparkContext ctx = SparkContextFactory.getTestSparkContext();
final List<ACNVModeledSegment> segs = SegmentUtils.readACNVModeledSegmentFile(ACNV_SEG_FILE);
SparkTestUtils.roundTripInKryo(segs.get(0), ACNVModeledSegment.class, ctx.getConf());
// Make sure the CNLOH Caller is serializable before making calls.
SparkTestUtils.roundTripInKryo(cnlohCaller, CNLOHCaller.class, ctx.getConf());
final List<AllelicCalls> calls = cnlohCaller.makeCalls(segs, 2, ctx);
Assert.assertNotNull(calls);
Assert.assertTrue(calls.size() > 0);
Assert.assertTrue(calls.stream().allMatch(c -> c.getBalancedCall() != null));
Assert.assertTrue(calls.stream().allMatch(c -> c.getCnlohCall() != null));
Assert.assertTrue(calls.stream().allMatch(c -> c.getAcnvSegment() != null));
// Make sure the CNLOH Caller is serializable after making calls.
SparkTestUtils.roundTripInKryo(cnlohCaller, CNLOHCaller.class, ctx.getConf());
SparkTestUtils.roundTripInKryo(calls.get(0), AllelicCalls.class, ctx.getConf());
}
use of org.apache.spark.api.java.JavaSparkContext in project gatk by broadinstitute.
the class CNLOHCaller method calcNewRhos.
private double[] calcNewRhos(final List<ACNVModeledSegment> segments, final List<double[][][]> responsibilitiesBySeg, final double lambda, final double[] rhos, final int[] mVals, final int[] nVals, final JavaSparkContext ctx) {
// Since, we pass in the entire responsibilities matrix, we need the correct index for each rho. That, and the
// fact that this is a univariate objective function, means we need to create an instance for each rho. And
// then we blast across Spark.
final List<Pair<? extends Function<Double, Double>, SearchInterval>> objectives = IntStream.range(0, rhos.length).mapToObj(i -> new Pair<>(new Function<Double, Double>() {
@Override
public Double apply(Double rho) {
return calculateESmnObjective(rho, segments, responsibilitiesBySeg, mVals, nVals, lambda, i);
}
}, new SearchInterval(0.0, 1.0, rhos[i]))).collect(Collectors.toList());
final JavaRDD<Pair<? extends Function<Double, Double>, SearchInterval>> objectivesRDD = ctx.parallelize(objectives);
final List<Double> resultsAsDouble = objectivesRDD.map(objective -> optimizeIt(objective.getFirst(), objective.getSecond())).collect();
return resultsAsDouble.stream().mapToDouble(Double::doubleValue).toArray();
}
use of org.apache.spark.api.java.JavaSparkContext in project gatk by broadinstitute.
the class CountBasesSpark method runTool.
@Override
protected void runTool(final JavaSparkContext ctx) {
final JavaRDD<GATKRead> reads = getReads();
final long count = reads.map(r -> (long) r.getLength()).reduce(Long::sum);
System.out.println(count);
if (out != null) {
try (final PrintStream ps = new PrintStream(BucketUtils.createFile(out))) {
ps.print(count);
}
}
}
use of org.apache.spark.api.java.JavaSparkContext in project gatk by broadinstitute.
the class CollectMultipleMetricsSpark method runTool.
@Override
protected void runTool(final JavaSparkContext ctx) {
final JavaRDD<GATKRead> unFilteredReads = getUnfilteredReads();
List<SparkCollectorProvider> collectorsToRun = getCollectorsToRun();
if (collectorsToRun.size() > 1) {
// if there is more than one collector to run, cache the
// unfiltered RDD so we don't recompute it
unFilteredReads.cache();
}
for (final SparkCollectorProvider provider : collectorsToRun) {
MetricsCollectorSpark<? extends MetricsArgumentCollection> metricsCollector = provider.createCollector(outputBaseName, metricAccumulationLevel.accumulationLevels, getDefaultHeaders(), getHeaderForReads());
validateCollector(metricsCollector, collectorsToRun.get(collectorsToRun.indexOf(provider)).getClass().getName());
// Execute the collector's lifecycle
//Bypass the framework merging of command line filters and just apply the default
//ones specified by the collector
ReadFilter readFilter = ReadFilter.fromList(metricsCollector.getDefaultReadFilters(), getHeaderForReads());
metricsCollector.collectMetrics(unFilteredReads.filter(r -> readFilter.test(r)), getHeaderForReads());
metricsCollector.saveMetrics(getReadSourceName(), getAuthHolder());
}
}
Aggregations