Search in sources :

Example 66 with Stream

use of java.util.stream.Stream in project gatk by broadinstitute.

the class CoverageModelEMWorkspace method updateCopyRatioPosteriorExpectationsSpark.

/**
     * The Spark implementation of the E-step update of copy ratio posteriors
     *
     * @return a {@link SubroutineSignal} containing the update size
     */
@EvaluatesRDD
@UpdatesRDD
@CachesRDD
private SubroutineSignal updateCopyRatioPosteriorExpectationsSpark(final double admixingRatio) {
    /* local final member variables for lambda capture */
    final List<LinearlySpacedIndexBlock> targetBlocks = new ArrayList<>();
    targetBlocks.addAll(this.targetBlocks);
    final List<Target> targetList = new ArrayList<>();
    targetList.addAll(processedTargetList);
    final List<String> sampleNameList = new ArrayList<>();
    sampleNameList.addAll(processedSampleNameList);
    final List<SexGenotypeData> sampleSexGenotypeData = new ArrayList<>();
    sampleSexGenotypeData.addAll(processedSampleSexGenotypeData);
    final int numTargetBlocks = targetBlocks.size();
    final CopyRatioExpectationsCalculator<CoverageModelCopyRatioEmissionData, STATE> calculator = this.copyRatioExpectationsCalculator;
    final INDArray sampleReadDepths = Transforms.exp(sampleMeanLogReadDepths, true);
    /* make an RDD of copy ratio posterior expectations */
    final JavaPairRDD<Integer, CopyRatioExpectations> copyRatioPosteriorExpectationsPairRDD = /* fetch copy ratio emission data from workers */
    fetchCopyRatioEmissionDataSpark().mapPartitionsToPair(it -> {
        final List<Tuple2<Integer, CopyRatioExpectations>> newPartitionData = new ArrayList<>();
        while (it.hasNext()) {
            final Tuple2<Integer, List<CoverageModelCopyRatioEmissionData>> prevDatum = it.next();
            final int si = prevDatum._1;
            final CopyRatioCallingMetadata copyRatioCallingMetadata = CopyRatioCallingMetadata.builder().sampleName(sampleNameList.get(si)).sampleSexGenotypeData(sampleSexGenotypeData.get(si)).sampleCoverageDepth(sampleReadDepths.getDouble(si)).emissionCalculationStrategy(EmissionCalculationStrategy.HYBRID_POISSON_GAUSSIAN).build();
            newPartitionData.add(new Tuple2<>(prevDatum._1, calculator.getCopyRatioPosteriorExpectations(copyRatioCallingMetadata, targetList, prevDatum._2)));
        }
        return newPartitionData.iterator();
    }, true);
    /* we need to do two things to copyRatioPosteriorExpectationsPairRDD; so we cache it */
    /* step 1. update log chain posterior expectation on the driver node */
    final double[] newSampleLogChainPosteriors = copyRatioPosteriorExpectationsPairRDD.mapValues(CopyRatioExpectations::getLogChainPosteriorProbability).collect().stream().sorted(Comparator.comparingInt(t -> t._1)).mapToDouble(t -> t._2).toArray();
    sampleLogChainPosteriors.assign(Nd4j.create(newSampleLogChainPosteriors, new int[] { numSamples, 1 }));
    /* step 2. repartition in target space */
    final JavaPairRDD<LinearlySpacedIndexBlock, ImmutablePair<INDArray, INDArray>> blockifiedCopyRatioPosteriorResultsPairRDD = copyRatioPosteriorExpectationsPairRDD.flatMapToPair(dat -> targetBlocks.stream().map(tb -> new Tuple2<>(tb, new Tuple2<>(dat._1, ImmutablePair.of(dat._2.getLogCopyRatioMeans(tb), dat._2.getLogCopyRatioVariances(tb))))).iterator()).combineByKey(/* recipe to create an singleton list */
    Collections::singletonList, /* recipe to add an element to the list */
    (list, element) -> Stream.concat(list.stream(), Stream.of(element)).collect(Collectors.toList()), /* recipe to concatenate two lists */
    (list1, list2) -> Stream.concat(list1.stream(), list2.stream()).collect(Collectors.toList()), /* repartition with respect to target space blocks */
    new HashPartitioner(numTargetBlocks)).mapValues(list -> list.stream().sorted(Comparator.comparingInt(t -> t._1)).map(p -> p._2).map(t -> ImmutablePair.of(Nd4j.create(t.left), Nd4j.create(t.right))).collect(Collectors.toList())).mapValues(CoverageModelEMWorkspace::stackCopyRatioPosteriorDataForAllSamples);
    /* we do not need copy ratio expectations anymore */
    copyRatioPosteriorExpectationsPairRDD.unpersist();
    /* step 3. merge with computeRDD and update */
    computeRDD = computeRDD.join(blockifiedCopyRatioPosteriorResultsPairRDD).mapValues(t -> t._1.cloneWithUpdatedCopyRatioPosteriors(t._2.left, t._2.right, admixingRatio));
    cacheWorkers("after E-step for copy ratio update");
    /* collect subroutine signals */
    final List<SubroutineSignal> sigs = mapWorkersAndCollect(CoverageModelEMComputeBlock::getLatestMStepSignal);
    final double errorNormInfinity = Collections.max(sigs.stream().map(sig -> sig.<Double>get(StandardSubroutineSignals.RESIDUAL_ERROR_NORM)).collect(Collectors.toList()));
    return SubroutineSignal.builder().put(StandardSubroutineSignals.RESIDUAL_ERROR_NORM, errorNormInfinity).build();
}
Also used : ScalarProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.ScalarProducer) Function2(org.apache.spark.api.java.function.Function2) HMMSegmentProcessor(org.broadinstitute.hellbender.utils.hmm.segmentation.HMMSegmentProcessor) GermlinePloidyAnnotatedTargetCollection(org.broadinstitute.hellbender.tools.exome.sexgenotyper.GermlinePloidyAnnotatedTargetCollection) HiddenStateSegmentRecordWriter(org.broadinstitute.hellbender.utils.hmm.segmentation.HiddenStateSegmentRecordWriter) BiFunction(java.util.function.BiFunction) GATKException(org.broadinstitute.hellbender.exceptions.GATKException) SexGenotypeData(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeData) ParamUtils(org.broadinstitute.hellbender.utils.param.ParamUtils) CallStringProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.CallStringProducer) StorageLevel(org.apache.spark.storage.StorageLevel) SynchronizedUnivariateSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.SynchronizedUnivariateSolver) CopyRatioExpectationsCalculator(org.broadinstitute.hellbender.tools.coveragemodel.interfaces.CopyRatioExpectationsCalculator) UnivariateSolverSpecifications(org.broadinstitute.hellbender.tools.coveragemodel.math.UnivariateSolverSpecifications) IndexRange(org.broadinstitute.hellbender.utils.IndexRange) Broadcast(org.apache.spark.broadcast.Broadcast) ExitStatus(org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray.ExitStatus) SexGenotypeDataCollection(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeDataCollection) HashPartitioner(org.apache.spark.HashPartitioner) Predicate(java.util.function.Predicate) GeneralLinearOperator(org.broadinstitute.hellbender.tools.coveragemodel.linalg.GeneralLinearOperator) Nd4j(org.nd4j.linalg.factory.Nd4j) INDArrayIndex(org.nd4j.linalg.indexing.INDArrayIndex) FastMath(org.apache.commons.math3.util.FastMath) org.broadinstitute.hellbender.tools.exome(org.broadinstitute.hellbender.tools.exome) Tuple2(scala.Tuple2) Collectors(java.util.stream.Collectors) Sets(com.google.common.collect.Sets) AbstractUnivariateSolver(org.apache.commons.math3.analysis.solvers.AbstractUnivariateSolver) FourierLinearOperatorNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.FourierLinearOperatorNDArray) Logger(org.apache.logging.log4j.Logger) Stream(java.util.stream.Stream) UserException(org.broadinstitute.hellbender.exceptions.UserException) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) Utils(org.broadinstitute.hellbender.utils.Utils) Function(org.apache.spark.api.java.function.Function) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer) IntStream(java.util.stream.IntStream) java.util(java.util) NDArrayIndex(org.nd4j.linalg.indexing.NDArrayIndex) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) AlleleMetadataProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.AlleleMetadataProducer) EmissionCalculationStrategy(org.broadinstitute.hellbender.tools.coveragemodel.CoverageModelCopyRatioEmissionProbabilityCalculator.EmissionCalculationStrategy) RobustBrentSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver) IntervalUtils(org.broadinstitute.hellbender.utils.IntervalUtils) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) HiddenStateSegmentRecord(org.broadinstitute.hellbender.utils.hmm.segmentation.HiddenStateSegmentRecord) ImmutableTriple(org.apache.commons.lang3.tuple.ImmutableTriple) IterativeLinearSolverNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray) GATKProtectedMathUtils(org.broadinstitute.hellbender.utils.GATKProtectedMathUtils) Nd4jIOUtils(org.broadinstitute.hellbender.tools.coveragemodel.nd4jutils.Nd4jIOUtils) IOException(java.io.IOException) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) File(java.io.File) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VisibleForTesting(com.google.common.annotations.VisibleForTesting) Transforms(org.nd4j.linalg.ops.transforms.Transforms) LogManager(org.apache.logging.log4j.LogManager) NoBracketingException(org.apache.commons.math3.exception.NoBracketingException) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) Tuple2(scala.Tuple2) SexGenotypeData(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeData) HashPartitioner(org.apache.spark.HashPartitioner)

Example 67 with Stream

use of java.util.stream.Stream in project gatk by broadinstitute.

the class CoverageModelEMWorkspace method updateSampleUnexplainedVariance.

/**
     * E-step update of the sample-specific unexplained variance
     *
     * @return a {@link SubroutineSignal} containing the update size (key: "error_norm") and the average
     * number of function evaluations per sample (key: "iterations")
     */
@EvaluatesRDD
@UpdatesRDD
@CachesRDD
public SubroutineSignal updateSampleUnexplainedVariance() {
    mapWorkers(cb -> cb.cloneWithUpdatedCachesByTag(CoverageModelEMComputeBlock.CoverageModelICGCacheTag.E_STEP_GAMMA));
    cacheWorkers("after E-step for sample unexplained variance initialization");
    /* create a compound objective function for simultaneous multi-sample queries */
    final java.util.function.Function<Map<Integer, Double>, Map<Integer, Double>> objFunc = arg -> {
        if (arg.isEmpty()) {
            return Collections.emptyMap();
        }
        final int[] sampleIndices = arg.keySet().stream().mapToInt(i -> i).toArray();
        final INDArray gammaValues = Nd4j.create(Arrays.stream(sampleIndices).mapToDouble(arg::get).toArray(), new int[] { sampleIndices.length, 1 });
        final INDArray eval = mapWorkersAndReduce(cb -> cb.calculateSampleSpecificVarianceObjectiveFunctionMultiSample(sampleIndices, gammaValues), INDArray::add);
        final Map<Integer, Double> output = new HashMap<>();
        IntStream.range(0, sampleIndices.length).forEach(evalIdx -> output.put(sampleIndices[evalIdx], eval.getDouble(evalIdx)));
        return output;
    };
    final java.util.function.Function<UnivariateSolverSpecifications, AbstractUnivariateSolver> solverFactory = spec -> new RobustBrentSolver(spec.getRelativeAccuracy(), spec.getAbsoluteAccuracy(), spec.getFunctionValueAccuracy(), null, config.getSampleSpecificVarianceSolverNumBisections(), config.getSampleSpecificVarianceSolverRefinementDepth());
    /* instantiate a synchronized multi-sample root finder and add jobs */
    final SynchronizedUnivariateSolver syncSolver = new SynchronizedUnivariateSolver(objFunc, solverFactory, numSamples);
    IntStream.range(0, numSamples).forEach(si -> {
        final double x0 = 0.5 * config.getSampleSpecificVarianceUpperLimit();
        syncSolver.add(si, 0, config.getSampleSpecificVarianceUpperLimit(), x0, config.getSampleSpecificVarianceAbsoluteTolerance(), config.getSampleSpecificVarianceRelativeTolerance(), config.getSampleSpecificVarianceMaximumIterations());
    });
    /* solve and collect statistics */
    final INDArray newSampleUnexplainedVariance = Nd4j.create(numSamples, 1);
    final List<Integer> numberOfEvaluations = new ArrayList<>(numSamples);
    try {
        final Map<Integer, SynchronizedUnivariateSolver.UnivariateSolverSummary> newSampleSpecificVarianceMap = syncSolver.solve();
        newSampleSpecificVarianceMap.entrySet().forEach(entry -> {
            final int sampleIndex = entry.getKey();
            final SynchronizedUnivariateSolver.UnivariateSolverSummary summary = entry.getValue();
            double val = 0;
            switch(summary.status) {
                case SUCCESS:
                    val = summary.x;
                    break;
                case TOO_MANY_EVALUATIONS:
                    logger.warn("Could not locate the root of gamma -- increase the maximum number of" + "function evaluations");
                    break;
            }
            newSampleUnexplainedVariance.put(sampleIndex, 0, val);
            numberOfEvaluations.add(summary.evaluations);
        });
    } catch (final InterruptedException ex) {
        throw new RuntimeException("The update of sample unexplained variance was interrupted -- can not continue");
    }
    /* admix */
    final INDArray newSampleUnexplainedVarianceAdmixed = newSampleUnexplainedVariance.mul(config.getMeanFieldAdmixingRatio()).addi(sampleUnexplainedVariance.mul(1 - config.getMeanFieldAdmixingRatio()));
    /* calculate the error */
    final double errorNormInfinity = CoverageModelEMWorkspaceMathUtils.getINDArrayNormInfinity(newSampleUnexplainedVarianceAdmixed.sub(sampleUnexplainedVariance));
    /* update local copy */
    sampleUnexplainedVariance.assign(newSampleUnexplainedVarianceAdmixed);
    /* push to workers */
    pushToWorkers(newSampleUnexplainedVarianceAdmixed, (arr, cb) -> cb.cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.gamma_s, newSampleUnexplainedVarianceAdmixed));
    final int iterations = (int) (numberOfEvaluations.stream().mapToDouble(d -> d).sum() / numSamples);
    return SubroutineSignal.builder().put(StandardSubroutineSignals.RESIDUAL_ERROR_NORM, errorNormInfinity).put(StandardSubroutineSignals.ITERATIONS, iterations).build();
}
Also used : ScalarProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.ScalarProducer) Function2(org.apache.spark.api.java.function.Function2) HMMSegmentProcessor(org.broadinstitute.hellbender.utils.hmm.segmentation.HMMSegmentProcessor) GermlinePloidyAnnotatedTargetCollection(org.broadinstitute.hellbender.tools.exome.sexgenotyper.GermlinePloidyAnnotatedTargetCollection) HiddenStateSegmentRecordWriter(org.broadinstitute.hellbender.utils.hmm.segmentation.HiddenStateSegmentRecordWriter) BiFunction(java.util.function.BiFunction) GATKException(org.broadinstitute.hellbender.exceptions.GATKException) SexGenotypeData(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeData) ParamUtils(org.broadinstitute.hellbender.utils.param.ParamUtils) CallStringProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.CallStringProducer) StorageLevel(org.apache.spark.storage.StorageLevel) SynchronizedUnivariateSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.SynchronizedUnivariateSolver) CopyRatioExpectationsCalculator(org.broadinstitute.hellbender.tools.coveragemodel.interfaces.CopyRatioExpectationsCalculator) UnivariateSolverSpecifications(org.broadinstitute.hellbender.tools.coveragemodel.math.UnivariateSolverSpecifications) IndexRange(org.broadinstitute.hellbender.utils.IndexRange) Broadcast(org.apache.spark.broadcast.Broadcast) ExitStatus(org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray.ExitStatus) SexGenotypeDataCollection(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeDataCollection) HashPartitioner(org.apache.spark.HashPartitioner) Predicate(java.util.function.Predicate) GeneralLinearOperator(org.broadinstitute.hellbender.tools.coveragemodel.linalg.GeneralLinearOperator) Nd4j(org.nd4j.linalg.factory.Nd4j) INDArrayIndex(org.nd4j.linalg.indexing.INDArrayIndex) FastMath(org.apache.commons.math3.util.FastMath) org.broadinstitute.hellbender.tools.exome(org.broadinstitute.hellbender.tools.exome) Tuple2(scala.Tuple2) Collectors(java.util.stream.Collectors) Sets(com.google.common.collect.Sets) AbstractUnivariateSolver(org.apache.commons.math3.analysis.solvers.AbstractUnivariateSolver) FourierLinearOperatorNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.FourierLinearOperatorNDArray) Logger(org.apache.logging.log4j.Logger) Stream(java.util.stream.Stream) UserException(org.broadinstitute.hellbender.exceptions.UserException) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) Utils(org.broadinstitute.hellbender.utils.Utils) Function(org.apache.spark.api.java.function.Function) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer) IntStream(java.util.stream.IntStream) java.util(java.util) NDArrayIndex(org.nd4j.linalg.indexing.NDArrayIndex) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) AlleleMetadataProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.AlleleMetadataProducer) EmissionCalculationStrategy(org.broadinstitute.hellbender.tools.coveragemodel.CoverageModelCopyRatioEmissionProbabilityCalculator.EmissionCalculationStrategy) RobustBrentSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver) IntervalUtils(org.broadinstitute.hellbender.utils.IntervalUtils) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) HiddenStateSegmentRecord(org.broadinstitute.hellbender.utils.hmm.segmentation.HiddenStateSegmentRecord) ImmutableTriple(org.apache.commons.lang3.tuple.ImmutableTriple) IterativeLinearSolverNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray) GATKProtectedMathUtils(org.broadinstitute.hellbender.utils.GATKProtectedMathUtils) Nd4jIOUtils(org.broadinstitute.hellbender.tools.coveragemodel.nd4jutils.Nd4jIOUtils) IOException(java.io.IOException) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) File(java.io.File) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VisibleForTesting(com.google.common.annotations.VisibleForTesting) Transforms(org.nd4j.linalg.ops.transforms.Transforms) LogManager(org.apache.logging.log4j.LogManager) NoBracketingException(org.apache.commons.math3.exception.NoBracketingException) SynchronizedUnivariateSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.SynchronizedUnivariateSolver) AbstractUnivariateSolver(org.apache.commons.math3.analysis.solvers.AbstractUnivariateSolver) RobustBrentSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver) INDArray(org.nd4j.linalg.api.ndarray.INDArray) UnivariateSolverSpecifications(org.broadinstitute.hellbender.tools.coveragemodel.math.UnivariateSolverSpecifications)

Example 68 with Stream

use of java.util.stream.Stream in project gatk by broadinstitute.

the class CoverageModelEMWorkspace method initializeWorkerPosteriors.

@UpdatesRDD
private void initializeWorkerPosteriors() {
    /* make a local copy for lambda capture (these are small, so no broadcasting is necessary) */
    final INDArray sampleMeanLogReadDepths = this.sampleMeanLogReadDepths;
    final INDArray sampleVarLogReadDepths = this.sampleVarLogReadDepths;
    final INDArray sampleUnexplainedVariance = this.sampleUnexplainedVariance;
    final INDArray sampleBiasLatentPosteriorFirstMoments = this.sampleBiasLatentPosteriorFirstMoments;
    final INDArray sampleBiasLatentPosteriorSecondMoments = this.sampleBiasLatentPosteriorSecondMoments;
    /* calculate copy ratio prior expectations */
    logger.info("Calculating copy ratio priors on the driver node...");
    final List<CopyRatioExpectations> copyRatioPriorExpectationsList = sampleIndexStream().mapToObj(si -> copyRatioExpectationsCalculator.getCopyRatioPriorExpectations(CopyRatioCallingMetadata.builder().sampleName(processedSampleNameList.get(si)).sampleSexGenotypeData(processedSampleSexGenotypeData.get(si)).sampleAverageMappingErrorProbability(config.getMappingErrorRate()).build(), processedTargetList)).collect(Collectors.toList());
    /* update log chain posterior expectation */
    sampleLogChainPosteriors.assign(Nd4j.create(copyRatioPriorExpectationsList.stream().mapToDouble(CopyRatioExpectations::getLogChainPosteriorProbability).toArray(), new int[] { numSamples, 1 }));
    /* push per-target copy ratio expectations to workers */
    final List<Tuple2<LinearlySpacedIndexBlock, Tuple2<INDArray, INDArray>>> copyRatioPriorsList = new ArrayList<>();
    for (final LinearlySpacedIndexBlock tb : targetBlocks) {
        final double[] logCopyRatioPriorMeansBlock = IntStream.range(0, tb.getNumElements()).mapToObj(rti -> copyRatioPriorExpectationsList.stream().mapToDouble(cre -> cre.getLogCopyRatioMeans()[rti + tb.getBegIndex()]).toArray()).flatMapToDouble(Arrays::stream).toArray();
        final double[] logCopyRatioPriorVariancesBlock = IntStream.range(0, tb.getNumElements()).mapToObj(rti -> copyRatioPriorExpectationsList.stream().mapToDouble(cre -> cre.getLogCopyRatioVariances()[rti + tb.getBegIndex()]).toArray()).flatMapToDouble(Arrays::stream).toArray();
        /* we do not need to take care of log copy ratio means and variances on masked targets here.
             * potential NaNs will be rectified in the compute blocks by calling the method
             * {@link CoverageModelEMComputeBlock#cloneWithInitializedData} */
        copyRatioPriorsList.add(new Tuple2<>(tb, new Tuple2<INDArray, INDArray>(Nd4j.create(logCopyRatioPriorMeansBlock, new int[] { numSamples, tb.getNumElements() }, 'f'), Nd4j.create(logCopyRatioPriorVariancesBlock, new int[] { numSamples, tb.getNumElements() }, 'f'))));
    }
    /* push to compute blocks */
    logger.info("Pushing posteriors to worker(s)...");
    /* copy ratio priors */
    joinWithWorkersAndMap(copyRatioPriorsList, p -> p._1.cloneWithUpdateCopyRatioPriors(p._2._1, p._2._2));
    /* read depth and sample-specific unexplained variance */
    mapWorkers(cb -> cb.cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.log_d_s, sampleMeanLogReadDepths).cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.var_log_d_s, sampleVarLogReadDepths).cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.gamma_s, sampleUnexplainedVariance));
    /* if bias covariates are enabled, initialize E[z] and E[z z^T] as well */
    if (biasCovariatesEnabled) {
        mapWorkers(cb -> cb.cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.z_sl, sampleBiasLatentPosteriorFirstMoments).cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.zz_sll, sampleBiasLatentPosteriorSecondMoments));
    }
}
Also used : ScalarProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.ScalarProducer) Function2(org.apache.spark.api.java.function.Function2) HMMSegmentProcessor(org.broadinstitute.hellbender.utils.hmm.segmentation.HMMSegmentProcessor) GermlinePloidyAnnotatedTargetCollection(org.broadinstitute.hellbender.tools.exome.sexgenotyper.GermlinePloidyAnnotatedTargetCollection) HiddenStateSegmentRecordWriter(org.broadinstitute.hellbender.utils.hmm.segmentation.HiddenStateSegmentRecordWriter) BiFunction(java.util.function.BiFunction) GATKException(org.broadinstitute.hellbender.exceptions.GATKException) SexGenotypeData(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeData) ParamUtils(org.broadinstitute.hellbender.utils.param.ParamUtils) CallStringProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.CallStringProducer) StorageLevel(org.apache.spark.storage.StorageLevel) SynchronizedUnivariateSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.SynchronizedUnivariateSolver) CopyRatioExpectationsCalculator(org.broadinstitute.hellbender.tools.coveragemodel.interfaces.CopyRatioExpectationsCalculator) UnivariateSolverSpecifications(org.broadinstitute.hellbender.tools.coveragemodel.math.UnivariateSolverSpecifications) IndexRange(org.broadinstitute.hellbender.utils.IndexRange) Broadcast(org.apache.spark.broadcast.Broadcast) ExitStatus(org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray.ExitStatus) SexGenotypeDataCollection(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeDataCollection) HashPartitioner(org.apache.spark.HashPartitioner) Predicate(java.util.function.Predicate) GeneralLinearOperator(org.broadinstitute.hellbender.tools.coveragemodel.linalg.GeneralLinearOperator) Nd4j(org.nd4j.linalg.factory.Nd4j) INDArrayIndex(org.nd4j.linalg.indexing.INDArrayIndex) FastMath(org.apache.commons.math3.util.FastMath) org.broadinstitute.hellbender.tools.exome(org.broadinstitute.hellbender.tools.exome) Tuple2(scala.Tuple2) Collectors(java.util.stream.Collectors) Sets(com.google.common.collect.Sets) AbstractUnivariateSolver(org.apache.commons.math3.analysis.solvers.AbstractUnivariateSolver) FourierLinearOperatorNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.FourierLinearOperatorNDArray) Logger(org.apache.logging.log4j.Logger) Stream(java.util.stream.Stream) UserException(org.broadinstitute.hellbender.exceptions.UserException) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) Utils(org.broadinstitute.hellbender.utils.Utils) Function(org.apache.spark.api.java.function.Function) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer) IntStream(java.util.stream.IntStream) java.util(java.util) NDArrayIndex(org.nd4j.linalg.indexing.NDArrayIndex) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) AlleleMetadataProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.AlleleMetadataProducer) EmissionCalculationStrategy(org.broadinstitute.hellbender.tools.coveragemodel.CoverageModelCopyRatioEmissionProbabilityCalculator.EmissionCalculationStrategy) RobustBrentSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver) IntervalUtils(org.broadinstitute.hellbender.utils.IntervalUtils) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) HiddenStateSegmentRecord(org.broadinstitute.hellbender.utils.hmm.segmentation.HiddenStateSegmentRecord) ImmutableTriple(org.apache.commons.lang3.tuple.ImmutableTriple) IterativeLinearSolverNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray) GATKProtectedMathUtils(org.broadinstitute.hellbender.utils.GATKProtectedMathUtils) Nd4jIOUtils(org.broadinstitute.hellbender.tools.coveragemodel.nd4jutils.Nd4jIOUtils) IOException(java.io.IOException) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) File(java.io.File) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VisibleForTesting(com.google.common.annotations.VisibleForTesting) Transforms(org.nd4j.linalg.ops.transforms.Transforms) LogManager(org.apache.logging.log4j.LogManager) NoBracketingException(org.apache.commons.math3.exception.NoBracketingException) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Tuple2(scala.Tuple2)

Example 69 with Stream

use of java.util.stream.Stream in project gatk-protected by broadinstitute.

the class CoverageModelEMWorkspace method updateCopyRatioPosteriorExpectationsSpark.

/**
     * The Spark implementation of the E-step update of copy ratio posteriors
     *
     * @return a {@link SubroutineSignal} containing the update size
     */
@EvaluatesRDD
@UpdatesRDD
@CachesRDD
private SubroutineSignal updateCopyRatioPosteriorExpectationsSpark(final double admixingRatio) {
    /* local final member variables for lambda capture */
    final List<LinearlySpacedIndexBlock> targetBlocks = new ArrayList<>();
    targetBlocks.addAll(this.targetBlocks);
    final List<Target> targetList = new ArrayList<>();
    targetList.addAll(processedTargetList);
    final List<String> sampleNameList = new ArrayList<>();
    sampleNameList.addAll(processedSampleNameList);
    final List<SexGenotypeData> sampleSexGenotypeData = new ArrayList<>();
    sampleSexGenotypeData.addAll(processedSampleSexGenotypeData);
    final int numTargetBlocks = targetBlocks.size();
    final CopyRatioExpectationsCalculator<CoverageModelCopyRatioEmissionData, STATE> calculator = this.copyRatioExpectationsCalculator;
    final INDArray sampleReadDepths = Transforms.exp(sampleMeanLogReadDepths, true);
    /* make an RDD of copy ratio posterior expectations */
    final JavaPairRDD<Integer, CopyRatioExpectations> copyRatioPosteriorExpectationsPairRDD = /* fetch copy ratio emission data from workers */
    fetchCopyRatioEmissionDataSpark().mapPartitionsToPair(it -> {
        final List<Tuple2<Integer, CopyRatioExpectations>> newPartitionData = new ArrayList<>();
        while (it.hasNext()) {
            final Tuple2<Integer, List<CoverageModelCopyRatioEmissionData>> prevDatum = it.next();
            final int si = prevDatum._1;
            final CopyRatioCallingMetadata copyRatioCallingMetadata = CopyRatioCallingMetadata.builder().sampleName(sampleNameList.get(si)).sampleSexGenotypeData(sampleSexGenotypeData.get(si)).sampleCoverageDepth(sampleReadDepths.getDouble(si)).emissionCalculationStrategy(EmissionCalculationStrategy.HYBRID_POISSON_GAUSSIAN).build();
            newPartitionData.add(new Tuple2<>(prevDatum._1, calculator.getCopyRatioPosteriorExpectations(copyRatioCallingMetadata, targetList, prevDatum._2)));
        }
        return newPartitionData.iterator();
    }, true);
    /* we need to do two things to copyRatioPosteriorExpectationsPairRDD; so we cache it */
    /* step 1. update log chain posterior expectation on the driver node */
    final double[] newSampleLogChainPosteriors = copyRatioPosteriorExpectationsPairRDD.mapValues(CopyRatioExpectations::getLogChainPosteriorProbability).collect().stream().sorted(Comparator.comparingInt(t -> t._1)).mapToDouble(t -> t._2).toArray();
    sampleLogChainPosteriors.assign(Nd4j.create(newSampleLogChainPosteriors, new int[] { numSamples, 1 }));
    /* step 2. repartition in target space */
    final JavaPairRDD<LinearlySpacedIndexBlock, ImmutablePair<INDArray, INDArray>> blockifiedCopyRatioPosteriorResultsPairRDD = copyRatioPosteriorExpectationsPairRDD.flatMapToPair(dat -> targetBlocks.stream().map(tb -> new Tuple2<>(tb, new Tuple2<>(dat._1, ImmutablePair.of(dat._2.getLogCopyRatioMeans(tb), dat._2.getLogCopyRatioVariances(tb))))).iterator()).combineByKey(/* recipe to create an singleton list */
    Collections::singletonList, /* recipe to add an element to the list */
    (list, element) -> Stream.concat(list.stream(), Stream.of(element)).collect(Collectors.toList()), /* recipe to concatenate two lists */
    (list1, list2) -> Stream.concat(list1.stream(), list2.stream()).collect(Collectors.toList()), /* repartition with respect to target space blocks */
    new HashPartitioner(numTargetBlocks)).mapValues(list -> list.stream().sorted(Comparator.comparingInt(t -> t._1)).map(p -> p._2).map(t -> ImmutablePair.of(Nd4j.create(t.left), Nd4j.create(t.right))).collect(Collectors.toList())).mapValues(CoverageModelEMWorkspace::stackCopyRatioPosteriorDataForAllSamples);
    /* we do not need copy ratio expectations anymore */
    copyRatioPosteriorExpectationsPairRDD.unpersist();
    /* step 3. merge with computeRDD and update */
    computeRDD = computeRDD.join(blockifiedCopyRatioPosteriorResultsPairRDD).mapValues(t -> t._1.cloneWithUpdatedCopyRatioPosteriors(t._2.left, t._2.right, admixingRatio));
    cacheWorkers("after E-step for copy ratio update");
    /* collect subroutine signals */
    final List<SubroutineSignal> sigs = mapWorkersAndCollect(CoverageModelEMComputeBlock::getLatestMStepSignal);
    final double errorNormInfinity = Collections.max(sigs.stream().map(sig -> sig.<Double>get(StandardSubroutineSignals.RESIDUAL_ERROR_NORM)).collect(Collectors.toList()));
    return SubroutineSignal.builder().put(StandardSubroutineSignals.RESIDUAL_ERROR_NORM, errorNormInfinity).build();
}
Also used : ScalarProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.ScalarProducer) Function2(org.apache.spark.api.java.function.Function2) HMMSegmentProcessor(org.broadinstitute.hellbender.utils.hmm.segmentation.HMMSegmentProcessor) GermlinePloidyAnnotatedTargetCollection(org.broadinstitute.hellbender.tools.exome.sexgenotyper.GermlinePloidyAnnotatedTargetCollection) HiddenStateSegmentRecordWriter(org.broadinstitute.hellbender.utils.hmm.segmentation.HiddenStateSegmentRecordWriter) BiFunction(java.util.function.BiFunction) GATKException(org.broadinstitute.hellbender.exceptions.GATKException) SexGenotypeData(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeData) ParamUtils(org.broadinstitute.hellbender.utils.param.ParamUtils) CallStringProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.CallStringProducer) StorageLevel(org.apache.spark.storage.StorageLevel) SynchronizedUnivariateSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.SynchronizedUnivariateSolver) CopyRatioExpectationsCalculator(org.broadinstitute.hellbender.tools.coveragemodel.interfaces.CopyRatioExpectationsCalculator) UnivariateSolverSpecifications(org.broadinstitute.hellbender.tools.coveragemodel.math.UnivariateSolverSpecifications) IndexRange(org.broadinstitute.hellbender.utils.IndexRange) Broadcast(org.apache.spark.broadcast.Broadcast) ExitStatus(org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray.ExitStatus) SexGenotypeDataCollection(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeDataCollection) HashPartitioner(org.apache.spark.HashPartitioner) Predicate(java.util.function.Predicate) GeneralLinearOperator(org.broadinstitute.hellbender.tools.coveragemodel.linalg.GeneralLinearOperator) Nd4j(org.nd4j.linalg.factory.Nd4j) INDArrayIndex(org.nd4j.linalg.indexing.INDArrayIndex) FastMath(org.apache.commons.math3.util.FastMath) org.broadinstitute.hellbender.tools.exome(org.broadinstitute.hellbender.tools.exome) Tuple2(scala.Tuple2) Collectors(java.util.stream.Collectors) Sets(com.google.common.collect.Sets) AbstractUnivariateSolver(org.apache.commons.math3.analysis.solvers.AbstractUnivariateSolver) FourierLinearOperatorNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.FourierLinearOperatorNDArray) Logger(org.apache.logging.log4j.Logger) Stream(java.util.stream.Stream) UserException(org.broadinstitute.hellbender.exceptions.UserException) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) Utils(org.broadinstitute.hellbender.utils.Utils) Function(org.apache.spark.api.java.function.Function) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer) IntStream(java.util.stream.IntStream) java.util(java.util) NDArrayIndex(org.nd4j.linalg.indexing.NDArrayIndex) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) AlleleMetadataProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.AlleleMetadataProducer) EmissionCalculationStrategy(org.broadinstitute.hellbender.tools.coveragemodel.CoverageModelCopyRatioEmissionProbabilityCalculator.EmissionCalculationStrategy) RobustBrentSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver) IntervalUtils(org.broadinstitute.hellbender.utils.IntervalUtils) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) HiddenStateSegmentRecord(org.broadinstitute.hellbender.utils.hmm.segmentation.HiddenStateSegmentRecord) ImmutableTriple(org.apache.commons.lang3.tuple.ImmutableTriple) IterativeLinearSolverNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray) GATKProtectedMathUtils(org.broadinstitute.hellbender.utils.GATKProtectedMathUtils) Nd4jIOUtils(org.broadinstitute.hellbender.tools.coveragemodel.nd4jutils.Nd4jIOUtils) IOException(java.io.IOException) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) File(java.io.File) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VisibleForTesting(com.google.common.annotations.VisibleForTesting) Transforms(org.nd4j.linalg.ops.transforms.Transforms) LogManager(org.apache.logging.log4j.LogManager) NoBracketingException(org.apache.commons.math3.exception.NoBracketingException) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) Tuple2(scala.Tuple2) SexGenotypeData(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeData) HashPartitioner(org.apache.spark.HashPartitioner)

Example 70 with Stream

use of java.util.stream.Stream in project chilo-producer by cccties.

the class Util method cleanUpDirectory.

public static void cleanUpDirectory(Path targetPath) throws IOException {
    if (Files.exists(targetPath)) {
        try (Stream<Path> stream = Files.walk(targetPath)) {
            List<Path> filePaths = new ArrayList<>();
            // rootDirectory以下にあるディレクトリ以外のファイルをすべてitemsにaddしていく。
            stream.filter(entry -> !(Files.isDirectory(entry))).forEach(filePaths::add);
            for (Path file : filePaths) {
                Files.delete(file);
            }
        }
    }
}
Also used : Path(java.nio.file.Path) BufferedImage(java.awt.image.BufferedImage) Files(java.nio.file.Files) Predicate(java.util.function.Predicate) LocalDateTime(java.time.LocalDateTime) IOException(java.io.IOException) BasicFileAttributes(java.nio.file.attribute.BasicFileAttributes) File(java.io.File) StandardCopyOption(java.nio.file.StandardCopyOption) ArrayList(java.util.ArrayList) BiPredicate(java.util.function.BiPredicate) List(java.util.List) Stream(java.util.stream.Stream) DateTimeFormatter(java.time.format.DateTimeFormatter) FileTypeMap(javax.activation.FileTypeMap) ImageIO(javax.imageio.ImageIO) Clock(java.time.Clock) Log(org.apache.commons.logging.Log) LogFactory(org.apache.commons.logging.LogFactory) Path(java.nio.file.Path) ArrayList(java.util.ArrayList)

Aggregations

Stream (java.util.stream.Stream)161 Collectors (java.util.stream.Collectors)98 List (java.util.List)89 ArrayList (java.util.ArrayList)66 Map (java.util.Map)66 Set (java.util.Set)59 IOException (java.io.IOException)58 Optional (java.util.Optional)45 Collections (java.util.Collections)43 HashMap (java.util.HashMap)43 Arrays (java.util.Arrays)33 HashSet (java.util.HashSet)33 File (java.io.File)32 Path (java.nio.file.Path)32 Function (java.util.function.Function)28 Logger (org.slf4j.Logger)26 LoggerFactory (org.slf4j.LoggerFactory)26 java.util (java.util)25 Predicate (java.util.function.Predicate)23 Objects (java.util.Objects)22