use of org.apache.spark.api.java.function.FlatMapFunction in project beam by apache.
the class SparkCompat method extractOutput.
/**
* Extracts the output for a given collection of WindowedAccumulators.
*
* <p>This is required because the API of JavaPairRDD.flatMapValues is different among Spark
* versions. See https://issues.apache.org/jira/browse/SPARK-19287
*/
public static <K, InputT, AccumT, OutputT> JavaPairRDD<K, WindowedValue<OutputT>> extractOutput(JavaPairRDD<K, SparkCombineFn.WindowedAccumulator<KV<K, InputT>, InputT, AccumT, ?>> accumulatePerKey, SparkCombineFn<KV<K, InputT>, InputT, AccumT, OutputT> sparkCombineFn) {
try {
if (accumulatePerKey.context().version().startsWith("3")) {
FlatMapFunction<SparkCombineFn.WindowedAccumulator<KV<K, InputT>, InputT, AccumT, ?>, WindowedValue<OutputT>> flatMapFunction = (FlatMapFunction<SparkCombineFn.WindowedAccumulator<KV<K, InputT>, InputT, AccumT, ?>, WindowedValue<OutputT>>) windowedAccumulator -> sparkCombineFn.extractOutputStream(windowedAccumulator).iterator();
// This invokes by reflection the equivalent of:
// return accumulatePerKey.flatMapValues(flatMapFunction);
Method method = accumulatePerKey.getClass().getDeclaredMethod("flatMapValues", FlatMapFunction.class);
Object result = method.invoke(accumulatePerKey, flatMapFunction);
return (JavaPairRDD<K, WindowedValue<OutputT>>) result;
}
Function<SparkCombineFn.WindowedAccumulator<KV<K, InputT>, InputT, AccumT, ?>, Iterable<WindowedValue<OutputT>>> flatMapFunction = windowedAccumulator -> sparkCombineFn.extractOutputStream(windowedAccumulator).collect(Collectors.toList());
// This invokes by reflection the equivalent of:
// return accumulatePerKey.flatMapValues(flatMapFunction);
Method method = accumulatePerKey.getClass().getDeclaredMethod("flatMapValues", Function.class);
Object result = method.invoke(accumulatePerKey, flatMapFunction);
return (JavaPairRDD<K, WindowedValue<OutputT>>) result;
} catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) {
throw new RuntimeException("Error invoking Spark flatMapValues", e);
}
}
use of org.apache.spark.api.java.function.FlatMapFunction in project gatk by broadinstitute.
the class SparkSharder method computePartitionReadExtents.
/**
* For each partition, find the interval that spans it.
*/
static <L extends Locatable> List<PartitionLocatable<SimpleInterval>> computePartitionReadExtents(JavaRDD<L> locatables, SAMSequenceDictionary sequenceDictionary, int maxLocatableLength) {
// Find the first locatable in each partition. This is very efficient since only the first record in each partition is read.
// If a partition is empty then set the locatable to null
List<PartitionLocatable<L>> allSplitPoints = locatables.mapPartitions((FlatMapFunction<Iterator<L>, PartitionLocatable<L>>) it -> ImmutableList.of(new PartitionLocatable<>(-1, it.hasNext() ? it.next() : null)).iterator()).collect();
// fill in index and remove nulls (empty partitions)
List<PartitionLocatable<L>> splitPoints = new ArrayList<>();
for (int i = 0; i < allSplitPoints.size(); i++) {
L locatable = allSplitPoints.get(i).getLocatable();
if (locatable != null) {
splitPoints.add(new PartitionLocatable<L>(i, locatable));
}
}
List<PartitionLocatable<SimpleInterval>> extents = new ArrayList<>();
for (int i = 0; i < splitPoints.size(); i++) {
PartitionLocatable<L> splitPoint = splitPoints.get(i);
int partitionIndex = splitPoint.getPartitionIndex();
Locatable current = splitPoint.getLocatable();
int intervalContigIndex = sequenceDictionary.getSequenceIndex(current.getContig());
final Locatable next;
final int nextContigIndex;
if (i < splitPoints.size() - 1) {
next = splitPoints.get(i + 1);
nextContigIndex = sequenceDictionary.getSequenceIndex(next.getContig());
} else {
next = null;
nextContigIndex = sequenceDictionary.getSequences().size();
}
if (intervalContigIndex == nextContigIndex) {
// same contig
addPartitionReadExtent(extents, partitionIndex, current.getContig(), current.getStart(), next.getStart() + maxLocatableLength);
} else {
// complete current contig
int contigEnd = sequenceDictionary.getSequence(current.getContig()).getSequenceLength();
addPartitionReadExtent(extents, partitionIndex, current.getContig(), current.getStart(), contigEnd);
// add any whole contigs up to next (exclusive)
for (int contigIndex = intervalContigIndex + 1; contigIndex < nextContigIndex; contigIndex++) {
SAMSequenceRecord sequence = sequenceDictionary.getSequence(contigIndex);
addPartitionReadExtent(extents, partitionIndex, sequence.getSequenceName(), 1, sequence.getSequenceLength());
}
// add start of next contig
if (next != null) {
addPartitionReadExtent(extents, partitionIndex, next.getContig(), 1, next.getStart() + maxLocatableLength);
}
}
}
return extents;
}
use of org.apache.spark.api.java.function.FlatMapFunction in project gatk by broadinstitute.
the class AssemblyRegionWalkerSpark method getAssemblyRegionsFunction.
private static FlatMapFunction<Shard<GATKRead>, AssemblyRegionWalkerContext> getAssemblyRegionsFunction(final Broadcast<ReferenceMultiSource> bReferenceSource, final Broadcast<FeatureManager> bFeatureManager, final SAMSequenceDictionary sequenceDictionary, final SAMFileHeader header, final AssemblyRegionEvaluator evaluator, final int minAssemblyRegionSize, final int maxAssemblyRegionSize, final int assemblyRegionPadding, final double activeProbThreshold, final int maxProbPropagationDistance) {
return (FlatMapFunction<Shard<GATKRead>, AssemblyRegionWalkerContext>) shardedRead -> {
SimpleInterval paddedInterval = shardedRead.getPaddedInterval();
SimpleInterval assemblyRegionPaddedInterval = paddedInterval.expandWithinContig(assemblyRegionPadding, sequenceDictionary);
ReferenceDataSource reference = bReferenceSource == null ? null : new ReferenceMemorySource(bReferenceSource.getValue().getReferenceBases(null, assemblyRegionPaddedInterval), sequenceDictionary);
FeatureManager features = bFeatureManager == null ? null : bFeatureManager.getValue();
ReferenceContext referenceContext = new ReferenceContext(reference, paddedInterval);
FeatureContext featureContext = new FeatureContext(features, paddedInterval);
final Iterable<AssemblyRegion> assemblyRegions = AssemblyRegion.createFromReadShard(shardedRead, header, referenceContext, featureContext, evaluator, minAssemblyRegionSize, maxAssemblyRegionSize, assemblyRegionPadding, activeProbThreshold, maxProbPropagationDistance);
return StreamSupport.stream(assemblyRegions.spliterator(), false).map(assemblyRegion -> new AssemblyRegionWalkerContext(assemblyRegion, new ReferenceContext(reference, assemblyRegion.getExtendedSpan()), new FeatureContext(features, assemblyRegion.getExtendedSpan()))).iterator();
};
}
use of org.apache.spark.api.java.function.FlatMapFunction in project gatk by broadinstitute.
the class LocusWalkerSpark method getAlignmentsFunction.
/**
* Return a function that maps a {@link Shard} of reads into a tuple of alignments and their corresponding reference and features.
* @param bReferenceSource the reference source broadcast
* @param bFeatureManager the feature manager broadcast
* @param sequenceDictionary the sequence dictionary for the reads
* @param header the reads header
* @param downsamplingInfo the downsampling method for the reads
* @return a function that maps a {@link Shard} of reads into a tuple of alignments and their corresponding reference and features.
*/
private static FlatMapFunction<Shard<GATKRead>, LocusWalkerContext> getAlignmentsFunction(Broadcast<ReferenceMultiSource> bReferenceSource, Broadcast<FeatureManager> bFeatureManager, SAMSequenceDictionary sequenceDictionary, SAMFileHeader header, LIBSDownsamplingInfo downsamplingInfo) {
return (FlatMapFunction<Shard<GATKRead>, LocusWalkerContext>) shardedRead -> {
SimpleInterval interval = shardedRead.getInterval();
SimpleInterval paddedInterval = shardedRead.getPaddedInterval();
Iterator<GATKRead> readIterator = shardedRead.iterator();
ReferenceDataSource reference = bReferenceSource == null ? null : new ReferenceMemorySource(bReferenceSource.getValue().getReferenceBases(null, paddedInterval), sequenceDictionary);
FeatureManager fm = bFeatureManager == null ? null : bFeatureManager.getValue();
final Set<String> samples = header.getReadGroups().stream().map(SAMReadGroupRecord::getSample).collect(Collectors.toSet());
LocusIteratorByState libs = new LocusIteratorByState(readIterator, downsamplingInfo, false, samples, header, true, false);
IntervalOverlappingIterator<AlignmentContext> alignmentContexts = new IntervalOverlappingIterator<>(libs, ImmutableList.of(interval), sequenceDictionary);
final Spliterator<AlignmentContext> alignmentContextSpliterator = Spliterators.spliteratorUnknownSize(alignmentContexts, 0);
return StreamSupport.stream(alignmentContextSpliterator, false).map(alignmentContext -> {
final SimpleInterval alignmentInterval = new SimpleInterval(alignmentContext);
return new LocusWalkerContext(alignmentContext, new ReferenceContext(reference, alignmentInterval), new FeatureContext(fm, alignmentInterval));
}).iterator();
};
}
use of org.apache.spark.api.java.function.FlatMapFunction in project gatk by broadinstitute.
the class ReadWalkerSpark method getReadsFunction.
private static FlatMapFunction<Shard<GATKRead>, ReadWalkerContext> getReadsFunction(Broadcast<ReferenceMultiSource> bReferenceSource, Broadcast<FeatureManager> bFeatureManager, SAMSequenceDictionary sequenceDictionary, int readShardPadding) {
return (FlatMapFunction<Shard<GATKRead>, ReadWalkerContext>) shard -> {
SimpleInterval paddedInterval = shard.getInterval().expandWithinContig(readShardPadding, sequenceDictionary);
ReferenceDataSource reference = bReferenceSource == null ? null : new ReferenceMemorySource(bReferenceSource.getValue().getReferenceBases(null, paddedInterval), sequenceDictionary);
FeatureManager features = bFeatureManager == null ? null : bFeatureManager.getValue();
return StreamSupport.stream(shard.spliterator(), false).map(r -> {
final SimpleInterval readInterval = getReadInterval(r);
return new ReadWalkerContext(r, new ReferenceContext(reference, readInterval), new FeatureContext(features, readInterval));
}).iterator();
};
}
Aggregations