use of org.apache.spark.api.java.JavaSparkContext in project gatk-protected by broadinstitute.
the class PCATangentNormalizationUtilsUnitTest method testSparkTangentNormalizeSparkVsNoSpark.
@Test
public void testSparkTangentNormalizeSparkVsNoSpark() {
final JavaSparkContext ctx = SparkContextFactory.getTestSparkContext();
final File ponFile = PoNTestUtils.createDummyHDF5FilePoN(TEST_PCOV_FILE, 20);
try (final HDF5File ponHDF5File = new HDF5File(ponFile)) {
final PCACoveragePoN pon = new HDF5PCACoveragePoN(ponHDF5File);
final PCATangentNormalizationResult tnWithSpark = pon.normalizeNormalsInPoN(ctx);
final PCATangentNormalizationResult tnWithoutSpark = pon.normalizeNormalsInPoN();
PoNTestUtils.assertEqualsMatrix(tnWithSpark.getTangentNormalized().counts(), tnWithoutSpark.getTangentNormalized().counts(), false);
PoNTestUtils.assertEqualsMatrix(tnWithSpark.getPreTangentNormalized().counts(), tnWithoutSpark.getPreTangentNormalized().counts(), false);
PoNTestUtils.assertEqualsMatrix(tnWithSpark.getTangentBetaHats(), tnWithoutSpark.getTangentBetaHats(), false);
}
}
use of org.apache.spark.api.java.JavaSparkContext in project gatk by broadinstitute.
the class BroadcastJoinReadsWithVariants method join.
public static JavaPairRDD<GATKRead, Iterable<GATKVariant>> join(final JavaRDD<GATKRead> reads, final JavaRDD<GATKVariant> variants) {
final JavaSparkContext ctx = new JavaSparkContext(reads.context());
final IntervalsSkipList<GATKVariant> variantSkipList = new IntervalsSkipList<>(variants.collect());
final Broadcast<IntervalsSkipList<GATKVariant>> variantsBroadcast = ctx.broadcast(variantSkipList);
return reads.mapToPair(r -> {
final IntervalsSkipList<GATKVariant> intervalsSkipList = variantsBroadcast.getValue();
if (SimpleInterval.isValid(r.getContig(), r.getStart(), r.getEnd())) {
return new Tuple2<>(r, intervalsSkipList.getOverlapping(new SimpleInterval(r)));
} else {
return new Tuple2<>(r, Collections.emptyList());
}
});
}
use of org.apache.spark.api.java.JavaSparkContext in project gatk by broadinstitute.
the class LocusWalkerSpark method getAlignments.
/**
* Loads alignments and the corresponding reference and features into a {@link JavaRDD} for the intervals specified.
*
* If no intervals were specified, returns all the alignments.
*
* @return all alignments as a {@link JavaRDD}, bounded by intervals if specified.
*/
public JavaRDD<LocusWalkerContext> getAlignments(JavaSparkContext ctx) {
SAMSequenceDictionary sequenceDictionary = getBestAvailableSequenceDictionary();
List<SimpleInterval> intervals = hasIntervals() ? getIntervals() : IntervalUtils.getAllIntervalsForReference(sequenceDictionary);
final List<ShardBoundary> intervalShards = intervals.stream().flatMap(interval -> Shard.divideIntervalIntoShards(interval, readShardSize, readShardPadding, sequenceDictionary).stream()).collect(Collectors.toList());
int maxLocatableSize = Math.min(readShardSize, readShardPadding);
JavaRDD<Shard<GATKRead>> shardedReads = SparkSharder.shard(ctx, getReads(), GATKRead.class, sequenceDictionary, intervalShards, maxLocatableSize, shuffle);
Broadcast<ReferenceMultiSource> bReferenceSource = hasReference() ? ctx.broadcast(getReference()) : null;
Broadcast<FeatureManager> bFeatureManager = features == null ? null : ctx.broadcast(features);
return shardedReads.flatMap(getAlignmentsFunction(bReferenceSource, bFeatureManager, sequenceDictionary, getHeaderForReads(), getDownsamplingInfo()));
}
use of org.apache.spark.api.java.JavaSparkContext in project gatk by broadinstitute.
the class VariantWalkerSpark method getVariants.
/**
* Loads variants and the corresponding reads, reference and features into a {@link JavaRDD} for the intervals specified.
* FOr the current implementation the reads context will always be empty.
*
* If no intervals were specified, returns all the variants.
*
* @return all variants as a {@link JavaRDD}, bounded by intervals if specified.
*/
public JavaRDD<VariantWalkerContext> getVariants(JavaSparkContext ctx) {
SAMSequenceDictionary sequenceDictionary = getBestAvailableSequenceDictionary();
List<SimpleInterval> intervals = hasIntervals() ? getIntervals() : IntervalUtils.getAllIntervalsForReference(sequenceDictionary);
// use unpadded shards (padding is only needed for reference bases)
final List<ShardBoundary> intervalShards = intervals.stream().flatMap(interval -> Shard.divideIntervalIntoShards(interval, variantShardSize, 0, sequenceDictionary).stream()).collect(Collectors.toList());
JavaRDD<VariantContext> variants = variantsSource.getParallelVariantContexts(drivingVariantFile, getIntervals());
VariantFilter variantFilter = makeVariantFilter();
variants = variants.filter(variantFilter::test);
JavaRDD<Shard<VariantContext>> shardedVariants = SparkSharder.shard(ctx, variants, VariantContext.class, sequenceDictionary, intervalShards, variantShardSize, shuffle);
Broadcast<ReferenceMultiSource> bReferenceSource = hasReference() ? ctx.broadcast(getReference()) : null;
Broadcast<FeatureManager> bFeatureManager = features == null ? null : ctx.broadcast(features);
return shardedVariants.flatMap(getVariantsFunction(bReferenceSource, bFeatureManager, sequenceDictionary, variantShardPadding));
}
use of org.apache.spark.api.java.JavaSparkContext in project gatk by broadinstitute.
the class HDF5PCACoveragePoNCreationUtilsUnitTest method testCalculateReducedPanelAndPInversesUsingJollifesRule.
@Test(dataProvider = "readCountOnlyWithDiverseShapeData")
public void testCalculateReducedPanelAndPInversesUsingJollifesRule(final ReadCountCollection readCounts) {
final JavaSparkContext ctx = SparkContextFactory.getTestSparkContext();
final ReductionResult result = HDF5PCACoveragePoNCreationUtils.calculateReducedPanelAndPInverses(readCounts, OptionalInt.empty(), NULL_LOGGER, ctx);
final RealMatrix counts = readCounts.counts();
Assert.assertNotNull(result);
Assert.assertNotNull(result.getPseudoInverse());
Assert.assertNotNull(result.getReducedCounts());
Assert.assertNotNull(result.getReducedPseudoInverse());
Assert.assertNotNull(result.getAllSingularValues());
Assert.assertEquals(counts.getColumnDimension(), result.getAllSingularValues().length);
Assert.assertEquals(result.getReducedCounts().getRowDimension(), counts.getRowDimension());
final int eigensamples = result.getReducedCounts().getColumnDimension();
final Mean mean = new Mean();
final double meanSingularValue = mean.evaluate(result.getAllSingularValues());
final double threshold = HDF5PCACoveragePoNCreationUtils.JOLLIFES_RULE_MEAN_FACTOR * meanSingularValue;
final int expectedEigensamples = (int) DoubleStream.of(result.getAllSingularValues()).filter(d -> d >= threshold).count();
Assert.assertTrue(eigensamples <= counts.getColumnDimension());
Assert.assertEquals(eigensamples, expectedEigensamples);
assertPseudoInverse(counts, result.getPseudoInverse());
assertPseudoInverse(result.getReducedCounts(), result.getReducedPseudoInverse());
}
Aggregations