Search in sources :

Example 6 with Broadcast

use of org.apache.spark.broadcast.Broadcast in project gatk by broadinstitute.

the class CoverageModelEMWorkspace method initializeWorkersWithPCA.

/**
     * Initialize model parameters by performing PCA.
     */
@EvaluatesRDD
@UpdatesRDD
@CachesRDD
private void initializeWorkersWithPCA() {
    logger.info("Initializing model parameters using PCA...");
    /* initially, set m_t, Psi_t and W_tl to zero to get an estimate of the read depth */
    final int numLatents = config.getNumLatents();
    mapWorkers(cb -> cb.cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.m_t, Nd4j.zeros(new int[] { 1, cb.getTargetSpaceBlock().getNumElements() })).cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.Psi_t, Nd4j.zeros(new int[] { 1, cb.getTargetSpaceBlock().getNumElements() })));
    if (biasCovariatesEnabled) {
        mapWorkers(cb -> cb.cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.W_tl, Nd4j.zeros(new int[] { cb.getTargetSpaceBlock().getNumElements(), numLatents })));
    }
    /* update read depth without taking into account correction from bias covariates */
    updateReadDepthPosteriorExpectations(1.0, true);
    /* fetch sample covariance matrix */
    final int minPCAInitializationReadCount = config.getMinPCAInitializationReadCount();
    mapWorkers(cb -> cb.cloneWithPCAInitializationData(minPCAInitializationReadCount, Integer.MAX_VALUE));
    cacheWorkers("PCA initialization");
    final INDArray targetCovarianceMatrix = mapWorkersAndReduce(CoverageModelEMComputeBlock::calculateTargetCovarianceMatrixForPCAInitialization, INDArray::add);
    /* perform eigen-decomposition on the target covariance matrix */
    final ImmutablePair<INDArray, INDArray> targetCovarianceEigensystem = CoverageModelEMWorkspaceMathUtils.eig(targetCovarianceMatrix, false, logger);
    /* the eigenvalues of sample covariance matrix can be immediately inferred by scaling */
    final INDArray sampleCovarianceEigenvalues = targetCovarianceEigensystem.getLeft().div(numSamples);
    /* estimate the isotropic unexplained variance -- see Bishop 12.46 */
    final int residualDim = numTargets - numLatents;
    final double isotropicVariance = sampleCovarianceEigenvalues.get(NDArrayIndex.interval(numLatents, numSamples)).sumNumber().doubleValue() / residualDim;
    logger.info(String.format("PCA estimate of isotropic unexplained variance: %f", isotropicVariance));
    /* estimate bias factors -- see Bishop 12.45 */
    final INDArray scaleFactors = Transforms.sqrt(sampleCovarianceEigenvalues.get(NDArrayIndex.interval(0, numLatents)).sub(isotropicVariance), false);
    final INDArray biasCovariatesPCA = Nd4j.create(new int[] { numTargets, numLatents });
    for (int li = 0; li < numLatents; li++) {
        final INDArray v = targetCovarianceEigensystem.getRight().getColumn(li);
        /* calculate [Delta_PCA_st]^T v */
        /* note: we do not need to broadcast vec since it is small and lambda capture is just fine */
        final INDArray unnormedBiasCovariate = CoverageModelSparkUtils.assembleINDArrayBlocksFromCollection(mapWorkersAndCollect(cb -> ImmutablePair.of(cb.getTargetSpaceBlock(), cb.getINDArrayFromCache(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.Delta_PCA_st).transpose().mmul(v))), 0);
        final double norm = unnormedBiasCovariate.norm1Number().doubleValue();
        final INDArray normedBiasCovariate = unnormedBiasCovariate.divi(norm).muli(scaleFactors.getDouble(li));
        biasCovariatesPCA.getColumn(li).assign(normedBiasCovariate);
    }
    if (ardEnabled) {
        /* a better estimate of ARD coefficients */
        biasCovariatesARDCoefficients.assign(Nd4j.zeros(new int[] { 1, numLatents }).addi(config.getInitialARDPrecisionRelativeToNoise() / isotropicVariance));
    }
    final CoverageModelParameters modelParamsFromPCA = new CoverageModelParameters(processedTargetList, Nd4j.zeros(new int[] { 1, numTargets }), Nd4j.zeros(new int[] { 1, numTargets }).addi(isotropicVariance), biasCovariatesPCA, biasCovariatesARDCoefficients);
    /* clear PCA initialization data from workers */
    mapWorkers(CoverageModelEMComputeBlock::cloneWithRemovedPCAInitializationData);
    /* push model parameters to workers */
    initializeWorkersWithGivenModel(modelParamsFromPCA);
    /* update bias latent posterior expectations without admixing */
    updateBiasLatentPosteriorExpectations(1.0);
}
Also used : ScalarProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.ScalarProducer) Function2(org.apache.spark.api.java.function.Function2) HMMSegmentProcessor(org.broadinstitute.hellbender.utils.hmm.segmentation.HMMSegmentProcessor) GermlinePloidyAnnotatedTargetCollection(org.broadinstitute.hellbender.tools.exome.sexgenotyper.GermlinePloidyAnnotatedTargetCollection) HiddenStateSegmentRecordWriter(org.broadinstitute.hellbender.utils.hmm.segmentation.HiddenStateSegmentRecordWriter) BiFunction(java.util.function.BiFunction) GATKException(org.broadinstitute.hellbender.exceptions.GATKException) SexGenotypeData(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeData) ParamUtils(org.broadinstitute.hellbender.utils.param.ParamUtils) CallStringProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.CallStringProducer) StorageLevel(org.apache.spark.storage.StorageLevel) SynchronizedUnivariateSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.SynchronizedUnivariateSolver) CopyRatioExpectationsCalculator(org.broadinstitute.hellbender.tools.coveragemodel.interfaces.CopyRatioExpectationsCalculator) UnivariateSolverSpecifications(org.broadinstitute.hellbender.tools.coveragemodel.math.UnivariateSolverSpecifications) IndexRange(org.broadinstitute.hellbender.utils.IndexRange) Broadcast(org.apache.spark.broadcast.Broadcast) ExitStatus(org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray.ExitStatus) SexGenotypeDataCollection(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeDataCollection) HashPartitioner(org.apache.spark.HashPartitioner) Predicate(java.util.function.Predicate) GeneralLinearOperator(org.broadinstitute.hellbender.tools.coveragemodel.linalg.GeneralLinearOperator) Nd4j(org.nd4j.linalg.factory.Nd4j) INDArrayIndex(org.nd4j.linalg.indexing.INDArrayIndex) FastMath(org.apache.commons.math3.util.FastMath) org.broadinstitute.hellbender.tools.exome(org.broadinstitute.hellbender.tools.exome) Tuple2(scala.Tuple2) Collectors(java.util.stream.Collectors) Sets(com.google.common.collect.Sets) AbstractUnivariateSolver(org.apache.commons.math3.analysis.solvers.AbstractUnivariateSolver) FourierLinearOperatorNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.FourierLinearOperatorNDArray) Logger(org.apache.logging.log4j.Logger) Stream(java.util.stream.Stream) UserException(org.broadinstitute.hellbender.exceptions.UserException) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) Utils(org.broadinstitute.hellbender.utils.Utils) Function(org.apache.spark.api.java.function.Function) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer) IntStream(java.util.stream.IntStream) java.util(java.util) NDArrayIndex(org.nd4j.linalg.indexing.NDArrayIndex) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) AlleleMetadataProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.AlleleMetadataProducer) EmissionCalculationStrategy(org.broadinstitute.hellbender.tools.coveragemodel.CoverageModelCopyRatioEmissionProbabilityCalculator.EmissionCalculationStrategy) RobustBrentSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver) IntervalUtils(org.broadinstitute.hellbender.utils.IntervalUtils) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) HiddenStateSegmentRecord(org.broadinstitute.hellbender.utils.hmm.segmentation.HiddenStateSegmentRecord) ImmutableTriple(org.apache.commons.lang3.tuple.ImmutableTriple) IterativeLinearSolverNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray) GATKProtectedMathUtils(org.broadinstitute.hellbender.utils.GATKProtectedMathUtils) Nd4jIOUtils(org.broadinstitute.hellbender.tools.coveragemodel.nd4jutils.Nd4jIOUtils) IOException(java.io.IOException) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) File(java.io.File) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VisibleForTesting(com.google.common.annotations.VisibleForTesting) Transforms(org.nd4j.linalg.ops.transforms.Transforms) LogManager(org.apache.logging.log4j.LogManager) NoBracketingException(org.apache.commons.math3.exception.NoBracketingException) INDArray(org.nd4j.linalg.api.ndarray.INDArray)

Example 7 with Broadcast

use of org.apache.spark.broadcast.Broadcast in project gatk by broadinstitute.

the class CoverageModelWPreconditionerSpark method operate.

@Override
public INDArray operate(@Nonnull final INDArray W_tl) throws DimensionMismatchException {
    if (W_tl.rank() != 2 || W_tl.shape()[0] != numTargets || W_tl.shape()[1] != numLatents) {
        throw new DimensionMismatchException(W_tl.length(), numTargets * numLatents);
    }
    long startTimeRFFT = System.nanoTime();
    /* forward rfft */
    final INDArray W_kl = Nd4j.create(fftSize, numLatents);
    IntStream.range(0, numLatents).parallel().forEach(li -> W_kl.get(NDArrayIndex.all(), NDArrayIndex.point(li)).assign(Nd4j.create(F_tt.getForwardFFT(W_tl.get(NDArrayIndex.all(), NDArrayIndex.point(li))), new int[] { fftSize, 1 })));
    long endTimeRFFT = System.nanoTime();
    /* apply the preconditioner in the Fourier space */
    long startTimePrecond = System.nanoTime();
    final Map<LinearlySpacedIndexBlock, INDArray> W_kl_map = CoverageModelSparkUtils.partitionINDArrayToMap(fourierSpaceBlocks, W_kl);
    final Broadcast<Map<LinearlySpacedIndexBlock, INDArray>> W_kl_bc = ctx.broadcast(W_kl_map);
    final JavaPairRDD<LinearlySpacedIndexBlock, INDArray> preconditionedWRDD = linOpPairRDD.mapToPair(p -> {
        final INDArray W_kl_chuck = W_kl_bc.value().get(p._1);
        final INDArray linOp_chunk = p._2;
        final int blockSize = linOp_chunk.shape()[0];
        final List<INDArray> linOpWList = IntStream.range(0, blockSize).parallel().mapToObj(k -> CoverageModelEMWorkspaceMathUtils.linsolve(linOp_chunk.get(NDArrayIndex.point(k)), W_kl_chuck.get(NDArrayIndex.point(k)))).collect(Collectors.toList());
        return new Tuple2<>(p._1, Nd4j.vstack(linOpWList));
    });
    W_kl.assign(CoverageModelSparkUtils.assembleINDArrayBlocksFromRDD(preconditionedWRDD, 0));
    W_kl_bc.destroy();
    //        final JavaPairRDD<LinearlySpacedIndexBlock, INDArray> W_kl_RDD = CoverageModelSparkUtils.rddFromINDArray(W_kl,
    //                fourierSpaceBlocks, ctx, true);
    //        W_kl.assign(CoverageModelSparkUtils.assembleINDArrayBlocks(linOpPairRDD.join((W_kl_RDD))
    //                .mapValues(p -> {
    //                    final INDArray linOp = p._1;
    //                    final INDArray W = p._2;
    //                    final int blockSize = linOp.shape()[0];
    //                    final List<INDArray> linOpWList = IntStream.range(0, blockSize).parallel().mapToObj(k ->
    //                            CoverageModelEMWorkspaceMathUtils.linsolve(linOp.get(NDArrayIndex.point(k)),
    //                                    W.get(NDArrayIndex.point(k))))
    //                            .collect(Collectors.toList());
    //                    return Nd4j.vstack(linOpWList);
    //                }), false));
    //        W_kl_RDD.unpersist();
    long endTimePrecond = System.nanoTime();
    /* irfft */
    long startTimeIRFFT = System.nanoTime();
    final INDArray res = Nd4j.create(numTargets, numLatents);
    IntStream.range(0, numLatents).parallel().forEach(li -> res.get(NDArrayIndex.all(), NDArrayIndex.point(li)).assign(F_tt.getInverseFFT(W_kl.get(NDArrayIndex.all(), NDArrayIndex.point(li)))));
    long endTimeIRFFT = System.nanoTime();
    logger.debug("Local FFT timing: " + (endTimeRFFT - startTimeRFFT + endTimeIRFFT - startTimeIRFFT) / 1000000 + " ms");
    logger.debug("Spark preconditioner application timing: " + (endTimePrecond - startTimePrecond) / 1000000 + " ms");
    return res;
}
Also used : Broadcast(org.apache.spark.broadcast.Broadcast) IntStream(java.util.stream.IntStream) NDArrayIndex(org.nd4j.linalg.indexing.NDArrayIndex) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) GeneralLinearOperator(org.broadinstitute.hellbender.tools.coveragemodel.linalg.GeneralLinearOperator) Nd4j(org.nd4j.linalg.factory.Nd4j) DimensionMismatchException(org.apache.commons.math3.exception.DimensionMismatchException) Tuple2(scala.Tuple2) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) Collectors(java.util.stream.Collectors) ParamUtils(org.broadinstitute.hellbender.utils.param.ParamUtils) FourierLinearOperatorNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.FourierLinearOperatorNDArray) List(java.util.List) Logger(org.apache.logging.log4j.Logger) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Map(java.util.Map) LogManager(org.apache.logging.log4j.LogManager) Nonnull(javax.annotation.Nonnull) DimensionMismatchException(org.apache.commons.math3.exception.DimensionMismatchException) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Tuple2(scala.Tuple2) Map(java.util.Map)

Example 8 with Broadcast

use of org.apache.spark.broadcast.Broadcast in project gatk by broadinstitute.

the class CoverageModelWLinearOperatorSpark method operate.

@Override
public INDArray operate(@Nonnull final INDArray W_tl) throws DimensionMismatchException {
    if (W_tl.rank() != 2 || W_tl.shape()[0] != numTargets || W_tl.shape()[1] != numLatents)
        throw new DimensionMismatchException(W_tl.length(), numTargets * numLatents);
    /* Z F W */
    final long startTimeZFW = System.nanoTime();
    final INDArray Z_F_W_tl = Nd4j.create(numTargets, numLatents);
    IntStream.range(0, numLatents).parallel().forEach(li -> Z_F_W_tl.get(NDArrayIndex.all(), NDArrayIndex.point(li)).assign(F_tt.operate(W_tl.get(NDArrayIndex.all(), NDArrayIndex.point(li)))));
    Z_F_W_tl.assign(Nd4j.gemm(Z_F_W_tl, Z_ll, false, false));
    final long endTimeZFW = System.nanoTime();
    /* perform a broadcast hash join */
    final long startTimeQW = System.nanoTime();
    final Map<LinearlySpacedIndexBlock, INDArray> W_tl_map = CoverageModelSparkUtils.partitionINDArrayToMap(targetSpaceBlocks, W_tl);
    final Broadcast<Map<LinearlySpacedIndexBlock, INDArray>> W_tl_bc = ctx.broadcast(W_tl_map);
    final INDArray Q_W_tl = CoverageModelSparkUtils.assembleINDArrayBlocksFromRDD(computeRDD.mapValues(cb -> {
        final INDArray W_tl_chunk = W_tl_bc.value().get(cb.getTargetSpaceBlock());
        final INDArray Q_tll_chunk = cb.getINDArrayFromCache(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.Q_tll);
        final Collection<INDArray> W_Q_chunk = IntStream.range(0, cb.getTargetSpaceBlock().getNumElements()).parallel().mapToObj(ti -> Q_tll_chunk.get(NDArrayIndex.point(ti)).mmul(W_tl_chunk.get(NDArrayIndex.point(ti)).transpose())).collect(Collectors.toList());
        return Nd4j.vstack(W_Q_chunk);
    }), 0);
    W_tl_bc.destroy();
    //        final JavaPairRDD<LinearlySpacedIndexBlock, INDArray> W_tl_RDD = CoverageModelSparkUtils.rddFromINDArray(W_tl,
    //                targetSpaceBlocks, ctx, true);
    //        final INDArray Q_W_tl = CoverageModelSparkUtils.assembleINDArrayBlocks(
    //                computeRDD.join(W_tl_RDD).mapValues(p -> {
    //                    final CoverageModelEMComputeBlock cb = p._1;
    //                    final INDArray W_tl_chunk = p._2;
    //                    final INDArray Q_tll_chunk = cb.getINDArrayFromCache("Q_tll");
    //                    return Nd4j.vstack(IntStream.range(0, cb.getTargetSpaceBlock().getNumElements()).parallel()
    //                            .mapToObj(ti -> Q_tll_chunk.get(NDArrayIndex.point(ti)).mmul(W_tl_chunk.get(NDArrayIndex.point(ti)).transpose()))
    //                            .collect(Collectors.toList()));
    //                }), false);
    //        W_tl_RDD.unpersist();
    final long endTimeQW = System.nanoTime();
    logger.debug("Local [Z] [F] [W] timing: " + (endTimeZFW - startTimeZFW) / 1000000 + " ms");
    logger.debug("Spark [Q] [W] timing: " + (endTimeQW - startTimeQW) / 1000000 + " ms");
    return Q_W_tl.addi(Z_F_W_tl);
}
Also used : Broadcast(org.apache.spark.broadcast.Broadcast) IntStream(java.util.stream.IntStream) NDArrayIndex(org.nd4j.linalg.indexing.NDArrayIndex) Collection(java.util.Collection) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) GeneralLinearOperator(org.broadinstitute.hellbender.tools.coveragemodel.linalg.GeneralLinearOperator) Nd4j(org.nd4j.linalg.factory.Nd4j) DimensionMismatchException(org.apache.commons.math3.exception.DimensionMismatchException) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) Collectors(java.util.stream.Collectors) FourierLinearOperatorNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.FourierLinearOperatorNDArray) List(java.util.List) Logger(org.apache.logging.log4j.Logger) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Map(java.util.Map) LogManager(org.apache.logging.log4j.LogManager) Nonnull(javax.annotation.Nonnull) DimensionMismatchException(org.apache.commons.math3.exception.DimensionMismatchException) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Collection(java.util.Collection) Map(java.util.Map)

Example 9 with Broadcast

use of org.apache.spark.broadcast.Broadcast in project gatk by broadinstitute.

the class LocusWalkerSpark method getAlignments.

/**
     * Loads alignments and the corresponding reference and features into a {@link JavaRDD} for the intervals specified.
     *
     * If no intervals were specified, returns all the alignments.
     *
     * @return all alignments as a {@link JavaRDD}, bounded by intervals if specified.
     */
public JavaRDD<LocusWalkerContext> getAlignments(JavaSparkContext ctx) {
    SAMSequenceDictionary sequenceDictionary = getBestAvailableSequenceDictionary();
    List<SimpleInterval> intervals = hasIntervals() ? getIntervals() : IntervalUtils.getAllIntervalsForReference(sequenceDictionary);
    final List<ShardBoundary> intervalShards = intervals.stream().flatMap(interval -> Shard.divideIntervalIntoShards(interval, readShardSize, readShardPadding, sequenceDictionary).stream()).collect(Collectors.toList());
    int maxLocatableSize = Math.min(readShardSize, readShardPadding);
    JavaRDD<Shard<GATKRead>> shardedReads = SparkSharder.shard(ctx, getReads(), GATKRead.class, sequenceDictionary, intervalShards, maxLocatableSize, shuffle);
    Broadcast<ReferenceMultiSource> bReferenceSource = hasReference() ? ctx.broadcast(getReference()) : null;
    Broadcast<FeatureManager> bFeatureManager = features == null ? null : ctx.broadcast(features);
    return shardedReads.flatMap(getAlignmentsFunction(bReferenceSource, bFeatureManager, sequenceDictionary, getHeaderForReads(), getDownsamplingInfo()));
}
Also used : Broadcast(org.apache.spark.broadcast.Broadcast) java.util(java.util) IntervalOverlappingIterator(org.broadinstitute.hellbender.utils.iterators.IntervalOverlappingIterator) ReferenceMultiSource(org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource) SAMSequenceDictionary(htsjdk.samtools.SAMSequenceDictionary) Argument(org.broadinstitute.barclay.argparser.Argument) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) LocusIteratorByState(org.broadinstitute.hellbender.utils.locusiterator.LocusIteratorByState) GATKRead(org.broadinstitute.hellbender.utils.read.GATKRead) SAMFileHeader(htsjdk.samtools.SAMFileHeader) SimpleInterval(org.broadinstitute.hellbender.utils.SimpleInterval) Collectors(java.util.stream.Collectors) org.broadinstitute.hellbender.engine(org.broadinstitute.hellbender.engine) SAMReadGroupRecord(htsjdk.samtools.SAMReadGroupRecord) IntervalUtils(org.broadinstitute.hellbender.utils.IntervalUtils) ImmutableList(com.google.common.collect.ImmutableList) StreamSupport(java.util.stream.StreamSupport) LIBSDownsamplingInfo(org.broadinstitute.hellbender.utils.locusiterator.LIBSDownsamplingInfo) JavaRDD(org.apache.spark.api.java.JavaRDD) FlatMapFunction(org.apache.spark.api.java.function.FlatMapFunction) CommandLineException(org.broadinstitute.barclay.argparser.CommandLineException) ReferenceMultiSource(org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource) SAMSequenceDictionary(htsjdk.samtools.SAMSequenceDictionary) SimpleInterval(org.broadinstitute.hellbender.utils.SimpleInterval)

Example 10 with Broadcast

use of org.apache.spark.broadcast.Broadcast in project gatk by broadinstitute.

the class VariantWalkerSpark method getVariantsFunction.

private static FlatMapFunction<Shard<VariantContext>, VariantWalkerContext> getVariantsFunction(final Broadcast<ReferenceMultiSource> bReferenceSource, final Broadcast<FeatureManager> bFeatureManager, final SAMSequenceDictionary sequenceDictionary, final int variantShardPadding) {
    return (FlatMapFunction<Shard<VariantContext>, VariantWalkerContext>) shard -> {
        SimpleInterval paddedInterval = shard.getInterval().expandWithinContig(variantShardPadding, sequenceDictionary);
        ReferenceDataSource reference = bReferenceSource == null ? null : new ReferenceMemorySource(bReferenceSource.getValue().getReferenceBases(null, paddedInterval), sequenceDictionary);
        FeatureManager features = bFeatureManager == null ? null : bFeatureManager.getValue();
        return StreamSupport.stream(shard.spliterator(), false).filter(v -> v.getStart() >= shard.getStart() && v.getStart() <= shard.getEnd()).map(v -> {
            final SimpleInterval variantInterval = new SimpleInterval(v);
            return new VariantWalkerContext(v, new ReadsContext(), new ReferenceContext(reference, variantInterval), new FeatureContext(features, variantInterval));
        }).iterator();
    };
}
Also used : Broadcast(org.apache.spark.broadcast.Broadcast) VCFHeader(htsjdk.variant.vcf.VCFHeader) ReferenceMultiSource(org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource) SAMSequenceDictionary(htsjdk.samtools.SAMSequenceDictionary) Argument(org.broadinstitute.barclay.argparser.Argument) IndexUtils(org.broadinstitute.hellbender.utils.IndexUtils) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) VariantFilterLibrary(org.broadinstitute.hellbender.engine.filters.VariantFilterLibrary) StandardArgumentDefinitions(org.broadinstitute.hellbender.cmdline.StandardArgumentDefinitions) SimpleInterval(org.broadinstitute.hellbender.utils.SimpleInterval) Collectors(java.util.stream.Collectors) VariantFilter(org.broadinstitute.hellbender.engine.filters.VariantFilter) org.broadinstitute.hellbender.engine(org.broadinstitute.hellbender.engine) List(java.util.List) IntervalUtils(org.broadinstitute.hellbender.utils.IntervalUtils) VariantContext(htsjdk.variant.variantcontext.VariantContext) VariantsSparkSource(org.broadinstitute.hellbender.engine.spark.datasources.VariantsSparkSource) StreamSupport(java.util.stream.StreamSupport) JavaRDD(org.apache.spark.api.java.JavaRDD) FlatMapFunction(org.apache.spark.api.java.function.FlatMapFunction) FlatMapFunction(org.apache.spark.api.java.function.FlatMapFunction) VariantContext(htsjdk.variant.variantcontext.VariantContext) SimpleInterval(org.broadinstitute.hellbender.utils.SimpleInterval)

Aggregations

Broadcast (org.apache.spark.broadcast.Broadcast)25 JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)23 Collectors (java.util.stream.Collectors)21 List (java.util.List)15 JavaRDD (org.apache.spark.api.java.JavaRDD)15 IntervalUtils (org.broadinstitute.hellbender.utils.IntervalUtils)15 Argument (org.broadinstitute.barclay.argparser.Argument)13 JavaPairRDD (org.apache.spark.api.java.JavaPairRDD)12 Tuple2 (scala.Tuple2)12 SAMSequenceDictionary (htsjdk.samtools.SAMSequenceDictionary)11 IntStream (java.util.stream.IntStream)11 LogManager (org.apache.logging.log4j.LogManager)11 Logger (org.apache.logging.log4j.Logger)11 ReferenceMultiSource (org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource)11 SimpleInterval (org.broadinstitute.hellbender.utils.SimpleInterval)11 StreamSupport (java.util.stream.StreamSupport)10 org.broadinstitute.hellbender.engine (org.broadinstitute.hellbender.engine)10 UserException (org.broadinstitute.hellbender.exceptions.UserException)10 FlatMapFunction (org.apache.spark.api.java.function.FlatMapFunction)9 GATKException (org.broadinstitute.hellbender.exceptions.GATKException)9