Search in sources :

Example 11 with FlatMapFunction

use of org.apache.spark.api.java.function.FlatMapFunction in project beam by apache.

the class SparkCompat method extractOutput.

/**
 * Extracts the output for a given collection of WindowedAccumulators.
 *
 * <p>This is required because the API of JavaPairRDD.flatMapValues is different among Spark
 * versions. See https://issues.apache.org/jira/browse/SPARK-19287
 */
public static <K, InputT, AccumT, OutputT> JavaPairRDD<K, WindowedValue<OutputT>> extractOutput(JavaPairRDD<K, SparkCombineFn.WindowedAccumulator<KV<K, InputT>, InputT, AccumT, ?>> accumulatePerKey, SparkCombineFn<KV<K, InputT>, InputT, AccumT, OutputT> sparkCombineFn) {
    try {
        if (accumulatePerKey.context().version().startsWith("3")) {
            FlatMapFunction<SparkCombineFn.WindowedAccumulator<KV<K, InputT>, InputT, AccumT, ?>, WindowedValue<OutputT>> flatMapFunction = (FlatMapFunction<SparkCombineFn.WindowedAccumulator<KV<K, InputT>, InputT, AccumT, ?>, WindowedValue<OutputT>>) windowedAccumulator -> sparkCombineFn.extractOutputStream(windowedAccumulator).iterator();
            // This invokes by reflection the equivalent of:
            // return accumulatePerKey.flatMapValues(flatMapFunction);
            Method method = accumulatePerKey.getClass().getDeclaredMethod("flatMapValues", FlatMapFunction.class);
            Object result = method.invoke(accumulatePerKey, flatMapFunction);
            return (JavaPairRDD<K, WindowedValue<OutputT>>) result;
        }
        Function<SparkCombineFn.WindowedAccumulator<KV<K, InputT>, InputT, AccumT, ?>, Iterable<WindowedValue<OutputT>>> flatMapFunction = windowedAccumulator -> sparkCombineFn.extractOutputStream(windowedAccumulator).collect(Collectors.toList());
        // This invokes by reflection the equivalent of:
        // return accumulatePerKey.flatMapValues(flatMapFunction);
        Method method = accumulatePerKey.getClass().getDeclaredMethod("flatMapValues", Function.class);
        Object result = method.invoke(accumulatePerKey, flatMapFunction);
        return (JavaPairRDD<K, WindowedValue<OutputT>>) result;
    } catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) {
        throw new RuntimeException("Error invoking Spark flatMapValues", e);
    }
}
Also used : SparkListenerApplicationStart(org.apache.spark.scheduler.SparkListenerApplicationStart) SparkCombineFn(org.apache.beam.runners.spark.translation.SparkCombineFn) KV(org.apache.beam.sdk.values.KV) WindowedValue(org.apache.beam.sdk.util.WindowedValue) JavaStreamingContext(org.apache.spark.streaming.api.java.JavaStreamingContext) PipelineResult(org.apache.beam.sdk.PipelineResult) ApplicationNameOptions(org.apache.beam.sdk.options.ApplicationNameOptions) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) Option(scala.Option) Constructor(java.lang.reflect.Constructor) Collectors(java.util.stream.Collectors) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) InvocationTargetException(java.lang.reflect.InvocationTargetException) SparkBeamMetric(org.apache.beam.runners.spark.metrics.SparkBeamMetric) List(java.util.List) JavaConverters(scala.collection.JavaConverters) JavaDStream(org.apache.spark.streaming.api.java.JavaDStream) Function(org.apache.spark.api.java.function.Function) Method(java.lang.reflect.Method) SparkPipelineOptions(org.apache.beam.runners.spark.SparkPipelineOptions) FlatMapFunction(org.apache.spark.api.java.function.FlatMapFunction) KV(org.apache.beam.sdk.values.KV) Method(java.lang.reflect.Method) InvocationTargetException(java.lang.reflect.InvocationTargetException) WindowedValue(org.apache.beam.sdk.util.WindowedValue) FlatMapFunction(org.apache.spark.api.java.function.FlatMapFunction) SparkCombineFn(org.apache.beam.runners.spark.translation.SparkCombineFn) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD)

Example 12 with FlatMapFunction

use of org.apache.spark.api.java.function.FlatMapFunction in project gatk by broadinstitute.

the class SparkSharder method computePartitionReadExtents.

/**
     * For each partition, find the interval that spans it.
     */
static <L extends Locatable> List<PartitionLocatable<SimpleInterval>> computePartitionReadExtents(JavaRDD<L> locatables, SAMSequenceDictionary sequenceDictionary, int maxLocatableLength) {
    // Find the first locatable in each partition. This is very efficient since only the first record in each partition is read.
    // If a partition is empty then set the locatable to null
    List<PartitionLocatable<L>> allSplitPoints = locatables.mapPartitions((FlatMapFunction<Iterator<L>, PartitionLocatable<L>>) it -> ImmutableList.of(new PartitionLocatable<>(-1, it.hasNext() ? it.next() : null)).iterator()).collect();
    // fill in index and remove nulls (empty partitions)
    List<PartitionLocatable<L>> splitPoints = new ArrayList<>();
    for (int i = 0; i < allSplitPoints.size(); i++) {
        L locatable = allSplitPoints.get(i).getLocatable();
        if (locatable != null) {
            splitPoints.add(new PartitionLocatable<L>(i, locatable));
        }
    }
    List<PartitionLocatable<SimpleInterval>> extents = new ArrayList<>();
    for (int i = 0; i < splitPoints.size(); i++) {
        PartitionLocatable<L> splitPoint = splitPoints.get(i);
        int partitionIndex = splitPoint.getPartitionIndex();
        Locatable current = splitPoint.getLocatable();
        int intervalContigIndex = sequenceDictionary.getSequenceIndex(current.getContig());
        final Locatable next;
        final int nextContigIndex;
        if (i < splitPoints.size() - 1) {
            next = splitPoints.get(i + 1);
            nextContigIndex = sequenceDictionary.getSequenceIndex(next.getContig());
        } else {
            next = null;
            nextContigIndex = sequenceDictionary.getSequences().size();
        }
        if (intervalContigIndex == nextContigIndex) {
            // same contig
            addPartitionReadExtent(extents, partitionIndex, current.getContig(), current.getStart(), next.getStart() + maxLocatableLength);
        } else {
            // complete current contig
            int contigEnd = sequenceDictionary.getSequence(current.getContig()).getSequenceLength();
            addPartitionReadExtent(extents, partitionIndex, current.getContig(), current.getStart(), contigEnd);
            // add any whole contigs up to next (exclusive)
            for (int contigIndex = intervalContigIndex + 1; contigIndex < nextContigIndex; contigIndex++) {
                SAMSequenceRecord sequence = sequenceDictionary.getSequence(contigIndex);
                addPartitionReadExtent(extents, partitionIndex, sequence.getSequenceName(), 1, sequence.getSequenceLength());
            }
            // add start of next contig
            if (next != null) {
                addPartitionReadExtent(extents, partitionIndex, next.getContig(), 1, next.getStart() + maxLocatableLength);
            }
        }
    }
    return extents;
}
Also used : FlatMapFunction(org.apache.spark.api.java.function.FlatMapFunction) SAMSequenceRecord(htsjdk.samtools.SAMSequenceRecord) Locatable(htsjdk.samtools.util.Locatable)

Example 13 with FlatMapFunction

use of org.apache.spark.api.java.function.FlatMapFunction in project gatk by broadinstitute.

the class AssemblyRegionWalkerSpark method getAssemblyRegionsFunction.

private static FlatMapFunction<Shard<GATKRead>, AssemblyRegionWalkerContext> getAssemblyRegionsFunction(final Broadcast<ReferenceMultiSource> bReferenceSource, final Broadcast<FeatureManager> bFeatureManager, final SAMSequenceDictionary sequenceDictionary, final SAMFileHeader header, final AssemblyRegionEvaluator evaluator, final int minAssemblyRegionSize, final int maxAssemblyRegionSize, final int assemblyRegionPadding, final double activeProbThreshold, final int maxProbPropagationDistance) {
    return (FlatMapFunction<Shard<GATKRead>, AssemblyRegionWalkerContext>) shardedRead -> {
        SimpleInterval paddedInterval = shardedRead.getPaddedInterval();
        SimpleInterval assemblyRegionPaddedInterval = paddedInterval.expandWithinContig(assemblyRegionPadding, sequenceDictionary);
        ReferenceDataSource reference = bReferenceSource == null ? null : new ReferenceMemorySource(bReferenceSource.getValue().getReferenceBases(null, assemblyRegionPaddedInterval), sequenceDictionary);
        FeatureManager features = bFeatureManager == null ? null : bFeatureManager.getValue();
        ReferenceContext referenceContext = new ReferenceContext(reference, paddedInterval);
        FeatureContext featureContext = new FeatureContext(features, paddedInterval);
        final Iterable<AssemblyRegion> assemblyRegions = AssemblyRegion.createFromReadShard(shardedRead, header, referenceContext, featureContext, evaluator, minAssemblyRegionSize, maxAssemblyRegionSize, assemblyRegionPadding, activeProbThreshold, maxProbPropagationDistance);
        return StreamSupport.stream(assemblyRegions.spliterator(), false).map(assemblyRegion -> new AssemblyRegionWalkerContext(assemblyRegion, new ReferenceContext(reference, assemblyRegion.getExtendedSpan()), new FeatureContext(features, assemblyRegion.getExtendedSpan()))).iterator();
    };
}
Also used : GATKRead(org.broadinstitute.hellbender.utils.read.GATKRead) Broadcast(org.apache.spark.broadcast.Broadcast) ReferenceMultiSource(org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource) SAMSequenceDictionary(htsjdk.samtools.SAMSequenceDictionary) Argument(org.broadinstitute.barclay.argparser.Argument) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) Advanced(org.broadinstitute.barclay.argparser.Advanced) ReadFilter(org.broadinstitute.hellbender.engine.filters.ReadFilter) GATKRead(org.broadinstitute.hellbender.utils.read.GATKRead) SAMFileHeader(htsjdk.samtools.SAMFileHeader) SimpleInterval(org.broadinstitute.hellbender.utils.SimpleInterval) Collectors(java.util.stream.Collectors) ArrayList(java.util.ArrayList) org.broadinstitute.hellbender.engine(org.broadinstitute.hellbender.engine) List(java.util.List) IntervalUtils(org.broadinstitute.hellbender.utils.IntervalUtils) WellformedReadFilter(org.broadinstitute.hellbender.engine.filters.WellformedReadFilter) StreamSupport(java.util.stream.StreamSupport) JavaRDD(org.apache.spark.api.java.JavaRDD) FlatMapFunction(org.apache.spark.api.java.function.FlatMapFunction) ReadFilterLibrary(org.broadinstitute.hellbender.engine.filters.ReadFilterLibrary) FlatMapFunction(org.apache.spark.api.java.function.FlatMapFunction) SimpleInterval(org.broadinstitute.hellbender.utils.SimpleInterval)

Example 14 with FlatMapFunction

use of org.apache.spark.api.java.function.FlatMapFunction in project gatk by broadinstitute.

the class LocusWalkerSpark method getAlignmentsFunction.

/**
     * Return a function that maps a {@link Shard} of reads into a tuple of alignments and their corresponding reference and features.
     * @param bReferenceSource the reference source broadcast
     * @param bFeatureManager the feature manager broadcast
     * @param sequenceDictionary the sequence dictionary for the reads
     * @param header the reads header
     * @param downsamplingInfo the downsampling method for the reads
     * @return a function that maps a {@link Shard} of reads into a tuple of alignments and their corresponding reference and features.
     */
private static FlatMapFunction<Shard<GATKRead>, LocusWalkerContext> getAlignmentsFunction(Broadcast<ReferenceMultiSource> bReferenceSource, Broadcast<FeatureManager> bFeatureManager, SAMSequenceDictionary sequenceDictionary, SAMFileHeader header, LIBSDownsamplingInfo downsamplingInfo) {
    return (FlatMapFunction<Shard<GATKRead>, LocusWalkerContext>) shardedRead -> {
        SimpleInterval interval = shardedRead.getInterval();
        SimpleInterval paddedInterval = shardedRead.getPaddedInterval();
        Iterator<GATKRead> readIterator = shardedRead.iterator();
        ReferenceDataSource reference = bReferenceSource == null ? null : new ReferenceMemorySource(bReferenceSource.getValue().getReferenceBases(null, paddedInterval), sequenceDictionary);
        FeatureManager fm = bFeatureManager == null ? null : bFeatureManager.getValue();
        final Set<String> samples = header.getReadGroups().stream().map(SAMReadGroupRecord::getSample).collect(Collectors.toSet());
        LocusIteratorByState libs = new LocusIteratorByState(readIterator, downsamplingInfo, false, samples, header, true, false);
        IntervalOverlappingIterator<AlignmentContext> alignmentContexts = new IntervalOverlappingIterator<>(libs, ImmutableList.of(interval), sequenceDictionary);
        final Spliterator<AlignmentContext> alignmentContextSpliterator = Spliterators.spliteratorUnknownSize(alignmentContexts, 0);
        return StreamSupport.stream(alignmentContextSpliterator, false).map(alignmentContext -> {
            final SimpleInterval alignmentInterval = new SimpleInterval(alignmentContext);
            return new LocusWalkerContext(alignmentContext, new ReferenceContext(reference, alignmentInterval), new FeatureContext(fm, alignmentInterval));
        }).iterator();
    };
}
Also used : GATKRead(org.broadinstitute.hellbender.utils.read.GATKRead) Broadcast(org.apache.spark.broadcast.Broadcast) java.util(java.util) IntervalOverlappingIterator(org.broadinstitute.hellbender.utils.iterators.IntervalOverlappingIterator) ReferenceMultiSource(org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource) SAMSequenceDictionary(htsjdk.samtools.SAMSequenceDictionary) Argument(org.broadinstitute.barclay.argparser.Argument) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) LocusIteratorByState(org.broadinstitute.hellbender.utils.locusiterator.LocusIteratorByState) GATKRead(org.broadinstitute.hellbender.utils.read.GATKRead) SAMFileHeader(htsjdk.samtools.SAMFileHeader) SimpleInterval(org.broadinstitute.hellbender.utils.SimpleInterval) Collectors(java.util.stream.Collectors) org.broadinstitute.hellbender.engine(org.broadinstitute.hellbender.engine) SAMReadGroupRecord(htsjdk.samtools.SAMReadGroupRecord) IntervalUtils(org.broadinstitute.hellbender.utils.IntervalUtils) ImmutableList(com.google.common.collect.ImmutableList) StreamSupport(java.util.stream.StreamSupport) LIBSDownsamplingInfo(org.broadinstitute.hellbender.utils.locusiterator.LIBSDownsamplingInfo) JavaRDD(org.apache.spark.api.java.JavaRDD) FlatMapFunction(org.apache.spark.api.java.function.FlatMapFunction) CommandLineException(org.broadinstitute.barclay.argparser.CommandLineException) IntervalOverlappingIterator(org.broadinstitute.hellbender.utils.iterators.IntervalOverlappingIterator) LocusIteratorByState(org.broadinstitute.hellbender.utils.locusiterator.LocusIteratorByState) FlatMapFunction(org.apache.spark.api.java.function.FlatMapFunction) SimpleInterval(org.broadinstitute.hellbender.utils.SimpleInterval)

Example 15 with FlatMapFunction

use of org.apache.spark.api.java.function.FlatMapFunction in project gatk by broadinstitute.

the class ReadWalkerSpark method getReadsFunction.

private static FlatMapFunction<Shard<GATKRead>, ReadWalkerContext> getReadsFunction(Broadcast<ReferenceMultiSource> bReferenceSource, Broadcast<FeatureManager> bFeatureManager, SAMSequenceDictionary sequenceDictionary, int readShardPadding) {
    return (FlatMapFunction<Shard<GATKRead>, ReadWalkerContext>) shard -> {
        SimpleInterval paddedInterval = shard.getInterval().expandWithinContig(readShardPadding, sequenceDictionary);
        ReferenceDataSource reference = bReferenceSource == null ? null : new ReferenceMemorySource(bReferenceSource.getValue().getReferenceBases(null, paddedInterval), sequenceDictionary);
        FeatureManager features = bFeatureManager == null ? null : bFeatureManager.getValue();
        return StreamSupport.stream(shard.spliterator(), false).map(r -> {
            final SimpleInterval readInterval = getReadInterval(r);
            return new ReadWalkerContext(r, new ReferenceContext(reference, readInterval), new FeatureContext(features, readInterval));
        }).iterator();
    };
}
Also used : GATKRead(org.broadinstitute.hellbender.utils.read.GATKRead) Broadcast(org.apache.spark.broadcast.Broadcast) ReferenceMultiSource(org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource) SAMSequenceDictionary(htsjdk.samtools.SAMSequenceDictionary) Argument(org.broadinstitute.barclay.argparser.Argument) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) GATKRead(org.broadinstitute.hellbender.utils.read.GATKRead) SimpleInterval(org.broadinstitute.hellbender.utils.SimpleInterval) Collectors(java.util.stream.Collectors) org.broadinstitute.hellbender.engine(org.broadinstitute.hellbender.engine) List(java.util.List) IntervalUtils(org.broadinstitute.hellbender.utils.IntervalUtils) StreamSupport(java.util.stream.StreamSupport) JavaRDD(org.apache.spark.api.java.JavaRDD) FlatMapFunction(org.apache.spark.api.java.function.FlatMapFunction) FlatMapFunction(org.apache.spark.api.java.function.FlatMapFunction) SimpleInterval(org.broadinstitute.hellbender.utils.SimpleInterval)

Aggregations

FlatMapFunction (org.apache.spark.api.java.function.FlatMapFunction)15 JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)12 List (java.util.List)10 JavaRDD (org.apache.spark.api.java.JavaRDD)9 ArrayList (java.util.ArrayList)8 Collectors (java.util.stream.Collectors)5 Function (org.apache.spark.api.java.function.Function)5 Tuple2 (scala.Tuple2)5 SAMSequenceDictionary (htsjdk.samtools.SAMSequenceDictionary)4 JavaPairRDD (org.apache.spark.api.java.JavaPairRDD)4 Iterator (java.util.Iterator)3 StreamSupport (java.util.stream.StreamSupport)3 SparkConf (org.apache.spark.SparkConf)3 Function2 (org.apache.spark.api.java.function.Function2)3 Broadcast (org.apache.spark.broadcast.Broadcast)3 Argument (org.broadinstitute.barclay.argparser.Argument)3 org.broadinstitute.hellbender.engine (org.broadinstitute.hellbender.engine)3 ReferenceMultiSource (org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource)3 IntervalUtils (org.broadinstitute.hellbender.utils.IntervalUtils)3 SimpleInterval (org.broadinstitute.hellbender.utils.SimpleInterval)3